From eef907330c2c14829906b7c8996f2ec0a4dcb180 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Mon, 9 Sep 2019 12:19:50 -0400
Subject: [PATCH] better test cleanups

---
 test/test_class1_affinity_predictor.py        | 15 ++++++++----
 test/test_download_models_class1.py           | 14 +++++++++--
 test/test_network_merging.py                  | 20 ++++++++++++----
 ...test_released_predictors_on_hpv_dataset.py | 24 +++++++++++++------
 ...est_released_predictors_well_correlated.py | 24 ++++++++++++-------
 test/test_speed.py                            | 22 ++++++++++++++++-
 6 files changed, 92 insertions(+), 27 deletions(-)

diff --git a/test/test_class1_affinity_predictor.py b/test/test_class1_affinity_predictor.py
index 277e78cb..563b6dcf 100644
--- a/test/test_class1_affinity_predictor.py
+++ b/test/test_class1_affinity_predictor.py
@@ -15,14 +15,21 @@ from nose.tools import eq_, assert_raises
 from numpy import testing
 
 from mhcflurry.downloads import get_path
-
 import mhcflurry.testing_utils
-teardown = mhcflurry.testing_utils.module_cleanup
-
 
 DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load()
 
-logging.basicConfig(level=logging.DEBUG)
+
+def setup():
+    global DOWNLOADED_PREDICTOR
+    DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load()
+    logging.basicConfig(level=logging.DEBUG)
+
+
+def teardown():
+    global DOWNLOADED_PREDICTOR
+    DOWNLOADED_PREDICTOR = None
+    mhcflurry.testing_utils.module_cleanup()
 
 
 # To hunt down a weird warning we were seeing in pandas.
diff --git a/test/test_download_models_class1.py b/test/test_download_models_class1.py
index 23b25c3d..5e433272 100644
--- a/test/test_download_models_class1.py
+++ b/test/test_download_models_class1.py
@@ -6,10 +6,20 @@ from numpy.testing import assert_equal
 from mhcflurry import Class1AffinityPredictor, Class1NeuralNetwork
 
 from mhcflurry.testing_utils import module_cleanup
-teardown = module_cleanup
 
 
-DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load()
+DOWNLOADED_PREDICTOR = None
+
+
+def setup():
+    global DOWNLOADED_PREDICTOR
+    DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load()
+
+
+def teardown():
+    global DOWNLOADED_PREDICTOR
+    DOWNLOADED_PREDICTOR = None
+    module_cleanup()
 
 
 def predict_and_check(
diff --git a/test/test_network_merging.py b/test/test_network_merging.py
index a180df49..8d1aca24 100644
--- a/test/test_network_merging.py
+++ b/test/test_network_merging.py
@@ -6,14 +6,24 @@ from mhcflurry.common import random_peptides
 from mhcflurry.downloads import get_path
 
 from mhcflurry.testing_utils import module_cleanup
-teardown = module_cleanup
 
 logging.getLogger('tensorflow').disabled = True
 
-PAN_ALLELE_PREDICTOR = Class1AffinityPredictor.load(
-    get_path("models_class1_pan", "models.with_mass_spec"),
-    max_models=4,
-    optimization_level=0,)
+PAN_ALLELE_PREDICTOR = None
+
+
+def setup():
+    global PAN_ALLELE_PREDICTOR
+    PAN_ALLELE_PREDICTOR = Class1AffinityPredictor.load(
+        get_path("models_class1_pan", "models.with_mass_spec"),
+        max_models=4,
+        optimization_level=0,)
+
+
+def teardown():
+    global PAN_ALLELE_PREDICTOR
+    PAN_ALLELE_PREDICTOR = None
+    module_cleanup()
 
 
 def test_merge():
diff --git a/test/test_released_predictors_on_hpv_dataset.py b/test/test_released_predictors_on_hpv_dataset.py
index 10f0fc23..02928f6d 100644
--- a/test/test_released_predictors_on_hpv_dataset.py
+++ b/test/test_released_predictors_on_hpv_dataset.py
@@ -14,7 +14,6 @@ from mhcflurry import Class1AffinityPredictor
 from mhcflurry.downloads import get_path
 
 from mhcflurry.testing_utils import module_cleanup
-teardown = module_cleanup
 
 
 def data_path(name):
@@ -25,14 +24,24 @@ def data_path(name):
     return os.path.join(os.path.dirname(__file__), "data", name)
 
 
-PREDICTORS = {
-    'allele-specific': Class1AffinityPredictor.load(
-        get_path("models_class1", "models")),
-    'pan-allele': Class1AffinityPredictor.load(
-        get_path("models_class1_pan", "models.with_mass_spec"))
+DF = pandas.read_csv(data_path("hpv_predictions.csv"))
+PREDICTORS = None
+
+
+def setup():
+    global PREDICTORS
+    PREDICTORS = {
+        'allele-specific': Class1AffinityPredictor.load(
+            get_path("models_class1", "models")),
+        'pan-allele': Class1AffinityPredictor.load(
+            get_path("models_class1_pan", "models.with_mass_spec"))
 }
 
-DF = pandas.read_csv(data_path("hpv_predictions.csv"))
+
+def teardown():
+    global PREDICTORS
+    PREDICTORS = None
+    module_cleanup()
 
 
 def test_on_hpv(df=DF):
@@ -63,6 +72,7 @@ def test_on_hpv(df=DF):
 
 if __name__ == '__main__':
     # If run directly from python, leave the user in a shell to explore results.
+    setup()
     result = test_on_hpv()
 
     # Leave in ipython
diff --git a/test/test_released_predictors_well_correlated.py b/test/test_released_predictors_well_correlated.py
index 6b8a6a7e..67d0e98e 100644
--- a/test/test_released_predictors_well_correlated.py
+++ b/test/test_released_predictors_well_correlated.py
@@ -16,17 +16,24 @@ from mhcflurry.downloads import get_path
 from mhcflurry.common import random_peptides
 
 from mhcflurry.testing_utils import module_cleanup
-teardown = module_cleanup
 
+PREDICTORS = None
 
-PREDICTORS = {
-    'allele-specific': Class1AffinityPredictor.load(
-        get_path("models_class1", "models")),
-    'pan-allele': Class1AffinityPredictor.load(
-        get_path("models_class1_pan", "models.with_mass_spec"))
-}
 
-# PREDICTORS["pan-allele"].optimize()
+def setup():
+    global PREDICTORS
+    PREDICTORS = {
+        'allele-specific': Class1AffinityPredictor.load(
+            get_path("models_class1", "models")),
+        'pan-allele': Class1AffinityPredictor.load(
+            get_path("models_class1_pan", "models.with_mass_spec"))
+    }
+
+
+def teardown():
+    global PREDICTORS
+    PREDICTORS = None
+    module_cleanup()
 
 
 def test_correlation(
@@ -83,6 +90,7 @@ parser.add_argument(
 
 if __name__ == '__main__':
     # If run directly from python, leave the user in a shell to explore results.
+    setup()
     args = parser.parse_args(sys.argv[1:])
     result = test_correlation(alleles=args.alleles, debug=True)
 
diff --git a/test/test_speed.py b/test/test_speed.py
index b35327fe..4f0d473e 100644
--- a/test/test_speed.py
+++ b/test/test_speed.py
@@ -19,7 +19,6 @@ from mhcflurry.common import random_peptides
 from mhcflurry.downloads import get_path
 
 from mhcflurry.testing_utils import module_cleanup
-teardown = module_cleanup
 
 ALLELE_SPECIFIC_PREDICTOR = Class1AffinityPredictor.load(
     get_path("models_class1", "models"))
@@ -27,6 +26,26 @@ ALLELE_SPECIFIC_PREDICTOR = Class1AffinityPredictor.load(
 PAN_ALLELE_PREDICTOR = Class1AffinityPredictor.load(
     get_path("models_class1_pan", "models.with_mass_spec"))
 
+
+PREDICTORS = None
+
+
+def setup():
+    global PREDICTORS
+    PREDICTORS = {
+        'allele-specific': Class1AffinityPredictor.load(
+            get_path("models_class1", "models")),
+        'pan-allele': Class1AffinityPredictor.load(
+            get_path("models_class1_pan", "models.with_mass_spec"))
+    }
+
+
+def teardown():
+    global PREDICTORS
+    PREDICTORS = None
+    module_cleanup()
+
+
 DEFAULT_NUM_PREDICTIONS = 10000
 
 
@@ -137,6 +156,7 @@ if __name__ == '__main__':
     # to explore results.
 
     args = parser.parse_args(sys.argv[1:])
+    setup()
 
     if "allele-specific" in args.predictor:
         print("Running allele-specific test")
-- 
GitLab