From d2ff112bcb28ab81032a09bdf7e37de84a9ff6ec Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Sun, 28 Jan 2018 13:20:15 -0500 Subject: [PATCH] better parallelization of model training --- .../train_allele_specific_models_command.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py index e60d6a87..4f48bb62 100644 --- a/mhcflurry/train_allele_specific_models_command.py +++ b/mhcflurry/train_allele_specific_models_command.py @@ -18,6 +18,12 @@ import tqdm # progress bar from .class1_affinity_predictor import Class1AffinityPredictor from .common import configure_logging, set_keras_backend + +# To avoid pickling large matrices to send to child processes when running in +# parallel, we use this global variable as a place to store data. Data that is +# stored here before creating the thread pool will be inherited to the child +# processes upon fork() call, allowing us to share large data with the workers +# efficiently. GLOBAL_DATA = {} @@ -160,6 +166,8 @@ def run(argv=sys.argv[1:]): print("Selected %d alleles: %s" % (len(alleles), ' '.join(alleles))) print("Training data: %s" % (str(df.shape))) + GLOBAL_DATA["train_data"] = df + predictor = Class1AffinityPredictor() if args.train_num_jobs == 1: # Serial run @@ -190,9 +198,7 @@ def run(argv=sys.argv[1:]): hyperparameters['max_epochs'] = args.max_epochs work_items = [] - total_data_to_train_on = 0 - for (i, (allele, sub_df)) in enumerate(df.groupby("allele")): - total_data_to_train_on += len(sub_df) * n_models + for (i, allele) in enumerate(df.allele.unique()): for model_group in range(n_models): work_dict = { 'model_group': model_group, @@ -202,7 +208,7 @@ def run(argv=sys.argv[1:]): 'hyperparameter_set_num': h, 'num_hyperparameter_sets': len(hyperparameters_lst), 'allele': allele, - 'data': sub_df, + 'data': None, # subselect from GLOBAL_DATA["train_data"] 'hyperparameters': hyperparameters, 'verbose': args.verbosity, 'predictor': predictor if not worker_pool else None, @@ -318,6 +324,10 @@ def train_model( if predictor is None: predictor = Class1AffinityPredictor() + if data is None: + full_data = GLOBAL_DATA["train_data"] + data = full_data.loc[full_data.allele == allele] + progress_preamble = ( "[%2d / %2d hyperparameters] " "[%4d / %4d alleles] " -- GitLab