From 5e96719fdcb3ac9add109e588a605b727f9537bf Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Sun, 28 Jan 2018 15:34:45 -0500 Subject: [PATCH] attempt fix race condition --- mhcflurry/class1_neural_network.py | 7 +++++++ mhcflurry/train_allele_specific_models_command.py | 5 ++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 73405fdb..043a5c0c 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -125,6 +125,13 @@ class Class1NeuralNetwork(object): (Keras model, existing network weights) """ + @classmethod + def clear_model_cache(klass): + """ + Clear the Keras model cache. + """ + klass.KERAS_MODELS_CACHE.clear() + @classmethod def borrow_cached_network(klass, network_json, network_weights): """ diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py index 3d72b32a..c2b6ed5d 100644 --- a/mhcflurry/train_allele_specific_models_command.py +++ b/mhcflurry/train_allele_specific_models_command.py @@ -17,6 +17,7 @@ from mhcnames import normalize_allele_name import tqdm # progress bar from .class1_affinity_predictor import Class1AffinityPredictor +from .class1_neural_network import Class1NeuralNetwork from .common import configure_logging, set_keras_backend @@ -281,10 +282,11 @@ def run(argv=sys.argv[1:]): # Store peptides in global variable so they are in shared memory # after fork, instead of needing to be pickled. GLOBAL_DATA["calibration_peptides"] = encoded_peptides + Class1NeuralNetwork.clear_model_cache() worker_pool = Pool( processes=( args.calibration_num_jobs - if args.train_num_jobs else None)) + if args.calibration_num_jobs else None)) print("Using worker pool: %s" % str(worker_pool)) results = worker_pool.imap_unordered( partial( @@ -376,6 +378,7 @@ def calibrate_percentile_ranks(allele, predictor, peptides=None): """ Private helper function. """ + global GLOBAL_DATA if peptides is None: peptides = GLOBAL_DATA["calibration_peptides"] predictor.calibrate_percentile_ranks( -- GitLab