diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py index 406da3e0ea647e741cdbe5dc95ee243496c5b7aa..583fe01308bd5bf03b4b982590b14a4d50b943fd 100644 --- a/mhcflurry/train_allele_specific_models_command.py +++ b/mhcflurry/train_allele_specific_models_command.py @@ -155,14 +155,14 @@ def run(argv=sys.argv[1:]): print("Training data: %s" % (str(df.shape))) predictor = Class1AffinityPredictor() - if args.parallelization_num_jobs == 1: + if args.train_num_jobs == 1: # Serial run worker_pool = None else: worker_pool = Pool( processes=( - args.parallelization_num_jobs - if args.parallelization_num_jobs else None)) + args.train_num_jobs + if args.train_num_jobs else None)) print("Using worker pool: %s" % str(worker_pool)) if args.out_models_dir and not os.path.exists(args.out_models_dir): @@ -235,7 +235,21 @@ def run(argv=sys.argv[1:]): len(predictor.neural_networks), training_time / 60.0)) print("*" * 30) + if worker_pool: + worker_pool.close() + worker_pool.join() + if args.percent_rank_calibration_num_peptides_per_length > 0: + if args.calibration_num_jobs == 1: + # Serial run + worker_pool = None + else: + worker_pool = Pool( + processes=( + args.calibration_num_jobs + if args.train_num_jobs else None)) + print("Using worker pool: %s" % str(worker_pool)) + print("Performing percent rank calibration.") start = time.time() predictor.calibrate_percentile_ranks(