diff --git a/mhcflurry/select_allele_specific_models_command.py b/mhcflurry/select_allele_specific_models_command.py index 2f35d92bba4c62f055ad0af820a671126b036024..660f1d8e06f92516e7d91152e18bbaee9233a75f 100644 --- a/mhcflurry/select_allele_specific_models_command.py +++ b/mhcflurry/select_allele_specific_models_command.py @@ -442,13 +442,13 @@ def model_select(allele): unselected_score_function = ( unselected_accuracy_scorer.score_function(allele)) - unselected_score = unselected_score_function(predictor) - scrambled_predictor = ScrambledPredictor(predictor) additional_metadata = {} + unselected_score = unselected_score_function( + predictor, additional_metadata_out=additional_metadata) + scrambled_predictor = ScrambledPredictor(predictor) scrambled_scores = numpy.array([ unselected_score_function( - scrambled_predictor, - additional_metadata_out=additional_metadata) + scrambled_predictor) for _ in range(unselected_accuracy_scorer_samples) ]) unselected_score_scrambled_mean = scrambled_scores.mean() @@ -779,8 +779,7 @@ class MassSpecModelSelector(object): # We additionally compute AUC score. additional_metadata_out["score_mass_spec_AUC"] = roc_auc_score( - self.df[allele].values, - -1 * predictions) + self.df[allele].values, -1 * predictions) return ppv * multiplier summary = "mass-spec (%d hits / %d decoys)" % (total_hits, total_decoys) diff --git a/test/test_train_and_related_commands.py b/test/test_train_and_related_commands.py index 409f8337c3a40572cf57e42f04aadb875c069f8d..a2077858de020600d75363927ead5058e2cff5ce 100644 --- a/test/test_train_and_related_commands.py +++ b/test/test_train_and_related_commands.py @@ -105,7 +105,7 @@ def run_and_check_with_model_selection(n_jobs=1): deepcopy(HYPERPARAMETERS[0]), deepcopy(HYPERPARAMETERS[0]), ] - hyperparameters[-1]["max_epochs"] = 0 + hyperparameters[-1]["max_epochs"] = 10 with open(hyperparameters_filename, "w") as fd: json.dump(hyperparameters, fd) @@ -153,9 +153,9 @@ def run_and_check_with_model_selection(n_jobs=1): result.allele_to_allele_specific_models["HLA-A*03:01"][ 0].hyperparameters["max_epochs"], 500) - #print("Deleting: %s" % models_dir1) - #print("Deleting: %s" % models_dir2) - #shutil.rmtree(models_dir1) + print("Deleting: %s" % models_dir1) + print("Deleting: %s" % models_dir2) + shutil.rmtree(models_dir1) if os.environ.get("KERAS_BACKEND") != "theano":