From c0479deb06d309266a6e3c775ce5d477cb4bd41f Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Fri, 9 Feb 2018 15:47:21 -0500
Subject: [PATCH] Multi-GPU training

---
 .../models_class1_unselected/GENERATE.sh      |  2 +-
 mhcflurry/common.py                           |  4 ++
 .../train_allele_specific_models_command.py   | 47 ++++++++++++++++---
 3 files changed, 46 insertions(+), 7 deletions(-)

diff --git a/downloads-generation/models_class1_unselected/GENERATE.sh b/downloads-generation/models_class1_unselected/GENERATE.sh
index 4890da0c..3a3af72e 100755
--- a/downloads-generation/models_class1_unselected/GENERATE.sh
+++ b/downloads-generation/models_class1_unselected/GENERATE.sh
@@ -37,7 +37,7 @@ time mhcflurry-class1-train-allele-specific-models \
     --out-models-dir models \
     --percent-rank-calibration-num-peptides-per-length 0 \
     --min-measurements-per-allele 75 \
-    --num-jobs 32 16
+    --num-jobs 32 --gpus 4 --backend tensorflow-default
 
 cp $SCRIPT_ABSOLUTE_PATH .
 bzip2 LOG.txt
diff --git a/mhcflurry/common.py b/mhcflurry/common.py
index 942e9fea..a699235f 100644
--- a/mhcflurry/common.py
+++ b/mhcflurry/common.py
@@ -27,6 +27,9 @@ def set_keras_backend(backend):
     elif backend == "tensorflow-gpu":
         print("Forcing tensorflow/GPU backend.")
         device_count = {'CPU': 0, 'GPU': 1}
+    elif backend == "tensorflow-default":
+        print("Forcing tensorflow backend.")
+        device_count = None
     else:
         raise ValueError("Unsupported backend: %s" % backend)
 
@@ -34,6 +37,7 @@ def set_keras_backend(backend):
     from keras import backend as K
     config = tensorflow.ConfigProto(
         device_count=device_count)
+    config.gpu_options.allow_growth=True 
     session = tensorflow.Session(config=config)
     K.set_session(session)
 
diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index c05ce479..4b525127 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -7,7 +7,8 @@ import signal
 import sys
 import time
 import traceback
-from multiprocessing import Pool
+import itertools
+from multiprocessing import Pool, Queue
 from functools import partial
 from pprint import pprint
 
@@ -114,9 +115,11 @@ parser.add_argument(
     "Set to 1 for serial run. Set to 0 to use number of cores. Default: %(default)s.")
 parser.add_argument(
     "--backend",
-    choices=("tensorflow-gpu", "tensorflow-cpu"),
+    choices=("tensorflow-gpu", "tensorflow-cpu", "tensorflow-default"),
     help="Keras backend. If not specified will use system default.")
-
+parser.add_argument(
+    "--gpus",
+    type=int)
 
 def run(argv=sys.argv[1:]):
     global GLOBAL_DATA
@@ -127,9 +130,6 @@ def run(argv=sys.argv[1:]):
 
     args = parser.parse_args(argv)
 
-    if args.backend:
-        set_keras_backend(args.backend)
-
     configure_logging(verbose=args.verbosity > 1)
 
     hyperparameters_lst = yaml.load(open(args.hyperparameters))
@@ -170,14 +170,37 @@ def run(argv=sys.argv[1:]):
     print("Training data: %s" % (str(df.shape)))
 
     GLOBAL_DATA["train_data"] = df
+    GLOBAL_DATA["args"] = args
 
     predictor = Class1AffinityPredictor()
     if args.num_jobs[0] == 1:
         # Serial run
         print("Running in serial.")
         worker_pool = None
+        if args.backend:
+            set_keras_backend(args.backend)
+
     else:
+        env_queue = None
+        if args.gpus:
+            next_device = itertools.cycle([
+                "%d" % num
+                for num in range(args.gpus)
+            ])
+            queue_items = []
+            for num in range(args.num_jobs[0]):
+                queue_items.append([
+                    ("CUDA_VISIBLE_DEVICES", next(next_device)),
+                ])
+        
+            print("Attempting to round-robin assign each worker a GPU", queue_items)
+            env_queue = Queue()
+            for item in queue_items:
+                env_queue.put(item)
+
         worker_pool = Pool(
+            initializer=worker_init,
+            initargs=(env_queue,),
             processes=(
                 args.num_jobs[0]
                 if args.num_jobs[0] else None))
@@ -401,5 +424,17 @@ def calibrate_percentile_ranks(allele, predictor, peptides=None):
     }
 
 
+def worker_init(env_queue=None):
+    global GLOBAL_DATA
+
+    if env_queue:
+        settings = env_queue.get()
+        print("Setting: ", settings)
+        os.environ.update(settings)
+        
+    command_args = GLOBAL_DATA['args']
+    if command_args.backend:
+        set_keras_backend(command_args.backend)
+
 if __name__ == '__main__':
     run()
-- 
GitLab