diff --git a/mhcflurry/allele_encoding.py b/mhcflurry/allele_encoding.py
index 06355361dd337c5faaa5aa040932a62a733ece87..fcc02789733e10d76df51996bdd74f3e45a6a295 100644
--- a/mhcflurry/allele_encoding.py
+++ b/mhcflurry/allele_encoding.py
@@ -1,5 +1,8 @@
+import numpy
 import pandas
 
+from copy import copy
+
 from . import amino_acid
 
 
@@ -18,7 +21,8 @@ class AlleleEncoding(object):
         Parameters
         ----------
         alleles : list of string
-            Allele names
+            Allele names. If any allele is None instead of string, it will be
+            mapped to the special index value -1.
 
         allele_to_sequence : dict of str -> str
             Allele name to amino acid sequence
@@ -42,6 +46,7 @@ class AlleleEncoding(object):
             self.allele_to_index = dict(
                 (allele, i)
                 for (i, allele) in enumerate(all_alleles))
+            self.allele_to_index[None] = -1  # special mask value
             unpadded = pandas.Series(
                 [allele_to_sequence[a] for a in all_alleles],
                 index=all_alleles)
@@ -138,3 +143,52 @@ class AlleleEncoding(object):
             result = vector_encoded[self.indices]
             self.encoding_cache[cache_key] = result
         return self.encoding_cache[cache_key]
+
+
+class MultipleAlleleEncoding(object):
+    def __init__(
+            self,
+            experiment_names,
+            experiment_to_allele_list,
+            max_alleles_per_experiment=6,
+            allele_to_sequence=None,
+            borrow_from=None):
+
+        padded_experiment_to_allele_list = {}
+        for (name, alleles) in experiment_to_allele_list.items():
+            assert len(alleles) > 0
+            assert len(alleles) <= max_alleles_per_experiment
+            alleles_with_mask = alleles + [None] * (
+                    max_alleles_per_experiment - len(alleles))
+            padded_experiment_to_allele_list[name] = alleles_with_mask
+
+        flattened_allele_list = []
+        for name in experiment_names:
+            flattened_allele_list.extend(padded_experiment_to_allele_list[name])
+
+        self.allele_encoding = AlleleEncoding(
+            alleles=flattened_allele_list,
+            allele_to_sequence=allele_to_sequence,
+            borrow_from=borrow_from
+        )
+        self.max_alleles_per_experiment = max_alleles_per_experiment
+
+    @property
+    def indices(self):
+        return self.allele_encoding.indices.values.reshape(
+            (-1, self.max_alleles_per_experiment))
+
+    def compact(self):
+        result = copy(self)
+        result.allele_encoding = self.allele_encoding.compact()
+        return result
+
+    def allele_representations(self, encoding_name):
+        return self.allele_encoding.allele_representations(encoding_name)
+
+    @property
+    def allele_to_sequence(self):
+        return self.allele_encoding.allele_to_sequence
+
+    def fixed_length_vector_encoded_sequences(self, encoding_name):
+        raise NotImplementedError()
diff --git a/mhcflurry/class1_ligandome_predictor.py b/mhcflurry/class1_ligandome_predictor.py
index 8312c9169a88dfec1bd7c9a21c4cc7956ad2f3cc..7eba3357ab359a910f0bb680e60b9f06256da0f2 100644
--- a/mhcflurry/class1_ligandome_predictor.py
+++ b/mhcflurry/class1_ligandome_predictor.py
@@ -1,104 +1,411 @@
+import time
+import collections
+
+import numpy
 
 from .hyperparameters import HyperparameterDefaults
-from .class1_neural_network import Class1NeuralNetwork
+from .class1_neural_network import Class1NeuralNetwork, DEFAULT_PREDICT_BATCH_SIZE
+from .encodable_sequences import EncodableSequences
+
 
 class Class1LigandomePredictor(object):
     network_hyperparameter_defaults = HyperparameterDefaults(
-        retrain_mode="all",
+        allele_amino_acid_encoding="BLOSUM62",
+        peptide_encoding={
+            'vector_encoding_name': 'BLOSUM62',
+            'alignment_method': 'left_pad_centered_right_pad',
+            'max_length': 15,
+        },
+    )
+    """
+    Hyperparameters (and their default values) that affect the neural network
+    architecture.
+    """
+
+    fit_hyperparameter_defaults = HyperparameterDefaults(
+        max_epochs=500,
+        validation_split=0.1,
+        early_stopping=True,
+        minibatch_size=128,
+        random_negative_rate=0.0,
+        random_negative_constant=0,
+    )
+    """
+    Hyperparameters for neural network training.
+    """
+
+    early_stopping_hyperparameter_defaults = HyperparameterDefaults(
+        patience=20,
+        min_delta=0.0,
+    )
+    """
+    Hyperparameters for early stopping.
+    """
+
+    compile_hyperparameter_defaults = HyperparameterDefaults(
+        loss="custom:mse_with_inequalities",
+        optimizer="rmsprop",
+        learning_rate=None,
     )
+    """
+    Loss and optimizer hyperparameters. Any values supported by keras may be
+    used.
+    """
 
-    def __init__(self, class1_affinity_predictor):
-        if not class1_affinity_predictor.pan_allele_models:
+    hyperparameter_defaults = network_hyperparameter_defaults.extend(
+        fit_hyperparameter_defaults).extend(
+        early_stopping_hyperparameter_defaults).extend(
+        compile_hyperparameter_defaults)
+
+    def __init__(
+            self,
+            class1_affinity_predictor,
+            max_ensemble_size=None,
+            **hyperparameters):
+        if not class1_affinity_predictor.class1_pan_allele_models:
             raise NotImplementedError("Pan allele models required")
         if class1_affinity_predictor.allele_to_allele_specific_models:
             raise NotImplementedError("Only pan allele models are supported")
-        self.binding_predictors = class1_affinity_predictor.pan_allele_models
-        self.network = None
 
-        self.network = Class1NeuralNetwork.merge(
-            self.binding_predictors, merge_method="sum")
+        self.hyperparameters = self.hyperparameter_defaults.with_defaults(
+            hyperparameters)
+
+        models = class1_affinity_predictor.class1_pan_allele_models
+        if max_ensemble_size is not None:
+            models = models[:max_ensemble_size]
 
-    def make_network(self):
-        import keras
+        self.network = self.make_network(
+            models,
+            self.hyperparameters)
+
+        self.fit_info = []
+
+    @staticmethod
+    def make_network(pan_allele_class1_neural_networks, hyperparameters):
         import keras.backend as K
-        from keras.layers import Input
+        from keras.layers import Input, TimeDistributed, Lambda, Flatten, RepeatVector, concatenate, Dropout, Reshape, Embedding
+        from keras.activations import sigmoid
         from keras.models import Model
 
