Skip to content
Snippets Groups Projects
Commit 11c2c332 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

add support for dill

parent 792a9aa3
No related merge requests found
......@@ -177,7 +177,7 @@ def run(argv=sys.argv[1:]):
work_function=do_predictions,
work_items=work_items,
constant_data=GLOBAL_DATA,
result_serialization_method="pickle",
result_serialization_method="dill",
clear_constant_data=True)
else:
worker_pool = worker_pool_with_gpu_assignments_from_args(args)
......
......@@ -64,6 +64,7 @@ def cluster_results_from_args(
work_function,
work_items,
constant_data=None,
input_serialization_method="pickle",
result_serialization_method="pickle",
clear_constant_data=False):
"""
......@@ -95,6 +96,7 @@ def cluster_results_from_args(
results_workdir=args.cluster_results_workdir,
additional_complete_file=args.additional_complete_file,
script_prefix_path=args.cluster_script_prefix_path,
input_serialization_method=input_serialization_method,
result_serialization_method=result_serialization_method,
max_retries=args.cluster_max_retries,
clear_constant_data=clear_constant_data
......@@ -109,6 +111,7 @@ def cluster_results(
results_workdir="./cluster-workdir",
additional_complete_file=None,
script_prefix_path=None,
input_serialization_method="pickle",
result_serialization_method="pickle",
max_retries=3,
clear_constant_data=False):
......@@ -156,6 +159,13 @@ def cluster_results(
generator of B
"""
if input_serialization_method == "dill":
import dill
input_serialization_module = dill
else:
assert input_serialization_method == "pickle"
input_serialization_module = pickle
constant_payload = {
'constant_data': constant_data,
'function': work_function,
......@@ -165,9 +175,14 @@ def cluster_results(
str(int(time.time())))
os.mkdir(work_dir)
constant_payload_path = os.path.join(work_dir, "global_data.pkl")
constant_payload_path = os.path.join(
work_dir,
"global_data." + input_serialization_method)
with open(constant_payload_path, "wb") as fd:
pickle.dump(constant_payload, fd, protocol=pickle.HIGHEST_PROTOCOL)
input_serialization_module.dump(
constant_payload,
fd,
protocol=input_serialization_module.HIGHEST_PROTOCOL)
print("Wrote:", constant_payload_path)
if clear_constant_data:
constant_data.clear()
......@@ -186,9 +201,11 @@ def cluster_results(
work_dir, "work-item.%03d-of-%03d" % (i, len(work_items)))
os.mkdir(item_workdir)
item_data_path = os.path.join(item_workdir, "data.pkl")
item_data_path = os.path.join(
item_workdir, "data." + input_serialization_method)
with open(item_data_path, "wb") as fd:
pickle.dump(item, fd, protocol=pickle.HIGHEST_PROTOCOL)
input_serialization_module.dump(
item, fd, protocol=input_serialization_module.HIGHEST_PROTOCOL)
print("Wrote:", item_data_path)
item_result_path = os.path.join(item_workdir, "result")
......@@ -205,6 +222,7 @@ def cluster_results(
"--result-out", quote(item_result_path),
"--error-out", quote(item_error_path),
"--complete-dir", quote(item_finished_path),
"--input-serialization-method", input_serialization_method,
"--result-serialization-method", result_serialization_method,
]))
item_script = "\n".join(item_script_pieces)
......@@ -271,6 +289,8 @@ def cluster_results(
with open(error_path, "rb") as fd:
exception = pickle.load(fd)
print(exception)
else:
exception = RuntimeError("Error, but no exception saved")
if not os.path.exists(result_path):
print("Result path does NOT exist", result_path)
......@@ -297,10 +317,14 @@ def cluster_results(
print("Result path exists", result_path)
if result_serialization_method == "save_predictor":
result = Class1AffinityPredictor.load(result_path)
else:
assert result_serialization_method == "pickle"
elif result_serialization_method == "pickle":
with open(result_path, "rb") as fd:
result = pickle.load(fd)
else:
raise ValueError(
"Unsupported serialization method",
result_serialization_method)
yield result
else:
raise RuntimeError("Results do not exist", result_path)
......@@ -329,6 +353,10 @@ parser.add_argument(
parser.add_argument(
"--complete-dir",
)
parser.add_argument(
"--input-serialization-method",
choices=("pickle", "dill"),
default="pickle")
parser.add_argument(
"--result-serialization-method",
choices=("pickle", "save_predictor"),
......@@ -349,11 +377,18 @@ def worker_entry_point(argv=sys.argv[1:]):
args = parser.parse_args(argv)
if args.input_serialization_method == "dill":
import dill
input_serialization_module = dill
else:
assert args.input_serialization_method == "pickle"
input_serialization_module = pickle
with open(args.constant_data, "rb") as fd:
constant_payload = pickle.load(fd)
constant_payload = input_serialization_module.load(fd)
with open(args.worker_data, "rb") as fd:
worker_data = pickle.load(fd)
worker_data = input_serialization_module.load(fd)
kwargs = dict(worker_data)
if constant_payload['constant_data'] is not None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment