From cc1543cfc9d4d9c2ea54b608399cc20bc130fb39 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Sun, 11 Feb 2018 12:39:00 -0500
Subject: [PATCH] Restart workers every N tasks

---
 .../models_class1_unselected/GENERATE.sh      |   2 +-
 .../train_allele_specific_models_command.py   | 118 ++++++++++++++----
 2 files changed, 94 insertions(+), 26 deletions(-)

diff --git a/downloads-generation/models_class1_unselected/GENERATE.sh b/downloads-generation/models_class1_unselected/GENERATE.sh
index f04e23dd..a3cdb5c0 100755
--- a/downloads-generation/models_class1_unselected/GENERATE.sh
+++ b/downloads-generation/models_class1_unselected/GENERATE.sh
@@ -43,7 +43,7 @@ time mhcflurry-class1-train-allele-specific-models \
     --out-models-dir models \
     --percent-rank-calibration-num-peptides-per-length 0 \
     --min-measurements-per-allele 75 \
-    --num-jobs $(expr $PROCESSORS \* 2) --gpus $GPUS --max-workers-per-gpu 2
+    --num-jobs $(expr $PROCESSORS \* 2) --gpus $GPUS --max-workers-per-gpu 2 --max-tasks-per-worker 20
 
 cp $SCRIPT_ABSOLUTE_PATH .
 bzip2 LOG.txt
diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index 91dcedfc..1f7395f8 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -29,7 +29,6 @@ from .common import configure_logging, set_keras_backend
 # via shared memory.
 GLOBAL_DATA = {}
 
-
 # Note on parallelization:
 # It seems essential currently (tensorflow==1.4.1) that no processes are forked
 # after tensorflow has been used at all, which includes merely importing
@@ -132,6 +131,81 @@ parser.add_argument(
     default=60,
     help="Write models to disk every N seconds. Only affects parallel runs; "
     "serial runs write each model to disk as it is trained.")
+parser.add_argument(
+    "--max-tasks-per-worker",
+    type=int,
+    metavar="N",
+    default=None,
+    help="Restart workers after N tasks. Workaround for tensorflow memory "
+    "leaks. Requires Python >=3.2.")
+
+
+def make_worker_pool(
+        processes=None,
+        initializer=None,
+        initializer_kwargs_per_process=None,
+        max_tasks_per_worker=None):
+    """
+    Convenience wrapper to create a multiprocessing.Pool.
+
+    Parameters
+    ----------
+    processes : int
+        Number of workers. Default: num CPUs.
+
+    initializer : function, optional
+        Init function to call in each worker
+
+    initializer_kwargs_per_process : list of dict, optional
+        Arguments to pass to initializer function for each worker. Length of
+        list must equal the number of workers.
+
+    max_tasks_per_worker : int, optional
+        Restart workers after this many tasks. Requires Python >=3.2.
+
+    Returns
+    -------
+    multiprocessing.Pool
+    """
+
+    if not processes:
+        processes = cpu_count()
+
+    pool_kwargs = {
+        'processes': processes,
+    }
+    if max_tasks_per_worker:
+        pool_kwargs["maxtasksperchild"] = max_tasks_per_worker
+
+    if initializer:
+        if initializer_kwargs_per_process:
+            assert len(initializer_kwargs_per_process) == processes
+            kwargs_queue = Queue()
+            for kwargs in initializer_kwargs_per_process:
+                kwargs_queue.put(kwargs)
+            pool_kwargs["initializer"] = worker_init_entry_point
+            pool_kwargs["initargs"] = (initializer, kwargs_queue)
+        else:
+            pool_kwargs["initializer"] = initializer
+
+    worker_pool = Pool(**pool_kwargs)
+    print("Started pool: %s" % str(worker_pool))
+    pprint(pool_kwargs)
+    return worker_pool
+
+
+def worker_init_entry_point(init_function, arg_queue=None):
+    if arg_queue:
+        kwargs = arg_queue.get()
+
+        # We add the init args back to the queue so restarted workers (e.g. when
+        # when running with maxtasksperchild) will pickup init arguments in a
+        # round-robin style.
+        arg_queue.put(kwargs)
+    else:
+        kwargs = {}
+    init_function(**kwargs)
+
 
 def run(argv=sys.argv[1:]):
     global GLOBAL_DATA
