From b82888ef312f9e2d4d6b6f3cc1c8dbcb75d76dd5 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Thu, 22 Aug 2019 12:27:12 -0400
Subject: [PATCH] fix

---
 mhcflurry/local_parallelism.py               |  2 +-
 mhcflurry/train_pan_allele_models_command.py | 47 +++++++++-----------
 2 files changed, 23 insertions(+), 26 deletions(-)

diff --git a/mhcflurry/local_parallelism.py b/mhcflurry/local_parallelism.py
index b471f52a..ac3facaa 100644
--- a/mhcflurry/local_parallelism.py
+++ b/mhcflurry/local_parallelism.py
@@ -21,7 +21,7 @@ def add_local_parallelism_args(parser):
         default=0,
         type=int,
         metavar="N",
-        help="Number of processes to parallelize training over. Experimental. "
+        help="Number of local processes to parallelize training over. "
              "Set to 0 for serial run. Default: %(default)s.")
     group.add_argument(
         "--backend",
diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py
index e61f5bae..cf6a8624 100644
--- a/mhcflurry/train_pan_allele_models_command.py
+++ b/mhcflurry/train_pan_allele_models_command.py
@@ -398,7 +398,7 @@ def train_models(args):
     print("Found %d work items, of which %d are incomplete and will run now." % (
         len(all_work_items), len(work_items)))
 
-    serial_run = args.num_jobs == 0
+    serial_run = not args.cluster_parallelism and args.num_jobs == 0
 
     # The estimated time to completion is more accurate if we randomize
     # the order of the work.
@@ -415,7 +415,19 @@ def train_models(args):
 
     start = time.time()
 
-    if args.cluster_parallelism:
+    worker_pool = None
+    if serial_run:
+        # Run in serial. Every worker is passed the same predictor,
+        # which it adds models to, so no merging is required. It also saves
+        # as it goes so no saving is required at the end.
+        print("Processing %d work items in serial." % len(work_items))
+        for _ in tqdm.trange(len(work_items)):
+            item = work_items.pop(0)  # want to keep freeing up memory
+            work_predictor = train_model(**item)
+            assert work_predictor is predictor
+        assert not work_items
+        results_generator = None
+    elif args.cluster_parallelism:
         # Run using separate processes HPC cluster.
         results_generator = cluster_results_from_args(
             args,
@@ -423,35 +435,21 @@ def train_models(args):
             work_items=work_items,
             constant_data=GLOBAL_DATA,
             result_serialization_method="save_predictor")
-        worker_pool = None
     else:
         worker_pool = worker_pool_with_gpu_assignments_from_args(args)
         print("Worker pool", worker_pool)
+        assert worker_pool is not None
 
-        if worker_pool:
-            print("Processing %d work items in parallel." % len(work_items))
-            assert not serial_run
+        print("Processing %d work items in parallel." % len(work_items))
+        assert not serial_run
 
-            results_generator = worker_pool.imap_unordered(
-                partial(call_wrapped_kwargs, train_model),
-                work_items,
-                chunksize=1)
-        else:
-            # Run in serial. In this case, every worker is passed the same predictor,
-            # which it adds models to, so no merging is required. It also saves
-            # as it goes so no saving is required at the end.
-            print("Processing %d work items in serial." % len(work_items))
-            assert serial_run
-            for _ in tqdm.trange(len(work_items)):
-                item = work_items.pop(0)  # want to keep freeing up memory
-                work_predictor = train_model(**item)
-                assert work_predictor is predictor
-            assert not work_items
-            results_generator = None
+        results_generator = worker_pool.imap_unordered(
+            partial(call_wrapped_kwargs, train_model),
+            work_items,
+            chunksize=1)
 
     if results_generator:
-        #for new_predictor in tqdm.tqdm(results_generator, total=len(work_items)):
-        for new_predictor in results_generator:
+        for new_predictor in tqdm.tqdm(results_generator, total=len(work_items)):
             save_start = time.time()
             (new_model_name,) = predictor.merge_in_place([new_predictor])
             predictor.save(
@@ -465,7 +463,6 @@ def train_models(args):
                     time.time() - save_start,
                     args.out_models_dir))
 
-    print("Saving final predictor to: %s" % args.out_models_dir)
     # We want the final predictor to support all alleles with sequences, not
     # just those we actually used for model training.
     predictor.allele_to_sequence = (
-- 
GitLab