diff --git a/mhcflurry/allele_encoding.py b/mhcflurry/allele_encoding.py index d627993783e6c809fbc57b2b1fce51150d10c6f8..61140b4653dc9f4702e632ca96c3d752a49887fb 100644 --- a/mhcflurry/allele_encoding.py +++ b/mhcflurry/allele_encoding.py @@ -1,10 +1,16 @@ import pandas from . import amino_acid +from .allele_encoding_transforms import TRANSFORMS class AlleleEncoding(object): - def __init__(self, alleles=None, allele_to_sequence=None, borrow_from=None): + def __init__( + self, + alleles=None, + allele_to_sequence=None, + transforms=None, + borrow_from=None): """ A place to cache encodings for a (potentially large) sequence of alleles. @@ -22,6 +28,11 @@ class AlleleEncoding(object): self.borrow_from = borrow_from self.allele_to_sequence = allele_to_sequence + if transforms is None: + transforms = dict( + (name, klass()) for (name, klass) in TRANSFORMS.items()) + self.transforms = transforms + if self.borrow_from is None: assert allele_to_sequence is not None all_alleles = ( @@ -41,6 +52,7 @@ class AlleleEncoding(object): self.allele_to_index = borrow_from.allele_to_index self.sequences = borrow_from.sequences self.allele_to_sequence = borrow_from.allele_to_sequence + self.transforms = borrow_from.transforms if alleles is not None: assert all( @@ -52,32 +64,50 @@ class AlleleEncoding(object): self.encoding_cache = {} - def allele_representations(self, vector_encoding_name): + def allele_representations(self, encoding_name): if self.borrow_from is not None: - return self.borrow_from.allele_representations(vector_encoding_name) + return self.borrow_from.allele_representations(encoding_name) cache_key = ( "allele_representations", - vector_encoding_name) + encoding_name) if cache_key not in self.encoding_cache: - index_encoded_matrix = amino_acid.index_encoding( - self.sequences.values, - amino_acid.AMINO_ACID_INDEX) - vector_encoded = amino_acid.fixed_vectors_encoding( - index_encoded_matrix, - amino_acid.ENCODING_DATA_FRAMES[vector_encoding_name]) + if ":" in encoding_name: + # Apply transform + (transform_name, rest) = encoding_name.split(":", 2) + preliminary_encoded = self.allele_representations(rest) + try: + transform = self.transforms[transform_name] + except KeyError: + raise KeyError( + "Unsupported transform: %s. Supported transforms: %s" % ( + transform_name, + " ".join(self.transforms) if self.transforms else "(none)")) + + vector_encoded = transform.transform(preliminary_encoded) + else: + # No transform. + index_encoded_matrix = amino_acid.index_encoding( + self.sequences.values, + amino_acid.AMINO_ACID_INDEX) + vector_encoded = amino_acid.fixed_vectors_encoding( + index_encoded_matrix, + amino_acid.ENCODING_DATA_FRAMES[encoding_name]) self.encoding_cache[cache_key] = vector_encoded return self.encoding_cache[cache_key] - def fixed_length_vector_encoded_sequences(self, vector_encoding_name): + def fixed_length_vector_encoded_sequences(self, encoding_name): """ Encode alleles. Parameters ---------- - vector_encoding_name : string + encoding_name : string How to represent amino acids. One of "BLOSUM62", "one-hot", etc. Full list of supported vector encodings is given by available_vector_encodings() in amino_acid. + + Also supported are names like pca:BLOSUM62, which would run the + "pca" transform on BLOSUM62-encoded sequences. Returns ------- numpy.array with shape (num sequences, sequence length, m) where m is @@ -85,9 +115,9 @@ class AlleleEncoding(object): """ cache_key = ( "fixed_length_vector_encoding", - vector_encoding_name) + encoding_name) if cache_key not in self.encoding_cache: - vector_encoded = self.allele_representations(vector_encoding_name) + vector_encoded = self.allele_representations(encoding_name) result = vector_encoded[self.indices] self.encoding_cache[cache_key] = result return self.encoding_cache[cache_key] diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py index b70337849017c6ddaf719c798415b3d2340506b6..211ca1ad508d2fda72aa92f4880ec2a486b53bf7 100644 --- a/mhcflurry/class1_affinity_predictor.py +++ b/mhcflurry/class1_affinity_predictor.py @@ -25,6 +25,7 @@ from .regression_target import to_ic50 from .version import __version__ from .ensemble_centrality import CENTRALITY_MEASURES from .allele_encoding import AlleleEncoding +from .allele_encoding_transforms import TRANSFORMS as ALLELE_ENCODING_TRANSFORMS # Default function for combining predictions across models in an ensemble. @@ -46,6 +47,7 @@ class Class1AffinityPredictor(object): allele_to_allele_specific_models=None, class1_pan_allele_models=None, allele_to_sequence=None, + allele_encoding_transforms=None, manifest_df=None, allele_to_percent_rank_transform=None, metadata_dataframes=None): @@ -80,8 +82,9 @@ class Class1AffinityPredictor(object): if class1_pan_allele_models is None: class1_pan_allele_models = [] - self.allele_to_sequence = allele_to_sequence + self.allele_encoding_transforms = ( + allele_encoding_transforms if allele_encoding_transforms else {}) self.master_allele_encoding = None if class1_pan_allele_models: assert self.allele_to_sequence @@ -350,6 +353,7 @@ class Class1AffinityPredictor(object): metadata_df_path = join(models_dir, "%s.csv.bz2" % name) df.to_csv(metadata_df_path, index=False, compression="bz2") + # Save allele sequences if self.allele_to_sequence is not None: allele_to_sequence_df = pandas.DataFrame( list(self.allele_to_sequence.items()), @@ -359,6 +363,18 @@ class Class1AffinityPredictor(object): join(models_dir, "allele_sequences.csv"), index=False) logging.info("Wrote: %s" % join(models_dir, "allele_sequences.csv")) + # Save allele encoding transforms + for transform in self.allele_encoding_transforms.values(): + if transform.is_fit(): + fit_data = transform.get_fit() + assert set(fit_data) == set(transform.serialization_keys) + for (serialization_key, fit_df) in fit_data.items(): + csv_path = join( + models_dir, + "%s.%s.csv" % (transform.name, serialization_key)) + fit_df.to_csv(csv_path) + logging.info("Wrote: %s" % csv_path) + if self.allele_to_percent_rank_transform: percent_ranks_df = None for (allele, transform) in self.allele_to_percent_rank_transform.items(): @@ -419,12 +435,37 @@ class Class1AffinityPredictor(object): manifest_df["model"] = all_models + # Load allele sequences allele_to_fixed_length_sequence = None if exists(join(models_dir, "allele_sequences.csv")): allele_to_fixed_length_sequence = pandas.read_csv( join(models_dir, "allele_sequences.csv"), index_col="allele").to_dict() + # Load allele encoding transforms + allele_encoding_transforms = {} + for transform_name in ALLELE_ENCODING_TRANSFORMS: + klass = ALLELE_ENCODING_TRANSFORMS[transform_name] + transform = klass() + restored_fit = {} + for serialization_key in klass.serialization_keys: + csv_path = join( + models_dir, + "%s.%s.csv" % (transform_name, serialization_key)) + if exists(csv_path): + restored_fit[serialization_key] = pandas.read_csv( + csv_path, index_col=0) + if restored_fit: + if set(restored_fit) != set(klass.serialization_keys): + logging.warning( + "Missing some allele encoding transform serialization " + "data from %s. Found: %s. Expected: %s." % ( + models_dir, + str(set(restored_fit)), + str(set(klass.serialization_keys)))) + transform.restore_fit(restored_fit) + allele_encoding_transforms[transform_name] = transform + allele_to_percent_rank_transform = {} percent_ranks_path = join(models_dir, "percent_ranks.csv") if exists(percent_ranks_path): @@ -494,7 +535,8 @@ class Class1AffinityPredictor(object): self.master_allele_encoding.allele_to_sequence != self.allele_to_sequence): self.master_allele_encoding = AlleleEncoding( - allele_to_sequence=self.allele_to_sequence) + allele_to_sequence=self.allele_to_sequence, + transforms=self.allele_encoding_transforms) return self.master_allele_encoding def fit_allele_specific_predictors( diff --git a/test/test_allele_encoding.py b/test/test_allele_encoding.py index e8b772904b0f221a08f712579c909e093c68c0c2..ef526e2d67507f6396854485266c7100e37ae7fe 100644 --- a/test/test_allele_encoding.py +++ b/test/test_allele_encoding.py @@ -37,3 +37,18 @@ def test_allele_encoding_speed(): start = time.time() encoding1 = encoding.fixed_length_vector_encoded_sequences("BLOSUM62") print("Long encoding in %0.2f sec." % (time.time() - start)) + + +def test_pca(): + encoding = AlleleEncoding( + ["A*02:01", "A*02:03", "A*02:01"], + { + "A*02:01": "AC", + "A*02:03": "AE", + } + ) + encoded1 = encoding.fixed_length_vector_encoded_sequences("pca:BLOSUM62") + + numpy.testing.assert_array_equal(encoded1[0], encoded1[2]) + assert not numpy.array_equal(encoded1[0], encoded1[1]) + print(encoded1) \ No newline at end of file