From 0e5723ee1e63b4ffc0dbcd3fe05eafd1f99d9139 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Tue, 19 Jun 2018 12:45:39 -0400 Subject: [PATCH] fix --- mhcflurry/class1_neural_network.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 0a0e7d99..7dd79edb 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -882,7 +882,7 @@ class Class1NeuralNetwork(object): allele_representation = Embedding( name="allele_representation", input_dim=allele_representations.shape[0], - output_dim=allele_representations.shape[1], + output_dim=allele_representations.shape[1] * allele_representations.shape[2], input_length=1, trainable=False)(allele_input) @@ -953,7 +953,8 @@ class Class1NeuralNetwork(object): layer = self.network().get_layer("allele_representation") (existing,) = layer.get_weights() if existing.shape == allele_representations.shape: - layer.set_weights([allele_representations]) + layer.set_weights([ + allele_representations.reshape((allele_representations.shape[0], -1))]) else: raise NotImplementedError( "Network surgery required: %s != %s" % ( -- GitLab