diff --git a/downloads-generation/data_mass_spec_benchmark/run_mhcflurry.py b/downloads-generation/data_mass_spec_benchmark/run_mhcflurry.py
index 8cb557792f9f415069d3abc9930b478c2b78e7d2..37000577d752e438e48e3d2ddf742119d96a6f34 100644
--- a/downloads-generation/data_mass_spec_benchmark/run_mhcflurry.py
+++ b/downloads-generation/data_mass_spec_benchmark/run_mhcflurry.py
@@ -177,7 +177,7 @@ def run(argv=sys.argv[1:]):
             work_function=do_predictions,
             work_items=work_items,
             constant_data=GLOBAL_DATA,
-            result_serialization_method="pickle",
+            result_serialization_method="dill",
             clear_constant_data=True)
     else:
         worker_pool = worker_pool_with_gpu_assignments_from_args(args)
diff --git a/mhcflurry/cluster_parallelism.py b/mhcflurry/cluster_parallelism.py
index 0669ac2119b3127a5a6df7df55548db682b8b874..ae49f5ee1c6e7c362c95bfd22a0403197dfe07c4 100644
--- a/mhcflurry/cluster_parallelism.py
+++ b/mhcflurry/cluster_parallelism.py
@@ -64,6 +64,7 @@ def cluster_results_from_args(
         work_function,
         work_items,
         constant_data=None,
+        input_serialization_method="pickle",
         result_serialization_method="pickle",
         clear_constant_data=False):
     """
@@ -95,6 +96,7 @@ def cluster_results_from_args(
         results_workdir=args.cluster_results_workdir,
         additional_complete_file=args.additional_complete_file,
         script_prefix_path=args.cluster_script_prefix_path,
+        input_serialization_method=input_serialization_method,
         result_serialization_method=result_serialization_method,
         max_retries=args.cluster_max_retries,
         clear_constant_data=clear_constant_data
@@ -109,6 +111,7 @@ def cluster_results(
         results_workdir="./cluster-workdir",
         additional_complete_file=None,
         script_prefix_path=None,
+        input_serialization_method="pickle",
         result_serialization_method="pickle",
         max_retries=3,
         clear_constant_data=False):
@@ -156,6 +159,13 @@ def cluster_results(
     generator of B
     """
 
+    if input_serialization_method == "dill":
+        import dill
+        input_serialization_module = dill
+    else:
+        assert input_serialization_method == "pickle"
+        input_serialization_module = pickle
+
     constant_payload = {
         'constant_data': constant_data,
         'function': work_function,
@@ -165,9 +175,14 @@ def cluster_results(
         str(int(time.time())))
     os.mkdir(work_dir)
 
-    constant_payload_path = os.path.join(work_dir, "global_data.pkl")
+    constant_payload_path = os.path.join(
+        work_dir,
+        "global_data." + input_serialization_method)
     with open(constant_payload_path, "wb") as fd:
-        pickle.dump(constant_payload, fd, protocol=pickle.HIGHEST_PROTOCOL)
+        input_serialization_module.dump(
+            constant_payload,
+            fd,
+            protocol=input_serialization_module.HIGHEST_PROTOCOL)
     print("Wrote:", constant_payload_path)
     if clear_constant_data:
         constant_data.clear()
@@ -186,9 +201,11 @@ def cluster_results(
             work_dir, "work-item.%03d-of-%03d" % (i, len(work_items)))
         os.mkdir(item_workdir)
 
-        item_data_path = os.path.join(item_workdir, "data.pkl")
+        item_data_path = os.path.join(
+            item_workdir, "data." + input_serialization_method)
         with open(item_data_path, "wb") as fd:
-            pickle.dump(item, fd, protocol=pickle.HIGHEST_PROTOCOL)
+            input_serialization_module.dump(
+                item, fd, protocol=input_serialization_module.HIGHEST_PROTOCOL)
         print("Wrote:", item_data_path)
 
         item_result_path = os.path.join(item_workdir, "result")
@@ -205,6 +222,7 @@ def cluster_results(
             "--result-out", quote(item_result_path),
             "--error-out", quote(item_error_path),
             "--complete-dir", quote(item_finished_path),
+            "--input-serialization-method", input_serialization_method,
             "--result-serialization-method", result_serialization_method,
         ]))
         item_script = "\n".join(item_script_pieces)
@@ -271,6 +289,8 @@ def cluster_results(
                     with open(error_path, "rb") as fd:
                         exception = pickle.load(fd)
                         print(exception)
+                else:
+                    exception = RuntimeError("Error, but no exception saved")
                 if not os.path.exists(result_path):
                     print("Result path does NOT exist", result_path)
 
@@ -297,10 +317,14 @@ def cluster_results(
                 print("Result path exists", result_path)
                 if result_serialization_method == "save_predictor":
                     result = Class1AffinityPredictor.load(result_path)
-                else:
-                    assert result_serialization_method == "pickle"
+                elif result_serialization_method == "pickle":
                     with open(result_path, "rb") as fd:
                         result = pickle.load(fd)
+                else:
+                    raise ValueError(
+                        "Unsupported serialization method",
+                        result_serialization_method)
+
                 yield result
             else:
                 raise RuntimeError("Results do not exist", result_path)
@@ -329,6 +353,10 @@ parser.add_argument(
 parser.add_argument(
     "--complete-dir",
 )
+parser.add_argument(
+    "--input-serialization-method",
+    choices=("pickle", "dill"),
+    default="pickle")
 parser.add_argument(
     "--result-serialization-method",
     choices=("pickle", "save_predictor"),
@@ -349,11 +377,18 @@ def worker_entry_point(argv=sys.argv[1:]):
 
     args = parser.parse_args(argv)
 
+    if args.input_serialization_method == "dill":
+        import dill
+        input_serialization_module = dill
+    else:
+        assert args.input_serialization_method == "pickle"
+        input_serialization_module = pickle
+
     with open(args.constant_data, "rb") as fd:
-        constant_payload = pickle.load(fd)
+        constant_payload = input_serialization_module.load(fd)
 
     with open(args.worker_data, "rb") as fd:
-        worker_data = pickle.load(fd)
+        worker_data = input_serialization_module.load(fd)
 
     kwargs = dict(worker_data)
     if constant_payload['constant_data'] is not None: