diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 73405fdbf6e6a27e060873512cd5503ba735c30b..043a5c0c5704533dd7939e26572be9ee56a7bd41 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -125,6 +125,13 @@ class Class1NeuralNetwork(object): (Keras model, existing network weights) """ + @classmethod + def clear_model_cache(klass): + """ + Clear the Keras model cache. + """ + klass.KERAS_MODELS_CACHE.clear() + @classmethod def borrow_cached_network(klass, network_json, network_weights): """ diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py index 3d72b32a1a50ea5df3b0abf4533ce74dd05f394a..c2b6ed5dcb702d91e471b87a46708f326af6ce20 100644 --- a/mhcflurry/train_allele_specific_models_command.py +++ b/mhcflurry/train_allele_specific_models_command.py @@ -17,6 +17,7 @@ from mhcnames import normalize_allele_name import tqdm # progress bar from .class1_affinity_predictor import Class1AffinityPredictor +from .class1_neural_network import Class1NeuralNetwork from .common import configure_logging, set_keras_backend @@ -281,10 +282,11 @@ def run(argv=sys.argv[1:]): # Store peptides in global variable so they are in shared memory # after fork, instead of needing to be pickled. GLOBAL_DATA["calibration_peptides"] = encoded_peptides + Class1NeuralNetwork.clear_model_cache() worker_pool = Pool( processes=( args.calibration_num_jobs - if args.train_num_jobs else None)) + if args.calibration_num_jobs else None)) print("Using worker pool: %s" % str(worker_pool)) results = worker_pool.imap_unordered( partial( @@ -376,6 +378,7 @@ def calibrate_percentile_ranks(allele, predictor, peptides=None): """ Private helper function. """ + global GLOBAL_DATA if peptides is None: peptides = GLOBAL_DATA["calibration_peptides"] predictor.calibrate_percentile_ranks(