diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index c38a0b77c789443b7416414fd02f13efc26315ff..ad1c85f8f346bcafe46569ba2cf7db79e42122f7 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -790,8 +790,9 @@ class Class1NeuralNetwork(object):
             self.prediction_cache[peptides] = result
         return result
 
-    def make_allele_subnetwork(allele_sequence_layer):
-        return allele_sequence_layer
+    def make_allele_subnetwork(self, allele_sequence_layer):
+        from keras.layers.core import Flatten
+        return Flatten(name="allele_flat")(allele_sequence_layer)
 
     def make_network(
             self,
@@ -895,8 +896,6 @@ class Class1NeuralNetwork(object):
 
             allele_layer = self.make_allele_subnetwork(allele_layer)
 
-            allele_layer = Flatten(name="allele_flat")(allele_layer)
-
             for (i, layer_size) in enumerate(allele_dense_layer_sizes):
                 allele_layer = Dense(
                     layer_size,