From da0b00733c2dcac375a871781424fc334520ffb4 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Sat, 5 Oct 2019 13:06:42 -0400 Subject: [PATCH] Fix network merging for new architectures --- mhcflurry/class1_neural_network.py | 44 ++++++++++++++++++++++++++---- test/test_network_merging.py | 8 ++++-- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 1299dd8d..447ab47d 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -1129,13 +1129,25 @@ class Class1NeuralNetwork(object): for network in networks ] - pan_allele_layer_names = [ + pan_allele_layer_initial_names = [ 'allele', 'peptide', 'allele_representation', 'flattened_0', 'allele_flat', 'allele_peptide_merged', 'dense_0', 'dropout_0', - 'dense_1', 'dropout_1', 'output', + #'dense_1', 'dropout_1', 'output', ] + pan_allele_layer_final_names = [ + 'output' + ] + + def startswith(lst, prefix): + return lst[:len(prefix)] == prefix + + def endswith(lst, suffix): + return lst[-len(suffix):] == suffix + + if all(startswith(names, pan_allele_layer_initial_names) and + endswith(names, pan_allele_layer_final_names) + for names in layer_names): - if all(names == pan_allele_layer_names for names in layer_names): # Merging an ensemble of pan-allele architectures network = networks[0] peptide_input = Input( @@ -1154,15 +1166,35 @@ class Class1NeuralNetwork(object): allele_peptide_merged = network.get_layer("allele_peptide_merged")( [peptide_flat, allele_flat]) + sub_networks = [] for (i, network) in enumerate(networks): layers = network.layers[ - pan_allele_layer_names.index("allele_peptide_merged") + 1: + pan_allele_layer_initial_names.index( + "allele_peptide_merged") + 1: ] - node = allele_peptide_merged for layer in layers: layer.name += "_%d" % i - node = layer(node) + + node = allele_peptide_merged + layer_name_to_new_node = { + "allele_peptide_merged": allele_peptide_merged, + } + for layer in layers: + assert layer.name not in layer_name_to_new_node + input_layer_names = [] + for inbound_node in layer._inbound_nodes: + for inbound_layer in inbound_node.inbound_layers: + input_layer_names.append(inbound_layer.name) + input_nodes = [ + layer_name_to_new_node[name] + for name in input_layer_names + ] + if len(input_nodes) == 1: + node = layer(input_nodes[0]) + else: + node = layer(input_nodes) + layer_name_to_new_node[layer.name] = node sub_networks.append(node) if merge_method == 'average': diff --git a/test/test_network_merging.py b/test/test_network_merging.py index 69eab37d..761e138c 100644 --- a/test/test_network_merging.py +++ b/test/test_network_merging.py @@ -1,4 +1,7 @@ import logging +logging.getLogger('tensorflow').disabled = True +logging.getLogger('matplotlib').disabled = True + import numpy import pandas from mhcflurry import Class1AffinityPredictor, Class1NeuralNetwork @@ -6,7 +9,7 @@ from mhcflurry.common import random_peptides from mhcflurry.downloads import get_path from mhcflurry.testing_utils import cleanup, startup -logging.getLogger('tensorflow').disabled = True + PAN_ALLELE_PREDICTOR = None @@ -16,7 +19,6 @@ def setup(): startup() PAN_ALLELE_PREDICTOR = Class1AffinityPredictor.load( get_path("models_class1_pan", "models.with_mass_spec"), - max_models=4, optimization_level=0,) @@ -27,7 +29,7 @@ def teardown(): def test_merge(): - assert len(PAN_ALLELE_PREDICTOR.class1_pan_allele_models) == 4 + assert len(PAN_ALLELE_PREDICTOR.class1_pan_allele_models) > 1 peptides = random_peptides(100, length=9) peptides.extend(random_peptides(100, length=10)) -- GitLab