From 492aa09d7dccc2cc579ee33725cfd59c789a6402 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 20 Mar 2018 16:39:33 -0400
Subject: [PATCH] Fix bug that led to NaN mhcflurry_prediction_high and
 mhcflurry_prediction_low uncertainty estimates in some cases

---
 mhcflurry/class1_affinity_predictor.py | 4 ++--
 test/test_class1_affinity_predictor.py | 8 ++++++++
 test/test_predict_command.py           | 1 +
 3 files changed, 11 insertions(+), 2 deletions(-)

diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py
index 63cf7960..85ecc1a1 100644
--- a/mhcflurry/class1_affinity_predictor.py
+++ b/mhcflurry/class1_affinity_predictor.py
@@ -1005,8 +1005,8 @@ class Class1AffinityPredictor(object):
         df["prediction"] = numpy.exp(log_centers)
 
         if include_confidence_intervals:
-            df["prediction_low"] = numpy.exp(numpy.percentile(logs, 5.0, axis=1))
-            df["prediction_high"] = numpy.exp(numpy.percentile(logs, 95.0, axis=1))
+            df["prediction_low"] = numpy.exp(numpy.nanpercentile(logs, 5.0, axis=1))
+            df["prediction_high"] = numpy.exp(numpy.nanpercentile(logs, 95.0, axis=1))
 
         if include_individual_model_predictions:
             for i in range(num_pan_models):
diff --git a/test/test_class1_affinity_predictor.py b/test/test_class1_affinity_predictor.py
index 30faaf43..3df548a8 100644
--- a/test/test_class1_affinity_predictor.py
+++ b/test/test_class1_affinity_predictor.py
@@ -218,6 +218,14 @@ def test_class1_affinity_predictor_a0205_memorize_training_data():
     assert numpy.isnan(ic50_pred[2])
 
 
+def test_no_nans():
+    df = DOWNLOADED_PREDICTOR.predict_to_dataframe(
+            alleles=["A02:01", "A02:02"],
+            peptides=["SIINFEKL", "SIINFEKLL"])
+    print(df)
+    assert not df.isnull().any().any()
+
+
 def test_predict_implementations_equivalent():
     for allele in ["HLA-A02:01", "A02:02"]:
         for centrality_measure in ["mean", "robust_mean"]:
diff --git a/test/test_predict_command.py b/test/test_predict_command.py
index cfaf2a0a..f4631020 100644
--- a/test/test_predict_command.py
+++ b/test/test_predict_command.py
@@ -28,6 +28,7 @@ def test_csv():
         predict_command.run(full_args)
         result = pandas.read_csv(fd_out.name)
         print(result)
+        assert not result.isnull().any().any()
     finally:
         for delete in deletes:
             os.unlink(delete)
-- 
GitLab