Skip to content
Snippets Groups Projects
Commit f22f80cf authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Performance enhancement for faster start time: reuse compiled networks even if...

Performance enhancement for faster start time: reuse compiled networks even if their regularizations differ. This is safe because regularization should not affect predictions
parent 4c0b193f
No related merge requests found
import time
import collections
import logging
import json
import numpy
import pandas
......@@ -190,17 +191,26 @@ class Class1NeuralNetwork(object):
keras.models.Model
"""
assert network_weights is not None
if network_json not in klass.KERAS_MODELS_CACHE:
key = klass.keras_network_cache_key(network_json)
if key not in klass.KERAS_MODELS_CACHE:
# Cache miss.
import keras.models
network = keras.models.model_from_json(network_json)
existing_weights = None
else:
# Cache hit.
(network, existing_weights) = klass.KERAS_MODELS_CACHE[network_json]
(network, existing_weights) = klass.KERAS_MODELS_CACHE[key]
if existing_weights is not network_weights:
network.set_weights(network_weights)
klass.KERAS_MODELS_CACHE[network_json] = (network, network_weights)
klass.KERAS_MODELS_CACHE[key] = (network, network_weights)
# As an added safety check we overwrite the fit method on the returned
# model to throw an error if it is called.
def throw(*args, **kwargs):
raise NotImplementedError("Do not call fit on cached model.")
network.fit = throw
return network
def network(self, borrow=False):
......@@ -237,6 +247,20 @@ class Class1NeuralNetwork(object):
self.network_json = self._network.to_json()
self.network_weights = self._network.get_weights()
@staticmethod
def keras_network_cache_key(network_json):
# As an optimization, we remove anything about regularization as these
# do not affect predictions.
def drop_properties(d):
if 'kernel_regularizer' in d:
del d['kernel_regularizer']
return d
description = json.loads(
network_json,
object_hook=drop_properties)
return json.dumps(description)
def get_config(self):
"""
serialize to a dict all attributes except model weights
......
......@@ -41,7 +41,7 @@ class PercentRankTransform(object):
indices = numpy.searchsorted(self.bin_edges, values)
result = self.cdf[indices]
assert len(result) == len(values)
return result
return numpy.minimum(result, 100.0)
def to_series(self):
"""
......
import numpy
numpy.random.seed(0)
from mhcflurry import Class1AffinityPredictor
from numpy.testing import assert_equal
from mhcflurry import Class1AffinityPredictor, Class1NeuralNetwork
DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load()
......@@ -46,3 +48,22 @@ def test_a2_hiv_epitope_downloaded_models():
# The HIV-1 HLA-A2-SLYNTVATL Is a Help-Independent CTL Epitope
predict_and_check("HLA-A*02:01", "SLYNTVATL")
def test_caching():
Class1NeuralNetwork.KERAS_MODELS_CACHE.clear()
DOWNLOADED_PREDICTOR.predict(
peptides=["SIINFEKL"],
allele="HLA-A*02:01")
num_cached = len(Class1NeuralNetwork.KERAS_MODELS_CACHE)
# A new allele should leave the same number of models cached (under the
# current scheme in which all alelles use the same architectures).
DOWNLOADED_PREDICTOR.predict(
peptides=["SIINFEKL"],
allele="HLA-A*03:01")
print("Cached networks: %d" % len(Class1NeuralNetwork.KERAS_MODELS_CACHE))
assert_equal(num_cached, len(Class1NeuralNetwork.KERAS_MODELS_CACHE))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment