Skip to content
Snippets Groups Projects
Commit 32b6054d authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Working on ligandome predictor

parent 211b84fd
No related branches found
No related tags found
No related merge requests found
...@@ -127,6 +127,50 @@ class Class1LigandomePredictor(object): ...@@ -127,6 +127,50 @@ class Class1LigandomePredictor(object):
allele_peptide_merged = concatenate([peptides_repeated, allele_flat]) allele_peptide_merged = concatenate([peptides_repeated, allele_flat])
layer_names = [
layer.name for layer in merged_ensemble.layers
]
pan_allele_layer_initial_names = [
'allele', 'peptide',
'allele_representation', 'flattened_0', 'allele_flat',
'allele_peptide_merged', 'dense_0', 'dropout_0',
]
def startswith(lst, prefix):
return lst[:len(prefix)] == prefix
assert startswith(layer_names, pan_allele_layer_initial_names), layer_names
layers = merged_ensemble.layers[
pan_allele_layer_initial_names.index(
"allele_peptide_merged") + 1:
]
node = allele_peptide_merged
layer_name_to_new_node = {
"allele_peptide_merged": allele_peptide_merged,
}
for layer in layers:
assert layer.name not in layer_name_to_new_node
input_layer_names = []
for inbound_node in layer._inbound_nodes:
for inbound_layer in inbound_node.inbound_layers:
input_layer_names.append(inbound_layer.name)
input_nodes = [
layer_name_to_new_node[name]
for name in input_layer_names
]
if len(input_nodes) == 1:
lifted = TimeDistributed(layer)
result = lifted(input_nodes[0])
else:
print(layer, layer.name, node, lifted)
result = layer(input_nodes)
layer_name_to_new_node[layer.name] = result
"""
dense_0 = merged_ensemble.get_layer("dense_0") dense_0 = merged_ensemble.get_layer("dense_0")
td_dense0 = TimeDistributed(dense_0, name="td_dense_0")(allele_peptide_merged) td_dense0 = TimeDistributed(dense_0, name="td_dense_0")(allele_peptide_merged)
td_dense0 = Dropout(0.5)(td_dense0) td_dense0 = Dropout(0.5)(td_dense0)
...@@ -137,10 +181,11 @@ class Class1LigandomePredictor(object): ...@@ -137,10 +181,11 @@ class Class1LigandomePredictor(object):
output = merged_ensemble.get_layer("output") output = merged_ensemble.get_layer("output")
td_output = TimeDistributed(output)(td_dense1) td_output = TimeDistributed(output)(td_dense1)
"""
network = Model( network = Model(
inputs=[input_peptides, input_alleles], inputs=[input_peptides, input_alleles],
outputs=td_output, outputs=node,
name="ligandome", name="ligandome",
) )
#print('trainable', network.get_layer("td_dense_0").trainable) #print('trainable', network.get_layer("td_dense_0").trainable)
...@@ -358,7 +403,7 @@ class Class1LigandomePredictor(object): ...@@ -358,7 +403,7 @@ class Class1LigandomePredictor(object):
'allele': allele_encoding_input, 'allele': allele_encoding_input,
} }
predictions = self.network.predict(x_dict, batch_size=batch_size) predictions = self.network.predict(x_dict, batch_size=batch_size)
return numpy.squeeze(predictions, axis=-1) return numpy.squeeze(predictions)
#def predict(self): #def predict(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment