From b910d1e70368329a321354ce100be7808ff666ab Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Sat, 14 Sep 2019 19:33:41 -0400
Subject: [PATCH] fix

---
 mhcflurry/class1_affinity_predictor.py | 13 ++++++++++++-
 mhcflurry/class1_neural_network.py     | 14 ++++++++++----
 2 files changed, 22 insertions(+), 5 deletions(-)

diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py
index 4eacb9a1..57637888 100644
--- a/mhcflurry/class1_affinity_predictor.py
+++ b/mhcflurry/class1_affinity_predictor.py
@@ -360,13 +360,24 @@ class Class1AffinityPredictor(object):
 
         sub_manifest_df = self.manifest_df.loc[
             self.manifest_df.model_name.isin(model_names_to_write)
-        ]
+        ].copy()
 
+        # Network JSON configs may have changed since the models were added,
+        # for example due to changes to the allele representation layer.
+        # So we update the JSON configs here also.
+        updated_network_config_jsons = []
         for (_, row) in sub_manifest_df.iterrows():
+            updated_network_config_jsons.append(
+                json.dumps(row.model.get_config()))
             weights_path = self.weights_path(models_dir, row.model_name)
             Class1AffinityPredictor.save_weights(
                 row.model.get_weights(), weights_path)
             logging.info("Wrote: %s", weights_path)
+        sub_manifest_df["config_json"] = updated_network_config_jsons
+        self.manifest_df.loc[
+            sub_manifest_df.index,
+            "config_json"
+        ] = updated_network_config_jsons
 
         write_manifest_df = self.manifest_df[[
             c for c in self.manifest_df.columns if c != "model"
diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 411ba54d..78e7ad69 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -1369,7 +1369,8 @@ class Class1NeuralNetwork(object):
         layer = original_model.get_layer("allele_representation")
         existing_weights_shape = (layer.input_dim, layer.output_dim)
         self.set_allele_representations(
-            numpy.zeros(shape=(0,) + existing_weights_shape.shape[1:]))
+            numpy.zeros(shape=(0,) + existing_weights_shape[1:]),
+            force_surgery=True)
 
 
     def set_allele_representations(self, allele_representations, force_surgery=False):
@@ -1397,15 +1398,18 @@ class Class1NeuralNetwork(object):
         import keras.backend as K
         import tensorflow as tf
 
-        reshaped = allele_representations.reshape(
-            (allele_representations.shape[0], -1))
+        reshaped = allele_representations.reshape((
+            allele_representations.shape[0],
+            numpy.product(allele_representations.shape[1:])
+        ))
         original_model = self.network()
         layer = original_model.get_layer("allele_representation")
         existing_weights_shape = (layer.input_dim, layer.output_dim)
 
         # Only changes to the number of supported alleles (not the length of
         # the allele sequences) are allowed.
-        assert existing_weights_shape[1:] == reshaped.shape[1:]
+        assert existing_weights_shape[1:] == reshaped.shape[1:], (
+            existing_weights_shape, reshaped.shape)
 
         if existing_weights_shape[0] > reshaped.shape[0] and not force_surgery:
             # Extend with NaNs so we can avoid having to reshape the weights
@@ -1439,6 +1443,8 @@ class Class1NeuralNetwork(object):
             def throw(*args, **kwargs):
                 raise RuntimeError("Using a disabled model!")
             original_model.predict = \
+                original_model.to_json = \
+                original_model.get_weights = \
                 original_model.fit = \
                 original_model.fit_generator = throw
 
-- 
GitLab