Skip to content
Snippets Groups Projects
Commit 11c2c332 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

add support for dill

parent 792a9aa3
No related merge requests found
...@@ -177,7 +177,7 @@ def run(argv=sys.argv[1:]): ...@@ -177,7 +177,7 @@ def run(argv=sys.argv[1:]):
work_function=do_predictions, work_function=do_predictions,
work_items=work_items, work_items=work_items,
constant_data=GLOBAL_DATA, constant_data=GLOBAL_DATA,
result_serialization_method="pickle", result_serialization_method="dill",
clear_constant_data=True) clear_constant_data=True)
else: else:
worker_pool = worker_pool_with_gpu_assignments_from_args(args) worker_pool = worker_pool_with_gpu_assignments_from_args(args)
......
...@@ -64,6 +64,7 @@ def cluster_results_from_args( ...@@ -64,6 +64,7 @@ def cluster_results_from_args(
work_function, work_function,
work_items, work_items,
constant_data=None, constant_data=None,
input_serialization_method="pickle",
result_serialization_method="pickle", result_serialization_method="pickle",
clear_constant_data=False): clear_constant_data=False):
""" """
...@@ -95,6 +96,7 @@ def cluster_results_from_args( ...@@ -95,6 +96,7 @@ def cluster_results_from_args(
results_workdir=args.cluster_results_workdir, results_workdir=args.cluster_results_workdir,
additional_complete_file=args.additional_complete_file, additional_complete_file=args.additional_complete_file,
script_prefix_path=args.cluster_script_prefix_path, script_prefix_path=args.cluster_script_prefix_path,
input_serialization_method=input_serialization_method,
result_serialization_method=result_serialization_method, result_serialization_method=result_serialization_method,
max_retries=args.cluster_max_retries, max_retries=args.cluster_max_retries,
clear_constant_data=clear_constant_data clear_constant_data=clear_constant_data
...@@ -109,6 +111,7 @@ def cluster_results( ...@@ -109,6 +111,7 @@ def cluster_results(
results_workdir="./cluster-workdir", results_workdir="./cluster-workdir",
additional_complete_file=None, additional_complete_file=None,
script_prefix_path=None, script_prefix_path=None,
input_serialization_method="pickle",
result_serialization_method="pickle", result_serialization_method="pickle",
max_retries=3, max_retries=3,
clear_constant_data=False): clear_constant_data=False):
...@@ -156,6 +159,13 @@ def cluster_results( ...@@ -156,6 +159,13 @@ def cluster_results(
generator of B generator of B
""" """
if input_serialization_method == "dill":
import dill
input_serialization_module = dill
else:
assert input_serialization_method == "pickle"
input_serialization_module = pickle
constant_payload = { constant_payload = {
'constant_data': constant_data, 'constant_data': constant_data,
'function': work_function, 'function': work_function,
...@@ -165,9 +175,14 @@ def cluster_results( ...@@ -165,9 +175,14 @@ def cluster_results(
str(int(time.time()))) str(int(time.time())))
os.mkdir(work_dir) os.mkdir(work_dir)
constant_payload_path = os.path.join(work_dir, "global_data.pkl") constant_payload_path = os.path.join(
work_dir,
"global_data." + input_serialization_method)
with open(constant_payload_path, "wb") as fd: with open(constant_payload_path, "wb") as fd:
pickle.dump(constant_payload, fd, protocol=pickle.HIGHEST_PROTOCOL) input_serialization_module.dump(
constant_payload,
fd,
protocol=input_serialization_module.HIGHEST_PROTOCOL)
print("Wrote:", constant_payload_path) print("Wrote:", constant_payload_path)
if clear_constant_data: if clear_constant_data:
constant_data.clear() constant_data.clear()
...@@ -186,9 +201,11 @@ def cluster_results( ...@@ -186,9 +201,11 @@ def cluster_results(
work_dir, "work-item.%03d-of-%03d" % (i, len(work_items))) work_dir, "work-item.%03d-of-%03d" % (i, len(work_items)))
os.mkdir(item_workdir) os.mkdir(item_workdir)
item_data_path = os.path.join(item_workdir, "data.pkl") item_data_path = os.path.join(
item_workdir, "data." + input_serialization_method)
with open(item_data_path, "wb") as fd: with open(item_data_path, "wb") as fd:
pickle.dump(item, fd, protocol=pickle.HIGHEST_PROTOCOL) input_serialization_module.dump(
item, fd, protocol=input_serialization_module.HIGHEST_PROTOCOL)
print("Wrote:", item_data_path) print("Wrote:", item_data_path)
item_result_path = os.path.join(item_workdir, "result") item_result_path = os.path.join(item_workdir, "result")
...@@ -205,6 +222,7 @@ def cluster_results( ...@@ -205,6 +222,7 @@ def cluster_results(
"--result-out", quote(item_result_path), "--result-out", quote(item_result_path),
"--error-out", quote(item_error_path), "--error-out", quote(item_error_path),
"--complete-dir", quote(item_finished_path), "--complete-dir", quote(item_finished_path),
"--input-serialization-method", input_serialization_method,
"--result-serialization-method", result_serialization_method, "--result-serialization-method", result_serialization_method,
])) ]))
item_script = "\n".join(item_script_pieces) item_script = "\n".join(item_script_pieces)
...@@ -271,6 +289,8 @@ def cluster_results( ...@@ -271,6 +289,8 @@ def cluster_results(
with open(error_path, "rb") as fd: with open(error_path, "rb") as fd:
exception = pickle.load(fd) exception = pickle.load(fd)
print(exception) print(exception)
else:
exception = RuntimeError("Error, but no exception saved")
if not os.path.exists(result_path): if not os.path.exists(result_path):
print("Result path does NOT exist", result_path) print("Result path does NOT exist", result_path)
...@@ -297,10 +317,14 @@ def cluster_results( ...@@ -297,10 +317,14 @@ def cluster_results(
print("Result path exists", result_path) print("Result path exists", result_path)
if result_serialization_method == "save_predictor": if result_serialization_method == "save_predictor":
result = Class1AffinityPredictor.load(result_path) result = Class1AffinityPredictor.load(result_path)
else: elif result_serialization_method == "pickle":
assert result_serialization_method == "pickle"
with open(result_path, "rb") as fd: with open(result_path, "rb") as fd:
result = pickle.load(fd) result = pickle.load(fd)
else:
raise ValueError(
"Unsupported serialization method",
result_serialization_method)
yield result yield result
else: else:
raise RuntimeError("Results do not exist", result_path) raise RuntimeError("Results do not exist", result_path)
...@@ -329,6 +353,10 @@ parser.add_argument( ...@@ -329,6 +353,10 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--complete-dir", "--complete-dir",
) )
parser.add_argument(
"--input-serialization-method",
choices=("pickle", "dill"),
default="pickle")
parser.add_argument( parser.add_argument(
"--result-serialization-method", "--result-serialization-method",
choices=("pickle", "save_predictor"), choices=("pickle", "save_predictor"),
...@@ -349,11 +377,18 @@ def worker_entry_point(argv=sys.argv[1:]): ...@@ -349,11 +377,18 @@ def worker_entry_point(argv=sys.argv[1:]):
args = parser.parse_args(argv) args = parser.parse_args(argv)
if args.input_serialization_method == "dill":
import dill
input_serialization_module = dill
else:
assert args.input_serialization_method == "pickle"
input_serialization_module = pickle
with open(args.constant_data, "rb") as fd: with open(args.constant_data, "rb") as fd:
constant_payload = pickle.load(fd) constant_payload = input_serialization_module.load(fd)
with open(args.worker_data, "rb") as fd: with open(args.worker_data, "rb") as fd:
worker_data = pickle.load(fd) worker_data = input_serialization_module.load(fd)
kwargs = dict(worker_data) kwargs = dict(worker_data)
if constant_payload['constant_data'] is not None: if constant_payload['constant_data'] is not None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment