diff --git a/mhcflurry/class1_ligandome_predictor.py b/mhcflurry/class1_ligandome_predictor.py
index 5be86637fa093e5325db2bb82cbcf00e2955ccde..74a9e1db4715657448e3cfe4d1fc6f8eb5c139f2 100644
--- a/mhcflurry/class1_ligandome_predictor.py
+++ b/mhcflurry/class1_ligandome_predictor.py
@@ -140,88 +140,29 @@ class Class1LigandomePredictor(object):
             name="ligandome",
         )
         #print('trainable', network.get_layer("td_dense_0").trainable)
-        network.get_layer("td_dense_0").trainable = False
+        #network.get_layer("td_dense_0").trainable = False
         #print('trainable', network.get_layer("td_dense_0").trainable)
 
         return network
 
     @staticmethod
-    def loss(y_true, y_pred, lmbda=0.001):
-        import keras.backend as K
+    def loss(y_true, y_pred, delta=0.2):
+        """
+        Loss function for ligandome prediction.
+        """
         import tensorflow as tf
 
         y_pred = tf.squeeze(y_pred, axis=-1)
-
-        #y_pred = tf.Print(y_pred, [y_pred, tf.shape(y_pred)], "y_pred", summarize=20)
-        #y_true = tf.Print(y_true, [y_true, tf.shape(y_true)], "y_true", summarize=20)
-
         y_true = tf.reshape(tf.cast(y_true, tf.bool), (-1,))
 
         pos = tf.boolean_mask(y_pred, y_true)
-        pos_max = tf.reduce_max(pos, axis=1)
-        #pos_max = tf.reduce_logsumexp(tf.boolean_mask(y_pred, y_true), axis=1)
+        #pos_max = tf.reduce_max(pos, axis=1)
+        
         neg = tf.boolean_mask(y_pred, tf.logical_not(y_true))
-
         result = tf.reduce_sum(
-            tf.maximum(0.0, tf.reshape(neg, (-1, 1)) - pos_max)**2)
-
-        term2 = tf.reduce_sum(
-            tf.minimum(0.0, tf.reshape(neg, (-1, 1)) - pos_max))
-        result = result + lmbda * term2
-
-        #differences = tf.reshape(neg, (-1, 1)) - pos
-
-        #result = tf.reduce_sum(tf.sign(differences) * differences**2)
-        #result = tf.Print(result, [result], "result", summarize=20)
-
-        #term2 = lmbda * tf.reduce_mean((1 - pos)**2)
-        #result = result + term2
+            tf.maximum(0.0, tf.reshape(neg, (-1, 1)) - pos_max + delta) ** 2)
         return result
 
-        """
-        pos = tf.boolean_mask(y_pred, y_true)
-
-        pos = y_pred[y_true.astype(bool)].max(1)
-        neg = y_pred[~y_true.astype(bool)]
-        expected2 = (numpy.maximum(0,
-            neg.flatten().reshape((-1, 1)) - pos) ** 2).sum()
-        """
-
-
-
-
-    @staticmethod
-    def loss_old(y_true, y_pred):
-        """Binary cross entropy after taking logsumexp over predictions"""
-        import keras.backend as K
-        import tensorflow as tf
-        #y_pred_aggregated = K.logsumexp(y_pred, axis=1, keepdims=True)
-        #y_pred_aggregated = K.sigmoid(y_pred_aggregated)
-        #y_pred = tf.Print(y_pred, [y_pred], "y_pred", summarize=20)
-        #y_true = tf.Print(y_true, [y_true], "y_true", summarize=20)
-
-        y_pred_aggregated = K.max(y_pred, axis=1, keepdims=False)
-        #y_pred_aggregated = tf.Print(y_pred_aggregated, [y_pred_aggregated], "y_pred_aggregated",
-        #    summarize=20)
-
-        y_true = K.squeeze(K.cast(y_true, y_pred_aggregated.dtype), axis=-1)
-        #print("SHAPES", y_pred, K.int_shape(y_pred), y_pred_aggregated, K.int_shape(y_pred_aggregated), y_true, K.int_shape(y_true))
-        #K.print_tensor(y_pred_aggregated, "y_pred_aggregated")
-        #K.print_tensor(y_true, "y_true")
-
-        #y_pred_aggregated = K.print_tensor(y_pred_aggregated, "y_pred_aggregated")
-
-
-        #y_true = K.print_tensor(y_true, "y_true")
-
-        #return K.mean(
-        #    K.binary_crossentropy(y_true, y_pred_aggregated),
-        #    axis=-1)
-        return K.mean(
-            (y_true - y_pred_aggregated)**2,
-            axis=-1
-        )
-
     def peptides_to_network_input(self, peptides):
         """
         Encode peptides to the fixed-length encoding expected by the neural
diff --git a/test/test_class1_ligandome_predictor.py b/test/test_class1_ligandome_predictor.py
index 411d9c066898c63373ca89439810f657fa5dfbb6..7fb5cacf44a36c97e9ab62b1fc4973d6d4761a30 100644
--- a/test/test_class1_ligandome_predictor.py
+++ b/test/test_class1_ligandome_predictor.py
@@ -21,6 +21,7 @@ logging.getLogger('matplotlib').disabled = True
 import pandas
 import argparse
 import sys
+from functools import partial
 
 from numpy.testing import assert_, assert_equal, assert_allclose
 import numpy
@@ -107,7 +108,9 @@ def evaluate_loss(loss, y_true, y_pred):
         raise ValueError("Unsupported backend: %s" % K.backend())
 
 
-def Xtest_loss():
+def test_loss():
+    delta = 0.4
+
     # Hit labels
     y_true = [
         1.0,
@@ -134,7 +137,7 @@ def Xtest_loss():
                 if y_true[j] == 0.0:
                     tightest_i = max(y_pred[i])
                     contribution = sum(
-                        max(0, y_pred[j, k] - tightest_i)**2
+                        max(0, y_pred[j, k] - tightest_i + delta)**2
                         for k in range(y_pred.shape[1])
                     )
                     contributions.append(contribution)
@@ -145,12 +148,12 @@ def Xtest_loss():
     pos = y_pred[y_true.astype(bool)].max(1)
     neg = y_pred[~y_true.astype(bool)]
     expected2 = (
-            numpy.maximum(0, neg.reshape((-1, 1)) - pos)**2).sum()
+            numpy.maximum(0, neg.reshape((-1, 1)) - pos + delta)**2).sum()
 
     numpy.testing.assert_almost_equal(expected1, expected2)
 
     computed = evaluate_loss(
-        Class1LigandomePredictor.loss,
+        partial(Class1LigandomePredictor.loss, delta=delta),
         y_true,
         y_pred.reshape(y_pred.shape + (1,)))
     numpy.testing.assert_almost_equal(computed, expected1)
@@ -197,20 +200,16 @@ def make_motif(allele, peptides, frac=0.01):
         peptides=peptides,
         allele=allele,
     )
-
     random_predictions_df = pandas.DataFrame({"peptide": peptides.sequences})
     random_predictions_df["prediction"] = predictions
     random_predictions_df = random_predictions_df.sort_values(
         "prediction", ascending=True)
-    #print("Random peptide predictions", allele)
-    #print(random_predictions_df)
     top = random_predictions_df.iloc[:int(len(random_predictions_df) * frac)]
     matrix = positional_frequency_matrix(top.peptide.values)
-    #print("Matrix")
     return matrix
 
 
-def test_synthetic_allele_refinement():
+def test_synthetic_allele_refinement(max_epochs=10):
     refine_allele = "HLA-C*01:02"
     alleles = [
         "HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01",
@@ -266,8 +265,8 @@ def test_synthetic_allele_refinement():
     predictor = Class1LigandomePredictor(
         PAN_ALLELE_PREDICTOR_NO_MASS_SPEC,
         max_ensemble_size=1,
-        max_epochs=10,
-        learning_rate=0.00001,
+        max_epochs=max_epochs,
+        learning_rate=0.0001,
         patience=5,
         min_delta=0.0)
 
@@ -295,8 +294,6 @@ def test_synthetic_allele_refinement():
     pre_auc = roc_auc_score(train_df.hit.values, train_df.pre_max_prediction.values)
     print("PRE_AUC", pre_auc)
 
-    #import ipdb ; ipdb.set_trace()
-
     assert_allclose(pre_predictions, expected_pre_predictions)
 
     motifs_history = []
@@ -396,13 +393,19 @@ parser.add_argument(
     "--out-motifs-pickle",
     default=None,
     help="Metrics output")
+parser.add_argument(
+    "--max-epochs",
+    default=100,
+    type=int,
+    help="Max epochs")
 
 
 if __name__ == '__main__':
     # If run directly from python, leave the user in a shell to explore results.
     setup()
     args = parser.parse_args(sys.argv[1:])
-    (predictor, predictions, metrics, motifs) = test_synthetic_allele_refinement()
+    (predictor, predictions, metrics, motifs) = (
+        test_synthetic_allele_refinement(max_epochs=args.max_epochs))
 
     if args.out_metrics_csv:
         metrics.to_csv(args.out_metrics_csv)