diff --git a/mhcflurry/ensemble_centrality.py b/mhcflurry/ensemble_centrality.py
index 8a812d7656b40099f8bec67e7a0d5832c9cc677f..54dc05500b2d65b80e8947729bb55be54252bd28 100644
--- a/mhcflurry/ensemble_centrality.py
+++ b/mhcflurry/ensemble_centrality.py
@@ -25,10 +25,11 @@ def robust_mean(log_values):
     if log_values.shape[1] <= 3:
         # Too few values to use robust mean.
         return numpy.nanmean(log_values, axis=1)
+    without_nans = numpy.nan_to_num(log_values)  # replace nan with 0
     mask = (
-        (log_values <= numpy.nanpercentile(log_values, 75, axis=1).reshape((-1, 1))) &
-        (log_values >= numpy.nanpercentile(log_values, 25, axis=1).reshape((-1, 1))))
-    return (log_values * mask.astype(float)).sum(1) / mask.sum(1)
+        (without_nans <= numpy.nanpercentile(log_values, 75, axis=1).reshape((-1, 1))) &
+        (without_nans >= numpy.nanpercentile(log_values, 25, axis=1).reshape((-1, 1))))
+    return (without_nans * mask.astype(float)).sum(1) / mask.sum(1)
 
 CENTRALITY_MEASURES = {
     "mean": partial(numpy.nanmean, axis=1),
diff --git a/test/test_ensemble_centrality.py b/test/test_ensemble_centrality.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f5188488312404a042ca20a09ba9c8623e3a22a
--- /dev/null
+++ b/test/test_ensemble_centrality.py
@@ -0,0 +1,24 @@
+import numpy
+
+from numpy.testing import assert_equal
+
+from mhcflurry import ensemble_centrality
+
+
+def test_robust_mean():
+    arr1 = numpy.array([
+        [1, 2, 3, 4, 5],
+        [-10000, 2, 3, 4, 100000],
+    ])
+
+    results = ensemble_centrality.robust_mean(arr1)
+    assert_equal(results, [3, 3])
+
+    # Should ignore nans.
+    arr2 = numpy.array([
+        [1, 2, 3, 4, 5],
+        [numpy.nan, 2, 3, 4, numpy.nan],
+    ])
+
+    results = ensemble_centrality.robust_mean(arr2)
+    assert_equal(results, [3, 3])