From f179489b2330574ea956a6a047118b1057e4b2e5 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 28 Nov 2017 15:25:56 -0500
Subject: [PATCH] Add time remaining estimation to training script

---
 .../train_allele_specific_models_command.py   | 28 +++++++++++++++----
 1 file changed, 22 insertions(+), 6 deletions(-)

diff --git a/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py b/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py
index 836e4fdc..afbf0fba 100644
--- a/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py
+++ b/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py
@@ -120,14 +120,14 @@ def run(argv=sys.argv[1:]):
 
     if args.allele:
         alleles = [normalize_allele_name(a) for a in args.allele]
-
-        # Allele names in data are assumed to be already normalized.
-        df = df.ix[df.allele.isin(alleles)]
     else:
         alleles = list(allele_counts.ix[
             allele_counts > args.min_measurements_per_allele
         ].index)
 
+    # Allele names in data are assumed to be already normalized.
+    df = df.loc[df.allele.isin(alleles)].dropna()
+
     print("Selected %d alleles: %s" % (len(alleles), ' '.join(alleles)))
     print("Training data: %s" % (str(df.shape)))
 
@@ -160,7 +160,9 @@ def run(argv=sys.argv[1:]):
             hyperparameters['max_epochs'] = args.max_epochs
 
         work_items = []
+        total_data_to_train_on = 0
         for (i, (allele, sub_df)) in enumerate(df.groupby("allele")):
+            total_data_to_train_on += len(sub_df) * n_models
             for model_group in range(n_models):
                 work_dict = {
                     'model_group': model_group,
@@ -170,7 +172,7 @@ def run(argv=sys.argv[1:]):
                     'hyperparameter_set_num': h,
                     'num_hyperparameter_sets': len(hyperparameters_lst),
                     'allele': allele,
-                    'sub_df': sub_df,
+                    'data': sub_df,
                     'hyperparameters': hyperparameters,
                     'verbose': args.verbosity,
                     'predictor': predictor if not worker_pool else None,
@@ -189,11 +191,25 @@ def run(argv=sys.argv[1:]):
             # Run in serial. In this case, every worker is passed the same predictor,
             # which it adds models to, so no merging is required. It also saves
             # as it goes so no saving is required at the end.
+            start = time.time()
+            data_trained_on = 0
             while work_items:
                 item = work_items.pop(0)
                 work_predictor = work_entrypoint(item)
                 assert work_predictor is predictor
 
+                # When running in serial we try to estimate time remaining.
+                data_trained_on += len(item['data'])
+                progress = data_trained_on / total_data_to_train_on
+                time_elapsed = time.time() - start
+                total_time = time_elapsed / progress
+                print(
+                    "Estimated total training time: %0.2f min, "
+                    "remaining: %0.2f min" % (
+                        total_time / 60,
+                        (total_time - time_elapsed) / 60))
+
+
     if worker_pool:
         worker_pool.close()
         worker_pool.join()
@@ -220,7 +236,7 @@ def process_work(
         hyperparameter_set_num,
         num_hyperparameter_sets,
         allele,
-        sub_df,
+        data,
         hyperparameters,
         verbose,
         predictor,
@@ -241,7 +257,7 @@ def process_work(
             n_models,
             allele))
 
-    train_data = sub_df.dropna().sample(frac=1.0)
+    train_data = data.sample(frac=1.0)
     (model,) = predictor.fit_allele_specific_predictors(
         n_models=1,
         architecture_hyperparameters=hyperparameters,
-- 
GitLab