-        models = self.binding_predictors
-
-        if len(models) == 1:
-            return models[0]
-        assert len(models) > 1
-
-        result = Class1NeuralNetwork(**dict(models[0].hyperparameters))
-
-        # Remove hyperparameters that are not shared by all models.
-        for model in models:
-            for (key, value) in model.hyperparameters.items():
-                if result.hyperparameters.get(key, value) != value:
-                    del result.hyperparameters[key]
-
-        assert result._network is None
-
-        networks = [model.network() for model in models]
-
-        layer_names = [[layer.name for layer in network.layers] for network in
-            networks]
-
-        pan_allele_layer_names = ['allele', 'peptide', 'allele_representation',
-            'flattened_0', 'allele_flat', 'allele_peptide_merged', 'dense_0',
-            'dropout_0', 'dense_1', 'dropout_1', 'output', ]
-
-        if all(names == pan_allele_layer_names for names in layer_names):
-            # Merging an ensemble of pan-allele architectures
-            network = networks[0]
-            peptide_input = Input(
-                shape=tuple(int(x) for x in K.int_shape(network.inputs[0])[1:]),
-                dtype='float32', name='peptide')
-            allele_input = Input(shape=(1,), dtype='float32', name='allele')
-
-            allele_embedding = network.get_layer("allele_representation")(
-                allele_input)
-            peptide_flat = network.get_layer("flattened_0")(peptide_input)
-            allele_flat = network.get_layer("allele_flat")(allele_embedding)
-            allele_peptide_merged = network.get_layer("allele_peptide_merged")(
-                [peptide_flat, allele_flat])
-
-            sub_networks = []
-            for (i, network) in enumerate(networks):
-                layers = network.layers[
-                pan_allele_layer_names.index("allele_peptide_merged") + 1:]
-                node = allele_peptide_merged
-                for layer in layers:
-                    layer.name += "_%d" % i
-                    node = layer(node)
-                sub_networks.append(node)
-
-            if merge_method == 'average':
-                output = keras.layers.average(sub_networks)
-            elif merge_method == 'sum':
-                output = keras.layers.add(sub_networks)
-            elif merge_method == 'concatenate':
-                output = keras.layers.concatenate(sub_networks)
-            else:
-                raise NotImplementedError("Unsupported merge method",
-                    merge_method)
-
-            result._network = Model(inputs=[peptide_input, allele_input],
-                outputs=[output], name="merged_predictor")
-            result.update_network_description()
-        else:
-            raise NotImplementedError(
-                "Don't know merge_method to merge networks with layer names: ",
-                layer_names)
-        return result
-
-
-    def fit(self, peptides, labels, experiment_names,
-            experiment_name_to_alleles):
-
-
-        pass
-
-    def predict(self, allele_lists, peptides):
-        pass
+        networks = [model.network() for model in pan_allele_class1_neural_networks]
+        merged_ensemble = Class1NeuralNetwork.merge(
+            networks,
+            merge_method="average")
+
+        peptide_shape = tuple(
+            int(x) for x in K.int_shape(merged_ensemble.inputs[0])[1:])
+
+        input_alleles = Input(shape=(6,), name="allele")  # up to 6 alleles
+        input_peptides = Input(
+            shape=peptide_shape,
+            dtype='float32',
+            name='peptide')
+
+        #peptides_broadcasted = Lambda(
+        #    lambda x:
+        #        K.reshape(
+        #            K.repeat(
+        #                K.reshape(x, (-1, numpy.product(peptide_shape))), 6),
+        #         (-1, 6) + peptide_shape)
+        #)(input_peptides)
+
+        peptides_flattened = Flatten()(input_peptides)
+        peptides_repeated = RepeatVector(6)(peptides_flattened)
+
+        allele_representation = Embedding(
+            name="allele_representation",
+            input_dim=64,  # arbitrary, how many alleles to have room for
+            output_dim=1029,
+            input_length=6,
+            trainable=False)(input_alleles)
+
+        allele_flat = Reshape((6, -1))(allele_representation)
+
+        allele_peptide_merged = concatenate([peptides_repeated, allele_flat])
+
+        dense_0 = merged_ensemble.get_layer("dense_0")
+        td_dense0 = TimeDistributed(dense_0, name="td_dense_0")(allele_peptide_merged)
+        td_dense0 = Dropout(0.5)(td_dense0)
+
+        dense_1 = merged_ensemble.get_layer("dense_1")
+        td_dense1 = TimeDistributed(dense_1, name="td_dense_1")(td_dense0)
+        td_dense1 = Dropout(0.5)(td_dense1)
+
+        output = merged_ensemble.get_layer("output")
+        td_output = TimeDistributed(output)(td_dense1)
+
+        network = Model(
+            inputs=[input_peptides, input_alleles],
+            outputs=td_output,
+            name="ligandome",
+        )
+        #print('trainable', network.get_layer("td_dense_0").trainable)
+        network.get_layer("td_dense_0").trainable = False
+        #print('trainable', network.get_layer("td_dense_0").trainable)
+
+        return network
+
+    @staticmethod
+    def loss(y_true, y_pred):
+        """Binary cross entropy after taking logsumexp over predictions"""
+        import keras.backend as K
+        import tensorflow as tf
+        #y_pred_aggregated = K.logsumexp(y_pred, axis=1, keepdims=True)
+        #y_pred_aggregated = K.sigmoid(y_pred_aggregated)
+        #y_pred = tf.Print(y_pred, [y_pred], "y_pred", summarize=20)
+        #y_true = tf.Print(y_true, [y_true], "y_true", summarize=20)
+
+        y_pred_aggregated = K.max(y_pred, axis=1, keepdims=False)
+        #y_pred_aggregated = tf.Print(y_pred_aggregated, [y_pred_aggregated], "y_pred_aggregated",
+        #    summarize=20)
+
+        y_true = K.squeeze(K.cast(y_true, y_pred_aggregated.dtype), axis=-1)
+        #print("SHAPES", y_pred, K.int_shape(y_pred), y_pred_aggregated, K.int_shape(y_pred_aggregated), y_true, K.int_shape(y_true))
+        #K.print_tensor(y_pred_aggregated, "y_pred_aggregated")
+        #K.print_tensor(y_true, "y_true")
+
+        #y_pred_aggregated = K.print_tensor(y_pred_aggregated, "y_pred_aggregated")
+
+
+        #y_true = K.print_tensor(y_true, "y_true")
+
+        #return K.mean(
+        #    K.binary_crossentropy(y_true, y_pred_aggregated),
+        #    axis=-1)
+        return K.mean(
+            (y_true - y_pred_aggregated)**2,
+            axis=-1
+        )
+
+    def peptides_to_network_input(self, peptides):
+        """
+        Encode peptides to the fixed-length encoding expected by the neural
+        network (which depends on the architecture).
+
+        Parameters
+        ----------
+        peptides : EncodableSequences or list of string
+
+        Returns
+        -------
+        numpy.array
+        """
+        encoder = EncodableSequences.create(peptides)
+        encoded = encoder.variable_length_to_fixed_length_vector_encoding(
+            **self.hyperparameters['peptide_encoding'])
+        assert len(encoded) == len(peptides)
+        return encoded
+
+    def allele_encoding_to_network_input(self, allele_encoding):
+        """
+        Encode alleles to the fixed-length encoding expected by the neural
+        network (which depends on the architecture).
+
+        Parameters
+        ----------
+        allele_encoding : AlleleEncoding
+
+        Returns
+        -------
+        (numpy.array, numpy.array)
+
+        Indices and allele representations.
+
+        """
+        return (
+            allele_encoding.indices,
+            allele_encoding.allele_representations(
+                self.hyperparameters['allele_amino_acid_encoding']))
+
+    def fit(
+            self,
+            peptides,
+            labels,
+            allele_encoding,
+            shuffle_permutation=None,
+            verbose=1,
+            progress_callback=None,
+            progress_preamble="",
+            progress_print_interval=5.0):
+
+        import keras.backend as K
+
+        peptides = EncodableSequences.create(peptides)
+        peptide_encoding = self.peptides_to_network_input(peptides)
+
+        # Optional optimization
+        allele_encoding = allele_encoding.compact()
+
+        (allele_encoding_input, allele_representations) = (
+            self.allele_encoding_to_network_input(allele_encoding))
+
+        # Shuffle
+        if shuffle_permutation is None:
+            shuffle_permutation = numpy.random.permutation(len(labels))
+        peptide_encoding = peptide_encoding[shuffle_permutation]
+        allele_encoding_input = allele_encoding_input[shuffle_permutation]
+        labels = labels[shuffle_permutation]
+
+        x_dict = {
+            'peptide': peptide_encoding,
+            'allele': allele_encoding_input,
+        }
+
+        fit_info = collections.defaultdict(list)
+
+        self.set_allele_representations(allele_representations)
+        self.network.compile(
+            loss=self.loss,
+            optimizer=self.hyperparameters['optimizer'])
+        if self.hyperparameters['learning_rate'] is not None:
+            K.set_value(
+                self.network.optimizer.lr,
+                self.hyperparameters['learning_rate'])
+        fit_info["learning_rate"] = float(
+            K.get_value(self.network.optimizer.lr))
+
+        if verbose:
+            self.network.summary()
+
+        min_val_loss_iteration = None
+        min_val_loss = None
+        last_progress_print = 0
+        start = time.time()
+        for i in range(self.hyperparameters['max_epochs']):
+            epoch_start = time.time()
+            fit_history = self.network.fit(
+                x_dict,
+                labels,
+                shuffle=True,
+                batch_size=self.hyperparameters['minibatch_size'],
+                verbose=verbose,
+                epochs=i + 1,
+                initial_epoch=i,
+                validation_split=self.hyperparameters['validation_split'])
+            epoch_time = time.time() - epoch_start
+
+            for (key, value) in fit_history.history.items():
+                fit_info[key].extend(value)
+
+            # Print progress no more often than once every few seconds.
+            if progress_print_interval is not None and (
+                    not last_progress_print or (
+                        time.time() - last_progress_print
+                        > progress_print_interval)):
+                print((progress_preamble + " " +
+                       "Epoch %3d / %3d [%0.2f sec]: loss=%g. "
+                       "Min val loss (%s) at epoch %s" % (
+                           i,
+                           self.hyperparameters['max_epochs'],
+                           epoch_time,
+                           fit_info['loss'][-1],
+                           str(min_val_loss),
+                           min_val_loss_iteration)).strip())
+                last_progress_print = time.time()
+
+            if self.hyperparameters['validation_split']:
+                val_loss = fit_info['val_loss'][-1]
+                if min_val_loss is None or (
+                        val_loss < min_val_loss -
+                        self.hyperparameters['min_delta']):
+                    min_val_loss = val_loss
+                    min_val_loss_iteration = i
+
+                if self.hyperparameters['early_stopping']:
+                    threshold = (
+                        min_val_loss_iteration +
+                        self.hyperparameters['patience'])
+                    if i > threshold:
+                        if progress_print_interval is not None:
+                            print((progress_preamble + " " +
+                                "Stopping at epoch %3d / %3d: loss=%g. "
+                                "Min val loss (%g) at epoch %s" % (
+                                    i,
+                                    self.hyperparameters['max_epochs'],
+                                    fit_info['loss'][-1],
+                                    (
+                                        min_val_loss if min_val_loss is not None
+                                        else numpy.nan),
+                                    min_val_loss_iteration)).strip())
+                        break
+
+            if progress_callback:
+                progress_callback()
+
+        fit_info["time"] = time.time() - start
+        fit_info["num_points"] = len(peptides)
+        self.fit_info.append(dict(fit_info))
+
+    def predict(
+            self,
+            peptides,
+            allele_encoding,
+            batch_size=DEFAULT_PREDICT_BATCH_SIZE):
+        (allele_encoding_input, allele_representations) = (
+                self.allele_encoding_to_network_input(allele_encoding.compact()))
+        self.set_allele_representations(allele_representations)
+        x_dict = {
+            'peptide': self.peptides_to_network_input(peptides),
+            'allele': allele_encoding_input,
+        }
+        predictions = self.network.predict(x_dict, batch_size=batch_size)
+        return numpy.squeeze(predictions, axis=-1)
+
+    #def predict(self):
+
+
+
+    def set_allele_representations(self, allele_representations):
+        """
+        """
+        from keras.models import clone_model
+        import keras.backend as K
+        import tensorflow as tf
+
+        reshaped = allele_representations.reshape(
+            (allele_representations.shape[0], -1))
+        original_model = self.network
+
+        layer = original_model.get_layer("allele_representation")
+        existing_weights_shape = (layer.input_dim, layer.output_dim)
+
+        # Only changes to the number of supported alleles (not the length of
+        # the allele sequences) are allowed.
+        assert existing_weights_shape[1:] == reshaped.shape[1:]
+
+        if existing_weights_shape[0] > reshaped.shape[0]:
+            # Extend with NaNs so we can avoid having to reshape the weights
+            # matrix, which is expensive.
+            reshaped = numpy.append(
+                reshaped,
+                numpy.ones([
+                    existing_weights_shape[0] - reshaped.shape[0],
+                    reshaped.shape[1]
+                ]) * numpy.nan,
+                axis=0)
+
+        if existing_weights_shape != reshaped.shape:
+            print("Performing network surgery", existing_weights_shape, reshaped.shape)
+            # Network surgery required. Make a new network with this layer's
+            # dimensions changed. Kind of a hack.
+            layer.input_dim = reshaped.shape[0]
+            new_model = clone_model(original_model)
+
+            # copy weights for other layers over
+            for layer in new_model.layers:
+                if layer.name != "allele_representation":
+                    layer.set_weights(
+                        original_model.get_layer(name=layer.name).get_weights())
+
+            self.network = new_model
+
+            layer = new_model.get_layer("allele_representation")
+
+            # Disable the old model to catch bugs.
+            def throw(*args, **kwargs):
+                raise RuntimeError("Using a disabled model!")
+            original_model.predict = \
+                original_model.fit = \
+                original_model.fit_generator = throw
+
+        layer.set_weights([reshaped])
diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 447ab47dcf622fcf8e8ee1cc2ddf3a21175d2240..1b5f2eca6770e5a60284cb5e4a87dd007a122c2b 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -82,8 +82,7 @@ class Class1NeuralNetwork(object):
         learning_rate=None,
     )
     """
-    Loss and optimizer hyperparameters. Any values supported by keras may be
-    used.
+    Loss and optimizer hyperparameters.
     """
 
     fit_hyperparameter_defaults = HyperparameterDefaults(
diff --git a/test/test_class1_ligandome_predictor.py b/test/test_class1_ligandome_predictor.py
index 041e9e6862906cd14a38ccd43085dd96ce3a381a..568ff97ef07279c4843243962b7264730e85413c 100644
--- a/test/test_class1_ligandome_predictor.py
+++ b/test/test_class1_ligandome_predictor.py
@@ -19,14 +19,17 @@ import pandas
 import argparse
 import sys
 
-from numpy.testing import assert_, assert_equal
+from numpy.testing import assert_, assert_equal, assert_allclose
 import numpy
 from random import shuffle
 
+from sklearn.metrics import roc_auc_score
+
 from mhcflurry import Class1AffinityPredictor,Class1NeuralNetwork
-from mhcflurry.allele_encoding import AlleleEncoding
+from mhcflurry.allele_encoding import MultipleAlleleEncoding
 from mhcflurry.class1_ligandome_predictor import Class1LigandomePredictor
 from mhcflurry.downloads import get_path
+from mhcflurry.regression_target import from_ic50
 
 from mhcflurry.testing_utils import cleanup, startup
 from mhcflurry.amino_acid import COMMON_AMINO_ACIDS
@@ -44,7 +47,10 @@ def setup():
     global PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF
     startup()
     PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = Class1AffinityPredictor.load(
-            get_path("models_class1_pan", "models.no_mass_spec"))
+        get_path("models_class1_pan", "models.no_mass_spec"),
+        optimization_level=0,
+        max_models=1)
+
     PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF = pandas.read_csv(
         get_path(
             "models_class1_pan",
@@ -142,26 +148,66 @@ def test_synthetic_allele_refinement():
 
     train_df = pandas.concat([hits_df, decoys_df], ignore_index=True)
 
-    predictor = Class1LigandomePredictor(PAN_ALLELE_PREDICTOR_NO_MASS_SPEC)
+    predictor = Class1LigandomePredictor(
+        PAN_ALLELE_PREDICTOR_NO_MASS_SPEC,
+        max_ensemble_size=1,
+        max_epochs=100,
+        patience=5)
+
+    allele_encoding = MultipleAlleleEncoding(
+        experiment_names=["experiment1"] * len(train_df),
+        experiment_to_allele_list={
+            "experiment1": alleles,
+        },
+        max_alleles_per_experiment=6,
+        allele_to_sequence=PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.allele_to_sequence,
+    ).compact()
+
+    pre_predictions = predictor.predict(
+        peptides=train_df.peptide.values,
+        allele_encoding=allele_encoding)
+
+    (model,) = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.class1_pan_allele_models
+    expected_pre_predictions = from_ic50(
+        model.predict(
+            peptides=numpy.repeat(train_df.peptide.values, len(alleles)),
+            allele_encoding=allele_encoding.allele_encoding,
+    )).reshape((-1, len(alleles)))
+
+    train_df["pre_max_prediction"] = pre_predictions.max(1)
+    pre_auc = roc_auc_score(train_df.hit.values, train_df.pre_max_prediction.values)
+    print("PRE_AUC", pre_auc)
+
+    #import ipdb ; ipdb.set_trace()
+
+    assert_allclose(pre_predictions, expected_pre_predictions)
+
     predictor.fit(
         peptides=train_df.peptide.values,
         labels=train_df.hit.values,
-        experiment_names=["experiment1"] * len(train_df),
-        experiment_name_to_alleles={
-            "experiment1": alleles,
-        }
+        allele_encoding=allele_encoding
     )
 
     predictions = predictor.predict(
         peptides=train_df.peptide.values,
-        alleles=alleles,
-        output_format="concatenate"
+        allele_encoding=allele_encoding,
     )
 
+    train_df["max_prediction"] = predictions.max(1)
+    train_df["predicted_allele"] = pandas.Series(alleles).loc[
+        predictions.argmax(1).flatten()
+    ].values
+
     print(predictions)
 
+    auc = roc_auc_score(train_df.hit.values, train_df.max_prediction.values)
+    print("AUC", auc)
+
     import ipdb ; ipdb.set_trace()
 
+    return predictions
+
+
 
 """
 def test_simple_synethetic(