diff --git a/mhcflurry/antigen_presentation/presentation_component_models/mhcflurry_trained_on_hits.py b/mhcflurry/antigen_presentation/presentation_component_models/mhcflurry_trained_on_hits.py index cbb9b8d391324738261934a90fc42051c1c1c310..0f0177e03cf4abbd344b4935cefb8aa01d15a4e3 100644 --- a/mhcflurry/antigen_presentation/presentation_component_models/mhcflurry_trained_on_hits.py +++ b/mhcflurry/antigen_presentation/presentation_component_models/mhcflurry_trained_on_hits.py @@ -1,16 +1,13 @@ -import logging from copy import copy import pandas -from numpy import log, exp, nanmean, array +from numpy import log, exp, nanmean from ...dataset import Dataset from ...class1_allele_specific import Class1BindingPredictor -from ...common import normalize_allele_name, dataframe_cryptographic_hash +from ...common import normalize_allele_name -from .presentation_component_model import PresentationComponentModel -from ..decoy_strategies import SameTranscriptsAsHits -from ..percent_rank_transform import PercentRankTransform +from .mhc_binding_component_model_base import MHCBindingComponentModelBase MHCFLURRY_DEFAULT_HYPERPARAMETERS = dict( @@ -18,44 +15,16 @@ MHCFLURRY_DEFAULT_HYPERPARAMETERS = dict( dropout_probability=0.25) -class MHCflurryTrainedOnHits(PresentationComponentModel): +class MHCflurryTrainedOnHits(MHCBindingComponentModelBase): """ Final model input that is a mhcflurry predictor trained on mass-spec hits and, optionally, affinity measurements (for example from IEDB). Parameters ------------ - - predictor_name : string - used on column name. Example: 'vanilla' - - experiment_to_alleles : dict: string -> string list - Normalized allele names for each experiment. - - experiment_to_expression_group : dict of string -> string - Maps experiment names to expression groups. - - transcripts : pandas.DataFrame - Index is peptide, columns are expression groups, values are - which transcript to use for the given peptide. - Not required if decoy_strategy specified. - - peptides_and_transcripts : pandas.DataFrame - Dataframe with columns 'peptide' and 'transcript' - Not required if decoy_strategy specified. - - decoy_strategy : decoy_strategy.DecoyStrategy - how to pick decoys. If not specified peptides_and_transcripts and - transcripts must be specified. - - fallback_predictor : function: (allele, peptides) -> predictions - Used when missing an allele. - iedb_dataset : mhcflurry.Dataset IEDB data for this allele. If not specified no iedb data is used. - decoys_per_hit : int - mhcflurry_hyperparameters : dict hit_affinity : float @@ -64,122 +33,43 @@ class MHCflurryTrainedOnHits(PresentationComponentModel): decoy_affinity : float nM affinity to use for decoys - random_peptides_for_percent_rank : list of string - If specified, then percentile rank will be calibrated and emitted - using the given peptides. + **kwargs : dict + Passed to MHCBindingComponentModel() """ def __init__( self, - predictor_name, - experiment_to_alleles, - experiment_to_expression_group=None, - transcripts=None, - peptides_and_transcripts=None, - decoy_strategy=None, - fallback_predictor=None, iedb_dataset=None, - decoys_per_hit=10, mhcflurry_hyperparameters=MHCFLURRY_DEFAULT_HYPERPARAMETERS, hit_affinity=100, decoy_affinity=20000, - random_peptides_for_percent_rank=None, **kwargs): - PresentationComponentModel.__init__(self, **kwargs) - self.predictor_name = predictor_name - self.experiment_to_alleles = experiment_to_alleles - self.fallback_predictor = fallback_predictor + MHCBindingComponentModelBase.__init__(self, **kwargs) self.iedb_dataset = iedb_dataset self.mhcflurry_hyperparameters = mhcflurry_hyperparameters self.hit_affinity = hit_affinity self.decoy_affinity = decoy_affinity - self.allele_to_model = None - - if decoy_strategy is None: - assert peptides_and_transcripts is not None - assert transcripts is not None - self.decoy_strategy = SameTranscriptsAsHits( - experiment_to_expression_group=experiment_to_expression_group, - peptides_and_transcripts=peptides_and_transcripts, - peptide_to_expression_group_to_transcript=transcripts, - decoys_per_hit=decoys_per_hit) - else: - self.decoy_strategy = decoy_strategy - - if random_peptides_for_percent_rank is None: - self.percent_rank_transforms = None - self.random_peptides_for_percent_rank = None - else: - self.percent_rank_transforms = {} - self.random_peptides_for_percent_rank = array( - random_peptides_for_percent_rank) + self.allele_to_model = {} def combine_ensemble_predictions(self, column_name, values): # Geometric mean return exp(nanmean(log(values), axis=1)) - def stratification_groups(self, hits_df): - return [ - self.experiment_to_alleles[e][0] - for e in hits_df.experiment_name - ] - - def column_name_affinity(self): - return "mhcflurry_%s_affinity" % self.predictor_name - - def column_name_percentile_rank(self): - return "mhcflurry_%s_percentile_rank" % self.predictor_name - - def column_names(self): - columns = [self.column_name_affinity()] - if self.percent_rank_transforms is not None: - columns.append(self.column_name_percentile_rank()) - return columns - - def requires_fitting(self): - return True - - def fit_percentile_rank_if_needed(self, alleles): - for allele in alleles: - if allele not in self.percent_rank_transforms: - logging.info('fitting percent rank for allele: %s' % allele) - self.percent_rank_transforms[allele] = PercentRankTransform() - self.percent_rank_transforms[allele].fit( - self.predict_affinity_for_allele( - allele, - self.random_peptides_for_percent_rank)) - - def fit(self, hits_df): - assert 'experiment_name' in hits_df.columns - assert 'peptide' in hits_df.columns - if 'hit' in hits_df.columns: - assert (hits_df.hit == 1).all() + def supports_predicting_allele(self, allele): + return allele in self.allele_to_model - grouped = hits_df.groupby("experiment_name") - for (experiment_name, sub_df) in grouped: - self.fit_to_experiment(experiment_name, sub_df.peptide.values) - - # No longer required after fitting. - self.decoy_strategy = None - self.iedb_dataset = None - - def fit_to_experiment(self, experiment_name, hit_list): - assert len(hit_list) > 0 + def fit_allele(self, allele, hit_list, decoys_list): if self.allele_to_model is None: self.allele_to_model = {} - alleles = self.experiment_to_alleles[experiment_name] - if len(alleles) != 1: - raise ValueError("Monoallelic data required") - - (allele,) = alleles - mhcflurry_allele = normalize_allele_name(allele) assert allele not in self.allele_to_model, \ "TODO: Support training on >1 experiments with same allele " \ + str(self.allele_to_model) + mhcflurry_allele = normalize_allele_name(allele) + extra_hits = hit_list = set(hit_list) iedb_dataset_df = None @@ -191,10 +81,9 @@ class MHCflurryTrainedOnHits(PresentationComponentModel): len(extra_hits), len(hit_list))) - decoys = self.decoy_strategy.decoys_for_experiment( - experiment_name, hit_list) - - df = pandas.DataFrame({"peptide": sorted(set(hit_list).union(decoys))}) + df = pandas.DataFrame({ + "peptide": sorted(set(hit_list).union(decoys_list)) + }) df["allele"] = mhcflurry_allele df["species"] = "human" df["affinity"] = (( @@ -215,72 +104,9 @@ class MHCflurryTrainedOnHits(PresentationComponentModel): model.fit_dataset(dataset, verbose=True) self.allele_to_model[allele] = model - def predict_affinity_for_allele(self, allele, peptides): - if self.cached_predictions is None: - cache_key = None - cached_result = None - else: - cache_key = ( - allele, - dataframe_cryptographic_hash(pandas.Series(peptides))) - cached_result = self.cached_predictions.get(cache_key) - if cached_result is not None: - print("Cache hit in predict_affinity_for_allele: %s %s %s" % ( - allele, str(self), id(cached_result))) - return cached_result - else: - print("Cache miss in predict_affinity_for_allele: %s %s" % ( - allele, str(self))) - - if allele in self.allele_to_model: - result = self.allele_to_model[allele].predict(peptides) - elif self.fallback_predictor: - print( - "MHCflurry: falling back on %s, " - "available alleles: %s" % ( - allele, ' '.join(self.allele_to_model))) - result = self.fallback_predictor(allele, peptides) - else: - raise ValueError("No model for allele: %s" % allele) - - if self.cached_predictions is not None: - self.cached_predictions[cache_key] = result - return result - - def predict_for_experiment(self, experiment_name, peptides): - assert self.allele_to_model is not None, "Must fit first" - - peptides_deduped = pandas.unique(peptides) - print(len(peptides_deduped)) - - alleles = self.experiment_to_alleles[experiment_name] - predictions = pandas.DataFrame(index=peptides_deduped) - for allele in alleles: - predictions[allele] = self.predict_affinity_for_allele( - allele, peptides_deduped) - - result = { - self.column_name_affinity(): ( - predictions.min(axis=1).ix[peptides].values) - } - if self.percent_rank_transforms is not None: - self.fit_percentile_rank_if_needed(alleles) - percentile_ranks = pandas.DataFrame(index=peptides_deduped) - for allele in alleles: - percentile_ranks[allele] = ( - self.percent_rank_transforms[allele] - .transform(predictions[allele].values)) - result[self.column_name_percentile_rank()] = ( - percentile_ranks.min(axis=1).ix[peptides].values) - assert all(len(x) == len(peptides) for x in result.values()), ( - "Result lengths don't match peptide lengths. peptides=%d, " - "peptides_deduped=%d, %s" % ( - len(peptides), - len(peptides_deduped), - ", ".join( - "%s=%d" % (key, len(value)) - for (key, value) in result.items()))) - return result + def predict_allele(self, allele, peptides_list): + assert self.allele_to_model, "Must fit first" + return self.allele_to_model[allele].predict(peptides_list) def get_fit(self): return { diff --git a/test/test_antigen_presentation.py b/test/test_antigen_presentation.py index c154460ea339869331a2ce6bbabb5b04aed1bcc3..96bc03627c580eea3e008d59c4b7140403684c70 100644 --- a/test/test_antigen_presentation.py +++ b/test/test_antigen_presentation.py @@ -96,21 +96,27 @@ def test_percent_rank_transform(): err_msg=str(model.__dict__)) -def test_mhcflurry_trained_on_hits(): - mhcflurry_model = presentation_component_models.MHCflurryTrainedOnHits( - "basic", +def mhcflurry_basic_model(): + return presentation_component_models.MHCflurryTrainedOnHits( + predictor_name="mhcflurry_affinity", experiment_to_alleles=EXPERIMENT_TO_ALLELES, experiment_to_expression_group=EXPERIMENT_TO_EXPRESSION_GROUP, transcripts=TRANSCIPTS_DF, peptides_and_transcripts=PEPTIDES_AND_TRANSCRIPTS_DF, random_peptides_for_percent_rank=OTHER_PEPTIDES, ) + + +def test_mhcflurry_trained_on_hits(): + mhcflurry_model = mhcflurry_basic_model() mhcflurry_model.fit(HITS_DF) peptides = PEPTIDES_DF.copy() predictions = mhcflurry_model.predict(peptides) - peptides["affinity"] = predictions["mhcflurry_basic_affinity"] - peptides["percent_rank"] = predictions["mhcflurry_basic_percentile_rank"] + peptides["affinity"] = predictions["mhcflurry_affinity_value"] + peptides["percent_rank"] = predictions[ + "mhcflurry_affinity_percentile_rank" + ] assert_less( peptides.affinity[peptides.hit].mean(), peptides.affinity[~peptides.hit].mean()) @@ -119,15 +125,25 @@ def test_mhcflurry_trained_on_hits(): peptides.percent_rank[~peptides.hit].mean()) +def compare_predictions(peptides_df, model1, model2): + predictions1 = model1.predict(peptides_df) + predictions2 = model2.predict(peptides_df) + failed = False + for i in range(len(peptides_df)): + if predictions1[i] != predictions2[i]: + failed = True + print( + "Compare predictions: mismatch at index %d: " + "%f != %f, row: %s" % ( + i, + predictions1[i], + predictions2[i], + str(peptides_df.iloc[i]))) + assert not failed + + def test_presentation_model(): - mhcflurry_model = presentation_component_models.MHCflurryTrainedOnHits( - "basic", - experiment_to_alleles=EXPERIMENT_TO_ALLELES, - experiment_to_expression_group=EXPERIMENT_TO_EXPRESSION_GROUP, - transcripts=TRANSCIPTS_DF, - peptides_and_transcripts=PEPTIDES_AND_TRANSCRIPTS_DF, - random_peptides_for_percent_rank=OTHER_PEPTIDES, - ) + mhcflurry_model = mhcflurry_basic_model() aa_content_model = ( presentation_component_models.FixedPerPeptideQuantity( @@ -141,7 +157,7 @@ def test_presentation_model(): terms = { 'A_ms': ( [mhcflurry_model], - ["log1p(mhcflurry_basic_affinity)"]), + ["log1p(mhcflurry_affinity_value)"]), 'P': ( [aa_content_model], list(AA_COMPOSITION_DF.columns)), @@ -170,14 +186,12 @@ def test_presentation_model(): peptides.prediction[peptides.hit].mean()) model2 = pickle.loads(pickle.dumps(model)) - assert_allclose( - model.predict(peptides), model2.predict(peptides)) + compare_predictions(peptides, model, model2) model3 = unfit_model.clone() assert not model3.has_been_fit model3.restore_fit(model2.get_fit()) - assert_allclose( - model.predict(peptides), model3.predict(peptides)) + compare_predictions(peptides, model, model3) better_unfit_model = models["A_ms + P"] model = better_unfit_model.clone()