From 5e96719fdcb3ac9add109e588a605b727f9537bf Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Sun, 28 Jan 2018 15:34:45 -0500
Subject: [PATCH] attempt fix race condition

---
 mhcflurry/class1_neural_network.py                | 7 +++++++
 mhcflurry/train_allele_specific_models_command.py | 5 ++++-
 2 files changed, 11 insertions(+), 1 deletion(-)

diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 73405fdb..043a5c0c 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -125,6 +125,13 @@ class Class1NeuralNetwork(object):
     (Keras model, existing network weights)
     """
 
+    @classmethod
+    def clear_model_cache(klass):
+        """
+        Clear the Keras model cache.
+        """
+        klass.KERAS_MODELS_CACHE.clear()
+
     @classmethod
     def borrow_cached_network(klass, network_json, network_weights):
         """
diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index 3d72b32a..c2b6ed5d 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -17,6 +17,7 @@ from mhcnames import normalize_allele_name
 import tqdm  # progress bar
 
 from .class1_affinity_predictor import Class1AffinityPredictor
+from .class1_neural_network import Class1NeuralNetwork
 from .common import configure_logging, set_keras_backend
 
 
@@ -281,10 +282,11 @@ def run(argv=sys.argv[1:]):
             # Store peptides in global variable so they are in shared memory
             # after fork, instead of needing to be pickled.
             GLOBAL_DATA["calibration_peptides"] = encoded_peptides
+            Class1NeuralNetwork.clear_model_cache()
             worker_pool = Pool(
                 processes=(
                     args.calibration_num_jobs
-                    if args.train_num_jobs else None))
+                    if args.calibration_num_jobs else None))
             print("Using worker pool: %s" % str(worker_pool))
             results = worker_pool.imap_unordered(
                 partial(
@@ -376,6 +378,7 @@ def calibrate_percentile_ranks(allele, predictor, peptides=None):
     """
     Private helper function.
     """
+    global GLOBAL_DATA
     if peptides is None:
         peptides = GLOBAL_DATA["calibration_peptides"]
     predictor.calibrate_percentile_ranks(
-- 
GitLab