diff --git a/mhcflurry/batch_generator.py b/mhcflurry/batch_generator.py
index 4b000ac4d1e9dbf956ecfe7366f03ca60f438ff4..eb44260a4b66aff772af9d557ea03f5b663aefe6 100644
--- a/mhcflurry/batch_generator.py
+++ b/mhcflurry/batch_generator.py
@@ -116,28 +116,24 @@ class MultiallelicMassSpecBatchGenerator(object):
     def plan_from_dataframe(df, hyperparameters):
         affinity_fraction = hyperparameters["batch_generator_affinity_fraction"]
         batch_size = hyperparameters["batch_generator_batch_size"]
-        equivalence_columns = ["is_affinity", "is_binder", "experiment_name"]
-        df["equivalence_key"] = df[equivalence_columns].astype(str).sum(1)
-        equivalence_map = dict(
-            (v, i)
-            for (i, v) in zip(*df.equivalence_key.factorize()))
-        df["equivalence_class"] = df.equivalence_key.map(equivalence_map)
         df["first_allele"] = df.alleles.str.get(0)
+        equivalence_columns = [
+            "is_affinity",
+            "is_binder",
+            "experiment_name",
+            "first_allele",
+        ]
+        df["equivalence_key"] = numpy.where(
+            df.is_affinity,
+            df.first_allele,
+            df.experiment_name,
+        ) + " " + df.is_binder.map({True: "binder", False: "nonbinder"})
+
+        (df["equivalence_class"], equivalence_class_labels) = (
+            df.equivalence_key.factorize())
         df["unused"] = True
         df["idx"] = df.index
-        equivalence_class_to_label = dict(
-            (idx, (
-                "{first_allele} {binder}" if row.is_affinity else
-                "{experiment_name} {binder}"
-                ).format(
-                    binder="binder" if row.is_binder else "nonbinder",
-                    **row.to_dict()))
-            for (idx, row) in df.drop_duplicates(
-                "equivalence_class").set_index("equivalence_class").iterrows())
         df = df.sample(frac=1.0)
-        #df["key"] = df.is_binder ^ (numpy.arange(len(df)) % 2).astype(bool)
-        #df = df.sort_values("key")
-        #del df["key"]
 
         affinities_per_batch = int(affinity_fraction * batch_size)
 
@@ -206,10 +202,7 @@ class MultiallelicMassSpecBatchGenerator(object):
         return BatchPlan(
             equivalence_classes=equivalence_classes,
             batch_compositions=batch_compositions,
-            equivalence_class_labels=[
-                equivalence_class_to_label[i] for i in
-                range(len(class_to_indices))
-            ])
+            equivalence_class_labels=equivalence_class_labels)
 
     def plan(
             self,
diff --git a/test/test_batch_generator.py b/test/test_batch_generator.py
index f5aac2aa491eb9986bc50b1a5ef135d09eab9b61..ac32ee0a9280de98591c77e87a0917bf68e1e5fd 100644
--- a/test/test_batch_generator.py
+++ b/test/test_batch_generator.py
@@ -137,7 +137,7 @@ def test_large(sample_rate=0.01):
     planner = MultiallelicMassSpecBatchGenerator(
         hyperparameters=dict(
             batch_generator_validation_split=0.2,
-            batch_generator_batch_size=1024,
+            batch_generator_batch_size=128,
             batch_generator_affinity_fraction=0.5))
 
     s = time.time()
@@ -168,6 +168,7 @@ def test_large(sample_rate=0.01):
             combined_train_df.loc[idx, "kind"] = kind
             combined_train_df.loc[idx, "idx"] = idx
             combined_train_df.loc[idx, "batch"] = i
+    import ipdb ; ipdb.set_trace()
     combined_train_df["idx"] = combined_train_df.idx.astype(int)
     combined_train_df["batch"] = combined_train_df.batch.astype(int)
 
diff --git a/test/test_class1_ligandome_predictor.py b/test/test_class1_ligandome_predictor.py
index 0561998ab4d114a320bd80d7b529baba4de766f0..f808438d822de635d736f325639c83805833694d 100644
--- a/test/test_class1_ligandome_predictor.py
+++ b/test/test_class1_ligandome_predictor.py
@@ -22,9 +22,10 @@ import pandas
 import argparse
 import sys
 import copy
-from functools import partial
+import os
 
 from numpy.testing import assert_, assert_equal, assert_allclose
+from nose.tools import assert_greater, assert_less
 import numpy
 from random import shuffle
 
@@ -47,6 +48,14 @@ PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = None
 PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF = None
 PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF = None
 
+def data_path(name):
+    '''
+    Return the absolute path to a file in the test/data directory.
+    The name specified should be relative to test/data.
+    '''
+    return os.path.join(os.path.dirname(__file__), "data", name)
+
+
 
 def setup():
     global PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
@@ -341,7 +350,7 @@ def Xtest_real_data_multiallelic_refinement(max_epochs=10):
     import ipdb ; ipdb.set_trace()
 
 
-def test_synthetic_allele_refinement_with_affinity_data(max_epochs=10):
+def Xtest_synthetic_allele_refinement_with_affinity_data(max_epochs=10):
     refine_allele = "HLA-C*01:02"
     alleles = [
         "HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01",
@@ -744,6 +753,77 @@ def Xtest_synthetic_allele_refinement(max_epochs=10):
     return (predictor, predictions, metrics, motifs)
 
 
+def test_refinemeent_large(sample_rate=0.1):
+    multi_train_df = pandas.read_csv(
+        data_path("multiallelic_ms.benchmark1.csv.bz2"))
+    multi_train_df["label"] = multi_train_df.hit
+    multi_train_df["is_affinity"] = False
+
+    sample_table = multi_train_df.loc[
+        multi_train_df.label == True
+    ].drop_duplicates("sample_id").set_index("sample_id").loc[
+        multi_train_df.sample_id.unique()
+    ]
+    grouped = multi_train_df.groupby("sample_id").nunique()
+    for col in sample_table.columns:
+        if (grouped[col] > 1).any():
+            del sample_table[col]
+    sample_table["alleles"] = sample_table.hla.str.split()
+
+    pan_train_df = pandas.read_csv(
+        get_path(
+            "models_class1_pan", "models.with_mass_spec/train_data.csv.bz2"))
+    pan_sub_train_df = pan_train_df
+    pan_sub_train_df["label"] = pan_sub_train_df["measurement_value"]
+    del pan_sub_train_df["measurement_value"]
+    pan_sub_train_df["is_affinity"] = True
+
+    pan_sub_train_df = pan_sub_train_df.sample(frac=sample_rate)
+    multi_train_df = multi_train_df.sample(frac=sample_rate)
+
+    pan_predictor = Class1AffinityPredictor.load(
+        get_path("models_class1_pan", "models.with_mass_spec"),
+        optimization_level=0,
+        max_models=1)
+
+    allele_encoding = MultipleAlleleEncoding(
+        experiment_names=multi_train_df.sample_id.values,
+        experiment_to_allele_list=sample_table.alleles.to_dict(),
+        max_alleles_per_experiment=sample_table.alleles.str.len().max(),
+        allele_to_sequence=pan_predictor.allele_to_sequence,
+    )
+    allele_encoding.append_alleles(pan_sub_train_df.allele.values)
+    allele_encoding = allele_encoding.compact()
+
+    combined_train_df = pandas.concat(
+        [multi_train_df, pan_sub_train_df], ignore_index=True, sort=True)
+
+    ligandome_predictor = Class1LigandomePredictor(
+        pan_predictor,
+        auxiliary_input_features=[],
+        max_ensemble_size=1,
+        max_epochs=0,
+        batch_generator_batch_size=128,
+        learning_rate=0.0001,
+        patience=5,
+        min_delta=0.0,
+        random_negative_rate=1.0)
+
+    fit_results = ligandome_predictor.fit(
+        peptides=combined_train_df.peptide.values,
+        labels=combined_train_df.label.values,
+        allele_encoding=allele_encoding,
+        affinities_mask=combined_train_df.is_affinity.values,
+        inequalities=combined_train_df.measurement_inequality.values,
+    )
+
+    batch_generator = fit_results['batch_generator']
+    train_batch_plan = batch_generator.train_batch_plan
+
+    assert_greater(len(train_batch_plan.equivalence_class_labels), 100)
+    assert_less(len(train_batch_plan.equivalence_class_labels), 1000)
+
+
 parser = argparse.ArgumentParser(usage=__doc__)
 parser.add_argument(
     "--out-metrics-csv",
@@ -760,6 +840,8 @@ parser.add_argument(
     help="Max epochs")
 
 
+
+
 if __name__ == '__main__':
     # If run directly from python, leave the user in a shell to explore results.
     setup()