diff --git a/mhcflurry/custom_loss.py b/mhcflurry/custom_loss.py
index dcb8c0a8ee0d64b218842cf8fee7479f08eb7aa8..12b2c669245e280431278842af856957c74497e9 100644
--- a/mhcflurry/custom_loss.py
+++ b/mhcflurry/custom_loss.py
@@ -286,14 +286,14 @@ class MultiallelicMassSpecLoss(Loss):
 
     def loss(self, y_true, y_pred):
         import tensorflow as tf
-        y_true = tf.reshape(y_true, (-1,))
+        y_true = tf.squeeze(y_true, axis=-1)
         pos = tf.boolean_mask(y_pred, tf.math.equal(y_true, 1.0))
         pos_max = tf.reduce_max(pos, axis=1)
         neg = tf.boolean_mask(y_pred, tf.math.equal(y_true, 0.0))
         term = tf.reshape(neg, (-1, 1)) - pos_max + self.delta
         result = tf.reduce_sum(tf.maximum(0.0, term) ** 2) / tf.cast(
-            tf.shape(term)[0], tf.float32) * self.multiplier
-        return tf.where(tf.is_nan(result), 0.0, result)
+            tf.size(term), tf.float32) * self.multiplier
+        return result
 
 
 def check_shape(name, arr, expected_shape):
diff --git a/test/test_class1_presentation_predictor.py b/test/test_class1_presentation_predictor.py
index 499b7bd0a1f26dd16a8e800c2d6f0491341273a6..69a8da8085bf0462673aa7d1854afeec5afd4182 100644
--- a/test/test_class1_presentation_predictor.py
+++ b/test/test_class1_presentation_predictor.py
@@ -91,7 +91,7 @@ def teardown():
     cleanup()
 
 
-def test_basic():
+def Xtest_basic():
     affinity_predictor = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
     models = []
     for affinity_network in affinity_predictor.class1_pan_allele_models:
@@ -224,7 +224,7 @@ def evaluate_loss(loss, y_true, y_pred):
         raise ValueError("Unsupported backend: %s" % K.backend())
 
 
-def Xtest_loss():
+def test_loss():
     for delta in [0.0, 0.3]:
         print("delta", delta)
         # Hit labels
@@ -234,7 +234,8 @@ def Xtest_loss():
             1.0,
             -1.0,  # ignored
             1.0,
-            0.0
+            0.0,
+            1.0,
         ]
         y_true = numpy.array(y_true)
         y_pred = [
@@ -244,7 +245,7 @@ def Xtest_loss():
             [0.9, 0.1, 0.2],
             [0.1, 0.7, 0.1],
             [0.8, 0.2, 0.4],
-
+            [0.1, 0.2, 0.4],
         ]
         y_pred = numpy.array(y_pred)
 
@@ -262,11 +263,10 @@ def Xtest_loss():
                 for j in range(len(y_true)):
                     if y_true[j] == 0.0:
                         tightest_i = max(y_pred[i])
-                        contribution = sum(
-                            max(0, y_pred[j, k] - tightest_i + delta)**2
-                            for k in range(y_pred.shape[1])
-                        )
-                        contributions.append(contribution)
+                        for k in range(y_pred.shape[1]):
+                            contribution = max(
+                                0, y_pred[j, k] - tightest_i + delta)**2
+                            contributions.append(contribution)
         contributions = numpy.array(contributions)
         expected1 = contributions.sum() / len(contributions)
 
@@ -278,18 +278,21 @@ def Xtest_loss():
         ])
 
         neg = y_pred[(y_true == 0.0).astype(bool)]
+        term = neg.reshape((-1, 1)) - pos + delta
+        print("Term:")
+        print(term)
         expected2 = (
-                numpy.maximum(0, neg.reshape((-1, 1)) - pos + delta)**2).sum() / (
-            len(pos) * len(neg))
+                numpy.maximum(0, term)**2).sum() / (
+            len(pos) * neg.shape[0] * neg.shape[1])
 
-        yield numpy.testing.assert_almost_equal, expected1, expected2, 4
+        numpy.testing.assert_almost_equal(expected1, expected2)
 
         computed = evaluate_loss(
             MultiallelicMassSpecLoss(delta=delta).loss,
             y_true,
-            y_pred.reshape(y_pred.shape + (1,)))
+            y_pred.reshape(y_pred.shape))
 
-        yield numpy.testing.assert_almost_equal, computed, expected1, 4
+        numpy.testing.assert_almost_equal(computed, expected1, 4)
 
 
 AA_DIST = pandas.Series(