From 11632aad78a96f436bd889a88b290785a19b5e38 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Wed, 19 Sep 2018 14:54:00 -0400
Subject: [PATCH] fixes

---
 mhcflurry/class1_neural_network.py | 6 +++++-
 requirements.txt                   | 1 +
 setup.py                           | 1 +
 3 files changed, 7 insertions(+), 1 deletion(-)

diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index adeb0850..7e3e667f 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -435,6 +435,7 @@ class Class1NeuralNetwork(object):
             sample_weights=None,
             shuffle_permutation=None,
             verbose=1,
+            progress_callback=None,
             progress_preamble="",
             progress_print_interval=5.0):
         """
@@ -739,6 +740,9 @@ class Class1NeuralNetwork(object):
                                     min_val_loss_iteration)).strip())
                         break
 
+            if progress_callback:
+                progress_callback()
+
         fit_info["time"] = time.time() - start
         fit_info["num_points"] = len(peptides)
         self.fit_info.append(dict(fit_info))
@@ -823,7 +827,7 @@ class Class1NeuralNetwork(object):
 
         from keras.layers import Input
         import keras.layers
-        from keras.layers.core import Dense, Flatten, Reshape, Dropout
+        from keras.layers.core import Dense, Flatten, Dropout
         from keras.layers.embeddings import Embedding
         from keras.layers.normalization import BatchNormalization
 
diff --git a/requirements.txt b/requirements.txt
index da0ceda3..f9de6c97 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,3 +8,4 @@ scikit-learn
 mhcnames
 pyyaml
 tqdm
+np_utils
\ No newline at end of file
diff --git a/setup.py b/setup.py
index fe12d963..97224549 100644
--- a/setup.py
+++ b/setup.py
@@ -58,6 +58,7 @@ if __name__ == '__main__':
         'mhcnames',
         'pyyaml',
         'tqdm',
+        'np_utils',
     ]
     if PY2:
         # concurrent.futures is a standard library in Py3 but Py2
-- 
GitLab