From eeabbafdf6a7593f3aaba38a4e73ce112aaa6393 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Thu, 16 May 2019 11:45:06 -0400
Subject: [PATCH] fix

---
 .../models_class1_pan_unselected/GENERATE.sh     |  2 +-
 .../generate_hyperparameters.py                  |  6 +++++-
 mhcflurry/train_pan_allele_models_command.py     | 16 ++++++++++------
 3 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/downloads-generation/models_class1_pan_unselected/GENERATE.sh b/downloads-generation/models_class1_pan_unselected/GENERATE.sh
index 8475d01d..8724ed3f 100755
--- a/downloads-generation/models_class1_pan_unselected/GENERATE.sh
+++ b/downloads-generation/models_class1_pan_unselected/GENERATE.sh
@@ -5,7 +5,7 @@
 set -e
 set -x
 
-DOWNLOAD_NAME=models_class1_unselected
+DOWNLOAD_NAME=models_class1_pan
 SCRATCH_DIR=${TMPDIR-/tmp}/mhcflurry-downloads-generation
 SCRIPT_ABSOLUTE_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/$(basename "${BASH_SOURCE[0]}")"
 SCRIPT_DIR=$(dirname "$SCRIPT_ABSOLUTE_PATH")
diff --git a/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py b/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py
index 6a9c2782..bfbd4432 100644
--- a/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py
+++ b/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py
@@ -39,7 +39,11 @@ base_hyperparameters = {
     'random_negative_distribution_smoothing': 0.0,
     'random_negative_match_distribution': True,
     'random_negative_rate': 0.2,
-    'train_data': {},
+    'train_data': {
+        'pretrain': True,
+        'pretrain_peptides_per_epoch': 1024,
+        'pretrain_patience': 10,
+    },
     'validation_split': 0.1,
 }
 
diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py
index 1520502a..3bdd9963 100644
--- a/mhcflurry/train_pan_allele_models_command.py
+++ b/mhcflurry/train_pan_allele_models_command.py
@@ -291,6 +291,10 @@ def main(args):
         if args.max_epochs:
             hyperparameters['max_epochs'] = args.max_epochs
 
+        if hyperparameters.get("train_data", {}).get("pretrain", False):
+            if not args.pretrain_data:
+                raise ValueError("--pretrain-data is required")
+
         for fold in range(args.ensemble_size):
             for replicate in range(args.num_replicates):
                 work_dict = {
@@ -386,11 +390,10 @@ def train_model(
         num_replicates,
         hyperparameters,
         pretrain_data_filename,
-        pretrain_patience=1,
-        verbose=None,
-        progress_print_interval=None,
-        predictor=None,
-        save_to=None):
+        verbose,
+        progress_print_interval,
+        predictor,
+        save_to):
 
     if predictor is None:
         predictor = Class1AffinityPredictor()
@@ -424,7 +427,7 @@ def train_model(
             replicate_num + 1,
             num_replicates))
 
-    if pretrain_data_filename:
+    if hyperparameters.get("train_data", {}).get("pretrain", False):
         iterator = pretrain_data_iterator(pretrain_data_filename, allele_encoding)
         original_hyperparameters = dict(model.hyperparameters)
         model.hyperparameters['minibatch_size'] = int(len(next(iterator)[-1]) / 100)
@@ -432,6 +435,7 @@ def train_model(
         model.hyperparameters['validation_split'] = 0.0
         model.hyperparameters['random_negative_rate'] = 0.0
         model.hyperparameters['random_negative_constant'] = 0
+        pretrain_patience = hyperparameters["train_data"]["pretrain_patience"]
         scores = []
         best_score = float('inf')
         best_score_epoch = 0
-- 
GitLab