diff --git a/mhcflurry/allele_encoding.py b/mhcflurry/allele_encoding.py index 1fb7b1319f226b90a8c32bef513f049585bec46b..092ee7c48fd4677e435dff1e2eb3f66b3510b531 100644 --- a/mhcflurry/allele_encoding.py +++ b/mhcflurry/allele_encoding.py @@ -8,7 +8,7 @@ class AlleleEncoding(object): def __init__( self, alleles, - allele_to_fixed_length_sequence=None): + allele_to_fixed_length_sequence): """ A place to cache encodings for a (potentially large) sequence of alleles. @@ -18,22 +18,22 @@ class AlleleEncoding(object): Allele names allele_to_fixed_length_sequence : dict of str -> str - Allele name to fixed lengths sequence ("pseudosequence") + Allele name to fixed lengths sequence ("pseudosequence"), or a + pandas dataframe with allele names as the index and arbitrary values + to use for the encoding of those alleles """ - alleles = pandas.Series(alleles) + self.alleles = pandas.Series(alleles) - all_alleles = list(sorted(alleles.unique())) - - self.allele_to_index = dict( - (allele, i) - for (i, allele) in enumerate(all_alleles)) - - self.indices = alleles.map(self.allele_to_index) - - self.fixed_length_sequences = pandas.Series( - [allele_to_fixed_length_sequence[a] for a in all_alleles], - index=all_alleles) + if isinstance(allele_to_fixed_length_sequence, dict): + self.allele_to_fixed_length_sequence = pandas.DataFrame( + index=allele_to_fixed_length_sequence) + self.allele_to_fixed_length_sequence["value"] = ( + self.allele_to_fixed_length_sequence.index.map( + allele_to_fixed_length_sequence.get)) + else: + assert isinstance(allele_to_fixed_length_sequence, pandas.DataFrame) + self.allele_to_fixed_length_sequence = allele_to_fixed_length_sequence self.encoding_cache = {} @@ -48,6 +48,10 @@ class AlleleEncoding(object): One of "BLOSUM62", "one-hot", etc. Full list of supported vector encodings is given by available_vector_encodings() in amino_acid. + If a DataFrame was provided as `allele_to_fixed_length_sequence` + in the constructor, then those values will be used and this argument + will be ignored. + Returns ------- numpy.array with shape (num sequences, sequence length, m) where m is @@ -57,13 +61,34 @@ class AlleleEncoding(object): "fixed_length_vector_encoding", vector_encoding_name) if cache_key not in self.encoding_cache: - index_encoded_matrix = amino_acid.index_encoding( - self.fixed_length_sequences.values, - amino_acid.AMINO_ACID_INDEX) - vector_encoded = amino_acid.fixed_vectors_encoding( - index_encoded_matrix, - amino_acid.ENCODING_DATA_FRAMES[vector_encoding_name]) - result = vector_encoded[self.indices] + all_alleles = list(sorted(self.alleles.unique())) + allele_to_index = dict( + (allele, i) + for (i, allele) in enumerate(all_alleles)) + indices = self.alleles.map(allele_to_index) + + allele_to_fixed_length_sequence = self.allele_to_fixed_length_sequence.loc[ + all_alleles + ].copy() + + if list(allele_to_fixed_length_sequence) == ["value"]: + # Pseudosequence + index_encoded_matrix = amino_acid.index_encoding( + allele_to_fixed_length_sequence["value"].values, + amino_acid.AMINO_ACID_INDEX) + vector_encoded = amino_acid.fixed_vectors_encoding( + index_encoded_matrix, + amino_acid.ENCODING_DATA_FRAMES[vector_encoding_name]) + flattened = pandas.DataFrame( + vector_encoded.reshape( + (len(allele_to_fixed_length_sequence), -1))) + encoding_shape = vector_encoded.shape[1:] + else: + # Raw values + flattened = allele_to_fixed_length_sequence + encoding_shape = (allele_to_fixed_length_sequence.shape[1],) + result = flattened.iloc[indices].values.reshape( + (len(self.alleles),) + encoding_shape) self.encoding_cache[cache_key] = result return self.encoding_cache[cache_key] diff --git a/test/test_allele_encoding.py b/test/test_allele_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..caebb66e15d4571130ba0ede04c7f459a4d31deb --- /dev/null +++ b/test/test_allele_encoding.py @@ -0,0 +1,58 @@ +import time + +from mhcflurry.allele_encoding import AlleleEncoding +from mhcflurry.amino_acid import BLOSUM62_MATRIX +from nose.tools import eq_ +from numpy.testing import assert_equal +import numpy +import pandas + + +def test_allele_encoding_speed(): + encoding = AlleleEncoding( + ["A*02:01", "A*02:03", "A*02:01"], + { + "A*02:01": "AC", + "A*02:03": "AE", + } + ) + start = time.time() + encoding1 = encoding.fixed_length_vector_encoded_sequences("BLOSUM62") + assert_equal( + [ + [BLOSUM62_MATRIX["A"], BLOSUM62_MATRIX["C"]], + [BLOSUM62_MATRIX["A"], BLOSUM62_MATRIX["E"]], + [BLOSUM62_MATRIX["A"], BLOSUM62_MATRIX["C"]], + ], encoding1) + print("Simple encoding in %0.2f sec." % (time.time() - start)) + print(encoding1) + + encoding = AlleleEncoding( + ["A*02:01", "A*02:03", "A*02:01"] * int(1e5), + { + "A*02:01": "AC" * 16, + "A*02:03": "AE" * 16, + } + ) + start = time.time() + encoding1 = encoding.fixed_length_vector_encoded_sequences("BLOSUM62") + print("Long encoding in %0.2f sec." % (time.time() - start)) + + +def test_allele_encoding_raw_values(): + encoding = AlleleEncoding( + ["A*02:01", "A*02:03", "A*02:01"], + pandas.DataFrame( + [ + [0, 1, -1], + [10, 11, 12], + ], + index=["A*02:01", "A*02:03"])) + + encoding1 = encoding.fixed_length_vector_encoded_sequences("BLOSUM62") + assert_equal( + [ + [0, 1, -1], + [10, 11, 12], + [0, 1, -1], + ], encoding1)