Skip to content
Snippets Groups Projects
Commit 6c839283 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Add --threads option to mhcflurry-predict. Closes #135

parent 387346ac
No related merge requests found
......@@ -12,7 +12,7 @@ import pandas
from . import amino_acid
def set_keras_backend(backend=None, gpu_device_nums=None):
def set_keras_backend(backend=None, gpu_device_nums=None, num_threads=None):
"""
Configure Keras backend to use GPU or CPU. Only tensorflow is supported.
......@@ -24,6 +24,9 @@ def set_keras_backend(backend=None, gpu_device_nums=None):
gpu_device_nums : list of int, optional
GPU devices to potentially use
num_threads : int, optional
Tensorflow threads to use
"""
os.environ["KERAS_BACKEND"] = "tensorflow"
......@@ -49,9 +52,11 @@ def set_keras_backend(backend=None, gpu_device_nums=None):
import tensorflow
from keras import backend as K
config = tensorflow.ConfigProto(
device_count=device_count)
config.gpu_options.allow_growth=True
config = tensorflow.ConfigProto(device_count=device_count)
config.gpu_options.allow_growth = True
if num_threads:
config.inter_op_parallelism_threads = num_threads
config.intra_op_parallelism_threads = num_threads
session = tensorflow.Session(config=config)
K.set_session(session)
......
......@@ -136,7 +136,7 @@ def make_worker_pool(
issue we add a second 'backup queue'. This queue always contains the
full set of initializer arguments: whenever a worker reads from it, it
always pushes the pop'd args back to the end of the queue immediately.
If the primary arg queue is every empty, then workers will read
If the primary arg queue is ever empty, then workers will read
from this backup queue.
Parameters
......
......@@ -33,6 +33,7 @@ import logging
import pandas
from .common import set_keras_backend
from .downloads import get_default_class1_models_dir
from .class1_affinity_predictor import Class1AffinityPredictor
from .version import __version__
......@@ -68,7 +69,7 @@ helper_args.add_argument(
version="mhcflurry %s" % __version__,
)
input_args = parser.add_argument_group(title="Required input arguments")
input_args = parser.add_argument_group(title="Input (required)")
input_args.add_argument(
"input",
metavar="INPUT.csv",
......@@ -86,7 +87,7 @@ input_args.add_argument(
help="Peptides to predict (exclusive with --input)")
input_mod_args = parser.add_argument_group(title="Optional input modifiers")
input_mod_args = parser.add_argument_group(title="Input options")
input_mod_args.add_argument(
"--allele-column",
metavar="NAME",
......@@ -104,7 +105,7 @@ input_mod_args.add_argument(
help="Return NaNs for unsupported alleles or peptides instead of raising")
output_args = parser.add_argument_group(title="Optional output modifiers")
output_args = parser.add_argument_group(title="Output options")
output_args.add_argument(
"--out",
metavar="OUTPUT.csv",
......@@ -119,26 +120,39 @@ output_args.add_argument(
metavar="CHAR",
default=",",
help="Delimiter character for results. Default: '%(default)s'")
output_args.add_argument(
"--include-individual-model-predictions",
action="store_true",
default=False,
help="Include predictions from each model in the ensemble"
)
model_args = parser.add_argument_group(title="Optional model settings")
model_args = parser.add_argument_group(title="Model options")
model_args.add_argument(
"--models",
metavar="DIR",
default=None,
help="Directory containing models. "
"Default: %s" % get_default_class1_models_dir(test_exists=False))
model_args.add_argument(
"--include-individual-model-predictions",
action="store_true",
default=False,
help="Include predictions from each model in the ensemble"
)
implementation_args = parser.add_argument_group(title="Implementation options")
implementation_args.add_argument(
"--backend",
choices=("tensorflow-gpu", "tensorflow-cpu", "tensorflow-default"),
help="Keras backend. If not specified will use system default.")
implementation_args.add_argument(
"--threads",
metavar="N",
type=int,
help="Num threads for tensorflow to use. If unspecified, tensorflow will "
"pick a value based on the number of cores.")
def run(argv=sys.argv[1:]):
args = parser.parse_args(argv)
set_keras_backend(backend=args.backend, num_threads=args.threads)
# It's hard to pass a tab in a shell, so we correct a common error:
if args.output_delimiter == "\\t":
args.output_delimiter = "\t"
......
__version__ = "1.2.2"
__version__ = "1.2.3"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment