Newer
Older
from multiprocessing import Pool, Queue, cpu_count
from multiprocessing.util import Finalize
from pprint import pprint
import random
import numpy
from .common import set_keras_backend
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def make_worker_pool(
processes=None,
initializer=None,
initializer_kwargs_per_process=None,
max_tasks_per_worker=None):
"""
Convenience wrapper to create a multiprocessing.Pool.
This function adds support for per-worker initializer arguments, which are
not natively supported by the multiprocessing module. The motivation for
this feature is to support allocating each worker to a (different) GPU.
IMPLEMENTATION NOTE:
The per-worker initializer arguments are implemented using a Queue. Each
worker reads its arguments from this queue when it starts. When it
terminates, it adds its initializer arguments back to the queue, so a
future process can initialize itself using these arguments.
There is one issue with this approach, however. If a worker crashes, it
never repopulates the queue of initializer arguments. This will prevent
any future worker from re-using those arguments. To deal with this
issue we add a second 'backup queue'. This queue always contains the
full set of initializer arguments: whenever a worker reads from it, it
always pushes the pop'd args back to the end of the queue immediately.
If the primary arg queue is every empty, then workers will read
from this backup queue.
Parameters
----------
processes : int
Number of workers. Default: num CPUs.
initializer : function, optional
Init function to call in each worker
initializer_kwargs_per_process : list of dict, optional
Arguments to pass to initializer function for each worker. Length of
list must equal the number of workers.
max_tasks_per_worker : int, optional
Restart workers after this many tasks. Requires Python >=3.2.
Returns
-------
multiprocessing.Pool
"""
if not processes:
processes = cpu_count()
pool_kwargs = {
'processes': processes,
}
if max_tasks_per_worker:
pool_kwargs["maxtasksperchild"] = max_tasks_per_worker
if initializer:
if initializer_kwargs_per_process:
assert len(initializer_kwargs_per_process) == processes
kwargs_queue = Queue()
kwargs_queue_backup = Queue()
for kwargs in initializer_kwargs_per_process:
kwargs_queue.put(kwargs)
kwargs_queue_backup.put(kwargs)
pool_kwargs["initializer"] = worker_init_entry_point
pool_kwargs["initargs"] = (
initializer, kwargs_queue, kwargs_queue_backup)
else:
pool_kwargs["initializer"] = initializer
worker_pool = Pool(**pool_kwargs)
print("Started pool: %s" % str(worker_pool))
pprint(pool_kwargs)
return worker_pool
def worker_init_entry_point(
init_function, arg_queue=None, backup_arg_queue=None):
kwargs = {}
if arg_queue:
try:
kwargs = arg_queue.get(block=False)
print("Argument queue empty. Using round robin arg queue.")
kwargs = backup_arg_queue.get(block=True)
backup_arg_queue.put(kwargs)
# On exit we add the init args back to the queue so restarted workers
# (e.g. when when running with maxtasksperchild) will pickup init
# arguments from a previously exited worker.
Finalize(None, arg_queue.put, (kwargs,), exitpriority=1)
print("Initializing worker: %s" % str(kwargs))
init_function(**kwargs)
def worker_init(keras_backend=None, gpu_device_nums=None):
# Each worker needs distinct random numbers
numpy.random.seed()
random.seed()
if keras_backend or gpu_device_nums:
print("WORKER pid=%d assigned GPU devices: %s" % (
os.getpid(), gpu_device_nums))
set_keras_backend(
keras_backend, gpu_device_nums=gpu_device_nums)
# Solution suggested in https://bugs.python.org/issue13831
class WrapException(Exception):
"""
Add traceback info to exception so exceptions raised in worker processes
can still show traceback info when re-raised in the parent.
"""
def __init__(self):
exc_type, exc_value, exc_tb = sys.exc_info()
self.exception = exc_value
self.formatted = ''.join(traceback.format_exception(exc_type, exc_value, exc_tb))
def __str__(self):
return '%s\nOriginal traceback:\n%s' % (Exception.__str__(self), self.formatted)
def call_wrapped(function, *args, **kwargs):
try:
return function(*args, **kwargs)
except:
raise WrapException()
def call_wrapped_kwargs(function, kwargs):
return call_wrapped(function, **kwargs)