From d39b3269c3816ad993e79d2bb57b9b3a8a1804b1 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 19 Jun 2018 14:54:23 -0400
Subject: [PATCH] fix

---
 mhcflurry/class1_neural_network.py | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index bbf3e18e..c38a0b77 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -790,9 +790,11 @@ class Class1NeuralNetwork(object):
             self.prediction_cache[peptides] = result
         return result
 
-    @staticmethod
+    def make_allele_subnetwork(allele_sequence_layer):
+        return allele_sequence_layer
+
     def make_network(
-            allele_representations,
+            self,
             kmer_size,
             peptide_amino_acid_encoding,
             embedding_input_dim,
@@ -810,7 +812,8 @@ class Class1NeuralNetwork(object):
             dropout_probability,
             batch_normalization,
             embedding_init_method,
-            locally_connected_layers):
+            locally_connected_layers,
+            allele_representations=None):
         """
         Helper function to make a keras network for class1 affinity prediction.
         """
@@ -889,6 +892,9 @@ class Class1NeuralNetwork(object):
             allele_layer = Reshape(
                 target_shape=allele_representations.shape[1:],
                 name="allele_reshaped")(allele_representation)
+
+            allele_layer = self.make_allele_subnetwork(allele_layer)
+
             allele_layer = Flatten(name="allele_flat")(allele_layer)
 
             for (i, layer_size) in enumerate(allele_dense_layer_sizes):
-- 
GitLab