diff --git a/downloads-generation/data_mass_spec_benchmark/run_mhcflurry.py b/downloads-generation/data_mass_spec_benchmark/run_mhcflurry.py index 1b4253ae3b78ba151d5c8966f124a6f7888809c4..804cd531d0ecdcd97b548737d89641fb812899e9 100644 --- a/downloads-generation/data_mass_spec_benchmark/run_mhcflurry.py +++ b/downloads-generation/data_mass_spec_benchmark/run_mhcflurry.py @@ -230,22 +230,14 @@ def run(argv=sys.argv[1:]): prediction_time / 60.0)) -def do_predictions(chunk_index, peptides, alleles, constant_data=GLOBAL_DATA): - return predict_for_allele( - chunk_index, - peptides, - alleles, - predictor=constant_data['predictor'], - **constant_data["args"]) - - -def predict_for_allele( - chunk_index, - peptides, - alleles, - predictor, - verbose=False, - model_kwargs={}): +def do_predictions(chunk_index, peptides, alleles, constant_data=None): + if constant_data is None: + constant_data = GLOBAL_DATA + + predictor = constant_data['predictor'] + verbose = constant_data['args'].get("verbose", False) + model_kwargs = constant_data['args'].get("model_kwargs", {}) + predictor.optimize(warn=False) # since we loaded with optimization_level=0 start = time.time() results = {}