From b3840422fb6196a5b59d073c5cb29e4814c1a456 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Mon, 5 Feb 2018 16:32:13 -0500 Subject: [PATCH] better handling of default class1 models configuration --- mhcflurry/class1_affinity_predictor.py | 4 +-- mhcflurry/downloads.py | 36 +++++++++++++++++++++++++- mhcflurry/downloads_command.py | 6 ++--- mhcflurry/predict_command.py | 6 ++--- 4 files changed, 42 insertions(+), 10 deletions(-) diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py index 3eb2bfb6..fb19b09f 100644 --- a/mhcflurry/class1_affinity_predictor.py +++ b/mhcflurry/class1_affinity_predictor.py @@ -18,7 +18,7 @@ from six import string_types from .class1_neural_network import Class1NeuralNetwork from .common import random_peptides -from .downloads import get_path +from .downloads import get_default_class1_models_dir from .encodable_sequences import EncodableSequences from .percent_rank_transform import PercentRankTransform from .regression_target import to_ic50 @@ -283,7 +283,7 @@ class Class1AffinityPredictor(object): `Class1AffinityPredictor` instance """ if models_dir is None: - models_dir = get_path("models_class1", "models") + models_dir = get_default_class1_models_dir() manifest_path = join(models_dir, "manifest.csv") manifest_df = pandas.read_csv(manifest_path, nrows=max_models) diff --git a/mhcflurry/downloads.py b/mhcflurry/downloads.py index e88675ed..b3b7e9d3 100644 --- a/mhcflurry/downloads.py +++ b/mhcflurry/downloads.py @@ -9,7 +9,7 @@ from __future__ import ( ) import logging import yaml -from os.path import join, exists +from os.path import join, exists, relpath from pipes import quote from os import environ from collections import OrderedDict @@ -20,11 +20,14 @@ ENVIRONMENT_VARIABLES = [ "MHCFLURRY_DATA_DIR", "MHCFLURRY_DOWNLOADS_CURRENT_RELEASE", "MHCFLURRY_DOWNLOADS_DIR", + "MHCFLURRY_DEFAULT_CLASS1_MODELS" ] _DOWNLOADS_DIR = None _CURRENT_RELEASE = None _METADATA = None +_MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR = environ.get( + "MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR") def get_downloads_dir(): @@ -51,6 +54,37 @@ def get_downloads_metadata(): return _METADATA +def get_default_class1_models_dir(test_exists=True): + """ + Return the absolute path to the default class1 models dir. + + If environment variable MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR is set to an + absolute path, return that path. If it's set to a relative path (i.e. does + not start with /) then return that path taken to be relative to the mhcflurry + downloads dir. + + If environment variable MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR is NOT set, + then return the path to downloaded models in the "models_class1" download. + + Parameters + ---------- + + test_exists : boolean, optional + Whether to raise an exception of the path does not exist + + Returns + ------- + string : absolute path + """ + if _MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR: + result = join(get_downloads_dir(), _MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR) + if test_exists and not exists(result): + raise IOError("No such directory: %s" % result) + return result + else: + return get_path("models_class1", "models", test_exists=test_exists) + + def get_current_release_downloads(): """ Return a dict of all available downloads in the current release. diff --git a/mhcflurry/downloads_command.py b/mhcflurry/downloads_command.py index 08648a99..3af2fa3e 100644 --- a/mhcflurry/downloads_command.py +++ b/mhcflurry/downloads_command.py @@ -229,15 +229,13 @@ def info_subcommand(args): downloads = get_current_release_downloads() - format_string = "%-40s %-12s %-12s %-20s " - print(format_string % ( - "DOWNLOAD NAME", "DOWNLOADED?", "DEFAULT?", "URL")) + format_string = "%-40s %-12s %-20s " + print(format_string % ("DOWNLOAD NAME", "DOWNLOADED?", "URL")) for (item, info) in downloads.items(): print(format_string % ( item, yes_no(info['downloaded']), - yes_no(info['metadata']['default']), info['metadata']["url"])) diff --git a/mhcflurry/predict_command.py b/mhcflurry/predict_command.py index ab07cfac..aea742d3 100644 --- a/mhcflurry/predict_command.py +++ b/mhcflurry/predict_command.py @@ -33,7 +33,7 @@ import logging import pandas -from .downloads import get_path +from .downloads import get_default_class1_models_dir from .class1_affinity_predictor import Class1AffinityPredictor @@ -122,7 +122,7 @@ model_args.add_argument( metavar="DIR", default=None, help="Directory containing models. " - "Default: %s" % get_path("models_class1", "models", test_exists=False)) + "Default: %s" % get_default_class1_models_dir(test_exists=False)) model_args.add_argument( "--include-individual-model-predictions", action="store_true", @@ -143,7 +143,7 @@ def run(argv=sys.argv[1:]): # The reason we set the default here instead of in the argument parser is that # we want to test_exists at this point, so the user gets a message instructing # them to download the models if needed. - models_dir = get_path("models_class1", "models") + models_dir = get_default_class1_models_dir(test_exists=True) predictor = Class1AffinityPredictor.load(models_dir) # The following two are informative commands that can come -- GitLab