From 32bec557799b519660bb693331eba4d276622b9b Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Wed, 19 Jun 2019 12:26:50 -0400
Subject: [PATCH] fix

---
 mhcflurry/parallelism.py | 28 +++++++++++++++++++++-------
 1 file changed, 21 insertions(+), 7 deletions(-)

diff --git a/mhcflurry/parallelism.py b/mhcflurry/parallelism.py
index 55d7c450..f1bd1493 100644
--- a/mhcflurry/parallelism.py
+++ b/mhcflurry/parallelism.py
@@ -1,6 +1,7 @@
 import traceback
 import sys
 import os
+import functools
 from multiprocessing import Pool, Queue, cpu_count
 from six.moves import queue
 from multiprocessing.util import Finalize
@@ -46,6 +47,10 @@ def add_worker_pool_args(parser):
         default=None,
         help="Restart workers after N tasks. Workaround for tensorflow memory "
              "leaks. Requires Python >=3.2.")
+    group.add_argument(
+        "--worker-log-dir",
+        default=None,
+        help="Write worker stdout and stderr logs to given directory.")
 
 
 def worker_pool_with_gpu_assignments_from_args(args):
@@ -54,7 +59,8 @@ def worker_pool_with_gpu_assignments_from_args(args):
         num_gpus=args.gpus,
         backend=args.backend,
         max_workers_per_gpu=args.max_workers_per_gpu,
-        max_tasks_per_worker=args.max_tasks_per_worker
+        max_tasks_per_worker=args.max_tasks_per_worker,
+        worker_log_dir=args.worker_log_dir,
     )
 
 
@@ -63,7 +69,8 @@ def worker_pool_with_gpu_assignments(
         num_gpus=0,
         backend=None,
         max_workers_per_gpu=1,
-        max_tasks_per_worker=None):
+        max_tasks_per_worker=None,
+        worker_log_dir=None):
 
     num_workers = num_jobs if num_jobs else cpu_count()
 
@@ -72,7 +79,7 @@ def worker_pool_with_gpu_assignments(
             set_keras_backend(backend)
         return None
 
-    worker_init_kwargs = None
+    worker_init_kwargs = [{} for _ in range(num_workers)]
     if num_gpus:
         print("Attempting to round-robin assign each worker a GPU.")
         if backend != "tensorflow-default":
@@ -82,8 +89,7 @@ def worker_pool_with_gpu_assignments(
         gpu_assignments_remaining = dict((
             (gpu, max_workers_per_gpu) for gpu in range(num_gpus)
         ))
-        worker_init_kwargs = []
-        for worker_num in range(num_workers):
+        for (worker_num, kwargs) in enumerate(worker_init_kwargs):
             if gpu_assignments_remaining:
                 # Use a GPU
                 gpu_num = sorted(
@@ -97,13 +103,17 @@ def worker_pool_with_gpu_assignments(
                 # Use CPU
                 gpu_assignment = []
 
-            worker_init_kwargs.append({
+            kwargs.update({
                 'gpu_device_nums': gpu_assignment,
                 'keras_backend': backend
             })
             print("Worker %d assigned GPUs: %s" % (
                 worker_num, gpu_assignment))
 
+    if worker_log_dir:
+        for kwargs in worker_init_kwargs:
+            kwargs["worker_log_dir"] = worker_log_dir
+
     worker_pool = make_worker_pool(
         processes=num_workers,
         initializer=worker_init,
@@ -208,7 +218,11 @@ def worker_init_entry_point(
     init_function(**kwargs)
 
 
-def worker_init(keras_backend=None, gpu_device_nums=None):
+def worker_init(keras_backend=None, gpu_device_nums=None, worker_log_dir=None):
+    if worker_log_dir:
+        sys.stderr = sys.stdout = open(
+            os.path.join(worker_log_dir, "LOG-worker.%d.txt", "w"))
+
     # Each worker needs distinct random numbers
     numpy.random.seed()
     random.seed()
-- 
GitLab