From d9cbbcd3ca507b776697399846364b4b0c088118 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Thu, 12 Sep 2019 14:34:17 -0400
Subject: [PATCH] starting ligandome predictor

---
 mhcflurry/class1_ligandome_predictor.py | 104 +++++++++++
 test/test_class1_ligandome_predictor.py | 221 ++++++++++++++++++++++++
 test/test_network_merging.py            |   1 -
 test/test_speed.py                      |   8 +-
 4 files changed, 329 insertions(+), 5 deletions(-)
 create mode 100644 mhcflurry/class1_ligandome_predictor.py
 create mode 100644 test/test_class1_ligandome_predictor.py

diff --git a/mhcflurry/class1_ligandome_predictor.py b/mhcflurry/class1_ligandome_predictor.py
new file mode 100644
index 00000000..8312c916
--- /dev/null
+++ b/mhcflurry/class1_ligandome_predictor.py
@@ -0,0 +1,104 @@
+
+from .hyperparameters import HyperparameterDefaults
+from .class1_neural_network import Class1NeuralNetwork
+
+class Class1LigandomePredictor(object):
+    network_hyperparameter_defaults = HyperparameterDefaults(
+        retrain_mode="all",
+    )
+
+    def __init__(self, class1_affinity_predictor):
+        if not class1_affinity_predictor.pan_allele_models:
+            raise NotImplementedError("Pan allele models required")
+        if class1_affinity_predictor.allele_to_allele_specific_models:
+            raise NotImplementedError("Only pan allele models are supported")
+        self.binding_predictors = class1_affinity_predictor.pan_allele_models
+        self.network = None
+
+        self.network = Class1NeuralNetwork.merge(
+            self.binding_predictors, merge_method="sum")
+
+    def make_network(self):
+        import keras
+        import keras.backend as K
+        from keras.layers import Input
+        from keras.models import Model
+
+        models = self.binding_predictors
+
+        if len(models) == 1:
+            return models[0]
+        assert len(models) > 1
+
+        result = Class1NeuralNetwork(**dict(models[0].hyperparameters))
+
+        # Remove hyperparameters that are not shared by all models.
+        for model in models:
+            for (key, value) in model.hyperparameters.items():
+                if result.hyperparameters.get(key, value) != value:
+                    del result.hyperparameters[key]
+
+        assert result._network is None
+
+        networks = [model.network() for model in models]
+
+        layer_names = [[layer.name for layer in network.layers] for network in
+            networks]
+
+        pan_allele_layer_names = ['allele', 'peptide', 'allele_representation',
+            'flattened_0', 'allele_flat', 'allele_peptide_merged', 'dense_0',
+            'dropout_0', 'dense_1', 'dropout_1', 'output', ]
+
+        if all(names == pan_allele_layer_names for names in layer_names):
+            # Merging an ensemble of pan-allele architectures
+            network = networks[0]
+            peptide_input = Input(
+                shape=tuple(int(x) for x in K.int_shape(network.inputs[0])[1:]),
+                dtype='float32', name='peptide')
+            allele_input = Input(shape=(1,), dtype='float32', name='allele')
+
+            allele_embedding = network.get_layer("allele_representation")(
+                allele_input)
+            peptide_flat = network.get_layer("flattened_0")(peptide_input)
+            allele_flat = network.get_layer("allele_flat")(allele_embedding)
+            allele_peptide_merged = network.get_layer("allele_peptide_merged")(
+                [peptide_flat, allele_flat])
+
+            sub_networks = []
+            for (i, network) in enumerate(networks):
+                layers = network.layers[
+                pan_allele_layer_names.index("allele_peptide_merged") + 1:]
+                node = allele_peptide_merged
+                for layer in layers:
+                    layer.name += "_%d" % i
+                    node = layer(node)
+                sub_networks.append(node)
+
+            if merge_method == 'average':
+                output = keras.layers.average(sub_networks)
+            elif merge_method == 'sum':
+                output = keras.layers.add(sub_networks)
+            elif merge_method == 'concatenate':
+                output = keras.layers.concatenate(sub_networks)
+            else:
+                raise NotImplementedError("Unsupported merge method",
+                    merge_method)
+
+            result._network = Model(inputs=[peptide_input, allele_input],
+                outputs=[output], name="merged_predictor")
+            result.update_network_description()
+        else:
+            raise NotImplementedError(
+                "Don't know merge_method to merge networks with layer names: ",
+                layer_names)
+        return result
+
+
+    def fit(self, peptides, labels, experiment_names,
+            experiment_name_to_alleles):
+
+
+        pass
+
+    def predict(self, allele_lists, peptides):
+        pass
diff --git a/test/test_class1_ligandome_predictor.py b/test/test_class1_ligandome_predictor.py
new file mode 100644
index 00000000..041e9e68
--- /dev/null
+++ b/test/test_class1_ligandome_predictor.py
@@ -0,0 +1,221 @@
+"""
+
+Idea:
+
+- take an allele where MS vs. no-MS trained predictors are very different. One
+    possiblility is DLA-88*501:01 but human would be better
+- generate synethetic multi-allele MS by combining single-allele MS for differnet
+   alleles, including the selected allele
+- train ligandome predictor based on the no-ms pan-allele models on theis
+  synthetic dataset
+- see if the pan-allele predictor learns the "correct" motif for the selected
+  allele, i.e. updates to become more similar to the with-ms pan allele predictor.
+
+
+"""
+
+from sklearn.metrics import roc_auc_score
+import pandas
+import argparse
+import sys
+
+from numpy.testing import assert_, assert_equal
+import numpy
+from random import shuffle
+
+from mhcflurry import Class1AffinityPredictor,Class1NeuralNetwork
+from mhcflurry.allele_encoding import AlleleEncoding
+from mhcflurry.class1_ligandome_predictor import Class1LigandomePredictor
+from mhcflurry.downloads import get_path
+
+from mhcflurry.testing_utils import cleanup, startup
+from mhcflurry.amino_acid import COMMON_AMINO_ACIDS
+
+COMMON_AMINO_ACIDS = sorted(COMMON_AMINO_ACIDS)
+
+PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = None
+PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF = None
+PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF = None
+
+
+def setup():
+    global PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
+    global PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF
+    global PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF
+    startup()
+    PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = Class1AffinityPredictor.load(
+            get_path("models_class1_pan", "models.no_mass_spec"))
+    PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF = pandas.read_csv(
+        get_path(
+            "models_class1_pan",
+            "models.with_mass_spec/frequency_matrices.csv.bz2"))
+    PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF = pandas.read_csv(
+        get_path(
+            "models_class1_pan",
+            "models.no_mass_spec/frequency_matrices.csv.bz2"))
+
+
+def teardown():
+    global PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
+    global PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF
+    global PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF
+
+    PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = None
+    PAN_ALLELE_MOTIFS_WITH_MASS_SPEC_DF = None
+    PAN_ALLELE_MOTIFS_NO_MASS_SPEC_DF = None
+    cleanup()
+
+
+def sample_peptides_from_pssm(pssm, count):
+    result = pandas.DataFrame(
+        index=numpy.arange(count),
+        columns=pssm.index,
+        dtype=object,
+    )
+
+    for (position, vector) in pssm.iterrows():
+        result.loc[:, position] = numpy.random.choice(
+            pssm.columns,
+            size=count,
+            replace=True,
+            p=vector.values)
+
+    return result.apply("".join, axis=1)
+
+
+def scramble_peptide(peptide):
+    lst = list(peptide)
+    shuffle(lst)
+    return "".join(lst)
+
+
+def test_synthetic_allele_refinement():
+    refine_allele = "HLA-C*01:02"
+    alleles = [
+        "HLA-A*02:01", "HLA-B*27:01", "HLA-C*07:01",
+        "HLA-A*03:01", "HLA-B*15:01", refine_allele
+    ]
+    peptides_per_allele = [
+        2000, 1000, 500,
+        1500, 1200, 800,
+    ]
+
+    allele_to_peptides = dict(zip(alleles, peptides_per_allele))
+
+    length = 9
+
+    train_with_ms = pandas.read_csv(
+        get_path("data_curated", "curated_training_data.with_mass_spec.csv.bz2"))
+    train_no_ms = pandas.read_csv(get_path("data_curated",
+        "curated_training_data.no_mass_spec.csv.bz2"))
+
+    def filter_df(df):
+        df = df.loc[
+            (df.allele.isin(alleles)) &
+            (df.peptide.str.len() == length)
+        ]
+        return df
+
+    train_with_ms = filter_df(train_with_ms)
+    train_no_ms = filter_df(train_no_ms)
+
+    ms_specific = train_with_ms.loc[
+        ~train_with_ms.peptide.isin(train_no_ms.peptide)
+    ]
+
+    train_peptides = []
+    train_true_alleles = []
+    for allele in alleles:
+        peptides = ms_specific.loc[ms_specific.allele == allele].peptide.sample(
+            n=allele_to_peptides[allele])
+        train_peptides.extend(peptides)
+        train_true_alleles.extend([allele] * len(peptides))
+
+    hits_df = pandas.DataFrame({"peptide": train_peptides})
+    hits_df["true_allele"] = train_true_alleles
+    hits_df["hit"] = 1.0
+
+    decoys_df = hits_df.copy()
+    decoys_df["peptide"] = decoys_df.peptide.map(scramble_peptide)
+    decoys_df["true_allele"] = ""
+    decoys_df["hit"] = 0.0
+
+    train_df = pandas.concat([hits_df, decoys_df], ignore_index=True)
+
+    predictor = Class1LigandomePredictor(PAN_ALLELE_PREDICTOR_NO_MASS_SPEC)
+    predictor.fit(
+        peptides=train_df.peptide.values,
+        labels=train_df.hit.values,
+        experiment_names=["experiment1"] * len(train_df),
+        experiment_name_to_alleles={
+            "experiment1": alleles,
+        }
+    )
+
+    predictions = predictor.predict(
+        peptides=train_df.peptide.values,
+        alleles=alleles,
+        output_format="concatenate"
+    )
+
+    print(predictions)
+
+    import ipdb ; ipdb.set_trace()
+
+
+"""
+def test_simple_synethetic(
+        num_peptide_per_allele_and_length=100, lengths=[8,9,10,11]):
+    alleles = [
+        "HLA-A*02:01", "HLA-B*52:01", "HLA-C*07:01",
+        "HLA-A*03:01", "HLA-B*57:02", "HLA-C*03:01",
+    ]
+    cutoff = PAN_ALLELE_MOTIFS_DF.cutoff_fraction.min()
+    peptides_and_alleles = []
+    for allele in alleles:
+        sub_df = PAN_ALLELE_MOTIFS_DF.loc[
+            (PAN_ALLELE_MOTIFS_DF.allele == allele) &
+            (PAN_ALLELE_MOTIFS_DF.cutoff_fraction == cutoff)
+        ]
+        assert len(sub_df) > 0, allele
+        for length in lengths:
+            pssm = sub_df.loc[
+                sub_df.length == length
+            ].set_index("position")[COMMON_AMINO_ACIDS]
+            peptides = sample_peptides_from_pssm(pssm, num_peptide_per_allele_and_length)
+            for peptide in peptides:
+                peptides_and_alleles.append((peptide, allele))
+
+    hits_df = pandas.DataFrame(
+        peptides_and_alleles,
+        columns=["peptide", "allele"]
+    )
+    hits_df["hit"] = 1
+
+    decoys = hits_df.copy()
+    decoys["peptide"] = decoys.peptide.map(scramble_peptide)
+    decoys["hit"] = 0.0
+
+    train_df = pandas.concat([hits_df, decoys], ignore_index=True)
+
+    return train_df
+    return
+    pass
+"""
+
+parser = argparse.ArgumentParser(usage=__doc__)
+parser.add_argument(
+    "--alleles",
+    nargs="+",
+    default=None,
+    help="Which alleles to test")
+
+if __name__ == '__main__':
+    # If run directly from python, leave the user in a shell to explore results.
+    setup()
+    args = parser.parse_args(sys.argv[1:])
+    result = test_synthetic_allele_refinement()
+
+    # Leave in ipython
+    import ipdb  # pylint: disable=import-error
+    ipdb.set_trace()
diff --git a/test/test_network_merging.py b/test/test_network_merging.py
index 761e138c..31c63377 100644
--- a/test/test_network_merging.py
+++ b/test/test_network_merging.py
@@ -30,7 +30,6 @@ def teardown():
 
 def test_merge():
     assert len(PAN_ALLELE_PREDICTOR.class1_pan_allele_models) > 1
-
     peptides = random_peptides(100, length=9)
     peptides.extend(random_peptides(100, length=10))
     peptides = pandas.Series(peptides).sample(frac=1.0)
diff --git a/test/test_speed.py b/test/test_speed.py
index 037ae61c..3f75579b 100644
--- a/test/test_speed.py
+++ b/test/test_speed.py
@@ -22,11 +22,11 @@ from mhcflurry.testing_utils import cleanup, startup
 
 
 ALLELE_SPECIFIC_PREDICTOR = None
-PAN_ALLELE_PREDICTOR = None
+PAN_ALLELE_PREDICTOR_NO_MASS_SPEC = None
 
 
 def setup():
-    global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR
+    global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
     startup()
     ALLELE_SPECIFIC_PREDICTOR = Class1AffinityPredictor.load(
         get_path("models_class1", "models"))
@@ -36,7 +36,7 @@ def setup():
 
 
 def teardown():
-    global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR
+    global ALLELE_SPECIFIC_PREDICTOR, PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
     ALLELE_SPECIFIC_PREDICTOR = None
     PAN_ALLELE_PREDICTOR = None
     cleanup()
@@ -97,7 +97,7 @@ def test_speed_allele_specific(profile=False, num=DEFAULT_NUM_PREDICTIONS):
 
 
 def test_speed_pan_allele(profile=False, num=DEFAULT_NUM_PREDICTIONS):
-    global PAN_ALLELE_PREDICTOR
+    global PAN_ALLELE_PREDICTOR_NO_MASS_SPEC
     starts = collections.OrderedDict()
     timings = collections.OrderedDict()
     profilers = collections.OrderedDict()
-- 
GitLab