From 6a292951d27c7c3535a95ba4975165cd22e2b216 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Wed, 23 May 2018 13:24:35 -0400
Subject: [PATCH] more efficient allele encoding

---
 mhcflurry/allele_encoding.py | 17 ++++++++---------
 test/test_allele_encoding.py |  2 +-
 2 files changed, 9 insertions(+), 10 deletions(-)

diff --git a/mhcflurry/allele_encoding.py b/mhcflurry/allele_encoding.py
index 092ee7c4..46ab3d8c 100644
--- a/mhcflurry/allele_encoding.py
+++ b/mhcflurry/allele_encoding.py
@@ -54,8 +54,13 @@ class AlleleEncoding(object):
 
         Returns
         -------
-        numpy.array with shape (num sequences, sequence length, m) where m is
+        list of numpy arrays. Pass it to numpy.array to get an array with shape
+        (num sequences, sequence length, m) where m is
         vector_encoding_length(vector_encoding_name)
+
+        The reason to return a list instead of an array is that the list can
+        use much less memory in the common case where many of the rows are
+        the same.
         """
         cache_key = (
             "fixed_length_vector_encoding",
@@ -79,16 +84,10 @@ class AlleleEncoding(object):
                 vector_encoded = amino_acid.fixed_vectors_encoding(
                     index_encoded_matrix,
                     amino_acid.ENCODING_DATA_FRAMES[vector_encoding_name])
-                flattened = pandas.DataFrame(
-                    vector_encoded.reshape(
-                        (len(allele_to_fixed_length_sequence), -1)))
-                encoding_shape = vector_encoded.shape[1:]
             else:
                 # Raw values
-                flattened = allele_to_fixed_length_sequence
-                encoding_shape = (allele_to_fixed_length_sequence.shape[1],)
-            result = flattened.iloc[indices].values.reshape(
-                (len(self.alleles),) + encoding_shape)
+                vector_encoded = allele_to_fixed_length_sequence.values
+            result = [vector_encoded[i] for i in indices]
             self.encoding_cache[cache_key] = result
         return self.encoding_cache[cache_key]
 
diff --git a/test/test_allele_encoding.py b/test/test_allele_encoding.py
index caebb66e..a3f270e0 100644
--- a/test/test_allele_encoding.py
+++ b/test/test_allele_encoding.py
@@ -55,4 +55,4 @@ def test_allele_encoding_raw_values():
             [0, 1, -1],
             [10, 11, 12],
             [0, 1, -1],
-        ], encoding1)
+        ], numpy.array(encoding1))
-- 
GitLab