Skip to content
Snippets Groups Projects
test_predict_command.py 1.74 KiB
Newer Older
import tempfile
import os

import pandas
from numpy.testing import assert_equal

from mhcflurry import predict_command

TEST_CSV = '''
Allele,Peptide,Experiment
HLA-A0201,SYNFEKKL,17
HLA-B4403,AAAAAAAAA,17
HLA-B4403,PPPPPPPP,18
'''.strip()


def test_csv():
    args = ["--allele-column", "Allele", "--peptide-column", "Peptide"]
    deletes = []
    try:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as fd:
            fd.write(TEST_CSV.encode())
            deletes.append(fd.name)
        fd_out = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
        deletes.append(fd_out.name)
        full_args = [fd.name] + args + ["--out", fd_out.name]
        print("Running with args: %s" % full_args)
        predict_command.run(full_args)
        result = pandas.read_csv(fd_out.name)
        print(result)
    finally:
        for delete in deletes:
            os.unlink(delete)

    assert_equal(result.shape, (3, 4))


def test_no_csv():
    args = [
        "--alleles", "HLA-A0201", "H-2Kb",
        "--peptides", "SIINFEKL", "DENDREKLLL", "PICKLE",
        "--prediction-column", "prediction",
    ]

    deletes = []
    try:
        fd_out = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
        deletes.append(fd_out.name)
        full_args = args + ["--out", fd_out.name]
        print("Running with args: %s" % full_args)
        predict_command.run(full_args)
        result = pandas.read_csv(fd_out.name)
        print(result)
    finally:
        for delete in deletes:
            os.unlink(delete)

    assert_equal(result.shape, (6, 3))
    sub_result1 = result.ix[result.peptide == "SIINFEKL"].set_index("allele")
    assert (
        sub_result1.ix["H-2Kb"].prediction <
        sub_result1.ix["HLA-A0201"].prediction)