Skip to content
Snippets Groups Projects
Commit c1189219 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

mhcflurry.predict can now work with custom model loaders

parent 770a1e21
No related merge requests found
......@@ -25,7 +25,7 @@ from os.path import join
import pandas
from ..downloads import get_path
from ..common import normalize_allele_name
from ..common import normalize_allele_name, UnsupportedAllele
CACHED_LOADER = None
......@@ -113,7 +113,7 @@ class Class1AlleleSpecificPredictorLoader(object):
try:
predictor_name = self.df.ix[allele_name].predictor_name
except KeyError:
raise ValueError(
raise UnsupportedAllele(
"No models for allele '%s'. Alleles with models: %s"
" in models file: %s" % (
allele_name,
......
......@@ -20,6 +20,10 @@ from collections import defaultdict
import numpy as np
class UnsupportedAllele(Exception):
pass
def parse_int_list(s):
return [int(part.strip()) for part in s.split(",")]
......
......@@ -70,7 +70,9 @@ class Dataset(object):
for expected_column_name in {"allele", "peptide", "affinity"}:
if expected_column_name not in columns:
raise ValueError("Missing column '%s' from DataFrame")
raise ValueError(
"Missing column '%s' from DataFrame" %
expected_column_name)
# make allele and peptide columns the index, and copy it
# so we can add a column without any observable side-effect in
# the calling code
......
......@@ -17,10 +17,10 @@ from collections import OrderedDict
import pandas as pd
from .class1_allele_specific import load
from .common import normalize_allele_name
from .common import normalize_allele_name, UnsupportedAllele
def predict(alleles, peptides):
def predict(alleles, peptides, loaders=None):
"""
Parameters
----------
......@@ -32,6 +32,10 @@ def predict(alleles, peptides):
Returns DataFrame with columns "Allele", "Peptide", and "Prediction"
"""
if loaders is None:
loaders = [
load.get_loader_for_downloaded_models(),
]
result_dict = OrderedDict([
("Allele", []),
("Peptide", []),
......@@ -39,7 +43,22 @@ def predict(alleles, peptides):
])
for allele in alleles:
allele = normalize_allele_name(allele)
model = load.from_allele_name(allele)
exceptions = {} # loader -> UnsupportedAllele exception
model = None
for loader in loaders:
try:
model = loader.from_allele_name(allele)
break
except UnsupportedAllele as e:
exceptions[loader] = e
if model is None:
raise UnsupportedAllele(
"No loaders support allele '%s'. Errors were:\n%s" % (
allele,
"\n".join(
("\t%-20s : %s" % (k, v))
for (k, v) in exceptions.items())))
for i, ic50 in enumerate(model.predict(peptides)):
result_dict["Allele"].append(allele)
result_dict["Peptide"].append(peptides[i])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment