From b549dadaa718c47ca5cfd316c45222d1133cd5da Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Thu, 8 Feb 2018 18:46:10 -0500
Subject: [PATCH] Fix robust_mean implementation

---
 mhcflurry/ensemble_centrality.py |  7 ++++---
 test/test_ensemble_centrality.py | 24 ++++++++++++++++++++++++
 2 files changed, 28 insertions(+), 3 deletions(-)
 create mode 100644 test/test_ensemble_centrality.py

diff --git a/mhcflurry/ensemble_centrality.py b/mhcflurry/ensemble_centrality.py
index 8a812d76..54dc0550 100644
--- a/mhcflurry/ensemble_centrality.py
+++ b/mhcflurry/ensemble_centrality.py
@@ -25,10 +25,11 @@ def robust_mean(log_values):
     if log_values.shape[1] <= 3:
         # Too few values to use robust mean.
         return numpy.nanmean(log_values, axis=1)
+    without_nans = numpy.nan_to_num(log_values)  # replace nan with 0
     mask = (
-        (log_values <= numpy.nanpercentile(log_values, 75, axis=1).reshape((-1, 1))) &
-        (log_values >= numpy.nanpercentile(log_values, 25, axis=1).reshape((-1, 1))))
-    return (log_values * mask.astype(float)).sum(1) / mask.sum(1)
+        (without_nans <= numpy.nanpercentile(log_values, 75, axis=1).reshape((-1, 1))) &
+        (without_nans >= numpy.nanpercentile(log_values, 25, axis=1).reshape((-1, 1))))
+    return (without_nans * mask.astype(float)).sum(1) / mask.sum(1)
 
 CENTRALITY_MEASURES = {
     "mean": partial(numpy.nanmean, axis=1),
diff --git a/test/test_ensemble_centrality.py b/test/test_ensemble_centrality.py
new file mode 100644
index 00000000..3f518848
--- /dev/null
+++ b/test/test_ensemble_centrality.py
@@ -0,0 +1,24 @@
+import numpy
+
+from numpy.testing import assert_equal
+
+from mhcflurry import ensemble_centrality
+
+
+def test_robust_mean():
+    arr1 = numpy.array([
+        [1, 2, 3, 4, 5],
+        [-10000, 2, 3, 4, 100000],
+    ])
+
+    results = ensemble_centrality.robust_mean(arr1)
+    assert_equal(results, [3, 3])
+
+    # Should ignore nans.
+    arr2 = numpy.array([
+        [1, 2, 3, 4, 5],
+        [numpy.nan, 2, 3, 4, numpy.nan],
+    ])
+
+    results = ensemble_centrality.robust_mean(arr2)
+    assert_equal(results, [3, 3])
-- 
GitLab