diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 73405fdbf6e6a27e060873512cd5503ba735c30b..043a5c0c5704533dd7939e26572be9ee56a7bd41 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -125,6 +125,13 @@ class Class1NeuralNetwork(object):
     (Keras model, existing network weights)
     """
 
+    @classmethod
+    def clear_model_cache(klass):
+        """
+        Clear the Keras model cache.
+        """
+        klass.KERAS_MODELS_CACHE.clear()
+
     @classmethod
     def borrow_cached_network(klass, network_json, network_weights):
         """
diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index 3d72b32a1a50ea5df3b0abf4533ce74dd05f394a..c2b6ed5dcb702d91e471b87a46708f326af6ce20 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -17,6 +17,7 @@ from mhcnames import normalize_allele_name
 import tqdm  # progress bar
 
 from .class1_affinity_predictor import Class1AffinityPredictor
+from .class1_neural_network import Class1NeuralNetwork
 from .common import configure_logging, set_keras_backend
 
 
@@ -281,10 +282,11 @@ def run(argv=sys.argv[1:]):
             # Store peptides in global variable so they are in shared memory
             # after fork, instead of needing to be pickled.
             GLOBAL_DATA["calibration_peptides"] = encoded_peptides
+            Class1NeuralNetwork.clear_model_cache()
             worker_pool = Pool(
                 processes=(
                     args.calibration_num_jobs
-                    if args.train_num_jobs else None))
+                    if args.calibration_num_jobs else None))
             print("Using worker pool: %s" % str(worker_pool))
             results = worker_pool.imap_unordered(
                 partial(
@@ -376,6 +378,7 @@ def calibrate_percentile_ranks(allele, predictor, peptides=None):
     """
     Private helper function.
     """
+    global GLOBAL_DATA
     if peptides is None:
         peptides = GLOBAL_DATA["calibration_peptides"]
     predictor.calibrate_percentile_ranks(