diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index 694e4a47c4e7aed1d4363b3990525371e9ea3f0e..7e4fdeb39f0272718992ca732d1e52bad566452a 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -258,21 +258,22 @@ def run(argv=sys.argv[1:]):
             hyperparameters['max_epochs'] = args.max_epochs
 
         for (i, allele) in enumerate(df.allele.unique()):
-            work_dict = {
-                'n_models': n_models,
-                'allele_num': i,
-                'n_alleles': len(alleles),
-                'hyperparameter_set_num': h,
-                'num_hyperparameter_sets': len(hyperparameters_lst),
-                'allele': allele,
-                'data': None,  # subselect from GLOBAL_DATA["train_data"]
-                'hyperparameters': hyperparameters,
-                'verbose': args.verbosity,
-                'progress_print_interval': None if worker_pool else 5.0,
-                'predictor': predictor if not worker_pool else None,
-                'save_to': args.out_models_dir if not worker_pool else None,
-            }
-            work_items.append(work_dict)
+            for model_num in range(n_models):
+                work_dict = {
+                    'n_models': 1,
+                    'allele_num': i,
+                    'n_alleles': len(alleles),
+                    'hyperparameter_set_num': h,
+                    'num_hyperparameter_sets': len(hyperparameters_lst),
+                    'allele': allele,
+                    'data': None,  # subselect from GLOBAL_DATA["train_data"]
+                    'hyperparameters': hyperparameters,
+                    'verbose': args.verbosity,
+                    'progress_print_interval': None if worker_pool else 5.0,
+                    'predictor': predictor if not worker_pool else None,
+                    'save_to': args.out_models_dir if not worker_pool else None,
+                }
+                work_items.append(work_dict)
 
     if worker_pool:
         print("Processing %d work items in parallel." % len(work_items))