From a9bac0c6c618f7f1fca7c20d357582ec31717719 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Wed, 1 Mar 2017 21:34:04 -0500
Subject: [PATCH] fixes

---
 .../class1_ensemble_multi_allele_predictor.py | 82 +++++++++----------
 test/test_ensemble.py                         |  6 --
 2 files changed, 39 insertions(+), 49 deletions(-)

diff --git a/mhcflurry/class1_allele_specific_ensemble/class1_ensemble_multi_allele_predictor.py b/mhcflurry/class1_allele_specific_ensemble/class1_ensemble_multi_allele_predictor.py
index 8c23c6d4..19021c35 100644
--- a/mhcflurry/class1_allele_specific_ensemble/class1_ensemble_multi_allele_predictor.py
+++ b/mhcflurry/class1_allele_specific_ensemble/class1_ensemble_multi_allele_predictor.py
@@ -45,7 +45,7 @@ IMPUTE_HYPERPARAMETER_DEFAULTS = HyperparameterDefaults(
     impute_method='mice',
     impute_min_observations_per_peptide=1,
     impute_min_observations_per_allele=1,
-    imputer_args=[{"n_burn_in": 5, "n_imputations": 25}])
+    imputer_args={"n_burn_in": 5, "n_imputations": 25})
 
 HYPERPARAMETER_DEFAULTS = (
     HyperparameterDefaults(
@@ -64,14 +64,13 @@ def fit_and_test(
         fold_num,
         train_measurement_collection_broadcast,
         test_measurement_collection_broadcast,
-        alleles,
-        hyperparameters_list):
+        allele_and_hyperparameter_pairs):
 
     assert len(train_measurement_collection_broadcast.value.df) > 0
     assert len(test_measurement_collection_broadcast.value.df) > 0
 
     results = []
-    for all_hyperparameters in hyperparameters_list:
+    for (allele, all_hyperparameters) in allele_and_hyperparameter_pairs:
         measurement_collection_hyperparameters = (
             MEASUREMENT_COLLECTION_HYPERPARAMETER_DEFAULTS.subselect(
                 all_hyperparameters))
@@ -79,34 +78,33 @@ def fit_and_test(
             Class1BindingPredictor.hyperparameter_defaults.subselect(
                 all_hyperparameters))
 
-        for allele in alleles:
-            train_dataset = (
-                train_measurement_collection_broadcast
-                .value
-                .select_allele(allele)
-                .to_dataset(**measurement_collection_hyperparameters))
-            test_dataset = (
-                test_measurement_collection_broadcast
-                .value
-                .select_allele(allele)
-                .to_dataset(**measurement_collection_hyperparameters))
-
-            assert len(train_dataset) > 0
-            assert len(test_dataset) > 0
-
-            model = Class1BindingPredictor(**model_hyperparameters)
-
-            model.fit_dataset(train_dataset)
-            predictions = model.predict(test_dataset.peptides)
-            scores = scoring.make_scores(
-                test_dataset.affinities, predictions)
-            results.append({
-                'fold_num': fold_num,
-                'allele': allele,
-                'hyperparameters': all_hyperparameters,
-                'model': model,
-                'scores': scores
-            })
+        train_dataset = (
+            train_measurement_collection_broadcast
+            .value
+            .select_allele(allele)
+            .to_dataset(**measurement_collection_hyperparameters))
+        test_dataset = (
+            test_measurement_collection_broadcast
+            .value
+            .select_allele(allele)
+            .to_dataset(**measurement_collection_hyperparameters))
+
+        assert len(train_dataset) > 0
+        assert len(test_dataset) > 0
+
+        model = Class1BindingPredictor(**model_hyperparameters)
+
+        model.fit_dataset(train_dataset)
+        predictions = model.predict(test_dataset.peptides)
+        scores = scoring.make_scores(
+            test_dataset.affinities, predictions)
+        results.append({
+            'fold_num': fold_num,
+            'allele': allele,
+            'hyperparameters': all_hyperparameters,
+            'model': model,
+            'scores': scores
+        })
     return results
 
 
@@ -321,19 +319,16 @@ class Class1EnsembleMultiAllelePredictor(object):
             train_broadcast = parallel_backend.broadcast(train_split)
             test_broadcast = parallel_backend.broadcast(test_split)
 
-            task_alleles = []
-            task_models = []
+            task_allele_model_pairs = []
 
             def make_task():
-                if task_alleles and task_models:
+                if task_allele_model_pairs:
                     tasks.append((
                         fold_num,
                         train_broadcast,
                         test_broadcast,
-                        list(task_alleles),
-                        list(task_models)))
-                task_alleles.clear()
-                task_models.clear()
+                        list(task_allele_model_pairs)))
+                    task_allele_model_pairs.clear()
 
             assert all(
                 allele in set(train_split.df.allele.unique())
@@ -348,12 +343,11 @@ class Class1EnsembleMultiAllelePredictor(object):
 
             for allele in alleles:
                 for model in self.hyperparameters_to_search:
-                    task_models.append(model)
-                    if len(task_alleles) * len(task_models) > work_per_task:
+                    task_allele_model_pairs.append((allele, model))
+                    if len(task_allele_model_pairs) > work_per_task:
                         make_task()
                 make_task()
-            assert not task_alleles
-            assert not task_models
+            assert not task_allele_model_pairs
 
         logging.info(
             "Training and scoring models: %d tasks (target was %d), "
@@ -365,6 +359,7 @@ class Class1EnsembleMultiAllelePredictor(object):
                 len(self.hyperparameters_to_search),
                 total_work))
 
+        assert len(tasks) > 0
         results = parallel_backend.map(call_fit_and_test, tasks)
 
         # fold number -> allele -> best model
@@ -404,6 +399,7 @@ class Class1EnsembleMultiAllelePredictor(object):
                         manifest_entry["%s_%s" % (key, sub_key)] = value
                 manifest_rows.append(manifest_entry)
 
+        assert len(manifest_rows) > 0
         manifest_df = pandas.DataFrame(manifest_rows)
         manifest_df.index = manifest_df.model_name
         del manifest_df["model_name"]
diff --git a/test/test_ensemble.py b/test/test_ensemble.py
index 099aeaae..275d5bf9 100644
--- a/test/test_ensemble.py
+++ b/test/test_ensemble.py
@@ -21,12 +21,6 @@ from mhcflurry \
         Class1EnsembleMultiAllelePredictor,
         HYPERPARAMETER_DEFAULTS)
 
-try:
-    import kubeface
-    KUBEFACE_INSTALLED = True
-except ImportError:
-    KUBEFACE_INSTALLED = False
-
 
 def test_basic():
     model_hyperparameters = HYPERPARAMETER_DEFAULTS.models_grid(
-- 
GitLab