@@ -193,8 +267,7 @@ def run(argv=sys.argv[1:]):
     else:
         # Parallel run.
         num_workers = args.num_jobs[0] if args.num_jobs[0] else cpu_count()
-
-        worker_init_args = None
+        worker_init_kwargs = None
         if args.gpus:
             print("Attempting to round-robin assign each worker a GPU.")
             if args.backend != "tensorflow-default":
@@ -204,7 +277,7 @@ def run(argv=sys.argv[1:]):
             gpu_assignments_remaining = dict((
                 (gpu, args.max_workers_per_gpu) for gpu in range(args.gpus)
             ))
-            worker_init_args = Queue()
+            worker_init_kwargs = []
             for worker_num in range(num_workers):
                 if gpu_assignments_remaining:
                     # Use a GPU
@@ -219,17 +292,18 @@ def run(argv=sys.argv[1:]):
                     # Use CPU
                     gpu_assignment = []
 
-                worker_init_args.put({
+                worker_init_kwargs.append({
                     'gpu_device_nums': gpu_assignment,
                     'keras_backend': args.backend
                 })
-                print("Worker %d assigned GPUs: %s" % (worker_num, gpu_assignment))
+                print("Worker %d assigned GPUs: %s" % (
+                    worker_num, gpu_assignment))
 
-        worker_pool = Pool(
-            initializer=worker_init_entrypoint,
-            initargs=(worker_init_args,),
-            processes=num_workers)
-        print("Started pool of %d workers: %s" % (num_workers, str(worker_pool)))
+        worker_pool = make_worker_pool(
+            processes=num_workers,
+            initializer=worker_init,
+            initializer_kwargs_per_process=worker_init_kwargs,
+            max_tasks_per_worker=args.max_tasks_per_worker)
 
     if not os.path.exists(args.out_models_dir):
         print("Attempting to create directory: %s" % args.out_models_dir)
@@ -275,7 +349,7 @@ def run(argv=sys.argv[1:]):
         random.shuffle(work_items)
 
         results_generator = worker_pool.imap_unordered(
-            train_model_entrypoint, work_items, chunksize=1)
+            train_model_entry_point, work_items, chunksize=1)
 
         unsaved_predictors = []
         last_save_time = time.time()
@@ -309,7 +383,7 @@ def run(argv=sys.argv[1:]):
         # as it goes so no saving is required at the end.
         for _ in tqdm.trange(len(work_items)):
             item = work_items.pop(0)  # want to keep freeing up memory
-            work_predictor = train_model_entrypoint(item)
+            work_predictor = train_model_entry_point(item)
             assert work_predictor is predictor
         assert not work_items
 
@@ -357,11 +431,13 @@ def run(argv=sys.argv[1:]):
             # Store peptides in global variable so they are in shared memory
             # after fork, instead of needing to be pickled.
             GLOBAL_DATA["calibration_peptides"] = encoded_peptides
-            worker_pool = Pool(
+
+            worker_pool = make_worker_pool(
                 processes=(
                     args.num_jobs[-1]
-                    if args.num_jobs[-1] else None))
-            print("Using worker pool: %s" % str(worker_pool))
+                    if args.num_jobs[-1] else None),
+                max_tasks_per_worker=args.max_tasks_per_worker)
+
             results = worker_pool.imap_unordered(
                 partial(
                     calibrate_percentile_ranks,
@@ -386,7 +462,7 @@ def run(argv=sys.argv[1:]):
     print("Predictor written to: %s" % args.out_models_dir)
 
 
-def train_model_entrypoint(item):
+def train_model_entry_point(item):
     return train_model(**item)
 
 
@@ -463,14 +539,6 @@ def calibrate_percentile_ranks(allele, predictor, peptides=None):
     }
 
 
-def worker_init_entrypoint(arg_queue):
-    if arg_queue:
-        kwargs = arg_queue.get()
-    else:
-        kwargs = {}
-    worker_init(**kwargs)
-
-
 def worker_init(keras_backend=None, gpu_device_nums=None):
     if keras_backend or gpu_device_nums:
         print("WORKER pid=%d assigned GPU devices: %s" % (
-- 
GitLab