diff --git a/mhcflurry/antigen_presentation/presentation_component_models/mhcflurry_trained_on_hits.py b/mhcflurry/antigen_presentation/presentation_component_models/mhcflurry_trained_on_hits.py
index cbb9b8d391324738261934a90fc42051c1c1c310..0f0177e03cf4abbd344b4935cefb8aa01d15a4e3 100644
--- a/mhcflurry/antigen_presentation/presentation_component_models/mhcflurry_trained_on_hits.py
+++ b/mhcflurry/antigen_presentation/presentation_component_models/mhcflurry_trained_on_hits.py
@@ -1,16 +1,13 @@
-import logging
 from copy import copy
 
 import pandas
-from numpy import log, exp, nanmean, array
+from numpy import log, exp, nanmean
 
 from ...dataset import Dataset
 from ...class1_allele_specific import Class1BindingPredictor
-from ...common import normalize_allele_name, dataframe_cryptographic_hash
+from ...common import normalize_allele_name
 
-from .presentation_component_model import PresentationComponentModel
-from ..decoy_strategies import SameTranscriptsAsHits
-from ..percent_rank_transform import PercentRankTransform
+from .mhc_binding_component_model_base import MHCBindingComponentModelBase
 
 
 MHCFLURRY_DEFAULT_HYPERPARAMETERS = dict(
@@ -18,44 +15,16 @@ MHCFLURRY_DEFAULT_HYPERPARAMETERS = dict(
     dropout_probability=0.25)
 
 
-class MHCflurryTrainedOnHits(PresentationComponentModel):
+class MHCflurryTrainedOnHits(MHCBindingComponentModelBase):
     """
     Final model input that is a mhcflurry predictor trained on mass-spec
     hits and, optionally, affinity measurements (for example from IEDB).
 
     Parameters
     ------------
-
-    predictor_name : string
-        used on column name. Example: 'vanilla'
-
-    experiment_to_alleles : dict: string -> string list
-        Normalized allele names for each experiment.
-
-    experiment_to_expression_group : dict of string -> string
-        Maps experiment names to expression groups.
-
-    transcripts : pandas.DataFrame
-        Index is peptide, columns are expression groups, values are
-        which transcript to use for the given peptide.
-        Not required if decoy_strategy specified.
-
-    peptides_and_transcripts : pandas.DataFrame
-        Dataframe with columns 'peptide' and 'transcript'
-        Not required if decoy_strategy specified.
-
-    decoy_strategy : decoy_strategy.DecoyStrategy
-        how to pick decoys. If not specified peptides_and_transcripts and
-        transcripts must be specified.
-
-    fallback_predictor : function: (allele, peptides) -> predictions
-        Used when missing an allele.
-
     iedb_dataset : mhcflurry.Dataset
         IEDB data for this allele. If not specified no iedb data is used.
 
-    decoys_per_hit : int
-
     mhcflurry_hyperparameters : dict
 
     hit_affinity : float
@@ -64,122 +33,43 @@ class MHCflurryTrainedOnHits(PresentationComponentModel):
     decoy_affinity : float
         nM affinity to use for decoys
 
-    random_peptides_for_percent_rank : list of string
-        If specified, then percentile rank will be calibrated and emitted
-        using the given peptides.
+    **kwargs : dict
+        Passed to MHCBindingComponentModel()
     """
 
     def __init__(
             self,
-            predictor_name,
-            experiment_to_alleles,
-            experiment_to_expression_group=None,
-            transcripts=None,
-            peptides_and_transcripts=None,
-            decoy_strategy=None,
-            fallback_predictor=None,
             iedb_dataset=None,
-            decoys_per_hit=10,
             mhcflurry_hyperparameters=MHCFLURRY_DEFAULT_HYPERPARAMETERS,
             hit_affinity=100,
             decoy_affinity=20000,
-            random_peptides_for_percent_rank=None,
             **kwargs):
 
-        PresentationComponentModel.__init__(self, **kwargs)
-        self.predictor_name = predictor_name
-        self.experiment_to_alleles = experiment_to_alleles
-        self.fallback_predictor = fallback_predictor
+        MHCBindingComponentModelBase.__init__(self, **kwargs)
         self.iedb_dataset = iedb_dataset
         self.mhcflurry_hyperparameters = mhcflurry_hyperparameters
         self.hit_affinity = hit_affinity
         self.decoy_affinity = decoy_affinity
 
-        self.allele_to_model = None
-
-        if decoy_strategy is None:
-            assert peptides_and_transcripts is not None
-            assert transcripts is not None
-            self.decoy_strategy = SameTranscriptsAsHits(
-                experiment_to_expression_group=experiment_to_expression_group,
-                peptides_and_transcripts=peptides_and_transcripts,
-                peptide_to_expression_group_to_transcript=transcripts,
-                decoys_per_hit=decoys_per_hit)
-        else:
-            self.decoy_strategy = decoy_strategy
-
-        if random_peptides_for_percent_rank is None:
-            self.percent_rank_transforms = None
-            self.random_peptides_for_percent_rank = None
-        else:
-            self.percent_rank_transforms = {}
-            self.random_peptides_for_percent_rank = array(
-                random_peptides_for_percent_rank)
+        self.allele_to_model = {}
 
     def combine_ensemble_predictions(self, column_name, values):
         # Geometric mean
         return exp(nanmean(log(values), axis=1))
 
-    def stratification_groups(self, hits_df):
-        return [
-            self.experiment_to_alleles[e][0]
-            for e in hits_df.experiment_name
-        ]
-
-    def column_name_affinity(self):
-        return "mhcflurry_%s_affinity" % self.predictor_name
-
-    def column_name_percentile_rank(self):
-        return "mhcflurry_%s_percentile_rank" % self.predictor_name
-
-    def column_names(self):
-        columns = [self.column_name_affinity()]
-        if self.percent_rank_transforms is not None:
-            columns.append(self.column_name_percentile_rank())
-        return columns
-
-    def requires_fitting(self):
-        return True
-
-    def fit_percentile_rank_if_needed(self, alleles):
-        for allele in alleles:
-            if allele not in self.percent_rank_transforms:
-                logging.info('fitting percent rank for allele: %s' % allele)
-                self.percent_rank_transforms[allele] = PercentRankTransform()
-                self.percent_rank_transforms[allele].fit(
-                    self.predict_affinity_for_allele(
-                        allele,
-                        self.random_peptides_for_percent_rank))
-
-    def fit(self, hits_df):
-        assert 'experiment_name' in hits_df.columns
-        assert 'peptide' in hits_df.columns
-        if 'hit' in hits_df.columns:
-            assert (hits_df.hit == 1).all()
+    def supports_predicting_allele(self, allele):
+        return allele in self.allele_to_model
 
-        grouped = hits_df.groupby("experiment_name")
-        for (experiment_name, sub_df) in grouped:
-            self.fit_to_experiment(experiment_name, sub_df.peptide.values)
-
-        # No longer required after fitting.
-        self.decoy_strategy = None
-        self.iedb_dataset = None
-
-    def fit_to_experiment(self, experiment_name, hit_list):
-        assert len(hit_list) > 0
+    def fit_allele(self, allele, hit_list, decoys_list):
         if self.allele_to_model is None:
             self.allele_to_model = {}
 
-        alleles = self.experiment_to_alleles[experiment_name]
-        if len(alleles) != 1:
-            raise ValueError("Monoallelic data required")
-
-        (allele,) = alleles
-        mhcflurry_allele = normalize_allele_name(allele)
         assert allele not in self.allele_to_model, \
             "TODO: Support training on >1 experiments with same allele " \
             + str(self.allele_to_model)
 
+        mhcflurry_allele = normalize_allele_name(allele)
+
         extra_hits = hit_list = set(hit_list)
 
         iedb_dataset_df = None
@@ -191,10 +81,9 @@ class MHCflurryTrainedOnHits(PresentationComponentModel):
                 len(extra_hits),
                 len(hit_list)))
 
