Skip to content
Snippets Groups Projects
Commit 098f5e71 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

redo parallelization implementation of percentile rank calibration

parent ad9a450e
No related merge requests found
...@@ -889,8 +889,7 @@ class Class1AffinityPredictor(object): ...@@ -889,8 +889,7 @@ class Class1AffinityPredictor(object):
peptides=None, peptides=None,
num_peptides_per_length=int(1e5), num_peptides_per_length=int(1e5),
alleles=None, alleles=None,
bins=None, bins=None):
worker_pool=None):
""" """
Compute the cumulative distribution of ic50 values for a set of alleles Compute the cumulative distribution of ic50 values for a set of alleles
over a large universe of random peptides, to enable computing quantiles in over a large universe of random peptides, to enable computing quantiles in
...@@ -898,7 +897,7 @@ class Class1AffinityPredictor(object): ...@@ -898,7 +897,7 @@ class Class1AffinityPredictor(object):
Parameters Parameters
---------- ----------
peptides : sequence of string, optional peptides : sequence of string or EncodableSequences, optional
Peptides to use Peptides to use
num_peptides_per_length : int, optional num_peptides_per_length : int, optional
If peptides argument is not specified, then num_peptides_per_length If peptides argument is not specified, then num_peptides_per_length
...@@ -911,8 +910,10 @@ class Class1AffinityPredictor(object): ...@@ -911,8 +910,10 @@ class Class1AffinityPredictor(object):
Anything that can be passed to numpy.histogram's "bins" argument Anything that can be passed to numpy.histogram's "bins" argument
can be used here, i.e. either an integer or a sequence giving bin can be used here, i.e. either an integer or a sequence giving bin
edges. This is in ic50 space. edges. This is in ic50 space.
worker_pool : multiprocessing.Pool, optional
If specified multiple alleles will be calibrated in parallel Returns
----------
EncodableSequences : peptides used for calibration
""" """
if bins is None: if bins is None:
bins = to_ic50(numpy.linspace(1, 0, 1000)) bins = to_ic50(numpy.linspace(1, 0, 1000))
...@@ -931,57 +932,12 @@ class Class1AffinityPredictor(object): ...@@ -931,57 +932,12 @@ class Class1AffinityPredictor(object):
encoded_peptides = EncodableSequences.create(peptides) encoded_peptides = EncodableSequences.create(peptides)
if worker_pool and len(alleles) > 1: for (i, allele) in enumerate(alleles):
# Run in parallel predictions = self.predict(peptides, allele=allele)
transform = PercentRankTransform()
# Performance hack. transform.fit(predictions, bins=bins)
self.neural_networks[0].peptides_to_network_input(encoded_peptides) self.allele_to_percent_rank_transform[allele] = transform
do_work = partial(
_calibrate_percentile_ranks,
predictor=self,
peptides=encoded_peptides,
bins=bins)
list_of_singleton_alleles = [ [allele] for allele in alleles ]
results = worker_pool.imap_unordered(
do_work, list_of_singleton_alleles, chunksize=1)
# Add progress bar
results = tqdm.tqdm(results, ascii=True, total=len(alleles))
# Merge results return encoded_peptides
for partial_dict in results:
self.allele_to_percent_rank_transform.update(partial_dict)
else:
# Run in serial
self.allele_to_percent_rank_transform.update(
_calibrate_percentile_ranks(
alleles=alleles,
predictor=self,
peptides=encoded_peptides,
bins=bins))
def _calibrate_percentile_ranks(alleles, predictor, peptides, bins):
"""
Private helper function.
Parameters
----------
alleles : list of string
predictor : Class1AffinityPredictor
peptides : list of string or EncodableSequences
bins : object
Returns
-------
dict : allele -> percentile rank transform
"""
result = {}
for (i, allele) in enumerate(alleles):
predictions = predictor.predict(peptides, allele=allele)
transform = PercentRankTransform()
transform.fit(predictions, bins=bins)
result[allele] = transform
return result
...@@ -8,6 +8,7 @@ import sys ...@@ -8,6 +8,7 @@ import sys
import time import time
import traceback import traceback
from multiprocessing import Pool from multiprocessing import Pool
from functools import partial
import pandas import pandas
import yaml import yaml
...@@ -17,6 +18,9 @@ import tqdm # progress bar ...@@ -17,6 +18,9 @@ import tqdm # progress bar
from .class1_affinity_predictor import Class1AffinityPredictor from .class1_affinity_predictor import Class1AffinityPredictor
from .common import configure_logging, set_keras_backend from .common import configure_logging, set_keras_backend
GLOBAL_DATA = {}
parser = argparse.ArgumentParser(usage=__doc__) parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument( parser.add_argument(
...@@ -106,6 +110,8 @@ parser.add_argument( ...@@ -106,6 +110,8 @@ parser.add_argument(
help="Keras backend. If not specified will use system default.") help="Keras backend. If not specified will use system default.")
def run(argv=sys.argv[1:]): def run(argv=sys.argv[1:]):
global GLOBAL_DATA
# On sigusr1 print stack trace # On sigusr1 print stack trace
print("To show stack trace, run:\nkill -s USR1 %d" % os.getpid()) print("To show stack trace, run:\nkill -s USR1 %d" % os.getpid())
signal.signal(signal.SIGUSR1, lambda sig, frame: traceback.print_stack()) signal.signal(signal.SIGUSR1, lambda sig, frame: traceback.print_stack())
...@@ -210,7 +216,7 @@ def run(argv=sys.argv[1:]): ...@@ -210,7 +216,7 @@ def run(argv=sys.argv[1:]):
predictors = list( predictors = list(
tqdm.tqdm( tqdm.tqdm(
worker_pool.imap_unordered( worker_pool.imap_unordered(
work_entrypoint, work_items, chunksize=1), train_model_entrypoint, work_items, chunksize=1),
ascii=True, ascii=True,
total=len(work_items))) total=len(work_items)))
...@@ -225,7 +231,7 @@ def run(argv=sys.argv[1:]): ...@@ -225,7 +231,7 @@ def run(argv=sys.argv[1:]):
start = time.time() start = time.time()
for _ in tqdm.trange(len(work_items)): for _ in tqdm.trange(len(work_items)):
item = work_items.pop(0) # want to keep freeing up memory item = work_items.pop(0) # want to keep freeing up memory
work_predictor = work_entrypoint(item) work_predictor = train_model_entrypoint(item)
assert work_predictor is predictor assert work_predictor is predictor
assert not work_items assert not work_items
...@@ -240,24 +246,46 @@ def run(argv=sys.argv[1:]): ...@@ -240,24 +246,46 @@ def run(argv=sys.argv[1:]):
worker_pool.join() worker_pool.join()
if args.percent_rank_calibration_num_peptides_per_length > 0: if args.percent_rank_calibration_num_peptides_per_length > 0:
alleles = list(predictor.supported_alleles)
first_allele = alleles.pop(0)
print("Performing percent rank calibration. Calibrating first allele.")
start = time.time()
encoded_peptides = predictor.calibrate_percentile_ranks(
alleles=[first_allele],
num_peptides_per_length=args.percent_rank_calibration_num_peptides_per_length)
percent_rank_calibration_time = time.time() - start
print("Finished calibrating percent ranks for first allele in %0.2f sec." % (
percent_rank_calibration_time))
print("Calibrating %d additional alleles." % len(alleles))
if args.calibration_num_jobs == 1: if args.calibration_num_jobs == 1:
# Serial run # Serial run
worker_pool = None worker_pool = None
results = (
calibrate_percentile_ranks(
allele=allele,
predictor=predictor,
peptides=encoded_peptides)
for allele in alleles)
else: else:
# Parallel run
# Store peptides in global variable so they are in shared memory
# after fork, instead of needing to be pickled.
GLOBAL_DATA["calibration_peptides"] = encoded_peptides
worker_pool = Pool( worker_pool = Pool(
processes=( processes=(
args.calibration_num_jobs args.calibration_num_jobs
if args.train_num_jobs else None)) if args.train_num_jobs else None))
print("Using worker pool: %s" % str(worker_pool)) print("Using worker pool: %s" % str(worker_pool))
results = worker_pool.imap_unordered(
print("Performing percent rank calibration.") partial(
start = time.time() calibrate_percentile_ranks,
predictor.calibrate_percentile_ranks( predictor=predictor), alleles, chunksize=1)
num_peptides_per_length=args.percent_rank_calibration_num_peptides_per_length,
worker_pool=worker_pool) for result in tqdm.tqdm(results, ascii=True, total=len(alleles)):
percent_rank_calibration_time = time.time() - start predictor.allele_to_percent_rank_transform.update(result)
print("Finished calibrating percent ranks in %0.2f sec." % ( print("Done calibrating %d additional alleles." % len(alleles))
percent_rank_calibration_time))
predictor.save(args.out_models_dir, model_names_to_write=[]) predictor.save(args.out_models_dir, model_names_to_write=[])
if worker_pool: if worker_pool:
...@@ -269,11 +297,11 @@ def run(argv=sys.argv[1:]): ...@@ -269,11 +297,11 @@ def run(argv=sys.argv[1:]):
print("Predictor written to: %s" % args.out_models_dir) print("Predictor written to: %s" % args.out_models_dir)
def work_entrypoint(item): def train_model_entrypoint(item):
return process_work(**item) return train_model(**item)
def process_work( def train_model(
model_group, model_group,
n_models, n_models,
allele_num, allele_num,
...@@ -325,5 +353,19 @@ def process_work( ...@@ -325,5 +353,19 @@ def process_work(
return predictor return predictor
def calibrate_percentile_ranks(allele, predictor, peptides=None):
"""
Private helper function.
"""
if peptides is None:
peptides = GLOBAL_DATA["calibration_peptides"]
predictor.calibrate_percentile_ranks(
peptides=peptides,
alleles=[allele])
return {
allele: predictor.allele_to_percent_rank_transform[allele],
}
if __name__ == '__main__': if __name__ == '__main__':
run() run()
...@@ -62,7 +62,8 @@ def run_and_check(n_jobs=0): ...@@ -62,7 +62,8 @@ def run_and_check(n_jobs=0):
"--allele", "HLA-A*02:01", "HLA-A*01:01", "HLA-A*03:01", "--allele", "HLA-A*02:01", "HLA-A*01:01", "HLA-A*03:01",
"--out-models-dir", models_dir, "--out-models-dir", models_dir,
"--percent-rank-calibration-num-peptides-per-length", "10000", "--percent-rank-calibration-num-peptides-per-length", "10000",
"--parallelization-num-jobs", str(n_jobs), "--train-num-jobs", str(n_jobs),
"--calibration-num-jobs", str(n_jobs),
"--ignore-inequalities", "--ignore-inequalities",
] ]
print("Running with args: %s" % args) print("Running with args: %s" % args)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment