From fc26e6875813678cfb49a387b6ef5a2714deeb61 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Mon, 9 Jan 2017 17:52:22 -0500
Subject: [PATCH] Swith to using parallel_backend.map instead of
 parallel_backend.submit, to be compatable with kubeface

---
 .../cross_validation.py                       | 34 ++++++------
 .../cv_and_train_command.py                   | 48 +++++++++--------
 mhcflurry/parallelism.py                      | 54 +++++++++++--------
 3 files changed, 76 insertions(+), 60 deletions(-)

diff --git a/mhcflurry/class1_allele_specific/cross_validation.py b/mhcflurry/class1_allele_specific/cross_validation.py
index 5ceeb4b7..4d3dc41a 100644
--- a/mhcflurry/class1_allele_specific/cross_validation.py
+++ b/mhcflurry/class1_allele_specific/cross_validation.py
@@ -142,6 +142,7 @@ def cross_validation_folds(
         alleles = train_data.unique_alleles()
 
     result_folds = []
+    imputation_args = []
     for allele in alleles:
         logging.info("Allele: %s" % allele)
         cv_iter = train_data.cross_validation_iterator(
@@ -165,27 +166,30 @@ def cross_validation_folds(
                 test_split = full_test_split
 
             if imputer is not None:
-                imputation_future = parallel_backend.submit(
-                    impute_and_select_allele,
-                    all_allele_train_split,
+                base_args = dict(impute_kwargs)
+                base_args.update(dict(
+                    dataset=all_allele_train_split,
                     imputer=imputer,
-                    allele=allele,
-                    **impute_kwargs)
-            else:
-                imputation_future = None
+                    allele=allele))
+                imputation_args.append(base_args)
 
             train_split = all_allele_train_split.get_allele(allele)
             fold = AlleleSpecificTrainTestFold(
+                imputed_train=None,  # updated later
                 allele=allele,
                 train=train_split,
-                imputed_train=imputation_future,
                 test=test_split)
             result_folds.append(fold)
 
-    return [
-        result_fold._replace(imputed_train=(
-            result_fold.imputed_train.result()
-            if result_fold.imputed_train is not None
-            else None))
-        for result_fold in result_folds
-    ]
+    if imputation_args:
+        assert len(imputation_args) == len(result_folds)
+        imputation_results = parallel_backend.map(
+            lambda kwargs: impute_and_select_allele(**kwargs),
+            imputation_args)
+
+        return [
+            result_fold._replace(imputed_train=imputation_result)
+            for (result_fold, imputation_result) in zip(
+                result_folds, imputation_results)
+        ]
+    return result_fold
diff --git a/mhcflurry/class1_allele_specific/cv_and_train_command.py b/mhcflurry/class1_allele_specific/cv_and_train_command.py
index b1256ac0..dac1c4e1 100644
--- a/mhcflurry/class1_allele_specific/cv_and_train_command.py
+++ b/mhcflurry/class1_allele_specific/cv_and_train_command.py
@@ -314,23 +314,24 @@ def go(args):
     logging.info("")
     train_folds = []
     train_models = []
+    imputation_args_list = []
+    best_architectures = []
     for (allele_num, allele) in enumerate(cv_results.allele.unique()):
         best_index = best_architectures_by_allele[allele]
         architecture = model_architectures[best_index]
+        best_architectures.append(architecture)
         train_models.append(architecture)
         logging.info(
             "Allele: %s best architecture is index %d: %s" %
             (allele, best_index, architecture))
 
         if architecture['impute']:
-            imputation_future = backend.submit(
-                impute_and_select_allele,
-                train_data,
+            imputation_args = dict(impute_kwargs)
+            imputation_args.update(dict(
+                dataset=train_data,
                 imputer=imputer,
-                allele=allele,
-                **impute_kwargs)
-        else:
-            imputation_future = None
+                allele=allele))
+            imputation_args_list.append(imputation_args)
 
         test_data_this_allele = None
         if test_data is not None:
@@ -338,25 +339,26 @@ def go(args):
         fold = AlleleSpecificTrainTestFold(
             allele=allele,
             train=train_data.get_allele(allele),
-
-            # Here we set imputed_train to the imputation *task* if
-            # imputation was used on this fold. We set this to the actual
-            # imputed training dataset a few lines farther down. This
-            # complexity is because we want to be able to parallelize
-            # the imputations so we have to queue up the tasks first.
-            # If we are not doing imputation then the imputation_task
-            # is None.
-            imputed_train=imputation_future,
+            imputed_train=None,
             test=test_data_this_allele)
         train_folds.append(fold)
 
-    train_folds = [
-        result_fold._replace(imputed_train=(
-            result_fold.imputed_train.result()
-            if result_fold.imputed_train is not None
-            else None))
-        for result_fold in train_folds
-    ]
+    if imputation_args_list:
+        imputation_results = list(backend.map(
+            lambda kwargs: impute_and_select_allele(**kwargs),
+            imputation_args_list))
+
+        new_train_folds = []
+        for (best_architecture, train_fold) in zip(
+                best_architectures, train_folds):
+            imputed_train = None
+            if best_architecture['impute']:
+                imputed_train = imputation_results.pop(0)
+            new_train_folds.append(
+                train_fold._replace(imputed_train=imputed_train))
+        assert not imputation_results
+
+        train_folds = new_train_folds
 
     logging.info("Training %d production models" % len(train_folds))
     start = time.time()
diff --git a/mhcflurry/parallelism.py b/mhcflurry/parallelism.py
index 4d0c47ac..c82b40e0 100644
--- a/mhcflurry/parallelism.py
+++ b/mhcflurry/parallelism.py
@@ -14,28 +14,6 @@ class ParallelBackend(object):
         self.module = module
         self.verbose = verbose
 
-    def submit(self, func, *args, **kwargs):
-        if self.verbose > 0:
-            logging.debug("Submitting: %s %s %s" % (func, args, kwargs))
-        return self.executor.submit(func, *args, **kwargs)
-
-    def map(self, func, iterable):
-        fs = [
-            self.executor.submit(func, arg) for arg in iterable
-        ]
-        return self.wait(fs)
-
-    def wait(self, fs):
-        result_dict = {}
-        for finished_future in self.module.as_completed(fs):
-            result = finished_future.result()
-            logging.info("%3d / %3d tasks completed" % (
-                len(result_dict), len(fs)))
-            result_dict[finished_future] = result
-
-        return [result_dict[future] for future in fs]
-
-
 class KubefaceParallelBackend(ParallelBackend):
     """
     ParallelBackend that uses kubeface
@@ -61,6 +39,22 @@ class DaskDistributedParallelBackend(ParallelBackend):
         ParallelBackend.__init__(self, executor, distributed, verbose=verbose)
         self.scheduler_ip_and_port = scheduler_ip_and_port
 
+    def map(self, func, iterable):
+        fs = [
+            self.executor.submit(func, arg) for arg in iterable
+        ]
+        return self.wait(fs)
+
+    def wait(self, fs):
+        result_dict = {}
+        for finished_future in self.module.as_completed(fs):
+            result = finished_future.result()
+            logging.info("%3d / %3d tasks completed" % (
+                len(result_dict), len(fs)))
+            result_dict[finished_future] = result
+
+        return [result_dict[future] for future in fs]
+
     def __str__(self):
         return "<Dask distributed backend, scheduler=%s, total_cores=%d>" % (
             self.scheduler_ip_and_port,
@@ -85,6 +79,22 @@ class ConcurrentFuturesParallelBackend(ParallelBackend):
         return "<Concurrent futures %s parallel backend, num workers = %d>" % (
             ("processes" if self.processes else "threads"), self.num_workers)
 
+    def map(self, func, iterable):
+        fs = [
+            self.executor.submit(func, arg) for arg in iterable
+        ]
+        return self.wait(fs)
+
+    def wait(self, fs):
+        result_dict = {}
+        for finished_future in self.module.as_completed(fs):
+            result = finished_future.result()
+            logging.info("%3d / %3d tasks completed" % (
+                len(result_dict), len(fs)))
+            result_dict[finished_future] = result
+
+        return [result_dict[future] for future in fs]
+
 
 def set_default_backend(backend):
     global DEFAULT_BACKEND
-- 
GitLab