From f4d66885cf238d8bf64a575ac5d654415902ba72 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Tue, 13 Feb 2018 16:43:26 -0500 Subject: [PATCH] more parallelism --- .../train_allele_specific_models_command.py | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py index 694e4a47..7e4fdeb3 100644 --- a/mhcflurry/train_allele_specific_models_command.py +++ b/mhcflurry/train_allele_specific_models_command.py @@ -258,21 +258,22 @@ def run(argv=sys.argv[1:]): hyperparameters['max_epochs'] = args.max_epochs for (i, allele) in enumerate(df.allele.unique()): - work_dict = { - 'n_models': n_models, - 'allele_num': i, - 'n_alleles': len(alleles), - 'hyperparameter_set_num': h, - 'num_hyperparameter_sets': len(hyperparameters_lst), - 'allele': allele, - 'data': None, # subselect from GLOBAL_DATA["train_data"] - 'hyperparameters': hyperparameters, - 'verbose': args.verbosity, - 'progress_print_interval': None if worker_pool else 5.0, - 'predictor': predictor if not worker_pool else None, - 'save_to': args.out_models_dir if not worker_pool else None, - } - work_items.append(work_dict) + for model_num in range(n_models): + work_dict = { + 'n_models': 1, + 'allele_num': i, + 'n_alleles': len(alleles), + 'hyperparameter_set_num': h, + 'num_hyperparameter_sets': len(hyperparameters_lst), + 'allele': allele, + 'data': None, # subselect from GLOBAL_DATA["train_data"] + 'hyperparameters': hyperparameters, + 'verbose': args.verbosity, + 'progress_print_interval': None if worker_pool else 5.0, + 'predictor': predictor if not worker_pool else None, + 'save_to': args.out_models_dir if not worker_pool else None, + } + work_items.append(work_dict) if worker_pool: print("Processing %d work items in parallel." % len(work_items)) -- GitLab