diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 582dac78e335d7a70b8d61b2b8ec0eddcb813817..6a8bb96bb5212cf6b9368eda3fa18d42ac0806fe 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -884,7 +884,7 @@ class Class1NeuralNetwork(object):
                 input_dim=allele_representations.shape[0],
                 output_dim=allele_representations.shape[1],
                 input_length=1,
-                trainable=False)
+                trainable=False)(allele_input)
 
             allele_layer = Flatten(name="allele_flat")(allele_representation)