diff --git a/mhcflurry/parallelism.py b/mhcflurry/parallelism.py index b21af5c64f2c72bd13336161fefa8604fa0edd3c..55d7c450aa2af5614486d4765bb422038df4f096 100644 --- a/mhcflurry/parallelism.py +++ b/mhcflurry/parallelism.py @@ -21,7 +21,7 @@ def add_worker_pool_args(parser): type=int, metavar="N", help="Number of processes to parallelize training over. Experimental. " - "Set to 1 for serial run. Set to 0 to use number of cores. Default: %(default)s.") + "Set to 0 for serial run. Default: %(default)s.") group.add_argument( "--backend", choices=("tensorflow-gpu", "tensorflow-cpu", "tensorflow-default"), @@ -67,7 +67,7 @@ def worker_pool_with_gpu_assignments( num_workers = num_jobs if num_jobs else cpu_count() - if num_workers == 1: + if num_workers == 0: if backend: set_keras_backend(backend) return None diff --git a/test/test_train_pan_allele_models_command.py b/test/test_train_pan_allele_models_command.py index d2c303886e9906ecf2230ed3b5de598e62b6feef..f4142a1d869187298d039a1968a6bbe080d2df7b 100644 --- a/test/test_train_pan_allele_models_command.py +++ b/test/test_train_pan_allele_models_command.py @@ -146,11 +146,12 @@ def run_and_check(n_jobs=0): if os.environ.get("KERAS_BACKEND") != "theano": def test_run_parallel(): + run_and_check(n_jobs=1) run_and_check(n_jobs=2) def test_run_serial(): - run_and_check(n_jobs=1) + run_and_check(n_jobs=0) if __name__ == "__main__": test_run_serial() \ No newline at end of file