From 32d0bd8e32da7fe459cb12f2bbd20701c0437c46 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 18 Jun 2019 10:44:55 -0400
Subject: [PATCH] fix

---
 mhcflurry/train_allele_specific_models_command.py |  2 +-
 mhcflurry/train_pan_allele_models_command.py      |  7 +++----
 test/test_train_pan_allele_models_command.py      | 13 ++++++++++---
 3 files changed, 14 insertions(+), 8 deletions(-)

diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index 270acb4a..2b5a5212 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -20,7 +20,7 @@ import tqdm  # progress bar
 tqdm.monitor_interval = 0  # see https://github.com/tqdm/tqdm/issues/481
 
 from .class1_affinity_predictor import Class1AffinityPredictor
-from .common import configure_logging, set_keras_backend
+from .common import configure_logging
 from .parallelism import (
     add_worker_pool_args,
     worker_pool_with_gpu_assignments_from_args,
diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py
index ae5bb8d8..56e4d60b 100644
--- a/mhcflurry/train_pan_allele_models_command.py
+++ b/mhcflurry/train_pan_allele_models_command.py
@@ -22,7 +22,7 @@ tqdm.monitor_interval = 0  # see https://github.com/tqdm/tqdm/issues/481
 
 from .class1_affinity_predictor import Class1AffinityPredictor
 from .class1_neural_network import Class1NeuralNetwork
-from .common import configure_logging, set_keras_backend
+from .common import configure_logging
 from .parallelism import (
     add_worker_pool_args,
     worker_pool_with_gpu_assignments_from_args,
@@ -451,6 +451,7 @@ def train_model(
             replicate_num + 1,
             num_replicates))
 
+    assert model.network() is None
     if hyperparameters.get("train_data", {}).get("pretrain", False):
         iterator = pretrain_data_iterator(pretrain_data_filename, allele_encoding)
         original_hyperparameters = dict(model.hyperparameters)
@@ -530,11 +531,9 @@ def train_model(
         predictor.manifest_df.shape[0], len(predictor.class1_pan_allele_models))
     predictor.clear_cache()
 
-    # Delete the network and release memory
+    # Delete the network to release memory
     model.update_network_description()  # save weights and config
     model._network = None  # release tensorflow network
-    K.clear_session()  # release graph
-
     return predictor
 
 
diff --git a/test/test_train_pan_allele_models_command.py b/test/test_train_pan_allele_models_command.py
index 397ccee6..870c3694 100644
--- a/test/test_train_pan_allele_models_command.py
+++ b/test/test_train_pan_allele_models_command.py
@@ -34,7 +34,7 @@ HYPERPARAMETERS_LIST = [
     'locally_connected_layers': [],
     'loss': 'custom:mse_with_inequalities',
     'max_epochs': 5,
-    'minibatch_size': 128,
+    'minibatch_size': 256,
     'optimizer': 'rmsprop',
     'output_activation': 'sigmoid',
     'patience': 10,
@@ -70,7 +70,7 @@ HYPERPARAMETERS_LIST = [
     'locally_connected_layers': [],
     'loss': 'custom:mse_with_inequalities',
     'max_epochs': 5,
-    'minibatch_size': 128,
+    'minibatch_size': 256,
     'optimizer': 'rmsprop',
     'output_activation': 'sigmoid',
     'patience': 10,
@@ -102,14 +102,21 @@ def run_and_check(n_jobs=0):
     with open(hyperparameters_filename, "w") as fd:
         json.dump(HYPERPARAMETERS_LIST, fd)
 
+    data_df = pandas.read_csv(
+        get_path("data_curated", "curated_training_data.no_mass_spec.csv.bz2"))
+    selected_data_df = data_df.loc[data_df.allele.str.startswith("HLA-A")]
+    selected_data_df.to_csv(
+        os.path.join(models_dir, "train_data.csv"), index=False)
+
     args = [
         "mhcflurry-class1-train-pan-allele-models",
-        "--data", get_path("data_curated", "curated_training_data.no_mass_spec.csv.bz2"),
+        "--data", os.path.join(models_dir, "train_data.csv"),
         "--allele-sequences", get_path("allele_sequences", "allele_sequences.csv"),
         "--hyperparameters", hyperparameters_filename,
         "--out-models-dir", models_dir,
         "--num-jobs", str(n_jobs),
         "--ensemble-size", "2",
+        "--verbosity", "1",
     ]
     print("Running with args: %s" % args)
     subprocess.check_call(args)
-- 
GitLab