From f4d66885cf238d8bf64a575ac5d654415902ba72 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 13 Feb 2018 16:43:26 -0500
Subject: [PATCH] more parallelism

---
 .../train_allele_specific_models_command.py   | 31 ++++++++++---------
 1 file changed, 16 insertions(+), 15 deletions(-)

diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index 694e4a47..7e4fdeb3 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -258,21 +258,22 @@ def run(argv=sys.argv[1:]):
             hyperparameters['max_epochs'] = args.max_epochs
 
         for (i, allele) in enumerate(df.allele.unique()):
-            work_dict = {
-                'n_models': n_models,
-                'allele_num': i,
-                'n_alleles': len(alleles),
-                'hyperparameter_set_num': h,
-                'num_hyperparameter_sets': len(hyperparameters_lst),
-                'allele': allele,
-                'data': None,  # subselect from GLOBAL_DATA["train_data"]
-                'hyperparameters': hyperparameters,
-                'verbose': args.verbosity,
-                'progress_print_interval': None if worker_pool else 5.0,
-                'predictor': predictor if not worker_pool else None,
-                'save_to': args.out_models_dir if not worker_pool else None,
-            }
-            work_items.append(work_dict)
+            for model_num in range(n_models):
+                work_dict = {
+                    'n_models': 1,
+                    'allele_num': i,
+                    'n_alleles': len(alleles),
+                    'hyperparameter_set_num': h,
+                    'num_hyperparameter_sets': len(hyperparameters_lst),
+                    'allele': allele,
+                    'data': None,  # subselect from GLOBAL_DATA["train_data"]
+                    'hyperparameters': hyperparameters,
+                    'verbose': args.verbosity,
+                    'progress_print_interval': None if worker_pool else 5.0,
+                    'predictor': predictor if not worker_pool else None,
+                    'save_to': args.out_models_dir if not worker_pool else None,
+                }
+                work_items.append(work_dict)
 
     if worker_pool:
         print("Processing %d work items in parallel." % len(work_items))
-- 
GitLab