diff --git a/mhcflurry/class1_allele_specific/cross_validation.py b/mhcflurry/class1_allele_specific/cross_validation.py index 4d3dc41a4f7312e132c39ff30364ea1058621ff6..184bae580ea4ed0c87efd9f0f9b9720a7a66dee1 100644 --- a/mhcflurry/class1_allele_specific/cross_validation.py +++ b/mhcflurry/class1_allele_specific/cross_validation.py @@ -192,4 +192,4 @@ def cross_validation_folds( for (result_fold, imputation_result) in zip( result_folds, imputation_results) ] - return result_fold + return result_folds diff --git a/mhcflurry/parallelism.py b/mhcflurry/parallelism.py index c82b40e0bce3d51bf795fe019afa1d633004a247..faecc871469210058ebb11f3323ae0a1ccdc20ce 100644 --- a/mhcflurry/parallelism.py +++ b/mhcflurry/parallelism.py @@ -14,6 +14,7 @@ class ParallelBackend(object): self.module = module self.verbose = verbose + class KubefaceParallelBackend(ParallelBackend): """ ParallelBackend that uses kubeface diff --git a/test/test_class1_allele_specific_cv_and_train_command.py b/test/test_class1_allele_specific_cv_and_train_command.py index f6e52d87a4aa73ba6555f4d0c86c493b58a48a85..9f0c193bb2330ecfb0aa6438fd67cd67d61d43a6 100644 --- a/test/test_class1_allele_specific_cv_and_train_command.py +++ b/test/test_class1_allele_specific_cv_and_train_command.py @@ -61,6 +61,7 @@ def test_small_run(): "--alleles", "HLA-A0201", "HLA-A0301", "--verbose", "--num-local-threads", "1", + "--storage-prefix", "/tmp/", ] print("Running cv_and_train_command with args: %s " % str(args)) diff --git a/test/test_cross_validation.py b/test/test_cross_validation.py index cf95333b81c993ef0282af397abd85658f737d8b..c0d4296e815dad534d8cb8f46b799dd0343f9f69 100644 --- a/test/test_cross_validation.py +++ b/test/test_cross_validation.py @@ -77,7 +77,7 @@ def test_cross_validation_with_imputation(): n_imputations=2, n_burn_in=1, n_nearest_columns=25) train_data = ( mhcflurry.dataset.Dataset.from_csv( - get_path("data_kim2014" , "bdata.2009.mhci.public.1.txt")) + get_path("data_kim2014", "bdata.2009.mhci.public.1.txt")) .get_alleles(["HLA-A0201", "HLA-A0202", "HLA-A0301"])) folds = cross_validation_folds(