Skip to content
Snippets Groups Projects
Commit a03267d5 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

cleaner losses

parent b5f9a292
No related branches found
No related tags found
No related merge requests found
...@@ -17,7 +17,10 @@ from .random_negative_peptides import RandomNegativePeptides ...@@ -17,7 +17,10 @@ from .random_negative_peptides import RandomNegativePeptides
from .allele_encoding import MultipleAlleleEncoding, AlleleEncoding from .allele_encoding import MultipleAlleleEncoding, AlleleEncoding
from .auxiliary_input import AuxiliaryInputEncoder from .auxiliary_input import AuxiliaryInputEncoder
from .batch_generator import MultiallelicMassSpecBatchGenerator from .batch_generator import MultiallelicMassSpecBatchGenerator
from .custom_loss import MSEWithInequalities, MultiallelicMassSpecLoss from .custom_loss import (
MSEWithInequalities,
TransformPredictionsLossWrapper,
MultiallelicMassSpecLoss)
class Class1PresentationNeuralNetwork(object): class Class1PresentationNeuralNetwork(object):
...@@ -412,12 +415,13 @@ class Class1PresentationNeuralNetwork(object): ...@@ -412,12 +415,13 @@ class Class1PresentationNeuralNetwork(object):
y1, y1,
]) ])
def keras_max(matrix): def tensor_max(matrix):
import keras.backend as K import keras.backend as K
result = K.max(matrix, axis=1) return K.max(matrix, axis=1)
return result
affinities_loss = MSEWithInequalities(transform_function=keras_max) affinities_loss = TransformPredictionsLossWrapper(
loss=MSEWithInequalities(),
y_pred_transform=tensor_max)
encoded_y1 = affinities_loss.encode_y( encoded_y1 = affinities_loss.encode_y(
y1_with_random_negatives, y1_with_random_negatives,
inequalities=adjusted_inequalities_with_random_negative) inequalities=adjusted_inequalities_with_random_negative)
......
...@@ -82,6 +82,31 @@ class StandardKerasLoss(Loss): ...@@ -82,6 +82,31 @@ class StandardKerasLoss(Loss):
return y return y
class TransformPredictionsLossWrapper(Loss):
"""
Wrapper that applies an arbitrary transform to y_pred before calling an
underlying loss function.
The y_pred_transform function should be a tensor -> tensor function.
"""
def __init__(
self,
loss,
y_pred_transform=None):
self.wrapped_loss = loss
self.name = "transformed_%s" % loss.name
self.y_pred_transform = y_pred_transform
self.supports_inequalities = loss.supports_inequalities
self.supports_multiple_outputs = loss.supports_multiple_outputs
def encode_y(self, *args, **kwargs):
return self.wrapped_loss.encode_y(*args, **kwargs)
def loss(self, y_true, y_pred):
y_pred_transformed = self.y_pred_transform(y_pred)
return self.wrapped_loss.loss(y_true, y_pred_transformed)
class MSEWithInequalities(Loss): class MSEWithInequalities(Loss):
""" """
Supports training a regression model on data that includes inequalities Supports training a regression model on data that includes inequalities
...@@ -111,9 +136,6 @@ class MSEWithInequalities(Loss): ...@@ -111,9 +136,6 @@ class MSEWithInequalities(Loss):
supports_inequalities = True supports_inequalities = True
supports_multiple_outputs = False supports_multiple_outputs = False
def __init__(self, transform_function=None):
self.transform_function = transform_function
@staticmethod @staticmethod
def encode_y(y, inequalities=None): def encode_y(y, inequalities=None):
y = array(y, dtype="float32") y = array(y, dtype="float32")
...@@ -142,11 +164,6 @@ class MSEWithInequalities(Loss): ...@@ -142,11 +164,6 @@ class MSEWithInequalities(Loss):
# We always delay import of Keras so that mhcflurry can be imported # We always delay import of Keras so that mhcflurry can be imported
# initially without tensorflow debug output, etc. # initially without tensorflow debug output, etc.
from keras import backend as K from keras import backend as K
import tensorflow as tf
if self.transform_function:
y_pred = self.transform_function(y_pred)
y_true = K.squeeze(y_true, axis=-1) y_true = K.squeeze(y_true, axis=-1)
# Handle (=) inequalities # Handle (=) inequalities
...@@ -172,8 +189,6 @@ class MSEWithInequalities(Loss): ...@@ -172,8 +189,6 @@ class MSEWithInequalities(Loss):
return result return result
#return tf.where(tf.is_nan(result), tf.zeros_like(result), result)
class MSEWithInequalitiesAndMultipleOutputs(Loss): class MSEWithInequalitiesAndMultipleOutputs(Loss):
""" """
...@@ -200,9 +215,6 @@ class MSEWithInequalitiesAndMultipleOutputs(Loss): ...@@ -200,9 +215,6 @@ class MSEWithInequalitiesAndMultipleOutputs(Loss):
supports_inequalities = True supports_inequalities = True
supports_multiple_outputs = True supports_multiple_outputs = True
def __init__(self, transform_function=None):
self.transform_function = transform_function
@staticmethod @staticmethod
def encode_y(y, inequalities=None, output_indices=None): def encode_y(y, inequalities=None, output_indices=None):
y = array(y, dtype="float32") y = array(y, dtype="float32")
...@@ -228,10 +240,6 @@ class MSEWithInequalitiesAndMultipleOutputs(Loss): ...@@ -228,10 +240,6 @@ class MSEWithInequalitiesAndMultipleOutputs(Loss):
def loss(self, y_true, y_pred): def loss(self, y_true, y_pred):
from keras import backend as K from keras import backend as K
if self.transform_function:
y_pred = self.transform_function(y_pred)
y_true = K.flatten(y_true) y_true = K.flatten(y_true)
output_indices = y_true // 10 output_indices = y_true // 10
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment