From ce82c72a901af7be889e67c8ccb96fcbca2453e9 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Mon, 9 Sep 2019 14:36:50 -0400
Subject: [PATCH] fix

---
 mhcflurry/testing_utils.py                    |  8 ++++
 ...test_calibrate_percentile_ranks_command.py |  4 +-
 test/test_changing_allele_representations.py  |  3 +-
 test/test_class1_affinity_predictor.py        |  5 +-
 test/test_class1_neural_network.py            |  3 +-
 test/test_class1_pan.py                       |  3 +-
 test/test_custom_loss.py                      |  3 +-
 test/test_download_models_class1.py           |  4 +-
 test/test_multi_output.py                     |  3 +-
 test/test_network_merging.py                  |  4 +-
 test/test_predict_command.py                  |  3 +-
 ...test_released_predictors_on_hpv_dataset.py |  4 +-
 ...est_released_predictors_well_correlated.py |  5 +-
 test/test_speed.py                            | 48 ++++++++-----------
 test/test_train_and_related_commands.py       |  3 +-
 test/test_train_pan_allele_models_command.py  |  3 +-
 16 files changed, 59 insertions(+), 47 deletions(-)

diff --git a/mhcflurry/testing_utils.py b/mhcflurry/testing_utils.py
index 157a11af..64e4965b 100644
--- a/mhcflurry/testing_utils.py
+++ b/mhcflurry/testing_utils.py
@@ -2,6 +2,14 @@
 Utilities used in MHCflurry unit tests.
 """
 from . import Class1NeuralNetwork
+from .common import set_keras_backend
+
+
+def startup():
+    """
+    Configure Keras backend for running unit tests.
+    """
+    set_keras_backend("tensorflow-cpu", num_threads=2)
 
 
 def cleanup():
diff --git a/test/test_calibrate_percentile_ranks_command.py b/test/test_calibrate_percentile_ranks_command.py
index 2dc44a0b..a5656ac35 100644
--- a/test/test_calibrate_percentile_ranks_command.py
+++ b/test/test_calibrate_percentile_ranks_command.py
@@ -14,9 +14,9 @@ from mhcflurry.downloads import get_path
 
 os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
-from mhcflurry.testing_utils import cleanup
+from mhcflurry.testing_utils import cleanup, startup
 teardown = cleanup
-
+setup = startup
 
 def run_and_check(n_jobs=0, delete=True, additional_args=[]):
     source_models_dir = get_path("models_class1_pan", "models.with_mass_spec")
diff --git a/test/test_changing_allele_representations.py b/test/test_changing_allele_representations.py
index 38b9e12c..271cb3d9 100644
--- a/test/test_changing_allele_representations.py
+++ b/test/test_changing_allele_representations.py
@@ -8,8 +8,9 @@ from mhcflurry.downloads import get_path
 
 from numpy.testing import assert_equal
 
-from mhcflurry.testing_utils import cleanup
+from mhcflurry.testing_utils import cleanup, startup
 teardown = cleanup
+setup = startup
 
 ALLELE_TO_SEQUENCE = pandas.read_csv(
     get_path(
diff --git a/test/test_class1_affinity_predictor.py b/test/test_class1_affinity_predictor.py
index 8f2b4aa0..aa70a25d 100644
--- a/test/test_class1_affinity_predictor.py
+++ b/test/test_class1_affinity_predictor.py
@@ -15,13 +15,14 @@ from nose.tools import eq_, assert_raises
 from numpy import testing
 
 from mhcflurry.downloads import get_path
-import mhcflurry.testing_utils
+from mhcflurry.testing_utils import cleanup, startup
 
 DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load()
 
 
 def setup():
     global DOWNLOADED_PREDICTOR
+    startup()
     DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load()
     logging.basicConfig(level=logging.DEBUG)
 
@@ -29,7 +30,7 @@ def setup():
 def teardown():
     global DOWNLOADED_PREDICTOR
     DOWNLOADED_PREDICTOR = None
-    mhcflurry.testing_utils.cleanup()
+    cleanup()
 
 
 # To hunt down a weird warning we were seeing in pandas.
diff --git a/test/test_class1_neural_network.py b/test/test_class1_neural_network.py
index 4620334a..47be495c 100644
--- a/test/test_class1_neural_network.py
+++ b/test/test_class1_neural_network.py
@@ -13,8 +13,9 @@ from mhcflurry.class1_neural_network import Class1NeuralNetwork
 from mhcflurry.downloads import get_path
 from mhcflurry.common import random_peptides
 
-from mhcflurry.testing_utils import cleanup
+from mhcflurry.testing_utils import cleanup, startup
 teardown = cleanup
+setup = startup
 
 
 def test_class1_neural_network_a0205_training_accuracy():
diff --git a/test/test_class1_pan.py b/test/test_class1_pan.py
index 9ad87d69..0528ef60 100644
--- a/test/test_class1_pan.py
+++ b/test/test_class1_pan.py
@@ -11,8 +11,9 @@ from mhcflurry import Class1AffinityPredictor,Class1NeuralNetwork
 from mhcflurry.allele_encoding import AlleleEncoding
 from mhcflurry.downloads import get_path
 
-from mhcflurry.testing_utils import cleanup
+from mhcflurry.testing_utils import cleanup, startup
 teardown = cleanup
+setup = startup
 
 
 HYPERPARAMETERS = {
diff --git a/test/test_custom_loss.py b/test/test_custom_loss.py
index 2426f48d..98ee4ab5 100644
--- a/test/test_custom_loss.py
+++ b/test/test_custom_loss.py
@@ -11,8 +11,9 @@ import keras.backend as K
 
 from mhcflurry.custom_loss import CUSTOM_LOSSES
 
-from mhcflurry.testing_utils import cleanup
+from mhcflurry.testing_utils import cleanup, startup
 teardown = cleanup
+setup = startup
 
 
 def evaluate_loss(loss, y_true, y_pred):
diff --git a/test/test_download_models_class1.py b/test/test_download_models_class1.py
index ffbf8b3c..29a4d47a 100644
--- a/test/test_download_models_class1.py
+++ b/test/test_download_models_class1.py
@@ -5,14 +5,14 @@ from numpy.testing import assert_equal
 
 from mhcflurry import Class1AffinityPredictor, Class1NeuralNetwork
 
-from mhcflurry.testing_utils import cleanup
-
+from mhcflurry.testing_utils import cleanup, startup
 
 DOWNLOADED_PREDICTOR = None
 
 
 def setup():
     global DOWNLOADED_PREDICTOR
+    startup()
     DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load()
 
 
diff --git a/test/test_multi_output.py b/test/test_multi_output.py
index b4254b56..452f7a76 100644
--- a/test/test_multi_output.py
+++ b/test/test_multi_output.py
@@ -12,8 +12,9 @@ logging.getLogger('tensorflow').disabled = True
 from mhcflurry.class1_neural_network import Class1NeuralNetwork
 from mhcflurry.common import random_peptides
 
-from mhcflurry.testing_utils import cleanup
+from mhcflurry.testing_utils import cleanup, startup
 teardown = cleanup
+setup = startup
 
 
 def test_multi_output():
diff --git a/test/test_network_merging.py b/test/test_network_merging.py
index 56683a52..69eab37d 100644
--- a/test/test_network_merging.py
+++ b/test/test_network_merging.py
@@ -5,8 +5,7 @@ from mhcflurry import Class1AffinityPredictor, Class1NeuralNetwork
 from mhcflurry.common import random_peptides
 from mhcflurry.downloads import get_path
 
-from mhcflurry.testing_utils import cleanup
-
+from mhcflurry.testing_utils import cleanup, startup
 logging.getLogger('tensorflow').disabled = True
 
 PAN_ALLELE_PREDICTOR = None
@@ -14,6 +13,7 @@ PAN_ALLELE_PREDICTOR = None
 
 def setup():
     global PAN_ALLELE_PREDICTOR
+    startup()
     PAN_ALLELE_PREDICTOR = Class1AffinityPredictor.load(
         get_path("models_class1_pan", "models.with_mass_spec"),
         max_models=4,
diff --git a/test/test_predict_command.py b/test/test_predict_command.py
index 507d7c4a..c3f0a5c1 100644
--- a/test/test_predict_command.py
+++ b/test/test_predict_command.py
@@ -6,8 +6,9 @@ from numpy.testing import assert_equal
 
 from mhcflurry import predict_command
 
-from mhcflurry.testing_utils import cleanup
+from mhcflurry.testing_utils import cleanup, startup
 teardown = cleanup
+setup = startup
 
 TEST_CSV = '''
 Allele,Peptide,Experiment
diff --git a/test/test_released_predictors_on_hpv_dataset.py b/test/test_released_predictors_on_hpv_dataset.py
index 33692574..26bf0227 100644
--- a/test/test_released_predictors_on_hpv_dataset.py
+++ b/test/test_released_predictors_on_hpv_dataset.py
@@ -13,8 +13,7 @@ from nose.tools import eq_, assert_less, assert_greater, assert_almost_equal
 from mhcflurry import Class1AffinityPredictor
 from mhcflurry.downloads import get_path
 
-from mhcflurry.testing_utils import cleanup
-
+from mhcflurry.testing_utils import cleanup, startup
 
 def data_path(name):
     '''
@@ -30,6 +29,7 @@ PREDICTORS = None
 
 def setup():
     global PREDICTORS
+    startup()
     PREDICTORS = {
         'allele-specific': Class1AffinityPredictor.load(
             get_path("models_class1", "models")),
diff --git a/test/test_released_predictors_well_correlated.py b/test/test_released_predictors_well_correlated.py
index 82071d65..5040843f 100644
--- a/test/test_released_predictors_well_correlated.py
+++ b/test/test_released_predictors_well_correlated.py
@@ -15,13 +15,14 @@ from mhcflurry.encodable_sequences import EncodableSequences
 from mhcflurry.downloads import get_path
 from mhcflurry.common import random_peptides
 
-from mhcflurry.testing_utils import cleanup
+from mhcflurry.testing_utils import cleanup, startup
 
 PREDICTORS = None
 
 
 def setup():
     global PREDICTORS
+    startup()
     PREDICTORS = {
         'allele-specific': Class1AffinityPredictor.load(
             get_path("models_class1", "models")),
@@ -38,7 +39,7 @@ def teardown():
 
 def test_correlation(
         alleles=None,
-        num_peptides_per_length=500,
+        num_peptides_per_length=100,
         lengths=[8, 9, 10],
         debug=False):
     peptides = []
diff --git a/test/test_speed.py b/test/test_speed.py
index eb604a50..037ae61c 100644
--- a/test/test_speed.py
+++ b/test/test_speed.py
@@ -18,46 +18,41 @@ from mhcflurry.encodable_sequences import EncodableSequences
 from mhcflurry.common import random_peptides
 from mhcflurry.downloads import get_path
 
-from mhcflurry.testing_utils import cleanup
+from mhcflurry.testing_utils import cleanup, startup
 
-ALLELE_SPECIFIC_PREDICTOR = Class1AffinityPredictor.load(
-    get_path("models_class1", "models"))
 
-PAN_ALLELE_PREDICTOR = Class1AffinityPredictor.load(
-    get_path("models_class1_pan", "models.with_mass_spec"))
-
-
-PREDICTORS = None
+ALLELE_SPECIFIC_PREDICTOR = None
+PAN_ALLELE_PREDICTOR = None
 
 
 def setup():
-    global PREDICTORS
-    PREDICTORS = {
-        'allele-specific': Class1AffinityPredictor.load(
-            get_path("models_class1", "models")),
-        'pan-allele': Class1AffinityPredictor.load(
-            get_path("models_class1_pan", "models.with_mass_spec"))
-    }
+    global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR
+    startup()
+    ALLELE_SPECIFIC_PREDICTOR = Class1AffinityPredictor.load(
+        get_path("models_class1", "models"))
+
+    PAN_ALLELE_PREDICTOR = Class1AffinityPredictor.load(
+        get_path("models_class1_pan", "models.with_mass_spec"))
 
 
 def teardown():
-    global PREDICTORS
-    PREDICTORS = None
+    global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR
+    ALLELE_SPECIFIC_PREDICTOR = None
+    PAN_ALLELE_PREDICTOR = None
     cleanup()
 
 
 DEFAULT_NUM_PREDICTIONS = 10000
 
 
-def test_speed_allele_specific(
-        profile=False,
-        predictor=ALLELE_SPECIFIC_PREDICTOR,
-        num=DEFAULT_NUM_PREDICTIONS):
-
+def test_speed_allele_specific(profile=False, num=DEFAULT_NUM_PREDICTIONS):
+    global ALLELE_SPECIFIC_PREDICTOR
     starts = collections.OrderedDict()
     timings = collections.OrderedDict()
     profilers = collections.OrderedDict()
 
+    predictor = ALLELE_SPECIFIC_PREDICTOR
+
     def start(name):
         starts[name] = time.time()
         if profile:
@@ -101,15 +96,14 @@ def test_speed_allele_specific(
         (key, pstats.Stats(value)) for (key, value) in profilers.items())
 
 
-def test_speed_pan_allele(
-        profile=False,
-        predictor=PAN_ALLELE_PREDICTOR,
-        num=DEFAULT_NUM_PREDICTIONS):
-
+def test_speed_pan_allele(profile=False, num=DEFAULT_NUM_PREDICTIONS):
+    global PAN_ALLELE_PREDICTOR
     starts = collections.OrderedDict()
     timings = collections.OrderedDict()
     profilers = collections.OrderedDict()
 
+    predictor = PAN_ALLELE_PREDICTOR
+
     def start(name):
         starts[name] = time.time()
         if profile:
diff --git a/test/test_train_and_related_commands.py b/test/test_train_and_related_commands.py
index 7b893f7b..92efbdbd 100644
--- a/test/test_train_and_related_commands.py
+++ b/test/test_train_and_related_commands.py
@@ -14,8 +14,9 @@ from numpy.testing import assert_array_less, assert_equal
 from mhcflurry import Class1AffinityPredictor
 from mhcflurry.downloads import get_path
 
-from mhcflurry.testing_utils import cleanup
+from mhcflurry.testing_utils import cleanup, startup
 teardown = cleanup
+setup = startup
 
 os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
diff --git a/test/test_train_pan_allele_models_command.py b/test/test_train_pan_allele_models_command.py
index 51c7c3ff..53c5bb25 100644
--- a/test/test_train_pan_allele_models_command.py
+++ b/test/test_train_pan_allele_models_command.py
@@ -15,8 +15,9 @@ from numpy.testing import assert_equal, assert_array_less
 from mhcflurry import Class1AffinityPredictor,Class1NeuralNetwork
 from mhcflurry.downloads import get_path
 
-from mhcflurry.testing_utils import cleanup
+from mhcflurry.testing_utils import cleanup, startup
 teardown = cleanup
+setup = startup
 
 os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
-- 
GitLab