Skip to content
Snippets Groups Projects
Commit 864cabaa authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

better loss

parent 288e0db7
No related branches found
No related tags found
No related merge requests found
...@@ -140,88 +140,29 @@ class Class1LigandomePredictor(object): ...@@ -140,88 +140,29 @@ class Class1LigandomePredictor(object):
name="ligandome", name="ligandome",
) )
#print('trainable', network.get_layer("td_dense_0").trainable) #print('trainable', network.get_layer("td_dense_0").trainable)
network.get_layer("td_dense_0").trainable = False #network.get_layer("td_dense_0").trainable = False
#print('trainable', network.get_layer("td_dense_0").trainable) #print('trainable', network.get_layer("td_dense_0").trainable)
return network return network
@staticmethod @staticmethod
def loss(y_true, y_pred, lmbda=0.001): def loss(y_true, y_pred, delta=0.2):
import keras.backend as K """
Loss function for ligandome prediction.
"""
import tensorflow as tf import tensorflow as tf
y_pred = tf.squeeze(y_pred, axis=-1) y_pred = tf.squeeze(y_pred, axis=-1)
#y_pred = tf.Print(y_pred, [y_pred, tf.shape(y_pred)], "y_pred", summarize=20)
#y_true = tf.Print(y_true, [y_true, tf.shape(y_true)], "y_true", summarize=20)
y_true = tf.reshape(tf.cast(y_true, tf.bool), (-1,)) y_true = tf.reshape(tf.cast(y_true, tf.bool), (-1,))
pos = tf.boolean_mask(y_pred, y_true) pos = tf.boolean_mask(y_pred, y_true)
pos_max = tf.reduce_max(pos, axis=1) #pos_max = tf.reduce_max(pos, axis=1)
#pos_max = tf.reduce_logsumexp(tf.boolean_mask(y_pred, y_true), axis=1)
neg = tf.boolean_mask(y_pred, tf.logical_not(y_true)) neg = tf.boolean_mask(y_pred, tf.logical_not(y_true))
result = tf.reduce_sum( result = tf.reduce_sum(
tf.maximum(0.0, tf.reshape(neg, (-1, 1)) - pos_max)**2) tf.maximum(0.0, tf.reshape(neg, (-1, 1)) - pos_max + delta) ** 2)
term2 = tf.reduce_sum(
tf.minimum(0.0, tf.reshape(neg, (-1, 1)) - pos_max))
result = result + lmbda * term2
#differences = tf.reshape(neg, (-1, 1)) - pos
#result = tf.reduce_sum(tf.sign(differences) * differences**2)
#result = tf.Print(result, [result], "result", summarize=20)
#term2 = lmbda * tf.reduce_mean((1 - pos)**2)
#result = result + term2
return result return result
"""
pos = tf.boolean_mask(y_pred, y_true)
pos = y_pred[y_true.astype(bool)].max(1)
neg = y_pred[~y_true.astype(bool)]
expected2 = (numpy.maximum(0,
neg.flatten().reshape((-1, 1)) - pos) ** 2).sum()
"""
@staticmethod
def loss_old(y_true, y_pred):
"""Binary cross entropy after taking logsumexp over predictions"""
import keras.backend as K
import tensorflow as tf
#y_pred_aggregated = K.logsumexp(y_pred, axis=1, keepdims=True)
#y_pred_aggregated = K.sigmoid(y_pred_aggregated)
#y_pred = tf.Print(y_pred, [y_pred], "y_pred", summarize=20)
#y_true = tf.Print(y_true, [y_true], "y_true", summarize=20)
y_pred_aggregated = K.max(y_pred, axis=1, keepdims=False)
#y_pred_aggregated = tf.Print(y_pred_aggregated, [y_pred_aggregated], "y_pred_aggregated",
# summarize=20)
y_true = K.squeeze(K.cast(y_true, y_pred_aggregated.dtype), axis=-1)
#print("SHAPES", y_pred, K.int_shape(y_pred), y_pred_aggregated, K.int_shape(y_pred_aggregated), y_true, K.int_shape(y_true))
#K.print_tensor(y_pred_aggregated, "y_pred_aggregated")
#K.print_tensor(y_true, "y_true")
#y_pred_aggregated = K.print_tensor(y_pred_aggregated, "y_pred_aggregated")
#y_true = K.print_tensor(y_true, "y_true")
#return K.mean(
# K.binary_crossentropy(y_true, y_pred_aggregated),
# axis=-1)
return K.mean(
(y_true - y_pred_aggregated)**2,
axis=-1
)
def peptides_to_network_input(self, peptides): def peptides_to_network_input(self, peptides):
""" """
Encode peptides to the fixed-length encoding expected by the neural Encode peptides to the fixed-length encoding expected by the neural
......
...@@ -21,6 +21,7 @@ logging.getLogger('matplotlib').disabled = True ...@@ -21,6 +21,7 @@ logging.getLogger('matplotlib').disabled = True
import pandas import pandas
import argparse import argparse
import sys import sys
from functools import partial
from numpy.testing import assert_, assert_equal, assert_allclose from numpy.testing import assert_, assert_equal, assert_allclose
import numpy import numpy
...@@ -107,7 +108,9 @@ def evaluate_loss(loss, y_true, y_pred): ...@@ -107,7 +108,9 @@ def evaluate_loss(loss, y_true, y_pred):
raise ValueError("Unsupported backend: %s" % K.backend()) raise ValueError("Unsupported backend: %s" % K.backend())
def Xtest_loss(): def test_loss():
delta = 0.4
# Hit labels # Hit labels
y_true = [ y_true = [
1.0, 1.0,
...@@ -134,7 +137,7 @@ def Xtest_loss(): ...@@ -134,7 +137,7 @@ def Xtest_loss():
if y_true[j] == 0.0: if y_true[j] == 0.0:
tightest_i = max(y_pred[i]) tightest_i = max(y_pred[i])
contribution = sum( contribution = sum(
max(0, y_pred[j, k] - tightest_i)**2 max(0, y_pred[j, k] - tightest_i + delta)**2
for k in range(y_pred.shape[1]) for k in range(y_pred.shape[1])
) )
contributions.append(contribution) contributions.append(contribution)
...@@ -145,12 +148,12 @@ def Xtest_loss(): ...@@ -145,12 +148,12 @@ def Xtest_loss():
pos = y_pred[y_true.astype(bool)].max(1) pos = y_pred[y_true.astype(bool)].max(1)
neg = y_pred[~y_true.astype(bool)] neg = y_pred[~y_true.astype(bool)]
expected2 = ( expected2 = (
numpy.maximum(0, neg.reshape((-1, 1)) - pos)**2).sum() numpy.maximum(0, neg.reshape((-1, 1)) - pos + delta)**2).sum()
numpy.testing.assert_almost_equal(expected1, expected2) numpy.testing.assert_almost_equal(expected1, expected2)
computed = evaluate_loss( computed = evaluate_loss(
Class1LigandomePredictor.loss, partial(Class1LigandomePredictor.loss, delta=delta),
y_true, y_true,
y_pred.reshape(y_pred.shape + (1,))) y_pred.reshape(y_pred.shape + (1,)))
numpy.testing.assert_almost_equal(computed, expected1) numpy.testing.assert_almost_equal(computed, expected1)
...@@ -197,20 +200,16 @@ def make_motif(allele, peptides, frac=0.01): ...@@ -197,20 +200,16 @@ def make_motif(allele, peptides, frac=0.01):
peptides=peptides, peptides=peptides,
allele=allele, allele=allele,
) )
random_predictions_df = pandas.DataFrame({"peptide": peptides.sequences}) random_predictions_df = pandas.DataFrame({"peptide": peptides.sequences})
random_predictions_df["prediction"] = predictions random_predictions_df["prediction"] = predictions
random_predictions_df = random_predictions_df.sort_values( random_predictions_df = random_predictions_df.sort_values(
"prediction", ascending=True) "prediction", ascending=True)
#print("Random peptide predictions", allele)
#print(random_predictions_df)
top = random_predictions_df.iloc[:int(len(random_predictions_df) * frac)] top = random_predictions_df.iloc[:int(len(random_predictions_df) * frac)]
matrix = positional_frequency_matrix(top.peptide.values) matrix = positional_frequency_matrix(top.peptide.values)
#print("Matrix")
return matrix return matrix
def test_synthetic_allele_refinement(): def test_synthetic_allele_refinement(max_epochs=10):
refine_allele = "HLA-C*01:02" refine_allele = "HLA-C*01:02"
alleles = [ alleles = [
"HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01", "HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01",
...@@ -266,8 +265,8 @@ def test_synthetic_allele_refinement(): ...@@ -266,8 +265,8 @@ def test_synthetic_allele_refinement():
predictor = Class1LigandomePredictor( predictor = Class1LigandomePredictor(
PAN_ALLELE_PREDICTOR_NO_MASS_SPEC, PAN_ALLELE_PREDICTOR_NO_MASS_SPEC,
max_ensemble_size=1, max_ensemble_size=1,
max_epochs=10, max_epochs=max_epochs,
learning_rate=0.00001, learning_rate=0.0001,
patience=5, patience=5,
min_delta=0.0) min_delta=0.0)
...@@ -295,8 +294,6 @@ def test_synthetic_allele_refinement(): ...@@ -295,8 +294,6 @@ def test_synthetic_allele_refinement():
pre_auc = roc_auc_score(train_df.hit.values, train_df.pre_max_prediction.values) pre_auc = roc_auc_score(train_df.hit.values, train_df.pre_max_prediction.values)
print("PRE_AUC", pre_auc) print("PRE_AUC", pre_auc)
#import ipdb ; ipdb.set_trace()
assert_allclose(pre_predictions, expected_pre_predictions) assert_allclose(pre_predictions, expected_pre_predictions)
motifs_history = [] motifs_history = []
...@@ -396,13 +393,19 @@ parser.add_argument( ...@@ -396,13 +393,19 @@ parser.add_argument(
"--out-motifs-pickle", "--out-motifs-pickle",
default=None, default=None,
help="Metrics output") help="Metrics output")
parser.add_argument(
"--max-epochs",
default=100,
type=int,
help="Max epochs")
if __name__ == '__main__': if __name__ == '__main__':
# If run directly from python, leave the user in a shell to explore results. # If run directly from python, leave the user in a shell to explore results.
setup() setup()
args = parser.parse_args(sys.argv[1:]) args = parser.parse_args(sys.argv[1:])
(predictor, predictions, metrics, motifs) = test_synthetic_allele_refinement() (predictor, predictions, metrics, motifs) = (
test_synthetic_allele_refinement(max_epochs=args.max_epochs))
if args.out_metrics_csv: if args.out_metrics_csv:
metrics.to_csv(args.out_metrics_csv) metrics.to_csv(args.out_metrics_csv)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment