diff --git a/mhcflurry/class1_ligandome_predictor.py b/mhcflurry/class1_ligandome_predictor.py
index 2dc4350932844c8f3e7c395397fd21c0f438e7ec..685f2083ade3af8734e2bdd71a94515724943b1f 100644
--- a/mhcflurry/class1_ligandome_predictor.py
+++ b/mhcflurry/class1_ligandome_predictor.py
@@ -1,5 +1,8 @@
+from __future__ import print_function
+
 import time
 import collections
+from functools import partial
 
 import numpy
 
@@ -43,7 +46,8 @@ class Class1LigandomePredictor(object):
     """
 
     compile_hyperparameter_defaults = HyperparameterDefaults(
-        loss="custom:mse_with_inequalities",
+        loss_delta=0.2,
+        loss_alpha=None,
         optimizer="rmsprop",
         learning_rate=None,
     )
@@ -146,18 +150,26 @@ class Class1LigandomePredictor(object):
         return network
 
     @staticmethod
-    def loss(y_true, y_pred, delta=0.2):
+    def loss(y_true, y_pred, delta=0.2, alpha=None):
         """
         Loss function for ligandome prediction.
         """
         import tensorflow as tf
+        import keras.backend as K
 
         y_pred = tf.squeeze(y_pred, axis=-1)
         y_true = tf.reshape(tf.cast(y_true, tf.bool), (-1,))
 
         pos = tf.boolean_mask(y_pred, y_true)
-        pos_max = tf.reduce_max(pos, axis=1)
 
+        if alpha is None:
+            pos_max = tf.reduce_max(pos, axis=1)
+        else:
+            # Smooth maximum
+            exp_alpha_x = tf.exp(alpha * pos)
+            numerator = tf.reduce_sum(tf.multiply(pos, exp_alpha_x), axis=1)
+            denominator = tf.reduce_sum(exp_alpha_x, axis=1)
+            pos_max = numerator / denominator
         neg = tf.boolean_mask(y_pred, tf.logical_not(y_true))
         result = tf.reduce_sum(
             tf.maximum(0.0, tf.reshape(neg, (-1, 1)) - pos_max + delta) ** 2)
@@ -246,7 +258,10 @@ class Class1LigandomePredictor(object):
 
         self.set_allele_representations(allele_representations)
         self.network.compile(
-            loss=self.loss,
+            loss=partial(
+                self.loss,
+                delta=self.hyperparameters['loss_delta'],
+                alpha=self.hyperparameters['loss_alpha']),
             optimizer=self.hyperparameters['optimizer'])
         if self.hyperparameters['learning_rate'] is not None:
             K.set_value(
diff --git a/test/test_class1_ligandome_predictor.py b/test/test_class1_ligandome_predictor.py
index 7fb5cacf44a36c97e9ab62b1fc4973d6d4761a30..03f31081ebcf93c34edbbc071b39556d53562027 100644
--- a/test/test_class1_ligandome_predictor.py
+++ b/test/test_class1_ligandome_predictor.py
@@ -109,54 +109,74 @@ def evaluate_loss(loss, y_true, y_pred):
 
 
 def test_loss():
-    delta = 0.4
-
-    # Hit labels
-    y_true = [
-        1.0,
-        0.0,
-        1.0,
-        1.0,
-        0.0
-    ]
-    y_true = numpy.array(y_true)
-    y_pred = [
-        [0.3, 0.7, 0.5],
-        [0.2, 0.4, 0.6],
-        [0.1, 0.5, 0.3],
-        [0.1, 0.7, 0.1],
-        [0.8, 0.2, 0.4],
-    ]
-    y_pred = numpy.array(y_pred)
-
-    # reference implementation 1
-    contributions = []
-    for i in range(len(y_true)):
-        if y_true[i] == 1.0:
-            for j in range(len(y_true)):
-                if y_true[j] == 0.0:
-                    tightest_i = max(y_pred[i])
-                    contribution = sum(
-                        max(0, y_pred[j, k] - tightest_i + delta)**2
-                        for k in range(y_pred.shape[1])
-                    )
-                    contributions.append(contribution)
-    contributions = numpy.array(contributions)
-    expected1 = contributions.sum()
-
-    # reference implementation 2: numpy
-    pos = y_pred[y_true.astype(bool)].max(1)
-    neg = y_pred[~y_true.astype(bool)]
-    expected2 = (
-            numpy.maximum(0, neg.reshape((-1, 1)) - pos + delta)**2).sum()
-
-    numpy.testing.assert_almost_equal(expected1, expected2)
-
-    computed = evaluate_loss(
-        partial(Class1LigandomePredictor.loss, delta=delta),
-        y_true,
-        y_pred.reshape(y_pred.shape + (1,)))
-    numpy.testing.assert_almost_equal(computed, expected1)
+    for delta in [0.0, 0.3]:
+        for alpha in [None, 1.0, 20.0]:
+            print("delta", delta)
+            print("alpha", alpha)
+            # Hit labels
+            y_true = [
+                1.0,
+                0.0,
+                1.0,
+                1.0,
+                0.0
+            ]
+            y_true = numpy.array(y_true)
+            y_pred = [
+                [0.3, 0.7, 0.5],
+                [0.2, 0.4, 0.6],
+                [0.1, 0.5, 0.3],
+                [0.1, 0.7, 0.1],
+                [0.8, 0.2, 0.4],
+            ]
+            y_pred = numpy.array(y_pred)
+
+            # reference implementation 1
+
+            def smooth_max(x, alpha):
+                x = numpy.array(x)
+                alpha = numpy.array([alpha])
+                return (x * numpy.exp(x * alpha)).sum() / (
+                    numpy.exp(x * alpha)).sum()
+
+            if alpha is None:
+                max_func = max
+            else:
+                max_func = partial(smooth_max, alpha=alpha)
+
+            contributions = []
+            for i in range(len(y_true)):
+                if y_true[i] == 1.0:
+                    for j in range(len(y_true)):
+                        if y_true[j] == 0.0:
+                            tightest_i = max_func(y_pred[i])
+                            contribution = sum(
+                                max(0, y_pred[j, k] - tightest_i + delta)**2
+                                for k in range(y_pred.shape[1])
+                            )
+                            contributions.append(contribution)
+            contributions = numpy.array(contributions)
+            expected1 = contributions.sum()
+
+            # reference implementation 2: numpy
+            pos = numpy.array([
+                max_func(y_pred[i])
+                for i in range(len(y_pred))
+                if y_true[i] == 1.0
+            ])
+
+            neg = y_pred[~y_true.astype(bool)]
+            expected2 = (
+                    numpy.maximum(0, neg.reshape((-1, 1)) - pos + delta)**2).sum()
+
+            yield numpy.testing.assert_almost_equal, expected1, expected2, 4
+
+            computed = evaluate_loss(
+                partial(Class1LigandomePredictor.loss, delta=delta, alpha=alpha),
+                y_true,
+                y_pred.reshape(y_pred.shape + (1,)))
+
+            yield numpy.testing.assert_almost_equal, computed, expected1, 4
 
 
 AA_DIST = pandas.Series(