Skip to content
Snippets Groups Projects
Commit 03dcdb51 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fixes

parent a9ec7e1c
No related merge requests found
......@@ -197,6 +197,9 @@ def run(argv=sys.argv[1:]):
worker_init_args = None
if args.gpus:
print("Attempting to round-robin assign each worker a GPU.")
if args.backend != "tensorflow-default":
print("Forcing keras backend to be tensorflow-default")
args.backend = "tensorflow-default"
gpu_assignments_remaining = dict((
(gpu, args.max_workers_per_gpu) for gpu in range(args.gpus)
......@@ -211,6 +214,7 @@ def run(argv=sys.argv[1:]):
gpu_assignments_remaining[gpu_num] -= 1
if not gpu_assignments_remaining[gpu_num]:
del gpu_assignments_remaining[gpu_num]
gpu_assignment = [gpu_num]
else:
# Use CPU
gpu_assignment = []
......@@ -219,9 +223,10 @@ def run(argv=sys.argv[1:]):
'gpu_device_nums': gpu_assignment,
'keras_backend': args.backend
})
print("Worker %d assigned GPUs: %s" % (worker_num, gpu_assignment))
worker_pool = Pool(
initializer=worker_init,
initializer=worker_init_entrypoint,
initargs=(worker_init_args,),
processes=num_workers)
print("Started pool of %d workers: %s" % (num_workers, str(worker_pool)))
......@@ -449,17 +454,16 @@ def calibrate_percentile_ranks(allele, predictor, peptides=None):
def worker_init_entrypoint(arg_queue):
if arg_queue:
(args, kwargs) = arg_queue.get()
kwargs = arg_queue.get()
else:
args = []
kwargs = {}
worker_init(*args, **kwargs)
worker_init(**kwargs)
def worker_init(keras_backend=None, gpu_device_nums=None):
if keras_backend or gpu_device_nums:
print("WORKER pid=%d assigned GPU devices: %s" % (
os.getpgid()), gpu_device_nums)
os.getpid(), gpu_device_nums))
set_keras_backend(
keras_backend, gpu_device_nums=gpu_device_nums)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment