Skip to content
Snippets Groups Projects
Commit b549dada authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Fix robust_mean implementation

parent 612fe5c8
No related merge requests found
......@@ -25,10 +25,11 @@ def robust_mean(log_values):
if log_values.shape[1] <= 3:
# Too few values to use robust mean.
return numpy.nanmean(log_values, axis=1)
without_nans = numpy.nan_to_num(log_values) # replace nan with 0
mask = (
(log_values <= numpy.nanpercentile(log_values, 75, axis=1).reshape((-1, 1))) &
(log_values >= numpy.nanpercentile(log_values, 25, axis=1).reshape((-1, 1))))
return (log_values * mask.astype(float)).sum(1) / mask.sum(1)
(without_nans <= numpy.nanpercentile(log_values, 75, axis=1).reshape((-1, 1))) &
(without_nans >= numpy.nanpercentile(log_values, 25, axis=1).reshape((-1, 1))))
return (without_nans * mask.astype(float)).sum(1) / mask.sum(1)
CENTRALITY_MEASURES = {
"mean": partial(numpy.nanmean, axis=1),
......
import numpy
from numpy.testing import assert_equal
from mhcflurry import ensemble_centrality
def test_robust_mean():
arr1 = numpy.array([
[1, 2, 3, 4, 5],
[-10000, 2, 3, 4, 100000],
])
results = ensemble_centrality.robust_mean(arr1)
assert_equal(results, [3, 3])
# Should ignore nans.
arr2 = numpy.array([
[1, 2, 3, 4, 5],
[numpy.nan, 2, 3, 4, numpy.nan],
])
results = ensemble_centrality.robust_mean(arr2)
assert_equal(results, [3, 3])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment