Skip to content
Snippets Groups Projects
Commit c6076446 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fixes

parent 553502bb
No related merge requests found
......@@ -27,6 +27,11 @@ from .ensemble_centrality import CENTRALITY_MEASURES
from .allele_encoding import AlleleEncoding
# Default function for combining predictions across models in an ensemble.
# See ensemble_centrality.py for other options.
DEFAULT_CENTRALITY_MEASURE = "mean"
class Class1AffinityPredictor(object):
"""
High-level interface for peptide/MHC I binding affinity prediction.
......@@ -722,7 +727,7 @@ class Class1AffinityPredictor(object):
alleles=None,
allele=None,
throw=True,
centrality_measure="robust_mean"):
centrality_measure=DEFAULT_CENTRALITY_MEASURE):
"""
Predict nM binding affinities.
......@@ -757,6 +762,7 @@ class Class1AffinityPredictor(object):
allele=allele,
throw=throw,
include_percentile_ranks=False,
include_confidence_intervals=False,
centrality_measure=centrality_measure,
)
return df.prediction.values
......@@ -769,7 +775,8 @@ class Class1AffinityPredictor(object):
throw=True,
include_individual_model_predictions=False,
include_percentile_ranks=True,
centrality_measure="mean"):
include_confidence_intervals=True,
centrality_measure=DEFAULT_CENTRALITY_MEASURE):
"""
Predict nM binding affinities. Gives more detailed output than `predict`
method, including 5-95% prediction intervals.
......@@ -812,18 +819,22 @@ class Class1AffinityPredictor(object):
raise TypeError("alleles must be a list or array, not a string")
if allele is None and alleles is None:
raise ValueError("Must specify 'allele' or 'alleles'.")
if allele is not None:
if alleles is not None:
raise ValueError("Specify exactly one of allele or alleles")
alleles = [allele] * len(peptides)
alleles = numpy.array(alleles)
peptides = EncodableSequences.create(peptides)
df = pandas.DataFrame({
'peptide': peptides.sequences,
'allele': alleles,
})
if allele is not None:
if alleles is not None:
raise ValueError("Specify exactly one of allele or alleles")
df["allele"] = allele
df["normalized_allele"] = mhcnames.normalize_allele_name(allele)
else:
df["allele"] = numpy.array(alleles)
df["normalized_allele"] = df.allele.map(
mhcnames.normalize_allele_name)
if len(df) == 0:
# No predictions.
logging.warning("Predicting for 0 peptides.")
......@@ -837,9 +848,6 @@ class Class1AffinityPredictor(object):
])
return empty_result
df["normalized_allele"] = df.allele.map(
mhcnames.normalize_allele_name)
(min_peptide_length, max_peptide_length) = (
self.supported_peptide_lengths)
df["supported_peptide_length"] = (
......@@ -928,8 +936,9 @@ class Class1AffinityPredictor(object):
logs = numpy.log(df_predictions)
log_centers = centrality_function(logs.values)
df["prediction"] = numpy.exp(log_centers)
df["prediction_low"] = numpy.exp(logs.quantile(0.05, axis=1))
df["prediction_high"] = numpy.exp(logs.quantile(0.95, axis=1))
if include_confidence_intervals:
df["prediction_low"] = numpy.exp(logs.quantile(0.05, axis=1))
df["prediction_high"] = numpy.exp(logs.quantile(0.95, axis=1))
if include_individual_model_predictions:
columns = sorted(df.columns, key=lambda c: c.startswith('model_'))
......
......@@ -27,10 +27,12 @@ def robust_mean(log_values):
return numpy.nanmean(log_values, axis=1)
without_nans = numpy.nan_to_num(log_values) # replace nan with 0
mask = (
(~numpy.isnan(log_values)) &
(without_nans <= numpy.nanpercentile(log_values, 75, axis=1).reshape((-1, 1))) &
(without_nans >= numpy.nanpercentile(log_values, 25, axis=1).reshape((-1, 1))))
return (without_nans * mask.astype(float)).sum(1) / mask.sum(1)
CENTRALITY_MEASURES = {
"mean": partial(numpy.nanmean, axis=1),
"median": partial(numpy.nanmedian, axis=1),
......
......@@ -217,3 +217,30 @@ def test_class1_affinity_predictor_a0205_memorize_training_data():
assert not numpy.isnan(ic50_pred[1])
assert numpy.isnan(ic50_pred[2])
def test_predict_implementations_equivalent():
for allele in ["HLA-A02:01", "A02:02"]:
for centrality_measure in ["mean", "robust_mean"]:
peptides = ["SIINFEKL", "SYYNFIIIKL", "SIINKFELQY"]
pred1 = DOWNLOADED_PREDICTOR.predict(
allele=allele,
peptides=peptides + ["SSSN"],
throw=False,
centrality_measure=centrality_measure)
pred2 = DOWNLOADED_PREDICTOR.predict_to_dataframe(
allele=allele,
peptides=peptides + ["SSSN"],
throw=False,
centrality_measure=centrality_measure).prediction.values
testing.assert_equal(pred1, pred2)
pred1 = DOWNLOADED_PREDICTOR.predict(
allele=allele,
peptides=peptides,
centrality_measure=centrality_measure)
pred2 = DOWNLOADED_PREDICTOR.predict_to_dataframe(
allele=allele,
peptides=peptides,
centrality_measure=centrality_measure).prediction.values
testing.assert_equal(pred1, pred2)
......@@ -8,7 +8,7 @@ from mhcflurry import ensemble_centrality
def test_robust_mean():
arr1 = numpy.array([
[1, 2, 3, 4, 5],
[-10000, 2, 3, 4, 100000],
[-10000, 2, 3, 4, 100],
])
results = ensemble_centrality.robust_mean(arr1)
......@@ -17,8 +17,12 @@ def test_robust_mean():
# Should ignore nans.
arr2 = numpy.array([
[1, 2, 3, 4, 5],
[numpy.nan, 2, 3, 4, numpy.nan],
[numpy.nan, 1, 2, 3, numpy.nan],
[numpy.nan, numpy.nan, numpy.nan, numpy.nan, numpy.nan],
])
results = ensemble_centrality.robust_mean(arr2)
assert_equal(results, [3, 3])
results = ensemble_centrality.CENTRALITY_MEASURES["robust_mean"](arr2)
assert_equal(results, [3, 2, numpy.nan])
results = ensemble_centrality.CENTRALITY_MEASURES["mean"](arr2)
assert_equal(results, [3, 2, numpy.nan])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment