From c07d80cfeb45cd815aae91dc4d4bf356f387254b Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Thu, 5 Sep 2019 23:32:38 -0400 Subject: [PATCH] fix --- .travis.yml | 2 +- mhcflurry/train_pan_allele_models_command.py | 2 ++ test/test_train_pan_allele_models_command.py | 6 +++--- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 05cb0e24..20711113 100644 --- a/.travis.yml +++ b/.travis.yml @@ -43,7 +43,7 @@ env: - KERAS_BACKEND=tensorflow script: # download data and models, then run tests - - mhcflurry-downloads fetch data_curated models_class1 models_class1_pan allele_sequences random_peptide_predictions + - mhcflurry-downloads fetch data_curated models_class1 models_class1_pan allele_sequences - mhcflurry-downloads info # just to test this command works - nosetests test -sv - ./lint.sh diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py index a676bb38..a985539d 100644 --- a/mhcflurry/train_pan_allele_models_command.py +++ b/mhcflurry/train_pan_allele_models_command.py @@ -655,6 +655,8 @@ def train_model( # Use a smaller learning rate for training on real data learning_rate = model.fit_info[-1]["learning_rate"] model.hyperparameters['learning_rate'] = learning_rate / 10 + else: + model = Class1NeuralNetwork(**hyperparameters) model.fit( peptides=train_peptides, diff --git a/test/test_train_pan_allele_models_command.py b/test/test_train_pan_allele_models_command.py index 2c600346..ea5261ec 100644 --- a/test/test_train_pan_allele_models_command.py +++ b/test/test_train_pan_allele_models_command.py @@ -89,7 +89,7 @@ HYPERPARAMETERS_LIST = [ 'random_negative_match_distribution': True, 'random_negative_rate': 0.2, 'train_data': { - "pretrain": True, + "pretrain": False, 'pretrain_peptides_per_epoch': 128, 'pretrain_max_epochs': 2, 'pretrain_max_val_loss': 0.2, @@ -121,8 +121,8 @@ def run_and_check(n_jobs=0, delete=True, additional_args=[]): "--num-jobs", str(n_jobs), "--ensemble-size", "2", "--verbosity", "1", - "--pretrain-data", get_path( - "random_peptide_predictions", "predictions.csv.bz2"), + # "--pretrain-data", get_path( + # "random_peptide_predictions", "predictions.csv.bz2"), ] + additional_args print("Running with args: %s" % args) subprocess.check_call(args) -- GitLab