Skip to content
Snippets Groups Projects
test_encodable_sequences.py 840 B
Newer Older
Tim O'Donnell's avatar
Tim O'Donnell committed
from mhcflurry import encodable_sequences
from nose.tools import eq_
from numpy.testing import assert_equal

letter_to_index_dict = {
    'A': 0,
    'B': 1,
    'C': 2,
}


def test_index_and_one_hot_encoding():
    index_encoding = encodable_sequences.index_encoding(
        ["AAAA", "ABCA"], letter_to_index_dict)
    assert_equal(
        index_encoding,
        [
            [0, 0, 0, 0],
            [0, 1, 2, 0],
        ])
    one_hot = encodable_sequences.one_hot_encoding(index_encoding, 3)
    eq_(one_hot.shape, (2, 4, 3))
    assert_equal(
        one_hot[0],
        [
            [1, 0, 0],
            [1, 0, 0],
            [1, 0, 0],
            [1, 0, 0],
        ])
    assert_equal(
        one_hot[1],
        [
            [1, 0, 0],
            [0, 1, 0],
            [0, 0, 1],
            [1, 0, 0],
        ])