From c11892191d7fab646b1418a1838715dd4a26f9ae Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Wed, 21 Sep 2016 17:07:04 -0400
Subject: [PATCH] mhcflurry.predict can now work with custom model loaders

---
 mhcflurry/class1_allele_specific/load.py |  4 ++--
 mhcflurry/common.py                      |  4 ++++
 mhcflurry/dataset.py                     |  4 +++-
 mhcflurry/predict.py                     | 25 +++++++++++++++++++++---
 4 files changed, 31 insertions(+), 6 deletions(-)

diff --git a/mhcflurry/class1_allele_specific/load.py b/mhcflurry/class1_allele_specific/load.py
index 8ee0f1d2..619bd0f3 100644
--- a/mhcflurry/class1_allele_specific/load.py
+++ b/mhcflurry/class1_allele_specific/load.py
@@ -25,7 +25,7 @@ from os.path import join
 import pandas
 
 from ..downloads import get_path
-from ..common import normalize_allele_name
+from ..common import normalize_allele_name, UnsupportedAllele
 
 CACHED_LOADER = None
 
@@ -113,7 +113,7 @@ class Class1AlleleSpecificPredictorLoader(object):
             try:
                 predictor_name = self.df.ix[allele_name].predictor_name
             except KeyError:
-                raise ValueError(
+                raise UnsupportedAllele(
                     "No models for allele '%s'. Alleles with models: %s"
                     " in models file: %s" % (
                         allele_name,
diff --git a/mhcflurry/common.py b/mhcflurry/common.py
index d67d34cd..809109c3 100644
--- a/mhcflurry/common.py
+++ b/mhcflurry/common.py
@@ -20,6 +20,10 @@ from collections import defaultdict
 import numpy as np
 
 
+class UnsupportedAllele(Exception):
+    pass
+
+
 def parse_int_list(s):
     return [int(part.strip()) for part in s.split(",")]
 
diff --git a/mhcflurry/dataset.py b/mhcflurry/dataset.py
index f70273dc..43fff5c9 100644
--- a/mhcflurry/dataset.py
+++ b/mhcflurry/dataset.py
@@ -70,7 +70,9 @@ class Dataset(object):
 
         for expected_column_name in {"allele", "peptide", "affinity"}:
             if expected_column_name not in columns:
-                raise ValueError("Missing column '%s' from DataFrame")
+                raise ValueError(
+                    "Missing column '%s' from DataFrame" %
+                    expected_column_name)
         # make allele and peptide columns the index, and copy it
         # so we can add a column without any observable side-effect in
         # the calling code
diff --git a/mhcflurry/predict.py b/mhcflurry/predict.py
index fee483d3..0f432cf0 100644
--- a/mhcflurry/predict.py
+++ b/mhcflurry/predict.py
@@ -17,10 +17,10 @@ from collections import OrderedDict
 import pandas as pd
 
 from .class1_allele_specific import load
-from .common import normalize_allele_name
+from .common import normalize_allele_name, UnsupportedAllele
 
 
-def predict(alleles, peptides):
+def predict(alleles, peptides, loaders=None):
     """
     Parameters
     ----------
@@ -32,6 +32,10 @@ def predict(alleles, peptides):
 
     Returns DataFrame with columns "Allele", "Peptide", and "Prediction"
     """
+    if loaders is None:
+        loaders = [
+            load.get_loader_for_downloaded_models(),
+        ]
     result_dict = OrderedDict([
         ("Allele", []),
         ("Peptide", []),
@@ -39,7 +43,22 @@ def predict(alleles, peptides):
     ])
     for allele in alleles:
         allele = normalize_allele_name(allele)
-        model = load.from_allele_name(allele)
+        exceptions = {}  # loader -> UnsupportedAllele exception
+        model = None
+        for loader in loaders:
+            try:
+                model = loader.from_allele_name(allele)
+                break
+            except UnsupportedAllele as e:
+                exceptions[loader] = e
+        if model is None:
+            raise UnsupportedAllele(
+                "No loaders support allele '%s'. Errors were:\n%s" % (
+                    allele,
+                    "\n".join(
+                        ("\t%-20s : %s" % (k, v))
+                        for (k, v) in exceptions.items())))
+
         for i, ic50 in enumerate(model.predict(peptides)):
             result_dict["Allele"].append(allele)
             result_dict["Peptide"].append(peptides[i])
-- 
GitLab