From 0b3a9b9fd7bf8f3afa78c562771bba1508a43b0d Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Sat, 27 Jan 2018 22:26:33 -0500
Subject: [PATCH] fix

---
 .../train_allele_specific_models_command.py   | 20 ++++++++++++++++---
 1 file changed, 17 insertions(+), 3 deletions(-)

diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index 406da3e0..583fe013 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -155,14 +155,14 @@ def run(argv=sys.argv[1:]):
     print("Training data: %s" % (str(df.shape)))
 
     predictor = Class1AffinityPredictor()
-    if args.parallelization_num_jobs == 1:
+    if args.train_num_jobs == 1:
         # Serial run
         worker_pool = None
     else:
         worker_pool = Pool(
             processes=(
-                args.parallelization_num_jobs
-                if args.parallelization_num_jobs else None))
+                args.train_num_jobs
+                if args.train_num_jobs else None))
         print("Using worker pool: %s" % str(worker_pool))
 
     if args.out_models_dir and not os.path.exists(args.out_models_dir):
@@ -235,7 +235,21 @@ def run(argv=sys.argv[1:]):
         len(predictor.neural_networks), training_time / 60.0))
     print("*" * 30)
 
+    if worker_pool:
+        worker_pool.close()
+        worker_pool.join()
+
     if args.percent_rank_calibration_num_peptides_per_length > 0:
+        if args.calibration_num_jobs == 1:
+            # Serial run
+            worker_pool = None
+        else:
+            worker_pool = Pool(
+                processes=(
+                    args.calibration_num_jobs
+                    if args.train_num_jobs else None))
+            print("Using worker pool: %s" % str(worker_pool))
+
         print("Performing percent rank calibration.")
         start = time.time()
         predictor.calibrate_percentile_ranks(
-- 
GitLab