From f22f80cff9941cec7ed25db78ee9834ad0cb7144 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Thu, 8 Feb 2018 11:56:58 -0500 Subject: [PATCH] Performance enhancement for faster start time: reuse compiled networks even if their regularizations differ. This is safe because regularization should not affect predictions --- mhcflurry/class1_neural_network.py | 30 ++++++++++++++++++++++++++--- mhcflurry/percent_rank_transform.py | 2 +- test/test_download_models_class1.py | 23 +++++++++++++++++++++- 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 936e7511..06d363dc 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -1,6 +1,7 @@ import time import collections import logging +import json import numpy import pandas @@ -190,17 +191,26 @@ class Class1NeuralNetwork(object): keras.models.Model """ assert network_weights is not None - if network_json not in klass.KERAS_MODELS_CACHE: + key = klass.keras_network_cache_key(network_json) + if key not in klass.KERAS_MODELS_CACHE: # Cache miss. import keras.models network = keras.models.model_from_json(network_json) existing_weights = None else: # Cache hit. - (network, existing_weights) = klass.KERAS_MODELS_CACHE[network_json] + (network, existing_weights) = klass.KERAS_MODELS_CACHE[key] if existing_weights is not network_weights: network.set_weights(network_weights) - klass.KERAS_MODELS_CACHE[network_json] = (network, network_weights) + klass.KERAS_MODELS_CACHE[key] = (network, network_weights) + + + # As an added safety check we overwrite the fit method on the returned + # model to throw an error if it is called. + def throw(*args, **kwargs): + raise NotImplementedError("Do not call fit on cached model.") + + network.fit = throw return network def network(self, borrow=False): @@ -237,6 +247,20 @@ class Class1NeuralNetwork(object): self.network_json = self._network.to_json() self.network_weights = self._network.get_weights() + @staticmethod + def keras_network_cache_key(network_json): + # As an optimization, we remove anything about regularization as these + # do not affect predictions. + def drop_properties(d): + if 'kernel_regularizer' in d: + del d['kernel_regularizer'] + return d + + description = json.loads( + network_json, + object_hook=drop_properties) + return json.dumps(description) + def get_config(self): """ serialize to a dict all attributes except model weights diff --git a/mhcflurry/percent_rank_transform.py b/mhcflurry/percent_rank_transform.py index 7054736d..6f42477d 100644 --- a/mhcflurry/percent_rank_transform.py +++ b/mhcflurry/percent_rank_transform.py @@ -41,7 +41,7 @@ class PercentRankTransform(object): indices = numpy.searchsorted(self.bin_edges, values) result = self.cdf[indices] assert len(result) == len(values) - return result + return numpy.minimum(result, 100.0) def to_series(self): """ diff --git a/test/test_download_models_class1.py b/test/test_download_models_class1.py index 7c0b6e08..1dec5633 100644 --- a/test/test_download_models_class1.py +++ b/test/test_download_models_class1.py @@ -1,7 +1,9 @@ import numpy numpy.random.seed(0) -from mhcflurry import Class1AffinityPredictor +from numpy.testing import assert_equal + +from mhcflurry import Class1AffinityPredictor, Class1NeuralNetwork DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load() @@ -46,3 +48,22 @@ def test_a2_hiv_epitope_downloaded_models(): # The HIV-1 HLA-A2-SLYNTVATL Is a Help-Independent CTL Epitope predict_and_check("HLA-A*02:01", "SLYNTVATL") +def test_caching(): + Class1NeuralNetwork.KERAS_MODELS_CACHE.clear() + DOWNLOADED_PREDICTOR.predict( + peptides=["SIINFEKL"], + allele="HLA-A*02:01") + num_cached = len(Class1NeuralNetwork.KERAS_MODELS_CACHE) + + # A new allele should leave the same number of models cached (under the + # current scheme in which all alelles use the same architectures). + DOWNLOADED_PREDICTOR.predict( + peptides=["SIINFEKL"], + allele="HLA-A*03:01") + print("Cached networks: %d" % len(Class1NeuralNetwork.KERAS_MODELS_CACHE)) + assert_equal(num_cached, len(Class1NeuralNetwork.KERAS_MODELS_CACHE)) + + + + + -- GitLab