From 8796764ca129a9e29e98a38e8c6d31c864d118ae Mon Sep 17 00:00:00 2001
From: Alex Rubinsteyn <alex.rubinsteyn@gmail.com>
Date: Thu, 22 Oct 2015 18:42:13 -0400
Subject: [PATCH] upgrading keras, trying to figure out why output changes

---
 mhcflurry/feedforward.py            | 2 ++
 mhcflurry/mhc1_binding_predictor.py | 3 +++
 2 files changed, 5 insertions(+)

diff --git a/mhcflurry/feedforward.py b/mhcflurry/feedforward.py
index fa37df23..7f09c73d 100644
--- a/mhcflurry/feedforward.py
+++ b/mhcflurry/feedforward.py
@@ -81,11 +81,13 @@ def make_network(
         model.add(Embedding(
             input_dim=embedding_input_dim,
             output_dim=embedding_output_dim,
+            input_length=input_size,
             init=init))
         model.add(Flatten())
         input_size = input_size * embedding_output_dim
 
     layer_sizes = (input_size,) + tuple(layer_sizes)
+
     for i, dim in enumerate(layer_sizes):
         if i == 0:
             # input is only conceptually a layer of the network,
diff --git a/mhcflurry/mhc1_binding_predictor.py b/mhcflurry/mhc1_binding_predictor.py
index baa567a9..983a8293 100644
--- a/mhcflurry/mhc1_binding_predictor.py
+++ b/mhcflurry/mhc1_binding_predictor.py
@@ -59,6 +59,7 @@ class Mhc1BindingPredictor(object):
         else:
             filename = self.allele + ".hdf"
             path = join(model_directory, filename)
+            print("HDF path: %s" % path)
             if not exists(path):
                 raise ValueError("Unsupported allele: %s" % (
                     original_allele_name,))
@@ -71,7 +72,9 @@ class Mhc1BindingPredictor(object):
                 init=INITIALIZATION_METHOD,
                 dropout_probability=DROPOUT_PROBABILITY,
                 compile_for_training=True)
+            print("before", len(self.model.get_weights()), self.model.get_weights()[0][0])
             self.model.load_weights(path)
+            print("after", len(self.model.get_weights()), self.model.get_weights()[0][0])
             _allele_model_cache[self.allele] = self.model
 
     def __repr__(self):
-- 
GitLab