diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index c38a0b77c789443b7416414fd02f13efc26315ff..ad1c85f8f346bcafe46569ba2cf7db79e42122f7 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -790,8 +790,9 @@ class Class1NeuralNetwork(object): self.prediction_cache[peptides] = result return result - def make_allele_subnetwork(allele_sequence_layer): - return allele_sequence_layer + def make_allele_subnetwork(self, allele_sequence_layer): + from keras.layers.core import Flatten + return Flatten(name="allele_flat")(allele_sequence_layer) def make_network( self, @@ -895,8 +896,6 @@ class Class1NeuralNetwork(object): allele_layer = self.make_allele_subnetwork(allele_layer) - allele_layer = Flatten(name="allele_flat")(allele_layer) - for (i, layer_size) in enumerate(allele_dense_layer_sizes): allele_layer = Dense( layer_size,