diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 7dd79edb788d8ef670c53e22366bb07170aafc72..7cbe72f03dcb0785876e2edb37ed1a2944118b9c 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -950,12 +950,12 @@ class Class1NeuralNetwork(object):
         allele_representations
 
         """
+        reshaped = allele_representations.reshape((allele_representations.shape[0], -1))
         layer = self.network().get_layer("allele_representation")
         (existing,) = layer.get_weights()
-        if existing.shape == allele_representations.shape:
-            layer.set_weights([
-                allele_representations.reshape((allele_representations.shape[0], -1))])
+        if existing.shape == reshaped.shape:
+            layer.set_weights([reshaped])
         else:
             raise NotImplementedError(
                 "Network surgery required: %s != %s" % (
-                    str(existing.shape), str(allele_representations.shape)))
+                    str(existing.shape), str(reshaped.shape)))