From c7b9df69ed30ed8a4550f03bb4a077fe352d62c2 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Tue, 1 Oct 2019 16:39:36 -0400 Subject: [PATCH] add support for dill --- .../data_mass_spec_benchmark/run_mhcflurry.py | 2 +- mhcflurry/cluster_parallelism.py | 51 ++++++++++++++++--- 2 files changed, 44 insertions(+), 9 deletions(-) diff --git a/downloads-generation/data_mass_spec_benchmark/run_mhcflurry.py b/downloads-generation/data_mass_spec_benchmark/run_mhcflurry.py index 8cb55779..37000577 100644 --- a/downloads-generation/data_mass_spec_benchmark/run_mhcflurry.py +++ b/downloads-generation/data_mass_spec_benchmark/run_mhcflurry.py @@ -177,7 +177,7 @@ def run(argv=sys.argv[1:]): work_function=do_predictions, work_items=work_items, constant_data=GLOBAL_DATA, - result_serialization_method="pickle", + result_serialization_method="dill", clear_constant_data=True) else: worker_pool = worker_pool_with_gpu_assignments_from_args(args) diff --git a/mhcflurry/cluster_parallelism.py b/mhcflurry/cluster_parallelism.py index 0669ac21..ae49f5ee 100644 --- a/mhcflurry/cluster_parallelism.py +++ b/mhcflurry/cluster_parallelism.py @@ -64,6 +64,7 @@ def cluster_results_from_args( work_function, work_items, constant_data=None, + input_serialization_method="pickle", result_serialization_method="pickle", clear_constant_data=False): """ @@ -95,6 +96,7 @@ def cluster_results_from_args( results_workdir=args.cluster_results_workdir, additional_complete_file=args.additional_complete_file, script_prefix_path=args.cluster_script_prefix_path, + input_serialization_method=input_serialization_method, result_serialization_method=result_serialization_method, max_retries=args.cluster_max_retries, clear_constant_data=clear_constant_data @@ -109,6 +111,7 @@ def cluster_results( results_workdir="./cluster-workdir", additional_complete_file=None, script_prefix_path=None, + input_serialization_method="pickle", result_serialization_method="pickle", max_retries=3, clear_constant_data=False): @@ -156,6 +159,13 @@ def cluster_results( generator of B """ + if input_serialization_method == "dill": + import dill + input_serialization_module = dill + else: + assert input_serialization_method == "pickle" + input_serialization_module = pickle + constant_payload = { 'constant_data': constant_data, 'function': work_function, @@ -165,9 +175,14 @@ def cluster_results( str(int(time.time()))) os.mkdir(work_dir) - constant_payload_path = os.path.join(work_dir, "global_data.pkl") + constant_payload_path = os.path.join( + work_dir, + "global_data." + input_serialization_method) with open(constant_payload_path, "wb") as fd: - pickle.dump(constant_payload, fd, protocol=pickle.HIGHEST_PROTOCOL) + input_serialization_module.dump( + constant_payload, + fd, + protocol=input_serialization_module.HIGHEST_PROTOCOL) print("Wrote:", constant_payload_path) if clear_constant_data: constant_data.clear() @@ -186,9 +201,11 @@ def cluster_results( work_dir, "work-item.%03d-of-%03d" % (i, len(work_items))) os.mkdir(item_workdir) - item_data_path = os.path.join(item_workdir, "data.pkl") + item_data_path = os.path.join( + item_workdir, "data." + input_serialization_method) with open(item_data_path, "wb") as fd: - pickle.dump(item, fd, protocol=pickle.HIGHEST_PROTOCOL) + input_serialization_module.dump( + item, fd, protocol=input_serialization_module.HIGHEST_PROTOCOL) print("Wrote:", item_data_path) item_result_path = os.path.join(item_workdir, "result") @@ -205,6 +222,7 @@ def cluster_results( "--result-out", quote(item_result_path), "--error-out", quote(item_error_path), "--complete-dir", quote(item_finished_path), + "--input-serialization-method", input_serialization_method, "--result-serialization-method", result_serialization_method, ])) item_script = "\n".join(item_script_pieces) @@ -271,6 +289,8 @@ def cluster_results( with open(error_path, "rb") as fd: exception = pickle.load(fd) print(exception) + else: + exception = RuntimeError("Error, but no exception saved") if not os.path.exists(result_path): print("Result path does NOT exist", result_path) @@ -297,10 +317,14 @@ def cluster_results( print("Result path exists", result_path) if result_serialization_method == "save_predictor": result = Class1AffinityPredictor.load(result_path) - else: - assert result_serialization_method == "pickle" + elif result_serialization_method == "pickle": with open(result_path, "rb") as fd: result = pickle.load(fd) + else: + raise ValueError( + "Unsupported serialization method", + result_serialization_method) + yield result else: raise RuntimeError("Results do not exist", result_path) @@ -329,6 +353,10 @@ parser.add_argument( parser.add_argument( "--complete-dir", ) +parser.add_argument( + "--input-serialization-method", + choices=("pickle", "dill"), + default="pickle") parser.add_argument( "--result-serialization-method", choices=("pickle", "save_predictor"), @@ -349,11 +377,18 @@ def worker_entry_point(argv=sys.argv[1:]): args = parser.parse_args(argv) + if args.input_serialization_method == "dill": + import dill + input_serialization_module = dill + else: + assert args.input_serialization_method == "pickle" + input_serialization_module = pickle + with open(args.constant_data, "rb") as fd: - constant_payload = pickle.load(fd) + constant_payload = input_serialization_module.load(fd) with open(args.worker_data, "rb") as fd: - worker_data = pickle.load(fd) + worker_data = input_serialization_module.load(fd) kwargs = dict(worker_data) if constant_payload['constant_data'] is not None: -- GitLab