diff --git a/mhcflurry/class1_presentation_neural_network.py b/mhcflurry/class1_presentation_neural_network.py index 5c33f1f3aa38ea9e67679f12fc8a068ba9d44cec..503fa4bcb2bce288fb474d3cce181ffe8a8534aa 100644 --- a/mhcflurry/class1_presentation_neural_network.py +++ b/mhcflurry/class1_presentation_neural_network.py @@ -17,7 +17,10 @@ from .random_negative_peptides import RandomNegativePeptides from .allele_encoding import MultipleAlleleEncoding, AlleleEncoding from .auxiliary_input import AuxiliaryInputEncoder from .batch_generator import MultiallelicMassSpecBatchGenerator -from .custom_loss import MSEWithInequalities, MultiallelicMassSpecLoss +from .custom_loss import ( + MSEWithInequalities, + TransformPredictionsLossWrapper, + MultiallelicMassSpecLoss) class Class1PresentationNeuralNetwork(object): @@ -412,12 +415,13 @@ class Class1PresentationNeuralNetwork(object): y1, ]) - def keras_max(matrix): + def tensor_max(matrix): import keras.backend as K - result = K.max(matrix, axis=1) - return result + return K.max(matrix, axis=1) - affinities_loss = MSEWithInequalities(transform_function=keras_max) + affinities_loss = TransformPredictionsLossWrapper( + loss=MSEWithInequalities(), + y_pred_transform=tensor_max) encoded_y1 = affinities_loss.encode_y( y1_with_random_negatives, inequalities=adjusted_inequalities_with_random_negative) diff --git a/mhcflurry/custom_loss.py b/mhcflurry/custom_loss.py index 6132ddf695b9cac0d29039a1a72c3864f4d2aab6..dcb8c0a8ee0d64b218842cf8fee7479f08eb7aa8 100644 --- a/mhcflurry/custom_loss.py +++ b/mhcflurry/custom_loss.py @@ -82,6 +82,31 @@ class StandardKerasLoss(Loss): return y +class TransformPredictionsLossWrapper(Loss): + """ + Wrapper that applies an arbitrary transform to y_pred before calling an + underlying loss function. + + The y_pred_transform function should be a tensor -> tensor function. + """ + def __init__( + self, + loss, + y_pred_transform=None): + self.wrapped_loss = loss + self.name = "transformed_%s" % loss.name + self.y_pred_transform = y_pred_transform + self.supports_inequalities = loss.supports_inequalities + self.supports_multiple_outputs = loss.supports_multiple_outputs + + def encode_y(self, *args, **kwargs): + return self.wrapped_loss.encode_y(*args, **kwargs) + + def loss(self, y_true, y_pred): + y_pred_transformed = self.y_pred_transform(y_pred) + return self.wrapped_loss.loss(y_true, y_pred_transformed) + + class MSEWithInequalities(Loss): """ Supports training a regression model on data that includes inequalities @@ -111,9 +136,6 @@ class MSEWithInequalities(Loss): supports_inequalities = True supports_multiple_outputs = False - def __init__(self, transform_function=None): - self.transform_function = transform_function - @staticmethod def encode_y(y, inequalities=None): y = array(y, dtype="float32") @@ -142,11 +164,6 @@ class MSEWithInequalities(Loss): # We always delay import of Keras so that mhcflurry can be imported # initially without tensorflow debug output, etc. from keras import backend as K - import tensorflow as tf - - if self.transform_function: - y_pred = self.transform_function(y_pred) - y_true = K.squeeze(y_true, axis=-1) # Handle (=) inequalities @@ -172,8 +189,6 @@ class MSEWithInequalities(Loss): return result - #return tf.where(tf.is_nan(result), tf.zeros_like(result), result) - class MSEWithInequalitiesAndMultipleOutputs(Loss): """ @@ -200,9 +215,6 @@ class MSEWithInequalitiesAndMultipleOutputs(Loss): supports_inequalities = True supports_multiple_outputs = True - def __init__(self, transform_function=None): - self.transform_function = transform_function - @staticmethod def encode_y(y, inequalities=None, output_indices=None): y = array(y, dtype="float32") @@ -228,10 +240,6 @@ class MSEWithInequalitiesAndMultipleOutputs(Loss): def loss(self, y_true, y_pred): from keras import backend as K - - if self.transform_function: - y_pred = self.transform_function(y_pred) - y_true = K.flatten(y_true) output_indices = y_true // 10