From 79060d38d6b61aff01dfc7bf20b4a3795d8d4645 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Thu, 15 Feb 2018 15:33:54 -0500
Subject: [PATCH] test fixes

---
 mhcflurry/class1_neural_network.py            | 24 ++++++++-----------
 test/test_class1_neural_network.py            |  3 +++
 ...st_train_allele_specific_models_command.py |  2 +-
 3 files changed, 14 insertions(+), 15 deletions(-)

diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 6ff042aa..3aa75abc 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -163,10 +163,7 @@ class Class1NeuralNetwork(object):
         self.network_weights = None
         self.network_weights_loader = None
 
-        self.loss_history = None
-        self.fit_seconds = None
-        self.fit_num_points = []
-
+        self.fit_info = []
         self.prediction_cache = weakref.WeakKeyDictionary()
 
     KERAS_MODELS_CACHE = {}
@@ -310,7 +307,6 @@ class Class1NeuralNetwork(object):
         """
         config = dict(config)
         instance = cls(**config.pop('hyperparameters'))
-        assert all(hasattr(instance, key) for key in config), config.keys()
         instance.__dict__.update(config)
         instance.network_weights = weights
         instance.network_weights_loader = weights_loader
@@ -471,9 +467,6 @@ class Class1NeuralNetwork(object):
             How often (in seconds) to print progress update. Set to None to
             disable.
         """
-
-        self.fit_num_points.append(len(peptides))
-
         encodable_peptides = EncodableSequences.create(peptides)
         peptide_encoding = self.peptides_to_network_input(encodable_peptides)
 
@@ -629,7 +622,7 @@ class Class1NeuralNetwork(object):
         min_val_loss_iteration = None
         min_val_loss = None
 
-        self.loss_history = collections.defaultdict(list)
+        fit_info = collections.defaultdict(list)
         start = time.time()
         last_progress_print = None
         x_dict_with_random_negatives = {}
@@ -692,7 +685,7 @@ class Class1NeuralNetwork(object):
                 sample_weight=sample_weights_with_random_negatives)
 
             for (key, value) in fit_history.history.items():
-                self.loss_history[key].extend(value)
+                fit_info[key].extend(value)
 
             # Print progress no more often than once every few seconds.
             if progress_print_interval is not None and (
@@ -704,13 +697,13 @@ class Class1NeuralNetwork(object):
                        "Min val loss (%s) at epoch %s" % (
                            i,
                            self.hyperparameters['max_epochs'],
-                           self.loss_history['loss'][-1],
+                           fit_info['loss'][-1],
                            str(min_val_loss),
                            min_val_loss_iteration)).strip())
                 last_progress_print = time.time()
 
             if self.hyperparameters['validation_split']:
-                val_loss = self.loss_history['val_loss'][-1]
+                val_loss = fit_info['val_loss'][-1]
                 val_losses.append(val_loss)
 
                 if min_val_loss is None or val_loss <= min_val_loss:
@@ -728,11 +721,14 @@ class Class1NeuralNetwork(object):
                                 "Min val loss (%s) at epoch %s" % (
                                     i,
                                     self.hyperparameters['max_epochs'],
-                                    self.loss_history['loss'][-1],
+                                    fit_info['loss'][-1],
                                     str(min_val_loss),
                                     min_val_loss_iteration)).strip())
                         break
-        self.fit_seconds = time.time() - start
+
+        fit_info["time"] = time.time() - start
+        fit_info["num_points"] = len(peptides)
+        self.fit_info.append(dict(fit_info))
 
     def predict(self, peptides, allele_encoding=None, batch_size=4096):
         """
diff --git a/test/test_class1_neural_network.py b/test/test_class1_neural_network.py
index b082f28c..8ba330cf 100644
--- a/test/test_class1_neural_network.py
+++ b/test/test_class1_neural_network.py
@@ -89,9 +89,12 @@ def test_inequalities():
     # Memorize the dataset.
     hyperparameters = dict(
         loss="custom:mse_with_inequalities",
+        peptide_amino_acid_encoding="one-hot",
         activation="tanh",
         layer_sizes=[16],
         max_epochs=50,
+        minibatch_size=32,
+        random_negative_rate=0.0,
         early_stopping=False,
         validation_split=0.0,
         locally_connected_layers=[
diff --git a/test/test_train_allele_specific_models_command.py b/test/test_train_allele_specific_models_command.py
index fab19ce7..a1e597ca 100644
--- a/test/test_train_allele_specific_models_command.py
+++ b/test/test_train_allele_specific_models_command.py
@@ -12,7 +12,7 @@ from mhcflurry.downloads import get_path
 HYPERPARAMETERS = [
     {
         "n_models": 2,
-        "max_epochs": 2,
+        "max_epochs": 20,
         "patience": 10,
         "early_stopping": True,
         "validation_split": 0.2,
-- 
GitLab