Skip to content
Snippets Groups Projects
Commit c0479deb authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Multi-GPU training

parent d2b44b38
No related merge requests found
...@@ -37,7 +37,7 @@ time mhcflurry-class1-train-allele-specific-models \ ...@@ -37,7 +37,7 @@ time mhcflurry-class1-train-allele-specific-models \
--out-models-dir models \ --out-models-dir models \
--percent-rank-calibration-num-peptides-per-length 0 \ --percent-rank-calibration-num-peptides-per-length 0 \
--min-measurements-per-allele 75 \ --min-measurements-per-allele 75 \
--num-jobs 32 16 --num-jobs 32 --gpus 4 --backend tensorflow-default
cp $SCRIPT_ABSOLUTE_PATH . cp $SCRIPT_ABSOLUTE_PATH .
bzip2 LOG.txt bzip2 LOG.txt
......
...@@ -27,6 +27,9 @@ def set_keras_backend(backend): ...@@ -27,6 +27,9 @@ def set_keras_backend(backend):
elif backend == "tensorflow-gpu": elif backend == "tensorflow-gpu":
print("Forcing tensorflow/GPU backend.") print("Forcing tensorflow/GPU backend.")
device_count = {'CPU': 0, 'GPU': 1} device_count = {'CPU': 0, 'GPU': 1}
elif backend == "tensorflow-default":
print("Forcing tensorflow backend.")
device_count = None
else: else:
raise ValueError("Unsupported backend: %s" % backend) raise ValueError("Unsupported backend: %s" % backend)
...@@ -34,6 +37,7 @@ def set_keras_backend(backend): ...@@ -34,6 +37,7 @@ def set_keras_backend(backend):
from keras import backend as K from keras import backend as K
config = tensorflow.ConfigProto( config = tensorflow.ConfigProto(
device_count=device_count) device_count=device_count)
config.gpu_options.allow_growth=True
session = tensorflow.Session(config=config) session = tensorflow.Session(config=config)
K.set_session(session) K.set_session(session)
......
...@@ -7,7 +7,8 @@ import signal ...@@ -7,7 +7,8 @@ import signal
import sys import sys
import time import time
import traceback import traceback
from multiprocessing import Pool import itertools
from multiprocessing import Pool, Queue
from functools import partial from functools import partial
from pprint import pprint from pprint import pprint
...@@ -114,9 +115,11 @@ parser.add_argument( ...@@ -114,9 +115,11 @@ parser.add_argument(
"Set to 1 for serial run. Set to 0 to use number of cores. Default: %(default)s.") "Set to 1 for serial run. Set to 0 to use number of cores. Default: %(default)s.")
parser.add_argument( parser.add_argument(
"--backend", "--backend",
choices=("tensorflow-gpu", "tensorflow-cpu"), choices=("tensorflow-gpu", "tensorflow-cpu", "tensorflow-default"),
help="Keras backend. If not specified will use system default.") help="Keras backend. If not specified will use system default.")
parser.add_argument(
"--gpus",
type=int)
def run(argv=sys.argv[1:]): def run(argv=sys.argv[1:]):
global GLOBAL_DATA global GLOBAL_DATA
...@@ -127,9 +130,6 @@ def run(argv=sys.argv[1:]): ...@@ -127,9 +130,6 @@ def run(argv=sys.argv[1:]):
args = parser.parse_args(argv) args = parser.parse_args(argv)
if args.backend:
set_keras_backend(args.backend)
configure_logging(verbose=args.verbosity > 1) configure_logging(verbose=args.verbosity > 1)
hyperparameters_lst = yaml.load(open(args.hyperparameters)) hyperparameters_lst = yaml.load(open(args.hyperparameters))
...@@ -170,14 +170,37 @@ def run(argv=sys.argv[1:]): ...@@ -170,14 +170,37 @@ def run(argv=sys.argv[1:]):
print("Training data: %s" % (str(df.shape))) print("Training data: %s" % (str(df.shape)))
GLOBAL_DATA["train_data"] = df GLOBAL_DATA["train_data"] = df
GLOBAL_DATA["args"] = args
predictor = Class1AffinityPredictor() predictor = Class1AffinityPredictor()
if args.num_jobs[0] == 1: if args.num_jobs[0] == 1:
# Serial run # Serial run
print("Running in serial.") print("Running in serial.")
worker_pool = None worker_pool = None
if args.backend:
set_keras_backend(args.backend)
else: else:
env_queue = None
if args.gpus:
next_device = itertools.cycle([
"%d" % num
for num in range(args.gpus)
])
queue_items = []
for num in range(args.num_jobs[0]):
queue_items.append([
("CUDA_VISIBLE_DEVICES", next(next_device)),
])
print("Attempting to round-robin assign each worker a GPU", queue_items)
env_queue = Queue()
for item in queue_items:
env_queue.put(item)
worker_pool = Pool( worker_pool = Pool(
initializer=worker_init,
initargs=(env_queue,),
processes=( processes=(
args.num_jobs[0] args.num_jobs[0]
if args.num_jobs[0] else None)) if args.num_jobs[0] else None))
...@@ -401,5 +424,17 @@ def calibrate_percentile_ranks(allele, predictor, peptides=None): ...@@ -401,5 +424,17 @@ def calibrate_percentile_ranks(allele, predictor, peptides=None):
} }
def worker_init(env_queue=None):
global GLOBAL_DATA
if env_queue:
settings = env_queue.get()
print("Setting: ", settings)
os.environ.update(settings)
command_args = GLOBAL_DATA['args']
if command_args.backend:
set_keras_backend(command_args.backend)
if __name__ == '__main__': if __name__ == '__main__':
run() run()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment