Skip to content
Snippets Groups Projects
Commit 098f5e71 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

redo parallelization implementation of percentile rank calibration

parent ad9a450e
No related branches found
No related tags found
No related merge requests found
......@@ -889,8 +889,7 @@ class Class1AffinityPredictor(object):
peptides=None,
num_peptides_per_length=int(1e5),
alleles=None,
bins=None,
worker_pool=None):
bins=None):
"""
Compute the cumulative distribution of ic50 values for a set of alleles
over a large universe of random peptides, to enable computing quantiles in
......@@ -898,7 +897,7 @@ class Class1AffinityPredictor(object):
Parameters
----------
peptides : sequence of string, optional
peptides : sequence of string or EncodableSequences, optional
Peptides to use
num_peptides_per_length : int, optional
If peptides argument is not specified, then num_peptides_per_length
......@@ -911,8 +910,10 @@ class Class1AffinityPredictor(object):
Anything that can be passed to numpy.histogram's "bins" argument
can be used here, i.e. either an integer or a sequence giving bin
edges. This is in ic50 space.
worker_pool : multiprocessing.Pool, optional
If specified multiple alleles will be calibrated in parallel
Returns
----------
EncodableSequences : peptides used for calibration
"""
if bins is None:
bins = to_ic50(numpy.linspace(1, 0, 1000))
......@@ -931,57 +932,12 @@ class Class1AffinityPredictor(object):
encoded_peptides = EncodableSequences.create(peptides)
if worker_pool and len(alleles) > 1:
# Run in parallel
# Performance hack.
self.neural_networks[0].peptides_to_network_input(encoded_peptides)
do_work = partial(
_calibrate_percentile_ranks,
predictor=self,
peptides=encoded_peptides,
bins=bins)
list_of_singleton_alleles = [ [allele] for allele in alleles ]
results = worker_pool.imap_unordered(
do_work, list_of_singleton_alleles, chunksize=1)
# Add progress bar
results = tqdm.tqdm(results, ascii=True, total=len(alleles))
for (i, allele) in enumerate(alleles):
predictions = self.predict(peptides, allele=allele)
transform = PercentRankTransform()
transform.fit(predictions, bins=bins)
self.allele_to_percent_rank_transform[allele] = transform
# Merge results
for partial_dict in results:
self.allele_to_percent_rank_transform.update(partial_dict)
else:
# Run in serial
self.allele_to_percent_rank_transform.update(
_calibrate_percentile_ranks(
alleles=alleles,
predictor=self,
peptides=encoded_peptides,
bins=bins))
def _calibrate_percentile_ranks(alleles, predictor, peptides, bins):
"""
Private helper function.
Parameters
----------
alleles : list of string
predictor : Class1AffinityPredictor
peptides : list of string or EncodableSequences
bins : object
return encoded_peptides
Returns
-------
dict : allele -> percentile rank transform
"""
result = {}
for (i, allele) in enumerate(alleles):
predictions = predictor.predict(peptides, allele=allele)
transform = PercentRankTransform()
transform.fit(predictions, bins=bins)
result[allele] = transform
return result
......@@ -8,6 +8,7 @@ import sys
import time
import traceback
from multiprocessing import Pool
from functools import partial
import pandas
import yaml
......@@ -17,6 +18,9 @@ import tqdm # progress bar
from .class1_affinity_predictor import Class1AffinityPredictor
from .common import configure_logging, set_keras_backend
GLOBAL_DATA = {}
parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument(
......@@ -106,6 +110,8 @@ parser.add_argument(
help="Keras backend. If not specified will use system default.")
def run(argv=sys.argv[1:]):
global GLOBAL_DATA
# On sigusr1 print stack trace
print("To show stack trace, run:\nkill -s USR1 %d" % os.getpid())
signal.signal(signal.SIGUSR1, lambda sig, frame: traceback.print_stack())
......@@ -210,7 +216,7 @@ def run(argv=sys.argv[1:]):
predictors = list(
tqdm.tqdm(
worker_pool.imap_unordered(
work_entrypoint, work_items, chunksize=1),
train_model_entrypoint, work_items, chunksize=1),
ascii=True,
total=len(work_items)))
......@@ -225,7 +231,7 @@ def run(argv=sys.argv[1:]):
start = time.time()
for _ in tqdm.trange(len(work_items)):
item = work_items.pop(0) # want to keep freeing up memory
work_predictor = work_entrypoint(item)
work_predictor = train_model_entrypoint(item)
assert work_predictor is predictor
assert not work_items
......@@ -240,24 +246,46 @@ def run(argv=sys.argv[1:]):
worker_pool.join()
if args.percent_rank_calibration_num_peptides_per_length > 0:
alleles = list(predictor.supported_alleles)
first_allele = alleles.pop(0)
print("Performing percent rank calibration. Calibrating first allele.")
start = time.time()
encoded_peptides = predictor.calibrate_percentile_ranks(
alleles=[first_allele],
num_peptides_per_length=args.percent_rank_calibration_num_peptides_per_length)
percent_rank_calibration_time = time.time() - start
print("Finished calibrating percent ranks for first allele in %0.2f sec." % (
percent_rank_calibration_time))
print("Calibrating %d additional alleles." % len(alleles))
if args.calibration_num_jobs == 1:
# Serial run
worker_pool = None
results = (
calibrate_percentile_ranks(
allele=allele,
predictor=predictor,
peptides=encoded_peptides)
for allele in alleles)
else:
# Parallel run
# Store peptides in global variable so they are in shared memory
# after fork, instead of needing to be pickled.
GLOBAL_DATA["calibration_peptides"] = encoded_peptides
worker_pool = Pool(
processes=(
args.calibration_num_jobs
if args.train_num_jobs else None))
print("Using worker pool: %s" % str(worker_pool))
print("Performing percent rank calibration.")
start = time.time()
predictor.calibrate_percentile_ranks(
num_peptides_per_length=args.percent_rank_calibration_num_peptides_per_length,
worker_pool=worker_pool)
percent_rank_calibration_time = time.time() - start
print("Finished calibrating percent ranks in %0.2f sec." % (
percent_rank_calibration_time))
results = worker_pool.imap_unordered(
partial(
calibrate_percentile_ranks,
predictor=predictor), alleles, chunksize=1)
for result in tqdm.tqdm(results, ascii=True, total=len(alleles)):
predictor.allele_to_percent_rank_transform.update(result)
print("Done calibrating %d additional alleles." % len(alleles))
predictor.save(args.out_models_dir, model_names_to_write=[])
if worker_pool:
......@@ -269,11 +297,11 @@ def run(argv=sys.argv[1:]):
print("Predictor written to: %s" % args.out_models_dir)
def work_entrypoint(item):
return process_work(**item)
def train_model_entrypoint(item):
return train_model(**item)
def process_work(
def train_model(
model_group,
n_models,
allele_num,
......@@ -325,5 +353,19 @@ def process_work(
return predictor
def calibrate_percentile_ranks(allele, predictor, peptides=None):
"""
Private helper function.
"""
if peptides is None:
peptides = GLOBAL_DATA["calibration_peptides"]
predictor.calibrate_percentile_ranks(
peptides=peptides,
alleles=[allele])
return {
allele: predictor.allele_to_percent_rank_transform[allele],
}
if __name__ == '__main__':
run()
......@@ -62,7 +62,8 @@ def run_and_check(n_jobs=0):
"--allele", "HLA-A*02:01", "HLA-A*01:01", "HLA-A*03:01",
"--out-models-dir", models_dir,
"--percent-rank-calibration-num-peptides-per-length", "10000",
"--parallelization-num-jobs", str(n_jobs),
"--train-num-jobs", str(n_jobs),
"--calibration-num-jobs", str(n_jobs),
"--ignore-inequalities",
]
print("Running with args: %s" % args)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment