diff --git a/downloads-generation/models_class1_pan_refined/hyperparameters.yaml b/downloads-generation/models_class1_pan_refined/hyperparameters.yaml
index a1747f7ada068f5f37c10560b020e168e706222b..f165687660c9d36d217d32422a582d41faab0cf1 100644
--- a/downloads-generation/models_class1_pan_refined/hyperparameters.yaml
+++ b/downloads-generation/models_class1_pan_refined/hyperparameters.yaml
@@ -1,13 +1,14 @@
 #########################
 # Batch generation
 #########################
+batch_generator: simple
 batch_generator_validation_split: 0.1
 batch_generator_batch_size: 1024
 batch_generator_affinity_fraction: 0.5
 max_epochs: 500
 random_negative_rate: 1.0
 random_negative_constant: 25
-learning_rate: 0.0001
+learning_rate: 0.001
 patience: 5
 min_delta: 0.0
 loss_multiallelic_mass_spec_multiplier: 10
diff --git a/mhcflurry/batch_generator.py b/mhcflurry/batch_generator.py
index 83d495f02c9d79ed27a7bae3d5849c2de990fd00..9a6d3789daebc6013e7eea279c79ed5429ffe3ea 100644
--- a/mhcflurry/batch_generator.py
+++ b/mhcflurry/batch_generator.py
@@ -71,9 +71,9 @@ class BatchPlan(object):
         lines = []
         equivalence_class_labels = self.equivalence_class_labels
         if equivalence_class_labels is None:
-            equivalence_class_labels = (
-                "class-" + numpy.arange(self.equivalence_classes).astype("str"))
-
+            equivalence_class_labels = numpy.array([
+                "class-%d" % i for i in range(len(self.equivalence_classes))
+            ])
         i = 0
         while i < len(self.batch_compositions):
             composition = self.batch_compositions[i]
@@ -100,18 +100,111 @@ class BatchPlan(object):
         return max(len(b) for b in self.batch_compositions)
 
 
-class MultiallelicMassSpecBatchGenerator(object):
+class BatchGenerator(object):
+    implementations = {}
     hyperparameter_defaults = HyperparameterDefaults(
+        batch_generator="simple",
         batch_generator_validation_split=0.1,
-        batch_generator_batch_size=128,
+        batch_generator_batch_size=128)
+
+    @staticmethod
+    def register_implementation(name, klass):
+        BatchGenerator.implementations[name] = klass
+        BatchGenerator.hyperparameter_defaults = (
+            BatchGenerator.hyperparameter_defaults.extend(
+                klass.hyperparameter_defaults))
+
+    @staticmethod
+    def create(hyperparameters):
+        name = hyperparameters['batch_generator']
+        return BatchGenerator.implementations[name](hyperparameters)
+
+    def __init__(self, hyperparameters):
+        self.hyperparameters = BatchGenerator.hyperparameter_defaults.with_defaults(
+            hyperparameters)
+        self.train_batch_plan = None
+        self.test_batch_plan = None
+
+    def plan(self, *args, **kwargs):
+        raise NotImplementedError()
+
+    def summary(self):
+        return (
+            "Train:\n" + self.train_batch_plan.summary(indent=1) +
+            "\n***\nTest: " + self.test_batch_plan.summary(indent=1))
+
+    def get_train_and_test_generators(self, x_dict, y_list, epochs=1):
+        train_generator = self.train_batch_plan.batches_generator(
+            x_dict, y_list, epochs=epochs)
+        test_generator = self.test_batch_plan.batches_generator(
+            x_dict, y_list, epochs=epochs)
+        return (train_generator, test_generator)
+
+    @property
+    def num_train_batches(self):
+        return self.train_batch_plan.num_batches
+
+    @property
+    def num_test_batches(self):
+        return self.test_batch_plan.num_batches
+
+
+class SimpleBatchGenerator(BatchGenerator):
+    hyperparameter_defaults = HyperparameterDefaults()
+
+    def __init__(self, hyperparameters):
+        BatchGenerator.__init__(self, hyperparameters)
+
+    def plan(self, num, validation_weights=None, **kwargs):
+        if validation_weights is not None:
+            validation_weights = numpy.array(
+                validation_weights, copy=True, dtype=float)
+            numpy.testing.assert_equal(len(validation_weights), num)
+            validation_weights /= validation_weights.sum()
+
+        validation_items = numpy.random.choice(
+            num,
+            int((self.hyperparameters['batch_generator_validation_split']) * num),
+            replace=False,
+            p=validation_weights)
+        validation_items_set = set(validation_items)
+        numpy.testing.assert_equal(
+            len(validation_items), len(validation_items_set))
+        training_items = numpy.array([
+            x for x in range(num) if x not in validation_items_set
+        ], dtype=int)
+        numpy.testing.assert_equal(
+            len(validation_items) + len(training_items), num)
+
+        def simple_compositions(
+                num,
+                num_per_batch=self.hyperparameters['batch_generator_batch_size']):
+            full_batch = numpy.zeros(num_per_batch, dtype=int)
+            result = [full_batch] * int(numpy.floor(num / num_per_batch))
+            if num % num_per_batch != 0:
+                result.append(numpy.zeros(num % num_per_batch, dtype=int))
+            numpy.testing.assert_equal(sum(len(x) for x in result), num)
+            return result
+
+        self.train_batch_plan = BatchPlan(
+            equivalence_classes=[training_items],
+            batch_compositions=simple_compositions(len(training_items)))
+        self.test_batch_plan = BatchPlan(
+            equivalence_classes=[validation_items],
+            batch_compositions=simple_compositions(len(validation_items)))
+
+
+BatchGenerator.register_implementation("simple", SimpleBatchGenerator)
+
+class MultiallelicMassSpecBatchGenerator(BatchGenerator):
+    hyperparameter_defaults = HyperparameterDefaults(
         batch_generator_affinity_fraction=0.5)
     """
     Hyperperameters for batch generation for the presentation predictor.
     """
 
     def __init__(self, hyperparameters):
