diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py index 3eb2bfb69c855c07a16082ec60b813988ba64fbe..fb19b09f87144973872305b716b4d195c6a42393 100644 --- a/mhcflurry/class1_affinity_predictor.py +++ b/mhcflurry/class1_affinity_predictor.py @@ -18,7 +18,7 @@ from six import string_types from .class1_neural_network import Class1NeuralNetwork from .common import random_peptides -from .downloads import get_path +from .downloads import get_default_class1_models_dir from .encodable_sequences import EncodableSequences from .percent_rank_transform import PercentRankTransform from .regression_target import to_ic50 @@ -283,7 +283,7 @@ class Class1AffinityPredictor(object): `Class1AffinityPredictor` instance """ if models_dir is None: - models_dir = get_path("models_class1", "models") + models_dir = get_default_class1_models_dir() manifest_path = join(models_dir, "manifest.csv") manifest_df = pandas.read_csv(manifest_path, nrows=max_models) diff --git a/mhcflurry/downloads.py b/mhcflurry/downloads.py index e88675ed00b78718f0687e28ed4911473593e13c..b3b7e9d3b78d421af685841f370b5a99885f595c 100644 --- a/mhcflurry/downloads.py +++ b/mhcflurry/downloads.py @@ -9,7 +9,7 @@ from __future__ import ( ) import logging import yaml -from os.path import join, exists +from os.path import join, exists, relpath from pipes import quote from os import environ from collections import OrderedDict @@ -20,11 +20,14 @@ ENVIRONMENT_VARIABLES = [ "MHCFLURRY_DATA_DIR", "MHCFLURRY_DOWNLOADS_CURRENT_RELEASE", "MHCFLURRY_DOWNLOADS_DIR", + "MHCFLURRY_DEFAULT_CLASS1_MODELS" ] _DOWNLOADS_DIR = None _CURRENT_RELEASE = None _METADATA = None +_MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR = environ.get( + "MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR") def get_downloads_dir(): @@ -51,6 +54,37 @@ def get_downloads_metadata(): return _METADATA +def get_default_class1_models_dir(test_exists=True): + """ + Return the absolute path to the default class1 models dir. + + If environment variable MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR is set to an + absolute path, return that path. If it's set to a relative path (i.e. does + not start with /) then return that path taken to be relative to the mhcflurry + downloads dir. + + If environment variable MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR is NOT set, + then return the path to downloaded models in the "models_class1" download. + + Parameters + ---------- + + test_exists : boolean, optional + Whether to raise an exception of the path does not exist + + Returns + ------- + string : absolute path + """ + if _MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR: + result = join(get_downloads_dir(), _MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR) + if test_exists and not exists(result): + raise IOError("No such directory: %s" % result) + return result + else: + return get_path("models_class1", "models", test_exists=test_exists) + + def get_current_release_downloads(): """ Return a dict of all available downloads in the current release. diff --git a/mhcflurry/downloads_command.py b/mhcflurry/downloads_command.py index 08648a99e8eb6d2c03967cfd3ebb0bd12bbdaab8..3af2fa3ecf47ad71778a35844372dfa65c2a046a 100644 --- a/mhcflurry/downloads_command.py +++ b/mhcflurry/downloads_command.py @@ -229,15 +229,13 @@ def info_subcommand(args): downloads = get_current_release_downloads() - format_string = "%-40s %-12s %-12s %-20s " - print(format_string % ( - "DOWNLOAD NAME", "DOWNLOADED?", "DEFAULT?", "URL")) + format_string = "%-40s %-12s %-20s " + print(format_string % ("DOWNLOAD NAME", "DOWNLOADED?", "URL")) for (item, info) in downloads.items(): print(format_string % ( item, yes_no(info['downloaded']), - yes_no(info['metadata']['default']), info['metadata']["url"])) diff --git a/mhcflurry/predict_command.py b/mhcflurry/predict_command.py index ab07cfac25b77f74a16a4e93aea4a93ce95eb942..aea742d3928254f24a392e3d6156817333de3d33 100644 --- a/mhcflurry/predict_command.py +++ b/mhcflurry/predict_command.py @@ -33,7 +33,7 @@ import logging import pandas -from .downloads import get_path +from .downloads import get_default_class1_models_dir from .class1_affinity_predictor import Class1AffinityPredictor @@ -122,7 +122,7 @@ model_args.add_argument( metavar="DIR", default=None, help="Directory containing models. " - "Default: %s" % get_path("models_class1", "models", test_exists=False)) + "Default: %s" % get_default_class1_models_dir(test_exists=False)) model_args.add_argument( "--include-individual-model-predictions", action="store_true", @@ -143,7 +143,7 @@ def run(argv=sys.argv[1:]): # The reason we set the default here instead of in the argument parser is that # we want to test_exists at this point, so the user gets a message instructing # them to download the models if needed. - models_dir = get_path("models_class1", "models") + models_dir = get_default_class1_models_dir(test_exists=True) predictor = Class1AffinityPredictor.load(models_dir) # The following two are informative commands that can come