diff --git a/mhcflurry/amino_acid.py b/mhcflurry/amino_acid.py index eba9de5f1c04b05b0a39c89dc30fc69612b4ad53..c412dcad956abbd91b16cd796e9aded916608247 100644 --- a/mhcflurry/amino_acid.py +++ b/mhcflurry/amino_acid.py @@ -21,7 +21,6 @@ import collections from copy import copy import pandas -import numpy from six import StringIO @@ -136,21 +135,27 @@ def index_encoding(sequences, letter_to_index_dict): return result.values -def fixed_vectors_encoding(sequences, letter_to_vector_function): +def fixed_vectors_encoding(sequences, letter_to_vector_df): """ - Given a sequence of n strings all of length k, return a n * k * m array where - the (i, j)th element is letter_to_vector_function(sequence[i][j]). + Given a sequence of n strings all of length k, and a dataframe mapping each + character to an arbitrary vector, return a n * k * m array where + the (i, j)th element is letter_to_vector_df.loc[sequence[i][j]]. Parameters ---------- sequences : list of length n of strings of length k - letter_to_vector_function : function : string -> vector of length m + letter_to_vector_df : pandas.DataFrame of shape (alphabet size, m) + The index of the dataframe should be amino acid characters. Returns ------- numpy.array of integers with shape (n, k, m) """ - arr = numpy.array([list(s) for s in sequences]) - result = numpy.vectorize( - letter_to_vector_function, signature='()->(n)')(arr) - return result \ No newline at end of file + target_shape = ( + len(sequences), + len(sequences[0]), + letter_to_vector_df.shape[0]) + result = letter_to_vector_df.loc[ + (letter for seq in sequences for letter in seq) + ].values.reshape(target_shape) + return result diff --git a/mhcflurry/encodable_sequences.py b/mhcflurry/encodable_sequences.py index 67553d0fd2ae9983b48e60727d81b3a2950ac97f..ca71ba68bcfd58fce9d82daf52365e19ef014433 100644 --- a/mhcflurry/encodable_sequences.py +++ b/mhcflurry/encodable_sequences.py @@ -140,7 +140,7 @@ class EncodableSequences(object): ] result = amino_acid.fixed_vectors_encoding( fixed_length_sequences, - amino_acid.ENCODING_DFS[vector_encoding_name].loc.__getitem__) + amino_acid.ENCODING_DFS[vector_encoding_name]) assert result.shape[0] == len(self.sequences) self.encoding_cache[cache_key] = result return self.encoding_cache[cache_key] diff --git a/test/test_amino_acid.py b/test/test_amino_acid.py index 26c50b7bfc6e47c91d42bcd7d0747cd2581d5011..5f7eb83011dd7ada8b5d1cda3f5ff9f8c8a23032 100644 --- a/test/test_amino_acid.py +++ b/test/test_amino_acid.py @@ -2,6 +2,7 @@ from mhcflurry import amino_acid from nose.tools import eq_ from numpy.testing import assert_equal import numpy +import pandas letter_to_index_dict = { 'A': 0, @@ -11,6 +12,14 @@ letter_to_index_dict = { def test_index_and_one_hot_encoding(): + letter_to_vector_df = pandas.DataFrame( + [ + [1, 0, 0,], + [0, 1, 0,], + [0, 0, 1,] + ], columns=[0, 1, 2] + ) + index_encoding = amino_acid.index_encoding( ["AAAA", "ABCA"], letter_to_index_dict) assert_equal( @@ -21,11 +30,7 @@ def test_index_and_one_hot_encoding(): ]) one_hot = amino_acid.fixed_vectors_encoding( index_encoding, - { - 0: numpy.array([1, 0, 0]), - 1: numpy.array([0, 1, 0]), - 2: numpy.array([0, 0, 1]), - }.get) + letter_to_vector_df) eq_(one_hot.shape, (2, 4, 3)) assert_equal( one_hot[0],