From 182551af1dba6a04a8ce737d996641c1d71cae5b Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Mon, 15 Jul 2019 17:34:19 -0400
Subject: [PATCH] fix

---
 mhcflurry/class1_neural_network.py            | 19 +++++++---
 .../data_dependent_weights_initialization.py  |  2 +-
 mhcflurry/train_pan_allele_models_command.py  | 36 ++++++++++++++-----
 test/expensive_test_pretrain_optimizable.py   |  9 ++---
 4 files changed, 47 insertions(+), 19 deletions(-)

diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index ec5feb07..40efdbc2 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -569,6 +569,12 @@ class Class1NeuralNetwork(object):
                 verbose=verbose)
             iterator = itertools.chain([first_chunk], iterator)
 
+        def progress_update(epoch, logs):
+            if verbose:
+                print(
+                    "Cumulative training points:",
+                    mutable_generator_state['yielded_values'])
+
         fit_history = network.fit_generator(
             iterator,
             steps_per_epoch=steps_per_epoch,
@@ -577,11 +583,14 @@ class Class1NeuralNetwork(object):
             workers=1,
             validation_data=(validation_x_dict, validation_y_dict),
             verbose=verbose,
-            callbacks=[keras.callbacks.EarlyStopping(
-                monitor="val_loss",
-                patience=patience,
-                min_delta=min_delta,
-                verbose=verbose)]
+            callbacks=[
+                keras.callbacks.EarlyStopping(
+                    monitor="val_loss",
+                    patience=patience,
+                    min_delta=min_delta,
+                    verbose=verbose),
+                keras.callbacks.LambdaCallback(on_epoch_end=progress_update),
+            ]
         )
         for (key, value) in fit_history.history.items():
             fit_info[key].extend(value)
diff --git a/mhcflurry/data_dependent_weights_initialization.py b/mhcflurry/data_dependent_weights_initialization.py
index 165c1e6d..c6cfd6c8 100644
--- a/mhcflurry/data_dependent_weights_initialization.py
+++ b/mhcflurry/data_dependent_weights_initialization.py
@@ -35,7 +35,7 @@ import numpy
 
 
 def svd_orthonormal(shape):
-    # Orthonorm init code is from Lasagne
+    # Orthonormal init code is from Lasagne
     # https://github.com/Lasagne/Lasagne/blob/master/lasagne/init.py
     if len(shape) < 2:
         raise RuntimeError("Only shapes of length 2 or more are supported.")
diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py
index 1fcc9c04..ca5a5a75 100644
--- a/mhcflurry/train_pan_allele_models_command.py
+++ b/mhcflurry/train_pan_allele_models_command.py
@@ -475,17 +475,32 @@ def train_model(
     print("%s [pid %d]. Hyperparameters:" % (progress_preamble, os.getpid()))
     pprint.pprint(hyperparameters)
 
-    if hyperparameters.get("train_data", {}).get("pretrain", False):
-        pretrain_patience = hyperparameters["train_data"].get(
-            "pretrain_patience", 10)
-        pretrain_min_delta = hyperparameters["train_data"].get(
-            "pretrain_min_delta", 0.0)
-        pretrain_steps_per_epoch = hyperparameters["train_data"].get(
+    train_params = dict(hyperparameters.get("train_data", {}))
+
+    def get_train_param(param, default):
+        if param in train_params:
+            result = train_params.pop(param)
+            if verbose:
+                print("Train param", param, "=", result)
+        else:
+            result = default
+            if verbose:
+                print("Train param", param, "=", result, "[default]")
+        return result
+
+    if get_train_param("pretrain", False):
+        pretrain_patience = get_train_param("pretrain_patience", 10)
+        pretrain_min_delta = get_train_param("pretrain_min_delta", 0.0)
+        pretrain_steps_per_epoch = get_train_param(
             "pretrain_steps_per_epoch", 10)
-        pretrain_max_epochs = hyperparameters["train_data"].get(
+        pretrain_max_epochs = get_train_param(
             "pretrain_max_epochs", 1000)
+        pretrain_peptides_per_step = get_train_param(
+            "pretrain_peptides_per_step", 1024)
+        max_val_loss = get_train_param("pretrain_max_val_loss", None)
 
-        max_val_loss = hyperparameters["train_data"].get("pretrain_max_val_loss")
+        if verbose:
+            print("Unused train params", train_params)
 
         attempt = 0
         while True:
@@ -498,7 +513,10 @@ def train_model(
             model = Class1NeuralNetwork(**hyperparameters)
             assert model.network() is None
             generator = pretrain_data_iterator(
-                pretrain_data_filename, allele_encoding)
+                pretrain_data_filename,
+                allele_encoding,
+                peptides_per_chunk=pretrain_peptides_per_step)
+
             model.fit_generator(
                 generator,
                 validation_peptide_encoding=train_peptides,
diff --git a/test/expensive_test_pretrain_optimizable.py b/test/expensive_test_pretrain_optimizable.py
index c2b0e461..027a88bb 100644
--- a/test/expensive_test_pretrain_optimizable.py
+++ b/test/expensive_test_pretrain_optimizable.py
@@ -33,7 +33,7 @@ FOLDS_DF["fold_0"] = True
 HYPERPARAMTERS = {
     'activation': 'tanh', 'allele_dense_layer_sizes': [],
     'batch_normalization': False,
-    'dense_layer_l1_regularization': 9.999999999999999e-11,
+    'dense_layer_l1_regularization': 0.0,
     'dense_layer_l2_regularization': 0.0, 'dropout_probability': 0.5,
     'early_stopping': True, 'init': 'glorot_uniform',
     'layer_sizes': [1024, 512], 'learning_rate': None,
@@ -50,9 +50,10 @@ HYPERPARAMTERS = {
     'random_negative_distribution_smoothing': 0.0,
     'random_negative_match_distribution': True, 'random_negative_rate': 0.2,
     'train_data': {'pretrain': True,
-                   'pretrain_max_epochs': 3,
-                   'pretrain_peptides_per_epoch': 1024,
-                   'pretrain_steps_per_epoch': 16},
+                   'pretrain_max_epochs': 30,
+                   'pretrain_patience': 5,
+                   'pretrain_peptides_per_step': 32,
+                   'pretrain_steps_per_epoch': 256},
     'validation_split': 0.1,
     'data_dependent_initialization_method': "lsuv",
 }
-- 
GitLab