diff --git a/mhcflurry/ensemble_centrality.py b/mhcflurry/ensemble_centrality.py index 8a812d7656b40099f8bec67e7a0d5832c9cc677f..54dc05500b2d65b80e8947729bb55be54252bd28 100644 --- a/mhcflurry/ensemble_centrality.py +++ b/mhcflurry/ensemble_centrality.py @@ -25,10 +25,11 @@ def robust_mean(log_values): if log_values.shape[1] <= 3: # Too few values to use robust mean. return numpy.nanmean(log_values, axis=1) + without_nans = numpy.nan_to_num(log_values) # replace nan with 0 mask = ( - (log_values <= numpy.nanpercentile(log_values, 75, axis=1).reshape((-1, 1))) & - (log_values >= numpy.nanpercentile(log_values, 25, axis=1).reshape((-1, 1)))) - return (log_values * mask.astype(float)).sum(1) / mask.sum(1) + (without_nans <= numpy.nanpercentile(log_values, 75, axis=1).reshape((-1, 1))) & + (without_nans >= numpy.nanpercentile(log_values, 25, axis=1).reshape((-1, 1)))) + return (without_nans * mask.astype(float)).sum(1) / mask.sum(1) CENTRALITY_MEASURES = { "mean": partial(numpy.nanmean, axis=1), diff --git a/test/test_ensemble_centrality.py b/test/test_ensemble_centrality.py new file mode 100644 index 0000000000000000000000000000000000000000..3f5188488312404a042ca20a09ba9c8623e3a22a --- /dev/null +++ b/test/test_ensemble_centrality.py @@ -0,0 +1,24 @@ +import numpy + +from numpy.testing import assert_equal + +from mhcflurry import ensemble_centrality + + +def test_robust_mean(): + arr1 = numpy.array([ + [1, 2, 3, 4, 5], + [-10000, 2, 3, 4, 100000], + ]) + + results = ensemble_centrality.robust_mean(arr1) + assert_equal(results, [3, 3]) + + # Should ignore nans. + arr2 = numpy.array([ + [1, 2, 3, 4, 5], + [numpy.nan, 2, 3, 4, numpy.nan], + ]) + + results = ensemble_centrality.robust_mean(arr2) + assert_equal(results, [3, 3])