Skip to content
Snippets Groups Projects
Commit 5186ccf0 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Support calibrating percentile ranks in parallel

parent 67f17c27
No related branches found
No related tags found
No related merge requests found
...@@ -9,10 +9,12 @@ from os.path import join, exists ...@@ -9,10 +9,12 @@ from os.path import join, exists
from os import mkdir from os import mkdir
from socket import gethostname from socket import gethostname
from getpass import getuser from getpass import getuser
from functools import partial
import mhcnames import mhcnames
import numpy import numpy
import pandas import pandas
import tqdm # progress bars
from numpy.testing import assert_equal from numpy.testing import assert_equal
from six import string_types from six import string_types
...@@ -588,7 +590,8 @@ class Class1AffinityPredictor(object): ...@@ -588,7 +590,8 @@ class Class1AffinityPredictor(object):
num_peptides_per_length=int(1e5), num_peptides_per_length=int(1e5),
alleles=None, alleles=None,
bins=None, bins=None,
quiet=False): quiet=False,
worker_pool=None):
""" """
Compute the cumulative distribution of ic50 values for a set of alleles Compute the cumulative distribution of ic50 values for a set of alleles
over a large universe of random peptides, to enable computing quantiles in over a large universe of random peptides, to enable computing quantiles in
...@@ -627,6 +630,8 @@ class Class1AffinityPredictor(object): ...@@ -627,6 +630,8 @@ class Class1AffinityPredictor(object):
peptides.extend( peptides.extend(
random_peptides(num_peptides_per_length, length)) random_peptides(num_peptides_per_length, length))
if quiet: if quiet:
def msg(s): def msg(s):
pass pass
...@@ -948,3 +953,100 @@ class Class1AffinityPredictor(object): ...@@ -948,3 +953,100 @@ class Class1AffinityPredictor(object):
] ]
loaded.close() loaded.close()
return weights return weights
def calibrate_percentile_ranks(
self,
peptides=None,
num_peptides_per_length=int(1e5),
alleles=None,
bins=None,
worker_pool=None):
"""
Compute the cumulative distribution of ic50 values for a set of alleles
over a large universe of random peptides, to enable computing quantiles in
this distribution later.
Parameters
----------
peptides : sequence of string, optional
Peptides to use
num_peptides_per_length : int, optional
If peptides argument is not specified, then num_peptides_per_length
peptides are randomly sampled from a uniform distribution for each
supported length
alleles : sequence of string, optional
Alleles to perform calibration for. If not specified all supported
alleles will be calibrated.
bins : object
Anything that can be passed to numpy.histogram's "bins" argument
can be used here, i.e. either an integer or a sequence giving bin
edges. This is in ic50 space.
worker_pool : multiprocessing.Pool, optional
If specified multiple alleles will be calibrated in parallel
"""
if bins is None:
bins = to_ic50(numpy.linspace(1, 0, 1000))
if alleles is None:
alleles = self.supported_alleles
if peptides is None:
peptides = []
lengths = range(
self.supported_peptide_lengths[0],
self.supported_peptide_lengths[1] + 1)
for length in lengths:
peptides.extend(
random_peptides(num_peptides_per_length, length))
encoded_peptides = EncodableSequences.create(peptides)
if worker_pool and len(alleles) > 1:
# Run in parallel
do_work = partial(
_calibrate_percentile_ranks,
predictor=self,
peptides=encoded_peptides,
bins=bins)
list_of_singleton_alleles = [ [allele] for allele in alleles ]
results = worker_pool.imap_unordered(
do_work, list_of_singleton_alleles, chunksize=1)
# Add progress bar
results = tqdm.tqdm(results, ascii=True, total=len(alleles))
# Merge results
for partial_dict in results:
self.allele_to_percent_rank_transform.update(partial_dict)
else:
# Run in serial
self.allele_to_percent_rank_transform.update(
_calibrate_percentile_ranks(
alleles=alleles,
predictor=self,
peptides=encoded_peptides,
bins=bins))
def _calibrate_percentile_ranks(alleles, predictor, peptides, bins):
"""
Private helper function.
Parameters
----------
alleles
predictor
peptides
bins
Returns
-------
"""
result = {}
for (i, allele) in enumerate(alleles):
predictions = predictor.predict(peptides, allele=allele)
transform = PercentRankTransform()
transform.fit(predictions, bins=bins)
result[allele] = transform
return result
...@@ -215,36 +215,26 @@ def run(argv=sys.argv[1:]): ...@@ -215,36 +215,26 @@ def run(argv=sys.argv[1:]):
# as it goes so no saving is required at the end. # as it goes so no saving is required at the end.
start = time.time() start = time.time()
data_trained_on = 0 data_trained_on = 0
while work_items:
item = work_items.pop(0) for _ in tqdm.trange(len(work_items)):
item = work_items.pop(0) # want to keep freeing up memory
work_predictor = work_entrypoint(item) work_predictor = work_entrypoint(item)
assert work_predictor is predictor assert work_predictor is predictor
# When running in serial we try to estimate time remaining.
data_trained_on += len(item['data'])
progress = float(data_trained_on) / total_data_to_train_on
time_elapsed = time.time() - start
total_time = time_elapsed / progress
print(
"Estimated total training time: %0.2f min, "
"remaining: %0.2f min" % (
total_time / 60,
(total_time - time_elapsed) / 60))
if worker_pool:
worker_pool.close()
worker_pool.join()
if args.percent_rank_calibration_num_peptides_per_length > 0: if args.percent_rank_calibration_num_peptides_per_length > 0:
start = time.time() start = time.time()
print("Performing percent rank calibration.") print("Performing percent rank calibration.")
predictor.calibrate_percentile_ranks( predictor.calibrate_percentile_ranks(
num_peptides_per_length=args.percent_rank_calibration_num_peptides_per_length) num_peptides_per_length=args.percent_rank_calibration_num_peptides_per_length,
worker_pool=worker_pool)
print("Finished calibrating percent ranks in %0.2f sec." % ( print("Finished calibrating percent ranks in %0.2f sec." % (
time.time() - start)) time.time() - start))
predictor.save(args.out_models_dir, model_names_to_write=[]) predictor.save(args.out_models_dir, model_names_to_write=[])
if worker_pool:
worker_pool.close()
worker_pool.join()
def work_entrypoint(item): def work_entrypoint(item):
return process_work(**item) return process_work(**item)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment