From 492aa09d7dccc2cc579ee33725cfd59c789a6402 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Tue, 20 Mar 2018 16:39:33 -0400 Subject: [PATCH] Fix bug that led to NaN mhcflurry_prediction_high and mhcflurry_prediction_low uncertainty estimates in some cases --- mhcflurry/class1_affinity_predictor.py | 4 ++-- test/test_class1_affinity_predictor.py | 8 ++++++++ test/test_predict_command.py | 1 + 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py index 63cf7960..85ecc1a1 100644 --- a/mhcflurry/class1_affinity_predictor.py +++ b/mhcflurry/class1_affinity_predictor.py @@ -1005,8 +1005,8 @@ class Class1AffinityPredictor(object): df["prediction"] = numpy.exp(log_centers) if include_confidence_intervals: - df["prediction_low"] = numpy.exp(numpy.percentile(logs, 5.0, axis=1)) - df["prediction_high"] = numpy.exp(numpy.percentile(logs, 95.0, axis=1)) + df["prediction_low"] = numpy.exp(numpy.nanpercentile(logs, 5.0, axis=1)) + df["prediction_high"] = numpy.exp(numpy.nanpercentile(logs, 95.0, axis=1)) if include_individual_model_predictions: for i in range(num_pan_models): diff --git a/test/test_class1_affinity_predictor.py b/test/test_class1_affinity_predictor.py index 30faaf43..3df548a8 100644 --- a/test/test_class1_affinity_predictor.py +++ b/test/test_class1_affinity_predictor.py @@ -218,6 +218,14 @@ def test_class1_affinity_predictor_a0205_memorize_training_data(): assert numpy.isnan(ic50_pred[2]) +def test_no_nans(): + df = DOWNLOADED_PREDICTOR.predict_to_dataframe( + alleles=["A02:01", "A02:02"], + peptides=["SIINFEKL", "SIINFEKLL"]) + print(df) + assert not df.isnull().any().any() + + def test_predict_implementations_equivalent(): for allele in ["HLA-A02:01", "A02:02"]: for centrality_measure in ["mean", "robust_mean"]: diff --git a/test/test_predict_command.py b/test/test_predict_command.py index cfaf2a0a..f4631020 100644 --- a/test/test_predict_command.py +++ b/test/test_predict_command.py @@ -28,6 +28,7 @@ def test_csv(): predict_command.run(full_args) result = pandas.read_csv(fd_out.name) print(result) + assert not result.isnull().any().any() finally: for delete in deletes: os.unlink(delete) -- GitLab