From da0b00733c2dcac375a871781424fc334520ffb4 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Sat, 5 Oct 2019 13:06:42 -0400
Subject: [PATCH] Fix network merging for new architectures

---
 mhcflurry/class1_neural_network.py | 44 ++++++++++++++++++++++++++----
 test/test_network_merging.py       |  8 ++++--
 2 files changed, 43 insertions(+), 9 deletions(-)

diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 1299dd8d..447ab47d 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -1129,13 +1129,25 @@ class Class1NeuralNetwork(object):
             for network in networks
         ]
 
-        pan_allele_layer_names = [
+        pan_allele_layer_initial_names = [
             'allele', 'peptide', 'allele_representation', 'flattened_0',
             'allele_flat', 'allele_peptide_merged', 'dense_0', 'dropout_0',
-            'dense_1', 'dropout_1', 'output',
+            #'dense_1', 'dropout_1', 'output',
         ]
+        pan_allele_layer_final_names = [
+            'output'
+        ]
+
+        def startswith(lst, prefix):
+            return lst[:len(prefix)] == prefix
+
+        def endswith(lst, suffix):
+            return lst[-len(suffix):] == suffix
+
+        if all(startswith(names, pan_allele_layer_initial_names) and
+                endswith(names, pan_allele_layer_final_names)
+                for names in layer_names):
 
-        if all(names == pan_allele_layer_names for names in layer_names):
             # Merging an ensemble of pan-allele architectures
             network = networks[0]
             peptide_input = Input(
@@ -1154,15 +1166,35 @@ class Class1NeuralNetwork(object):
             allele_peptide_merged = network.get_layer("allele_peptide_merged")(
                 [peptide_flat, allele_flat])
 
+
             sub_networks = []
             for (i, network) in enumerate(networks):
                 layers = network.layers[
-                    pan_allele_layer_names.index("allele_peptide_merged") + 1:
+                    pan_allele_layer_initial_names.index(
+                        "allele_peptide_merged") + 1:
                 ]
-                node = allele_peptide_merged
                 for layer in layers:
                     layer.name += "_%d" % i
-                    node = layer(node)
+
+                node = allele_peptide_merged
+                layer_name_to_new_node = {
+                    "allele_peptide_merged": allele_peptide_merged,
+                }
+                for layer in layers:
+                    assert layer.name not in layer_name_to_new_node
+                    input_layer_names = []
+                    for inbound_node in layer._inbound_nodes:
+                        for inbound_layer in inbound_node.inbound_layers:
+                            input_layer_names.append(inbound_layer.name)
+                    input_nodes = [
+                        layer_name_to_new_node[name]
+                        for name in input_layer_names
+                    ]
+                    if len(input_nodes) == 1:
+                        node = layer(input_nodes[0])
+                    else:
+                        node = layer(input_nodes)
+                    layer_name_to_new_node[layer.name] = node
                 sub_networks.append(node)
 
             if merge_method == 'average':
diff --git a/test/test_network_merging.py b/test/test_network_merging.py
index 69eab37d..761e138c 100644
--- a/test/test_network_merging.py
+++ b/test/test_network_merging.py
@@ -1,4 +1,7 @@
 import logging
+logging.getLogger('tensorflow').disabled = True
+logging.getLogger('matplotlib').disabled = True
+
 import numpy
 import pandas
 from mhcflurry import Class1AffinityPredictor, Class1NeuralNetwork
@@ -6,7 +9,7 @@ from mhcflurry.common import random_peptides
 from mhcflurry.downloads import get_path
 
 from mhcflurry.testing_utils import cleanup, startup
-logging.getLogger('tensorflow').disabled = True
+
 
 PAN_ALLELE_PREDICTOR = None
 
@@ -16,7 +19,6 @@ def setup():
     startup()
     PAN_ALLELE_PREDICTOR = Class1AffinityPredictor.load(
         get_path("models_class1_pan", "models.with_mass_spec"),
-        max_models=4,
         optimization_level=0,)
 
 
@@ -27,7 +29,7 @@ def teardown():
 
 
 def test_merge():
-    assert len(PAN_ALLELE_PREDICTOR.class1_pan_allele_models) == 4
+    assert len(PAN_ALLELE_PREDICTOR.class1_pan_allele_models) > 1
 
     peptides = random_peptides(100, length=9)
     peptides.extend(random_peptides(100, length=10))
-- 
GitLab