From 0b3a9b9fd7bf8f3afa78c562771bba1508a43b0d Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Sat, 27 Jan 2018 22:26:33 -0500 Subject: [PATCH] fix --- .../train_allele_specific_models_command.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py index 406da3e0..583fe013 100644 --- a/mhcflurry/train_allele_specific_models_command.py +++ b/mhcflurry/train_allele_specific_models_command.py @@ -155,14 +155,14 @@ def run(argv=sys.argv[1:]): print("Training data: %s" % (str(df.shape))) predictor = Class1AffinityPredictor() - if args.parallelization_num_jobs == 1: + if args.train_num_jobs == 1: # Serial run worker_pool = None else: worker_pool = Pool( processes=( - args.parallelization_num_jobs - if args.parallelization_num_jobs else None)) + args.train_num_jobs + if args.train_num_jobs else None)) print("Using worker pool: %s" % str(worker_pool)) if args.out_models_dir and not os.path.exists(args.out_models_dir): @@ -235,7 +235,21 @@ def run(argv=sys.argv[1:]): len(predictor.neural_networks), training_time / 60.0)) print("*" * 30) + if worker_pool: + worker_pool.close() + worker_pool.join() + if args.percent_rank_calibration_num_peptides_per_length > 0: + if args.calibration_num_jobs == 1: + # Serial run + worker_pool = None + else: + worker_pool = Pool( + processes=( + args.calibration_num_jobs + if args.train_num_jobs else None)) + print("Using worker pool: %s" % str(worker_pool)) + print("Performing percent rank calibration.") start = time.time() predictor.calibrate_percentile_ranks( -- GitLab