diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 7cbe72f03dcb0785876e2edb37ed1a2944118b9c..bbf3e18eb9d455b6b18bcecc25c1e743cd7e3b84 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -820,7 +820,7 @@ class Class1NeuralNetwork(object): from keras.layers import Input import keras.layers - from keras.layers.core import Dense, Flatten, Dropout + from keras.layers.core import Dense, Flatten, Reshape, Dropout from keras.layers.embeddings import Embedding from keras.layers.normalization import BatchNormalization @@ -886,7 +886,10 @@ class Class1NeuralNetwork(object): input_length=1, trainable=False)(allele_input) - allele_layer = Flatten(name="allele_flat")(allele_representation) + allele_layer = Reshape( + target_shape=allele_representations.shape[1:], + name="allele_reshaped")(allele_representation) + allele_layer = Flatten(name="allele_flat")(allele_layer) for (i, layer_size) in enumerate(allele_dense_layer_sizes): allele_layer = Dense(