From 4082613a54250adf8bc77033b780f09047eb1907 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Fri, 13 Mar 2020 12:08:37 -0400
Subject: [PATCH] add test/test_predict_scan_command.py

---
 mhcflurry/class1_presentation_predictor.py |  10 +-
 mhcflurry/predict_command.py               |  19 +++-
 mhcflurry/predict_scan_command.py          |  42 +++----
 test/data/example.fasta                    |  18 +--
 test/test_predict_command.py               |   4 +
 test/test_predict_scan_command.py          | 125 +++++++++++++++++++++
 6 files changed, 175 insertions(+), 43 deletions(-)
 create mode 100644 test/test_predict_scan_command.py

diff --git a/mhcflurry/class1_presentation_predictor.py b/mhcflurry/class1_presentation_predictor.py
index 124ec621..812ffae6 100644
--- a/mhcflurry/class1_presentation_predictor.py
+++ b/mhcflurry/class1_presentation_predictor.py
@@ -93,11 +93,13 @@ class Class1PresentationPredictor(object):
         Parameters
         ----------
         peptides : list of string
+        alleles : dict of string -> list of string
+            Keys are experiment names, values are the alleles (genotype) for
+            that sample
         experiment_names : list of string [same length as peptides]
             Sample names corresponding to each peptide. These are used to
-            lookup the alleles for each peptide in the alleles dict.
-        alleles : dict of string -> list of string
-            Keys are experiment names, values are the alleles for that sample
+            lookup the alleles for each peptide in the alleles dict. If not
+            specified, then all combinations of experiment names
         include_affinity_percentile : bool
             Whether to include affinity percentile ranks
         verbose : int
