Skip to content
Snippets Groups Projects
Commit 77fda00d authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

improve parallelization

parent fc83dd0b
No related merge requests found
......@@ -37,8 +37,7 @@ time mhcflurry-class1-train-allele-specific-models \
--out-models-dir models \
--percent-rank-calibration-num-peptides-per-length 1000000 \
--min-measurements-per-allele 75 \
--train-num-jobs 0 \
--calibration-num-jobs 1
--num-jobs 0
cp $SCRIPT_ABSOLUTE_PATH .
bzip2 LOG.txt
......
......@@ -600,7 +600,7 @@ class Class1NeuralNetwork(object):
self.hyperparameters['patience'])
if i > threshold:
print((progress_preamble + " " +
"Early stopping at epoch %3d / %3d: loss=%g. "
"Stopping at epoch %3d / %3d: loss=%g. "
"Min val loss (%s) at epoch %s" % (
i,
self.hyperparameters['max_epochs'],
......
......@@ -17,7 +17,6 @@ from mhcnames import normalize_allele_name
import tqdm # progress bar
from .class1_affinity_predictor import Class1AffinityPredictor
from .class1_neural_network import Class1NeuralNetwork
from .common import configure_logging, set_keras_backend
......@@ -29,6 +28,12 @@ from .common import configure_logging, set_keras_backend
GLOBAL_DATA = {}
# Note on parallelization:
# It seems essential currently (tensorflow==1.4.1) that no processes are forked
# after tensorflow has been used at all, which includes merely importing
# keras.backend. So we must make sure not to use tensorflow in the main process
# if we are running in parallel.
parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument(
......@@ -97,26 +102,19 @@ parser.add_argument(
help="Keras verbosity. Default: %(default)s",
default=0)
parser.add_argument(
"--train-num-jobs",
default=1,
type=int,
metavar="N",
help="Number of processes to parallelize training over. "
"Set to 1 for serial run. Set to 0 to use number of cores. Experimental."
"Default: %(default)s.")
parser.add_argument(
"--calibration-num-jobs",
"--num-jobs",
default=1,
type=int,
metavar="N",
help="Number of processes to parallelize percent rank calibration over. "
"Set to 1 for serial run. Set to 0 to use number of cores. Experimental."
"Default: %(default)s.")
help="Number of processes to parallelize training and percent rank "
"calibration over. Experimental. "
"Set to 1 for serial run. Set to 0 to use number of cores. Default: %(default)s.")
parser.add_argument(
"--backend",
choices=("tensorflow-gpu", "tensorflow-cpu"),
help="Keras backend. If not specified will use system default.")
def run(argv=sys.argv[1:]):
global GLOBAL_DATA
......@@ -171,15 +169,15 @@ def run(argv=sys.argv[1:]):
GLOBAL_DATA["train_data"] = df
predictor = Class1AffinityPredictor()
if args.train_num_jobs == 1:
if args.num_jobs == 1:
# Serial run
print("Running in serial.")
worker_pool = None
else:
worker_pool = Pool(
processes=(
args.train_num_jobs
if args.train_num_jobs else None))
args.num_jobs
if args.num_jobs else None))
print("Using worker pool: %s" % str(worker_pool))
if args.out_models_dir and not os.path.exists(args.out_models_dir):
......@@ -257,18 +255,22 @@ def run(argv=sys.argv[1:]):
start = time.time()
if args.percent_rank_calibration_num_peptides_per_length > 0:
alleles = list(predictor.supported_alleles)
first_allele = alleles.pop(0)
print("Performing percent rank calibration. Calibrating first allele.")
print("Performing percent rank calibration. Encoding peptides.")
encoded_peptides = predictor.calibrate_percentile_ranks(
alleles=[first_allele],
alleles=[], # don't actually do any calibration, just return peptides
num_peptides_per_length=args.percent_rank_calibration_num_peptides_per_length)
# Now we encode the peptides for each neural network, so the encoding
# becomes cached.
for network in predictor.neural_networks:
network.peptides_to_network_input(encoded_peptides)
assert encoded_peptides.encoding_cache # must have cached the encoding
print("Finished calibrating percent ranks for first allele in %0.2f sec." % (
print("Finished encoding peptides for percent ranks in %0.2f sec." % (
time.time() - start))
print("Calibrating %d additional alleles." % len(alleles))
print("Calibrating percent rank calibration for %d alleles." % len(alleles))
if args.calibration_num_jobs == 1:
if args.num_jobs == 1:
# Serial run
print("Running in serial.")
worker_pool = None
......@@ -283,21 +285,21 @@ def run(argv=sys.argv[1:]):
# Store peptides in global variable so they are in shared memory
# after fork, instead of needing to be pickled.
GLOBAL_DATA["calibration_peptides"] = encoded_peptides
Class1NeuralNetwork.clear_model_cache()
worker_pool = Pool(
processes=(
args.calibration_num_jobs
if args.calibration_num_jobs else None))
args.num_jobs
if args.num_jobs else None))
print("Using worker pool: %s" % str(worker_pool))
results = worker_pool.imap_unordered(
partial(
calibrate_percentile_ranks,
predictor=args.out_models_dir),
predictor=predictor),
alleles,
chunksize=1)
for result in tqdm.tqdm(results, ascii=True, total=len(alleles)):
predictor.allele_to_percent_rank_transform.update(result)
print("Done calibrating %d additional alleles." % len(alleles))
predictor.save(args.out_models_dir, model_names_to_write=[])
......@@ -306,7 +308,6 @@ def run(argv=sys.argv[1:]):
if worker_pool:
worker_pool.close()
worker_pool.join()
worker_pool = None
print("Train time: %0.2f min. Percent rank calibration time: %0.2f min." % (
training_time / 60.0, percent_rank_calibration_time / 60.0))
......@@ -381,13 +382,8 @@ def calibrate_percentile_ranks(allele, predictor, peptides=None):
Private helper function.
"""
global GLOBAL_DATA
Class1NeuralNetwork.clear_model_cache()
import keras.backend as K
K.clear_session()
if peptides is None:
peptides = GLOBAL_DATA["calibration_peptides"]
if isinstance(predictor, str):
predictor = Class1AffinityPredictor.load(predictor)
predictor.calibrate_percentile_ranks(
peptides=peptides,
alleles=[allele])
......
......@@ -62,8 +62,7 @@ def run_and_check(n_jobs=0):
"--allele", "HLA-A*02:01", "HLA-A*01:01", "HLA-A*03:01",
"--out-models-dir", models_dir,
"--percent-rank-calibration-num-peptides-per-length", "10000",
"--train-num-jobs", str(n_jobs),
"--calibration-num-jobs", str(n_jobs),
"--num-jobs", str(n_jobs),
"--ignore-inequalities",
]
print("Running with args: %s" % args)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment