diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index c2b6ed5dcb702d91e471b87a46708f326af6ce20..ce6231ab53f644b37e55a622027f3808e879f217 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -252,6 +252,7 @@ def run(argv=sys.argv[1:]):
     if worker_pool:
         worker_pool.close()
         worker_pool.join()
+        worker_pool = None
 
     start = time.time()
     if args.percent_rank_calibration_num_peptides_per_length > 0:
@@ -305,6 +306,7 @@ def run(argv=sys.argv[1:]):
     if worker_pool:
         worker_pool.close()
         worker_pool.join()
+        worker_pool = None
 
     print("Train time: %0.2f min. Percent rank calibration time: %0.2f min." % (
         training_time / 60.0, percent_rank_calibration_time / 60.0))