From c07d80cfeb45cd815aae91dc4d4bf356f387254b Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Thu, 5 Sep 2019 23:32:38 -0400
Subject: [PATCH] fix

---
 .travis.yml                                  | 2 +-
 mhcflurry/train_pan_allele_models_command.py | 2 ++
 test/test_train_pan_allele_models_command.py | 6 +++---
 3 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/.travis.yml b/.travis.yml
index 05cb0e24..20711113 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -43,7 +43,7 @@ env:
     - KERAS_BACKEND=tensorflow
 script:
   # download data and models, then run tests
-  - mhcflurry-downloads fetch data_curated models_class1 models_class1_pan allele_sequences random_peptide_predictions
+  - mhcflurry-downloads fetch data_curated models_class1 models_class1_pan allele_sequences
   - mhcflurry-downloads info  # just to test this command works
   - nosetests test -sv
   - ./lint.sh
diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py
index a676bb38..a985539d 100644
--- a/mhcflurry/train_pan_allele_models_command.py
+++ b/mhcflurry/train_pan_allele_models_command.py
@@ -655,6 +655,8 @@ def train_model(
         # Use a smaller learning rate for training on real data
         learning_rate = model.fit_info[-1]["learning_rate"]
         model.hyperparameters['learning_rate'] = learning_rate / 10
+    else:
+        model = Class1NeuralNetwork(**hyperparameters)
 
     model.fit(
         peptides=train_peptides,
diff --git a/test/test_train_pan_allele_models_command.py b/test/test_train_pan_allele_models_command.py
index 2c600346..ea5261ec 100644
--- a/test/test_train_pan_allele_models_command.py
+++ b/test/test_train_pan_allele_models_command.py
@@ -89,7 +89,7 @@ HYPERPARAMETERS_LIST = [
     'random_negative_match_distribution': True,
     'random_negative_rate': 0.2,
     'train_data': {
-        "pretrain": True,
+        "pretrain": False,
         'pretrain_peptides_per_epoch': 128,
         'pretrain_max_epochs': 2,
         'pretrain_max_val_loss': 0.2,
@@ -121,8 +121,8 @@ def run_and_check(n_jobs=0, delete=True, additional_args=[]):
         "--num-jobs", str(n_jobs),
         "--ensemble-size", "2",
         "--verbosity", "1",
-        "--pretrain-data", get_path(
-            "random_peptide_predictions", "predictions.csv.bz2"),
+        # "--pretrain-data", get_path(
+        #      "random_peptide_predictions", "predictions.csv.bz2"),
     ] + additional_args
     print("Running with args: %s" % args)
     subprocess.check_call(args)
-- 
GitLab