Skip to content
Snippets Groups Projects
Commit 11c2c332 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

add support for dill

parent 792a9aa3
No related branches found
No related tags found
No related merge requests found
......@@ -177,7 +177,7 @@ def run(argv=sys.argv[1:]):
work_function=do_predictions,
work_items=work_items,
constant_data=GLOBAL_DATA,
result_serialization_method="pickle",
result_serialization_method="dill",
clear_constant_data=True)
else:
worker_pool = worker_pool_with_gpu_assignments_from_args(args)
......
......@@ -64,6 +64,7 @@ def cluster_results_from_args(
work_function,
work_items,
constant_data=None,
input_serialization_method="pickle",
result_serialization_method="pickle",
clear_constant_data=False):
"""
......@@ -95,6 +96,7 @@ def cluster_results_from_args(
results_workdir=args.cluster_results_workdir,
additional_complete_file=args.additional_complete_file,
script_prefix_path=args.cluster_script_prefix_path,
input_serialization_method=input_serialization_method,
result_serialization_method=result_serialization_method,
max_retries=args.cluster_max_retries,
clear_constant_data=clear_constant_data
......@@ -109,6 +111,7 @@ def cluster_results(
results_workdir="./cluster-workdir",
additional_complete_file=None,
script_prefix_path=None,
input_serialization_method="pickle",
result_serialization_method="pickle",
max_retries=3,
clear_constant_data=False):
......@@ -156,6 +159,13 @@ def cluster_results(
generator of B
"""
if input_serialization_method == "dill":
import dill
input_serialization_module = dill
else:
assert input_serialization_method == "pickle"
input_serialization_module = pickle
constant_payload = {
'constant_data': constant_data,
'function': work_function,
......@@ -165,9 +175,14 @@ def cluster_results(
str(int(time.time())))
os.mkdir(work_dir)
constant_payload_path = os.path.join(work_dir, "global_data.pkl")
constant_payload_path = os.path.join(
work_dir,
"global_data." + input_serialization_method)
with open(constant_payload_path, "wb") as fd:
pickle.dump(constant_payload, fd, protocol=pickle.HIGHEST_PROTOCOL)
input_serialization_module.dump(
constant_payload,
fd,
protocol=input_serialization_module.HIGHEST_PROTOCOL)
print("Wrote:", constant_payload_path)
if clear_constant_data:
constant_data.clear()
......@@ -186,9 +201,11 @@ def cluster_results(
work_dir, "work-item.%03d-of-%03d" % (i, len(work_items)))
os.mkdir(item_workdir)
item_data_path = os.path.join(item_workdir, "data.pkl")
item_data_path = os.path.join(
item_workdir, "data." + input_serialization_method)
with open(item_data_path, "wb") as fd:
pickle.dump(item, fd, protocol=pickle.HIGHEST_PROTOCOL)
input_serialization_module.dump(
item, fd, protocol=input_serialization_module.HIGHEST_PROTOCOL)
print("Wrote:", item_data_path)
item_result_path = os.path.join(item_workdir, "result")
......@@ -205,6 +222,7 @@ def cluster_results(
"--result-out", quote(item_result_path),
"--error-out", quote(item_error_path),
"--complete-dir", quote(item_finished_path),
"--input-serialization-method", input_serialization_method,
"--result-serialization-method", result_serialization_method,
]))
item_script = "\n".join(item_script_pieces)
......@@ -271,6 +289,8 @@ def cluster_results(
with open(error_path, "rb") as fd:
exception = pickle.load(fd)
print(exception)
else:
exception = RuntimeError("Error, but no exception saved")
if not os.path.exists(result_path):
print("Result path does NOT exist", result_path)
......@@ -297,10 +317,14 @@ def cluster_results(
print("Result path exists", result_path)
if result_serialization_method == "save_predictor":
result = Class1AffinityPredictor.load(result_path)
else:
assert result_serialization_method == "pickle"
elif result_serialization_method == "pickle":
with open(result_path, "rb") as fd:
result = pickle.load(fd)
else:
raise ValueError(
"Unsupported serialization method",
result_serialization_method)
yield result
else:
raise RuntimeError("Results do not exist", result_path)
......@@ -329,6 +353,10 @@ parser.add_argument(
parser.add_argument(
"--complete-dir",
)
parser.add_argument(
"--input-serialization-method",
choices=("pickle", "dill"),
default="pickle")
parser.add_argument(
"--result-serialization-method",
choices=("pickle", "save_predictor"),
......@@ -349,11 +377,18 @@ def worker_entry_point(argv=sys.argv[1:]):
args = parser.parse_args(argv)
if args.input_serialization_method == "dill":
import dill
input_serialization_module = dill
else:
assert args.input_serialization_method == "pickle"
input_serialization_module = pickle
with open(args.constant_data, "rb") as fd:
constant_payload = pickle.load(fd)
constant_payload = input_serialization_module.load(fd)
with open(args.worker_data, "rb") as fd:
worker_data = pickle.load(fd)
worker_data = input_serialization_module.load(fd)
kwargs = dict(worker_data)
if constant_payload['constant_data'] is not None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment