Skip to content
Snippets Groups Projects
Commit 3c0e8626 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

more output in tests

parent 253304d0
No related merge requests found
import logging
from copy import copy from copy import copy
import pandas import pandas
from numpy import log, exp, nanmean, array from numpy import log, exp, nanmean
from ...dataset import Dataset from ...dataset import Dataset
from ...class1_allele_specific import Class1BindingPredictor from ...class1_allele_specific import Class1BindingPredictor
from ...common import normalize_allele_name, dataframe_cryptographic_hash from ...common import normalize_allele_name
from .presentation_component_model import PresentationComponentModel from .mhc_binding_component_model_base import MHCBindingComponentModelBase
from ..decoy_strategies import SameTranscriptsAsHits
from ..percent_rank_transform import PercentRankTransform
MHCFLURRY_DEFAULT_HYPERPARAMETERS = dict( MHCFLURRY_DEFAULT_HYPERPARAMETERS = dict(
...@@ -18,44 +15,16 @@ MHCFLURRY_DEFAULT_HYPERPARAMETERS = dict( ...@@ -18,44 +15,16 @@ MHCFLURRY_DEFAULT_HYPERPARAMETERS = dict(
dropout_probability=0.25) dropout_probability=0.25)
class MHCflurryTrainedOnHits(PresentationComponentModel): class MHCflurryTrainedOnHits(MHCBindingComponentModelBase):
""" """
Final model input that is a mhcflurry predictor trained on mass-spec Final model input that is a mhcflurry predictor trained on mass-spec
hits and, optionally, affinity measurements (for example from IEDB). hits and, optionally, affinity measurements (for example from IEDB).
Parameters Parameters
------------ ------------
predictor_name : string
used on column name. Example: 'vanilla'
experiment_to_alleles : dict: string -> string list
Normalized allele names for each experiment.
experiment_to_expression_group : dict of string -> string
Maps experiment names to expression groups.
transcripts : pandas.DataFrame
Index is peptide, columns are expression groups, values are
which transcript to use for the given peptide.
Not required if decoy_strategy specified.
peptides_and_transcripts : pandas.DataFrame
Dataframe with columns 'peptide' and 'transcript'
Not required if decoy_strategy specified.
decoy_strategy : decoy_strategy.DecoyStrategy
how to pick decoys. If not specified peptides_and_transcripts and
transcripts must be specified.
fallback_predictor : function: (allele, peptides) -> predictions
Used when missing an allele.
iedb_dataset : mhcflurry.Dataset iedb_dataset : mhcflurry.Dataset
IEDB data for this allele. If not specified no iedb data is used. IEDB data for this allele. If not specified no iedb data is used.
decoys_per_hit : int
mhcflurry_hyperparameters : dict mhcflurry_hyperparameters : dict
hit_affinity : float hit_affinity : float
...@@ -64,122 +33,43 @@ class MHCflurryTrainedOnHits(PresentationComponentModel): ...@@ -64,122 +33,43 @@ class MHCflurryTrainedOnHits(PresentationComponentModel):
decoy_affinity : float decoy_affinity : float
nM affinity to use for decoys nM affinity to use for decoys
random_peptides_for_percent_rank : list of string **kwargs : dict
If specified, then percentile rank will be calibrated and emitted Passed to MHCBindingComponentModel()
using the given peptides.
""" """
def __init__( def __init__(
self, self,
predictor_name,
experiment_to_alleles,
experiment_to_expression_group=None,
transcripts=None,
peptides_and_transcripts=None,
decoy_strategy=None,
fallback_predictor=None,
iedb_dataset=None, iedb_dataset=None,
decoys_per_hit=10,
mhcflurry_hyperparameters=MHCFLURRY_DEFAULT_HYPERPARAMETERS, mhcflurry_hyperparameters=MHCFLURRY_DEFAULT_HYPERPARAMETERS,
hit_affinity=100, hit_affinity=100,
decoy_affinity=20000, decoy_affinity=20000,
random_peptides_for_percent_rank=None,
**kwargs): **kwargs):
PresentationComponentModel.__init__(self, **kwargs) MHCBindingComponentModelBase.__init__(self, **kwargs)
self.predictor_name = predictor_name
self.experiment_to_alleles = experiment_to_alleles
self.fallback_predictor = fallback_predictor
self.iedb_dataset = iedb_dataset self.iedb_dataset = iedb_dataset
self.mhcflurry_hyperparameters = mhcflurry_hyperparameters self.mhcflurry_hyperparameters = mhcflurry_hyperparameters
self.hit_affinity = hit_affinity self.hit_affinity = hit_affinity
self.decoy_affinity = decoy_affinity self.decoy_affinity = decoy_affinity
self.allele_to_model = None self.allele_to_model = {}
if decoy_strategy is None:
assert peptides_and_transcripts is not None
assert transcripts is not None
self.decoy_strategy = SameTranscriptsAsHits(
experiment_to_expression_group=experiment_to_expression_group,
peptides_and_transcripts=peptides_and_transcripts,
peptide_to_expression_group_to_transcript=transcripts,
decoys_per_hit=decoys_per_hit)
else:
self.decoy_strategy = decoy_strategy
if random_peptides_for_percent_rank is None:
self.percent_rank_transforms = None
self.random_peptides_for_percent_rank = None
else:
self.percent_rank_transforms = {}
self.random_peptides_for_percent_rank = array(
random_peptides_for_percent_rank)
def combine_ensemble_predictions(self, column_name, values): def combine_ensemble_predictions(self, column_name, values):
# Geometric mean # Geometric mean
return exp(nanmean(log(values), axis=1)) return exp(nanmean(log(values), axis=1))
def stratification_groups(self, hits_df): def supports_predicting_allele(self, allele):
return [ return allele in self.allele_to_model
self.experiment_to_alleles[e][0]
for e in hits_df.experiment_name
]
def column_name_affinity(self):
return "mhcflurry_%s_affinity" % self.predictor_name
def column_name_percentile_rank(self):
return "mhcflurry_%s_percentile_rank" % self.predictor_name
def column_names(self):
columns = [self.column_name_affinity()]
if self.percent_rank_transforms is not None:
columns.append(self.column_name_percentile_rank())
return columns
def requires_fitting(self):
return True
def fit_percentile_rank_if_needed(self, alleles):
for allele in alleles:
if allele not in self.percent_rank_transforms:
logging.info('fitting percent rank for allele: %s' % allele)
self.percent_rank_transforms[allele] = PercentRankTransform()
self.percent_rank_transforms[allele].fit(
self.predict_affinity_for_allele(
allele,
self.random_peptides_for_percent_rank))
def fit(self, hits_df):
assert 'experiment_name' in hits_df.columns
assert 'peptide' in hits_df.columns
if 'hit' in hits_df.columns:
assert (hits_df.hit == 1).all()
grouped = hits_df.groupby("experiment_name") def fit_allele(self, allele, hit_list, decoys_list):
for (experiment_name, sub_df) in grouped:
self.fit_to_experiment(experiment_name, sub_df.peptide.values)
# No longer required after fitting.
self.decoy_strategy = None
self.iedb_dataset = None
def fit_to_experiment(self, experiment_name, hit_list):
assert len(hit_list) > 0
if self.allele_to_model is None: if self.allele_to_model is None:
self.allele_to_model = {} self.allele_to_model = {}
alleles = self.experiment_to_alleles[experiment_name]
if len(alleles) != 1:
raise ValueError("Monoallelic data required")
(allele,) = alleles
mhcflurry_allele = normalize_allele_name(allele)
assert allele not in self.allele_to_model, \ assert allele not in self.allele_to_model, \
"TODO: Support training on >1 experiments with same allele " \ "TODO: Support training on >1 experiments with same allele " \
+ str(self.allele_to_model) + str(self.allele_to_model)
mhcflurry_allele = normalize_allele_name(allele)
extra_hits = hit_list = set(hit_list) extra_hits = hit_list = set(hit_list)
iedb_dataset_df = None iedb_dataset_df = None
...@@ -191,10 +81,9 @@ class MHCflurryTrainedOnHits(PresentationComponentModel): ...@@ -191,10 +81,9 @@ class MHCflurryTrainedOnHits(PresentationComponentModel):
len(extra_hits), len(extra_hits),
len(hit_list))) len(hit_list)))
decoys = self.decoy_strategy.decoys_for_experiment( df = pandas.DataFrame({
experiment_name, hit_list) "peptide": sorted(set(hit_list).union(decoys_list))
})
df = pandas.DataFrame({"peptide": sorted(set(hit_list).union(decoys))})
df["allele"] = mhcflurry_allele df["allele"] = mhcflurry_allele
df["species"] = "human" df["species"] = "human"
df["affinity"] = (( df["affinity"] = ((
...@@ -215,72 +104,9 @@ class MHCflurryTrainedOnHits(PresentationComponentModel): ...@@ -215,72 +104,9 @@ class MHCflurryTrainedOnHits(PresentationComponentModel):
model.fit_dataset(dataset, verbose=True) model.fit_dataset(dataset, verbose=True)
self.allele_to_model[allele] = model self.allele_to_model[allele] = model
def predict_affinity_for_allele(self, allele, peptides): def predict_allele(self, allele, peptides_list):
if self.cached_predictions is None: assert self.allele_to_model, "Must fit first"
cache_key = None return self.allele_to_model[allele].predict(peptides_list)
cached_result = None
else:
cache_key = (
allele,
dataframe_cryptographic_hash(pandas.Series(peptides)))
cached_result = self.cached_predictions.get(cache_key)
if cached_result is not None:
print("Cache hit in predict_affinity_for_allele: %s %s %s" % (
allele, str(self), id(cached_result)))
return cached_result
else:
print("Cache miss in predict_affinity_for_allele: %s %s" % (
allele, str(self)))
if allele in self.allele_to_model:
result = self.allele_to_model[allele].predict(peptides)
elif self.fallback_predictor:
print(
"MHCflurry: falling back on %s, "
"available alleles: %s" % (
allele, ' '.join(self.allele_to_model)))
result = self.fallback_predictor(allele, peptides)
else:
raise ValueError("No model for allele: %s" % allele)
if self.cached_predictions is not None:
self.cached_predictions[cache_key] = result
return result
def predict_for_experiment(self, experiment_name, peptides):
assert self.allele_to_model is not None, "Must fit first"
peptides_deduped = pandas.unique(peptides)
print(len(peptides_deduped))
alleles = self.experiment_to_alleles[experiment_name]
predictions = pandas.DataFrame(index=peptides_deduped)
for allele in alleles:
predictions[allele] = self.predict_affinity_for_allele(
allele, peptides_deduped)
result = {
self.column_name_affinity(): (
predictions.min(axis=1).ix[peptides].values)
}
if self.percent_rank_transforms is not None:
self.fit_percentile_rank_if_needed(alleles)
percentile_ranks = pandas.DataFrame(index=peptides_deduped)
for allele in alleles:
percentile_ranks[allele] = (
self.percent_rank_transforms[allele]
.transform(predictions[allele].values))
result[self.column_name_percentile_rank()] = (
percentile_ranks.min(axis=1).ix[peptides].values)
assert all(len(x) == len(peptides) for x in result.values()), (
"Result lengths don't match peptide lengths. peptides=%d, "
"peptides_deduped=%d, %s" % (
len(peptides),
len(peptides_deduped),
", ".join(
"%s=%d" % (key, len(value))
for (key, value) in result.items())))
return result
def get_fit(self): def get_fit(self):
return { return {
......
...@@ -96,21 +96,27 @@ def test_percent_rank_transform(): ...@@ -96,21 +96,27 @@ def test_percent_rank_transform():
err_msg=str(model.__dict__)) err_msg=str(model.__dict__))
def test_mhcflurry_trained_on_hits(): def mhcflurry_basic_model():
mhcflurry_model = presentation_component_models.MHCflurryTrainedOnHits( return presentation_component_models.MHCflurryTrainedOnHits(
"basic", predictor_name="mhcflurry_affinity",
experiment_to_alleles=EXPERIMENT_TO_ALLELES, experiment_to_alleles=EXPERIMENT_TO_ALLELES,
experiment_to_expression_group=EXPERIMENT_TO_EXPRESSION_GROUP, experiment_to_expression_group=EXPERIMENT_TO_EXPRESSION_GROUP,
transcripts=TRANSCIPTS_DF, transcripts=TRANSCIPTS_DF,
peptides_and_transcripts=PEPTIDES_AND_TRANSCRIPTS_DF, peptides_and_transcripts=PEPTIDES_AND_TRANSCRIPTS_DF,
random_peptides_for_percent_rank=OTHER_PEPTIDES, random_peptides_for_percent_rank=OTHER_PEPTIDES,
) )
def test_mhcflurry_trained_on_hits():
mhcflurry_model = mhcflurry_basic_model()
mhcflurry_model.fit(HITS_DF) mhcflurry_model.fit(HITS_DF)
peptides = PEPTIDES_DF.copy() peptides = PEPTIDES_DF.copy()
predictions = mhcflurry_model.predict(peptides) predictions = mhcflurry_model.predict(peptides)
peptides["affinity"] = predictions["mhcflurry_basic_affinity"] peptides["affinity"] = predictions["mhcflurry_affinity_value"]
peptides["percent_rank"] = predictions["mhcflurry_basic_percentile_rank"] peptides["percent_rank"] = predictions[
"mhcflurry_affinity_percentile_rank"
]
assert_less( assert_less(
peptides.affinity[peptides.hit].mean(), peptides.affinity[peptides.hit].mean(),
peptides.affinity[~peptides.hit].mean()) peptides.affinity[~peptides.hit].mean())
...@@ -119,15 +125,25 @@ def test_mhcflurry_trained_on_hits(): ...@@ -119,15 +125,25 @@ def test_mhcflurry_trained_on_hits():
peptides.percent_rank[~peptides.hit].mean()) peptides.percent_rank[~peptides.hit].mean())
def compare_predictions(peptides_df, model1, model2):
predictions1 = model1.predict(peptides_df)
predictions2 = model2.predict(peptides_df)
failed = False
for i in range(len(peptides_df)):
if predictions1[i] != predictions2[i]:
failed = True
print(
"Compare predictions: mismatch at index %d: "
"%f != %f, row: %s" % (
i,
predictions1[i],
predictions2[i],
str(peptides_df.iloc[i])))
assert not failed
def test_presentation_model(): def test_presentation_model():
mhcflurry_model = presentation_component_models.MHCflurryTrainedOnHits( mhcflurry_model = mhcflurry_basic_model()
"basic",
experiment_to_alleles=EXPERIMENT_TO_ALLELES,
experiment_to_expression_group=EXPERIMENT_TO_EXPRESSION_GROUP,
transcripts=TRANSCIPTS_DF,
peptides_and_transcripts=PEPTIDES_AND_TRANSCRIPTS_DF,
random_peptides_for_percent_rank=OTHER_PEPTIDES,
)
aa_content_model = ( aa_content_model = (
presentation_component_models.FixedPerPeptideQuantity( presentation_component_models.FixedPerPeptideQuantity(
...@@ -141,7 +157,7 @@ def test_presentation_model(): ...@@ -141,7 +157,7 @@ def test_presentation_model():
terms = { terms = {
'A_ms': ( 'A_ms': (
[mhcflurry_model], [mhcflurry_model],
["log1p(mhcflurry_basic_affinity)"]), ["log1p(mhcflurry_affinity_value)"]),
'P': ( 'P': (
[aa_content_model], [aa_content_model],
list(AA_COMPOSITION_DF.columns)), list(AA_COMPOSITION_DF.columns)),
...@@ -170,14 +186,12 @@ def test_presentation_model(): ...@@ -170,14 +186,12 @@ def test_presentation_model():
peptides.prediction[peptides.hit].mean()) peptides.prediction[peptides.hit].mean())
model2 = pickle.loads(pickle.dumps(model)) model2 = pickle.loads(pickle.dumps(model))
assert_allclose( compare_predictions(peptides, model, model2)
model.predict(peptides), model2.predict(peptides))
model3 = unfit_model.clone() model3 = unfit_model.clone()
assert not model3.has_been_fit assert not model3.has_been_fit
model3.restore_fit(model2.get_fit()) model3.restore_fit(model2.get_fit())
assert_allclose( compare_predictions(peptides, model, model3)
model.predict(peptides), model3.predict(peptides))
better_unfit_model = models["A_ms + P"] better_unfit_model = models["A_ms + P"]
model = better_unfit_model.clone() model = better_unfit_model.clone()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment