diff --git a/mhcflurry/batch_generator.py b/mhcflurry/batch_generator.py
index eb44260a4b66aff772af9d557ea03f5b663aefe6..570cc8adcf430892e305d726bad8691ec0e8f9e9 100644
--- a/mhcflurry/batch_generator.py
+++ b/mhcflurry/batch_generator.py
@@ -117,82 +117,63 @@ class MultiallelicMassSpecBatchGenerator(object):
         affinity_fraction = hyperparameters["batch_generator_affinity_fraction"]
         batch_size = hyperparameters["batch_generator_batch_size"]
         df["first_allele"] = df.alleles.str.get(0)
-        equivalence_columns = [
-            "is_affinity",
-            "is_binder",
-            "experiment_name",
-            "first_allele",
-        ]
         df["equivalence_key"] = numpy.where(
             df.is_affinity,
             df.first_allele,
             df.experiment_name,
         ) + " " + df.is_binder.map({True: "binder", False: "nonbinder"})
-
         (df["equivalence_class"], equivalence_class_labels) = (
             df.equivalence_key.factorize())
-        df["unused"] = True
         df["idx"] = df.index
         df = df.sample(frac=1.0)
 
         affinities_per_batch = int(affinity_fraction * batch_size)
 
+        remaining_affinities_df = df.loc[df.is_affinity].copy()
+
         # First do mixed affinity / multiallelic ms batches_generator.
         batch_compositions = []
-        for experiment in df.loc[~df.is_affinity].experiment_name.unique():
-            if experiment is None:
-                continue
-            while True:
-                experiment_df = df.loc[
-                    df.unused & (df.experiment_name == experiment)]
-                if len(experiment_df) == 0:
-                    break
-                (experiment_alleles,) = experiment_df.alleles.unique()
-                affinities_df = df.loc[df.unused & df.is_affinity].copy()
-                affinities_df["matches_allele"] = (
-                    affinities_df.first_allele.isin(experiment_alleles))
-
-                # Whenever possible we try to use affinities with the same
-                # alleles as the mass spec experiment
-                affinities_df = affinities_df.sort_values(
-                    "matches_allele", ascending=False)
-
+        for (experiment, experiment_df) in df.loc[~df.is_affinity].groupby(
+                "experiment_name"):
+            (experiment_alleles,) = experiment_df.alleles.unique()
+            remaining_affinities_df["matches_allele"] = (
+                remaining_affinities_df.first_allele.isin(experiment_alleles))
+            # Whenever possible we try to use affinities with the same
+            # alleles as the mass spec experiment
+            remaining_affinities_df = remaining_affinities_df.sort_values(
+                "matches_allele", ascending=False)
+            while len(experiment_df) > 0:
                 affinities_for_this_batch = min(
-                    affinities_per_batch, len(affinities_df))
+                    affinities_per_batch, len(remaining_affinities_df))
                 mass_spec_for_this_batch = (
                     batch_size - affinities_for_this_batch)
                 if len(experiment_df) < mass_spec_for_this_batch:
                     mass_spec_for_this_batch = len(experiment_df)
                     affinities_for_this_batch = (
                             batch_size - mass_spec_for_this_batch)
-                    if affinities_for_this_batch < len(affinities_df):
-                        # For mass spec, we only do whole batches_generator, since it's
-                        # unclear how our pairwise loss would interact with
-                        # a smaller batch.
-                        break
 
-                to_use_list = []
+                batch_composition = []
 
-                # sample mass spec
-                to_use = experiment_df.head(mass_spec_for_this_batch)
-                to_use_list.append(to_use.index.values)
+                # take mass spec
+                to_use = experiment_df.iloc[:mass_spec_for_this_batch]
+                experiment_df = experiment_df.iloc[mass_spec_for_this_batch:]
+                batch_composition.extend(to_use.equivalence_class.values)
 
-                # sample affinities
-                to_use = affinities_df.head(affinities_for_this_batch)
-                to_use_list.append(to_use.index.values)
-
-                to_use_indices = numpy.concatenate(to_use_list)
-                df.loc[to_use_indices, "unused"] = False
-                batch_compositions.append(
-                    df.loc[to_use_indices].equivalence_class.values)
+                # take affinities
+                to_use = remaining_affinities_df.iloc[
+                    :affinities_for_this_batch
+                ]
+                remaining_affinities_df = remaining_affinities_df.iloc[
+                    affinities_for_this_batch:
+                ]
+                batch_composition.extend(to_use.equivalence_class.values)
+                batch_compositions.append(batch_composition)
 
         # Affinities-only batches
-        affinities_df = df.loc[df.unused & df.is_affinity]
-        while len(affinities_df) > 0:
-            to_use = affinities_df.head(batch_size)
-            df.loc[to_use.index, "unused"] = False
+        while len(remaining_affinities_df) > 0:
+            to_use = remaining_affinities_df.iloc[:batch_size]
+            remaining_affinities_df = remaining_affinities_df.iloc[batch_size:]
             batch_compositions.append(to_use.equivalence_class.values)
-            affinities_df = df.loc[df.unused & df.is_affinity]
 
         class_to_indices = df.groupby("equivalence_class").idx.unique()
         equivalence_classes = [
@@ -228,7 +209,8 @@ class MultiallelicMassSpecBatchGenerator(object):
         validation_items = numpy.random.choice(
             n if potential_validation_mask is None
                 else numpy.where(potential_validation_mask)[0],
-            int(self.hyperparameters['batch_generator_validation_split'] * n))
+            int(self.hyperparameters['batch_generator_validation_split'] * n),
+            replace=False)
         validation_mask = numpy.zeros(n, dtype=bool)
         validation_mask[validation_items] = True
 
diff --git a/test/test_batch_generator.py b/test/test_batch_generator.py
index ac32ee0a9280de98591c77e87a0917bf68e1e5fd..87f79b8ddf312cfc10df0bd0f1069d17be8a6f38 100644
--- a/test/test_batch_generator.py
+++ b/test/test_batch_generator.py
@@ -19,6 +19,7 @@ from mhcflurry.regression_target import to_ic50
 from mhcflurry import Class1AffinityPredictor
 
 from numpy.testing import assert_equal
+from nose.tools import assert_greater, assert_less
 
 
 def data_path(name):
@@ -29,20 +30,27 @@ def data_path(name):
     return os.path.join(os.path.dirname(__file__), "data", name)
 
 
+def test_basic_repeat():
+    for _ in range(100):
+        test_basic()
+
+
 def test_basic():
+    batch_size = 7
+    validation_split = 0.2
     planner = MultiallelicMassSpecBatchGenerator(
         hyperparameters=dict(
-            batch_generator_validation_split=0.2,
-            batch_generator_batch_size=10,
+            batch_generator_validation_split=validation_split,
+            batch_generator_batch_size=batch_size,
             batch_generator_affinity_fraction=0.5))
 
     exp1_alleles = ["HLA-A*03:01", "HLA-B*07:02", "HLA-C*02:01"]
     exp2_alleles = ["HLA-A*02:01", "HLA-B*27:01", "HLA-C*02:01"]
 
     df = pandas.DataFrame(dict(
-        affinities_mask=([True] * 4) + ([False] * 6),
-        experiment_names=([None] * 4) + (["exp1"] * 2) + (["exp2"] * 4),
-        alleles_matrix=[
+        affinities_mask=([True] * 14) + ([False] * 6),
+        experiment_names=([None] * 14) + (["exp1"] * 2) + (["exp2"] * 4),
+        alleles_matrix=[["HLA-C*07:01", None, None]] * 10 + [
             ["HLA-A*02:01", None, None],
             ["HLA-A*02:01", None, None],
             ["HLA-A*03:01", None, None],
@@ -54,11 +62,20 @@ def test_basic():
             exp2_alleles,
             exp2_alleles,
         ],
-        is_binder=[
+        is_binder=[False, True] * 5 + [
             True, True, False, False, True, False, True, False, True, False,
         ]))
+    df = pandas.concat([df, df], ignore_index=True)
+    df = pandas.concat([df, df], ignore_index=True)
+
     planner.plan(**df.to_dict("list"))
-    print(planner.summary())
+
+    assert_equal(
+        planner.num_train_batches,
+        numpy.ceil(len(df) * (1 - validation_split) / batch_size))
+    assert_equal(
+        planner.num_test_batches,
+        numpy.ceil(len(df) * validation_split / batch_size))
 
     (train_iter, test_iter) = planner.get_train_and_test_generators(
         x_dict={
@@ -74,20 +91,36 @@ def test_basic():
             df.loc[idx, "batch"] = i
     df["idx"] = df.idx.astype(int)
     df["batch"] = df.batch.astype(int)
-    print(df)
 
+    assert_equal(df.kind.value_counts()["test"], len(df) * validation_split)
+    assert_equal(df.kind.value_counts()["train"], len(df) * (1 - validation_split))
+
+    experiment_allele_colocations = collections.defaultdict(int)
     for ((kind, batch_num), batch_df) in df.groupby(["kind", "batch"]):
         if not batch_df.affinities_mask.all():
             # Test each batch has at most one multiallelic ms experiment.
-            assert_equal(
-                batch_df.loc[
-                    ~batch_df.affinities_mask
-                ].experiment_names.nunique(), 1)
-
-    #import ipdb;ipdb.set_trace()
-
-
-def test_large(sample_rate=0.01):
+            names = batch_df.loc[
+                ~batch_df.affinities_mask
+            ].experiment_names.unique()
+            assert_equal(len(names), 1)
+            (experiment,) = names
+            if batch_df.affinities_mask.any():
+                # Test experiments are matched to the correct affinity alleles.
+                affinity_alleles = batch_df.loc[
+                    batch_df.affinities_mask
+                ].alleles_matrix.str.get(0).values
+                for allele in affinity_alleles:
+                    experiment_allele_colocations[(experiment, allele)] += 1
+
+    assert_greater(
+        experiment_allele_colocations[('exp1', 'HLA-A*03:01')],
+        experiment_allele_colocations[('exp1', 'HLA-A*02:01')])
+    assert_less(
+        experiment_allele_colocations[('exp2', 'HLA-A*03:01')],
+        experiment_allele_colocations[('exp2', 'HLA-A*02:01')])
+
+
+def test_large(sample_rate=1.0):
     multi_train_df = pandas.read_csv(
         data_path("multiallelic_ms.benchmark1.csv.bz2"))
     multi_train_df["label"] = multi_train_df.hit
@@ -151,6 +184,7 @@ def test_large(sample_rate=0.01):
             combined_train_df.is_affinity.values,
             combined_train_df.label.values,
             to_ic50(combined_train_df.label.values)) < 1000.0)
+    profiler.disable()
     stats = pstats.Stats(profiler)
     stats.sort_stats("cumtime").reverse_order().print_stats()
     print(planner.summary())
@@ -162,20 +196,25 @@ def test_large(sample_rate=0.01):
         },
         y_list=[])
 
+    train_batch_sizes = []
+    indices_total = numpy.zeros(len(combined_train_df))
     for (kind, it) in [("train", train_iter), ("test", test_iter)]:
         for (i, (x_item, y_item)) in enumerate(it):
             idx = x_item["idx"]
-            combined_train_df.loc[idx, "kind"] = kind
-            combined_train_df.loc[idx, "idx"] = idx
-            combined_train_df.loc[idx, "batch"] = i
-    import ipdb ; ipdb.set_trace()
-    combined_train_df["idx"] = combined_train_df.idx.astype(int)
-    combined_train_df["batch"] = combined_train_df.batch.astype(int)
-
-    for ((kind, batch_num), batch_df) in combined_train_df.groupby(["kind", "batch"]):
-        if not batch_df.is_affinity.all():
-            # Test each batch has at most one multiallelic ms experiment.
-            assert_equal(
-                batch_df.loc[
-                    ~batch_df.is_affinity
-                ].sample_id.nunique(), 1)
\ No newline at end of file
+            indices_total[idx] += 1
+            batch_df = combined_train_df.iloc[idx]
+            if not batch_df.is_affinity.all():
+                # Test each batch has at most one multiallelic ms experiment.
+                assert_equal(
+                    batch_df.loc[~batch_df.is_affinity].sample_id.nunique(), 1)
+            if kind == "train":
+                train_batch_sizes.append(len(batch_df))
+
+    # At most one short batch.
+    assert_less(sum(b != 128 for b in train_batch_sizes), 2)
+    assert_greater(
+        sum(b == 128 for b in train_batch_sizes), len(train_batch_sizes) - 2)
+
+    # Each point used exactly once.
+    assert_equal(
+        indices_total, numpy.ones(len(combined_train_df)))
diff --git a/test/test_class1_ligandome_predictor.py b/test/test_class1_ligandome_predictor.py
index f808438d822de635d736f325639c83805833694d..3682476ac2f987ab26a69a8b73dc8abc17ec5266 100644
--- a/test/test_class1_ligandome_predictor.py
+++ b/test/test_class1_ligandome_predictor.py
@@ -753,7 +753,7 @@ def Xtest_synthetic_allele_refinement(max_epochs=10):
     return (predictor, predictions, metrics, motifs)
 
 
-def test_refinemeent_large(sample_rate=0.1):
+def test_batch_generator(sample_rate=0.1):
     multi_train_df = pandas.read_csv(
         data_path("multiallelic_ms.benchmark1.csv.bz2"))
     multi_train_df["label"] = multi_train_df.hit