From 31de3617805acb37812223bb11ebeea8da758207 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Tue, 19 Jun 2018 12:51:34 -0400 Subject: [PATCH] fix --- mhcflurry/class1_neural_network.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 7cbe72f0..bbf3e18e 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -820,7 +820,7 @@ class Class1NeuralNetwork(object): from keras.layers import Input import keras.layers - from keras.layers.core import Dense, Flatten, Dropout + from keras.layers.core import Dense, Flatten, Reshape, Dropout from keras.layers.embeddings import Embedding from keras.layers.normalization import BatchNormalization @@ -886,7 +886,10 @@ class Class1NeuralNetwork(object): input_length=1, trainable=False)(allele_input) - allele_layer = Flatten(name="allele_flat")(allele_representation) + allele_layer = Reshape( + target_shape=allele_representations.shape[1:], + name="allele_reshaped")(allele_representation) + allele_layer = Flatten(name="allele_flat")(allele_layer) for (i, layer_size) in enumerate(allele_dense_layer_sizes): allele_layer = Dense( -- GitLab