From 098f5e71454ae74f9cfa5fefc1d65452fd69a640 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Sun, 28 Jan 2018 13:08:43 -0500
Subject: [PATCH] redo parallelization implementation of percentile rank
 calibration

---
 mhcflurry/class1_affinity_predictor.py        | 68 ++++--------------
 .../train_allele_specific_models_command.py   | 70 +++++++++++++++----
 ...st_train_allele_specific_models_command.py |  3 +-
 3 files changed, 70 insertions(+), 71 deletions(-)

diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py
index 7f0d248f..1bbefd41 100644
--- a/mhcflurry/class1_affinity_predictor.py
+++ b/mhcflurry/class1_affinity_predictor.py
@@ -889,8 +889,7 @@ class Class1AffinityPredictor(object):
             peptides=None,
             num_peptides_per_length=int(1e5),
             alleles=None,
-            bins=None,
-            worker_pool=None):
+            bins=None):
         """
         Compute the cumulative distribution of ic50 values for a set of alleles
         over a large universe of random peptides, to enable computing quantiles in
@@ -898,7 +897,7 @@ class Class1AffinityPredictor(object):
 
         Parameters
         ----------
-        peptides : sequence of string, optional
+        peptides : sequence of string or EncodableSequences, optional
             Peptides to use
         num_peptides_per_length : int, optional
             If peptides argument is not specified, then num_peptides_per_length
@@ -911,8 +910,10 @@ class Class1AffinityPredictor(object):
             Anything that can be passed to numpy.histogram's "bins" argument
             can be used here, i.e. either an integer or a sequence giving bin
             edges. This is in ic50 space.
-        worker_pool : multiprocessing.Pool, optional
-            If specified multiple alleles will be calibrated in parallel
+
+        Returns
+        ----------
+        EncodableSequences : peptides used for calibration
         """
         if bins is None:
             bins = to_ic50(numpy.linspace(1, 0, 1000))
@@ -931,57 +932,12 @@ class Class1AffinityPredictor(object):
 
         encoded_peptides = EncodableSequences.create(peptides)
 
-        if worker_pool and len(alleles) > 1:
-            # Run in parallel
-
-            # Performance hack.
-            self.neural_networks[0].peptides_to_network_input(encoded_peptides)
-
-            do_work = partial(
-                _calibrate_percentile_ranks,
-                predictor=self,
-                peptides=encoded_peptides,
-                bins=bins)
-            list_of_singleton_alleles = [ [allele] for allele in alleles ]
-            results = worker_pool.imap_unordered(
-                do_work, list_of_singleton_alleles, chunksize=1)
-
-            # Add progress bar
-            results = tqdm.tqdm(results, ascii=True, total=len(alleles))
+        for (i, allele) in enumerate(alleles):
+            predictions = self.predict(peptides, allele=allele)
+            transform = PercentRankTransform()
+            transform.fit(predictions, bins=bins)
+            self.allele_to_percent_rank_transform[allele] = transform
 
-            # Merge results
-            for partial_dict in results:
-                self.allele_to_percent_rank_transform.update(partial_dict)
-        else:
-            # Run in serial
-            self.allele_to_percent_rank_transform.update(
-                _calibrate_percentile_ranks(
-                    alleles=alleles,
-                    predictor=self,
-                    peptides=encoded_peptides,
-                    bins=bins))
-
-
-def _calibrate_percentile_ranks(alleles, predictor, peptides, bins):
-    """
-    Private helper function.
-
-    Parameters
-    ----------
-    alleles : list of string
-    predictor : Class1AffinityPredictor
-    peptides : list of string or EncodableSequences
-    bins : object
+        return encoded_peptides
 
-    Returns
-    -------
-    dict : allele -> percentile rank transform
 
-    """
-    result = {}
-    for (i, allele) in enumerate(alleles):
-        predictions = predictor.predict(peptides, allele=allele)
-        transform = PercentRankTransform()
-        transform.fit(predictions, bins=bins)
-        result[allele] = transform
-    return result
diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index 583fe013..e60d6a87 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -8,6 +8,7 @@ import sys
 import time
 import traceback
 from multiprocessing import Pool
+from functools import partial
 
 import pandas
 import yaml
@@ -17,6 +18,9 @@ import tqdm  # progress bar
 from .class1_affinity_predictor import Class1AffinityPredictor
 from .common import configure_logging, set_keras_backend
 
+GLOBAL_DATA = {}
+
+
 parser = argparse.ArgumentParser(usage=__doc__)
 
 parser.add_argument(
@@ -106,6 +110,8 @@ parser.add_argument(
     help="Keras backend. If not specified will use system default.")
 
 def run(argv=sys.argv[1:]):
+    global GLOBAL_DATA
+
     # On sigusr1 print stack trace
     print("To show stack trace, run:\nkill -s USR1 %d" % os.getpid())
     signal.signal(signal.SIGUSR1, lambda sig, frame: traceback.print_stack())
@@ -210,7 +216,7 @@ def run(argv=sys.argv[1:]):
             predictors = list(
                 tqdm.tqdm(
                     worker_pool.imap_unordered(
-                        work_entrypoint, work_items, chunksize=1),
+                        train_model_entrypoint, work_items, chunksize=1),
                     ascii=True,
                     total=len(work_items)))
 
