diff --git a/mhcflurry/antigen_presentation/presentation_component_models/mhcflurry_trained_on_hits.py b/mhcflurry/antigen_presentation/presentation_component_models/mhcflurry_trained_on_hits.py
index 70f933c23727b8e739cc5da4813059a9c64f8a1f..cbb9b8d391324738261934a90fc42051c1c1c310 100644
--- a/mhcflurry/antigen_presentation/presentation_component_models/mhcflurry_trained_on_hits.py
+++ b/mhcflurry/antigen_presentation/presentation_component_models/mhcflurry_trained_on_hits.py
@@ -2,7 +2,7 @@ import logging
 from copy import copy
 
 import pandas
-import numpy
+from numpy import log, exp, nanmean, array
 
 from ...dataset import Dataset
 from ...class1_allele_specific import Class1BindingPredictor
@@ -113,9 +113,19 @@ class MHCflurryTrainedOnHits(PresentationComponentModel):
             self.random_peptides_for_percent_rank = None
         else:
             self.percent_rank_transforms = {}
-            self.random_peptides_for_percent_rank = numpy.array(
+            self.random_peptides_for_percent_rank = array(
                 random_peptides_for_percent_rank)
 
+    def combine_ensemble_predictions(self, column_name, values):
+        # Geometric mean
+        return exp(nanmean(log(values), axis=1))
+
+    def stratification_groups(self, hits_df):
+        return [
+            self.experiment_to_alleles[e][0]
+            for e in hits_df.experiment_name
+        ]
+
     def column_name_affinity(self):
         return "mhcflurry_%s_affinity" % self.predictor_name
 
diff --git a/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py b/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py
index de7fbc8172f51bacc603e8da999196baa7b8a1e6..308cde88c5a4e32fcbd5a794e395a72b6cc10238 100644
--- a/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py
+++ b/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py
@@ -48,6 +48,12 @@ class PresentationComponentModel(object):
         self.__dict__.update(state)
         self.reset_cache()
 
+    def combine_ensemble_predictions(self, column_name, values):
+        return numpy.nanmean(values, axis=1)
+
+    def stratification_groups(self, hits_df):
+        return hits_df.experiment_name
+
     def column_names(self):
         """
         Names for the values this final model input emits.
@@ -113,7 +119,7 @@ class PresentationComponentModel(object):
 
     def fit(self, hits_df):
         """
-        Train the final model input.
+        Train the model.
 
         Parameters
         -----------
@@ -229,9 +235,6 @@ class PresentationComponentModel(object):
             self.cached_predictions[cache_key] = return_value
         return return_value
 
-    def fit_ensemble_and_predict(self, peptides_df):
-        raise NotImplementedError
-
     def clone(self):
         """
         Copy this object so that the original and copy can be fit
diff --git a/mhcflurry/antigen_presentation/presentation_model.py b/mhcflurry/antigen_presentation/presentation_model.py
index a24899de1a83d5112dd4b684b6dd9da1d071df87..f64b5000bb21f438235e3754c6f7ce0586cc7d34 100644
--- a/mhcflurry/antigen_presentation/presentation_model.py
+++ b/mhcflurry/antigen_presentation/presentation_model.py
@@ -135,8 +135,9 @@ class PresentationModel(object):
             model.reset_cache()
         if self.trained_component_models is not None:
             for models in self.trained_component_models:
-                for model in models:
-                    model.reset_cache()
+                for ensemble_group in models:
+                    for model in ensemble_group:
+                        model.reset_cache()
 
     def fit(self, hits_df):
         """
@@ -155,11 +156,15 @@ class PresentationModel(object):
         assert self.trained_component_models is None
         assert self.presentation_models_predictors is None
 
-        hits_df = hits_df.reset_index(drop=True)
+        hits_df = hits_df.reset_index(drop=True).copy()
         self.fit_experiments = set(hits_df.experiment_name.unique())
 
         if self.component_models_require_fitting and not self.ensemble_size:
             # Use two fold CV to train model inputs then final models.
+            # In this strategy, we fit the component models on half the data,
+            # and train the final predictor (usually logistic regression) on
+            # the other half. We do this twice to end up with two final.
+            # At prediction time, the results of these predictors are averaged.
             cv = StratifiedKFold(
                 n_splits=2, shuffle=True, random_state=self.random_state)
 
@@ -173,107 +178,122 @@ class PresentationModel(object):
                 assert len(fold2) > 0
                 model_input_training_hits_df = hits_df.iloc[fold1]
 
-                self.trained_component_models.append([])
-                for sub_model in self.component_models:
-                    sub_model = sub_model.clone_and_fit(
-                        model_input_training_hits_df)
-                    self.trained_component_models[-1].append(sub_model)
-
-                final_predictor = self.fit_final_predictor(
-                    hits_df.iloc[fold2],
-                    self.trained_component_models[-1])
-                self.presentation_models_predictors.append(final_predictor)
-        elif self.component_models_require_fitting:
-            print("Using ensemble fit, ensemble size: %d" % self.ensemble_size)
-            raise NotImplementedError()
-
-            '''
-            hits_in_train = pandas.DataFrame(index=hits_df.index)
-            out_of_sample_predictions = [
-                []
-                for _ in self.component_models
-            ]
-            for i in range(self.ensemble_size):
-                print("Training ensemble %d / %d" % (
-                    i + 1, self.ensemble_size))
-
-                train_mask = numpy.random.randint(2, size=len(hits_df))
-
-                model_input_training_hits_df = hits_df.ix[train_mask]
-                presentation_model_training_hits_df = hits_df.ix[~train_mask]
-
                 hits_and_decoys_df = make_hits_and_decoys_df(
-                    presentation_model_training_hits_df,
+                    hits_df.iloc[fold2],
                     self.decoy_strategy)
 
                 self.trained_component_models.append([])
-                out_of_sample_predictions.append([])
                 for sub_model in self.component_models:
                     sub_model = sub_model.clone_and_fit(
                         model_input_training_hits_df)
-                    self.trained_component_models[-1].append(sub_model)
-
-                    predictions = sub_model.predict(presentation_model_training_hits_df)
-                    for (col, values) in predictions.items():
-                        presentation_model_training_hits_df[col] = values
-                    out_of_sample_predictions[-1].append()
-
+                    self.trained_component_models[-1].append((sub_model,))
+                    predictions = sub_model.predict(hits_and_decoys_df)
+                    for (col, values) in predictions.iteritems():
+                        hits_and_decoys_df[col] = values
+                final_predictor = self.fit_final_predictor(hits_and_decoys_df)
+                self.presentation_models_predictors.append(final_predictor)
+        else:
+            # Use an ensemble of component predictors. Each component model is
+            # trained on a random half of the data (self.ensemble_size folds
+            # in total). Predictions are generated using the out of bag
+            # predictors. A single final model predictor is trained.
+            if self.component_models_require_fitting:
+                print("Using ensemble fit, ensemble size: %d" % (
+                    self.ensemble_size))
+            else:
+                print("Using single fold fit.")
+
+            component_model_index_to_stratification_groups = []
+            stratification_groups_to_ensemble_folds = {}
+            for (i, component_model) in enumerate(self.component_models):
+                if component_model.requires_fitting():
+                    stratification_groups = tuple(
+                        component_model.stratification_groups(hits_df))
+                    component_model_index_to_stratification_groups.append(
+                        stratification_groups)
+                    stratification_groups_to_ensemble_folds[
+                        stratification_groups
+                    ] = []
+
+            for (i, (stratification_groups, ensemble_folds)) in enumerate(
+                    stratification_groups_to_ensemble_folds.items()):
+                print("Preparing folds for stratification group %d / %d" % (
+                    i + 1, len(stratification_groups_to_ensemble_folds)))
+                while len(ensemble_folds) < self.ensemble_size:
+                    cv = StratifiedKFold(
+                        n_splits=2,
+                        shuffle=True,
+                        random_state=self.random_state + len(ensemble_folds))
+                    for (indices, _) in cv.split(
+                            hits_df, stratification_groups):
+                        ensemble_folds.append(indices)
+
+                # We may have one extra fold.
+                if len(ensemble_folds) == self.ensemble_size + 1:
+                    ensemble_folds.pop()
+
+            def fit_and_predict_component(model, fit_df, predict_df):
+                assert component_model.requires_fitting()
+                model = component_model.clone_and_fit(fit_df)
+                predictions = model.predict(predict_df)
+                return (model, predictions)
+
+            # Note: we depend on hits coming before decoys here, so that
+            # indices into hits_df are also indices into hits_and_decoys_df.
             hits_and_decoys_df = make_hits_and_decoys_df(
-                hits_df,
-                self.decoy_strategy)
-
-            for sub_model in component_models:
-                predictions = sub_model.predict(hits_and_decoys_df)
+                hits_df, self.decoy_strategy)
+
+            self.trained_component_models = [[]]
+            for (i, component_model) in enumerate(self.component_models):
+                if component_model.requires_fitting():
+                    print("Training component model %d / %d: %s" % (
+                        i + 1, len(self.component_models), component_model))
+                    stratification_groups = (
+                        component_model_index_to_stratification_groups[i])
+                    ensemble_folds = stratification_groups_to_ensemble_folds[
+                        stratification_groups
+                    ]
+                    (models, predictions) = train_and_predict_ensemble(
+                        component_model,
+                        hits_and_decoys_df,
+                        ensemble_folds)
+                else:
+                    models = (component_model,)
+                    predictions = component_model.predict(hits_and_decoys_df)
+
+                self.trained_component_models[0].append(models)
                 for (col, values) in predictions.items():
                     hits_and_decoys_df[col] = values
 
-            (x, y) = self.make_features_and_target(hits_and_decoys_df)
-            print("Training final model predictor on data of shape %s" % (
-                str(x.shape)))
-            final_predictor = clone(self.predictor)
-            final_predictor.fit(x.values, y.values)
-            self.presentation_models_predictors.append(final_predictor)
-            '''
-        else:
-            print("Using single-fold fit.")
-            # Use full data set to train final model.
-            final_predictor = self.fit_final_predictor(
-                hits_df,
-                self.component_models)
-
-            assert not self.presentation_models_predictors
+            final_predictor = self.fit_final_predictor(hits_and_decoys_df)
             self.presentation_models_predictors = [final_predictor]
-            self.trained_component_models = [
-                self.component_models
-            ]
 
         assert len(self.presentation_models_predictors) == \
             len(self.trained_component_models)
 
+        for models_group in self.trained_component_models:
+            assert isinstance(models_group, list)
+            assert len(models_group) == len(self.component_models)
+            assert all(
+                isinstance(ensemble_group, tuple)
+                for ensemble_group in models_group)
+
         print("Fit final model in %0.1f sec." % (time.time() - start))
 
         # Decoy strategy is no longer required after fitting.
         self.decoy_strategy = None
 
-    def fit_final_predictor(self, hits_df, component_models):
+    def fit_final_predictor(
+            self, hits_and_decoys_with_component_predictions_df):
         """
         Private helper method.
         """
-        hits_and_decoys_df = make_hits_and_decoys_df(
-            hits_df,
-            self.decoy_strategy)
-
-        for sub_model in component_models:
-            predictions = sub_model.predict(hits_and_decoys_df)
-            for (col, values) in predictions.iteritems():
-                hits_and_decoys_df[col] = values
-
-        (x, y) = self.make_features_and_target(hits_and_decoys_df)
+        (x, y) = self.make_features_and_target(
+            hits_and_decoys_with_component_predictions_df)
         print("Training final model predictor on data of shape %s" % (
             str(x.shape)))
         final_predictor = clone(self.predictor)
         final_predictor.fit(x.values, y.values)
-
         return final_predictor
 
     def evaluate_expressions(self, input_df):
@@ -329,12 +349,18 @@ class PresentationModel(object):
                 self.presentation_models_predictors))
         for (i, (component_models, presentation_model_predictor)) in zipped:
             df = pandas.DataFrame()
-            for sub_model in component_models:
+            for ensemble_models in component_models:
                 start_t = time.time()
-                predictions = sub_model.predict(peptides_df)
-                print("Input '%s' generated %d predictions in %0.2f sec." % (
-                    sub_model, len(peptides_df), (time.time() - start_t)))
-                for (col, values) in predictions.iteritems():
+                predictions = ensemble_predictions(
+                    ensemble_models, peptides_df)
+                print(
+                    "Component '%s' (ensemble size=%d) generated %d "
+                    "predictions in %0.2f sec." % (
+                        ensemble_models[0],
+                        len(ensemble_models),
+                        len(peptides_df),
+                        (time.time() - start_t)))
+                for (col, values) in predictions.items():
                     values = pandas.Series(values)
                     assert_no_null(values)
                     df[col] = values
@@ -447,11 +473,11 @@ class PresentationModel(object):
             'fit_experiments': self.fit_experiments,
             'feature_expressions': self.feature_expressions,
         }
-        for models in self.trained_component_models:
-            result['trained_component_model_fits'].append([
-                component_model.get_fit()
-                for component_model in models
-            ])
+        for final_predictor_models_group in self.trained_component_models:
+            fits = []
+            for ensemble_group in final_predictor_models_group:
+                fits.append(tuple(model.get_fit() for model in ensemble_group))
+            result['trained_component_model_fits'].append(fits)
         return result
 
     def restore_fit(self, fit):
@@ -484,16 +510,13 @@ class PresentationModel(object):
         self.trained_component_models = []
         for model_input_fits_for_fold in model_input_fits:
             self.trained_component_models.append([])
-            for (sub_model, sub_model_fit) in zip(
+            for (sub_model, sub_model_fits) in zip(
                     self.component_models,
                     model_input_fits_for_fold):
-                sub_model = sub_model.clone_and_restore_fit(sub_model_fit)
-                self.trained_component_models[-1].append(
-                    sub_model)
-
-        assert len(self.trained_component_models) == (
-            2 if self.component_models_require_fitting else 1), (
-            "Wrong length: %s" % self.trained_component_models)
+                restored_models = tuple(
+                    sub_model.clone_and_restore_fit(sub_model_fit)
+                    for sub_model_fit in sub_model_fits)
+                self.trained_component_models[-1].append(restored_models)
 
 
 def make_hits_and_decoys_df(hits_df, decoy_strategy):
@@ -512,3 +535,40 @@ def make_hits_and_decoys_df(hits_df, decoy_strategy):
         [hits_df, decoys_df],
         ignore_index=True)
     return peptides_df
+
+
+# TODO: paralellize this.
+def train_and_predict_ensemble(model, peptides_df, ensemble_folds):
+    assert model.requires_fitting()
+    fit_models = tuple(
+        model.clone_and_fit(peptides_df.iloc[indices])
+        for indices in ensemble_folds)
+    return (
+        fit_models,
+        ensemble_predictions(fit_models, peptides_df, ensemble_folds))
+
+
+def ensemble_predictions(models, peptides_df, mask_indices_list=None):
+    typical_model = models[0]
+    panel = pandas.Panel(
+        items=numpy.arange(len(models)),
+        major_axis=peptides_df.index,
+        minor_axis=typical_model.column_names(),
+        dtype=numpy.float32)
+
+    for (i, model) in enumerate(models):
+        predictions = model.predict(peptides_df)
+        for (key, values) in predictions.items():
+            panel.loc[i, :, key] = values
+
+    if mask_indices_list is not None:
+        for (i, indices) in enumerate(mask_indices_list):
+            panel.iloc[i, indices] = numpy.nan
+
+    result = {}
+    for col in typical_model.column_names():
+        values = panel.ix[:, :, col]
+        assert values.shape == (len(peptides_df), len(models))
+        result[col] = model.combine_ensemble_predictions(col, values.values)
+        assert_no_null(result[col])
+    return result
diff --git a/mhcflurry/common.py b/mhcflurry/common.py
index 939624b6961fb788fe7a3da966bc7742036097f3..887d73b3ecc38f4575032a09a46b3e881593da2c 100644
--- a/mhcflurry/common.py
+++ b/mhcflurry/common.py
@@ -212,15 +212,17 @@ def raise_or_debug(exception):
 
 def assert_no_null(df, message=''):
     """
-    Raise an assertino error if the given DataFrame has any nan or inf values.
-    """
-    # start = time.time()
-    with pandas.option_context('mode.use_inf_as_null', True):
-        if df.count().sum() != df.size:
-            raise_or_debug(
-                AssertionError(
-                    "%s %s" % (message, describe_nulls(df))))
-    # print("Null check completed in %0.2f sec" % (time.time() - start))
+    Raise an assertion error if the given DataFrame has any nan or inf values.
+    """
+    if hasattr(df, 'count'):
+        with pandas.option_context('mode.use_inf_as_null', True):
+            failed = df.count().sum() != df.size
+    else:
+        failed = np.isnan(df).sum() > 0
+    if failed:
+        raise_or_debug(
+            AssertionError(
+                "%s %s" % (message, describe_nulls(df))))
 
 
 def drop_nulls_and_warn(df, related_df_with_same_index_to_describe=None):
diff --git a/test/test_antigen_presentation.py b/test/test_antigen_presentation.py
index d46d0f8b620ce6f5c52a34ff2e380e0a41841a4e..df59722626e501d3a61506c61c06ae99ed1641f8 100644
--- a/test/test_antigen_presentation.py
+++ b/test/test_antigen_presentation.py
@@ -125,17 +125,19 @@ def test_presentation_model():
             ["log1p(mhcflurry_basic_affinity)"]),
     }
 
-    models = presentation_model.build_presentation_models(
-        terms,
-        ["A_ms"],
-        decoy_strategy=decoys)
-    eq_(len(models), 1)
-
-    model = models["A_ms"]
-    model.fit(HITS_DF.ix[HITS_DF.experiment_name == "exp1"])
-
-    peptides = PEPTIDES_DF.copy()
-    peptides["prediction"] = model.predict(peptides)
-    assert_less(
-        peptides.prediction[~peptides.hit].mean(),
-        peptides.prediction[peptides.hit].mean())
+    for kwargs in [{}, {'ensemble_size': 6}]:
+        models = presentation_model.build_presentation_models(
+            terms,
+            ["A_ms"],
+            decoy_strategy=decoys,
+            **kwargs)
+        eq_(len(models), 1)
+
+        model = models["A_ms"]
+        model.fit(HITS_DF.ix[HITS_DF.experiment_name == "exp1"])
+
+        peptides = PEPTIDES_DF.copy()
+        peptides["prediction"] = model.predict(peptides)
+        assert_less(
+            peptides.prediction[~peptides.hit].mean(),
+            peptides.prediction[peptides.hit].mean())