diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 40743c1b46384cd41800fc2dc779e8ba722134f1..ffb3435bf810b5bca11ee91354cc71a6d1d23f97 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -13,7 +13,7 @@ from .encodable_sequences import EncodableSequences, EncodingError
 from .amino_acid import available_vector_encodings, vector_encoding_length
 from .regression_target import to_ic50, from_ic50
 from .common import random_peptides, amino_acid_distribution
-from .custom_loss import CUSTOM_LOSSES
+from .custom_loss import get_loss
 
 
 class Class1NeuralNetwork(object):
@@ -419,6 +419,116 @@ class Class1NeuralNetwork(object):
             allele_encoding.allele_representations(
                 self.hyperparameters['allele_amino_acid_encoding']))
 
+
+    def fit_generator(
+            self,
+            generator,
+            validation_peptide_encoding,
+            validation_affinities,
+            validation_allele_encoding=None,
+            validation_inequalities=None,
+            validation_output_indices=None,
+            steps_per_epoch=10,
+            epochs=1000,
+            patience=10,
+            verbose=1):
+        """
+        Fit using a generator. Does not support many of the features of fit(),
+        such as random negative peptides.
+
+        Parameters
+        ----------
+        generator : generator yielding (alleles, peptides, affinities) tuples
+            where alleles and peptides are lists of strings, and affinities
+            is list of floats.
+
+        validation_peptide_encoding
+        validation_affinities
+        validation_allele_encoding
+        validation_inequalities
+        validation_output_indices
+        steps_per_epoch
+        epochs
+        patience
+        verbose
+
+        Returns
+        -------
+
+        """
+        import keras
+
+        loss = get_loss(self.hyperparameters['loss'])
+
+        (validation_allele_input, allele_representations) = (
+            self.allele_encoding_to_network_input(validation_allele_encoding))
+
+        if self.network() is None:
+            self._network = self.make_network(
+                allele_representations=allele_representations,
+                **self.network_hyperparameter_defaults.subselect(
+                    self.hyperparameters))
+            if verbose > 0:
+                self.network().summary()
+        network = self.network()
+
+        network.compile(
+            loss=loss.loss, optimizer=self.hyperparameters['optimizer'])
+        network._make_predict_function()
+        self.set_allele_representations(allele_representations)
+
+        validation_x_dict = {
+            'peptide': self.peptides_to_network_input(
+                validation_peptide_encoding),
+            'allele': validation_allele_input,
+        }
+        encode_y_kwargs = {}
+        if validation_inequalities is not None:
+            encode_y_kwargs["inequalities"] = validation_inequalities
+        if validation_output_indices is not None:
+            encode_y_kwargs["output_indices"] = validation_output_indices
+
+        output = loss.encode_y(
+            from_ic50(validation_affinities), **encode_y_kwargs)
+
+        validation_y_dict = {
+            'output': output,
+        }
+
+        yielded_values_box = [0]
+
+        def wrapped_generator():
+            for (alleles, peptides, affinities) in generator:
+                (allele_encoding_input, _) = (
+                    self.allele_encoding_to_network_input(alleles))
+                x_dict = {
+                    'peptide': self.peptides_to_network_input(peptides),
+                    'allele': allele_encoding_input,
+                }
+                y_dict = {
+                    'output': from_ic50(affinities)
+                }
+                yield (x_dict, y_dict)
+                yielded_values_box[0] += len(affinities)
+
+        start = time.time()
+        result = network.fit_generator(
+            wrapped_generator(),
+            steps_per_epoch=steps_per_epoch,
+            epochs=epochs,
+            use_multiprocessing=False,
+            workers=1,
+            validation_data=(validation_x_dict, validation_y_dict),
+            callbacks=[keras.callbacks.EarlyStopping(
+                monitor="val_loss",
+                patience=patience,
+                verbose=1)]
+        )
+        if verbose > 0:
+            print("fit_generator completed in %0.2f sec (%d total points)" % (
+                time.time() - start, yielded_values_box[0]))
+
+
     def fit(
             self,
             peptides,
@@ -539,41 +649,15 @@ class Class1NeuralNetwork(object):
         if output_indices is not None:
             output_indices = output_indices[shuffle_permutation]
 
-        if self.hyperparameters['loss'].startswith("custom:"):
-            # Using a custom loss
-            try:
-                custom_loss = CUSTOM_LOSSES[
-                    self.hyperparameters['loss'].replace("custom:", "")
-                ]
-            except KeyError:
-                raise ValueError(
-                    "No such custom loss function: %s. Supported losses are: %s" % (
-                        self.hyperparameters['loss'],
-                        ", ".join([
-                            "custom:" + loss_name for loss_name in CUSTOM_LOSSES
-                        ])))
-            loss_name_or_function = custom_loss.loss
-            loss_supports_inequalities = custom_loss.supports_inequalities
-            loss_supports_multiple_outputs = custom_loss.supports_multiple_outputs
-            loss_encode_y_function = custom_loss.encode_y
-        else:
-            # Using a regular keras loss.
-            loss_name_or_function = self.hyperparameters['loss']
-            loss_supports_inequalities = False
-            loss_supports_multiple_outputs = False
-            loss_encode_y_function = None
+        loss = get_loss(self.hyperparameters['loss'])
 
-        if not loss_supports_inequalities and (
+        if not loss.supports_inequalities and (
                 any(inequality != "=" for inequality in adjusted_inequalities)):
-            raise ValueError("Loss %s does not support inequalities" % (
-                loss_name_or_function))
+            raise ValueError("Loss %s does not support inequalities" % loss)
 
-        if (
-                not loss_supports_multiple_outputs and
-                output_indices is not None and
-                (output_indices != 0).any()):
-            raise ValueError("Loss %s does not support multiple outputs" % (
-                output_indices))
+        if (not loss.supports_multiple_outputs and output_indices is not None
+                and (output_indices != 0).any()):
+            raise ValueError("Loss %s does not support multiple outputs" % loss)
 
         if self.hyperparameters['num_outputs'] != 1:
             if output_indices is None:
@@ -592,8 +676,7 @@ class Class1NeuralNetwork(object):
             self.set_allele_representations(allele_representations)
 
         self.network().compile(
-            loss=loss_name_or_function,
-            optimizer=self.hyperparameters['optimizer'])
+            loss=loss.loss, optimizer=self.hyperparameters['optimizer'])
 
         if self.hyperparameters['learning_rate'] is not None:
             from keras import backend as K
@@ -601,7 +684,7 @@ class Class1NeuralNetwork(object):
                 self.network().optimizer.lr,
                 self.hyperparameters['learning_rate'])
 
-        if loss_supports_inequalities:
+        if loss.supports_inequalities:
             # Do not sample negative affinities: just use an inequality.
             random_negative_ic50 = self.hyperparameters['random_negative_affinity_min']
             random_negative_target = from_ic50(random_negative_ic50)
@@ -647,18 +730,17 @@ class Class1NeuralNetwork(object):
         else:
             output_indices_with_random_negatives = None
 
-        if loss_encode_y_function is not None:
-            encode_y_kwargs = {}
-            if adjusted_inequalities_with_random_negatives is not None:
-                encode_y_kwargs["inequalities"] = (
-                    adjusted_inequalities_with_random_negatives)
-            if output_indices_with_random_negatives is not None:
-                encode_y_kwargs["output_indices"] = (
-                    output_indices_with_random_negatives)
-
-            y_dict_with_random_negatives['output'] = loss_encode_y_function(
-                y_dict_with_random_negatives['output'],
-                **encode_y_kwargs)
+        encode_y_kwargs = {}
+        if adjusted_inequalities_with_random_negatives is not None:
+            encode_y_kwargs["inequalities"] = (
+                adjusted_inequalities_with_random_negatives)
+        if output_indices_with_random_negatives is not None:
+            encode_y_kwargs["output_indices"] = (
+                output_indices_with_random_negatives)
+
+        y_dict_with_random_negatives['output'] = loss.encode_y(
+            y_dict_with_random_negatives['output'],
+            **encode_y_kwargs)
 
         val_losses = []
         min_val_loss_iteration = None
diff --git a/mhcflurry/custom_loss.py b/mhcflurry/custom_loss.py
index 4612bf60b1bbcecfc1c12a6ee4edc0c0329e178e..641fd09969700670b22128059ef078151cf7d30e 100644
--- a/mhcflurry/custom_loss.py
+++ b/mhcflurry/custom_loss.py
@@ -13,7 +13,43 @@ from numpy import isnan, array
 CUSTOM_LOSSES = {}
 
 
-class MSEWithInequalities(object):
+def get_loss(name):
+    if name.startswith("custom:"):
+        try:
+            custom_loss = CUSTOM_LOSSES[name.replace("custom:", "")]
+        except KeyError:
+            raise ValueError(
+                "No such custom loss: %s. Supported losses are: %s" % (
+                    name,
+                    ", ".join([
+                        "custom:" + loss_name for loss_name in CUSTOM_LOSSES
+                    ])))
+        return custom_loss
+    return StandardKerasLoss(name)
+
+
+class Loss(object):
+    def __init__(self, name=None):
+        self.name = name if name else self.name  # use name from class instance
+
+    def __str__(self):
+        return "<Loss: %s>" % self.name
+
+
+class StandardKerasLoss(Loss):
+    supports_inequalities = False
+    supports_multiple_outputs = False
+
+    def __init__(self, loss_name="mse"):
+        self.loss = loss_name
+        Loss.__init__(self, loss_name)
+
+    @staticmethod
+    def encode_y(y):
+        return y
+
+
+class MSEWithInequalities(Loss):
     """
     Supports training a regressor on data that includes inequalities
     (e.g. x < 100). Mean square error is used as the loss for elements with
@@ -96,7 +132,7 @@ class MSEWithInequalities(object):
         return result
 
 
-class MSEWithInequalitiesAndMultipleOutputs(object):
+class MSEWithInequalitiesAndMultipleOutputs(Loss):
     name = "mse_with_inequalities_and_multiple_outputs"
     supports_inequalities = True
     supports_multiple_outputs = True
diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py
index 56e4d60b7cbb9f3d22c96a85413ce01d643e7563..cb56f80a1bf4424177194882cdb5e35f3b668f22 100644
--- a/mhcflurry/train_pan_allele_models_command.py
+++ b/mhcflurry/train_pan_allele_models_command.py
@@ -31,6 +31,7 @@ from .hyperparameters import HyperparameterDefaults
 from .allele_encoding import AlleleEncoding
 from .encodable_sequences import EncodableSequences
 from .regression_target import to_ic50, from_ic50
+from .import custom_loss
 
 
 # To avoid pickling large matrices to send to child processes when running in
@@ -417,6 +418,7 @@ def train_model(
         predictor,
         save_to):
     import keras.backend as K
+    import keras
 
     df = GLOBAL_DATA["train_data"]
     folds_df = GLOBAL_DATA["folds_df"]
@@ -436,7 +438,6 @@ def train_model(
     train_peptides = EncodableSequences(train_data.peptide.values)
     train_alleles = AlleleEncoding(
         train_data.allele.values, borrow_from=allele_encoding)
-    train_target = from_ic50(train_data.measurement_value.values)
 
     model = Class1NeuralNetwork(**hyperparameters)
 
@@ -453,61 +454,25 @@ def train_model(
 
     assert model.network() is None
     if hyperparameters.get("train_data", {}).get("pretrain", False):
-        iterator = pretrain_data_iterator(pretrain_data_filename, allele_encoding)
-        original_hyperparameters = dict(model.hyperparameters)
-        model.hyperparameters['minibatch_size'] = int(len(next(iterator)[-1]) / 100)
-        model.hyperparameters['max_epochs'] = 1
-        model.hyperparameters['validation_split'] = 0.0
-        model.hyperparameters['random_negative_rate'] = 0.0
-        model.hyperparameters['random_negative_constant'] = 0
-        pretrain_patience = hyperparameters["train_data"]["pretrain_patience"]
-        scores = []
-        best_score = float('inf')
-        best_score_epoch = 0
-        for (epoch, (alleles, peptides, affinities)) in enumerate(iterator):
-            # Fit one epoch.
-            start = time.time()
-            model.fit(
-                peptides=peptides,
-                affinities=affinities,
-                allele_encoding=alleles)
-
-            fit_time = time.time() - start
-            start = time.time()
-            predictions = model.predict(
-                train_peptides,
-                allele_encoding=train_alleles)
-            assert len(predictions) == len(train_data)
-
-            print("Prediction histogram:")
-            print(
-                pandas.Series(
-                    dict([k, v] for (v, k) in zip(*numpy.histogram(predictions)))))
-
-            for (inequality, func) in [(">", numpy.minimum), ("<", numpy.maximum)]:
-                mask = train_data.measurement_inequality == inequality
-                predictions[mask.values] = func(
-                    predictions[mask.values],
-                    train_data.loc[mask].measurement_value.values)
-            score_mse = numpy.mean((from_ic50(predictions) - train_target)**2)
-            score_time = time.time() - start
-            print(
-                progress_preamble,
-                "PRETRAIN epoch %d [%d values, %0.2f sec]. "
-                "MSE [%0.2f sec.]: %10f" % (
-                    epoch, len(affinities), fit_time, score_time, score_mse))
-            scores.append(score_mse)
-
-            if score_mse < best_score:
-                print("New best score_mse", score_mse)
-                best_score = score_mse
-                best_score_epoch = epoch
-
-            if epoch - best_score_epoch > pretrain_patience:
-                print("Stopping pretraining")
-                break
-
-        model.hyperparameters = original_hyperparameters
+        generator = pretrain_data_iterator(pretrain_data_filename, allele_encoding)
+        pretrain_patience = hyperparameters["train_data"].get(
+            "pretrain_patience", 10)
+        pretrain_steps_per_epoch = hyperparameters["train_data"].get(
+            "pretrain_steps_per_epoch", 10)
+        pretrain_max_epochs = hyperparameters["train_data"].get(
+            "pretrain_max_epochs", 1000)
+
+        model.fit_generator(
+            generator,
+            validation_peptide_encoding=train_peptides,
+            validation_affinities=train_data.measurement_value.values,
+            validation_allele_encoding=train_alleles,
+            validation_inequalities=train_data.measurement_inequality.values,
+            patience=pretrain_patience,
+            steps_per_epoch=pretrain_steps_per_epoch,
+            epochs=pretrain_max_epochs,
+            verbose=verbose,
+        )
         if model.hyperparameters['learning_rate']:
             model.hyperparameters['learning_rate'] /= 10
         else:
diff --git a/test/test_train_pan_allele_models_command.py b/test/test_train_pan_allele_models_command.py
index 25a9bdc7ed81c87c5a83ba8dd86e0f5bb186cf5e..d2c303886e9906ecf2230ed3b5de598e62b6feef 100644
--- a/test/test_train_pan_allele_models_command.py
+++ b/test/test_train_pan_allele_models_command.py
@@ -55,7 +55,7 @@ HYPERPARAMETERS_LIST = [
     'random_negative_distribution_smoothing': 0.0,
     'random_negative_match_distribution': True,
     'random_negative_rate': 0.2,
-    'train_data': {},
+    'train_data': {"pretrain": False},
     'validation_split': 0.1,
 },
 {
@@ -91,10 +91,14 @@ HYPERPARAMETERS_LIST = [
     'random_negative_distribution_smoothing': 0.0,
     'random_negative_match_distribution': True,
     'random_negative_rate': 0.2,
-    'train_data': {},
+    'train_data': {
+        "pretrain": True,
+        'pretrain_peptides_per_epoch': 128,
+        'pretrain_max_epochs': 2,
+    },
     'validation_split': 0.1,
 },
-]
+][1:]
 
 
 def run_and_check(n_jobs=0):
@@ -119,6 +123,8 @@ def run_and_check(n_jobs=0):
         "--num-jobs", str(n_jobs),
         "--ensemble-size", "2",
         "--verbosity", "1",
+        "--pretrain-data", get_path(
+            "random_peptide_predictions", "predictions.csv.bz2"),
     ]
     print("Running with args: %s" % args)
     subprocess.check_call(args)
@@ -147,8 +153,4 @@ def test_run_serial():
     run_and_check(n_jobs=1)
 
 if __name__ == "__main__":
-    test_run_serial()
-    #for (name, value) in list(globals().items()):
-    #    if name.startswith("test_"):
-    #        print("Running test", name)
-    #        value()
\ No newline at end of file
+    test_run_serial()
\ No newline at end of file