diff --git a/downloads-generation/models_class1/generate_hyperparameters.py b/downloads-generation/models_class1/generate_hyperparameters.py
index c9f546460060f9d3beab947cfe5a394ee3e79e9c..e958e101639acc5bc1a4524d6f73de22e3528a98 100644
--- a/downloads-generation/models_class1/generate_hyperparameters.py
+++ b/downloads-generation/models_class1/generate_hyperparameters.py
@@ -64,8 +64,8 @@ base_hyperparameters = {
 
 grid = []
 for train_subset in ["all", "quantitative"]:
-    for minibatch_size in [128]:
-        for dense_layer_size in [8, 16, 32, 64]:
+    for minibatch_size in [128, 512]:
+        for dense_layer_size in [8, 16, 32, 64, 128]:
             for l1 in [0.0, 0.001]:
                 for num_lc in [0, 1, 2]:
                     for lc_kernel_size in [3, 5]:
diff --git a/mhcflurry/common.py b/mhcflurry/common.py
index 3bd2390cf2d77240aeb8626abd3007b37911fffa..4b6f6d56bc1e312cadeeb247c8f1205c4572f5a5 100644
--- a/mhcflurry/common.py
+++ b/mhcflurry/common.py
@@ -32,7 +32,7 @@ def set_keras_backend(backend=None, gpu_device_nums=None):
         os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
             [str(i) for i in gpu_device_nums])
 
-    if backend == "tensorflow-cpu":
+    if backend == "tensorflow-cpu" or not gpu_device_nums:
         print("Forcing tensorflow/CPU backend.")
         os.environ["CUDA_VISIBLE_DEVICES"] = ""
         device_count = {'CPU': 1, 'GPU': 0}
diff --git a/mhcflurry/parallelism.py b/mhcflurry/parallelism.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcce96e5cf64352d065a6c03c62c44837508232e
--- /dev/null
+++ b/mhcflurry/parallelism.py
@@ -0,0 +1,100 @@
+from multiprocessing import Pool, Queue, cpu_count
+from queue import Empty
+from multiprocessing.util import Finalize
+from pprint import pprint
+
+
+def make_worker_pool(
+        processes=None,
+        initializer=None,
+        initializer_kwargs_per_process=None,
+        max_tasks_per_worker=None):
+    """
+    Convenience wrapper to create a multiprocessing.Pool.
+
+    This function adds support for per-worker initializer arguments, which are
+    not natively supported by the multiprocessing module. The motivation for
+    this feature is to support allocating each worker to a (different) GPU.
+
+    IMPLEMENTATION NOTE:
+        The per-worker initializer arguments are implemented using a Queue. Each
+        worker reads its arguments from this queue when it starts. When it
+        terminates, it adds its initializer arguments back to the queue, so a
+        future process can initialize itself using these arguments.
+
+        There is one issue with this approach, however. If a worker crashes, it
+        never repopulates the queue of initializer arguments. This will prevent
+        any future worker from re-using those arguments. To deal with this
+        issue we add a second 'backup queue'. This queue always contains the
+        full set of initializer arguments: whenever a worker reads from it, it
+        always pushes the pop'd args back to the end of the queue immediately.
+        If the primary arg queue is every empty, then workers will read
+        from this backup queue.
+
+    Parameters
+    ----------
+    processes : int
+        Number of workers. Default: num CPUs.
+
+    initializer : function, optional
+        Init function to call in each worker
+
+    initializer_kwargs_per_process : list of dict, optional
+        Arguments to pass to initializer function for each worker. Length of
+        list must equal the number of workers.
+
+    max_tasks_per_worker : int, optional
+        Restart workers after this many tasks. Requires Python >=3.2.
+
+    Returns
+    -------
+    multiprocessing.Pool
+    """
+
+    if not processes:
+        processes = cpu_count()
+
+    pool_kwargs = {
+        'processes': processes,
+    }
+    if max_tasks_per_worker:
+        pool_kwargs["maxtasksperchild"] = max_tasks_per_worker
+
+    if initializer:
+        if initializer_kwargs_per_process:
+            assert len(initializer_kwargs_per_process) == processes
+            kwargs_queue = Queue()
+            kwargs_queue_backup = Queue()
+            for kwargs in initializer_kwargs_per_process:
+                kwargs_queue.put(kwargs)
+                kwargs_queue_backup.put(kwargs)
+            pool_kwargs["initializer"] = worker_init_entry_point
+            pool_kwargs["initargs"] = (
+                initializer, kwargs_queue, kwargs_queue_backup)
+        else:
+            pool_kwargs["initializer"] = initializer
+
+    worker_pool = Pool(**pool_kwargs)
+    print("Started pool: %s" % str(worker_pool))
+    pprint(pool_kwargs)
+    return worker_pool
+
+
+def worker_init_entry_point(
+        init_function, arg_queue=None, backup_arg_queue=None):
+    kwargs = {}
+    if arg_queue:
+        try:
+            kwargs = arg_queue.get(block=False)
+        except Empty:
+            print("Argument queue empty. Using round robin arg queue.")
+            kwargs = backup_arg_queue.get(block=True)
+            backup_arg_queue.put(kwargs)
+
+        # On exit we add the init args back to the queue so restarted workers
+        # (e.g. when when running with maxtasksperchild) will pickup init
+        # arguments from a previously exited worker.
+        Finalize(None, arg_queue.put, (kwargs,), exitpriority=1)
+
+    print("Initializing worker: %s" % str(kwargs))
+    init_function(**kwargs)
diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index 98dfc408c7fe0871d3feab2a2681a3dd7ea8430b..eac3f9c3a9fabfeb51a42e3a584b87305e17a383 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -8,11 +8,7 @@ import sys
 import time
 import traceback
 import random