@@ -585,6 +587,8 @@ class Class1PresentationPredictor(object):
             peptide, n_flank, c_flank, sequence_name, affinity, best_allele,
             processing_score, presentation_score
         """
+        if comparison_quantity is None:
+            comparison_quantity = "presentation_score"
 
         processing_predictor = self.processing_predictor_with_flanks
         if not use_flanks or processing_predictor is None:
diff --git a/mhcflurry/predict_command.py b/mhcflurry/predict_command.py
index cf23a02b..7a84c4ad 100644
--- a/mhcflurry/predict_command.py
+++ b/mhcflurry/predict_command.py
@@ -11,7 +11,7 @@ Examples:
 Write a CSV file containing the contents of INPUT.csv plus additional columns
 giving MHCflurry predictions:
 
-    $ mhcflurry-predict INPUT.csv --out RESULT.csv
+$ mhcflurry-predict INPUT.csv --out RESULT.csv
 
 The input CSV file is expected to contain columns "allele", "peptide", and,
 optionally, "n_flank", and "c_flank".
@@ -19,10 +19,23 @@ optionally, "n_flank", and "c_flank".
 If `--out` is not specified, results are written to stdout.
 
 You can also run on alleles and peptides specified on the commandline, in
-which case predictions are written for all combinations of alleles and
+which case predictions are written for *all combinations* of alleles and
 peptides:
 
-    $ mhcflurry-predict --alleles HLA-A0201 H-2Kb --peptides SIINFEKL DENDREKLLL
+$ mhcflurry-predict --alleles HLA-A0201 H-2Kb --peptides SIINFEKL DENDREKLLL
+
+Instead of individual alleles (in a CSV or on the command line), you can also
+give a comma separated list of alleles giving a sample genotype. In this case,
+the tightest binding affinity across the alleles for the sample will be
+returned. For example:
+
+$ mhcflurry-predict --peptides SIINFEKL DENDREKLLL \
+    --alleles \
+        HLA-A*02:01,HLA-A*03:01,HLA-B*57:01,HLA-B*45:01,HLA-C*02:01,HLA-C*07:02 \
+        HLA-A*01:01,HLA-A*02:06,HLA-B*44:02,HLA-B*07:02,HLA-C*01:01,HLA-C*03:01
+
+will give the tightest predicted affinities across alleles for each of the two
+genotypes specified for each peptide.
 '''
 from __future__ import (
     print_function,
diff --git a/mhcflurry/predict_scan_command.py b/mhcflurry/predict_scan_command.py
index f535e6c4..b2a334ac 100644
--- a/mhcflurry/predict_scan_command.py
+++ b/mhcflurry/predict_scan_command.py
@@ -1,37 +1,39 @@
 '''
 Scan protein sequences using the MHCflurry presentation predictor.
 
-By default, subsequences with affinity percentile ranks less than 2.0 are
-returned. You can also specify --results-all to return predictions for all
-subsequences, or --results-best to return the top subsequence for each sequence.
+By default, sub-sequences (peptides) with affinity percentile ranks less than
+2.0 are returned. You can also specify --results-all to return predictions for
+all peptides, or --results-best to return the top peptide for each sequence.
 
 Examples:
 
 Scan a set of sequences in a FASTA file for binders to any alleles in a MHC I
 genotype:
 
-    mhcflurry-predict-scan \
-        test/data/example.fasta \
-        --alleles HLA-A*02:01,HLA-A*03:01,HLA-B*57:01,HLA-B*45:01,HLA-C*02:01,HLA-C*07:02
+$ mhcflurry-predict-scan \
+    test/data/example.fasta \
+    --alleles HLA-A*02:01,HLA-A*03:01,HLA-B*57:01,HLA-B*45:01,HLA-C*02:01,HLA-C*07:02
 
 Instead of a FASTA, you can also pass a CSV that has "sequence_id" and "sequence"
 columns.
 
-You can also specify multiple MHC I genotypes to scan:
+You can also specify multiple MHC I genotypes to scan as space-separated
+arguments to the --alleles option:
 
-    mhcflurry-predict-scan \
-        test/data/example.fasta \
-        --alleles \
-            HLA-A*02:01,HLA-A*03:01,HLA-B*57:01,HLA-B*45:01,HLA-C*02:01,HLA-C*07:02 \
-            HLA-A*01:01,HLA-A*02:06,HLA-B*68:01,HLA-B*07:02,HLA-C*01:01,HLA-C*03:01
+$ mhcflurry-predict-scan \
+    test/data/example.fasta \
+    --alleles \
+        HLA-A*02:01,HLA-A*03:01,HLA-B*57:01,HLA-B*45:01,HLA-C*02:02,HLA-C*07:02 \
+        HLA-A*01:01,HLA-A*02:06,HLA-B*44:02,HLA-B*07:02,HLA-C*01:02,HLA-C*03:01
 
 If `--out` is not specified, results are written to standard out.
 
-You can also run on sequences specified on the commandline:
+You can also specify sequences on the commandline:
 
 mhcflurry-predict-scan \
     --sequences MGYINVFAFPFTIYSLLLCRMNSRNYIAQVDVVNFNLT \
-    --alleles HLA-A0201 H-2Kb
+    --alleles HLA-A*02:01,HLA-A*03:01,HLA-B*57:01,HLA-B*45:01,HLA-C*02:02,HLA-C*07:02
+
 '''
 from __future__ import (
     print_function,
@@ -48,7 +50,6 @@ import os
 import pandas
 
 from .downloads import get_default_class1_presentation_models_dir
-from .class1_affinity_predictor import Class1AffinityPredictor
 from .class1_presentation_predictor import Class1PresentationPredictor
 from .fasta import read_fasta_to_dataframe
 from .version import __version__
@@ -70,13 +71,13 @@ helper_args.add_argument(
     "--list-supported-alleles",
     action="store_true",
     default=False,
-    help="Prints the list of supported alleles and exits"
+    help="Print the list of supported alleles and exits"
 )
 helper_args.add_argument(
     "--list-supported-peptide-lengths",
     action="store_true",
     default=False,
-    help="Prints the list of supported peptide lengths and exits"
+    help="Print the list of supported peptide lengths and exits"
 )
 helper_args.add_argument(
     "--version",
@@ -87,7 +88,7 @@ helper_args.add_argument(
 input_args = parser.add_argument_group(title="Input options")
 input_args.add_argument(
     "input",
-    metavar="INPUT.csv",
+    metavar="INPUT",
     nargs="?",
     help="Input CSV or FASTA")
 input_args.add_argument(
@@ -139,7 +140,7 @@ results_args.add_argument(
     "--results-all",
     action="store_true",
     default=False,
-    help="")
+    help="Return results for all peptides regardless of affinity, etc.")
 results_args.add_argument(
     "--results-best",
     choices=comparison_quantities,
@@ -229,7 +230,8 @@ def run(argv=sys.argv[1:]):
             "--results-filtered")
 
     (result,) = [key for (key, value) in result_args.items() if value]
-    result_comparison_quantity = result_args[result]
+    result_comparison_quantity = (
+        None if result == "all" else result_args[result])
     result_filter_value = None if result != "filtered" else {
         "presentation_score": args.threshold_presentation_score,
         "processing_score": args.threshold_processing_score,
diff --git a/test/data/example.fasta b/test/data/example.fasta
index 56b0ab32..ef0ad8ca 100644
--- a/test/data/example.fasta
+++ b/test/data/example.fasta
@@ -1,23 +1,7 @@
->QHN73810.1 surface glycoprotein [Severe acute respiratory syndrome coronavirus 2]
+>QHN73810.1 surface glycoprotein [Severe acute respiratory syndrome coronavirus 2] prefix
 MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHAIHV
 SGTNGTKRFDNPVLPFNDGVYFASTEKSNIIRGWIFGTTLDSKTQSLLIVNNATNVVIKVCEFQFCNDPF
 LGVYYHKNNKSWMESEFRVYSSANNCTFEYVSQPFLMDLEGKQGNFKNLREFVFKNIDGYFKIYSKHTPI
-NLVRDLPQGFSALEPLVDLPIGINITRFQTLLALHRSYLTPGDSSSGWTAGAAAYYVGYLQPRTFLLKYN
-ENGTITDAVDCALDPLSETKCTLKSFTVEKGIYQTSNFRVQPTESIVRFPNITNLCPFGEVFNATRFASV
-YAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIAD
-YNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYF
-PLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFNFNGLTGTGVLTESNKKFL
-PFQQFGRDIADTTDAVRDPQTLEILDITPCSFGGVSVITPGTNTSNQVAVLYQDVNCTEVPVAIHADQLT
-PTWRVYSTGSNVFQTRAGCLIGAEHVNNSYECDIPIGAGICASYQTQTNSPRRARSVASQSIIAYTMSLG
-AENSVAYSNNSIAIPTNFTISVTTEILPVSMTKTSVDCTMYICGDSTECSNLLLQYGSFCTQLNRALTGI
-AVEQDKNTQEVFAQVKQIYKTPPIKDFGGFNFSQILPDPSKPSKRSFIEDLLFNKVTLADAGFIKQYGDC
-LGDIAARDLICAQKFNGLTVLPPLLTDEMIAQYTSALLAGTITSGWTFGAGAALQIPFAMQMAYRFNGIG
-VTQNVLYENQKLIANQFNSAIGKIQDSLSSTASALGKLQDVVNQNAQALNTLVKQLSSNFGAISSVLNDI
-LSRLDKVEAEVQIDRLITGRLQSLQTYVTQQLIRAAEIRASANLAATKMSECVLGQSKRVDFCGKGYHLM
-SFPQSAPHGVVFLHVTYVPAQEKNFTTAPAICHDGKAHFPREGVFVSNGTHWFVTQRNFYEPQIITTDNT
-FVSGNCDVVIGIVNNTVYDPLQPELDSFKEELDKYFKNHTSPDVDLGDISGINASVVNIQKEIDRLNEVA
-KNLNESLIDLQELGKYEQYIKWPWYIWLGFIAGLIAIVMVTIMLCCMTSCCSCLKGCCSCGSCCKFDEDD
-SEPVLKGVKLHYT
 >protein1
 MDSKGSSQKGSRLLLLLVVSNLLLCQGVVSTPVCPNGPGNCQV
 EMFNEFDKRYAQGKGFITMALNSCHTSSLPTPEDKEQAQQTHH
diff --git a/test/test_predict_command.py b/test/test_predict_command.py
index 8d73fa7a..b0327bf7 100644
--- a/test/test_predict_command.py
+++ b/test/test_predict_command.py
@@ -1,3 +1,7 @@
+import logging
+logging.getLogger('matplotlib').disabled = True
+logging.getLogger('tensorflow').disabled = True
+
 import tempfile
 import os
 
diff --git a/test/test_predict_scan_command.py b/test/test_predict_scan_command.py
new file mode 100644
index 00000000..a945245f
--- /dev/null
+++ b/test/test_predict_scan_command.py
@@ -0,0 +1,125 @@
+import logging
+logging.getLogger('matplotlib').disabled = True
+logging.getLogger('tensorflow').disabled = True
+
+import tempfile
+import os
+
+import pandas
+from numpy.testing import assert_equal, assert_array_less, assert_array_equal
+
+from mhcflurry import predict_scan_command
+
+
+from mhcflurry.testing_utils import cleanup, startup
+teardown = cleanup
+setup = startup
+
+from . import data_path
+
+
+def test_fasta():
+    args = [
+        data_path("example.fasta"),
+        "--alleles",
+        "HLA-A*02:01,HLA-A*03:01,HLA-B*57:01,HLA-B*45:01,HLA-C*02:02,HLA-C*07:02",
+    ]
+    deletes = []
+    try:
+        fd_out = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
+        deletes.append(fd_out.name)
+        full_args = args + ["--out", fd_out.name]
+        print("Running with args: %s" % full_args)
+        predict_scan_command.run(full_args)
+        result = pandas.read_csv(fd_out.name)
+        print(result)
+        assert not result.isnull().any().any()
+    finally:
+        for delete in deletes:
+            os.unlink(delete)
+
+    assert_equal(result.best_allele.nunique(), 6)
+    assert_equal(result.sequence_name.nunique(), 3)
+    assert_array_less(result.affinity_percentile, 2.0)
+
+
+def test_fasta_50nm():
+    args = [
+        data_path("example.fasta"),
+        "--alleles",
+        "HLA-A*02:01,HLA-A*03:01,HLA-B*57:01,HLA-B*45:01,HLA-C*02:02,HLA-C*07:02",
+        "--results-filtered", "affinity",
+        "--threshold-affinity", "50",
+    ]
+    deletes = []
+    try:
+        fd_out = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
+        deletes.append(fd_out.name)
+        full_args = args + ["--out", fd_out.name]
+        print("Running with args: %s" % full_args)
+        predict_scan_command.run(full_args)
+        result = pandas.read_csv(fd_out.name)
+        print(result)
+        assert not result.isnull().any().any()
+    finally:
+        for delete in deletes:
+            os.unlink(delete)
+
+    assert len(result) > 0
+    assert_array_less(result.affinity, 50)
+
+
+def test_fasta_best():
+    args = [
+        data_path("example.fasta"),
+        "--alleles",
+        "HLA-A*02:01,HLA-A*03:01,HLA-B*57:01,HLA-B*45:01,HLA-C*02:02,HLA-C*07:02",
+        "--results-best", "affinity_percentile",
+    ]
+    deletes = []
+    try:
+        fd_out = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
+        deletes.append(fd_out.name)
+        full_args = args + ["--out", fd_out.name]
+        print("Running with args: %s" % full_args)
+        predict_scan_command.run(full_args)
+        result = pandas.read_csv(fd_out.name)
+        print(result)
+        assert not result.isnull().any().any()
+    finally:
+        for delete in deletes:
+            os.unlink(delete)
+
+    assert len(result) > 0
+    assert_array_equal(
+        result.groupby(["sequence_name"]).peptide.count().values, 1)
+
+
+def test_commandline_sequences():
+    args = [
+        "--sequences", "ASDFGHKL", "QWERTYIPCVNM",
+        "--alleles", "HLA-A0201,HLA-A0301", "H-2-Kb",
+        "--peptide-lengths", "8",
+        "--results-all",
+    ]
+
+    deletes = []
+    try:
+        fd_out = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
+        deletes.append(fd_out.name)
+        full_args = args + ["--out", fd_out.name]
+        print("Running with args: %s" % full_args)
+        predict_scan_command.run(full_args)
+        result = pandas.read_csv(fd_out.name)
+        print(result)
+    finally:
+        for delete in deletes:
+            os.unlink(delete)
+
+    print(result)
+
+    assert_equal(result.sequence_name.nunique(), 2)
+    assert_equal(result.best_allele.nunique(), 3)
+    assert_equal(result.experiment_name.nunique(), 2)
+    assert_equal((result.peptide == "ASDFGHKL").sum(), 2)
+    assert_equal((result.peptide != "ASDFGHKL").sum(), 10)
\ No newline at end of file
-- 
GitLab