From 30cc445179b8b51b854f2b7d254f549ce92e761e Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Thu, 5 Dec 2019 16:50:28 -0500
Subject: [PATCH] Implement
 Class1AffinityNeuralNetwork.copy_weights_to_affinity_model

---
 mhcflurry/class1_neural_network.py            |  1 -
 .../class1_presentation_neural_network.py     | 40 ++++++++++++++-----
 setup.py                                      |  2 +-
 test/test_class1_presentation_predictor.py    | 23 ++++++++++-
 4 files changed, 51 insertions(+), 15 deletions(-)

diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 1b5f2eca..4829449e 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -1384,7 +1384,6 @@ class Class1NeuralNetwork(object):
             numpy.zeros(shape=(0,) + existing_weights_shape[1:]),
             force_surgery=True)
 
-
     def set_allele_representations(self, allele_representations, force_surgery=False):
         """
         Set the allele representations in use by this model. This means mutating
diff --git a/mhcflurry/class1_presentation_neural_network.py b/mhcflurry/class1_presentation_neural_network.py
index 503fa4bc..d3a69e93 100644
--- a/mhcflurry/class1_presentation_neural_network.py
+++ b/mhcflurry/class1_presentation_neural_network.py
@@ -89,7 +89,7 @@ class Class1PresentationNeuralNetwork(object):
         self.fit_info = []
         self.allele_representation_hash = None
 
-    def load_from_class1_neural_network(self, class1_neural_network):
+    def load_from_class1_neural_network(self, model):
         import keras.backend as K
         from keras.layers import (
             Input,
@@ -107,11 +107,11 @@ class Class1PresentationNeuralNetwork(object):
         from keras.models import Model
         from keras.initializers import Zeros
 
-        if isinstance(class1_neural_network, Class1NeuralNetwork):
-            class1_neural_network = class1_neural_network.network()
+        assert isinstance(model, Class1NeuralNetwork), model
+        affinity_network = model.network()
 
         peptide_shape = tuple(
-            int(x) for x in K.int_shape(class1_neural_network.inputs[0])[1:])
+            int(x) for x in K.int_shape(affinity_network.inputs[0])[1:])
 
         input_alleles = Input(
             shape=(self.hyperparameters['max_alleles'],), name="allele")
@@ -138,7 +138,7 @@ class Class1PresentationNeuralNetwork(object):
             [peptides_repeated, allele_flat], name="allele_peptide_merged")
 
         layer_names = [
-            layer.name for layer in class1_neural_network.layers
+            layer.name for layer in affinity_network.layers
         ]
 
         pan_allele_layer_initial_names = [
@@ -153,7 +153,7 @@ class Class1PresentationNeuralNetwork(object):
         assert startswith(
             layer_names, pan_allele_layer_initial_names), layer_names
 
-        layers = class1_neural_network.layers[
+        layers = affinity_network.layers[
             pan_allele_layer_initial_names.index(
                 "allele_peptide_merged") + 1:
         ]
@@ -270,7 +270,13 @@ class Class1PresentationNeuralNetwork(object):
             ],
             name="presentation",
         )
-        self.network.summary()
+
+    def copy_weights_to_affinity_model(self, model):
+        # We assume that the other model's layers are a prefix of ours.
+        self.clear_allele_representations()
+        model.clear_allele_representations()
+        model.network().set_weights(
+            self.get_weights()[:len(model.get_weights())])
 
     def peptides_to_network_input(self, peptides):
         """
@@ -659,15 +665,27 @@ class Class1PresentationNeuralNetwork(object):
             self.network.predict(x_dict, batch_size=batch_size))
         return predictions
 
-    def set_allele_representations(self, allele_representations):
+    def clear_allele_representations(self):
+        """
+        Set allele representations to an empty array. Useful before saving to
+        save a smaller version of the model.
+        """
+        layer = self.network.get_layer("allele_representation")
+        existing_weights_shape = (layer.input_dim, layer.output_dim)
+        self.set_allele_representations(
+            numpy.zeros(shape=(0,) + existing_weights_shape[1:]),
+            force_surgery=True)
+
+    def set_allele_representations(self, allele_representations, force_surgery=False):
         """
         """
         from keras.models import clone_model
         import keras.backend as K
         import tensorflow as tf
 
-        reshaped = allele_representations.reshape(
-            (allele_representations.shape[0], -1))
+        reshaped = allele_representations.reshape((
+            allele_representations.shape[0],
+            numpy.product(allele_representations.shape[1:])))
         original_model = self.network
 
         layer = original_model.get_layer("allele_representation")
@@ -677,7 +695,7 @@ class Class1PresentationNeuralNetwork(object):
         # the allele sequences) are allowed.
         assert existing_weights_shape[1:] == reshaped.shape[1:]
 
-        if existing_weights_shape[0] > reshaped.shape[0]:
+        if existing_weights_shape[0] > reshaped.shape[0] and not force_surgery:
             # Extend with NaNs so we can avoid having to reshape the weights
             # matrix, which is expensive.
             reshaped = numpy.append(
diff --git a/setup.py b/setup.py
index 78809cd4..9275775e 100644
--- a/setup.py
+++ b/setup.py
@@ -53,7 +53,7 @@ if __name__ == '__main__':
         'pandas>=0.20.3',
         'Keras>=2.2.5',
         'appdirs',
-        'tensorflow>=1.1.0,<2.0.0',
+        'tensorflow>=1.15.0,<2.0.0',
         'scikit-learn',
         'mhcnames',
         'pyyaml',
diff --git a/test/test_class1_presentation_predictor.py b/test/test_class1_presentation_predictor.py
index b54bee2a..eeb574d3 100644
--- a/test/test_class1_presentation_predictor.py
+++ b/test/test_class1_presentation_predictor.py
@@ -99,7 +99,7 @@ def test_basic():
             optimizer="adam",
             random_negative_rate=0.0,
             random_negative_constant=0,
-            max_epochs=100,
+            max_epochs=25,
             learning_rate=0.001,
             batch_generator_batch_size=256)
         presentation_network.load_from_class1_neural_network(affinity_network)
@@ -190,6 +190,9 @@ def test_basic():
     train_df["updated_score"] = new_predictor.predict(
         train_df.peptide.values,
         alleles=["HLA-A*02:20"])
+    train_df["updated_affinity"] = new_predictor.predict_to_dataframe(
+        train_df.peptide.values,
+        alleles=["HLA-A*02:20"]).affinity.values
     train_df["score_diff"] = train_df.updated_score - train_df.original_score
     mean_change = train_df.groupby("label").score_diff.mean()
     print("Mean change:")
@@ -203,6 +206,22 @@ def test_basic():
         train_df.pre_train_affinity_prediction.values,
         train_df.post_train_affinity_prediction.values)
 
+    (affinity_model,) = affinity_predictor.class1_pan_allele_models
+    model.copy_weights_to_affinity_model(affinity_model)
+    train_df["post_copy_weights_prediction"] = affinity_predictor.predict(
+        train_df.peptide.values, alleles=train_df.allele.values)
+    assert_allclose(
+        train_df.updated_affinity.values,
+        train_df.post_copy_weights_prediction.values,
+        rtol=1e-5)
+    train_df["affinity_diff"] = (
+        train_df.post_copy_weights_prediction -
+        train_df.post_train_affinity_prediction)
+    median_affinity_change = train_df.groupby("label").affinity_diff.median()
+    print("Median affinity change:")
+    print(median_affinity_change)
+    assert_less(median_affinity_change[1.0], median_affinity_change[0.0])
+
 
 def scramble_peptide(peptide):
     lst = list(peptide)
@@ -235,7 +254,7 @@ def evaluate_loss(loss, y_true, y_pred):
         raise ValueError("Unsupported backend: %s" % K.backend())
 
 
-def test_loss():
+def Xtest_loss():
     for delta in [0.0, 0.3]:
         print("delta", delta)
         # Hit labels
-- 
GitLab