Skip to content
Snippets Groups Projects
cluster_parallelism.py 8.55 KiB
Newer Older
Tim O'Donnell's avatar
fix
Tim O'Donnell committed
import traceback
import sys
import os
import time
import signal
import argparse
import pickle
import subprocess
Tim O'Donnell's avatar
Tim O'Donnell committed
import shutil
Tim O'Donnell's avatar
fix
Tim O'Donnell committed

from .local_parallelism import call_wrapped_kwargs
from .class1_affinity_predictor import Class1AffinityPredictor

try:
    from shlex import quote
except ImportError:
    from pipes import quote


def add_cluster_parallelism_args(parser):
    group = parser.add_argument_group("Cluster parallelism")
    group.add_argument(
        "--cluster-parallelism",
        default=False,
        action="store_true")
    group.add_argument(
        "--cluster-submit-command",
        default='sh',
        help="Default: %(default)s")
    group.add_argument(
        "--cluster-results-workdir",
        default='./cluster-workdir',
        help="Default: %(default)s")
    group.add_argument(
        '--cluster-script-prefix-path',
        help="",
    )
Tim O'Donnell's avatar
Tim O'Donnell committed
    group.add_argument('--cluster-max-retries', help="", default=3)
Tim O'Donnell's avatar
fix
Tim O'Donnell committed


def cluster_results_from_args(
        args,
        work_function,
        work_items,
        constant_data=None,
        result_serialization_method="pickle"):
    return cluster_results(
        work_function=work_function,
        work_items=work_items,
        constant_data=constant_data,
        submit_command=args.cluster_submit_command,
        results_workdir=args.cluster_results_workdir,
        script_prefix_path=args.cluster_script_prefix_path,
        result_serialization_method=result_serialization_method
    )


def cluster_results(
        work_function,
        work_items,
        constant_data=None,
        submit_command="sh",
        results_workdir="./cluster-workdir",
        script_prefix_path=None,
Tim O'Donnell's avatar
Tim O'Donnell committed
        result_serialization_method="pickle",
        max_retries=3):
Tim O'Donnell's avatar
fix
Tim O'Donnell committed

    constant_payload = {
        'constant_data': constant_data,
        'function': work_function,
    }
    work_dir = os.path.join(
        os.path.abspath(results_workdir),
        str(int(time.time())))
    os.mkdir(work_dir)

    constant_payload_path = os.path.join(work_dir, "global_data.pkl")
    with open(constant_payload_path, "wb") as fd:
        pickle.dump(constant_payload, fd, protocol=pickle.HIGHEST_PROTOCOL)
    print("Wrote:", constant_payload_path)

    if script_prefix_path:
        with open(script_prefix_path) as fd:
            script_prefix = fd.read()
    else:
        script_prefix = "#!/bin/bash"

Tim O'Donnell's avatar
Tim O'Donnell committed
    result_items = []
Tim O'Donnell's avatar
fix
Tim O'Donnell committed

    for (i, item) in enumerate(work_items):
        item_workdir = os.path.join(
            work_dir, "work-item.%03d-of-%03d" % (i, len(work_items)))
        os.mkdir(item_workdir)

        item_data_path = os.path.join(item_workdir, "data.pkl")
        with open(item_data_path, "wb") as fd:
            pickle.dump(item, fd, protocol=pickle.HIGHEST_PROTOCOL)
        print("Wrote:", item_data_path)

        item_result_path = os.path.join(item_workdir, "result")
        item_error_path = os.path.join(item_workdir, "error.pkl")
        item_finished_path = os.path.join(item_workdir, "COMPLETE")

        item_script_pieces = [
            script_prefix.format(work_item_num=i, work_dir=item_workdir)
        ]
        item_script_pieces.append(" ".join([
            "_mhcflurry-cluster-worker-entry-point",
            "--constant-data", quote(constant_payload_path),
            "--worker-data", quote(item_data_path),
            "--result-out", quote(item_result_path),
            "--error-out", quote(item_error_path),
            "--complete-dir", quote(item_finished_path),
            "--result-serialization-method", result_serialization_method,
        ]))
        item_script = "\n".join(item_script_pieces)
        item_script_path = os.path.join(
            item_workdir,
            "run.%d.sh" % i)
        with open(item_script_path, "w") as fd:
            fd.write(item_script)
        print("Wrote:", item_script_path)

        launch_command = " ".join([
            submit_command, "<", quote(item_script_path)
        ])
        subprocess.check_call(launch_command, shell=True)
        print("Invoked", launch_command)

Tim O'Donnell's avatar
Tim O'Donnell committed
        result_items.append({
            'work_dir': item_workdir,
            'finished_path': item_finished_path,
            'result_path': item_result_path,
            'error_path': item_error_path,
            'retry_num': 0,
            'launch_command': launch_command,
        })
Tim O'Donnell's avatar
fix
Tim O'Donnell committed

    def result_generator():
        start = time.time()
Tim O'Donnell's avatar
Tim O'Donnell committed
        while result_items:
Tim O'Donnell's avatar
Tim O'Donnell committed
            print("[%0.1f sec elapsed] waiting on %d / %d items." % (
                time.time() - start, len(result_items), len(work_items)))
Tim O'Donnell's avatar
Tim O'Donnell committed
            while True:
                result_item = None
                for d in result_items:
Tim O'Donnell's avatar
Tim O'Donnell committed
                    if os.path.exists(d['finished_path']):
Tim O'Donnell's avatar
Tim O'Donnell committed
                        result_item = d
                        break
                if result_item is None:
Timothy ODonnell's avatar
Timothy ODonnell committed
                    time.sleep(60)
Tim O'Donnell's avatar
Tim O'Donnell committed
                else:
Tim O'Donnell's avatar
Tim O'Donnell committed
                    result_items.remove(result_item)
Tim O'Donnell's avatar
Tim O'Donnell committed
                    break

            complete_dir = result_item['finished_path']
            result_path = result_item['result_path']
            error_path = result_item['error_path']
            retry_num = result_item['retry_num']
            launch_command = result_item['launch_command']

            print("[%0.1f sec elapsed] processing item %s" % (
                time.time() - start, result_item))
Tim O'Donnell's avatar
fix
Tim O'Donnell committed

            if os.path.exists(error_path):
                print("Error path exists", error_path)
                with open(error_path, "rb") as fd:
                    exception = pickle.load(fd)
Tim O'Donnell's avatar
Tim O'Donnell committed
                    print(exception)
                    if retry_num < max_retries:
                        print("Relaunching", launch_command)
                        attempt_dir = os.path.join(
                            result_item['work_dir'], "attempt.%d" % retry_num)
                        shutil.move(complete_dir, attempt_dir)
                        shutil.move(error_path, attempt_dir)
                        subprocess.check_call(launch_command, shell=True)
                        print("Invoked", launch_command)
                        result_item['retry_num'] += 1
                        result_items.append(result_item)
                        continue
                    else:
                        print("Max retries exceeded", max_retries)
                        raise exception
Tim O'Donnell's avatar
fix
Tim O'Donnell committed

            if os.path.exists(result_path):
Tim O'Donnell's avatar
Tim O'Donnell committed
                print("Result path exists", result_path)
Tim O'Donnell's avatar
fix
Tim O'Donnell committed
                if result_serialization_method == "save_predictor":
                    result = Class1AffinityPredictor.load(result_path)
                else:
                    assert result_serialization_method == "pickle"
                    with open(result_path, "rb") as fd:
                        result = pickle.load(fd)
                yield result
            else:
                raise RuntimeError("Results do not exist", result_path)

    return result_generator()


parser = argparse.ArgumentParser(
    usage="Entry point for cluster workers")
parser.add_argument(
    "--constant-data",
    required=True,
)
parser.add_argument(
    "--worker-data",
    required=True,
)
parser.add_argument(
    "--result-out",
    required=True,
)
parser.add_argument(
    "--error-out",
    required=True,
)
parser.add_argument(
    "--complete-dir",
)
parser.add_argument(
    "--result-serialization-method",
    choices=("pickle", "save_predictor"),
    default="pickle")


def worker_entry_point(argv=sys.argv[1:]):
    # On sigusr1 print stack trace
    print("To show stack trace, run:\nkill -s USR1 %d" % os.getpid())
    signal.signal(signal.SIGUSR1, lambda sig, frame: traceback.print_stack())

    args = parser.parse_args(argv)

    with open(args.constant_data, "rb") as fd:
        constant_payload = pickle.load(fd)

    with open(args.worker_data, "rb") as fd:
        worker_data = pickle.load(fd)

    kwargs = dict(worker_data)
    if constant_payload['constant_data'] is not None:
        kwargs['constant_data'] = constant_payload['constant_data']

    try:
        result = call_wrapped_kwargs(constant_payload['function'], kwargs)
        if args.result_serialization_method == 'save_predictor':
            result.save(args.result_out)
        else:
            with open(args.out, "wb") as fd:
                pickle.dump(result, fd, pickle.HIGHEST_PROTOCOL)
        print("Wrote:", args.result_out)
    except Exception as e:
        print("Exception: ", e)
        with open(args.error_out, "wb") as fd:
            pickle.dump(e, fd, pickle.HIGHEST_PROTOCOL)
        print("Wrote:", args.error_out)
        raise
    finally:
        if args.complete_dir:
            os.mkdir(args.complete_dir)
            print("Created: ", args.complete_dir)