diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py index 694e4a47c4e7aed1d4363b3990525371e9ea3f0e..7e4fdeb39f0272718992ca732d1e52bad566452a 100644 --- a/mhcflurry/train_allele_specific_models_command.py +++ b/mhcflurry/train_allele_specific_models_command.py @@ -258,21 +258,22 @@ def run(argv=sys.argv[1:]): hyperparameters['max_epochs'] = args.max_epochs for (i, allele) in enumerate(df.allele.unique()): - work_dict = { - 'n_models': n_models, - 'allele_num': i, - 'n_alleles': len(alleles), - 'hyperparameter_set_num': h, - 'num_hyperparameter_sets': len(hyperparameters_lst), - 'allele': allele, - 'data': None, # subselect from GLOBAL_DATA["train_data"] - 'hyperparameters': hyperparameters, - 'verbose': args.verbosity, - 'progress_print_interval': None if worker_pool else 5.0, - 'predictor': predictor if not worker_pool else None, - 'save_to': args.out_models_dir if not worker_pool else None, - } - work_items.append(work_dict) + for model_num in range(n_models): + work_dict = { + 'n_models': 1, + 'allele_num': i, + 'n_alleles': len(alleles), + 'hyperparameter_set_num': h, + 'num_hyperparameter_sets': len(hyperparameters_lst), + 'allele': allele, + 'data': None, # subselect from GLOBAL_DATA["train_data"] + 'hyperparameters': hyperparameters, + 'verbose': args.verbosity, + 'progress_print_interval': None if worker_pool else 5.0, + 'predictor': predictor if not worker_pool else None, + 'save_to': args.out_models_dir if not worker_pool else None, + } + work_items.append(work_dict) if worker_pool: print("Processing %d work items in parallel." % len(work_items))