Skip to content
Snippets Groups Projects
Commit d2ff112b authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

better parallelization of model training

parent 098f5e71
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,12 @@ import tqdm # progress bar ...@@ -18,6 +18,12 @@ import tqdm # progress bar
from .class1_affinity_predictor import Class1AffinityPredictor from .class1_affinity_predictor import Class1AffinityPredictor
from .common import configure_logging, set_keras_backend from .common import configure_logging, set_keras_backend
# To avoid pickling large matrices to send to child processes when running in
# parallel, we use this global variable as a place to store data. Data that is
# stored here before creating the thread pool will be inherited to the child
# processes upon fork() call, allowing us to share large data with the workers
# efficiently.
GLOBAL_DATA = {} GLOBAL_DATA = {}
...@@ -160,6 +166,8 @@ def run(argv=sys.argv[1:]): ...@@ -160,6 +166,8 @@ def run(argv=sys.argv[1:]):
print("Selected %d alleles: %s" % (len(alleles), ' '.join(alleles))) print("Selected %d alleles: %s" % (len(alleles), ' '.join(alleles)))
print("Training data: %s" % (str(df.shape))) print("Training data: %s" % (str(df.shape)))
GLOBAL_DATA["train_data"] = df
predictor = Class1AffinityPredictor() predictor = Class1AffinityPredictor()
if args.train_num_jobs == 1: if args.train_num_jobs == 1:
# Serial run # Serial run
...@@ -190,9 +198,7 @@ def run(argv=sys.argv[1:]): ...@@ -190,9 +198,7 @@ def run(argv=sys.argv[1:]):
hyperparameters['max_epochs'] = args.max_epochs hyperparameters['max_epochs'] = args.max_epochs
work_items = [] work_items = []
total_data_to_train_on = 0 for (i, allele) in enumerate(df.allele.unique()):
for (i, (allele, sub_df)) in enumerate(df.groupby("allele")):
total_data_to_train_on += len(sub_df) * n_models
for model_group in range(n_models): for model_group in range(n_models):
work_dict = { work_dict = {
'model_group': model_group, 'model_group': model_group,
...@@ -202,7 +208,7 @@ def run(argv=sys.argv[1:]): ...@@ -202,7 +208,7 @@ def run(argv=sys.argv[1:]):
'hyperparameter_set_num': h, 'hyperparameter_set_num': h,
'num_hyperparameter_sets': len(hyperparameters_lst), 'num_hyperparameter_sets': len(hyperparameters_lst),
'allele': allele, 'allele': allele,
'data': sub_df, 'data': None, # subselect from GLOBAL_DATA["train_data"]
'hyperparameters': hyperparameters, 'hyperparameters': hyperparameters,
'verbose': args.verbosity, 'verbose': args.verbosity,
'predictor': predictor if not worker_pool else None, 'predictor': predictor if not worker_pool else None,
...@@ -318,6 +324,10 @@ def train_model( ...@@ -318,6 +324,10 @@ def train_model(
if predictor is None: if predictor is None:
predictor = Class1AffinityPredictor() predictor = Class1AffinityPredictor()
if data is None:
full_data = GLOBAL_DATA["train_data"]
data = full_data.loc[full_data.allele == allele]
progress_preamble = ( progress_preamble = (
"[%2d / %2d hyperparameters] " "[%2d / %2d hyperparameters] "
"[%4d / %4d alleles] " "[%4d / %4d alleles] "
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment