From 31a33adb8adb7e883dcfb5c811e8b72b398d81d4 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Tue, 19 Jun 2018 12:46:52 -0400 Subject: [PATCH] fix --- mhcflurry/class1_neural_network.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 7dd79edb..7cbe72f0 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -950,12 +950,12 @@ class Class1NeuralNetwork(object): allele_representations """ + reshaped = allele_representations.reshape((allele_representations.shape[0], -1)) layer = self.network().get_layer("allele_representation") (existing,) = layer.get_weights() - if existing.shape == allele_representations.shape: - layer.set_weights([ - allele_representations.reshape((allele_representations.shape[0], -1))]) + if existing.shape == reshaped.shape: + layer.set_weights([reshaped]) else: raise NotImplementedError( "Network surgery required: %s != %s" % ( - str(existing.shape), str(allele_representations.shape))) + str(existing.shape), str(reshaped.shape))) -- GitLab