diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index fa4f4e41000b85444680e9644f1f60f3972a69d6..40743c1b46384cd41800fc2dc779e8ba722134f1 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -536,6 +536,8 @@ class Class1NeuralNetwork(object):
                 x_dict_without_random_negatives[key][shuffle_permutation])
         if sample_weights is not None:
             sample_weights = sample_weights[shuffle_permutation]
+        if output_indices is not None:
+            output_indices = output_indices[shuffle_permutation]
 
         if self.hyperparameters['loss'].startswith("custom:"):
             # Using a custom loss
diff --git a/mhcflurry/custom_loss.py b/mhcflurry/custom_loss.py
index 7cad898b11e423e3969cf09b95b90fa6f3027b51..8bd1b9537ce868f71f1a27bea5739fb98bad32cf 100644
--- a/mhcflurry/custom_loss.py
+++ b/mhcflurry/custom_loss.py
@@ -90,10 +90,12 @@ class MSEWithInequalities(object):
         diff3 *= K.cast(y_true >= 4.0, "float32")
         diff3 *= K.cast(diff3 > 0.0, "float32")
 
-        return (
-            K.sum(K.square(diff1), axis=-1) +
-            K.sum(K.square(diff2), axis=-1) +
-            K.sum(K.square(diff3), axis=-1)) / K.cast(K.shape(y_true)[0], "float32")
+        result = (
+            K.sum(K.square(diff1)) +
+            K.sum(K.square(diff2)) +
+            K.sum(K.square(diff3))) / K.cast(K.shape(y_true)[0], "float32")
+
+        return result
 
 
 class MSEWithInequalitiesAndMultipleOutputs(object):
@@ -128,6 +130,11 @@ class MSEWithInequalitiesAndMultipleOutputs(object):
     def loss(y_true, y_pred):
         from keras import backend as K
 
+        #y_true = K.print_tensor(y_true, "y_true1")
+        #y_pred = K.print_tensor(y_pred, "y_pred1")
+
+        y_true = K.flatten(y_true)
+
         output_indices = y_true // 10
         updated_y_true = y_true - (10 * output_indices)
 
@@ -136,6 +143,8 @@ class MSEWithInequalitiesAndMultipleOutputs(object):
         ordinals = K.arange(K.shape(y_true)[0])
         flattened_indices = (
             ordinals * y_pred.shape[1] + K.cast(output_indices, "int32"))
+        import tensorflow
+        #flattened_indices = tensorflow.Print(flattened_indices, [flattened_indices], "flattened_indices", summarize=1000)
         updated_y_pred = K.gather(K.flatten(y_pred), flattened_indices)
 
         # Alternative implementation using tensorflow, which could be used if
diff --git a/test/test_multi_output.py b/test/test_multi_output.py
index 436c4c9de792ac0413d95d9803bde11287981cc9..31e75b0d7a4132881f4d4f0acaea42462f100e10 100644
--- a/test/test_multi_output.py
+++ b/test/test_multi_output.py
@@ -20,8 +20,8 @@ def test_multi_output():
         loss="custom:mse_with_inequalities_and_multiple_outputs",
         activation="tanh",
         layer_sizes=[16],
-        max_epochs=50,
-        minibatch_size=32,
+        max_epochs=500,
+        minibatch_size=250,
         random_negative_rate=0.0,
         random_negative_constant=0.0,
         early_stopping=False,
@@ -30,12 +30,14 @@ def test_multi_output():
         ],
         dense_layer_l1_regularization=0.0,
         dropout_probability=0.0,
-        num_outputs=2)
+        optimizer="adam",
+        num_outputs=3)
 
     df = pandas.DataFrame()
     df["peptide"] = random_peptides(10000, length=9)
-    df["output1"] = df.peptide.map(lambda s: s[4] == 'K').astype(int) * 10000 + 0.01
-    df["output2"] = df.peptide.map(lambda s: s[3] == 'Q').astype(int) * 10000 + 0.01
+    df["output1"] = df.peptide.map(lambda s: s[4] == 'K').astype(int) * 49000 + 1
+    df["output2"] = df.peptide.map(lambda s: s[3] == 'Q').astype(int) * 49000 + 1
+    df["output3"] = df.peptide.map(lambda s: s[4] == 'K' or s[3] == 'Q').astype(int) * 49000 + 1
 
     print("output1 mean", df.output1.mean())
     print("output2 mean", df.output2.mean())
@@ -45,6 +47,7 @@ def test_multi_output():
     stacked["output_index"] = stacked.output_name.map({
         "output1": 0,
         "output2": 1,
+        "output3": 2,
     })
     assert not stacked.output_index.isnull().any(), stacked
 
@@ -53,17 +56,36 @@ def test_multi_output():
     }
 
     predictor = Class1NeuralNetwork(**hyperparameters)
+    stacked_train = stacked
     predictor.fit(
-        stacked.peptide.values,
-        stacked.value.values,
-        output_indices=stacked.output_index.values,
+        stacked_train.peptide.values,
+        stacked_train.value.values,
+        output_indices=stacked_train.output_index.values,
         **fit_kwargs)
+
     result = predictor.predict(df.peptide.values, output_index=None)
     print(df.shape, result.shape)
     print(result)
 
     df["prediction1"] = result[:,0]
     df["prediction2"] = result[:,1]
+    df["prediction3"] = result[:,2]
+
+    df_by_peptide = df.set_index("peptide")
+
+    correlation = pandas.DataFrame(
+        numpy.corrcoef(df_by_peptide.T),
+        columns=df_by_peptide.columns,
+        index=df_by_peptide.columns)
+    print(correlation)
+
+    sub_correlation = correlation.loc[
+        ["output1", "output2", "output3"],
+        ["prediction1", "prediction2", "prediction3"],
+    ]
+    assert sub_correlation.iloc[0, 0] > 0.99, correlation
+    assert sub_correlation.iloc[1, 1] > 0.99, correlation
+    assert sub_correlation.iloc[2, 2] > 0.99, correlation
 
     import ipdb ; ipdb.set_trace()