diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 7dd79edb788d8ef670c53e22366bb07170aafc72..7cbe72f03dcb0785876e2edb37ed1a2944118b9c 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -950,12 +950,12 @@ class Class1NeuralNetwork(object): allele_representations """ + reshaped = allele_representations.reshape((allele_representations.shape[0], -1)) layer = self.network().get_layer("allele_representation") (existing,) = layer.get_weights() - if existing.shape == allele_representations.shape: - layer.set_weights([ - allele_representations.reshape((allele_representations.shape[0], -1))]) + if existing.shape == reshaped.shape: + layer.set_weights([reshaped]) else: raise NotImplementedError( "Network surgery required: %s != %s" % ( - str(existing.shape), str(allele_representations.shape))) + str(existing.shape), str(reshaped.shape)))