Skip to content
Snippets Groups Projects
Commit a03267d5 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

cleaner losses

parent b5f9a292
No related merge requests found
......@@ -17,7 +17,10 @@ from .random_negative_peptides import RandomNegativePeptides
from .allele_encoding import MultipleAlleleEncoding, AlleleEncoding
from .auxiliary_input import AuxiliaryInputEncoder
from .batch_generator import MultiallelicMassSpecBatchGenerator
from .custom_loss import MSEWithInequalities, MultiallelicMassSpecLoss
from .custom_loss import (
MSEWithInequalities,
TransformPredictionsLossWrapper,
MultiallelicMassSpecLoss)
class Class1PresentationNeuralNetwork(object):
......@@ -412,12 +415,13 @@ class Class1PresentationNeuralNetwork(object):
y1,
])
def keras_max(matrix):
def tensor_max(matrix):
import keras.backend as K
result = K.max(matrix, axis=1)
return result
return K.max(matrix, axis=1)
affinities_loss = MSEWithInequalities(transform_function=keras_max)
affinities_loss = TransformPredictionsLossWrapper(
loss=MSEWithInequalities(),
y_pred_transform=tensor_max)
encoded_y1 = affinities_loss.encode_y(
y1_with_random_negatives,
inequalities=adjusted_inequalities_with_random_negative)
......
......@@ -82,6 +82,31 @@ class StandardKerasLoss(Loss):
return y
class TransformPredictionsLossWrapper(Loss):
"""
Wrapper that applies an arbitrary transform to y_pred before calling an
underlying loss function.
The y_pred_transform function should be a tensor -> tensor function.
"""
def __init__(
self,
loss,
y_pred_transform=None):
self.wrapped_loss = loss
self.name = "transformed_%s" % loss.name
self.y_pred_transform = y_pred_transform
self.supports_inequalities = loss.supports_inequalities
self.supports_multiple_outputs = loss.supports_multiple_outputs
def encode_y(self, *args, **kwargs):
return self.wrapped_loss.encode_y(*args, **kwargs)
def loss(self, y_true, y_pred):
y_pred_transformed = self.y_pred_transform(y_pred)
return self.wrapped_loss.loss(y_true, y_pred_transformed)
class MSEWithInequalities(Loss):
"""
Supports training a regression model on data that includes inequalities
......@@ -111,9 +136,6 @@ class MSEWithInequalities(Loss):
supports_inequalities = True
supports_multiple_outputs = False
def __init__(self, transform_function=None):
self.transform_function = transform_function
@staticmethod
def encode_y(y, inequalities=None):
y = array(y, dtype="float32")
......@@ -142,11 +164,6 @@ class MSEWithInequalities(Loss):
# We always delay import of Keras so that mhcflurry can be imported
# initially without tensorflow debug output, etc.
from keras import backend as K
import tensorflow as tf
if self.transform_function:
y_pred = self.transform_function(y_pred)
y_true = K.squeeze(y_true, axis=-1)
# Handle (=) inequalities
......@@ -172,8 +189,6 @@ class MSEWithInequalities(Loss):
return result
#return tf.where(tf.is_nan(result), tf.zeros_like(result), result)
class MSEWithInequalitiesAndMultipleOutputs(Loss):
"""
......@@ -200,9 +215,6 @@ class MSEWithInequalitiesAndMultipleOutputs(Loss):
supports_inequalities = True
supports_multiple_outputs = True
def __init__(self, transform_function=None):
self.transform_function = transform_function
@staticmethod
def encode_y(y, inequalities=None, output_indices=None):
y = array(y, dtype="float32")
......@@ -228,10 +240,6 @@ class MSEWithInequalitiesAndMultipleOutputs(Loss):
def loss(self, y_true, y_pred):
from keras import backend as K
if self.transform_function:
y_pred = self.transform_function(y_pred)
y_true = K.flatten(y_true)
output_indices = y_true // 10
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment