From f179489b2330574ea956a6a047118b1057e4b2e5 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Tue, 28 Nov 2017 15:25:56 -0500 Subject: [PATCH] Add time remaining estimation to training script --- .../train_allele_specific_models_command.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py b/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py index 836e4fdc..afbf0fba 100644 --- a/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py +++ b/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py @@ -120,14 +120,14 @@ def run(argv=sys.argv[1:]): if args.allele: alleles = [normalize_allele_name(a) for a in args.allele] - - # Allele names in data are assumed to be already normalized. - df = df.ix[df.allele.isin(alleles)] else: alleles = list(allele_counts.ix[ allele_counts > args.min_measurements_per_allele ].index) + # Allele names in data are assumed to be already normalized. + df = df.loc[df.allele.isin(alleles)].dropna() + print("Selected %d alleles: %s" % (len(alleles), ' '.join(alleles))) print("Training data: %s" % (str(df.shape))) @@ -160,7 +160,9 @@ def run(argv=sys.argv[1:]): hyperparameters['max_epochs'] = args.max_epochs work_items = [] + total_data_to_train_on = 0 for (i, (allele, sub_df)) in enumerate(df.groupby("allele")): + total_data_to_train_on += len(sub_df) * n_models for model_group in range(n_models): work_dict = { 'model_group': model_group, @@ -170,7 +172,7 @@ def run(argv=sys.argv[1:]): 'hyperparameter_set_num': h, 'num_hyperparameter_sets': len(hyperparameters_lst), 'allele': allele, - 'sub_df': sub_df, + 'data': sub_df, 'hyperparameters': hyperparameters, 'verbose': args.verbosity, 'predictor': predictor if not worker_pool else None, @@ -189,11 +191,25 @@ def run(argv=sys.argv[1:]): # Run in serial. In this case, every worker is passed the same predictor, # which it adds models to, so no merging is required. It also saves # as it goes so no saving is required at the end. + start = time.time() + data_trained_on = 0 while work_items: item = work_items.pop(0) work_predictor = work_entrypoint(item) assert work_predictor is predictor + # When running in serial we try to estimate time remaining. + data_trained_on += len(item['data']) + progress = data_trained_on / total_data_to_train_on + time_elapsed = time.time() - start + total_time = time_elapsed / progress + print( + "Estimated total training time: %0.2f min, " + "remaining: %0.2f min" % ( + total_time / 60, + (total_time - time_elapsed) / 60)) + + if worker_pool: worker_pool.close() worker_pool.join() @@ -220,7 +236,7 @@ def process_work( hyperparameter_set_num, num_hyperparameter_sets, allele, - sub_df, + data, hyperparameters, verbose, predictor, @@ -241,7 +257,7 @@ def process_work( n_models, allele)) - train_data = sub_df.dropna().sample(frac=1.0) + train_data = data.sample(frac=1.0) (model,) = predictor.fit_allele_specific_predictors( n_models=1, architecture_hyperparameters=hyperparameters, -- GitLab