Skip to content
Snippets Groups Projects
Commit 32b6054d authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Working on ligandome predictor

parent 211b84fd
No related merge requests found
......@@ -127,6 +127,50 @@ class Class1LigandomePredictor(object):
allele_peptide_merged = concatenate([peptides_repeated, allele_flat])
layer_names = [
layer.name for layer in merged_ensemble.layers
]
pan_allele_layer_initial_names = [
'allele', 'peptide',
'allele_representation', 'flattened_0', 'allele_flat',
'allele_peptide_merged', 'dense_0', 'dropout_0',
]
def startswith(lst, prefix):
return lst[:len(prefix)] == prefix
assert startswith(layer_names, pan_allele_layer_initial_names), layer_names
layers = merged_ensemble.layers[
pan_allele_layer_initial_names.index(
"allele_peptide_merged") + 1:
]
node = allele_peptide_merged
layer_name_to_new_node = {
"allele_peptide_merged": allele_peptide_merged,
}
for layer in layers:
assert layer.name not in layer_name_to_new_node
input_layer_names = []
for inbound_node in layer._inbound_nodes:
for inbound_layer in inbound_node.inbound_layers:
input_layer_names.append(inbound_layer.name)
input_nodes = [
layer_name_to_new_node[name]
for name in input_layer_names
]
if len(input_nodes) == 1:
lifted = TimeDistributed(layer)
result = lifted(input_nodes[0])
else:
print(layer, layer.name, node, lifted)
result = layer(input_nodes)
layer_name_to_new_node[layer.name] = result
"""
dense_0 = merged_ensemble.get_layer("dense_0")
td_dense0 = TimeDistributed(dense_0, name="td_dense_0")(allele_peptide_merged)
td_dense0 = Dropout(0.5)(td_dense0)
......@@ -137,10 +181,11 @@ class Class1LigandomePredictor(object):
output = merged_ensemble.get_layer("output")
td_output = TimeDistributed(output)(td_dense1)
"""
network = Model(
inputs=[input_peptides, input_alleles],
outputs=td_output,
outputs=node,
name="ligandome",
)
#print('trainable', network.get_layer("td_dense_0").trainable)
......@@ -358,7 +403,7 @@ class Class1LigandomePredictor(object):
'allele': allele_encoding_input,
}
predictions = self.network.predict(x_dict, batch_size=batch_size)
return numpy.squeeze(predictions, axis=-1)
return numpy.squeeze(predictions)
#def predict(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment