import traceback import sys import os import time import signal import argparse import pickle import subprocess import shutil from .local_parallelism import call_wrapped_kwargs from .class1_affinity_predictor import Class1AffinityPredictor try: from shlex import quote except ImportError: from pipes import quote def add_cluster_parallelism_args(parser): group = parser.add_argument_group("Cluster parallelism") group.add_argument( "--cluster-parallelism", default=False, action="store_true") group.add_argument( "--cluster-submit-command", default='sh', help="Default: %(default)s") group.add_argument( "--cluster-results-workdir", default='./cluster-workdir', help="Default: %(default)s") group.add_argument( '--cluster-script-prefix-path', help="", ) group.add_argument('--cluster-max-retries', help="", default=3) def cluster_results_from_args( args, work_function, work_items, constant_data=None, result_serialization_method="pickle"): return cluster_results( work_function=work_function, work_items=work_items, constant_data=constant_data, submit_command=args.cluster_submit_command, results_workdir=args.cluster_results_workdir, script_prefix_path=args.cluster_script_prefix_path, result_serialization_method=result_serialization_method ) def cluster_results( work_function, work_items, constant_data=None, submit_command="sh", results_workdir="./cluster-workdir", script_prefix_path=None, result_serialization_method="pickle", max_retries=3): constant_payload = { 'constant_data': constant_data, 'function': work_function, } work_dir = os.path.join( os.path.abspath(results_workdir), str(int(time.time()))) os.mkdir(work_dir) constant_payload_path = os.path.join(work_dir, "global_data.pkl") with open(constant_payload_path, "wb") as fd: pickle.dump(constant_payload, fd, protocol=pickle.HIGHEST_PROTOCOL) print("Wrote:", constant_payload_path) if script_prefix_path: with open(script_prefix_path) as fd: script_prefix = fd.read() else: script_prefix = "#!/bin/bash" result_items = [] for (i, item) in enumerate(work_items): item_workdir = os.path.join( work_dir, "work-item.%03d-of-%03d" % (i, len(work_items))) os.mkdir(item_workdir) item_data_path = os.path.join(item_workdir, "data.pkl") with open(item_data_path, "wb") as fd: pickle.dump(item, fd, protocol=pickle.HIGHEST_PROTOCOL) print("Wrote:", item_data_path) item_result_path = os.path.join(item_workdir, "result") item_error_path = os.path.join(item_workdir, "error.pkl") item_finished_path = os.path.join(item_workdir, "COMPLETE") item_script_pieces = [ script_prefix.format(work_item_num=i, work_dir=item_workdir) ] item_script_pieces.append(" ".join([ "_mhcflurry-cluster-worker-entry-point", "--constant-data", quote(constant_payload_path), "--worker-data", quote(item_data_path), "--result-out", quote(item_result_path), "--error-out", quote(item_error_path), "--complete-dir", quote(item_finished_path), "--result-serialization-method", result_serialization_method, ])) item_script = "\n".join(item_script_pieces) item_script_path = os.path.join( item_workdir, "run.%d.sh" % i) with open(item_script_path, "w") as fd: fd.write(item_script) print("Wrote:", item_script_path) launch_command = " ".join([ submit_command, "<", quote(item_script_path) ]) subprocess.check_call(launch_command, shell=True) print("Invoked", launch_command) result_items.append({ 'work_dir': item_workdir, 'finished_path': item_finished_path, 'result_path': item_result_path, 'error_path': item_error_path, 'retry_num': 0, 'launch_command': launch_command, }) def result_generator(): start = time.time() while result_items: print("[%0.1f sec elapsed] waiting on %d / %d items." % ( time.time() - start, len(result_items), len(work_items))) while True: result_item = None for d in result_items: if os.path.exists(d['finished_path']): result_item = d break if result_item is None: time.sleep(60) else: result_items.remove(result_item) break complete_dir = result_item['finished_path'] result_path = result_item['result_path'] error_path = result_item['error_path'] retry_num = result_item['retry_num'] launch_command = result_item['launch_command'] print("[%0.1f sec elapsed] processing item %s" % ( time.time() - start, result_item)) if os.path.exists(error_path): print("Error path exists", error_path) with open(error_path, "rb") as fd: exception = pickle.load(fd) print(exception) if retry_num < max_retries: print("Relaunching", launch_command) attempt_dir = os.path.join( result_item['work_dir'], "attempt.%d" % retry_num) shutil.move(complete_dir, attempt_dir) shutil.move(error_path, attempt_dir) subprocess.check_call(launch_command, shell=True) print("Invoked", launch_command) result_item['retry_num'] += 1 result_items.append(result_item) continue else: print("Max retries exceeded", max_retries) raise exception if os.path.exists(result_path): print("Result path exists", result_path) if result_serialization_method == "save_predictor": result = Class1AffinityPredictor.load(result_path) else: assert result_serialization_method == "pickle" with open(result_path, "rb") as fd: result = pickle.load(fd) yield result else: raise RuntimeError("Results do not exist", result_path) return result_generator() parser = argparse.ArgumentParser( usage="Entry point for cluster workers") parser.add_argument( "--constant-data", required=True, ) parser.add_argument( "--worker-data", required=True, ) parser.add_argument( "--result-out", required=True, ) parser.add_argument( "--error-out", required=True, ) parser.add_argument( "--complete-dir", ) parser.add_argument( "--result-serialization-method", choices=("pickle", "save_predictor"), default="pickle") def worker_entry_point(argv=sys.argv[1:]): # On sigusr1 print stack trace print("To show stack trace, run:\nkill -s USR1 %d" % os.getpid()) signal.signal(signal.SIGUSR1, lambda sig, frame: traceback.print_stack()) args = parser.parse_args(argv) with open(args.constant_data, "rb") as fd: constant_payload = pickle.load(fd) with open(args.worker_data, "rb") as fd: worker_data = pickle.load(fd) kwargs = dict(worker_data) if constant_payload['constant_data'] is not None: kwargs['constant_data'] = constant_payload['constant_data'] try: result = call_wrapped_kwargs(constant_payload['function'], kwargs) if args.result_serialization_method == 'save_predictor': result.save(args.result_out) else: with open(args.out, "wb") as fd: pickle.dump(result, fd, pickle.HIGHEST_PROTOCOL) print("Wrote:", args.result_out) except Exception as e: print("Exception: ", e) with open(args.error_out, "wb") as fd: pickle.dump(e, fd, pickle.HIGHEST_PROTOCOL) print("Wrote:", args.error_out) raise finally: if args.complete_dir: os.mkdir(args.complete_dir) print("Created: ", args.complete_dir)