Skip to content
Snippets Groups Projects
Commit f22f80cf authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Performance enhancement for faster start time: reuse compiled networks even if...

Performance enhancement for faster start time: reuse compiled networks even if their regularizations differ. This is safe because regularization should not affect predictions
parent 4c0b193f
No related merge requests found
import time import time
import collections import collections
import logging import logging
import json
import numpy import numpy
import pandas import pandas
...@@ -190,17 +191,26 @@ class Class1NeuralNetwork(object): ...@@ -190,17 +191,26 @@ class Class1NeuralNetwork(object):
keras.models.Model keras.models.Model
""" """
assert network_weights is not None assert network_weights is not None
if network_json not in klass.KERAS_MODELS_CACHE: key = klass.keras_network_cache_key(network_json)
if key not in klass.KERAS_MODELS_CACHE:
# Cache miss. # Cache miss.
import keras.models import keras.models
network = keras.models.model_from_json(network_json) network = keras.models.model_from_json(network_json)
existing_weights = None existing_weights = None
else: else:
# Cache hit. # Cache hit.
(network, existing_weights) = klass.KERAS_MODELS_CACHE[network_json] (network, existing_weights) = klass.KERAS_MODELS_CACHE[key]
if existing_weights is not network_weights: if existing_weights is not network_weights:
network.set_weights(network_weights) network.set_weights(network_weights)
klass.KERAS_MODELS_CACHE[network_json] = (network, network_weights) klass.KERAS_MODELS_CACHE[key] = (network, network_weights)
# As an added safety check we overwrite the fit method on the returned
# model to throw an error if it is called.
def throw(*args, **kwargs):
raise NotImplementedError("Do not call fit on cached model.")
network.fit = throw
return network return network
def network(self, borrow=False): def network(self, borrow=False):
...@@ -237,6 +247,20 @@ class Class1NeuralNetwork(object): ...@@ -237,6 +247,20 @@ class Class1NeuralNetwork(object):
self.network_json = self._network.to_json() self.network_json = self._network.to_json()
self.network_weights = self._network.get_weights() self.network_weights = self._network.get_weights()
@staticmethod
def keras_network_cache_key(network_json):
# As an optimization, we remove anything about regularization as these
# do not affect predictions.
def drop_properties(d):
if 'kernel_regularizer' in d:
del d['kernel_regularizer']
return d
description = json.loads(
network_json,
object_hook=drop_properties)
return json.dumps(description)
def get_config(self): def get_config(self):
""" """
serialize to a dict all attributes except model weights serialize to a dict all attributes except model weights
......
...@@ -41,7 +41,7 @@ class PercentRankTransform(object): ...@@ -41,7 +41,7 @@ class PercentRankTransform(object):
indices = numpy.searchsorted(self.bin_edges, values) indices = numpy.searchsorted(self.bin_edges, values)
result = self.cdf[indices] result = self.cdf[indices]
assert len(result) == len(values) assert len(result) == len(values)
return result return numpy.minimum(result, 100.0)
def to_series(self): def to_series(self):
""" """
......
import numpy import numpy
numpy.random.seed(0) numpy.random.seed(0)
from mhcflurry import Class1AffinityPredictor from numpy.testing import assert_equal
from mhcflurry import Class1AffinityPredictor, Class1NeuralNetwork
DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load() DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load()
...@@ -46,3 +48,22 @@ def test_a2_hiv_epitope_downloaded_models(): ...@@ -46,3 +48,22 @@ def test_a2_hiv_epitope_downloaded_models():
# The HIV-1 HLA-A2-SLYNTVATL Is a Help-Independent CTL Epitope # The HIV-1 HLA-A2-SLYNTVATL Is a Help-Independent CTL Epitope
predict_and_check("HLA-A*02:01", "SLYNTVATL") predict_and_check("HLA-A*02:01", "SLYNTVATL")
def test_caching():
Class1NeuralNetwork.KERAS_MODELS_CACHE.clear()
DOWNLOADED_PREDICTOR.predict(
peptides=["SIINFEKL"],
allele="HLA-A*02:01")
num_cached = len(Class1NeuralNetwork.KERAS_MODELS_CACHE)
# A new allele should leave the same number of models cached (under the
# current scheme in which all alelles use the same architectures).
DOWNLOADED_PREDICTOR.predict(
peptides=["SIINFEKL"],
allele="HLA-A*03:01")
print("Cached networks: %d" % len(Class1NeuralNetwork.KERAS_MODELS_CACHE))
assert_equal(num_cached, len(Class1NeuralNetwork.KERAS_MODELS_CACHE))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment