From 6c8392837a35c3132be07e9c74baab041c2308ac Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Sat, 26 Jan 2019 17:39:11 -0500
Subject: [PATCH] Add --threads option to mhcflurry-predict. Closes #135

---
 mhcflurry/common.py          | 13 +++++++++----
 mhcflurry/parallelism.py     |  2 +-
 mhcflurry/predict_command.py | 36 +++++++++++++++++++++++++-----------
 mhcflurry/version.py         |  2 +-
 4 files changed, 36 insertions(+), 17 deletions(-)

diff --git a/mhcflurry/common.py b/mhcflurry/common.py
index a2da79dc..6b19251b 100644
--- a/mhcflurry/common.py
+++ b/mhcflurry/common.py
@@ -12,7 +12,7 @@ import pandas
 from . import amino_acid
 
 
-def set_keras_backend(backend=None, gpu_device_nums=None):
+def set_keras_backend(backend=None, gpu_device_nums=None, num_threads=None):
     """
     Configure Keras backend to use GPU or CPU. Only tensorflow is supported.
 
@@ -24,6 +24,9 @@ def set_keras_backend(backend=None, gpu_device_nums=None):
     gpu_device_nums : list of int, optional
         GPU devices to potentially use
 
+    num_threads : int, optional
+        Tensorflow threads to use
+
     """
     os.environ["KERAS_BACKEND"] = "tensorflow"
 
@@ -49,9 +52,11 @@ def set_keras_backend(backend=None, gpu_device_nums=None):
 
     import tensorflow
     from keras import backend as K
-    config = tensorflow.ConfigProto(
-        device_count=device_count)
-    config.gpu_options.allow_growth=True 
+    config = tensorflow.ConfigProto(device_count=device_count)
+    config.gpu_options.allow_growth = True
+    if num_threads:
+        config.inter_op_parallelism_threads = num_threads
+        config.intra_op_parallelism_threads = num_threads
     session = tensorflow.Session(config=config)
     K.set_session(session)
 
diff --git a/mhcflurry/parallelism.py b/mhcflurry/parallelism.py
index ccf9d005..b21af5c6 100644
--- a/mhcflurry/parallelism.py
+++ b/mhcflurry/parallelism.py
@@ -136,7 +136,7 @@ def make_worker_pool(
         issue we add a second 'backup queue'. This queue always contains the
         full set of initializer arguments: whenever a worker reads from it, it
         always pushes the pop'd args back to the end of the queue immediately.
-        If the primary arg queue is every empty, then workers will read
+        If the primary arg queue is ever empty, then workers will read
         from this backup queue.
 
     Parameters
diff --git a/mhcflurry/predict_command.py b/mhcflurry/predict_command.py
index 61f64e4f..0a8e5c97 100644
--- a/mhcflurry/predict_command.py
+++ b/mhcflurry/predict_command.py
@@ -33,6 +33,7 @@ import logging
 
 import pandas
 
+from .common import set_keras_backend
 from .downloads import get_default_class1_models_dir
 from .class1_affinity_predictor import Class1AffinityPredictor
 from .version import __version__
@@ -68,7 +69,7 @@ helper_args.add_argument(
     version="mhcflurry %s" % __version__,
 )
 
-input_args = parser.add_argument_group(title="Required input arguments")
+input_args = parser.add_argument_group(title="Input (required)")
 input_args.add_argument(
     "input",
     metavar="INPUT.csv",
@@ -86,7 +87,7 @@ input_args.add_argument(
     help="Peptides to predict (exclusive with --input)")
 
 
-input_mod_args = parser.add_argument_group(title="Optional input modifiers")
+input_mod_args = parser.add_argument_group(title="Input options")
 input_mod_args.add_argument(
     "--allele-column",
     metavar="NAME",
@@ -104,7 +105,7 @@ input_mod_args.add_argument(
     help="Return NaNs for unsupported alleles or peptides instead of raising")
 
 
-output_args = parser.add_argument_group(title="Optional output modifiers")
+output_args = parser.add_argument_group(title="Output options")
 output_args.add_argument(
     "--out",
     metavar="OUTPUT.csv",
@@ -119,26 +120,39 @@ output_args.add_argument(
     metavar="CHAR",
     default=",",
     help="Delimiter character for results. Default: '%(default)s'")
+output_args.add_argument(
+    "--include-individual-model-predictions",
+    action="store_true",
+    default=False,
+    help="Include predictions from each model in the ensemble"
+)
 
-
-model_args = parser.add_argument_group(title="Optional model settings")
+model_args = parser.add_argument_group(title="Model options")
 model_args.add_argument(
     "--models",
     metavar="DIR",
     default=None,
     help="Directory containing models. "
     "Default: %s" % get_default_class1_models_dir(test_exists=False))
-model_args.add_argument(
-    "--include-individual-model-predictions",
-    action="store_true",
-    default=False,
-    help="Include predictions from each model in the ensemble"
-)
+
+implementation_args = parser.add_argument_group(title="Implementation options")
+implementation_args.add_argument(
+    "--backend",
+    choices=("tensorflow-gpu", "tensorflow-cpu", "tensorflow-default"),
+    help="Keras backend. If not specified will use system default.")
+implementation_args.add_argument(
+    "--threads",
+    metavar="N",
+    type=int,
+    help="Num threads for tensorflow to use. If unspecified, tensorflow will "
+    "pick a value based on the number of cores.")
 
 
 def run(argv=sys.argv[1:]):
     args = parser.parse_args(argv)
 
+    set_keras_backend(backend=args.backend, num_threads=args.threads)
+
     # It's hard to pass a tab in a shell, so we correct a common error:
     if args.output_delimiter == "\\t":
         args.output_delimiter = "\t"
diff --git a/mhcflurry/version.py b/mhcflurry/version.py
index bc86c944..10aa336c 100644
--- a/mhcflurry/version.py
+++ b/mhcflurry/version.py
@@ -1 +1 @@
-__version__ = "1.2.2"
+__version__ = "1.2.3"
-- 
GitLab