From 419853382a008de3d498664ebedc82c526351419 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Mon, 27 Nov 2017 17:35:36 -0500
Subject: [PATCH] Much faster performance (~5X improvement) in amino acid
 encoding

---
 mhcflurry/amino_acid.py          | 23 ++++++++++++++---------
 mhcflurry/encodable_sequences.py |  2 +-
 test/test_amino_acid.py          | 15 ++++++++++-----
 3 files changed, 25 insertions(+), 15 deletions(-)

diff --git a/mhcflurry/amino_acid.py b/mhcflurry/amino_acid.py
index eba9de5f..c412dcad 100644
--- a/mhcflurry/amino_acid.py
+++ b/mhcflurry/amino_acid.py
@@ -21,7 +21,6 @@ import collections
 from copy import copy
 
 import pandas
-import numpy
 from six import StringIO
 
 
@@ -136,21 +135,27 @@ def index_encoding(sequences, letter_to_index_dict):
     return result.values
 
 
-def fixed_vectors_encoding(sequences, letter_to_vector_function):
+def fixed_vectors_encoding(sequences, letter_to_vector_df):
     """
-    Given a sequence of n strings all of length k, return a n * k * m array where
-    the (i, j)th element is letter_to_vector_function(sequence[i][j]).
+    Given a sequence of n strings all of length k, and a dataframe mapping each
+    character to an arbitrary vector, return a n * k * m array where
+    the (i, j)th element is letter_to_vector_df.loc[sequence[i][j]].
 
     Parameters
     ----------
     sequences : list of length n of strings of length k
-    letter_to_vector_function : function : string -> vector of length m
+    letter_to_vector_df : pandas.DataFrame of shape (alphabet size, m)
+        The index of the dataframe should be amino acid characters.
 
     Returns
     -------
     numpy.array of integers with shape (n, k, m)
     """
-    arr = numpy.array([list(s) for s in sequences])
-    result = numpy.vectorize(
-        letter_to_vector_function, signature='()->(n)')(arr)
-    return result
\ No newline at end of file
+    target_shape = (
+        len(sequences),
+        len(sequences[0]),
+        letter_to_vector_df.shape[0])
+    result = letter_to_vector_df.loc[
+            (letter for seq in sequences for letter in seq)
+    ].values.reshape(target_shape)
+    return result
diff --git a/mhcflurry/encodable_sequences.py b/mhcflurry/encodable_sequences.py
index 67553d0f..ca71ba68 100644
--- a/mhcflurry/encodable_sequences.py
+++ b/mhcflurry/encodable_sequences.py
@@ -140,7 +140,7 @@ class EncodableSequences(object):
             ]
             result = amino_acid.fixed_vectors_encoding(
                 fixed_length_sequences,
-                amino_acid.ENCODING_DFS[vector_encoding_name].loc.__getitem__)
+                amino_acid.ENCODING_DFS[vector_encoding_name])
             assert result.shape[0] == len(self.sequences)
             self.encoding_cache[cache_key] = result
         return self.encoding_cache[cache_key]
diff --git a/test/test_amino_acid.py b/test/test_amino_acid.py
index 26c50b7b..5f7eb830 100644
--- a/test/test_amino_acid.py
+++ b/test/test_amino_acid.py
@@ -2,6 +2,7 @@ from mhcflurry import amino_acid
 from nose.tools import eq_
 from numpy.testing import assert_equal
 import numpy
+import pandas
 
 letter_to_index_dict = {
     'A': 0,
@@ -11,6 +12,14 @@ letter_to_index_dict = {
 
 
 def test_index_and_one_hot_encoding():
+    letter_to_vector_df = pandas.DataFrame(
+        [
+            [1, 0, 0,],
+            [0, 1, 0,],
+            [0, 0, 1,]
+        ], columns=[0, 1, 2]
+    )
+
     index_encoding = amino_acid.index_encoding(
         ["AAAA", "ABCA"], letter_to_index_dict)
     assert_equal(
@@ -21,11 +30,7 @@ def test_index_and_one_hot_encoding():
         ])
     one_hot = amino_acid.fixed_vectors_encoding(
         index_encoding,
-        {
-            0: numpy.array([1, 0, 0]),
-            1: numpy.array([0, 1, 0]),
-            2: numpy.array([0, 0, 1]),
-        }.get)
+        letter_to_vector_df)
     eq_(one_hot.shape, (2, 4, 3))
     assert_equal(
         one_hot[0],
-- 
GitLab