From c749b93089ddf950200ef9d6106443446855638c Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Sat, 13 Jul 2019 12:58:34 -0400
Subject: [PATCH] fix

---
 .../generate_hyperparameters.py               |  3 ++
 mhcflurry/class1_neural_network.py            |  9 +++-
 mhcflurry/train_pan_allele_models_command.py  | 45 ++++++++++++++-----
 test/test_train_pan_allele_models_command.py  |  2 -
 4 files changed, 43 insertions(+), 16 deletions(-)

diff --git a/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py b/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py
index ccffc6b0..4ed6e35c 100644
--- a/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py
+++ b/downloads-generation/models_class1_pan_unselected/generate_hyperparameters.py
@@ -24,6 +24,7 @@ base_hyperparameters = {
     'optimizer': 'rmsprop',
     'output_activation': 'sigmoid',
     "patience": 20,
+    "min_delta": 0.0,
     'peptide_encoding': {
         'vector_encoding_name': 'BLOSUM62',
         'alignment_method': 'left_pad_centered_right_pad',
@@ -44,6 +45,8 @@ base_hyperparameters = {
         'pretrain_peptides_per_epoch': 1024,
         'pretrain_steps_per_epoch': 16,
         'pretrain_patience': 10,
+        'pretrain_min_delta': 0.0001,
+        'pretrain_max_val_loss': 0.10,
     },
     'validation_split': 0.1,
 }
diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index d1add2b7..3773f7d4 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -89,6 +89,7 @@ class Class1NeuralNetwork(object):
 
     early_stopping_hyperparameter_defaults = HyperparameterDefaults(
         patience=20,
+        min_delta=0.0,
     )
     """
     Hyperparameters for early stopping.
@@ -429,6 +430,7 @@ class Class1NeuralNetwork(object):
             steps_per_epoch=10,
             epochs=1000,
             patience=10,
+            min_delta=0.0,
             verbose=1):
         """
         Fit using a generator. Does not support many of the features of fit(),
@@ -532,6 +534,7 @@ class Class1NeuralNetwork(object):
             callbacks=[keras.callbacks.EarlyStopping(
                 monitor="val_loss",
                 patience=patience,
+                min_delta=min_delta,
                 verbose=verbose)]
         )
         for (key, value) in fit_history.history.items():
@@ -831,7 +834,8 @@ class Class1NeuralNetwork(object):
                 shuffle=True,
                 batch_size=self.hyperparameters['minibatch_size'],
                 verbose=verbose,
-                epochs=1,
+                epochs=i + 1,
+                initial_epoch=i,
                 validation_split=self.hyperparameters['validation_split'],
                 sample_weight=sample_weights_with_random_negatives)
 
@@ -857,7 +861,8 @@ class Class1NeuralNetwork(object):
                 val_loss = fit_info['val_loss'][-1]
                 val_losses.append(val_loss)
 
-                if min_val_loss is None or val_loss <= min_val_loss:
+                if min_val_loss is None or (
+                        val_loss < min_val_loss - self.hyperparameters['min_delta']):
                     min_val_loss = val_loss
                     min_val_loss_iteration = i
 
diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py
index b5aaf36e..1ab50a0f 100644
--- a/mhcflurry/train_pan_allele_models_command.py
+++ b/mhcflurry/train_pan_allele_models_command.py
@@ -332,7 +332,7 @@ def main(args):
                     'hyperparameters': hyperparameters,
                     'pretrain_data_filename': args.pretrain_data,
                     'verbose': args.verbosity,
-                    'progress_print_interval': None if not serial_run else 5.0,
+                    'progress_print_interval': 60.0 if not serial_run else 5.0,
                     'predictor': predictor if serial_run else None,
                     'save_to': args.out_models_dir if serial_run else None,
                 }
@@ -484,22 +484,43 @@ def train_model(
         generator = pretrain_data_iterator(pretrain_data_filename, allele_encoding)
         pretrain_patience = hyperparameters["train_data"].get(
             "pretrain_patience", 10)
+        pretrain_min_delta = hyperparameters["train_data"].get(
+            "pretrain_min_delta", 0.0)
         pretrain_steps_per_epoch = hyperparameters["train_data"].get(
             "pretrain_steps_per_epoch", 10)
         pretrain_max_epochs = hyperparameters["train_data"].get(
             "pretrain_max_epochs", 1000)
 
-        model.fit_generator(
-            generator,
-            validation_peptide_encoding=train_peptides,
-            validation_affinities=train_data.measurement_value.values,
-            validation_allele_encoding=train_alleles,
-            validation_inequalities=train_data.measurement_inequality.values,
-            patience=pretrain_patience,
-            steps_per_epoch=pretrain_steps_per_epoch,
-            epochs=pretrain_max_epochs,
-            verbose=verbose,
-        )
+        max_val_loss =  hyperparameters["train_data"].get("pretrain_max_val_loss")
+
+        attempt = 0
+        while True:
+            attempt += 1
+            print("Pre-training attempt %d" % attempt)
+            if attempt > 10:
+                print("Too many pre-training attempts! Stopping pretraining.")
+                break
+            model.fit_generator(
+                generator,
+                validation_peptide_encoding=train_peptides,
+                validation_affinities=train_data.measurement_value.values,
+                validation_allele_encoding=train_alleles,
+                validation_inequalities=train_data.measurement_inequality.values,
+                patience=pretrain_patience,
+                min_delta=pretrain_min_delta,
+                steps_per_epoch=pretrain_steps_per_epoch,
+                epochs=pretrain_max_epochs,
+                verbose=verbose,
+            )
+            if not max_val_loss:
+                break
+            if model.fit_info[-1]["val_loss"] >= max_val_loss:
+                print("Val loss %f >= max val loss %f. Pre-training again." % (
+                    model.fit_info[-1]["val_loss"], max_val_loss))
+            else:
+                print("Val loss %f < max val loss %f. Done pre-training." % (
+                    model.fit_info[-1]["val_loss"], max_val_loss))
+                break
 
         # Use a smaller learning rate for training on real data
         learning_rate = model.fit_info[-1]["learning_rate"]
diff --git a/test/test_train_pan_allele_models_command.py b/test/test_train_pan_allele_models_command.py
index 1ecbe85f..f4255f76 100644
--- a/test/test_train_pan_allele_models_command.py
+++ b/test/test_train_pan_allele_models_command.py
@@ -144,7 +144,6 @@ def run_and_check(n_jobs=0, delete=True, additional_args=[]):
         print("Deleting: %s" % models_dir)
         shutil.rmtree(models_dir)
 
-"""
 if os.environ.get("KERAS_BACKEND") != "theano":
     def test_run_parallel():
         run_and_check(n_jobs=1)
@@ -153,7 +152,6 @@ if os.environ.get("KERAS_BACKEND") != "theano":
 
 def test_run_serial():
     run_and_check(n_jobs=0)
-"""
 
 
 def test_run_cluster_parallelism():
-- 
GitLab