-from multiprocessing import Pool, Queue, cpu_count
-from queue import Empty
-from multiprocessing.util import Finalize
 from functools import partial
-from pprint import pprint
 
 import pandas
 import yaml
@@ -22,6 +18,7 @@ tqdm.monitor_interval = 0  # see https://github.com/tqdm/tqdm/issues/481
 
 from .class1_affinity_predictor import Class1AffinityPredictor
 from .common import configure_logging, set_keras_backend
+from .parallelism import make_worker_pool, cpu_count
 
 
 # To avoid pickling large matrices to send to child processes when running in
@@ -142,82 +139,6 @@ parser.add_argument(
     "leaks. Requires Python >=3.2.")
 
 
-def make_worker_pool(
-        processes=None,
-        initializer=None,
-        initializer_kwargs_per_process=None,
-        max_tasks_per_worker=None):
-    """
-    Convenience wrapper to create a multiprocessing.Pool.
-
-    Parameters
-    ----------
-    processes : int
-        Number of workers. Default: num CPUs.
-
-    initializer : function, optional
-        Init function to call in each worker
-
-    initializer_kwargs_per_process : list of dict, optional
-        Arguments to pass to initializer function for each worker. Length of
-        list must equal the number of workers.
-
-    max_tasks_per_worker : int, optional
-        Restart workers after this many tasks. Requires Python >=3.2.
-
-    Returns
-    -------
-    multiprocessing.Pool
-    """
-
-    if not processes:
-        processes = cpu_count()
-
-    pool_kwargs = {
-        'processes': processes,
-    }
-    if max_tasks_per_worker:
-        pool_kwargs["maxtasksperchild"] = max_tasks_per_worker
-
-    if initializer:
-        if initializer_kwargs_per_process:
-            assert len(initializer_kwargs_per_process) == processes
-            kwargs_queue = Queue()
-            kwargs_queue2 = Queue()
-            for kwargs in initializer_kwargs_per_process:
-                kwargs_queue.put(kwargs)
-                kwargs_queue2.put(kwargs)
-            pool_kwargs["initializer"] = worker_init_entry_point
-            pool_kwargs["initargs"] = (initializer, kwargs_queue, kwargs_queue2)
-        else:
-            pool_kwargs["initializer"] = initializer
-
-    worker_pool = Pool(**pool_kwargs)
-    print("Started pool: %s" % str(worker_pool))
-    pprint(pool_kwargs)
-    return worker_pool
-
-
-def worker_init_entry_point(
-        init_function, arg_queue=None, round_robin_arg_queue=None):
-    kwargs = {}
-    if arg_queue:
-        try:
-            kwargs = arg_queue.get(block=False)
-        except Empty:
-            print("Argument queue empty. Using round robin arg queue.")
-            kwargs = round_robin_arg_queue.get(block=True)
-            round_robin_arg_queue.put(kwargs)
-
-        # On exit we add the init args back to the queue so restarted workers
-        # (e.g. when when running with maxtasksperchild) will pickup init
-        # arguments from a previously exited worker.
-        Finalize(None, arg_queue.put, (kwargs,), exitpriority=1)
-
-    print("Initializing worker: %s" % str(kwargs))
-    init_function(**kwargs)
-
-
 def run(argv=sys.argv[1:]):
     global GLOBAL_DATA