Skip to content
Snippets Groups Projects
Commit 172effc2 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Fix bug in model loading when the models dir is changes. Refactored into a...

Fix bug in model loading when the models dir is changes. Refactored into a Class1AlleleSpecificPredictorLoader class
parent 09201ca5
No related branches found
No related tags found
No related merge requests found
......@@ -13,23 +13,22 @@
# limitations under the License.
'''
Load predictors
'''
import pickle
from os.path import join
import pandas
from ..downloads import get_path
from ..common import normalize_allele_name
_ALLELE_PREDICTOR_CACHE = {}
_PRODUCTION_MODELS_DATAFRAME = None
CACHED_LOADER = None
def from_allele_name(allele_name):
"""
Load a predictor for an allele.
Load a predictor for an allele using the default loader.
Parameters
----------
......@@ -39,46 +38,85 @@ def from_allele_name(allele_name):
----------
Class1BindingPredictor
"""
global _ALLELE_PREDICTOR_CACHE
allele_name = normalize_allele_name(allele_name)
if allele_name in _ALLELE_PREDICTOR_CACHE:
return _ALLELE_PREDICTOR_CACHE[allele_name]
models_df = production_models_dataframe()
try:
predictor_name = models_df.ix[allele_name].predictor_name
except KeyError:
raise ValueError(
"No models for allele '%s'. Alleles with models: %s"
% (allele_name, ' '.join(supported_alleles())))
model_path = get_path(
"models_class1_allele_specific_single",
"models/%s.pickle" % predictor_name)
with open(model_path, 'rb') as fd:
predictor = pickle.load(fd)
_ALLELE_PREDICTOR_CACHE[allele_name] = predictor
return predictor
return get_loader_for_downloaded_models().from_allele_name(allele_name)
def supported_alleles():
"""
Return a list of the names of the alleles for which there are trained
predictors.
predictors in the default laoder.
"""
return get_loader_for_downloaded_models().supported_alleles
def get_loader_for_downloaded_models():
"""
Return a Class1AlleleSpecificPredictorLoader that uses downloaded models.
"""
return list(sorted(production_models_dataframe().allele))
global CACHED_LOADER
# Some of the unit tests manipulate the downloads directory configuration
# so get_path here may return different results in the same Python process.
# For this reason we check the path and invalidate the loader if it's
# different.
path = get_path("models_class1_allele_specific_single")
if CACHED_LOADER is None or path != CACHED_LOADER.path:
CACHED_LOADER = Class1AlleleSpecificPredictorLoader(path)
return CACHED_LOADER
def production_models_dataframe():
class Class1AlleleSpecificPredictorLoader(object):
"""
Return a pandas.DataFrame describing the currently available trained
predictors.
Factory for Class1BindingPredictor instances that are stored on disk
using this directory structure:
production.csv - Manifest file giving information on all models
models/ - directory of models with names given in the manifest file
MODEL-BAR.pickle
MODEL-FOO.pickle
...
"""
global _PRODUCTION_MODELS_DATAFRAME
if _PRODUCTION_MODELS_DATAFRAME is None:
_PRODUCTION_MODELS_DATAFRAME = pandas.read_csv(
get_path("models_class1_allele_specific_single", "production.csv"))
_PRODUCTION_MODELS_DATAFRAME.index = (
_PRODUCTION_MODELS_DATAFRAME.allele)
return _PRODUCTION_MODELS_DATAFRAME
def __init__(self, path):
"""
Parameters
----------
path : string
Path to directory containing manifest and models
"""
self.path = path
self.path_to_models_csv = join(path, "production.csv")
self.df = pandas.read_csv(self.path_to_models_csv)
self.df.index = self.df["allele"]
self.supported_alleles = list(sorted(self.df.allele))
self.predictors_cache = {}
def from_allele_name(self, allele_name):
"""
Load a predictor for an allele.
Parameters
----------
allele_name : class I allele name
Returns
----------
Class1BindingPredictor
"""
allele_name = normalize_allele_name(allele_name)
if allele_name not in self.predictors_cache:
try:
predictor_name = self.df.ix[allele_name].predictor_name
except KeyError:
raise ValueError(
"No models for allele '%s'. Alleles with models: %s"
" in models file: %s" % (
allele_name,
' '.join(self.supported_alleles()),
self.path_to_models_csv))
model_path = join(self.path, "models", predictor_name + ".pickle")
with open(model_path, 'rb') as fd:
self.predictors_cache[allele_name] = pickle.load(fd)
return self.predictors_cache[allele_name]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment