diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py index 01774c3cc17fe30c555a834360fcc554f1b18016..c6521046efe8273e2ace3d0de1595cac1b75adea 100644 --- a/mhcflurry/class1_affinity_predictor.py +++ b/mhcflurry/class1_affinity_predictor.py @@ -27,6 +27,11 @@ from .ensemble_centrality import CENTRALITY_MEASURES from .allele_encoding import AlleleEncoding +# Default function for combining predictions across models in an ensemble. +# See ensemble_centrality.py for other options. +DEFAULT_CENTRALITY_MEASURE = "mean" + + class Class1AffinityPredictor(object): """ High-level interface for peptide/MHC I binding affinity prediction. @@ -722,7 +727,7 @@ class Class1AffinityPredictor(object): alleles=None, allele=None, throw=True, - centrality_measure="robust_mean"): + centrality_measure=DEFAULT_CENTRALITY_MEASURE): """ Predict nM binding affinities. @@ -757,6 +762,7 @@ class Class1AffinityPredictor(object): allele=allele, throw=throw, include_percentile_ranks=False, + include_confidence_intervals=False, centrality_measure=centrality_measure, ) return df.prediction.values @@ -769,7 +775,8 @@ class Class1AffinityPredictor(object): throw=True, include_individual_model_predictions=False, include_percentile_ranks=True, - centrality_measure="mean"): + include_confidence_intervals=True, + centrality_measure=DEFAULT_CENTRALITY_MEASURE): """ Predict nM binding affinities. Gives more detailed output than `predict` method, including 5-95% prediction intervals. @@ -812,18 +819,22 @@ class Class1AffinityPredictor(object): raise TypeError("alleles must be a list or array, not a string") if allele is None and alleles is None: raise ValueError("Must specify 'allele' or 'alleles'.") - if allele is not None: - if alleles is not None: - raise ValueError("Specify exactly one of allele or alleles") - alleles = [allele] * len(peptides) - alleles = numpy.array(alleles) peptides = EncodableSequences.create(peptides) - df = pandas.DataFrame({ 'peptide': peptides.sequences, - 'allele': alleles, }) + + if allele is not None: + if alleles is not None: + raise ValueError("Specify exactly one of allele or alleles") + df["allele"] = allele + df["normalized_allele"] = mhcnames.normalize_allele_name(allele) + else: + df["allele"] = numpy.array(alleles) + df["normalized_allele"] = df.allele.map( + mhcnames.normalize_allele_name) + if len(df) == 0: # No predictions. logging.warning("Predicting for 0 peptides.") @@ -837,9 +848,6 @@ class Class1AffinityPredictor(object): ]) return empty_result - df["normalized_allele"] = df.allele.map( - mhcnames.normalize_allele_name) - (min_peptide_length, max_peptide_length) = ( self.supported_peptide_lengths) df["supported_peptide_length"] = ( @@ -928,8 +936,9 @@ class Class1AffinityPredictor(object): logs = numpy.log(df_predictions) log_centers = centrality_function(logs.values) df["prediction"] = numpy.exp(log_centers) - df["prediction_low"] = numpy.exp(logs.quantile(0.05, axis=1)) - df["prediction_high"] = numpy.exp(logs.quantile(0.95, axis=1)) + if include_confidence_intervals: + df["prediction_low"] = numpy.exp(logs.quantile(0.05, axis=1)) + df["prediction_high"] = numpy.exp(logs.quantile(0.95, axis=1)) if include_individual_model_predictions: columns = sorted(df.columns, key=lambda c: c.startswith('model_')) diff --git a/mhcflurry/ensemble_centrality.py b/mhcflurry/ensemble_centrality.py index 54dc05500b2d65b80e8947729bb55be54252bd28..e370a39d66f31d8e343605a22c38e62bd12160b0 100644 --- a/mhcflurry/ensemble_centrality.py +++ b/mhcflurry/ensemble_centrality.py @@ -27,10 +27,12 @@ def robust_mean(log_values): return numpy.nanmean(log_values, axis=1) without_nans = numpy.nan_to_num(log_values) # replace nan with 0 mask = ( + (~numpy.isnan(log_values)) & (without_nans <= numpy.nanpercentile(log_values, 75, axis=1).reshape((-1, 1))) & (without_nans >= numpy.nanpercentile(log_values, 25, axis=1).reshape((-1, 1)))) return (without_nans * mask.astype(float)).sum(1) / mask.sum(1) + CENTRALITY_MEASURES = { "mean": partial(numpy.nanmean, axis=1), "median": partial(numpy.nanmedian, axis=1), diff --git a/test/test_class1_affinity_predictor.py b/test/test_class1_affinity_predictor.py index 40ac8c2f9ea22b1b6ad80495da061bd6b0acf5c5..afd36f7fc139b7d0e50805b85baa81e6b87811ca 100644 --- a/test/test_class1_affinity_predictor.py +++ b/test/test_class1_affinity_predictor.py @@ -217,3 +217,30 @@ def test_class1_affinity_predictor_a0205_memorize_training_data(): assert not numpy.isnan(ic50_pred[1]) assert numpy.isnan(ic50_pred[2]) + +def test_predict_implementations_equivalent(): + for allele in ["HLA-A02:01", "A02:02"]: + for centrality_measure in ["mean", "robust_mean"]: + peptides = ["SIINFEKL", "SYYNFIIIKL", "SIINKFELQY"] + + pred1 = DOWNLOADED_PREDICTOR.predict( + allele=allele, + peptides=peptides + ["SSSN"], + throw=False, + centrality_measure=centrality_measure) + pred2 = DOWNLOADED_PREDICTOR.predict_to_dataframe( + allele=allele, + peptides=peptides + ["SSSN"], + throw=False, + centrality_measure=centrality_measure).prediction.values + testing.assert_equal(pred1, pred2) + + pred1 = DOWNLOADED_PREDICTOR.predict( + allele=allele, + peptides=peptides, + centrality_measure=centrality_measure) + pred2 = DOWNLOADED_PREDICTOR.predict_to_dataframe( + allele=allele, + peptides=peptides, + centrality_measure=centrality_measure).prediction.values + testing.assert_equal(pred1, pred2) diff --git a/test/test_ensemble_centrality.py b/test/test_ensemble_centrality.py index 3f5188488312404a042ca20a09ba9c8623e3a22a..69ed4bbb82689dd00ed6d0f9f6bd9a8e0652b77c 100644 --- a/test/test_ensemble_centrality.py +++ b/test/test_ensemble_centrality.py @@ -8,7 +8,7 @@ from mhcflurry import ensemble_centrality def test_robust_mean(): arr1 = numpy.array([ [1, 2, 3, 4, 5], - [-10000, 2, 3, 4, 100000], + [-10000, 2, 3, 4, 100], ]) results = ensemble_centrality.robust_mean(arr1) @@ -17,8 +17,12 @@ def test_robust_mean(): # Should ignore nans. arr2 = numpy.array([ [1, 2, 3, 4, 5], - [numpy.nan, 2, 3, 4, numpy.nan], + [numpy.nan, 1, 2, 3, numpy.nan], + [numpy.nan, numpy.nan, numpy.nan, numpy.nan, numpy.nan], ]) - results = ensemble_centrality.robust_mean(arr2) - assert_equal(results, [3, 3]) + results = ensemble_centrality.CENTRALITY_MEASURES["robust_mean"](arr2) + assert_equal(results, [3, 2, numpy.nan]) + + results = ensemble_centrality.CENTRALITY_MEASURES["mean"](arr2) + assert_equal(results, [3, 2, numpy.nan])