From 32d0bd8e32da7fe459cb12f2bbd20701c0437c46 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Tue, 18 Jun 2019 10:44:55 -0400 Subject: [PATCH] fix --- mhcflurry/train_allele_specific_models_command.py | 2 +- mhcflurry/train_pan_allele_models_command.py | 7 +++---- test/test_train_pan_allele_models_command.py | 13 ++++++++++--- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py index 270acb4a..2b5a5212 100644 --- a/mhcflurry/train_allele_specific_models_command.py +++ b/mhcflurry/train_allele_specific_models_command.py @@ -20,7 +20,7 @@ import tqdm # progress bar tqdm.monitor_interval = 0 # see https://github.com/tqdm/tqdm/issues/481 from .class1_affinity_predictor import Class1AffinityPredictor -from .common import configure_logging, set_keras_backend +from .common import configure_logging from .parallelism import ( add_worker_pool_args, worker_pool_with_gpu_assignments_from_args, diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py index ae5bb8d8..56e4d60b 100644 --- a/mhcflurry/train_pan_allele_models_command.py +++ b/mhcflurry/train_pan_allele_models_command.py @@ -22,7 +22,7 @@ tqdm.monitor_interval = 0 # see https://github.com/tqdm/tqdm/issues/481 from .class1_affinity_predictor import Class1AffinityPredictor from .class1_neural_network import Class1NeuralNetwork -from .common import configure_logging, set_keras_backend +from .common import configure_logging from .parallelism import ( add_worker_pool_args, worker_pool_with_gpu_assignments_from_args, @@ -451,6 +451,7 @@ def train_model( replicate_num + 1, num_replicates)) + assert model.network() is None if hyperparameters.get("train_data", {}).get("pretrain", False): iterator = pretrain_data_iterator(pretrain_data_filename, allele_encoding) original_hyperparameters = dict(model.hyperparameters) @@ -530,11 +531,9 @@ def train_model( predictor.manifest_df.shape[0], len(predictor.class1_pan_allele_models)) predictor.clear_cache() - # Delete the network and release memory + # Delete the network to release memory model.update_network_description() # save weights and config model._network = None # release tensorflow network - K.clear_session() # release graph - return predictor diff --git a/test/test_train_pan_allele_models_command.py b/test/test_train_pan_allele_models_command.py index 397ccee6..870c3694 100644 --- a/test/test_train_pan_allele_models_command.py +++ b/test/test_train_pan_allele_models_command.py @@ -34,7 +34,7 @@ HYPERPARAMETERS_LIST = [ 'locally_connected_layers': [], 'loss': 'custom:mse_with_inequalities', 'max_epochs': 5, - 'minibatch_size': 128, + 'minibatch_size': 256, 'optimizer': 'rmsprop', 'output_activation': 'sigmoid', 'patience': 10, @@ -70,7 +70,7 @@ HYPERPARAMETERS_LIST = [ 'locally_connected_layers': [], 'loss': 'custom:mse_with_inequalities', 'max_epochs': 5, - 'minibatch_size': 128, + 'minibatch_size': 256, 'optimizer': 'rmsprop', 'output_activation': 'sigmoid', 'patience': 10, @@ -102,14 +102,21 @@ def run_and_check(n_jobs=0): with open(hyperparameters_filename, "w") as fd: json.dump(HYPERPARAMETERS_LIST, fd) + data_df = pandas.read_csv( + get_path("data_curated", "curated_training_data.no_mass_spec.csv.bz2")) + selected_data_df = data_df.loc[data_df.allele.str.startswith("HLA-A")] + selected_data_df.to_csv( + os.path.join(models_dir, "train_data.csv"), index=False) + args = [ "mhcflurry-class1-train-pan-allele-models", - "--data", get_path("data_curated", "curated_training_data.no_mass_spec.csv.bz2"), + "--data", os.path.join(models_dir, "train_data.csv"), "--allele-sequences", get_path("allele_sequences", "allele_sequences.csv"), "--hyperparameters", hyperparameters_filename, "--out-models-dir", models_dir, "--num-jobs", str(n_jobs), "--ensemble-size", "2", + "--verbosity", "1", ] print("Running with args: %s" % args) subprocess.check_call(args) -- GitLab