From b9fd472a9e77aa6797743938ae8bc67ff04a1759 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Sun, 11 Feb 2018 13:40:43 -0500 Subject: [PATCH] Add parallelism.py. Expand hyperparameter search." --- .../models_class1/generate_hyperparameters.py | 4 +- mhcflurry/common.py | 2 +- mhcflurry/parallelism.py | 100 ++++++++++++++++++ .../train_allele_specific_models_command.py | 81 +------------- 4 files changed, 104 insertions(+), 83 deletions(-) create mode 100644 mhcflurry/parallelism.py diff --git a/downloads-generation/models_class1/generate_hyperparameters.py b/downloads-generation/models_class1/generate_hyperparameters.py index c9f54646..e958e101 100644 --- a/downloads-generation/models_class1/generate_hyperparameters.py +++ b/downloads-generation/models_class1/generate_hyperparameters.py @@ -64,8 +64,8 @@ base_hyperparameters = { grid = [] for train_subset in ["all", "quantitative"]: - for minibatch_size in [128]: - for dense_layer_size in [8, 16, 32, 64]: + for minibatch_size in [128, 512]: + for dense_layer_size in [8, 16, 32, 64, 128]: for l1 in [0.0, 0.001]: for num_lc in [0, 1, 2]: for lc_kernel_size in [3, 5]: diff --git a/mhcflurry/common.py b/mhcflurry/common.py index 3bd2390c..4b6f6d56 100644 --- a/mhcflurry/common.py +++ b/mhcflurry/common.py @@ -32,7 +32,7 @@ def set_keras_backend(backend=None, gpu_device_nums=None): os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( [str(i) for i in gpu_device_nums]) - if backend == "tensorflow-cpu": + if backend == "tensorflow-cpu" or not gpu_device_nums: print("Forcing tensorflow/CPU backend.") os.environ["CUDA_VISIBLE_DEVICES"] = "" device_count = {'CPU': 1, 'GPU': 0} diff --git a/mhcflurry/parallelism.py b/mhcflurry/parallelism.py new file mode 100644 index 00000000..dcce96e5 --- /dev/null +++ b/mhcflurry/parallelism.py @@ -0,0 +1,100 @@ +from multiprocessing import Pool, Queue, cpu_count +from queue import Empty +from multiprocessing.util import Finalize +from pprint import pprint + + +def make_worker_pool( + processes=None, + initializer=None, + initializer_kwargs_per_process=None, + max_tasks_per_worker=None): + """ + Convenience wrapper to create a multiprocessing.Pool. + + This function adds support for per-worker initializer arguments, which are + not natively supported by the multiprocessing module. The motivation for + this feature is to support allocating each worker to a (different) GPU. + + IMPLEMENTATION NOTE: + The per-worker initializer arguments are implemented using a Queue. Each + worker reads its arguments from this queue when it starts. When it + terminates, it adds its initializer arguments back to the queue, so a + future process can initialize itself using these arguments. + + There is one issue with this approach, however. If a worker crashes, it + never repopulates the queue of initializer arguments. This will prevent + any future worker from re-using those arguments. To deal with this + issue we add a second 'backup queue'. This queue always contains the + full set of initializer arguments: whenever a worker reads from it, it + always pushes the pop'd args back to the end of the queue immediately. + If the primary arg queue is every empty, then workers will read + from this backup queue. + + Parameters + ---------- + processes : int + Number of workers. Default: num CPUs. + + initializer : function, optional + Init function to call in each worker + + initializer_kwargs_per_process : list of dict, optional + Arguments to pass to initializer function for each worker. Length of + list must equal the number of workers. + + max_tasks_per_worker : int, optional + Restart workers after this many tasks. Requires Python >=3.2. + + Returns + ------- + multiprocessing.Pool + """ + + if not processes: + processes = cpu_count() + + pool_kwargs = { + 'processes': processes, + } + if max_tasks_per_worker: + pool_kwargs["maxtasksperchild"] = max_tasks_per_worker + + if initializer: + if initializer_kwargs_per_process: + assert len(initializer_kwargs_per_process) == processes + kwargs_queue = Queue() + kwargs_queue_backup = Queue() + for kwargs in initializer_kwargs_per_process: + kwargs_queue.put(kwargs) + kwargs_queue_backup.put(kwargs) + pool_kwargs["initializer"] = worker_init_entry_point + pool_kwargs["initargs"] = ( + initializer, kwargs_queue, kwargs_queue_backup) + else: + pool_kwargs["initializer"] = initializer + + worker_pool = Pool(**pool_kwargs) + print("Started pool: %s" % str(worker_pool)) + pprint(pool_kwargs) + return worker_pool + + +def worker_init_entry_point( + init_function, arg_queue=None, backup_arg_queue=None): + kwargs = {} + if arg_queue: + try: + kwargs = arg_queue.get(block=False) + except Empty: + print("Argument queue empty. Using round robin arg queue.") + kwargs = backup_arg_queue.get(block=True) + backup_arg_queue.put(kwargs) + + # On exit we add the init args back to the queue so restarted workers + # (e.g. when when running with maxtasksperchild) will pickup init + # arguments from a previously exited worker. + Finalize(None, arg_queue.put, (kwargs,), exitpriority=1) + + print("Initializing worker: %s" % str(kwargs)) + init_function(**kwargs) diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py index 98dfc408..eac3f9c3 100644 --- a/mhcflurry/train_allele_specific_models_command.py +++ b/mhcflurry/train_allele_specific_models_command.py @@ -8,11 +8,7 @@ import sys import time import traceback import random -from multiprocessing import Pool, Queue, cpu_count -from queue import Empty -from multiprocessing.util import Finalize from functools import partial -from pprint import pprint import pandas import yaml @@ -22,6 +18,7 @@ tqdm.monitor_interval = 0 # see https://github.com/tqdm/tqdm/issues/481 from .class1_affinity_predictor import Class1AffinityPredictor from .common import configure_logging, set_keras_backend +from .parallelism import make_worker_pool, cpu_count # To avoid pickling large matrices to send to child processes when running in @@ -142,82 +139,6 @@ parser.add_argument( "leaks. Requires Python >=3.2.") -def make_worker_pool( - processes=None, - initializer=None, - initializer_kwargs_per_process=None, - max_tasks_per_worker=None): - """ - Convenience wrapper to create a multiprocessing.Pool. - - Parameters - ---------- - processes : int - Number of workers. Default: num CPUs. - - initializer : function, optional - Init function to call in each worker - - initializer_kwargs_per_process : list of dict, optional - Arguments to pass to initializer function for each worker. Length of - list must equal the number of workers. - - max_tasks_per_worker : int, optional - Restart workers after this many tasks. Requires Python >=3.2. - - Returns - ------- - multiprocessing.Pool - """ - - if not processes: - processes = cpu_count() - - pool_kwargs = { - 'processes': processes, - } - if max_tasks_per_worker: - pool_kwargs["maxtasksperchild"] = max_tasks_per_worker - - if initializer: - if initializer_kwargs_per_process: - assert len(initializer_kwargs_per_process) == processes - kwargs_queue = Queue() - kwargs_queue2 = Queue() - for kwargs in initializer_kwargs_per_process: - kwargs_queue.put(kwargs) - kwargs_queue2.put(kwargs) - pool_kwargs["initializer"] = worker_init_entry_point - pool_kwargs["initargs"] = (initializer, kwargs_queue, kwargs_queue2) - else: - pool_kwargs["initializer"] = initializer - - worker_pool = Pool(**pool_kwargs) - print("Started pool: %s" % str(worker_pool)) - pprint(pool_kwargs) - return worker_pool - - -def worker_init_entry_point( - init_function, arg_queue=None, round_robin_arg_queue=None): - kwargs = {} - if arg_queue: - try: - kwargs = arg_queue.get(block=False) - except Empty: - print("Argument queue empty. Using round robin arg queue.") - kwargs = round_robin_arg_queue.get(block=True) - round_robin_arg_queue.put(kwargs) - - # On exit we add the init args back to the queue so restarted workers - # (e.g. when when running with maxtasksperchild) will pickup init - # arguments from a previously exited worker. - Finalize(None, arg_queue.put, (kwargs,), exitpriority=1) - - print("Initializing worker: %s" % str(kwargs)) - init_function(**kwargs) - - def run(argv=sys.argv[1:]): global GLOBAL_DATA -- GitLab