-        self.hyperparameters = self.hyperparameter_defaults.with_defaults(
-            hyperparameters)
+        BatchGenerator.__init__(self, hyperparameters)
         self.equivalence_classes = None
         self.batch_indices = None
 
@@ -194,18 +287,19 @@ class MultiallelicMassSpecBatchGenerator(object):
             experiment_names,
             alleles_matrix,
             is_binder,
-            validation_weights=None):
+            validation_weights=None,
+            num=None):
         affinities_mask = numpy.array(affinities_mask, copy=False, dtype=bool)
         experiment_names = numpy.array(experiment_names, copy=False)
         alleles_matrix = numpy.array(alleles_matrix, copy=False)
         is_binder = numpy.array(is_binder, copy=False, dtype=bool)
         n = len(experiment_names)
+        if num is not None:
+            numpy.testing.assert_equal(num, n)
 
         numpy.testing.assert_equal(len(affinities_mask), n)
         numpy.testing.assert_equal(len(alleles_matrix), n)
         numpy.testing.assert_equal(len(is_binder), n)
-        numpy.testing.assert_equal(
-            affinities_mask, pandas.isnull(experiment_names))
 
         if validation_weights is not None:
             validation_weights = numpy.array(
@@ -238,22 +332,5 @@ class MultiallelicMassSpecBatchGenerator(object):
         self.test_batch_plan = self.plan_from_dataframe(
             test_df, self.hyperparameters)
 
-    def summary(self):
-        return (
-            "Train:\n" + self.train_batch_plan.summary(indent=1) +
-            "\n***\nTest: " + self.test_batch_plan.summary(indent=1))
-
-    def get_train_and_test_generators(self, x_dict, y_list, epochs=1):
-        train_generator = self.train_batch_plan.batches_generator(
-            x_dict, y_list, epochs=epochs)
-        test_generator = self.test_batch_plan.batches_generator(
-            x_dict, y_list, epochs=epochs)
-        return (train_generator, test_generator)
-
-    @property
-    def num_train_batches(self):
-        return self.train_batch_plan.num_batches
-
-    @property
-    def num_test_batches(self):
-        return self.test_batch_plan.num_batches
+BatchGenerator.register_implementation(
+    "multiallelic_mass_spec", MultiallelicMassSpecBatchGenerator)
diff --git a/mhcflurry/class1_presentation_neural_network.py b/mhcflurry/class1_presentation_neural_network.py
index 2e1156e619deaa4762a0a1146b12a157f5f50422..970c52867a9e687512da8a6d6a12b38d7dc610ac 100644
--- a/mhcflurry/class1_presentation_neural_network.py
+++ b/mhcflurry/class1_presentation_neural_network.py
@@ -16,7 +16,7 @@ from .regression_target import from_ic50, to_ic50
 from .random_negative_peptides import RandomNegativePeptides
 from .allele_encoding import MultipleAlleleEncoding, AlleleEncoding
 from .auxiliary_input import AuxiliaryInputEncoder
-from .batch_generator import MultiallelicMassSpecBatchGenerator
+from .batch_generator import BatchGenerator
 from .custom_loss import (
     MSEWithInequalities,
     TransformPredictionsLossWrapper,
@@ -43,7 +43,7 @@ class Class1PresentationNeuralNetwork(object):
         early_stopping=True,
         random_negative_affinity_min=20000.0,).extend(
         RandomNegativePeptides.hyperparameter_defaults).extend(
-        MultiallelicMassSpecBatchGenerator.hyperparameter_defaults
+        BatchGenerator.hyperparameter_defaults
     )
     """
     Hyperparameters for neural network training.
@@ -470,11 +470,12 @@ class Class1PresentationNeuralNetwork(object):
         if verbose:
             self.network.summary()
 
-        batch_generator = MultiallelicMassSpecBatchGenerator(
-            MultiallelicMassSpecBatchGenerator.hyperparameter_defaults.subselect(
+        batch_generator = BatchGenerator.create(
+            hyperparameters=BatchGenerator.hyperparameter_defaults.subselect(
                 self.hyperparameters))
         start = time.time()
         batch_generator.plan(
+            num=len(peptides) + num_random_negatives,
             affinities_mask=numpy.concatenate([
                 numpy.tile(True, num_random_negatives),
                 affinities_mask
@@ -669,7 +670,7 @@ class Class1PresentationNeuralNetwork(object):
         assert isinstance(allele_encoding, MultipleAlleleEncoding)
 
         (allele_encoding_input, allele_representations) = (
-                self.allele_encoding_to_network_input(allele_encoding.compact()))
+                self.allele_encoding_to_network_input(allele_encoding))
         self.set_allele_representations(allele_representations)
         x_dict = {
             'peptide': self.peptides_to_network_input(peptides),
diff --git a/mhcflurry/downloads.yml b/mhcflurry/downloads.yml
index 286121fd3e8a21953cf9fb495bd6a008d1807283..789efcd20bddf88d85b32975e8176bc0c29fbedf 100644
--- a/mhcflurry/downloads.yml
+++ b/mhcflurry/downloads.yml
@@ -30,7 +30,7 @@ releases:
               default: false
 
             - name: models_class1_pan_refined
-              url: https://github.com/openvax/mhcflurry/releases/download/1.4.0/models_class1_pan_refined.20191211b.tar.bz2
+              url: https://github.com/openvax/mhcflurry/releases/download/1.4.0/models_class1_pan_refined.20191211c.tar.bz2
               default: false
 
             - name: models_class1_pan_variants
diff --git a/test/test_class1_presentation_neural_network.py b/test/test_class1_presentation_neural_network.py
index 75cd751d69b322700f0180e147b4e6a7d8c3c994..6ea4528ade00fd3e40f3a014f9ef0d8d4efcf618 100644
--- a/test/test_class1_presentation_neural_network.py
+++ b/test/test_class1_presentation_neural_network.py
@@ -72,11 +72,17 @@ def scramble_peptide(peptide):
     return "".join(lst)
 
 
-def make_motif(presentation_predictor, allele, peptides, frac=0.01):
+def make_motif(presentation_predictor, allele, peptides, frac=0.01, master_allele_encoding=None):
+    if master_allele_encoding is not None:
+        alleles = MultipleAlleleEncoding(borrow_from=master_allele_encoding)
+        alleles.append_alleles([allele] * len(peptides))
+    else:
+        alleles = [allele]
+
     peptides = EncodableSequences.create(peptides)
     predictions = presentation_predictor.predict(
         peptides=peptides,
-        alleles=[allele],
+        alleles=alleles,
     )
     random_predictions_df = pandas.DataFrame({"peptide": peptides.sequences})
     random_predictions_df["prediction"] = predictions
@@ -91,6 +97,176 @@ def make_motif(presentation_predictor, allele, peptides, frac=0.01):
 # TESTS
 ###################################################
 
+def Xtest_synthetic_allele_refinement_max_affinity(include_affinities=True):
+    """
+    Test that in a synthetic example the model is able to learn that HLA-C*01:02
+    prefers P at position 3.
+    """
+    refine_allele = "HLA-C*01:02"
+    alleles = ["HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01", "HLA-A*03:01",
+        "HLA-B*15:01", refine_allele]
+    peptides_per_allele = [2000, 1000, 500, 1500, 1200, 800, ]
+
+    allele_to_peptides = dict(zip(alleles, peptides_per_allele))
+
+    length = 9
+
+    train_with_ms = pandas.read_csv(get_path("data_curated",
+        "curated_training_data.with_mass_spec.csv.bz2"))
+    train_no_ms = pandas.read_csv(
+        get_path("data_curated", "curated_training_data.no_mass_spec.csv.bz2"))
+
+    def filter_df(df):
+        return df.loc[
+            (df.allele.isin(alleles)) & (df.peptide.str.len() == length)]
+
+    train_with_ms = filter_df(train_with_ms)
+    train_no_ms = filter_df(train_no_ms)
+
+    ms_specific = train_with_ms.loc[
+        ~train_with_ms.peptide.isin(train_no_ms.peptide)]
+
+    train_peptides = []
+    train_true_alleles = []
+    for allele in alleles:
+        peptides = ms_specific.loc[ms_specific.allele == allele].peptide.sample(
+            n=allele_to_peptides[allele])
+        train_peptides.extend(peptides)
+        train_true_alleles.extend([allele] * len(peptides))
+
+    hits_df = pandas.DataFrame({"peptide": train_peptides})
+    hits_df["true_allele"] = train_true_alleles
+    hits_df["hit"] = 1.0
+    hits_df["label"] = 500
+    hits_df["measurement_inequality"] = "<"
+
+    decoys_df = hits_df.copy()
+    decoys_df["peptide"] = decoys_df.peptide.map(scramble_peptide)
+    decoys_df["true_allele"] = ""
+    decoys_df["hit"] = 0.0
+    decoys_df["label"] = 500
+    hits_df["measurement_inequality"] = ">"
+
+    mms_train_df = pandas.concat([hits_df, decoys_df], ignore_index=True)
+    mms_train_df["label"] = mms_train_df.hit
+    mms_train_df["is_affinity"] = True
+
+    if include_affinities:
+        affinity_train_df = pandas.read_csv(get_path("models_class1_pan",
+            "models.with_mass_spec/train_data.csv.bz2"))
+        affinity_train_df = affinity_train_df.loc[
+            affinity_train_df.allele.isin(alleles),
+            ["peptide", "allele", "measurement_inequality", "measurement_value"]
+        ]
+
+        affinity_train_df["label"] = affinity_train_df["measurement_value"]
+        del affinity_train_df["measurement_value"]
+        affinity_train_df["is_affinity"] = True
+    else:
+        affinity_train_df = None
+
+    (affinity_model,) = AFFINITY_PREDICTOR.class1_pan_allele_models
+    presentation_model = Class1PresentationNeuralNetwork(
+        auxiliary_input_features=["gene"],
+        batch_generator_batch_size=1024,
+        max_epochs=10,
+        learning_rate=0.001,
+        patience=5,
+        min_delta=0.0,
+        random_negative_rate=1.0,
+        random_negative_constant=25)
+    presentation_model.load_from_class1_neural_network(affinity_model)
+
+    presentation_predictor = Class1PresentationPredictor(
+        models=[presentation_model],
+        allele_to_sequence=AFFINITY_PREDICTOR.allele_to_sequence)
+
+    mms_allele_encoding = MultipleAlleleEncoding(
+        experiment_names=["experiment1"] * len(mms_train_df),
+        experiment_to_allele_list={
+            "experiment1": alleles,
+        }, max_alleles_per_experiment=6,
+        allele_to_sequence=AFFINITY_PREDICTOR.allele_to_sequence, )
+    allele_encoding = copy.deepcopy(mms_allele_encoding)
+    if affinity_train_df is not None:
+        allele_encoding.append_alleles(affinity_train_df.allele.values)
+        train_df = pandas.concat([mms_train_df, affinity_train_df],
+            ignore_index=True, sort=False)
+    else:
+        train_df = mms_train_df
+
+    allele_encoding = allele_encoding.compact()
+    mms_allele_encoding = mms_allele_encoding.compact()
+
+    pre_predictions = presentation_model.predict(
+        peptides=mms_train_df.peptide.values,
+        allele_encoding=mms_allele_encoding).score
+
+    expected_pre_predictions = from_ic50(affinity_model.predict(
+        peptides=numpy.repeat(mms_train_df.peptide.values, len(alleles)),
+        allele_encoding=mms_allele_encoding.allele_encoding, )).reshape(
+        (-1, len(alleles)))
+    assert_allclose(pre_predictions, expected_pre_predictions, rtol=1e-4)
+
+    random_peptides_encodable = EncodableSequences.create(
+        random_peptides(10000, 9))
+
+    original_motif = make_motif(
+        presentation_predictor=presentation_predictor,
+        peptides=random_peptides_encodable,
+        allele=refine_allele)
+    print("Original motif proline-3 rate: ", original_motif.loc[3, "P"])
+    assert_less(original_motif.loc[3, "P"], 0.1)
+
+    iteration_box = [0]
+
+    def progress(label = None):
+        if label is None:
+            label = str(iteration_box[0])
+            iteration_box[0] += 1
+        print("*** iteration ", label, "***")
+        predictions_df = presentation_predictor.predict_to_dataframe(
+            peptides=mms_train_df.peptide.values,
+            alleles=mms_allele_encoding)
+        merged_df = pandas.merge(mms_train_df, predictions_df, on="peptide")
+        merged_hit_df = merged_df.loc[merged_df.hit == 1.0]
+        correct_allele_fraction =  (
+            merged_hit_df.allele == merged_hit_df.true_allele).mean()
+        print("Correct allele fraction", correct_allele_fraction)
+        print(
+            "Mean score/affinity for hit",
+            merged_df.loc[merged_df.hit == 1.0].score.mean(),
+            merged_df.loc[merged_df.hit == 1.0].affinity.mean())
+        print(
+            "Mean score/affinity for decoy",
+            merged_df.loc[merged_df.hit == 0.0].score.mean(),
+            merged_df.loc[merged_df.hit == 0.0].affinity.mean())
+        auc = roc_auc_score(merged_df.hit.values, merged_df.score.values)
+        print("AUC", auc)
+        return (auc, correct_allele_fraction)
+
+    (pre_auc, pre_correct_allele_fraction) = progress(label="Pre fitting")
+    presentation_model.fit(peptides=train_df.peptide.values,
+        labels=train_df.label.values,
+        inequalities=train_df.measurement_inequality.values,
+        affinities_mask=train_df.is_affinity.values,
+        allele_encoding=allele_encoding,
+        progress_callback=progress)
+    (post_auc, post_correct_allele_fraction) = progress(label="Done fitting")
+
+    final_motif = make_motif(
+        presentation_predictor=presentation_predictor,
+        peptides=random_peptides_encodable,
+        allele=refine_allele)
+    print("Final motif proline-3 rate: ", final_motif.loc[3, "P"])
+
+    assert_greater(post_auc, pre_auc)
+    assert_greater(
+        post_correct_allele_fraction, pre_correct_allele_fraction - 0.05)
+    assert_greater(final_motif.loc[3, "P"], original_motif.loc[3, "P"])
+
+
+
 def test_synthetic_allele_refinement(include_affinities=True):
     """
     Test that in a synthetic example the model is able to learn that HLA-C*01:02
@@ -157,6 +333,8 @@ def test_synthetic_allele_refinement(include_affinities=True):
 
     (affinity_model,) = AFFINITY_PREDICTOR.class1_pan_allele_models
     presentation_model = Class1PresentationNeuralNetwork(
+        #batch_generator="multiallelic_mass_spec",
+        batch_generator="simple",
         auxiliary_input_features=["gene"],
         batch_generator_batch_size=1024,
         max_epochs=10,
@@ -176,7 +354,7 @@ def test_synthetic_allele_refinement(include_affinities=True):
         experiment_to_allele_list={
             "experiment1": alleles,
         }, max_alleles_per_experiment=6,
-        allele_to_sequence=AFFINITY_PREDICTOR.allele_to_sequence, )
+        allele_to_sequence=AFFINITY_PREDICTOR.allele_to_sequence)
     allele_encoding = copy.deepcopy(mms_allele_encoding)
     if affinity_train_df is not None:
         allele_encoding.append_alleles(affinity_train_df.allele.values)
@@ -186,6 +364,7 @@ def test_synthetic_allele_refinement(include_affinities=True):
         train_df = mms_train_df
 
     allele_encoding = allele_encoding.compact()
+    mms_allele_encoding = mms_allele_encoding.compact()
 
     pre_predictions = presentation_model.predict(
         peptides=mms_train_df.peptide.values,
@@ -232,6 +411,14 @@ def test_synthetic_allele_refinement(include_affinities=True):
             merged_df.loc[merged_df.hit == 0.0].affinity.mean())
         auc = roc_auc_score(merged_df.hit.values, merged_df.score.values)
         print("AUC", auc)
+
+        motif = make_motif(
+            presentation_predictor=presentation_predictor,
+            peptides=random_peptides_encodable,
+            allele=refine_allele,
+            master_allele_encoding=allele_encoding.allele_encoding)
+        print("Proline-3 rate: ", motif.loc[3, "P"])
+
         return (auc, correct_allele_fraction)
 
     (pre_auc, pre_correct_allele_fraction) = progress(label="Pre fitting")
@@ -255,7 +442,7 @@ def test_synthetic_allele_refinement(include_affinities=True):
     assert_greater(final_motif.loc[3, "P"], original_motif.loc[3, "P"])
 
 
-def test_real_data_multiallelic_refinement(max_epochs=10):
+def Xtest_real_data_multiallelic_refinement(max_epochs=10):
     """
     Test on real data that we can learn that HLA-A*02:20 has a preference K at
     position 1.
@@ -265,7 +452,7 @@ def test_real_data_multiallelic_refinement(max_epochs=10):
         auxiliary_input_features=["gene"],
         batch_generator_batch_size=1024,
         max_epochs=max_epochs,
-        learning_rate=0.0001,
+        learning_rate=0.001,
         patience=5,
         min_delta=0.0,
         random_negative_rate=0,
@@ -315,7 +502,8 @@ def test_real_data_multiallelic_refinement(max_epochs=10):
             "models_class1_pan", "models.with_mass_spec/train_data.csv.bz2"))
 
     pan_sub_train_df = pan_train_df.loc[
-        pan_train_df.allele.isin(multi_train_alleles),
+        #pan_train_df.allele.isin(multi_train_alleles),
+        :,
         ["peptide", "allele", "measurement_inequality", "measurement_value"]
     ]
     pan_sub_train_df["label"] = pan_sub_train_df["measurement_value"]
@@ -346,12 +534,24 @@ def test_real_data_multiallelic_refinement(max_epochs=10):
         "original motif lysine-1 rate: ",
         original_motif.loc[1, "K"])
 
+    def progress():
+        motif = make_motif(
+            presentation_predictor=presentation_predictor,
+            peptides=random_peptides_encodable,
+            allele=refine_allele,
+            master_allele_encoding=allele_encoding.allele_encoding)
+        print(
+            refine_allele,
+            "current motif lysine-1 rate: ",
+            motif.loc[1, "K"])
+
     presentation_model.fit(
         peptides=combined_train_df.peptide.values,
         labels=combined_train_df.label.values,
         allele_encoding=allele_encoding,
         affinities_mask=combined_train_df.is_affinity.values,
-        inequalities=combined_train_df.measurement_inequality.values)
+        inequalities=combined_train_df.measurement_inequality.values,
+        progress_callback=progress)
 
     final_motif = make_motif(
         presentation_predictor=presentation_predictor,
@@ -360,3 +560,8 @@ def test_real_data_multiallelic_refinement(max_epochs=10):
     print(refine_allele, "final motif lysine-1 rate: ", final_motif.loc[1, "K"])
 
     assert_greater(final_motif.loc[1, "K"], original_motif.loc[1, "K"])
+
+if __name__ == "__main__":
+    setup()
+    test_real_data_multiallelic_refinement()
+    teardown()
\ No newline at end of file