From e4b4dbb04ec7b9546c46f4fab77fa2f6783506cf Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 4 Feb 2020 11:40:58 -0500
Subject: [PATCH] calibrate percentile ranks for all alleles

---
 .../models_class1_pan/GENERATE.sh             | 17 +++--
 .../calibrate_percentile_ranks_command.py     | 66 +++++++++++++++----
 mhcflurry/class1_affinity_predictor.py        | 14 ++++
 3 files changed, 74 insertions(+), 23 deletions(-)

diff --git a/downloads-generation/models_class1_pan/GENERATE.sh b/downloads-generation/models_class1_pan/GENERATE.sh
index 87261dc5..112c03b7 100755
--- a/downloads-generation/models_class1_pan/GENERATE.sh
+++ b/downloads-generation/models_class1_pan/GENERATE.sh
@@ -99,11 +99,10 @@ for kind in combined
 do
     MODELS_DIR="models.unselected.${kind}"
 
-    # For now we calibrate percentile ranks only for alleles for which there
-    # is training data. Calibrating all alleles would be too slow.
-    # This could be improved though.
-    ALLELE_LIST=$(bzcat "$MODELS_DIR/train_data.csv.bz2" | cut -f 1 -d , | grep -v allele | uniq | sort | uniq)
-    ALLELE_LIST+=$(echo " " $(cat additional_alleles.txt | grep -v '#') )
+    # Older method calibrated only particular alleles. We are now calibrating
+    # all alleles, so this is commented out.
+    #ALLELE_LIST=$(bzcat "$MODELS_DIR/train_data.csv.bz2" | cut -f 1 -d , | grep -v allele | uniq | sort | uniq)
+    #ALLELE_LIST+=$(echo " " $(cat additional_alleles.txt | grep -v '#') )
 
     mhcflurry-class1-select-pan-allele-models \
         --data "$MODELS_DIR/train_data.csv.bz2" \
@@ -114,17 +113,17 @@ do
         $PARALLELISM_ARGS
     cp "$MODELS_DIR/train_data.csv.bz2" "models.${kind}/train_data.csv.bz2"
 
-    # For now we calibrate percentile ranks only for alleles for which there
-    # is training data. Calibrating all alleles would be too slow.
-    # This could be improved though.
+    # We are now calibrating all alleles.
+    # Previously had argument:  --allele $ALLELE_LIST \
     time mhcflurry-calibrate-percentile-ranks \
         --models-dir models.${kind} \
         --match-amino-acid-distribution-data "$MODELS_DIR/train_data.csv.bz2" \
         --motif-summary \
         --num-peptides-per-length 100000 \
-        --allele $ALLELE_LIST \
+        --alleles-per-work-chunk 10 \
         --verbosity 1 \
         $PARALLELISM_ARGS
+
 done
 
 cp $SCRIPT_ABSOLUTE_PATH .
diff --git a/mhcflurry/calibrate_percentile_ranks_command.py b/mhcflurry/calibrate_percentile_ranks_command.py
index c11dd4b1..c28df1ff 100644
--- a/mhcflurry/calibrate_percentile_ranks_command.py
+++ b/mhcflurry/calibrate_percentile_ranks_command.py
@@ -85,6 +85,12 @@ parser.add_argument(
     type=int,
     default=4096,
     help="Keras batch size for predictions")
+parser.add_argument(
+    "--alleles-per-work-chunk",
+    type=int,
+    metavar="N",
+    default=1,
+    help="Number of alleles per work chunk. Default: %(default)s.")
 parser.add_argument(
     "--verbosity",
     type=int,
@@ -120,7 +126,26 @@ def run(argv=sys.argv[1:]):
     else:
         alleles = predictor.supported_alleles
 
-    alleles = sorted(set(alleles))
+    allele_set = set(alleles)
+
+    if predictor.allele_to_sequence:
+        # Remove alleles that have the same sequence.
+        new_allele_set = set()
+        sequence_to_allele = collections.defaultdict(set)
+        for allele in list(allele_set):
+            sequence_to_allele[predictor.allele_to_sequence[allele]].add(allele)
+        for equivalent_alleles in sequence_to_allele.values():
+            equivalent_alleles = sorted(equivalent_alleles)
+            keep = equivalent_alleles.pop(0)
+            new_allele_set.add(keep)
+        print(
+            "Sequence comparison reduced num alleles from",
+            len(allele_set),
+            "to",
+            len(new_allele_set))
+        allele_set = new_allele_set
+
+    alleles = sorted(allele_set)
 
     distribution = None
     if args.match_amino_acid_distribution_data:
@@ -171,7 +196,14 @@ def run(argv=sys.argv[1:]):
     serial_run = not args.cluster_parallelism and args.num_jobs == 0
     worker_pool = None
     start = time.time()
-    work_items = [{"allele": allele} for allele in alleles]
+
+    work_items = []
+    for allele in alleles:
+        if not work_items or len(
+                work_items[-1]['alleles']) >= args.alleles_per_work_chunk:
+            work_items.append({"alleles": []})
+        work_items[-1]['alleles'].append(allele)
+
     if serial_run:
         # Serial run
         print("Running in serial.")
@@ -197,12 +229,13 @@ def run(argv=sys.argv[1:]):
             chunksize=1)
 
     summary_results_lists = collections.defaultdict(list)
-    for (transforms, summary_results) in tqdm.tqdm(results, total=len(work_items)):
-        predictor.allele_to_percent_rank_transform.update(transforms)
-        if summary_results is not None:
-            for (item, value) in summary_results.items():
-                summary_results_lists[item].append(value)
-    print("Done calibrating %d alleles." % len(work_items))
+    for work_item in tqdm.tqdm(results, total=len(work_items)):
+        for (transforms, summary_results) in work_item:
+            predictor.allele_to_percent_rank_transform.update(transforms)
+            if summary_results is not None:
+                for (item, value) in summary_results.items():
+                    summary_results_lists[item].append(value)
+    print("Done calibrating %d alleles." % len(alleles))
     if summary_results_lists:
         for (name, lst) in summary_results_lists.items():
             df = pandas.concat(lst, ignore_index=True)
@@ -223,12 +256,17 @@ def run(argv=sys.argv[1:]):
     print("Predictor written to: %s" % args.models_dir)
 
 
-def do_calibrate_percentile_ranks(allele, constant_data=GLOBAL_DATA):
-    return calibrate_percentile_ranks(
-        allele,
-        constant_data['predictor'],
-        peptides=constant_data['calibration_peptides'],
-        **constant_data["args"])
+def do_calibrate_percentile_ranks(alleles, constant_data=GLOBAL_DATA):
+    result_list = []
+    for (i, allele) in enumerate(alleles):
+        print("Processing allele", i + 1, "of", len(alleles))
+        result_item = calibrate_percentile_ranks(
+            allele,
+            constant_data['predictor'],
+            peptides=constant_data['calibration_peptides'],
+            **constant_data["args"])
+        result_list.append(result_item)
+    return result_list
 
 
 def calibrate_percentile_ranks(
diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py
index 9a3258cf..2b08a6ec 100644
--- a/mhcflurry/class1_affinity_predictor.py
+++ b/mhcflurry/class1_affinity_predictor.py
@@ -897,6 +897,20 @@ class Class1AffinityPredictor(object):
                 transform = self.allele_to_percent_rank_transform[normalized_allele]
                 return transform.transform(affinities)
             except KeyError:
+                if self.allele_to_sequence:
+                    # See if we have information for an equivalent allele
+                    sequence = self.allele_to_sequence[normalized_allele]
+                    other_alleles = [
+                        other_allele for (other_allele, other_sequence)
+                        in self.allele_to_sequence.items()
+                        if other_sequence == sequence
+                    ]
+                    for other_allele in other_alleles:
+                        if other_allele in self.allele_to_percent_rank_transform:
+                            transform = self.allele_to_percent_rank_transform[
+                                other_allele]
+                            return transform.transform(affinities)
+
                 msg = "Allele %s has no percentile rank information" % (
                     allele + (
                         "" if allele == normalized_allele
-- 
GitLab