diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py
index 1f2507e8d29f2237f3a0c02f56274d966426fed2..17dc13c50268f407c2ab7f452090ba00bffa32df 100644
--- a/mhcflurry/class1_affinity_predictor.py
+++ b/mhcflurry/class1_affinity_predictor.py
@@ -312,6 +312,11 @@ class Class1AffinityPredictor(object):
                 str(self.class1_pan_allele_models),
                 str(self.allele_to_allele_specific_models)))
 
+    def save_metadata_df(self, models_dir, name):
+        df = self.metadata_dataframes[name]
+        metadata_df_path = join(models_dir, "%s.csv.bz2" % name)
+        df.to_csv(metadata_df_path, index=False, compression="bz2")
+
     def save(self, models_dir, model_names_to_write=None, write_metadata=True):
         """
         Serialize the predictor to a directory on disk. If the directory does
@@ -375,9 +380,8 @@ class Class1AffinityPredictor(object):
                 info_path, sep="\t", header=False, index=False)
 
             if self.metadata_dataframes:
-                for (name, df) in self.metadata_dataframes.items():
-                    metadata_df_path = join(models_dir, "%s.csv.bz2" % name)
-                    df.to_csv(metadata_df_path, index=False, compression="bz2")
+                for name in self.metadata_dataframes:
+                    self.save_metadata_df(models_dir, name)
 
         # Save allele sequences
         if self.allele_to_sequence is not None:
diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py
index 8e56bbd468e9508c0f95c6ae96c87bea42a6ebb7..e756485ccd2a17742a0e92fd7dccd12b96e3da5e 100644
--- a/mhcflurry/train_pan_allele_models_command.py
+++ b/mhcflurry/train_pan_allele_models_command.py
@@ -12,6 +12,7 @@ import pprint
 import hashlib
 import pickle
 import subprocess
+import uuid
 from functools import partial
 
 import numpy
@@ -40,7 +41,7 @@ from .encodable_sequences import EncodableSequences
 # stored here before creating the thread pool will be inherited to the child
 # processes upon fork() call, allowing us to share large data with the workers
 # via shared memory.
-GLOBAL_DATA = {}
+GLOBAL_DATA = None
 
 # Note on parallelization:
 # It seems essential currently (tensorflow==1.4.1) that no processes are forked
@@ -50,6 +51,12 @@ GLOBAL_DATA = {}
 
 parser = argparse.ArgumentParser(usage=__doc__)
 
+parser.add_argument(
+    "--action",
+    nargs="+",
+    default=["setup", "train", "finalize"],
+    choices=["setup", "train", "finalize"],
+    help="Actions to run")
 parser.add_argument(
     "--data",
     metavar="FILE.csv",
@@ -108,13 +115,6 @@ parser.add_argument(
     "--allele-sequences",
     metavar="FILE.csv",
     help="Allele sequences file.")
-parser.add_argument(
-    "--save-interval",
-    type=float,
-    metavar="N",
-    default=60,
-    help="Write models to disk every N seconds. Only affects parallel runs; "
-    "serial runs write each model to disk as it is trained.")
 parser.add_argument(
     "--verbosity",
     type=int,
@@ -212,6 +212,18 @@ def pretrain_data_iterator(
             yield (allele_encoding, encodable_peptides, df.stack().values)
 
 
+def mark_work_item_complete(predictor, work_item_name):
+    if "training_work_items" not in predictor.metadata_dataframes:
+        predictor.metadata_dataframes["training_work_items"] = pandas.DataFrame({
+            'work_item_name': [work_item_name],
+            'complete': [True],
+        }).set_index("work_item_name", drop=False)
+    else:
+        predictor.metadata_dataframes["training_work_items"].loc[
+            work_item_name, "complete"
+        ] = True
+
+
 def run(argv=sys.argv[1:]):
     # On sigusr1 print stack trace
     print("To show stack trace, run:\nkill -s USR1 %d" % os.getpid())
@@ -240,6 +252,9 @@ def main(args):
 
     configure_logging(verbose=args.verbosity > 1)
 
+def setup_action(args):
+
+    print("Beginning setup action.")
     hyperparameters_lst = yaml.load(open(args.hyperparameters))
     assert isinstance(hyperparameters_lst, list)
     print("Loaded hyperparameters list:")
@@ -297,9 +312,10 @@ def main(args):
         alleles=allele_sequences_in_use.index.values,
         allele_to_sequence=allele_sequences_in_use.to_dict())
 
-    GLOBAL_DATA["train_data"] = df
-    GLOBAL_DATA["folds_df"] = folds_df
-    GLOBAL_DATA["allele_encoding"] = allele_encoding
+    global_data = {}
+    global_data["train_data"] = df
+    global_data["folds_df"] = folds_df
+    global_data["allele_encoding"] = allele_encoding
 
     if not os.path.exists(args.out_models_dir):
         print("Attempting to create directory: %s" % args.out_models_dir)
@@ -315,7 +331,6 @@ def main(args):
                 left_index=True,
                 right_index=True)
         })
-    serial_run = args.num_jobs == 0
 
     work_items = []
     for (h, hyperparameters) in enumerate(hyperparameters_lst):
@@ -332,6 +347,7 @@ def main(args):
         for fold in range(args.ensemble_size):
             for replicate in range(args.num_replicates):
                 work_dict = {
+                    'work_item_name': str(uuid.uuid4()),
                     'architecture_num': h,
                     'num_architectures': len(hyperparameters_lst),
                     'fold_num': fold,
@@ -340,14 +356,44 @@ def main(args):
                     'num_replicates': args.num_replicates,
                     'hyperparameters': hyperparameters,
                     'pretrain_data_filename': args.pretrain_data,
-                    'verbose': args.verbosity,
-                    'progress_print_interval': 60.0 if not serial_run else 5.0,
-                    'predictor': predictor if serial_run else None,
-                    'save_to': args.out_models_dir if serial_run else None,
                 }
                 work_items.append(work_dict)
 
-    start = time.time()
+    predictor.metadata_dataframes["training_work_items"] = pandas.DataFrame(
+        work_dict,
+    ).set_index("work_item_name", drop=False)
+    predictor.metadata_dataframes["training_work_items"]["complete"] = False
+    with open(os.path.join(args.out_models_dir, "global_data.pkl"), "wb") as fd:
+        pickle.dump(global_data, fd, protocol=pickle.HIGHEST_PROTOCOL)
+
+
+def train_action(args):
+    global GLOBAL_DATA
+
+    print("Beginning train action.")
+    predictor = Class1AffinityPredictor.load(args.out_models_dir)
+    print("Loaded predictor with %d networks" % len(predictor.neural_networks))
+
+    with open(os.path.join(args.out_models_dir, "global_data.pkl"), "rb") as fd:
+        GLOBAL_DATA = pickle.load(fd)
+    print("Loaded global data:")
+    print(GLOBAL_DATA)
+
+    work_items_df = predictor.metadata_dataframes[
+        "training_work_items"
+    ].set_index("work_item_name", drop=False)
+    print("Loaded work items:")
+    print(work_items_df)
+    print("Work items complete:")
+    print(work_items_df.complete.value_counts())
+    work_items_df = work_items_df.loc[
+        work_items_df.complete
+    ].copy()
+    del work_items_df["complete"]
+    work_items = work_items_df.to_dict('records')
+    print("Will process %d work items" % len(work_items))
+
+    serial_run = args.num_jobs == 0
 
     # The estimated time to completion is more accurate if we randomize
     # the order of the work.
@@ -355,6 +401,14 @@ def main(args):
     for (work_item_num, item) in enumerate(work_items):
         item['work_item_num'] = work_item_num
         item['num_work_items'] = len(work_items)
+        item['progress_print_interval'] = 60.0 if not serial_run else 5.0
+        item['predictor'] = predictor if serial_run else None
+        item['save_to'] = args.out_models_dir if serial_run else None
+        item['verbose'] = args.verbosity
+        if args.pretrain_data:
+            item['pretrain_data_filename'] = args.pretrain_data
+
+    start = time.time()
 
     if args.cluster_parallelism:
         # Run using separate processes HPC cluster.
@@ -391,30 +445,25 @@ def main(args):
             results_generator = None
 
     if results_generator:
-        unsaved_predictors = []
-        last_save_time = time.time()
         for new_predictor in tqdm.tqdm(results_generator, total=len(work_items)):
-            unsaved_predictors.append(new_predictor)
-
-            if time.time() > last_save_time + args.save_interval:
-                # Save current predictor.
-                save_start = time.time()
-                new_model_names = predictor.merge_in_place(unsaved_predictors)
-                predictor.save(
-                    args.out_models_dir,
-                    model_names_to_write=new_model_names,
-                    write_metadata=False)
-                print(
-                    "Saved predictor (%d models total) including %d new models "
-                    "in %0.2f sec to %s" % (
-                        len(predictor.neural_networks),
-                        len(new_model_names),
-                        time.time() - save_start,
-                        args.out_models_dir))
-                unsaved_predictors = []
-                last_save_time = time.time()
-
-        predictor.merge_in_place(unsaved_predictors)
+            save_start = time.time()
+            (work_item_name,) = new_predictor.metadata_dataframes[
+                "training_work_items"
+            ].work_item_name.values
+            (new_model_name,) = predictor.merge_in_place([new_predictor])
+            mark_work_item_complete(predictor, work_item_name)
+            predictor.save(
+                args.out_models_dir,
+                model_names_to_write=[new_model_name],
+                write_metadata=False)
+            predictor.save_metadata_df(
+                args.out_models_dir, "training_work_items")
+            print(
+                "Saved predictor (%d models total) with 1 new models"
+                "in %0.2f sec to %s" % (
+                    len(predictor.neural_networks),
+                    time.time() - save_start,
+                    args.out_models_dir))
 
     print("Saving final predictor to: %s" % args.out_models_dir)
     # We want the final predictor to support all alleles with sequences, not
@@ -438,6 +487,7 @@ def main(args):
 
 
 def train_model(
+        work_item_name,
         work_item_num,
         num_work_items,
         architecture_num,
@@ -587,10 +637,12 @@ def train_model(
         "architecture_num": architecture_num,
         "num_architectures": num_architectures,
         "train_peptide_hash": train_peptide_hash.hexdigest(),
+        "work_item_name": work_item_name,
     })
 
     numpy.testing.assert_equal(
         predictor.manifest_df.shape[0], len(predictor.class1_pan_allele_models))
+    mark_work_item_complete(predictor, work_item_name)
     predictor.add_pan_allele_model(model, models_dir_for_save=save_to)
     numpy.testing.assert_equal(
         predictor.manifest_df.shape[0], len(predictor.class1_pan_allele_models))