diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index 270acb4a92e760fbfe58c2804fb98b57d673fbf1..2b5a5212b5f640c48a87c5a6ee1cb36a4218a871 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -20,7 +20,7 @@ import tqdm  # progress bar
 tqdm.monitor_interval = 0  # see https://github.com/tqdm/tqdm/issues/481
 
 from .class1_affinity_predictor import Class1AffinityPredictor
-from .common import configure_logging, set_keras_backend
+from .common import configure_logging
 from .parallelism import (
     add_worker_pool_args,
     worker_pool_with_gpu_assignments_from_args,
diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py
index ae5bb8d8e261bf2c1a92973ad914814a0d975d3a..56e4d60b7cbb9f3d22c96a85413ce01d643e7563 100644
--- a/mhcflurry/train_pan_allele_models_command.py
+++ b/mhcflurry/train_pan_allele_models_command.py
@@ -22,7 +22,7 @@ tqdm.monitor_interval = 0  # see https://github.com/tqdm/tqdm/issues/481
 
 from .class1_affinity_predictor import Class1AffinityPredictor
 from .class1_neural_network import Class1NeuralNetwork
-from .common import configure_logging, set_keras_backend
+from .common import configure_logging
 from .parallelism import (
     add_worker_pool_args,
     worker_pool_with_gpu_assignments_from_args,
@@ -451,6 +451,7 @@ def train_model(
             replicate_num + 1,
             num_replicates))
 
+    assert model.network() is None
     if hyperparameters.get("train_data", {}).get("pretrain", False):
         iterator = pretrain_data_iterator(pretrain_data_filename, allele_encoding)
         original_hyperparameters = dict(model.hyperparameters)
@@ -530,11 +531,9 @@ def train_model(
         predictor.manifest_df.shape[0], len(predictor.class1_pan_allele_models))
     predictor.clear_cache()
 
-    # Delete the network and release memory
+    # Delete the network to release memory
     model.update_network_description()  # save weights and config
     model._network = None  # release tensorflow network
-    K.clear_session()  # release graph
-
     return predictor
 
 
diff --git a/test/test_train_pan_allele_models_command.py b/test/test_train_pan_allele_models_command.py
index 397ccee6498496037251471799ece305bb6e2e00..870c3694fb0a9f0c45b1c3f4f2fb45b0428be0a5 100644
--- a/test/test_train_pan_allele_models_command.py
+++ b/test/test_train_pan_allele_models_command.py
@@ -34,7 +34,7 @@ HYPERPARAMETERS_LIST = [
     'locally_connected_layers': [],
     'loss': 'custom:mse_with_inequalities',
     'max_epochs': 5,
-    'minibatch_size': 128,
+    'minibatch_size': 256,
     'optimizer': 'rmsprop',
     'output_activation': 'sigmoid',
     'patience': 10,
@@ -70,7 +70,7 @@ HYPERPARAMETERS_LIST = [
     'locally_connected_layers': [],
     'loss': 'custom:mse_with_inequalities',
     'max_epochs': 5,
-    'minibatch_size': 128,
+    'minibatch_size': 256,
     'optimizer': 'rmsprop',
     'output_activation': 'sigmoid',
     'patience': 10,
@@ -102,14 +102,21 @@ def run_and_check(n_jobs=0):
     with open(hyperparameters_filename, "w") as fd:
         json.dump(HYPERPARAMETERS_LIST, fd)
 
+    data_df = pandas.read_csv(
+        get_path("data_curated", "curated_training_data.no_mass_spec.csv.bz2"))
+    selected_data_df = data_df.loc[data_df.allele.str.startswith("HLA-A")]
+    selected_data_df.to_csv(
+        os.path.join(models_dir, "train_data.csv"), index=False)
+
     args = [
         "mhcflurry-class1-train-pan-allele-models",
-        "--data", get_path("data_curated", "curated_training_data.no_mass_spec.csv.bz2"),
+        "--data", os.path.join(models_dir, "train_data.csv"),
         "--allele-sequences", get_path("allele_sequences", "allele_sequences.csv"),
         "--hyperparameters", hyperparameters_filename,
         "--out-models-dir", models_dir,
         "--num-jobs", str(n_jobs),
         "--ensemble-size", "2",
+        "--verbosity", "1",
     ]
     print("Running with args: %s" % args)
     subprocess.check_call(args)