Skip to content
Snippets Groups Projects
Commit a96ecdd1 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

add support for allele encoding transforms (for pca)

parent 257e39ad
No related merge requests found
import pandas
from . import amino_acid
from .allele_encoding_transforms import TRANSFORMS
class AlleleEncoding(object):
def __init__(self, alleles=None, allele_to_sequence=None, borrow_from=None):
def __init__(
self,
alleles=None,
allele_to_sequence=None,
transforms=None,
borrow_from=None):
"""
A place to cache encodings for a (potentially large) sequence of alleles.
......@@ -22,6 +28,11 @@ class AlleleEncoding(object):
self.borrow_from = borrow_from
self.allele_to_sequence = allele_to_sequence
if transforms is None:
transforms = dict(
(name, klass()) for (name, klass) in TRANSFORMS.items())
self.transforms = transforms
if self.borrow_from is None:
assert allele_to_sequence is not None
all_alleles = (
......@@ -41,6 +52,7 @@ class AlleleEncoding(object):
self.allele_to_index = borrow_from.allele_to_index
self.sequences = borrow_from.sequences
self.allele_to_sequence = borrow_from.allele_to_sequence
self.transforms = borrow_from.transforms
if alleles is not None:
assert all(
......@@ -52,32 +64,50 @@ class AlleleEncoding(object):
self.encoding_cache = {}
def allele_representations(self, vector_encoding_name):
def allele_representations(self, encoding_name):
if self.borrow_from is not None:
return self.borrow_from.allele_representations(vector_encoding_name)
return self.borrow_from.allele_representations(encoding_name)
cache_key = (
"allele_representations",
vector_encoding_name)
encoding_name)
if cache_key not in self.encoding_cache:
index_encoded_matrix = amino_acid.index_encoding(
self.sequences.values,
amino_acid.AMINO_ACID_INDEX)
vector_encoded = amino_acid.fixed_vectors_encoding(
index_encoded_matrix,
amino_acid.ENCODING_DATA_FRAMES[vector_encoding_name])
if ":" in encoding_name:
# Apply transform
(transform_name, rest) = encoding_name.split(":", 2)
preliminary_encoded = self.allele_representations(rest)
try:
transform = self.transforms[transform_name]
except KeyError:
raise KeyError(
"Unsupported transform: %s. Supported transforms: %s" % (
transform_name,
" ".join(self.transforms) if self.transforms else "(none)"))
vector_encoded = transform.transform(preliminary_encoded)
else:
# No transform.
index_encoded_matrix = amino_acid.index_encoding(
self.sequences.values,
amino_acid.AMINO_ACID_INDEX)
vector_encoded = amino_acid.fixed_vectors_encoding(
index_encoded_matrix,
amino_acid.ENCODING_DATA_FRAMES[encoding_name])
self.encoding_cache[cache_key] = vector_encoded
return self.encoding_cache[cache_key]
def fixed_length_vector_encoded_sequences(self, vector_encoding_name):
def fixed_length_vector_encoded_sequences(self, encoding_name):
"""
Encode alleles.
Parameters
----------
vector_encoding_name : string
encoding_name : string
How to represent amino acids.
One of "BLOSUM62", "one-hot", etc. Full list of supported vector
encodings is given by available_vector_encodings() in amino_acid.
Also supported are names like pca:BLOSUM62, which would run the
"pca" transform on BLOSUM62-encoded sequences.
Returns
-------
numpy.array with shape (num sequences, sequence length, m) where m is
......@@ -85,9 +115,9 @@ class AlleleEncoding(object):
"""
cache_key = (
"fixed_length_vector_encoding",
vector_encoding_name)
encoding_name)
if cache_key not in self.encoding_cache:
vector_encoded = self.allele_representations(vector_encoding_name)
vector_encoded = self.allele_representations(encoding_name)
result = vector_encoded[self.indices]
self.encoding_cache[cache_key] = result
return self.encoding_cache[cache_key]
......
......@@ -25,6 +25,7 @@ from .regression_target import to_ic50
from .version import __version__
from .ensemble_centrality import CENTRALITY_MEASURES
from .allele_encoding import AlleleEncoding
from .allele_encoding_transforms import TRANSFORMS as ALLELE_ENCODING_TRANSFORMS
# Default function for combining predictions across models in an ensemble.
......@@ -46,6 +47,7 @@ class Class1AffinityPredictor(object):
allele_to_allele_specific_models=None,
class1_pan_allele_models=None,
allele_to_sequence=None,
allele_encoding_transforms=None,
manifest_df=None,
allele_to_percent_rank_transform=None,
metadata_dataframes=None):
......@@ -80,8 +82,9 @@ class Class1AffinityPredictor(object):
if class1_pan_allele_models is None:
class1_pan_allele_models = []
self.allele_to_sequence = allele_to_sequence
self.allele_encoding_transforms = (
allele_encoding_transforms if allele_encoding_transforms else {})
self.master_allele_encoding = None
if class1_pan_allele_models:
assert self.allele_to_sequence
......@@ -350,6 +353,7 @@ class Class1AffinityPredictor(object):
metadata_df_path = join(models_dir, "%s.csv.bz2" % name)
df.to_csv(metadata_df_path, index=False, compression="bz2")
# Save allele sequences
if self.allele_to_sequence is not None:
allele_to_sequence_df = pandas.DataFrame(
list(self.allele_to_sequence.items()),
......@@ -359,6 +363,18 @@ class Class1AffinityPredictor(object):
join(models_dir, "allele_sequences.csv"), index=False)
logging.info("Wrote: %s" % join(models_dir, "allele_sequences.csv"))
# Save allele encoding transforms
for transform in self.allele_encoding_transforms.values():
if transform.is_fit():
fit_data = transform.get_fit()
assert set(fit_data) == set(transform.serialization_keys)
for (serialization_key, fit_df) in fit_data.items():
csv_path = join(
models_dir,
"%s.%s.csv" % (transform.name, serialization_key))
fit_df.to_csv(csv_path)
logging.info("Wrote: %s" % csv_path)
if self.allele_to_percent_rank_transform:
percent_ranks_df = None
for (allele, transform) in self.allele_to_percent_rank_transform.items():
......@@ -419,12 +435,37 @@ class Class1AffinityPredictor(object):
manifest_df["model"] = all_models
# Load allele sequences
allele_to_fixed_length_sequence = None
if exists(join(models_dir, "allele_sequences.csv")):
allele_to_fixed_length_sequence = pandas.read_csv(
join(models_dir, "allele_sequences.csv"),
index_col="allele").to_dict()
# Load allele encoding transforms
allele_encoding_transforms = {}
for transform_name in ALLELE_ENCODING_TRANSFORMS:
klass = ALLELE_ENCODING_TRANSFORMS[transform_name]
transform = klass()
restored_fit = {}
for serialization_key in klass.serialization_keys:
csv_path = join(
models_dir,
"%s.%s.csv" % (transform_name, serialization_key))
if exists(csv_path):
restored_fit[serialization_key] = pandas.read_csv(
csv_path, index_col=0)
if restored_fit:
if set(restored_fit) != set(klass.serialization_keys):
logging.warning(
"Missing some allele encoding transform serialization "
"data from %s. Found: %s. Expected: %s." % (
models_dir,
str(set(restored_fit)),
str(set(klass.serialization_keys))))
transform.restore_fit(restored_fit)
allele_encoding_transforms[transform_name] = transform
allele_to_percent_rank_transform = {}
percent_ranks_path = join(models_dir, "percent_ranks.csv")
if exists(percent_ranks_path):
......@@ -494,7 +535,8 @@ class Class1AffinityPredictor(object):
self.master_allele_encoding.allele_to_sequence !=
self.allele_to_sequence):
self.master_allele_encoding = AlleleEncoding(
allele_to_sequence=self.allele_to_sequence)
allele_to_sequence=self.allele_to_sequence,
transforms=self.allele_encoding_transforms)
return self.master_allele_encoding
def fit_allele_specific_predictors(
......
......@@ -37,3 +37,18 @@ def test_allele_encoding_speed():
start = time.time()
encoding1 = encoding.fixed_length_vector_encoded_sequences("BLOSUM62")
print("Long encoding in %0.2f sec." % (time.time() - start))
def test_pca():
encoding = AlleleEncoding(
["A*02:01", "A*02:03", "A*02:01"],
{
"A*02:01": "AC",
"A*02:03": "AE",
}
)
encoded1 = encoding.fixed_length_vector_encoded_sequences("pca:BLOSUM62")
numpy.testing.assert_array_equal(encoded1[0], encoded1[2])
assert not numpy.array_equal(encoded1[0], encoded1[1])
print(encoded1)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment