diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 0a0e7d9959f305290bcda3050f6208585ef83f90..7dd79edb788d8ef670c53e22366bb07170aafc72 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -882,7 +882,7 @@ class Class1NeuralNetwork(object): allele_representation = Embedding( name="allele_representation", input_dim=allele_representations.shape[0], - output_dim=allele_representations.shape[1], + output_dim=allele_representations.shape[1] * allele_representations.shape[2], input_length=1, trainable=False)(allele_input) @@ -953,7 +953,8 @@ class Class1NeuralNetwork(object): layer = self.network().get_layer("allele_representation") (existing,) = layer.get_weights() if existing.shape == allele_representations.shape: - layer.set_weights([allele_representations]) + layer.set_weights([ + allele_representations.reshape((allele_representations.shape[0], -1))]) else: raise NotImplementedError( "Network surgery required: %s != %s" % (