From b549dadaa718c47ca5cfd316c45222d1133cd5da Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Thu, 8 Feb 2018 18:46:10 -0500 Subject: [PATCH] Fix robust_mean implementation --- mhcflurry/ensemble_centrality.py | 7 ++++--- test/test_ensemble_centrality.py | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 test/test_ensemble_centrality.py diff --git a/mhcflurry/ensemble_centrality.py b/mhcflurry/ensemble_centrality.py index 8a812d76..54dc0550 100644 --- a/mhcflurry/ensemble_centrality.py +++ b/mhcflurry/ensemble_centrality.py @@ -25,10 +25,11 @@ def robust_mean(log_values): if log_values.shape[1] <= 3: # Too few values to use robust mean. return numpy.nanmean(log_values, axis=1) + without_nans = numpy.nan_to_num(log_values) # replace nan with 0 mask = ( - (log_values <= numpy.nanpercentile(log_values, 75, axis=1).reshape((-1, 1))) & - (log_values >= numpy.nanpercentile(log_values, 25, axis=1).reshape((-1, 1)))) - return (log_values * mask.astype(float)).sum(1) / mask.sum(1) + (without_nans <= numpy.nanpercentile(log_values, 75, axis=1).reshape((-1, 1))) & + (without_nans >= numpy.nanpercentile(log_values, 25, axis=1).reshape((-1, 1)))) + return (without_nans * mask.astype(float)).sum(1) / mask.sum(1) CENTRALITY_MEASURES = { "mean": partial(numpy.nanmean, axis=1), diff --git a/test/test_ensemble_centrality.py b/test/test_ensemble_centrality.py new file mode 100644 index 00000000..3f518848 --- /dev/null +++ b/test/test_ensemble_centrality.py @@ -0,0 +1,24 @@ +import numpy + +from numpy.testing import assert_equal + +from mhcflurry import ensemble_centrality + + +def test_robust_mean(): + arr1 = numpy.array([ + [1, 2, 3, 4, 5], + [-10000, 2, 3, 4, 100000], + ]) + + results = ensemble_centrality.robust_mean(arr1) + assert_equal(results, [3, 3]) + + # Should ignore nans. + arr2 = numpy.array([ + [1, 2, 3, 4, 5], + [numpy.nan, 2, 3, 4, numpy.nan], + ]) + + results = ensemble_centrality.robust_mean(arr2) + assert_equal(results, [3, 3]) -- GitLab