@@ -225,7 +231,7 @@ def run(argv=sys.argv[1:]):
             start = time.time()
             for _ in tqdm.trange(len(work_items)):
                 item = work_items.pop(0)  # want to keep freeing up memory
-                work_predictor = work_entrypoint(item)
+                work_predictor = train_model_entrypoint(item)
                 assert work_predictor is predictor
             assert not work_items
 
@@ -240,24 +246,46 @@ def run(argv=sys.argv[1:]):
         worker_pool.join()
 
     if args.percent_rank_calibration_num_peptides_per_length > 0:
+        alleles = list(predictor.supported_alleles)
+        first_allele = alleles.pop(0)
+
+        print("Performing percent rank calibration. Calibrating first allele.")
+        start = time.time()
+        encoded_peptides = predictor.calibrate_percentile_ranks(
+            alleles=[first_allele],
+            num_peptides_per_length=args.percent_rank_calibration_num_peptides_per_length)
+        percent_rank_calibration_time = time.time() - start
+        print("Finished calibrating percent ranks for first allele in %0.2f sec." % (
+            percent_rank_calibration_time))
+        print("Calibrating %d additional alleles." % len(alleles))
+
         if args.calibration_num_jobs == 1:
             # Serial run
             worker_pool = None
+            results = (
+                calibrate_percentile_ranks(
+                    allele=allele,
+                    predictor=predictor,
+                    peptides=encoded_peptides)
+                for allele in alleles)
         else:
+            # Parallel run
+            # Store peptides in global variable so they are in shared memory
+            # after fork, instead of needing to be pickled.
+            GLOBAL_DATA["calibration_peptides"] = encoded_peptides
             worker_pool = Pool(
                 processes=(
                     args.calibration_num_jobs
                     if args.train_num_jobs else None))
             print("Using worker pool: %s" % str(worker_pool))
-
-        print("Performing percent rank calibration.")
-        start = time.time()
-        predictor.calibrate_percentile_ranks(
-            num_peptides_per_length=args.percent_rank_calibration_num_peptides_per_length,
-            worker_pool=worker_pool)
-        percent_rank_calibration_time = time.time() - start
-        print("Finished calibrating percent ranks in %0.2f sec." % (
-            percent_rank_calibration_time))
+            results = worker_pool.imap_unordered(
+                partial(
+                    calibrate_percentile_ranks,
+                    predictor=predictor), alleles, chunksize=1)
+
+        for result in tqdm.tqdm(results, ascii=True, total=len(alleles)):
+            predictor.allele_to_percent_rank_transform.update(result)
+        print("Done calibrating %d additional alleles." % len(alleles))
         predictor.save(args.out_models_dir, model_names_to_write=[])
 
     if worker_pool:
@@ -269,11 +297,11 @@ def run(argv=sys.argv[1:]):
     print("Predictor written to: %s" % args.out_models_dir)
 
 
-def work_entrypoint(item):
-    return process_work(**item)
+def train_model_entrypoint(item):
+    return train_model(**item)
 
 
-def process_work(
+def train_model(
         model_group,
         n_models,
         allele_num,
@@ -325,5 +353,19 @@ def process_work(
     return predictor
 
 
+def calibrate_percentile_ranks(allele, predictor, peptides=None):
+    """
+    Private helper function.
+    """
+    if peptides is None:
+        peptides = GLOBAL_DATA["calibration_peptides"]
+    predictor.calibrate_percentile_ranks(
+        peptides=peptides,
+        alleles=[allele])
+    return {
+        allele: predictor.allele_to_percent_rank_transform[allele],
+    }
+
+
 if __name__ == '__main__':
     run()
diff --git a/test/test_train_allele_specific_models_command.py b/test/test_train_allele_specific_models_command.py
index db6ef9a8..537aac7a 100644
--- a/test/test_train_allele_specific_models_command.py
+++ b/test/test_train_allele_specific_models_command.py
@@ -62,7 +62,8 @@ def run_and_check(n_jobs=0):
         "--allele", "HLA-A*02:01", "HLA-A*01:01", "HLA-A*03:01",
         "--out-models-dir", models_dir,
         "--percent-rank-calibration-num-peptides-per-length", "10000",
-        "--parallelization-num-jobs", str(n_jobs),
+        "--train-num-jobs", str(n_jobs),
+        "--calibration-num-jobs", str(n_jobs),
         "--ignore-inequalities",
     ]
     print("Running with args: %s" % args)
-- 
GitLab