From 0e5723ee1e63b4ffc0dbcd3fe05eafd1f99d9139 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 19 Jun 2018 12:45:39 -0400
Subject: [PATCH] fix

---
 mhcflurry/class1_neural_network.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 0a0e7d99..7dd79edb 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -882,7 +882,7 @@ class Class1NeuralNetwork(object):
             allele_representation = Embedding(
                 name="allele_representation",
                 input_dim=allele_representations.shape[0],
-                output_dim=allele_representations.shape[1],
+                output_dim=allele_representations.shape[1] * allele_representations.shape[2],
                 input_length=1,
                 trainable=False)(allele_input)
 
@@ -953,7 +953,8 @@ class Class1NeuralNetwork(object):
         layer = self.network().get_layer("allele_representation")
         (existing,) = layer.get_weights()
         if existing.shape == allele_representations.shape:
-            layer.set_weights([allele_representations])
+            layer.set_weights([
+                allele_representations.reshape((allele_representations.shape[0], -1))])
         else:
             raise NotImplementedError(
                 "Network surgery required: %s != %s" % (
-- 
GitLab