diff --git a/mhcflurry/calibrate_percentile_ranks_command.py b/mhcflurry/calibrate_percentile_ranks_command.py index bf8621f25fcc69f86c9eac73c6d3bdd5fcab2baa..c0a019658afaa734381da33a58738641119975d3 100644 --- a/mhcflurry/calibrate_percentile_ranks_command.py +++ b/mhcflurry/calibrate_percentile_ranks_command.py @@ -22,7 +22,7 @@ from .common import configure_logging, random_peptides, amino_acid_distribution from .local_parallelism import ( add_local_parallelism_args, worker_pool_with_gpu_assignments_from_args, - call_wrapped) + call_wrapped_kwargs) from .cluster_parallelism import ( add_cluster_parallelism_args, cluster_results_from_args) @@ -152,18 +152,19 @@ def run(argv=sys.argv[1:]): serial_run = not args.cluster_parallelism and args.num_jobs == 0 worker_pool = None start = time.time() + work_items = [{"allele": allele} for allele in alleles] if serial_run: # Serial run print("Running in serial.") results = ( - do_calibrate_percentile_ranks(allele) for allele in alleles) + do_calibrate_percentile_ranks(**item) for item in work_items) elif args.cluster_parallelism: # Run using separate processes HPC cluster. print("Running on cluster.") results = cluster_results_from_args( args, work_function=do_calibrate_percentile_ranks, - work_items=alleles, + work_items=work_items, constant_data=GLOBAL_DATA, result_serialization_method="pickle") else: @@ -171,17 +172,17 @@ def run(argv=sys.argv[1:]): print("Worker pool", worker_pool) assert worker_pool is not None results = worker_pool.imap_unordered( - partial(call_wrapped, do_calibrate_percentile_ranks), - alleles, + partial(call_wrapped_kwargs, do_calibrate_percentile_ranks), + work_items, chunksize=1) summary_results_lists = collections.defaultdict(list) - for (transforms, summary_results) in tqdm.tqdm(results, total=len(alleles)): + for (transforms, summary_results) in tqdm.tqdm(results, total=len(work_items)): predictor.allele_to_percent_rank_transform.update(transforms) if summary_results is not None: for (item, value) in summary_results.items(): summary_results_lists[item].append(value) - print("Done calibrating %d alleles." % len(alleles)) + print("Done calibrating %d alleles." % len(work_items)) if summary_results_lists: for (name, lst) in summary_results_lists.items(): df = pandas.concat(lst, ignore_index=True)