diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 936e75119f6974fb6b619ad4f57461a9a7e50339..06d363dc6e0baef7b5c022712dfab666c6e313b8 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -1,6 +1,7 @@ import time import collections import logging +import json import numpy import pandas @@ -190,17 +191,26 @@ class Class1NeuralNetwork(object): keras.models.Model """ assert network_weights is not None - if network_json not in klass.KERAS_MODELS_CACHE: + key = klass.keras_network_cache_key(network_json) + if key not in klass.KERAS_MODELS_CACHE: # Cache miss. import keras.models network = keras.models.model_from_json(network_json) existing_weights = None else: # Cache hit. - (network, existing_weights) = klass.KERAS_MODELS_CACHE[network_json] + (network, existing_weights) = klass.KERAS_MODELS_CACHE[key] if existing_weights is not network_weights: network.set_weights(network_weights) - klass.KERAS_MODELS_CACHE[network_json] = (network, network_weights) + klass.KERAS_MODELS_CACHE[key] = (network, network_weights) + + + # As an added safety check we overwrite the fit method on the returned + # model to throw an error if it is called. + def throw(*args, **kwargs): + raise NotImplementedError("Do not call fit on cached model.") + + network.fit = throw return network def network(self, borrow=False): @@ -237,6 +247,20 @@ class Class1NeuralNetwork(object): self.network_json = self._network.to_json() self.network_weights = self._network.get_weights() + @staticmethod + def keras_network_cache_key(network_json): + # As an optimization, we remove anything about regularization as these + # do not affect predictions. + def drop_properties(d): + if 'kernel_regularizer' in d: + del d['kernel_regularizer'] + return d + + description = json.loads( + network_json, + object_hook=drop_properties) + return json.dumps(description) + def get_config(self): """ serialize to a dict all attributes except model weights diff --git a/mhcflurry/percent_rank_transform.py b/mhcflurry/percent_rank_transform.py index 7054736db6cfb48d903eb862d9697e77e2eb55d1..6f42477d56f94619e526d84bd3a8a5c5965caa45 100644 --- a/mhcflurry/percent_rank_transform.py +++ b/mhcflurry/percent_rank_transform.py @@ -41,7 +41,7 @@ class PercentRankTransform(object): indices = numpy.searchsorted(self.bin_edges, values) result = self.cdf[indices] assert len(result) == len(values) - return result + return numpy.minimum(result, 100.0) def to_series(self): """ diff --git a/test/test_download_models_class1.py b/test/test_download_models_class1.py index 7c0b6e082941ffc1246576fad28eaf82145a5da9..1dec5633c865d1dbea686b7c5bfcaa3a0364f9c5 100644 --- a/test/test_download_models_class1.py +++ b/test/test_download_models_class1.py @@ -1,7 +1,9 @@ import numpy numpy.random.seed(0) -from mhcflurry import Class1AffinityPredictor +from numpy.testing import assert_equal + +from mhcflurry import Class1AffinityPredictor, Class1NeuralNetwork DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load() @@ -46,3 +48,22 @@ def test_a2_hiv_epitope_downloaded_models(): # The HIV-1 HLA-A2-SLYNTVATL Is a Help-Independent CTL Epitope predict_and_check("HLA-A*02:01", "SLYNTVATL") +def test_caching(): + Class1NeuralNetwork.KERAS_MODELS_CACHE.clear() + DOWNLOADED_PREDICTOR.predict( + peptides=["SIINFEKL"], + allele="HLA-A*02:01") + num_cached = len(Class1NeuralNetwork.KERAS_MODELS_CACHE) + + # A new allele should leave the same number of models cached (under the + # current scheme in which all alelles use the same architectures). + DOWNLOADED_PREDICTOR.predict( + peptides=["SIINFEKL"], + allele="HLA-A*03:01") + print("Cached networks: %d" % len(Class1NeuralNetwork.KERAS_MODELS_CACHE)) + assert_equal(num_cached, len(Class1NeuralNetwork.KERAS_MODELS_CACHE)) + + + + +