Skip to content
Snippets Groups Projects
Commit e7a59d70 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

fix

parent 86c91f9e
No related branches found
No related tags found
No related merge requests found
import traceback
import sys
import os
import time
import signal
import argparse
import pickle
import subprocess
from .local_parallelism import call_wrapped_kwargs
from .class1_affinity_predictor import Class1AffinityPredictor
try:
from shlex import quote
except ImportError:
from pipes import quote
def add_cluster_parallelism_args(parser):
group = parser.add_argument_group("Cluster parallelism")
group.add_argument(
"--cluster-parallelism",
default=False,
action="store_true")
group.add_argument(
"--cluster-submit-command",
default='sh',
help="Default: %(default)s")
group.add_argument(
"--cluster-results-workdir",
default='./cluster-workdir',
help="Default: %(default)s")
group.add_argument(
'--cluster-script-prefix-path',
help="",
)
def cluster_results_from_args(
args,
work_function,
work_items,
constant_data=None,
result_serialization_method="pickle"):
return cluster_results(
work_function=work_function,
work_items=work_items,
constant_data=constant_data,
submit_command=args.cluster_submit_command,
results_workdir=args.cluster_results_workdir,
script_prefix_path=args.cluster_script_prefix_path,
result_serialization_method=result_serialization_method
)
def cluster_results(
work_function,
work_items,
constant_data=None,
submit_command="sh",
results_workdir="./cluster-workdir",
script_prefix_path=None,
result_serialization_method="pickle"):
constant_payload = {
'constant_data': constant_data,
'function': work_function,
}
work_dir = os.path.join(
os.path.abspath(results_workdir),
str(int(time.time())))
os.mkdir(work_dir)
constant_payload_path = os.path.join(work_dir, "global_data.pkl")
with open(constant_payload_path, "wb") as fd:
pickle.dump(constant_payload, fd, protocol=pickle.HIGHEST_PROTOCOL)
print("Wrote:", constant_payload_path)
if script_prefix_path:
with open(script_prefix_path) as fd:
script_prefix = fd.read()
else:
script_prefix = "#!/bin/bash"
result_paths = []
for (i, item) in enumerate(work_items):
item_workdir = os.path.join(
work_dir, "work-item.%03d-of-%03d" % (i, len(work_items)))
os.mkdir(item_workdir)
item_data_path = os.path.join(item_workdir, "data.pkl")
with open(item_data_path, "wb") as fd:
pickle.dump(item, fd, protocol=pickle.HIGHEST_PROTOCOL)
print("Wrote:", item_data_path)
item_result_path = os.path.join(item_workdir, "result")
item_error_path = os.path.join(item_workdir, "error.pkl")
item_finished_path = os.path.join(item_workdir, "COMPLETE")
item_script_pieces = [
script_prefix.format(work_item_num=i, work_dir=item_workdir)
]
item_script_pieces.append(" ".join([
"_mhcflurry-cluster-worker-entry-point",
"--constant-data", quote(constant_payload_path),
"--worker-data", quote(item_data_path),
"--result-out", quote(item_result_path),
"--error-out", quote(item_error_path),
"--complete-dir", quote(item_finished_path),
"--result-serialization-method", result_serialization_method,
]))
item_script = "\n".join(item_script_pieces)
item_script_path = os.path.join(
item_workdir,
"run.%d.sh" % i)
with open(item_script_path, "w") as fd:
fd.write(item_script)
print("Wrote:", item_script_path)
launch_command = " ".join([
submit_command, "<", quote(item_script_path)
])
subprocess.check_call(launch_command, shell=True)
print("Invoked", launch_command)
result_paths.append(
(item_finished_path, item_result_path, item_error_path))
def result_generator():
start = time.time()
for (complete_dir, result_path, error_path) in result_paths:
while not os.path.exists(complete_dir):
print("[%0.1f sec elapsed] waiting on" % (time.time() - start),
complete_dir)
time.sleep(60)
print("Complete", complete_dir)
if os.path.exists(error_path):
print("Error path exists", error_path)
with open(error_path, "rb") as fd:
exception = pickle.load(fd)
raise exception
if os.path.exists(result_path):
print("Result path exists", error_path)
if result_serialization_method == "save_predictor":
result = Class1AffinityPredictor.load(result_path)
else:
assert result_serialization_method == "pickle"
with open(result_path, "rb") as fd:
result = pickle.load(fd)
yield result
else:
raise RuntimeError("Results do not exist", result_path)
return result_generator()
parser = argparse.ArgumentParser(
usage="Entry point for cluster workers")
parser.add_argument(
"--constant-data",
required=True,
)
parser.add_argument(
"--worker-data",
required=True,
)
parser.add_argument(
"--result-out",
required=True,
)
parser.add_argument(
"--error-out",
required=True,
)
parser.add_argument(
"--complete-dir",
)
parser.add_argument(
"--result-serialization-method",
choices=("pickle", "save_predictor"),
default="pickle")
def worker_entry_point(argv=sys.argv[1:]):
# On sigusr1 print stack trace
print("To show stack trace, run:\nkill -s USR1 %d" % os.getpid())
signal.signal(signal.SIGUSR1, lambda sig, frame: traceback.print_stack())
args = parser.parse_args(argv)
with open(args.constant_data, "rb") as fd:
constant_payload = pickle.load(fd)
with open(args.worker_data, "rb") as fd:
worker_data = pickle.load(fd)
kwargs = dict(worker_data)
if constant_payload['constant_data'] is not None:
kwargs['constant_data'] = constant_payload['constant_data']
try:
result = call_wrapped_kwargs(constant_payload['function'], kwargs)
if args.result_serialization_method == 'save_predictor':
result.save(args.result_out)
else:
with open(args.out, "wb") as fd:
pickle.dump(result, fd, pickle.HIGHEST_PROTOCOL)
print("Wrote:", args.result_out)
except Exception as e:
print("Exception: ", e)
with open(args.error_out, "wb") as fd:
pickle.dump(e, fd, pickle.HIGHEST_PROTOCOL)
print("Wrote:", args.error_out)
raise
finally:
if args.complete_dir:
os.mkdir(args.complete_dir)
print("Created: ", args.complete_dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment