Skip to content
Snippets Groups Projects
Commit b82888ef authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix

parent 7df3a99e
No related branches found
No related tags found
No related merge requests found
...@@ -21,7 +21,7 @@ def add_local_parallelism_args(parser): ...@@ -21,7 +21,7 @@ def add_local_parallelism_args(parser):
default=0, default=0,
type=int, type=int,
metavar="N", metavar="N",
help="Number of processes to parallelize training over. Experimental. " help="Number of local processes to parallelize training over. "
"Set to 0 for serial run. Default: %(default)s.") "Set to 0 for serial run. Default: %(default)s.")
group.add_argument( group.add_argument(
"--backend", "--backend",
......
...@@ -398,7 +398,7 @@ def train_models(args): ...@@ -398,7 +398,7 @@ def train_models(args):
print("Found %d work items, of which %d are incomplete and will run now." % ( print("Found %d work items, of which %d are incomplete and will run now." % (
len(all_work_items), len(work_items))) len(all_work_items), len(work_items)))
serial_run = args.num_jobs == 0 serial_run = not args.cluster_parallelism and args.num_jobs == 0
# The estimated time to completion is more accurate if we randomize # The estimated time to completion is more accurate if we randomize
# the order of the work. # the order of the work.
...@@ -415,7 +415,19 @@ def train_models(args): ...@@ -415,7 +415,19 @@ def train_models(args):
start = time.time() start = time.time()
if args.cluster_parallelism: worker_pool = None
if serial_run:
# Run in serial. Every worker is passed the same predictor,
# which it adds models to, so no merging is required. It also saves
# as it goes so no saving is required at the end.
print("Processing %d work items in serial." % len(work_items))
for _ in tqdm.trange(len(work_items)):
item = work_items.pop(0) # want to keep freeing up memory
work_predictor = train_model(**item)
assert work_predictor is predictor
assert not work_items
results_generator = None
elif args.cluster_parallelism:
# Run using separate processes HPC cluster. # Run using separate processes HPC cluster.
results_generator = cluster_results_from_args( results_generator = cluster_results_from_args(
args, args,
...@@ -423,35 +435,21 @@ def train_models(args): ...@@ -423,35 +435,21 @@ def train_models(args):
work_items=work_items, work_items=work_items,
constant_data=GLOBAL_DATA, constant_data=GLOBAL_DATA,
result_serialization_method="save_predictor") result_serialization_method="save_predictor")
worker_pool = None
else: else:
worker_pool = worker_pool_with_gpu_assignments_from_args(args) worker_pool = worker_pool_with_gpu_assignments_from_args(args)
print("Worker pool", worker_pool) print("Worker pool", worker_pool)
assert worker_pool is not None
if worker_pool: print("Processing %d work items in parallel." % len(work_items))
print("Processing %d work items in parallel." % len(work_items)) assert not serial_run
assert not serial_run
results_generator = worker_pool.imap_unordered( results_generator = worker_pool.imap_unordered(
partial(call_wrapped_kwargs, train_model), partial(call_wrapped_kwargs, train_model),
work_items, work_items,
chunksize=1) chunksize=1)
else:
# Run in serial. In this case, every worker is passed the same predictor,
# which it adds models to, so no merging is required. It also saves
# as it goes so no saving is required at the end.
print("Processing %d work items in serial." % len(work_items))
assert serial_run
for _ in tqdm.trange(len(work_items)):
item = work_items.pop(0) # want to keep freeing up memory
work_predictor = train_model(**item)
assert work_predictor is predictor
assert not work_items
results_generator = None
if results_generator: if results_generator:
#for new_predictor in tqdm.tqdm(results_generator, total=len(work_items)): for new_predictor in tqdm.tqdm(results_generator, total=len(work_items)):
for new_predictor in results_generator:
save_start = time.time() save_start = time.time()
(new_model_name,) = predictor.merge_in_place([new_predictor]) (new_model_name,) = predictor.merge_in_place([new_predictor])
predictor.save( predictor.save(
...@@ -465,7 +463,6 @@ def train_models(args): ...@@ -465,7 +463,6 @@ def train_models(args):
time.time() - save_start, time.time() - save_start,
args.out_models_dir)) args.out_models_dir))
print("Saving final predictor to: %s" % args.out_models_dir)
# We want the final predictor to support all alleles with sequences, not # We want the final predictor to support all alleles with sequences, not
# just those we actually used for model training. # just those we actually used for model training.
predictor.allele_to_sequence = ( predictor.allele_to_sequence = (
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment