Skip to content
Snippets Groups Projects
Commit c4593f16 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix

parent 88364fe8
No related branches found
No related tags found
No related merge requests found
......@@ -72,8 +72,6 @@ class MSEWithInequalities(object):
# without tensorflow debug output, etc.
from keras import backend as K
y_pred = K.flatten(y_pred)
# Handle (=) inequalities
diff1 = y_pred - y_true
diff1 *= K.cast(y_true >= 0.0, "float32")
......@@ -93,7 +91,7 @@ class MSEWithInequalities(object):
result = (
K.sum(K.square(diff1)) +
K.sum(K.square(diff2)) +
K.sum(K.square(diff3))) / K.cast(K.shape(y_true)[0], "float32")
K.sum(K.square(diff3))) / K.cast(K.shape(y_pred)[0], "float32")
return result
......
......@@ -88,7 +88,6 @@ def test_class1_neural_network_a0205_training_accuracy():
def test_inequalities():
# Memorize the dataset.
hyperparameters = dict(
loss="custom:mse_with_inequalities",
peptide_amino_acid_encoding="one-hot",
activation="tanh",
layer_sizes=[16],
......@@ -105,7 +104,8 @@ def test_inequalities():
}
],
dense_layer_l1_regularization=0.0,
dropout_probability=0.0)
dropout_probability=0.0,
loss="custom:mse_with_inequalities_and_multiple_outputs")
df = pandas.DataFrame()
df["peptide"] = random_peptides(1000, length=9)
......
......@@ -17,8 +17,10 @@ def evaluate_loss(loss, y_true, y_pred):
y_pred = numpy.array(y_pred)
if y_pred.ndim == 1:
y_pred = y_pred.reshape((len(y_pred), 1))
if y_true.ndim == 1:
y_true = y_true.reshape((len(y_true), 1))
assert y_true.ndim == 1
assert y_true.ndim == 2
assert y_pred.ndim == 2
if K.backend() == "tensorflow":
......
......@@ -10,12 +10,10 @@ import logging
logging.getLogger('tensorflow').disabled = True
from mhcflurry.class1_neural_network import Class1NeuralNetwork
from mhcflurry.downloads import get_path
from mhcflurry.common import random_peptides
def test_multi_output():
# Memorize the dataset.
hyperparameters = dict(
loss="custom:mse_with_inequalities_and_multiple_outputs",
activation="tanh",
......@@ -87,62 +85,3 @@ def test_multi_output():
assert sub_correlation.iloc[1, 1] > 0.99, correlation
assert sub_correlation.iloc[2, 2] > 0.99, correlation
import ipdb ; ipdb.set_trace()
# Prediction2 has a (<) inequality on binders and an (=) on non-binders
predictor = Class1NeuralNetwork(**hyperparameters)
predictor.fit(
df.peptide.values,
df.value.values,
inequalities=df.inequality2.values,
**fit_kwargs)
df["prediction2"] = predictor.predict(df.peptide.values)
# Prediction3 has a (=) inequality on binders and an (>) on non-binders
predictor = Class1NeuralNetwork(**hyperparameters)
predictor.fit(
df.peptide.values,
df.value.values,
inequalities=df.inequality3.values,
**fit_kwargs)
df["prediction3"] = predictor.predict(df.peptide.values)
df_binders = df.loc[df.binder]
df_nonbinders = df.loc[~df.binder]
print("***** Binders: *****")
print(df_binders.head(5))
print("***** Non-binders: *****")
print(df_nonbinders.head(5))
# Binders should always be given tighter predicted affinity than non-binders
assert_less(df_binders.prediction1.mean(), df_nonbinders.prediction1.mean())
assert_less(df_binders.prediction2.mean(), df_nonbinders.prediction2.mean())
assert_less(df_binders.prediction3.mean(), df_nonbinders.prediction3.mean())
# prediction2 binders should be tighter on average than prediction1
# binders, since prediction2 has a (<) inequality for binders.
# Non-binders should be about the same between prediction2 and prediction1
assert_less(df_binders.prediction2.mean(), df_binders.prediction1.mean())
assert_almost_equal(
df_nonbinders.prediction2.mean(),
df_nonbinders.prediction1.mean(),
delta=3000)
# prediction3 non-binders should be weaker on average than prediction2 (or 1)
# non-binders, since prediction3 has a (>) inequality for these peptides.
# Binders should be about the same.
assert_greater(
df_nonbinders.prediction3.mean(),
df_nonbinders.prediction2.mean())
assert_greater(
df_nonbinders.prediction3.mean(),
df_nonbinders.prediction1.mean())
assert_almost_equal(
df_binders.prediction3.mean(),
df_binders.prediction1.mean(),
delta=3000)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment