From f22f80cff9941cec7ed25db78ee9834ad0cb7144 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Thu, 8 Feb 2018 11:56:58 -0500
Subject: [PATCH] Performance enhancement for faster start time: reuse compiled
 networks even if their regularizations differ. This is safe because
 regularization should not affect predictions

---
 mhcflurry/class1_neural_network.py  | 30 ++++++++++++++++++++++++++---
 mhcflurry/percent_rank_transform.py |  2 +-
 test/test_download_models_class1.py | 23 +++++++++++++++++++++-
 3 files changed, 50 insertions(+), 5 deletions(-)

diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 936e7511..06d363dc 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -1,6 +1,7 @@
 import time
 import collections
 import logging
+import json
 
 import numpy
 import pandas
@@ -190,17 +191,26 @@ class Class1NeuralNetwork(object):
         keras.models.Model
         """
         assert network_weights is not None
-        if network_json not in klass.KERAS_MODELS_CACHE:
+        key = klass.keras_network_cache_key(network_json)
+        if key not in klass.KERAS_MODELS_CACHE:
             # Cache miss.
             import keras.models
             network = keras.models.model_from_json(network_json)
             existing_weights = None
         else:
             # Cache hit.
-            (network, existing_weights) = klass.KERAS_MODELS_CACHE[network_json]
+            (network, existing_weights) = klass.KERAS_MODELS_CACHE[key]
         if existing_weights is not network_weights:
             network.set_weights(network_weights)
-            klass.KERAS_MODELS_CACHE[network_json] = (network, network_weights)
+            klass.KERAS_MODELS_CACHE[key] = (network, network_weights)
+
+
+        # As an added safety check we overwrite the fit method on the returned
+        # model to throw an error if it is called.
+        def throw(*args, **kwargs):
+            raise NotImplementedError("Do not call fit on cached model.")
+
+        network.fit = throw
         return network
 
     def network(self, borrow=False):
@@ -237,6 +247,20 @@ class Class1NeuralNetwork(object):
             self.network_json = self._network.to_json()
             self.network_weights = self._network.get_weights()
 
+    @staticmethod
+    def keras_network_cache_key(network_json):
+        # As an optimization, we remove anything about regularization as these
+        # do not affect predictions.
+        def drop_properties(d):
+            if 'kernel_regularizer' in d:
+                del d['kernel_regularizer']
+            return d
+
+        description = json.loads(
+            network_json,
+            object_hook=drop_properties)
+        return json.dumps(description)
+
     def get_config(self):
         """
         serialize to a dict all attributes except model weights
diff --git a/mhcflurry/percent_rank_transform.py b/mhcflurry/percent_rank_transform.py
index 7054736d..6f42477d 100644
--- a/mhcflurry/percent_rank_transform.py
+++ b/mhcflurry/percent_rank_transform.py
@@ -41,7 +41,7 @@ class PercentRankTransform(object):
         indices = numpy.searchsorted(self.bin_edges, values)
         result = self.cdf[indices]
         assert len(result) == len(values)
-        return result
+        return numpy.minimum(result, 100.0)
 
     def to_series(self):
         """
diff --git a/test/test_download_models_class1.py b/test/test_download_models_class1.py
index 7c0b6e08..1dec5633 100644
--- a/test/test_download_models_class1.py
+++ b/test/test_download_models_class1.py
@@ -1,7 +1,9 @@
 import numpy
 numpy.random.seed(0)
 
-from mhcflurry import Class1AffinityPredictor
+from numpy.testing import assert_equal
+
+from mhcflurry import Class1AffinityPredictor, Class1NeuralNetwork
 
 
 DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load()
@@ -46,3 +48,22 @@ def test_a2_hiv_epitope_downloaded_models():
     #    The HIV-1 HLA-A2-SLYNTVATL Is a Help-Independent CTL Epitope
     predict_and_check("HLA-A*02:01", "SLYNTVATL")
 
+def test_caching():
+    Class1NeuralNetwork.KERAS_MODELS_CACHE.clear()
+    DOWNLOADED_PREDICTOR.predict(
+        peptides=["SIINFEKL"],
+        allele="HLA-A*02:01")
+    num_cached = len(Class1NeuralNetwork.KERAS_MODELS_CACHE)
+
+    # A new allele should leave the same number of models cached (under the
+    # current scheme in which all alelles use the same architectures).
+    DOWNLOADED_PREDICTOR.predict(
+        peptides=["SIINFEKL"],
+        allele="HLA-A*03:01")
+    print("Cached networks: %d" % len(Class1NeuralNetwork.KERAS_MODELS_CACHE))
+    assert_equal(num_cached, len(Class1NeuralNetwork.KERAS_MODELS_CACHE))
+
+
+
+
+
-- 
GitLab