diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index 7d42c9faa8745c4307f9438a2945342d2a02c107..c9b018b210ba4aa436e8b192a75f8878dadbf37a 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -197,6 +197,9 @@ def run(argv=sys.argv[1:]):
         worker_init_args = None
         if args.gpus:
             print("Attempting to round-robin assign each worker a GPU.")
+            if args.backend != "tensorflow-default":
+                print("Forcing keras backend to be tensorflow-default")
+                args.backend = "tensorflow-default"
 
             gpu_assignments_remaining = dict((
                 (gpu, args.max_workers_per_gpu) for gpu in range(args.gpus)
@@ -211,6 +214,7 @@ def run(argv=sys.argv[1:]):
                     gpu_assignments_remaining[gpu_num] -= 1
                     if not gpu_assignments_remaining[gpu_num]:
                         del gpu_assignments_remaining[gpu_num]
+                    gpu_assignment = [gpu_num]
                 else:
                     # Use CPU
                     gpu_assignment = []
@@ -219,9 +223,10 @@ def run(argv=sys.argv[1:]):
                     'gpu_device_nums': gpu_assignment,
                     'keras_backend': args.backend
                 })
+                print("Worker %d assigned GPUs: %s" % (worker_num, gpu_assignment))
 
         worker_pool = Pool(
-            initializer=worker_init,
+            initializer=worker_init_entrypoint,
             initargs=(worker_init_args,),
             processes=num_workers)
         print("Started pool of %d workers: %s" % (num_workers, str(worker_pool)))
@@ -449,17 +454,16 @@ def calibrate_percentile_ranks(allele, predictor, peptides=None):
 
 def worker_init_entrypoint(arg_queue):
     if arg_queue:
-        (args, kwargs) = arg_queue.get()
+        kwargs = arg_queue.get()
     else:
-        args = []
         kwargs = {}
-    worker_init(*args, **kwargs)
+    worker_init(**kwargs)
 
 
 def worker_init(keras_backend=None, gpu_device_nums=None):
     if keras_backend or gpu_device_nums:
         print("WORKER pid=%d assigned GPU devices: %s" % (
-            os.getpgid()), gpu_device_nums)
+            os.getpid(), gpu_device_nums))
         set_keras_backend(
             keras_backend, gpu_device_nums=gpu_device_nums)