From 67987dafef56d559f028236f9a33a9f4e9ad0d9e Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Fri, 6 Dec 2019 09:37:32 -0500
Subject: [PATCH] better tests

---
 mhcflurry/custom_loss.py                      |   5 +-
 ...test_class1_presentation_neural_network.py | 430 +++++++++
 test/test_class1_presentation_predictor.py    | 840 +-----------------
 test/test_custom_loss.py                      |  74 +-
 4 files changed, 508 insertions(+), 841 deletions(-)
 create mode 100644 test/test_class1_presentation_neural_network.py

diff --git a/mhcflurry/custom_loss.py b/mhcflurry/custom_loss.py
index 3660ae32..6cec8a8b 100644
--- a/mhcflurry/custom_loss.py
+++ b/mhcflurry/custom_loss.py
@@ -164,7 +164,8 @@ class MSEWithInequalities(Loss):
         # We always delay import of Keras so that mhcflurry can be imported
         # initially without tensorflow debug output, etc.
         from keras import backend as K
-        y_true = K.squeeze(y_true, axis=-1)
+        y_true = K.flatten(y_true)
+        y_pred = K.flatten(y_pred)
 
         # Handle (=) inequalities
         diff1 = y_pred - y_true
@@ -286,7 +287,7 @@ class MultiallelicMassSpecLoss(Loss):
 
     def loss(self, y_true, y_pred):
         import tensorflow as tf
-        y_true = tf.squeeze(y_true, axis=-1)
+        y_true = tf.reshape(y_true, (-1,))
         pos = tf.boolean_mask(y_pred, tf.math.equal(y_true, 1.0))
         pos_max = tf.reduce_max(pos, axis=1)
         neg = tf.boolean_mask(y_pred, tf.math.equal(y_true, 0.0))
diff --git a/test/test_class1_presentation_neural_network.py b/test/test_class1_presentation_neural_network.py
new file mode 100644
index 00000000..70823bce
--- /dev/null
+++ b/test/test_class1_presentation_neural_network.py
@@ -0,0 +1,430 @@
+import logging
+logging.getLogger('tensorflow').disabled = True
+logging.getLogger('matplotlib').disabled = True
+
+import pandas
+import argparse
+import sys
+import copy
+import os
+import tempfile
+import pickle
+
+from numpy.testing import assert_, assert_equal, assert_allclose, assert_array_equal
+from nose.tools import assert_greater, assert_less
+import numpy
+from random import shuffle
+
+from sklearn.metrics import roc_auc_score
+
+from mhcflurry import Class1AffinityPredictor
+from mhcflurry.allele_encoding import MultipleAlleleEncoding
+from mhcflurry.class1_presentation_neural_network import Class1PresentationNeuralNetwork
+from mhcflurry.class1_presentation_predictor import Class1PresentationPredictor
+from mhcflurry.encodable_sequences import EncodableSequences
+from mhcflurry.downloads import get_path
+from mhcflurry.regression_target import from_ic50
+from mhcflurry.common import random_peptides, positional_frequency_matrix
+from mhcflurry.testing_utils import cleanup, startup
+from mhcflurry.amino_acid import COMMON_AMINO_ACIDS
+from mhcflurry.custom_loss import MultiallelicMassSpecLoss
+from mhcflurry.regression_target import to_ic50
+
+
+###################################################
+# SETUP
+###################################################
+
+COMMON_AMINO_ACIDS = sorted(COMMON_AMINO_ACIDS)
+
+PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = None
+
+def setup():
+    global PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
+    startup()
+    PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = Class1AffinityPredictor.load(
+        get_path("models_class1_pan", "models.no_mass_spec"),
+        optimization_level=0,
+        max_models=1)
+
+
+def teardown():
+    global PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
+    PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = None
+    cleanup()
+
+
+def data_path(name):
+    '''
+    Return the absolute path to a file in the test/data directory.
+    The name specified should be relative to test/data.
+    '''
+    return os.path.join(os.path.dirname(__file__), "data", name)
+
+
+###################################################
+# UTILITY FUNCTIONS
+###################################################
+
+def scramble_peptide(peptide):
+    lst = list(peptide)
+    shuffle(lst)
+    return "".join(lst)
+
+
+def make_motif(presentation_predictor, allele, peptides, frac=0.01):
+    peptides = EncodableSequences.create(peptides)
+    predictions = presentation_predictor.predict(
+        peptides=peptides,
+        alleles=[allele],
+    )
+    random_predictions_df = pandas.DataFrame({"peptide": peptides.sequences})
+    random_predictions_df["prediction"] = predictions
+    random_predictions_df = random_predictions_df.sort_values(
+        "prediction", ascending=False)
+    top = random_predictions_df.iloc[:int(len(random_predictions_df) * frac)]
+    matrix = positional_frequency_matrix(top.peptide.values)
+    return matrix
+
+
+###################################################
+# TESTS
+###################################################
+
+def test_synthetic_allele_refinement_with_affinity_data():
+    test_synthetic_allele_refinement(include_affinities=True)
+
+
+def test_synthetic_allele_refinement(max_epochs=10, include_affinities=False):
+    """
+    Idea:
+
+    - take an allele where MS vs. no-MS trained predictors are very
+    different. One
+        possiblility is DLA-88*501:01 but human would be better
+    - generate synethetic multi-allele MS by combining single-allele MS for
+    differnet
+       alleles, including the selected allele
+    - train presentation predictor based on the no-ms pan-allele models on theis
+      synthetic dataset
+    - see if the pan-allele predictor learns the "correct" motif for the
+    selected
+      allele, i.e. updates to become more similar to the with-ms pan allele
+      predictor.
+    """
+    refine_allele = "HLA-C*01:02"
+    alleles = ["HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01", "HLA-A*03:01",
+        "HLA-B*15:01", refine_allele]
+    peptides_per_allele = [2000, 1000, 500, 1500, 1200, 800, ]
+
+    allele_to_peptides = dict(zip(alleles, peptides_per_allele))
+
+    length = 9
+
+    train_with_ms = pandas.read_csv(get_path("data_curated",
+        "curated_training_data.with_mass_spec.csv.bz2"))
+    train_no_ms = pandas.read_csv(
+        get_path("data_curated", "curated_training_data.no_mass_spec.csv.bz2"))
+
+    def filter_df(df):
+        return df.loc[
+            (df.allele.isin(alleles)) & (df.peptide.str.len() == length)]
+
+    train_with_ms = filter_df(train_with_ms)
+    train_no_ms = filter_df(train_no_ms)
+
+    ms_specific = train_with_ms.loc[
+        ~train_with_ms.peptide.isin(train_no_ms.peptide)]
+
+    train_peptides = []
+    train_true_alleles = []
+    for allele in alleles:
+        peptides = ms_specific.loc[ms_specific.allele == allele].peptide.sample(
+            n=allele_to_peptides[allele])
+        train_peptides.extend(peptides)
+        train_true_alleles.extend([allele] * len(peptides))
+
+    hits_df = pandas.DataFrame({"peptide": train_peptides})
+    hits_df["true_allele"] = train_true_alleles
+    hits_df["hit"] = 1.0
+
+    decoys_df = hits_df.copy()
+    decoys_df["peptide"] = decoys_df.peptide.map(scramble_peptide)
+    decoys_df["true_allele"] = ""
+    decoys_df["hit"] = 0.0
+
+    mms_train_df = pandas.concat([hits_df, decoys_df], ignore_index=True)
+    mms_train_df["label"] = mms_train_df.hit
+    mms_train_df["is_affinity"] = False
+    mms_train_df["measurement_inequality"] = None
+
+    if include_affinities:
+        affinity_train_df = pandas.read_csv(get_path("models_class1_pan",
+            "models.with_mass_spec/train_data.csv.bz2"))
+        affinity_train_df = affinity_train_df.loc[
+            affinity_train_df.allele.isin(alleles), ["peptide", "allele",
+                "measurement_inequality", "measurement_value"]]
+
+        affinity_train_df["label"] = affinity_train_df["measurement_value"]
+        del affinity_train_df["measurement_value"]
+        affinity_train_df["is_affinity"] = True
+    else:
+        affinity_train_df = None
+
+    (
+    affinity_model,) = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.class1_pan_allele_models
+    presentation_model = Class1PresentationNeuralNetwork(
+        auxiliary_input_features=["gene"], batch_generator_batch_size=1024,
+        max_epochs=max_epochs, learning_rate=0.001, patience=5, min_delta=0.0,
+        random_negative_rate=1.0, random_negative_constant=25)
+    presentation_model.load_from_class1_neural_network(affinity_model)
+
+    presentation_predictor = Class1PresentationPredictor(
+        models=[presentation_model],
+        allele_to_sequence=PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.allele_to_sequence)
+
+    mms_allele_encoding = MultipleAlleleEncoding(
+        experiment_names=["experiment1"] * len(mms_train_df),
+        experiment_to_allele_list={
+            "experiment1": alleles,
+        }, max_alleles_per_experiment=6,
+        allele_to_sequence=PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.allele_to_sequence, )
+    allele_encoding = copy.deepcopy(mms_allele_encoding)
+    if affinity_train_df is not None:
+        allele_encoding.append_alleles(affinity_train_df.allele.values)
+        train_df = pandas.concat([mms_train_df, affinity_train_df],
+            ignore_index=True, sort=False)
+    else:
+        train_df = mms_train_df
+
+    allele_encoding = allele_encoding.compact()
+
+    pre_predictions = presentation_model.predict(
+        peptides=mms_train_df.peptide.values,
+        allele_encoding=mms_allele_encoding).score
+
+    expected_pre_predictions = from_ic50(affinity_model.predict(
+        peptides=numpy.repeat(mms_train_df.peptide.values, len(alleles)),
+        allele_encoding=mms_allele_encoding.allele_encoding, )).reshape(
+        (-1, len(alleles)))
+    mms_train_df["pre_max_prediction"] = pre_predictions.max(1)
+    pre_auc = roc_auc_score(
+        mms_train_df.hit.values,
+        mms_train_df.pre_max_prediction.values)
+    print("PRE_AUC", pre_auc)
+
+    assert_allclose(pre_predictions, expected_pre_predictions, rtol=1e-4)
+
+    random_peptides_encodable = EncodableSequences.create(
+        random_peptides(10000, 9))
+
+    original_motif = make_motif(
+        presentation_predictor=presentation_predictor,
+        peptides=random_peptides_encodable,
+        allele=refine_allele)
+    print("Original motif proline-3 rate: ", original_motif.loc[3, "P"])
+
+    metric_rows = []
+
+    def progress():
+        (_, presentation_prediction, affinities_predictions) = (
+            predictor.predict(
+                output="all",
+                peptides=mms_train_df.peptide.values,
+                alleles=mms_allele_encoding))
+        affinities_predictions = from_ic50(affinities_predictions)
+        for (kind, predictions) in [
+                ("affinities", affinities_predictions),
+                ("presentation", presentation_prediction)]:
+
+            mms_train_df["max_prediction"] = predictions.max(1)
+            mms_train_df["predicted_allele"] = pandas.Series(alleles).loc[
+                predictions.argmax(1).flatten()
+            ].values
+
+            print(kind)
+            print(predictions)
+
+            mean_predictions_for_hit = mms_train_df.loc[
+                mms_train_df.hit == 1.0
+            ].max_prediction.mean()
+            mean_predictions_for_decoy = mms_train_df.loc[
+                mms_train_df.hit == 0.0
+            ].max_prediction.mean()
+            correct_allele_fraction = (
+                    mms_train_df.loc[mms_train_df.hit == 1.0].predicted_allele ==
+                    mms_train_df.loc[mms_train_df.hit == 1.0].true_allele
+            ).mean()
+            auc = roc_auc_score(mms_train_df.hit.values, mms_train_df.max_prediction.values)
+
+            print(kind, "Mean prediction for hit", mean_predictions_for_hit)
+            print(kind, "Mean prediction for decoy", mean_predictions_for_decoy)
+            print(kind, "Correct predicted allele fraction", correct_allele_fraction)
+            print(kind, "AUC", auc)
+
+            metric_rows.append((
+                kind,
+                mean_predictions_for_hit,
+                mean_predictions_for_decoy,
+                correct_allele_fraction,
+                auc,
+            ))
+
+            update_motifs()
+
+        return (presentation_prediction, auc)
+
+
+    print("Pre fitting:")
+    #progress()
+
+    presentation_model.fit(peptides=train_df.peptide.values,
+        labels=train_df.label.values,
+        inequalities=train_df.measurement_inequality.values,
+        affinities_mask=train_df.is_affinity.values,
+        allele_encoding=allele_encoding, )
+    post_predictions = presentation_model.predict(
+        peptides=mms_train_df.peptide.values,
+        allele_encoding=mms_allele_encoding).score
+    mms_train_df["post_max_prediction"] = pre_predictions.max(1)
+    post_auc = roc_auc_score(
+        mms_train_df.hit.values,
+        mms_train_df.post_max_prediction.values)
+    print("POST_AUC", post_auc)
+
+    final_motif = make_motif(
+        presentation_predictor=presentation_predictor,
+        peptides=random_peptides_encodable,
+        allele=refine_allele)
+    print("Final motif proline-3 rate: ", final_motif.loc[3, "P"])
+    import ipdb ; ipdb.set_trace()
+
+    # (predictions, final_auc) = progress()
+    # print("Final AUC", final_auc)
+
+    """"
+    update_motifs()
+
+    metrics = pandas.DataFrame(
+        metric_rows,
+        columns=[
+            "output",
+            "mean_predictions_for_hit",
+            "mean_predictions_for_decoy",
+            "correct_allele_fraction",
+            "auc"
+        ])
+    """
+
+
+def Xtest_real_data_multiallelic_refinement(max_epochs=10):
+    ms_df = pandas.read_csv(
+        get_path("data_mass_spec_annotated", "annotated_ms.csv.bz2"))
+    ms_df = ms_df.loc[
+        (ms_df.mhc_class == "I") & (~ms_df.protein_ensembl.isnull())].copy()
+
+    sample_table = ms_df.drop_duplicates(
+        "sample_id").set_index("sample_id").loc[ms_df.sample_id.unique()]
+    grouped = ms_df.groupby("sample_id").nunique()
+    for col in sample_table.columns:
+        if (grouped[col] > 1).any():
+            del sample_table[col]
+    sample_table["alleles"] = sample_table.hla.str.split()
+
+    multi_train_hit_df = ms_df.loc[
+        ms_df.sample_id  == "RA957"
+    ].drop_duplicates("peptide")[["peptide", "sample_id"]].reset_index(drop=True)
+    multi_train_hit_df["label"] = 1.0
+
+    multi_train_decoy_df = ms_df.loc[
+        (ms_df.sample_id  == "CD165") &
+        (~ms_df.peptide.isin(multi_train_hit_df.peptide.unique()))
+    ].drop_duplicates("peptide")[["peptide"]]
+    (multi_train_decoy_df["sample_id"],) = multi_train_hit_df.sample_id.unique()
+    multi_train_decoy_df["label"] = 0.0
+
+    multi_train_df = pandas.concat(
+        [multi_train_hit_df, multi_train_decoy_df], ignore_index=True)
+    multi_train_df["is_affinity"] = False
+
+    multi_train_alleles = set()
+    for alleles in sample_table.loc[multi_train_df.sample_id.unique()].alleles:
+        multi_train_alleles.update(alleles)
+    multi_train_alleles = sorted(multi_train_alleles)
+
+    pan_train_df = pandas.read_csv(
+        get_path(
+            "models_class1_pan", "models.with_mass_spec/train_data.csv.bz2"))
+
+    pan_sub_train_df = pan_train_df.loc[
+        pan_train_df.allele.isin(multi_train_alleles),
+        ["peptide", "allele", "measurement_inequality", "measurement_value"]
+    ]
+    pan_sub_train_df["label"] = pan_sub_train_df["measurement_value"]
+    del pan_sub_train_df["measurement_value"]
+    pan_sub_train_df["is_affinity"] = True
+
+    pan_predictor = Class1AffinityPredictor.load(
+        get_path("models_class1_pan", "models.with_mass_spec"),
+        optimization_level=0,
+        max_models=1)
+
+    allele_encoding = MultipleAlleleEncoding(
+        experiment_names=multi_train_df.sample_id.values,
+        experiment_to_allele_list=sample_table.alleles.to_dict(),
+        max_alleles_per_experiment=sample_table.alleles.str.len().max(),
+        allele_to_sequence=pan_predictor.allele_to_sequence,
+    )
+    allele_encoding.append_alleles(pan_sub_train_df.allele.values)
+    allele_encoding =  allele_encoding.compact()
+
+    combined_train_df = pandas.concat([multi_train_df, pan_sub_train_df])
+
+    presentation_predictor = Class1PresentationNeuralNetwork(
+        pan_predictor,
+        auxiliary_input_features=[],
+        max_ensemble_size=1,
+        max_epochs=max_epochs,
+        learning_rate=0.0001,
+        patience=5,
+        min_delta=0.0,
+        random_negative_rate=1.0)
+
+    pre_predictions = from_ic50(
+        presentation_predictor.predict(
+            output="affinities",
+            peptides=combined_train_df.peptide.values,
+            alleles=allele_encoding))
+
+    (model,) = pan_predictor.class1_pan_allele_models
+    expected_pre_predictions = from_ic50(
+        model.predict(
+            peptides=numpy.repeat(combined_train_df.peptide.values, len(alleles)),
+            allele_encoding=allele_encoding.allele_encoding,
+    )).reshape((-1, len(alleles)))[:,0]
+
+    assert_allclose(pre_predictions, expected_pre_predictions, rtol=1e-4)
+
+    motifs_history = []
+    random_peptides_encodable = make_random_peptides(10000, [9])
+
+
+    def update_motifs():
+        for allele in multi_train_alleles:
+            motif = make_motif(allele, random_peptides_encodable)
+            motifs_history.append((allele, motif))
+
+    print("Pre fitting:")
+    update_motifs()
+    print("Fitting...")
+
+    presentation_predictor.fit(
+        peptides=combined_train_df.peptide.values,
+        labels=combined_train_df.label.values,
+        allele_encoding=allele_encoding,
+        affinities_mask=combined_train_df.is_affinity.values,
+        inequalities=combined_train_df.measurement_inequality.values,
+        progress_callback=update_motifs,
+    )
+
+    import ipdb ; ipdb.set_trace()
\ No newline at end of file
diff --git a/test/test_class1_presentation_predictor.py b/test/test_class1_presentation_predictor.py
index 37837571..80fa1f43 100644
--- a/test/test_class1_presentation_predictor.py
+++ b/test/test_class1_presentation_predictor.py
@@ -1,97 +1,42 @@
-"""
-
-Idea:
-
-- take an allele where MS vs. no-MS trained predictors are very different. One
-    possiblility is DLA-88*501:01 but human would be better
-- generate synethetic multi-allele MS by combining single-allele MS for differnet
-   alleles, including the selected allele
-- train presentation predictor based on the no-ms pan-allele models on theis
-  synthetic dataset
-- see if the pan-allele predictor learns the "correct" motif for the selected
-  allele, i.e. updates to become more similar to the with-ms pan allele predictor.
-
-
-"""
-
 import logging
 logging.getLogger('tensorflow').disabled = True
 logging.getLogger('matplotlib').disabled = True
 
 import pandas
-import argparse
-import sys
-import copy
-import os
 import tempfile
 import pickle
 
 from numpy.testing import assert_, assert_equal, assert_allclose, assert_array_equal
 from nose.tools import assert_greater, assert_less
 import numpy
-from random import shuffle
-
-from sklearn.metrics import roc_auc_score
 
 from mhcflurry import Class1AffinityPredictor
 from mhcflurry.allele_encoding import MultipleAlleleEncoding
 from mhcflurry.class1_presentation_neural_network import Class1PresentationNeuralNetwork
 from mhcflurry.class1_presentation_predictor import Class1PresentationPredictor
-from mhcflurry.encodable_sequences import EncodableSequences
 from mhcflurry.downloads import get_path
-from mhcflurry.regression_target import from_ic50
-from mhcflurry.common import random_peptides, positional_frequency_matrix
+from mhcflurry.common import random_peptides
 from mhcflurry.testing_utils import cleanup, startup
-from mhcflurry.amino_acid import COMMON_AMINO_ACIDS
-from mhcflurry.custom_loss import MultiallelicMassSpecLoss
 from mhcflurry.regression_target import to_ic50
 
-COMMON_AMINO_ACIDS = sorted(COMMON_AMINO_ACIDS)
-
 PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = None
-PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF = None
-PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF = None
-
-def data_path(name):
-    '''
-    Return the absolute path to a file in the test/data directory.
-    The name specified should be relative to test/data.
-    '''
-    return os.path.join(os.path.dirname(__file__), "data", name)
-
 
 def setup():
     global PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
-    global PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF
-    global PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF
     startup()
     PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = Class1AffinityPredictor.load(
         get_path("models_class1_pan", "models.no_mass_spec"),
         optimization_level=0,
         max_models=1)
 
-    PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF = pandas.read_csv(
-        get_path(
-            "models_class1_pan",
-            "models.with_mass_spec/frequency_matrices.csv.bz2"))
-    PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF = pandas.read_csv(
-        get_path(
-            "models_class1_pan",
-            "models.no_mass_spec/frequency_matrices.csv.bz2"))
-
 
 def teardown():
     global PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
-    global PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF
-    global PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF
-
     PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = None
-    PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF = None
-    PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF = None
     cleanup()
 
 
-def Xtest_basic():
+def test_basic():
     affinity_predictor = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
     models = []
     for affinity_network in affinity_predictor.class1_pan_allele_models:
@@ -222,784 +167,3 @@ def Xtest_basic():
     print(median_affinity_change)
     assert_less(median_affinity_change[1.0], median_affinity_change[0.0])
 
-
-def scramble_peptide(peptide):
-    lst = list(peptide)
-    shuffle(lst)
-    return "".join(lst)
-
-
-def evaluate_loss(loss, y_true, y_pred):
-    import keras.backend as K
-
-    y_true = numpy.array(y_true)
-    y_pred = numpy.array(y_pred)
-    if y_pred.ndim == 1:
-        y_pred = y_pred.reshape((len(y_pred), 1))
-    if y_true.ndim == 1:
-        y_true = y_true.reshape((len(y_true), 1))
-
-    if K.backend() == "tensorflow":
-        session = K.get_session()
-        y_true_var = K.constant(y_true, name="y_true")
-        y_pred_var = K.constant(y_pred, name="y_pred")
-        result = loss(y_true_var, y_pred_var)
-        return result.eval(session=session)
-    elif K.backend() == "theano":
-        y_true_var = K.constant(y_true, name="y_true")
-        y_pred_var = K.constant(y_pred, name="y_pred")
-        result = loss(y_true_var, y_pred_var)
-        return result.eval()
-    else:
-        raise ValueError("Unsupported backend: %s" % K.backend())
-
-
-def Xtest_loss():
-    for delta in [0.0, 0.3]:
-        print("delta", delta)
-        # Hit labels
-        y_true = [
-            1.0,
-            0.0,
-            1.0,
-            -1.0,  # ignored
-            1.0,
-            0.0,
-            1.0,
-        ]
-        y_true = numpy.array(y_true)
-        y_pred = [
-            [0.3, 0.7, 0.5],
-            [0.2, 0.4, 0.6],
-            [0.1, 0.5, 0.3],
-            [0.9, 0.1, 0.2],
-            [0.1, 0.7, 0.1],
-            [0.8, 0.2, 0.4],
-            [0.1, 0.2, 0.4],
-        ]
-        y_pred = numpy.array(y_pred)
-
-        # reference implementation 1
-
-        def smooth_max(x, alpha):
-            x = numpy.array(x)
-            alpha = numpy.array([alpha])
-            return (x * numpy.exp(x * alpha)).sum() / (
-                numpy.exp(x * alpha)).sum()
-
-        contributions = []
-        for i in range(len(y_true)):
-            if y_true[i] == 1.0:
-                for j in range(len(y_true)):
-                    if y_true[j] == 0.0:
-                        tightest_i = max(y_pred[i])
-                        for k in range(y_pred.shape[1]):
-                            contribution = max(
-                                0, y_pred[j, k] - tightest_i + delta)**2
-                            contributions.append(contribution)
-        contributions = numpy.array(contributions)
-        expected1 = contributions.sum() / len(contributions)
-
-        # reference implementation 2: numpy
-        pos = numpy.array([
-            max(y_pred[i])
-            for i in range(len(y_pred))
-            if y_true[i] == 1.0
-        ])
-
-        neg = y_pred[(y_true == 0.0).astype(bool)]
-        term = neg.reshape((-1, 1)) - pos + delta
-        expected2 = (
-                numpy.maximum(0, term)**2).sum() / (
-            len(pos) * neg.shape[0] * neg.shape[1])
-
-        numpy.testing.assert_almost_equal(expected1, expected2)
-
-        computed = evaluate_loss(
-            MultiallelicMassSpecLoss(delta=delta).loss,
-            y_true,
-            y_pred.reshape(y_pred.shape))
-
-        numpy.testing.assert_almost_equal(computed, expected1, 4)
-
-
-AA_DIST = pandas.Series(
-    dict((line.split()[0], float(line.split()[1])) for line in """
-A    0.071732
-E    0.060102
-N    0.034679
-D    0.039601
-T    0.055313
-L    0.115337
-V    0.070498
-S    0.071882
-Q    0.040436
-F    0.050178
-G    0.053176
-C    0.005429
-H    0.025487
-I    0.056312
-W    0.013593
-K    0.057832
-M    0.021079
-Y    0.043372
-R    0.060330
-P    0.053632
-""".strip().split("\n")))
-print(AA_DIST)
-
-
-def make_random_peptides(num_peptides_per_length=10000, lengths=[9]):
-    peptides = []
-    for length in lengths:
-        peptides.extend(
-            random_peptides
-                (num_peptides_per_length, length=length, distribution=AA_DIST))
-    return EncodableSequences.create(peptides)
-
-
-def make_motif(allele, peptides, frac=0.01):
-    peptides = EncodableSequences.create(peptides)
-    predictions = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.predict(
-        peptides=peptides,
-        allele=allele,
-    )
-    random_predictions_df = pandas.DataFrame({"peptide": peptides.sequences})
-    random_predictions_df["prediction"] = predictions
-    random_predictions_df = random_predictions_df.sort_values(
-        "prediction", ascending=True)
-    top = random_predictions_df.iloc[:int(len(random_predictions_df) * frac)]
-    matrix = positional_frequency_matrix(top.peptide.values)
-    return matrix
-
-
-def Xtest_real_data_multiallelic_refinement(max_epochs=10):
-    ms_df = pandas.read_csv(
-        get_path("data_mass_spec_annotated", "annotated_ms.csv.bz2"))
-    ms_df = ms_df.loc[
-        (ms_df.mhc_class == "I") & (~ms_df.protein_ensembl.isnull())].copy()
-
-    sample_table = ms_df.drop_duplicates(
-        "sample_id").set_index("sample_id").loc[ms_df.sample_id.unique()]
-    grouped = ms_df.groupby("sample_id").nunique()
-    for col in sample_table.columns:
-        if (grouped[col] > 1).any():
-            del sample_table[col]
-    sample_table["alleles"] = sample_table.hla.str.split()
-
-    multi_train_hit_df = ms_df.loc[
-        ms_df.sample_id  == "RA957"
-    ].drop_duplicates("peptide")[["peptide", "sample_id"]].reset_index(drop=True)
-    multi_train_hit_df["label"] = 1.0
-
-    multi_train_decoy_df = ms_df.loc[
-        (ms_df.sample_id  == "CD165") &
-        (~ms_df.peptide.isin(multi_train_hit_df.peptide.unique()))
-    ].drop_duplicates("peptide")[["peptide"]]
-    (multi_train_decoy_df["sample_id"],) = multi_train_hit_df.sample_id.unique()
-    multi_train_decoy_df["label"] = 0.0
-
-    multi_train_df = pandas.concat(
-        [multi_train_hit_df, multi_train_decoy_df], ignore_index=True)
-    multi_train_df["is_affinity"] = False
-
-    multi_train_alleles = set()
-    for alleles in sample_table.loc[multi_train_df.sample_id.unique()].alleles:
-        multi_train_alleles.update(alleles)
-    multi_train_alleles = sorted(multi_train_alleles)
-
-    pan_train_df = pandas.read_csv(
-        get_path(
-            "models_class1_pan", "models.with_mass_spec/train_data.csv.bz2"))
-
-    pan_sub_train_df = pan_train_df.loc[
-        pan_train_df.allele.isin(multi_train_alleles),
-        ["peptide", "allele", "measurement_inequality", "measurement_value"]
-    ]
-    pan_sub_train_df["label"] = pan_sub_train_df["measurement_value"]
-    del pan_sub_train_df["measurement_value"]
-    pan_sub_train_df["is_affinity"] = True
-
-    pan_predictor = Class1AffinityPredictor.load(
-        get_path("models_class1_pan", "models.with_mass_spec"),
-        optimization_level=0,
-        max_models=1)
-
-    allele_encoding = MultipleAlleleEncoding(
-        experiment_names=multi_train_df.sample_id.values,
-        experiment_to_allele_list=sample_table.alleles.to_dict(),
-        max_alleles_per_experiment=sample_table.alleles.str.len().max(),
-        allele_to_sequence=pan_predictor.allele_to_sequence,
-    )
-    allele_encoding.append_alleles(pan_sub_train_df.allele.values)
-    allele_encoding =  allele_encoding.compact()
-
-    combined_train_df = pandas.concat([multi_train_df, pan_sub_train_df])
-
-    presentation_predictor = Class1PresentationNeuralNetwork(
-        pan_predictor,
-        auxiliary_input_features=[],
-        max_ensemble_size=1,
-        max_epochs=max_epochs,
-        learning_rate=0.0001,
-        patience=5,
-        min_delta=0.0,
-        random_negative_rate=1.0)
-
-    pre_predictions = from_ic50(
-        presentation_predictor.predict(
-            output="affinities",
-            peptides=combined_train_df.peptide.values,
-            alleles=allele_encoding))
-
-    (model,) = pan_predictor.class1_pan_allele_models
-    expected_pre_predictions = from_ic50(
-        model.predict(
-            peptides=numpy.repeat(combined_train_df.peptide.values, len(alleles)),
-            allele_encoding=allele_encoding.allele_encoding,
-    )).reshape((-1, len(alleles)))[:,0]
-
-    assert_allclose(pre_predictions, expected_pre_predictions, rtol=1e-4)
-
-    motifs_history = []
-    random_peptides_encodable = make_random_peptides(10000, [9])
-
-
-    def update_motifs():
-        for allele in multi_train_alleles:
-            motif = make_motif(allele, random_peptides_encodable)
-            motifs_history.append((allele, motif))
-
-    print("Pre fitting:")
-    update_motifs()
-    print("Fitting...")
-
-    presentation_predictor.fit(
-        peptides=combined_train_df.peptide.values,
-        labels=combined_train_df.label.values,
-        allele_encoding=allele_encoding,
-        affinities_mask=combined_train_df.is_affinity.values,
-        inequalities=combined_train_df.measurement_inequality.values,
-        progress_callback=update_motifs,
-    )
-
-    import ipdb ; ipdb.set_trace()
-
-def test_synthetic_allele_refinement_with_affinity_data():
-    test_synthetic_allele_refinement(include_affinities=True)
-
-
-def test_synthetic_allele_refinement(max_epochs=10, include_affinities=False):
-    refine_allele = "HLA-C*01:02"
-    alleles = [
-        "HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01",
-        "HLA-A*03:01", "HLA-B*15:01", refine_allele
-    ]
-    peptides_per_allele = [
-        2000, 1000, 500,
-        1500, 1200, 800,
-    ]
-
-    allele_to_peptides = dict(zip(alleles, peptides_per_allele))
-
-    length = 9
-
-    train_with_ms = pandas.read_csv(
-        get_path("data_curated", "curated_training_data.with_mass_spec.csv.bz2"))
-    train_no_ms = pandas.read_csv(get_path("data_curated",
-        "curated_training_data.no_mass_spec.csv.bz2"))
-
-    def filter_df(df):
-        return df.loc[
-            (df.allele.isin(alleles)) &
-            (df.peptide.str.len() == length)
-        ]
-
-    train_with_ms = filter_df(train_with_ms)
-    train_no_ms = filter_df(train_no_ms)
-
-    ms_specific = train_with_ms.loc[
-        ~train_with_ms.peptide.isin(train_no_ms.peptide)
-    ]
-
-    train_peptides = []
-    train_true_alleles = []
-    for allele in alleles:
-        peptides = ms_specific.loc[ms_specific.allele == allele].peptide.sample(
-            n=allele_to_peptides[allele])
-        train_peptides.extend(peptides)
-        train_true_alleles.extend([allele] * len(peptides))
-
-    hits_df = pandas.DataFrame({"peptide": train_peptides})
-    hits_df["true_allele"] = train_true_alleles
-    hits_df["hit"] = 1.0
-
-    decoys_df = hits_df.copy()
-    decoys_df["peptide"] = decoys_df.peptide.map(scramble_peptide)
-    decoys_df["true_allele"] = ""
-    decoys_df["hit"] = 0.0
-
-    mms_train_df = pandas.concat([hits_df, decoys_df], ignore_index=True)
-    mms_train_df["label"] =  mms_train_df.hit
-    mms_train_df["is_affinity"] = False
-    mms_train_df["measurement_inequality"] = None
-
-    if include_affinities:
-        affinity_train_df = pandas.read_csv(
-            get_path(
-                "models_class1_pan", "models.with_mass_spec/train_data.csv.bz2"))
-        affinity_train_df = affinity_train_df.loc[
-            affinity_train_df.allele.isin(alleles),
-            ["peptide", "allele",  "measurement_inequality", "measurement_value"]]
-
-        affinity_train_df["label"] = affinity_train_df["measurement_value"]
-        del affinity_train_df["measurement_value"]
-        affinity_train_df["is_affinity"] = True
-    else:
-        affinity_train_df = None
-
-    (affinity_model,) = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.class1_pan_allele_models
-    presentation_model = Class1PresentationNeuralNetwork(
-        auxiliary_input_features=["gene"],
-        batch_generator_batch_size=1024,
-        max_epochs=max_epochs,
-        learning_rate=0.001,
-        patience=5,
-        min_delta=0.0,
-        random_negative_rate=1.0,
-        random_negative_constant=25)
-    presentation_model.load_from_class1_neural_network(affinity_model)
-
-    presentation_model = pickle.loads(pickle.dumps(presentation_model))
-
-    mms_allele_encoding = MultipleAlleleEncoding(
-        experiment_names=["experiment1"] * len(mms_train_df),
-        experiment_to_allele_list={
-            "experiment1": alleles,
-        },
-        max_alleles_per_experiment=6,
-        allele_to_sequence=PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.allele_to_sequence,
-    )
-    allele_encoding = copy.deepcopy(mms_allele_encoding)
-    if affinity_train_df is not None:
-        allele_encoding.append_alleles(affinity_train_df.allele.values)
-        train_df = pandas.concat(
-            [mms_train_df, affinity_train_df], ignore_index=True, sort=False)
-    else:
-        train_df = mms_train_df
-
-    allele_encoding = allele_encoding.compact()
-
-    pre_predictions = presentation_model.predict(
-            peptides=mms_train_df.peptide.values,
-            allele_encoding=mms_allele_encoding).score
-
-    expected_pre_predictions = from_ic50(
-        affinity_model.predict(
-            peptides=numpy.repeat(mms_train_df.peptide.values, len(alleles)),
-            allele_encoding=mms_allele_encoding.allele_encoding,
-    )).reshape((-1, len(alleles)))
-    mms_train_df["pre_max_prediction"] = pre_predictions.max(1)
-    pre_auc = roc_auc_score(mms_train_df.hit.values, mms_train_df.pre_max_prediction.values)
-    print("PRE_AUC", pre_auc)
-
-    assert_allclose(pre_predictions, expected_pre_predictions, rtol=1e-4)
-
-    motifs_history = []
-    random_peptides_encodable = make_random_peptides(10000, [9])
-
-
-    def update_motifs():
-        for allele in alleles:
-            motif = make_motif(allele, random_peptides_encodable)
-            motifs_history.append((allele, motif))
-
-    metric_rows = []
-
-    """
-    def progress():
-        (_, presentation_prediction, affinities_predictions) = (
-            predictor.predict(
-                output="all",
-                peptides=mms_train_df.peptide.values,
-                alleles=mms_allele_encoding))
-        affinities_predictions = from_ic50(affinities_predictions)
-        for (kind, predictions) in [
-                ("affinities", affinities_predictions),
-                ("presentation", presentation_prediction)]:
-
-            mms_train_df["max_prediction"] = predictions.max(1)
-            mms_train_df["predicted_allele"] = pandas.Series(alleles).loc[
-                predictions.argmax(1).flatten()
-            ].values
-
-            print(kind)
-            print(predictions)
-
-            mean_predictions_for_hit = mms_train_df.loc[
-                mms_train_df.hit == 1.0
-            ].max_prediction.mean()
-            mean_predictions_for_decoy = mms_train_df.loc[
-                mms_train_df.hit == 0.0
-            ].max_prediction.mean()
-            correct_allele_fraction = (
-                    mms_train_df.loc[mms_train_df.hit == 1.0].predicted_allele ==
-                    mms_train_df.loc[mms_train_df.hit == 1.0].true_allele
-            ).mean()
-            auc = roc_auc_score(mms_train_df.hit.values, mms_train_df.max_prediction.values)
-
-            print(kind, "Mean prediction for hit", mean_predictions_for_hit)
-            print(kind, "Mean prediction for decoy", mean_predictions_for_decoy)
-            print(kind, "Correct predicted allele fraction", correct_allele_fraction)
-            print(kind, "AUC", auc)
-
-            metric_rows.append((
-                kind,
-                mean_predictions_for_hit,
-                mean_predictions_for_decoy,
-                correct_allele_fraction,
-                auc,
-            ))
-
-            update_motifs()
-
-        return (presentation_prediction, auc)
-    
-
-    print("Pre fitting:")
-    progress()
-    update_motifs()
-    print("Fitting...")
-    """
-    presentation_model.fit(
-        peptides=train_df.peptide.values,
-        labels=train_df.label.values,
-        inequalities=train_df.measurement_inequality.values,
-        affinities_mask=train_df.is_affinity.values,
-        allele_encoding=allele_encoding,
-        #progress_callback=progress,
-    )
-
-    #(predictions, final_auc) = progress()
-    #print("Final AUC", final_auc)
-
-    """
-    update_motifs()
-
-    motifs = pandas.DataFrame(
-        motifs_history,
-        columns=[
-            "allele",
-            "motif",
-        ]
-    )
-
-    metrics = pandas.DataFrame(
-        metric_rows,
-        columns=[
-            "output",
-            "mean_predictions_for_hit",
-            "mean_predictions_for_decoy",
-            "correct_allele_fraction",
-            "auc"
-        ])
-
-    return (predictor, predictions, metrics, motifs)
-    """
-
-
-
-def Xtest_synthetic_allele_refinement(max_epochs=10):
-    refine_allele = "HLA-C*01:02"
-    alleles = [
-        "HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01",
-        "HLA-A*03:01", "HLA-B*15:01", refine_allele
-    ]
-    peptides_per_allele = [
-        2000, 1000, 500,
-        1500, 1200, 800,
-    ]
-
-    allele_to_peptides = dict(zip(alleles, peptides_per_allele))
-
-    length = 9
-
-    train_with_ms = pandas.read_csv(
-        get_path("data_curated", "curated_training_data.with_mass_spec.csv.bz2"))
-    train_no_ms = pandas.read_csv(get_path("data_curated",
-        "curated_training_data.no_mass_spec.csv.bz2"))
-
-    def filter_df(df):
-        df = df.loc[
-            (df.allele.isin(alleles)) &
-            (df.peptide.str.len() == length)
-        ]
-        return df
-
-    train_with_ms = filter_df(train_with_ms)
-    train_no_ms = filter_df(train_no_ms)
-
-    ms_specific = train_with_ms.loc[
-        ~train_with_ms.peptide.isin(train_no_ms.peptide)
-    ]
-
-    train_peptides = []
-    train_true_alleles = []
-    for allele in alleles:
-        peptides = ms_specific.loc[ms_specific.allele == allele].peptide.sample(
-            n=allele_to_peptides[allele])
-        train_peptides.extend(peptides)
-        train_true_alleles.extend([allele] * len(peptides))
-
-    hits_df = pandas.DataFrame({"peptide": train_peptides})
-    hits_df["true_allele"] = train_true_alleles
-    hits_df["hit"] = 1.0
-
-    decoys_df = hits_df.copy()
-    decoys_df["peptide"] = decoys_df.peptide.map(scramble_peptide)
-    decoys_df["true_allele"] = ""
-    decoys_df["hit"] = 0.0
-
-    train_df = pandas.concat([hits_df, decoys_df], ignore_index=True)
-
-    predictor = Class1PresentationNeuralNetwork(
-        PAN_ALLELE_PREDICTOR_NO_MASS_SPEC,
-        max_ensemble_size=1,
-        max_epochs=max_epochs,
-        learning_rate=0.0001,
-        patience=5,
-        min_delta=0.0,
-        random_negative_rate=0.0,
-        random_negative_constant=0)
-
-    allele_encoding = MultipleAlleleEncoding(
-        experiment_names=["experiment1"] * len(train_df),
-        experiment_to_allele_list={
-            "experiment1": alleles,
-        },
-        max_alleles_per_experiment=6,
-        allele_to_sequence=PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.allele_to_sequence,
-    ).compact()
-
-    pre_predictions = from_ic50(
-        predictor.predict(
-            output="affinities_matrix",
-            peptides=train_df.peptide.values,
-            alleles=allele_encoding))
-
-    (model,) = PAN_ALLELE_PREDICTOR_NO_MASS_SPEC.class1_pan_allele_models
-    expected_pre_predictions = from_ic50(
-        model.predict(
-            peptides=numpy.repeat(train_df.peptide.values, len(alleles)),
-            allele_encoding=allele_encoding.allele_encoding,
-    )).reshape((-1, len(alleles)))
-
-    #import ipdb ; ipdb.set_trace()
-
-    train_df["pre_max_prediction"] = pre_predictions.max(1)
-    pre_auc = roc_auc_score(train_df.hit.values, train_df.pre_max_prediction.values)
-    print("PRE_AUC", pre_auc)
-
-    assert_allclose(pre_predictions, expected_pre_predictions, rtol=1e-4)
-
-    motifs_history = []
-    random_peptides_encodable = make_random_peptides(10000, [9])
-
-
-    def update_motifs():
-        for allele in alleles:
-            motif = make_motif(allele, random_peptides_encodable)
-            motifs_history.append((allele, motif))
-
-    metric_rows = []
-
-    def progress():
-        (_, presentation_prediction, affinities_predictions) = (
-            predictor.predict(
-                output="all",
-                peptides=train_df.peptide.values,
-                alleles=allele_encoding))
-        affinities_predictions = from_ic50(affinities_predictions)
-        for (kind, predictions) in [
-                ("affinities", affinities_predictions),
-                ("presentation", presentation_prediction)]:
-
-            train_df["max_prediction"] = predictions.max(1)
-            train_df["predicted_allele"] = pandas.Series(alleles).loc[
-                predictions.argmax(1).flatten()
-            ].values
-
-            print(kind)
-            print(predictions)
-
-            mean_predictions_for_hit = train_df.loc[
-                train_df.hit == 1.0
-            ].max_prediction.mean()
-            mean_predictions_for_decoy = train_df.loc[
-                train_df.hit == 0.0
-            ].max_prediction.mean()
-            correct_allele_fraction = (
-                    train_df.loc[train_df.hit == 1.0].predicted_allele ==
-                    train_df.loc[train_df.hit == 1.0].true_allele
-            ).mean()
-            auc = roc_auc_score(train_df.hit.values, train_df.max_prediction.values)
-
-            print(kind, "Mean prediction for hit", mean_predictions_for_hit)
-            print(kind, "Mean prediction for decoy", mean_predictions_for_decoy)
-            print(kind, "Correct predicted allele fraction", correct_allele_fraction)
-            print(kind, "AUC", auc)
-
-            metric_rows.append((
-                kind,
-                mean_predictions_for_hit,
-                mean_predictions_for_decoy,
-                correct_allele_fraction,
-                auc,
-            ))
-
-            update_motifs()
-
-        return (presentation_prediction, auc)
-
-    print("Pre fitting:")
-    progress()
-    update_motifs()
-    print("Fitting...")
-
-    predictor.fit(
-        peptides=train_df.peptide.values,
-        labels=train_df.hit.values,
-        allele_encoding=allele_encoding,
-        progress_callback=progress,
-    )
-
-    (predictions, final_auc) = progress()
-    print("Final AUC", final_auc)
-
-    update_motifs()
-
-    motifs = pandas.DataFrame(
-        motifs_history,
-        columns=[
-            "allele",
-            "motif",
-        ]
-    )
-
-    metrics = pandas.DataFrame(
-        metric_rows,
-        columns=[
-            "output",
-            "mean_predictions_for_hit",
-            "mean_predictions_for_decoy",
-            "correct_allele_fraction",
-            "auc"
-        ])
-
-    return (predictor, predictions, metrics, motifs)
-
-
-def Xtest_batch_generator(sample_rate=0.1):
-    multi_train_df = pandas.read_csv(
-        data_path("multiallelic_ms.benchmark1.csv.bz2"))
-    multi_train_df["label"] = multi_train_df.hit
-    multi_train_df["is_affinity"] = False
-
-    sample_table = multi_train_df.loc[
-        multi_train_df.label == True
-    ].drop_duplicates("sample_id").set_index("sample_id").loc[
-        multi_train_df.sample_id.unique()
-    ]
-    grouped = multi_train_df.groupby("sample_id").nunique()
-    for col in sample_table.columns:
-        if (grouped[col] > 1).any():
-            del sample_table[col]
-    sample_table["alleles"] = sample_table.hla.str.split()
-
-    pan_train_df = pandas.read_csv(
-        get_path(
-            "models_class1_pan", "models.with_mass_spec/train_data.csv.bz2"))
-    pan_sub_train_df = pan_train_df
-    pan_sub_train_df["label"] = pan_sub_train_df["measurement_value"]
-    del pan_sub_train_df["measurement_value"]
-    pan_sub_train_df["is_affinity"] = True
-
-    pan_sub_train_df = pan_sub_train_df.sample(frac=sample_rate)
-    multi_train_df = multi_train_df.sample(frac=sample_rate)
-
-    pan_predictor = Class1AffinityPredictor.load(
-        get_path("models_class1_pan", "models.with_mass_spec"),
-        optimization_level=0,
-        max_models=1)
-
-    allele_encoding = MultipleAlleleEncoding(
-        experiment_names=multi_train_df.sample_id.values,
-        experiment_to_allele_list=sample_table.alleles.to_dict(),
-        max_alleles_per_experiment=sample_table.alleles.str.len().max(),
-        allele_to_sequence=pan_predictor.allele_to_sequence,
-    )
-    allele_encoding.append_alleles(pan_sub_train_df.allele.values)
-    allele_encoding = allele_encoding.compact()
-
-    combined_train_df = pandas.concat(
-        [multi_train_df, pan_sub_train_df], ignore_index=True, sort=True)
-
-    presentation_predictor = Class1PresentationNeuralNetwork(
-        pan_predictor,
-        auxiliary_input_features=[],
-        max_ensemble_size=1,
-        max_epochs=0,
-        batch_generator_batch_size=128,
-        learning_rate=0.0001,
-        patience=5,
-        min_delta=0.0,
-        random_negative_rate=1.0)
-
-    fit_results = presentation_predictor.fit(
-        peptides=combined_train_df.peptide.values,
-        labels=combined_train_df.label.values,
-        allele_encoding=allele_encoding,
-        affinities_mask=combined_train_df.is_affinity.values,
-        inequalities=combined_train_df.measurement_inequality.values,
-    )
-
-    batch_generator = fit_results['batch_generator']
-    train_batch_plan = batch_generator.train_batch_plan
-
-    assert_greater(len(train_batch_plan.equivalence_class_labels), 100)
-    assert_less(len(train_batch_plan.equivalence_class_labels), 1000)
-
-
-parser = argparse.ArgumentParser(usage=__doc__)
-parser.add_argument(
-    "--out-metrics-csv",
-    default=None,
-    help="Metrics output")
-parser.add_argument(
-    "--out-motifs-pickle",
-    default=None,
-    help="Metrics output")
-parser.add_argument(
-    "--max-epochs",
-    default=100,
-    type=int,
-    help="Max epochs")
-
-
-
-
-if __name__ == '__main__':
-    # If run directly from python, leave the user in a shell to explore results.
-    setup()
-    args = parser.parse_args(sys.argv[1:])
-    (predictor, predictions, metrics, motifs) = (
-        test_synthetic_allele_refinement(max_epochs=args.max_epochs))
-
-    if args.out_metrics_csv:
-        metrics.to_csv(args.out_metrics_csv)
-    if args.out_motifs_pickle:
-        motifs.to_pickle(args.out_motifs_pickle)
-
-    # Leave in ipython
-    import ipdb  # pylint: disable=import-error
-    ipdb.set_trace()
diff --git a/test/test_custom_loss.py b/test/test_custom_loss.py
index 98ee4ab5..ed833af6 100644
--- a/test/test_custom_loss.py
+++ b/test/test_custom_loss.py
@@ -1,3 +1,7 @@
+import logging
+logging.getLogger('tensorflow').disabled = True
+logging.getLogger('matplotlib').disabled = True
+
 from nose.tools import eq_, assert_less, assert_greater, assert_almost_equal
 
 import numpy
@@ -9,7 +13,7 @@ logging.getLogger('tensorflow').disabled = True
 
 import keras.backend as K
 
-from mhcflurry.custom_loss import CUSTOM_LOSSES
+from mhcflurry.custom_loss import CUSTOM_LOSSES, MultiallelicMassSpecLoss
 
 from mhcflurry.testing_utils import cleanup, startup
 teardown = cleanup
@@ -144,3 +148,71 @@ def test_mse_with_inequalities_and_multiple_outputs():
         ])
     assert_almost_equal(loss0, 0.02 / 4)
 
+
+def test_multiallelic_mass_spec_loss():
+    for delta in [0.0, 0.3]:
+        print("delta", delta)
+        # Hit labels
+        y_true = [
+            1.0,
+            0.0,
+            1.0,
+            -1.0,  # ignored
+            1.0,
+            0.0,
+            1.0,
+        ]
+        y_true = numpy.array(y_true)
+        y_pred = [
+            [0.3, 0.7, 0.5],
+            [0.2, 0.4, 0.6],
+            [0.1, 0.5, 0.3],
+            [0.9, 0.1, 0.2],
+            [0.1, 0.7, 0.1],
+            [0.8, 0.2, 0.4],
+            [0.1, 0.2, 0.4],
+        ]
+        y_pred = numpy.array(y_pred)
+
+        # reference implementation 1
+
+        def smooth_max(x, alpha):
+            x = numpy.array(x)
+            alpha = numpy.array([alpha])
+            return (x * numpy.exp(x * alpha)).sum() / (
+                numpy.exp(x * alpha)).sum()
+
+        contributions = []
+        for i in range(len(y_true)):
+            if y_true[i] == 1.0:
+                for j in range(len(y_true)):
+                    if y_true[j] == 0.0:
+                        tightest_i = max(y_pred[i])
+                        for k in range(y_pred.shape[1]):
+                            contribution = max(
+                                0, y_pred[j, k] - tightest_i + delta)**2
+                            contributions.append(contribution)
+        contributions = numpy.array(contributions)
+        expected1 = contributions.sum() / len(contributions)
+
+        # reference implementation 2: numpy
+        pos = numpy.array([
+            max(y_pred[i])
+            for i in range(len(y_pred))
+            if y_true[i] == 1.0
+        ])
+
+        neg = y_pred[(y_true == 0.0).astype(bool)]
+        term = neg.reshape((-1, 1)) - pos + delta
+        expected2 = (
+                numpy.maximum(0, term)**2).sum() / (
+            len(pos) * neg.shape[0] * neg.shape[1])
+
+        numpy.testing.assert_almost_equal(expected1, expected2)
+
+        computed = evaluate_loss(
+            MultiallelicMassSpecLoss(delta=delta).loss,
+            y_true,
+            y_pred.reshape(y_pred.shape))
+
+        numpy.testing.assert_almost_equal(computed, expected1, 4)
-- 
GitLab