Skip to content
Snippets Groups Projects
Commit 03dcdb51 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fixes

parent a9ec7e1c
No related branches found
No related tags found
No related merge requests found
......@@ -197,6 +197,9 @@ def run(argv=sys.argv[1:]):
worker_init_args = None
if args.gpus:
print("Attempting to round-robin assign each worker a GPU.")
if args.backend != "tensorflow-default":
print("Forcing keras backend to be tensorflow-default")
args.backend = "tensorflow-default"
gpu_assignments_remaining = dict((
(gpu, args.max_workers_per_gpu) for gpu in range(args.gpus)
......@@ -211,6 +214,7 @@ def run(argv=sys.argv[1:]):
gpu_assignments_remaining[gpu_num] -= 1
if not gpu_assignments_remaining[gpu_num]:
del gpu_assignments_remaining[gpu_num]
gpu_assignment = [gpu_num]
else:
# Use CPU
gpu_assignment = []
......@@ -219,9 +223,10 @@ def run(argv=sys.argv[1:]):
'gpu_device_nums': gpu_assignment,
'keras_backend': args.backend
})
print("Worker %d assigned GPUs: %s" % (worker_num, gpu_assignment))
worker_pool = Pool(
initializer=worker_init,
initializer=worker_init_entrypoint,
initargs=(worker_init_args,),
processes=num_workers)
print("Started pool of %d workers: %s" % (num_workers, str(worker_pool)))
......@@ -449,17 +454,16 @@ def calibrate_percentile_ranks(allele, predictor, peptides=None):
def worker_init_entrypoint(arg_queue):
if arg_queue:
(args, kwargs) = arg_queue.get()
kwargs = arg_queue.get()
else:
args = []
kwargs = {}
worker_init(*args, **kwargs)
worker_init(**kwargs)
def worker_init(keras_backend=None, gpu_device_nums=None):
if keras_backend or gpu_device_nums:
print("WORKER pid=%d assigned GPU devices: %s" % (
os.getpgid()), gpu_device_nums)
os.getpid(), gpu_device_nums))
set_keras_backend(
keras_backend, gpu_device_nums=gpu_device_nums)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment