From 31de3617805acb37812223bb11ebeea8da758207 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 19 Jun 2018 12:51:34 -0400
Subject: [PATCH] fix

---
 mhcflurry/class1_neural_network.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 7cbe72f0..bbf3e18e 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -820,7 +820,7 @@ class Class1NeuralNetwork(object):
 
         from keras.layers import Input
         import keras.layers
-        from keras.layers.core import Dense, Flatten, Dropout
+        from keras.layers.core import Dense, Flatten, Reshape, Dropout
         from keras.layers.embeddings import Embedding
         from keras.layers.normalization import BatchNormalization
 
@@ -886,7 +886,10 @@ class Class1NeuralNetwork(object):
                 input_length=1,
                 trainable=False)(allele_input)
 
-            allele_layer = Flatten(name="allele_flat")(allele_representation)
+            allele_layer = Reshape(
+                target_shape=allele_representations.shape[1:],
+                name="allele_reshaped")(allele_representation)
+            allele_layer = Flatten(name="allele_flat")(allele_layer)
 
             for (i, layer_size) in enumerate(allele_dense_layer_sizes):
                 allele_layer = Dense(
-- 
GitLab