diff --git a/mhcflurry/class1_allele_specific/cv_and_train_command.py b/mhcflurry/class1_allele_specific/cv_and_train_command.py index 075e8e3761c5421efc3f7f6d2f30ed30d2e8b565..b1256ac0dc586d8b6dd56eeaf3ec7d7af5fdb5bc 100644 --- a/mhcflurry/class1_allele_specific/cv_and_train_command.py +++ b/mhcflurry/class1_allele_specific/cv_and_train_command.py @@ -169,6 +169,12 @@ parser.add_argument( default=False, help="Output more info") +try: + import kubeface + kubeface.Client.add_args(parser) +except ImportError: + logging.error("Kubeface support disabled, not installed.") + def run(argv=sys.argv[1:]): args = parser.parse_args(argv) @@ -183,6 +189,8 @@ def run(argv=sys.argv[1:]): if args.dask_scheduler: backend = parallelism.DaskDistributedParallelBackend( args.dask_scheduler) + elif hasattr(args, 'storage_prefix') and args.storage_prefix: + backend = parallelism.KubefaceParallelBackend(args) else: if args.num_local_processes: backend = parallelism.ConcurrentFuturesParallelBackend( diff --git a/mhcflurry/parallelism.py b/mhcflurry/parallelism.py index 18008b4e9057c3b4e0c3b9c6d1a2db81522beb5f..4d0c47acabc88248244b54e5c0640755af44c706 100644 --- a/mhcflurry/parallelism.py +++ b/mhcflurry/parallelism.py @@ -36,6 +36,21 @@ class ParallelBackend(object): return [result_dict[future] for future in fs] +class KubefaceParallelBackend(ParallelBackend): + """ + ParallelBackend that uses kubeface + """ + def __init__(self, args): + from kubeface import Client # pylint: disable=import-error + self.client = Client.from_args(args) + + def map(self, func, iterable): + return self.client.map(func, iterable) + + def __str__(self): + return "<Kubeface backend, client=%s>" % self.client + + class DaskDistributedParallelBackend(ParallelBackend): """ ParallelBackend that uses dask.distributed