Skip to content
Snippets Groups Projects
Commit 864cabaa authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

better loss

parent 288e0db7
No related branches found
No related tags found
No related merge requests found
......@@ -140,88 +140,29 @@ class Class1LigandomePredictor(object):
name="ligandome",
)
#print('trainable', network.get_layer("td_dense_0").trainable)
network.get_layer("td_dense_0").trainable = False
#network.get_layer("td_dense_0").trainable = False
#print('trainable', network.get_layer("td_dense_0").trainable)
return network
@staticmethod
def loss(y_true, y_pred, lmbda=0.001):
import keras.backend as K
def loss(y_true, y_pred, delta=0.2):
"""
Loss function for ligandome prediction.
"""
import tensorflow as tf
y_pred = tf.squeeze(y_pred, axis=-1)
#y_pred = tf.Print(y_pred, [y_pred, tf.shape(y_pred)], "y_pred", summarize=20)
#y_true = tf.Print(y_true, [y_true, tf.shape(y_true)], "y_true", summarize=20)
y_true = tf.reshape(tf.cast(y_true, tf.bool), (-1,))
pos = tf.boolean_mask(y_pred, y_true)
pos_max = tf.reduce_max(pos, axis=1)
#pos_max = tf.reduce_logsumexp(tf.boolean_mask(y_pred, y_true), axis=1)
#pos_max = tf.reduce_max(pos, axis=1)
neg = tf.boolean_mask(y_pred, tf.logical_not(y_true))
result = tf.reduce_sum(
tf.maximum(0.0, tf.reshape(neg, (-1, 1)) - pos_max)**2)
term2 = tf.reduce_sum(
tf.minimum(0.0, tf.reshape(neg, (-1, 1)) - pos_max))
result = result + lmbda * term2
#differences = tf.reshape(neg, (-1, 1)) - pos
#result = tf.reduce_sum(tf.sign(differences) * differences**2)
#result = tf.Print(result, [result], "result", summarize=20)
#term2 = lmbda * tf.reduce_mean((1 - pos)**2)
#result = result + term2
tf.maximum(0.0, tf.reshape(neg, (-1, 1)) - pos_max + delta) ** 2)
return result
"""
pos = tf.boolean_mask(y_pred, y_true)
pos = y_pred[y_true.astype(bool)].max(1)
neg = y_pred[~y_true.astype(bool)]
expected2 = (numpy.maximum(0,
neg.flatten().reshape((-1, 1)) - pos) ** 2).sum()
"""
@staticmethod
def loss_old(y_true, y_pred):
"""Binary cross entropy after taking logsumexp over predictions"""
import keras.backend as K
import tensorflow as tf
#y_pred_aggregated = K.logsumexp(y_pred, axis=1, keepdims=True)
#y_pred_aggregated = K.sigmoid(y_pred_aggregated)
#y_pred = tf.Print(y_pred, [y_pred], "y_pred", summarize=20)
#y_true = tf.Print(y_true, [y_true], "y_true", summarize=20)
y_pred_aggregated = K.max(y_pred, axis=1, keepdims=False)
#y_pred_aggregated = tf.Print(y_pred_aggregated, [y_pred_aggregated], "y_pred_aggregated",
# summarize=20)
y_true = K.squeeze(K.cast(y_true, y_pred_aggregated.dtype), axis=-1)
#print("SHAPES", y_pred, K.int_shape(y_pred), y_pred_aggregated, K.int_shape(y_pred_aggregated), y_true, K.int_shape(y_true))
#K.print_tensor(y_pred_aggregated, "y_pred_aggregated")
#K.print_tensor(y_true, "y_true")
#y_pred_aggregated = K.print_tensor(y_pred_aggregated, "y_pred_aggregated")
#y_true = K.print_tensor(y_true, "y_true")
#return K.mean(
# K.binary_crossentropy(y_true, y_pred_aggregated),
# axis=-1)
return K.mean(
(y_true - y_pred_aggregated)**2,
axis=-1
)
def peptides_to_network_input(self, peptides):
"""
Encode peptides to the fixed-length encoding expected by the neural
......
......@@ -21,6 +21,7 @@ logging.getLogger('matplotlib').disabled = True
import pandas
import argparse
import sys
from functools import partial
from numpy.testing import assert_, assert_equal, assert_allclose
import numpy
......@@ -107,7 +108,9 @@ def evaluate_loss(loss, y_true, y_pred):
raise ValueError("Unsupported backend: %s" % K.backend())
def Xtest_loss():
def test_loss():
delta = 0.4
# Hit labels
y_true = [
1.0,
......@@ -134,7 +137,7 @@ def Xtest_loss():
if y_true[j] == 0.0:
tightest_i = max(y_pred[i])
contribution = sum(
max(0, y_pred[j, k] - tightest_i)**2
max(0, y_pred[j, k] - tightest_i + delta)**2
for k in range(y_pred.shape[1])
)
contributions.append(contribution)
......@@ -145,12 +148,12 @@ def Xtest_loss():
pos = y_pred[y_true.astype(bool)].max(1)
neg = y_pred[~y_true.astype(bool)]
expected2 = (
numpy.maximum(0, neg.reshape((-1, 1)) - pos)**2).sum()
numpy.maximum(0, neg.reshape((-1, 1)) - pos + delta)**2).sum()
numpy.testing.assert_almost_equal(expected1, expected2)
computed = evaluate_loss(
Class1LigandomePredictor.loss,
partial(Class1LigandomePredictor.loss, delta=delta),
y_true,
y_pred.reshape(y_pred.shape + (1,)))
numpy.testing.assert_almost_equal(computed, expected1)
......@@ -197,20 +200,16 @@ def make_motif(allele, peptides, frac=0.01):
peptides=peptides,
allele=allele,
)
random_predictions_df = pandas.DataFrame({"peptide": peptides.sequences})
random_predictions_df["prediction"] = predictions
random_predictions_df = random_predictions_df.sort_values(
"prediction", ascending=True)
#print("Random peptide predictions", allele)
#print(random_predictions_df)
top = random_predictions_df.iloc[:int(len(random_predictions_df) * frac)]
matrix = positional_frequency_matrix(top.peptide.values)
#print("Matrix")
return matrix
def test_synthetic_allele_refinement():
def test_synthetic_allele_refinement(max_epochs=10):
refine_allele = "HLA-C*01:02"
alleles = [
"HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01",
......@@ -266,8 +265,8 @@ def test_synthetic_allele_refinement():
predictor = Class1LigandomePredictor(
PAN_ALLELE_PREDICTOR_NO_MASS_SPEC,
max_ensemble_size=1,
max_epochs=10,
learning_rate=0.00001,
max_epochs=max_epochs,
learning_rate=0.0001,
patience=5,
min_delta=0.0)
......@@ -295,8 +294,6 @@ def test_synthetic_allele_refinement():
pre_auc = roc_auc_score(train_df.hit.values, train_df.pre_max_prediction.values)
print("PRE_AUC", pre_auc)
#import ipdb ; ipdb.set_trace()
assert_allclose(pre_predictions, expected_pre_predictions)
motifs_history = []
......@@ -396,13 +393,19 @@ parser.add_argument(
"--out-motifs-pickle",
default=None,
help="Metrics output")
parser.add_argument(
"--max-epochs",
default=100,
type=int,
help="Max epochs")
if __name__ == '__main__':
# If run directly from python, leave the user in a shell to explore results.
setup()
args = parser.parse_args(sys.argv[1:])
(predictor, predictions, metrics, motifs) = test_synthetic_allele_refinement()
(predictor, predictions, metrics, motifs) = (
test_synthetic_allele_refinement(max_epochs=args.max_epochs))
if args.out_metrics_csv:
metrics.to_csv(args.out_metrics_csv)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment