Skip to content
Snippets Groups Projects
Commit 75e6bbd1 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix

parent 4d08fb3c
No related branches found
No related tags found
No related merge requests found
......@@ -98,7 +98,7 @@ class Class1CleavageNeuralNetwork(object):
convolutional_filters=16,
convolutional_kernel_size=8,
convolutional_activation="relu",
convolutional_kernel_l1_l2=(0.001, 0.001),
convolutional_kernel_l1_l2=[0.001, 0.001],
dropout_rate=0.5,
post_convolutional_dense_layer_sizes=[],
)
......@@ -410,22 +410,22 @@ class Class1CleavageNeuralNetwork(object):
if flank == "n_flank":
peptide_input = "peptide_right_pad"
concat_order = [flank, peptide_input]
cleavage_position = n_flank_length
noncleaved_peptide_extractor = lambda x: x[
:, (n_flank_length + 1):]
flanking_extractor = lambda x: x[
:, : n_flank_length
]
cleavage_position_extractor = lambda x: x[:, n_flank_length]
else:
assert flank == "c_flank"
peptide_input = "peptide_left_pad"
concat_order = [peptide_input, flank]
cleavage_position = peptide_max_length - 1
noncleaved_peptide_extractor = lambda x: x[
:, 0 : peptide_max_length - 1]
flanking_extractor = lambda x: x[
:, peptide_max_length :
]
cleavage_position_extractor = lambda x: x[:, peptide_max_length - 1]
if include_flank:
current_layer = Concatenate(
......@@ -465,7 +465,7 @@ class Class1CleavageNeuralNetwork(object):
# Single output at cleavage position
single_output_at_cleavage_position = keras.layers.Lambda(
lambda x: x[:, cleavage_position])(single_output_result)
cleavage_position_extractor)(single_output_result)
outputs_for_final_dense.append(single_output_at_cleavage_position)
# Max of single-output at non-cleaved (peptide) positions.
......
......@@ -40,6 +40,13 @@ def test_basic():
})
df["score"] = predictor.predict(df.peptide, df.n_flank, df.c_flank)
# Test predictions are deterministic
df1b = predictor.predict_to_dataframe(
peptides=df.peptide.values,
n_flanks=df.n_flank.values,
c_flanks=df.c_flank.values)
assert_array_equal(df.score.values, df1b.score.values)
# Test saving and loading
models_dir = tempfile.mkdtemp("_models")
print(models_dir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment