From 9cce08f2133efab9e3cdbbba84b60ebfdce73a40 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 4 Feb 2020 11:00:19 -0500
Subject: [PATCH] add sequence scanning to presentation predictor

---
 mhcflurry/class1_presentation_predictor.py    | 118 +++++++++++++++++-
 mhcflurry/class1_processing_neural_network.py |   8 ++
 mhcflurry/class1_processing_predictor.py      |   9 ++
 test/test_class1_presentation_predictor.py    |  94 ++++++++++++++
 test/test_class1_processing_predictor.py      |   1 +
 5 files changed, 229 insertions(+), 1 deletion(-)

diff --git a/mhcflurry/class1_presentation_predictor.py b/mhcflurry/class1_presentation_predictor.py
index dbdca3b0..d95ca0c9 100644
--- a/mhcflurry/class1_presentation_predictor.py
+++ b/mhcflurry/class1_presentation_predictor.py
@@ -240,6 +240,111 @@ class Class1PresentationPredictor(object):
             c_flanks=c_flanks,
             verbose=verbose).presentation_score.values
 
+    def predict_scan(
+            self,
+            sequences,
+            alleles,
+            result="best",  # or "all" or "filtered"
+            comparison_quantity="presentation_score",
+            comparison_value=None,
+            peptide_lengths=[8, 9, 10, 11],
+            use_flanks=True,
+            include_affinity_percentile=False,
+            verbose=1,
+            throw=True):
+
+        processing_predictor = self.processing_predictor_with_flanks
+        if not use_flanks or processing_predictor is None:
+            processing_predictor = self.processing_predictor_without_flanks
+
+        supported_sequence_lengths = processing_predictor.sequence_lengths
+        n_flank_length = supported_sequence_lengths["n_flank"]
+        c_flank_length = supported_sequence_lengths["c_flank"]
+
+        sequence_names = []
+        n_flanks = [] if use_flanks else None
+        c_flanks = [] if use_flanks else None
+        peptides = []
+
+        if isinstance(sequences, string_types):
+            sequences = [sequences]
+
+        if not isinstance(sequences, dict):
+            sequences = collections.OrderedDict(
+                ("sequence_%04d" % (i + 1), sequence)
+                for (i, sequence) in enumerate(sequences))
+
+        if not isinstance(alleles, dict):
+            alleles = dict((name, alleles) for name in sequences.keys())
+
+        missing = [key for key in sequences if key not in alleles]
+        if missing:
+            raise ValueError(
+                "Sequence names missing from alleles dict: ", missing)
+
+        for (name, sequence) in sequences.items():
+            if not isinstance(sequence, string_types):
+                raise ValueError("Expected string, not %s (%s)" % (
+                    sequence, type(sequence)))
+            for peptide_start in range(len(sequence) - min(peptide_lengths)):
+                n_flank_start = max(0, peptide_start - n_flank_length)
+                for peptide_length in peptide_lengths:
+                    c_flank_end = (
+                        peptide_start + peptide_length + c_flank_length)
+                    sequence_names.append(name)
+                    peptides.append(
+                        sequence[peptide_start: peptide_start + peptide_length])
+                    if use_flanks:
+                        n_flanks.append(
+                            sequence[n_flank_start : peptide_start])
+                        c_flanks.append(
+                            sequence[peptide_start + peptide_length : c_flank_end])
+
+        result_df = self.predict_to_dataframe(
+            peptides=peptides,
+            alleles=alleles,
+            n_flanks=n_flanks,
+            c_flanks=c_flanks,
+            experiment_names=sequence_names,
+            include_affinity_percentile=include_affinity_percentile,
+            verbose=verbose,
+            throw=throw)
+
+        result_df = result_df.rename(
+            columns={"experiment_name": "sequence_name"})
+
+        comparison_is_score = comparison_quantity.endswith("score")
+
+        result_df = result_df.sort_values(
+            comparison_quantity,
+            ascending=not comparison_is_score)
+
+        if result == "best":
+            result_df = result_df.drop_duplicates(
+                "sequence_name", keep="first").sort_values("sequence_name")
+        elif result == "filtered":
+            if comparison_is_score:
+                result_df = result_df.loc[
+                    result_df[comparison_quantity] >= comparison_value
+                ]
+            else:
+                result_df = result_df.loc[
+                    result_df[comparison_quantity] <= comparison_value
+                ]
+        elif result == "all":
+            pass
+        else:
+            raise ValueError(
+                "Unknown result: %s. Valid choices are: best, filtered, all"
+                % result)
+
+        result_df = result_df.reset_index(drop=True)
+        result_df = result_df.copy()
+
+        return result_df
+
+
+
     def predict_to_dataframe(
             self,
             peptides,
@@ -298,7 +403,10 @@ class Class1PresentationPredictor(object):
             throw=throw)
         df["affinity_score"] = from_ic50(df.affinity)
         df["processing_score"] = processing_scores
-
+        if c_flanks is not None:
+            df.insert(1, "c_flank", c_flanks)
+        if n_flanks is not None:
+            df.insert(1, "n_flank", n_flanks)
 
         model_name = 'with_flanks' if n_flanks is not None else "without_flanks"
         model = self.get_model(model_name)
@@ -383,12 +491,20 @@ class Class1PresentationPredictor(object):
             processing_predictor_with_flanks = Class1ProcessingPredictor.load(
                 join(models_dir, "processing_predictor_with_flanks"),
                 max_models=max_models)
+        else:
+            logging.warning(
+                "Presentation predictor is missing processing predictor: %s",
+                join(models_dir, "processing_predictor_with_flanks"))
 
         processing_predictor_without_flanks = None
         if exists(join(models_dir, "processing_predictor_without_flanks")):
             processing_predictor_without_flanks = Class1ProcessingPredictor.load(
                 join(models_dir, "processing_predictor_without_flanks"),
                 max_models=max_models)
+        else:
+            logging.warning(
+                "Presentation predictor is missing processing predictor: %s",
+                join(models_dir, "processing_predictor_without_flanks"))
 
         weights_dataframe = pandas.read_csv(
             join(models_dir, "weights.csv"),
diff --git a/mhcflurry/class1_processing_neural_network.py b/mhcflurry/class1_processing_neural_network.py
index 8c992527..1a7643b7 100644
--- a/mhcflurry/class1_processing_neural_network.py
+++ b/mhcflurry/class1_processing_neural_network.py
@@ -80,6 +80,14 @@ class Class1ProcessingNeuralNetwork(object):
         self.network_weights = None
         self.fit_info = []
 
+    @property
+    def sequence_lengths(self):
+        return {
+            "peptide": self.hyperparameters['peptide_max_length'],
+            "n_flank": self.hyperparameters['n_flank_length'],
+            "c_flank": self.hyperparameters['c_flank_length'],
+        }
+
     def network(self):
         """
         Return the keras model associated with this network.
diff --git a/mhcflurry/class1_processing_predictor.py b/mhcflurry/class1_processing_predictor.py
index 4b8620e8..5e29e651 100644
--- a/mhcflurry/class1_processing_predictor.py
+++ b/mhcflurry/class1_processing_predictor.py
@@ -34,6 +34,15 @@ class Class1ProcessingPredictor(object):
         self.metadata_dataframes = (
             dict(metadata_dataframes) if metadata_dataframes else {})
 
+    @property
+    def sequence_lengths(self):
+        df = pandas.DataFrame([model.sequence_lengths for model in self.models])
+        return {
+            "peptide": df.peptide.min(),  # min: anything greater is error
+            "n_flank": df.n_flank.max(),  # max: anything greater is ignored
+            "c_flank": df.c_flank.max(),
+        }
+
     def add_models(self, models):
         new_model_names = []
         original_manifest = self.manifest_df
diff --git a/test/test_class1_presentation_predictor.py b/test/test_class1_presentation_predictor.py
index 99856f62..363ee84b 100644
--- a/test/test_class1_presentation_predictor.py
+++ b/test/test_class1_presentation_predictor.py
@@ -26,12 +26,14 @@ from . import data_path
 AFFINITY_PREDICTOR = None
 CLEAVAGE_PREDICTOR = None
 CLEAVAGE_PREDICTOR_NO_FLANKING = None
+PRESENTATION_PREDICTOR = None
 
 
 def setup():
     global AFFINITY_PREDICTOR
     global CLEAVAGE_PREDICTOR
     global CLEAVAGE_PREDICTOR_NO_FLANKING
+    global PRESENTATION_PREDICTOR
     startup()
     AFFINITY_PREDICTOR = Class1AffinityPredictor.load(
         get_path("models_class1_pan", "models.combined"),
@@ -42,15 +44,18 @@ def setup():
     CLEAVAGE_PREDICTOR_NO_FLANKING = Class1ProcessingPredictor.load(
         get_path("models_class1_processing_variants", "models.selected.no_flank"),
         max_models=1)
+    PRESENTATION_PREDICTOR = Class1PresentationPredictor.load()
 
 
 def teardown():
     global AFFINITY_PREDICTOR
     global CLEAVAGE_PREDICTOR
     global CLEAVAGE_PREDICTOR_NO_FLANKING
+    global PRESENTATION_PREDICTOR
     AFFINITY_PREDICTOR = None
     CLEAVAGE_PREDICTOR = None
     CLEAVAGE_PREDICTOR_NO_FLANKING = None
+    PRESENTATION_PREDICTOR = None
     cleanup()
 
 
@@ -126,3 +131,92 @@ def test_basic():
             test_df["prediction1"], other_test_df["prediction1"], decimal=6)
         numpy.testing.assert_array_almost_equal(
             test_df["prediction2"], other_test_df["prediction2"], decimal=6)
+
+def test_downloaded_predictor():
+    global PRESENTATION_PREDICTOR
+
+    # Test sequence scanning
+    scan_results1 = PRESENTATION_PREDICTOR.predict_scan(
+        sequences=[
+            "MESLVPGFNEKTHVQLSLPVLQVRDVLVRGFGDSVEEVLSEARQHLKDGTCGLVEVEKGVLPQLE",
+            "QPYVFIKRSDARTAPHGHVMVELVAELEGIQYGRSGETLGVLVPHVGEIPVAYRKVLLRKNGNKG",
+            "AGGHSYGADLKSFDLGDELGTDPYEDFQENWNTKHSSGVTRELMRELNGGAYTRYVDNNFCGPDG",
+        ],
+        alleles=[
+            "HLA-A*02:01",
+            "HLA-A*03:01",
+            "HLA-B*57:01",
+            "HLA-B*44:02",
+            "HLA-C*02:01",
+            "HLA-C*07:01",
+        ])
+    print(scan_results1)
+
+    assert_equal(len(scan_results1), 3)
+    assert (scan_results1.affinity < 200).all()
+    assert (scan_results1.presentation_score > 0.7).all()
+
+    scan_results2 = PRESENTATION_PREDICTOR.predict_scan(
+        result="filtered",
+        comparison_value=500,
+        comparison_quantity="affinity",
+        sequences={
+            "seq1": "MESLVPGFNEKTHVQLSLPVLQVRDVLVRGFGDSVEEVLSEARQHLKDGTCGLVEVEKGVLPQLE",
+            "seq2": "QPYVFIKRSDARTAPHGHVMVELVAELEGIQYGRSGETLGVLVPHVGEIPVAYRKVLLRKNGNKG",
+            "seq3": "AGGHSYGADLKSFDLGDELGTDPYEDFQENWNTKHSSGVTRELMRELNGGAYTRYVDNNFCGPDG",
+        },
+        alleles=[
+            "HLA-A*02:01",
+            "HLA-A*03:01",
+            "HLA-B*57:01",
+            "HLA-B*44:02",
+            "HLA-C*02:01",
+            "HLA-C*07:01",
+        ])
+    print(scan_results2)
+
+    assert len(scan_results2) > 10
+    assert (scan_results2.affinity <= 500).all()
+
+    scan_results3 = PRESENTATION_PREDICTOR.predict_scan(
+        result="filtered",
+        comparison_value=0.9,
+        comparison_quantity="presentation_score",
+        sequences={
+            "seq1": "MESLVPGFNEKTHVQLSLPVLQVRDVLVRGFGDSVEEVLSEARQHLKDGTCGLVEVEKGVLPQLE",
+            "seq2": "QPYVFIKRSDARTAPHGHVMVELVAELEGIQYGRSGETLGVLVPHVGEIPVAYRKVLLRKNGNKG",
+            "seq3": "AGGHSYGADLKSFDLGDELGTDPYEDFQENWNTKHSSGVTRELMRELNGGAYTRYVDNNFCGPDG",
+        },
+        alleles=[
+            "HLA-A*02:01",
+            "HLA-A*03:01",
+            "HLA-B*57:01",
+            "HLA-B*44:02",
+            "HLA-C*02:01",
+            "HLA-C*07:01",
+        ])
+    print(scan_results3)
+
+    assert len(scan_results3) > 5, len(scan_results3)
+    assert (scan_results3.presentation_score >= 0.9).all()
+
+    scan_results4 = PRESENTATION_PREDICTOR.predict_scan(
+        result="all",
+        comparison_quantity="affinity",
+        sequences={
+            "seq1": "MESLVPGFNEKTHVQLSLPVLQVRDVLVRGFGDSVEEVLSEARQHLKDGTCGLVEVEKGVLPQLE",
+            "seq2": "QPYVFIKRSDARTAPHGHVMVELVAELEGIQYGRSGETLGVLVPHVGEIPVAYRKVLLRKNGNKG",
+            "seq3": "AGGHSYGADLKSFDLGDELGTDPYEDFQENWNTKHSSGVTRELMRELNGGAYTRYVDNNFCGPDG",
+        },
+        alleles=[
+            "HLA-A*02:01",
+            "HLA-A*03:01",
+            "HLA-B*57:01",
+            "HLA-B*44:02",
+            "HLA-C*02:01",
+            "HLA-C*07:01",
+        ])
+    print(scan_results4)
+
+    assert len(scan_results4) > 200, len(scan_results4)
+    assert_less(scan_results4.iloc[0].affinity, 100)
diff --git a/test/test_class1_processing_predictor.py b/test/test_class1_processing_predictor.py
index b43df286..352b5d27 100644
--- a/test/test_class1_processing_predictor.py
+++ b/test/test_class1_processing_predictor.py
@@ -67,3 +67,4 @@ def test_basic():
         n_flanks=df.n_flank.values,
         c_flanks=df.c_flank.values)
     assert_array_equal(df.score.values, df3.score.values)
+
-- 
GitLab