diff --git a/mhcflurry/class1_presentation_neural_network.py b/mhcflurry/class1_presentation_neural_network.py
index 5c33f1f3aa38ea9e67679f12fc8a068ba9d44cec..503fa4bcb2bce288fb474d3cce181ffe8a8534aa 100644
--- a/mhcflurry/class1_presentation_neural_network.py
+++ b/mhcflurry/class1_presentation_neural_network.py
@@ -17,7 +17,10 @@ from .random_negative_peptides import RandomNegativePeptides
 from .allele_encoding import MultipleAlleleEncoding, AlleleEncoding
 from .auxiliary_input import AuxiliaryInputEncoder
 from .batch_generator import MultiallelicMassSpecBatchGenerator
-from .custom_loss import MSEWithInequalities, MultiallelicMassSpecLoss
+from .custom_loss import (
+    MSEWithInequalities,
+    TransformPredictionsLossWrapper,
+    MultiallelicMassSpecLoss)
 
 
 class Class1PresentationNeuralNetwork(object):
@@ -412,12 +415,13 @@ class Class1PresentationNeuralNetwork(object):
             y1,
         ])
 
-        def keras_max(matrix):
+        def tensor_max(matrix):
             import keras.backend as K
-            result = K.max(matrix, axis=1)
-            return result
+            return K.max(matrix, axis=1)
 
-        affinities_loss = MSEWithInequalities(transform_function=keras_max)
+        affinities_loss = TransformPredictionsLossWrapper(
+            loss=MSEWithInequalities(),
+            y_pred_transform=tensor_max)
         encoded_y1 = affinities_loss.encode_y(
             y1_with_random_negatives,
             inequalities=adjusted_inequalities_with_random_negative)
diff --git a/mhcflurry/custom_loss.py b/mhcflurry/custom_loss.py
index 6132ddf695b9cac0d29039a1a72c3864f4d2aab6..dcb8c0a8ee0d64b218842cf8fee7479f08eb7aa8 100644
--- a/mhcflurry/custom_loss.py
+++ b/mhcflurry/custom_loss.py
@@ -82,6 +82,31 @@ class StandardKerasLoss(Loss):
         return y
 
 
+class TransformPredictionsLossWrapper(Loss):
+    """
+    Wrapper that applies an arbitrary transform to y_pred before calling an
+    underlying loss function.
+
+    The y_pred_transform function should be a tensor -> tensor function.
+    """
+    def __init__(
+            self,
+            loss,
+            y_pred_transform=None):
+        self.wrapped_loss = loss
+        self.name = "transformed_%s" % loss.name
+        self.y_pred_transform = y_pred_transform
+        self.supports_inequalities = loss.supports_inequalities
+        self.supports_multiple_outputs = loss.supports_multiple_outputs
+
+    def encode_y(self, *args, **kwargs):
+        return self.wrapped_loss.encode_y(*args, **kwargs)
+
+    def loss(self, y_true, y_pred):
+        y_pred_transformed = self.y_pred_transform(y_pred)
+        return self.wrapped_loss.loss(y_true, y_pred_transformed)
+
+
 class MSEWithInequalities(Loss):
     """
     Supports training a regression model on data that includes inequalities
@@ -111,9 +136,6 @@ class MSEWithInequalities(Loss):
     supports_inequalities = True
     supports_multiple_outputs = False
 
-    def __init__(self, transform_function=None):
-        self.transform_function = transform_function
-
     @staticmethod
     def encode_y(y, inequalities=None):
         y = array(y, dtype="float32")
@@ -142,11 +164,6 @@ class MSEWithInequalities(Loss):
         # We always delay import of Keras so that mhcflurry can be imported
         # initially without tensorflow debug output, etc.
         from keras import backend as K
-        import tensorflow as tf
-
-        if self.transform_function:
-            y_pred = self.transform_function(y_pred)
-
         y_true = K.squeeze(y_true, axis=-1)
 
         # Handle (=) inequalities
@@ -172,8 +189,6 @@ class MSEWithInequalities(Loss):
 
         return result
 
-        #return tf.where(tf.is_nan(result), tf.zeros_like(result), result)
-
 
 class MSEWithInequalitiesAndMultipleOutputs(Loss):
     """
@@ -200,9 +215,6 @@ class MSEWithInequalitiesAndMultipleOutputs(Loss):
     supports_inequalities = True
     supports_multiple_outputs = True
 
-    def __init__(self, transform_function=None):
-        self.transform_function = transform_function
-
     @staticmethod
     def encode_y(y, inequalities=None, output_indices=None):
         y = array(y, dtype="float32")
@@ -228,10 +240,6 @@ class MSEWithInequalitiesAndMultipleOutputs(Loss):
 
     def loss(self, y_true, y_pred):
         from keras import backend as K
-
-        if self.transform_function:
-            y_pred = self.transform_function(y_pred)
-
         y_true = K.flatten(y_true)
 
         output_indices = y_true // 10