diff --git a/mhcflurry/class1_ligandome_predictor.py b/mhcflurry/class1_ligandome_predictor.py index 9a8038826cb679d1bb8216a9941072097405646d..498972992f0b8d2cb116e3dc45908f89f32edd37 100644 --- a/mhcflurry/class1_ligandome_predictor.py +++ b/mhcflurry/class1_ligandome_predictor.py @@ -127,6 +127,50 @@ class Class1LigandomePredictor(object): allele_peptide_merged = concatenate([peptides_repeated, allele_flat]) + layer_names = [ + layer.name for layer in merged_ensemble.layers + ] + + pan_allele_layer_initial_names = [ + 'allele', 'peptide', + 'allele_representation', 'flattened_0', 'allele_flat', + 'allele_peptide_merged', 'dense_0', 'dropout_0', + ] + + def startswith(lst, prefix): + return lst[:len(prefix)] == prefix + + assert startswith(layer_names, pan_allele_layer_initial_names), layer_names + + layers = merged_ensemble.layers[ + pan_allele_layer_initial_names.index( + "allele_peptide_merged") + 1: + ] + node = allele_peptide_merged + layer_name_to_new_node = { + "allele_peptide_merged": allele_peptide_merged, + } + for layer in layers: + assert layer.name not in layer_name_to_new_node + input_layer_names = [] + for inbound_node in layer._inbound_nodes: + for inbound_layer in inbound_node.inbound_layers: + input_layer_names.append(inbound_layer.name) + input_nodes = [ + layer_name_to_new_node[name] + for name in input_layer_names + ] + + if len(input_nodes) == 1: + lifted = TimeDistributed(layer) + result = lifted(input_nodes[0]) + else: + print(layer, layer.name, node, lifted) + result = layer(input_nodes) + + layer_name_to_new_node[layer.name] = result + + """ dense_0 = merged_ensemble.get_layer("dense_0") td_dense0 = TimeDistributed(dense_0, name="td_dense_0")(allele_peptide_merged) td_dense0 = Dropout(0.5)(td_dense0) @@ -137,10 +181,11 @@ class Class1LigandomePredictor(object): output = merged_ensemble.get_layer("output") td_output = TimeDistributed(output)(td_dense1) + """ network = Model( inputs=[input_peptides, input_alleles], - outputs=td_output, + outputs=node, name="ligandome", ) #print('trainable', network.get_layer("td_dense_0").trainable) @@ -358,7 +403,7 @@ class Class1LigandomePredictor(object): 'allele': allele_encoding_input, } predictions = self.network.predict(x_dict, batch_size=batch_size) - return numpy.squeeze(predictions, axis=-1) + return numpy.squeeze(predictions) #def predict(self):