-        decoys = self.decoy_strategy.decoys_for_experiment(
-            experiment_name, hit_list)
-
-        df = pandas.DataFrame({"peptide": sorted(set(hit_list).union(decoys))})
+        df = pandas.DataFrame({
+            "peptide": sorted(set(hit_list).union(decoys_list))
+        })
         df["allele"] = mhcflurry_allele
         df["species"] = "human"
         df["affinity"] = ((
@@ -215,72 +104,9 @@ class MHCflurryTrainedOnHits(PresentationComponentModel):
         model.fit_dataset(dataset, verbose=True)
         self.allele_to_model[allele] = model
 
-    def predict_affinity_for_allele(self, allele, peptides):
-        if self.cached_predictions is None:
-            cache_key = None
-            cached_result = None
-        else:
-            cache_key = (
-                allele,
-                dataframe_cryptographic_hash(pandas.Series(peptides)))
-            cached_result = self.cached_predictions.get(cache_key)
-        if cached_result is not None:
-            print("Cache hit in predict_affinity_for_allele: %s %s %s" % (
-                allele, str(self), id(cached_result)))
-            return cached_result
-        else:
-            print("Cache miss in predict_affinity_for_allele: %s %s" % (
-                allele, str(self)))
-
-        if allele in self.allele_to_model:
-            result = self.allele_to_model[allele].predict(peptides)
-        elif self.fallback_predictor:
-            print(
-                "MHCflurry: falling back on %s, "
-                "available alleles: %s" % (
-                    allele, ' '.join(self.allele_to_model)))
-            result = self.fallback_predictor(allele, peptides)
-        else:
-            raise ValueError("No model for allele: %s" % allele)
-
-        if self.cached_predictions is not None:
-            self.cached_predictions[cache_key] = result
-        return result
-
-    def predict_for_experiment(self, experiment_name, peptides):
-        assert self.allele_to_model is not None, "Must fit first"
-
-        peptides_deduped = pandas.unique(peptides)
-        print(len(peptides_deduped))
-
-        alleles = self.experiment_to_alleles[experiment_name]
-        predictions = pandas.DataFrame(index=peptides_deduped)
-        for allele in alleles:
-            predictions[allele] = self.predict_affinity_for_allele(
-                allele, peptides_deduped)
-
-        result = {
-            self.column_name_affinity(): (
-                predictions.min(axis=1).ix[peptides].values)
-        }
-        if self.percent_rank_transforms is not None:
-            self.fit_percentile_rank_if_needed(alleles)
-            percentile_ranks = pandas.DataFrame(index=peptides_deduped)
-            for allele in alleles:
-                percentile_ranks[allele] = (
-                    self.percent_rank_transforms[allele]
-                    .transform(predictions[allele].values))
-            result[self.column_name_percentile_rank()] = (
-                percentile_ranks.min(axis=1).ix[peptides].values)
-        assert all(len(x) == len(peptides) for x in result.values()), (
-            "Result lengths don't match peptide lengths. peptides=%d, "
-            "peptides_deduped=%d, %s" % (
-                len(peptides),
-                len(peptides_deduped),
-                ", ".join(
-                    "%s=%d" % (key, len(value))
-                    for (key, value) in result.items())))
-        return result
+    def predict_allele(self, allele, peptides_list):
+        assert self.allele_to_model, "Must fit first"
+        return self.allele_to_model[allele].predict(peptides_list)
 
     def get_fit(self):
         return {
diff --git a/test/test_antigen_presentation.py b/test/test_antigen_presentation.py
index c154460ea339869331a2ce6bbabb5b04aed1bcc3..96bc03627c580eea3e008d59c4b7140403684c70 100644
--- a/test/test_antigen_presentation.py
+++ b/test/test_antigen_presentation.py
@@ -96,21 +96,27 @@ def test_percent_rank_transform():
         err_msg=str(model.__dict__))
 
 
-def test_mhcflurry_trained_on_hits():
-    mhcflurry_model = presentation_component_models.MHCflurryTrainedOnHits(
-        "basic",
+def mhcflurry_basic_model():
+    return presentation_component_models.MHCflurryTrainedOnHits(
+        predictor_name="mhcflurry_affinity",
         experiment_to_alleles=EXPERIMENT_TO_ALLELES,
         experiment_to_expression_group=EXPERIMENT_TO_EXPRESSION_GROUP,
         transcripts=TRANSCIPTS_DF,
         peptides_and_transcripts=PEPTIDES_AND_TRANSCRIPTS_DF,
         random_peptides_for_percent_rank=OTHER_PEPTIDES,
     )
+
+
+def test_mhcflurry_trained_on_hits():
+    mhcflurry_model = mhcflurry_basic_model()
     mhcflurry_model.fit(HITS_DF)
 
     peptides = PEPTIDES_DF.copy()
     predictions = mhcflurry_model.predict(peptides)
-    peptides["affinity"] = predictions["mhcflurry_basic_affinity"]
-    peptides["percent_rank"] = predictions["mhcflurry_basic_percentile_rank"]
+    peptides["affinity"] = predictions["mhcflurry_affinity_value"]
+    peptides["percent_rank"] = predictions[
+        "mhcflurry_affinity_percentile_rank"
+    ]
     assert_less(
         peptides.affinity[peptides.hit].mean(),
         peptides.affinity[~peptides.hit].mean())
@@ -119,15 +125,25 @@ def test_mhcflurry_trained_on_hits():
         peptides.percent_rank[~peptides.hit].mean())
 
 
+def compare_predictions(peptides_df, model1, model2):
+    predictions1 = model1.predict(peptides_df)
+    predictions2 = model2.predict(peptides_df)
+    failed = False
+    for i in range(len(peptides_df)):
+        if predictions1[i] != predictions2[i]:
+            failed = True
+            print(
+                "Compare predictions: mismatch at index %d: "
+                "%f != %f, row: %s" % (
+                    i,
+                    predictions1[i],
+                    predictions2[i],
+                    str(peptides_df.iloc[i])))
+    assert not failed
+
+
 def test_presentation_model():
-    mhcflurry_model = presentation_component_models.MHCflurryTrainedOnHits(
-        "basic",
-        experiment_to_alleles=EXPERIMENT_TO_ALLELES,
-        experiment_to_expression_group=EXPERIMENT_TO_EXPRESSION_GROUP,
-        transcripts=TRANSCIPTS_DF,
-        peptides_and_transcripts=PEPTIDES_AND_TRANSCRIPTS_DF,
-        random_peptides_for_percent_rank=OTHER_PEPTIDES,
-    )
+    mhcflurry_model = mhcflurry_basic_model()
 
     aa_content_model = (
         presentation_component_models.FixedPerPeptideQuantity(
@@ -141,7 +157,7 @@ def test_presentation_model():
     terms = {
         'A_ms': (
             [mhcflurry_model],
-            ["log1p(mhcflurry_basic_affinity)"]),
+            ["log1p(mhcflurry_affinity_value)"]),
         'P': (
             [aa_content_model],
             list(AA_COMPOSITION_DF.columns)),
@@ -170,14 +186,12 @@ def test_presentation_model():
             peptides.prediction[peptides.hit].mean())
 
         model2 = pickle.loads(pickle.dumps(model))
-        assert_allclose(
-            model.predict(peptides), model2.predict(peptides))
+        compare_predictions(peptides, model, model2)
 
         model3 = unfit_model.clone()
         assert not model3.has_been_fit
         model3.restore_fit(model2.get_fit())
-        assert_allclose(
-            model.predict(peptides), model3.predict(peptides))
+        compare_predictions(peptides, model, model3)
 
         better_unfit_model = models["A_ms + P"]
         model = better_unfit_model.clone()