From 6a292951d27c7c3535a95ba4975165cd22e2b216 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Wed, 23 May 2018 13:24:35 -0400 Subject: [PATCH] more efficient allele encoding --- mhcflurry/allele_encoding.py | 17 ++++++++--------- test/test_allele_encoding.py | 2 +- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/mhcflurry/allele_encoding.py b/mhcflurry/allele_encoding.py index 092ee7c4..46ab3d8c 100644 --- a/mhcflurry/allele_encoding.py +++ b/mhcflurry/allele_encoding.py @@ -54,8 +54,13 @@ class AlleleEncoding(object): Returns ------- - numpy.array with shape (num sequences, sequence length, m) where m is + list of numpy arrays. Pass it to numpy.array to get an array with shape + (num sequences, sequence length, m) where m is vector_encoding_length(vector_encoding_name) + + The reason to return a list instead of an array is that the list can + use much less memory in the common case where many of the rows are + the same. """ cache_key = ( "fixed_length_vector_encoding", @@ -79,16 +84,10 @@ class AlleleEncoding(object): vector_encoded = amino_acid.fixed_vectors_encoding( index_encoded_matrix, amino_acid.ENCODING_DATA_FRAMES[vector_encoding_name]) - flattened = pandas.DataFrame( - vector_encoded.reshape( - (len(allele_to_fixed_length_sequence), -1))) - encoding_shape = vector_encoded.shape[1:] else: # Raw values - flattened = allele_to_fixed_length_sequence - encoding_shape = (allele_to_fixed_length_sequence.shape[1],) - result = flattened.iloc[indices].values.reshape( - (len(self.alleles),) + encoding_shape) + vector_encoded = allele_to_fixed_length_sequence.values + result = [vector_encoded[i] for i in indices] self.encoding_cache[cache_key] = result return self.encoding_cache[cache_key] diff --git a/test/test_allele_encoding.py b/test/test_allele_encoding.py index caebb66e..a3f270e0 100644 --- a/test/test_allele_encoding.py +++ b/test/test_allele_encoding.py @@ -55,4 +55,4 @@ def test_allele_encoding_raw_values(): [0, 1, -1], [10, 11, 12], [0, 1, -1], - ], encoding1) + ], numpy.array(encoding1)) -- GitLab