From a940a110e51b66faba57d2fe1edb24e62b9c2f99 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Thu, 8 Feb 2018 18:45:34 -0500 Subject: [PATCH] Cache predictions in Class1NeuralNetwork under certain circumstances --- mhcflurry/class1_neural_network.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 06d363dc..6ce50ffc 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -2,6 +2,7 @@ import time import collections import logging import json +import weakref import numpy import pandas @@ -155,6 +156,8 @@ class Class1NeuralNetwork(object): self.fit_seconds = None self.fit_num_points = None + self.prediction_cache = weakref.WeakKeyDictionary() + KERAS_MODELS_CACHE = {} """ Process-wide keras model cache, a map from: architecture JSON string to @@ -274,6 +277,7 @@ class Class1NeuralNetwork(object): result['_network'] = None result['network_weights'] = None result['network_weights_loader'] = None + result['prediction_cache'] = None return result @classmethod @@ -299,6 +303,7 @@ class Class1NeuralNetwork(object): instance.__dict__.update(config) instance.network_weights = weights instance.network_weights_loader = weights_loader + instance.prediction_cache = weakref.WeakKeyDictionary() return instance def load_weights(self): @@ -338,6 +343,7 @@ class Class1NeuralNetwork(object): self.load_weights() result = dict(self.__dict__) result['_network'] = None + result['prediction_cache'] = None return result def peptides_to_network_input(self, peptides): @@ -704,13 +710,18 @@ class Class1NeuralNetwork(object): def predict(self, peptides, allele_encoding=None, batch_size=4096): """ - Predict affinities - + Predict affinities. + + If peptides are specified as EncodableSequences, then the predictions + will be cached for this predictor as long as the EncodableSequences object + remains in memory. The cache is keyed in the object identity of the + EncodableSequences, not the sequences themselves. + Parameters ---------- peptides : EncodableSequences or list of string - allele_pseudosequences : AlleleEncoding, optional + allele_encoding : AlleleEncoding, optional Only required when this model is a pan-allele model batch_size : int @@ -720,6 +731,12 @@ class Class1NeuralNetwork(object): ------- numpy.array of nM affinity predictions """ + use_cache = ( + allele_encoding is None and + isinstance(peptides, EncodableSequences)) + if use_cache and peptides in self.prediction_cache: + return self.prediction_cache[peptides].copy() + x_dict = { 'peptide': self.peptides_to_network_input(peptides) } @@ -730,7 +747,10 @@ class Class1NeuralNetwork(object): network = self.network(borrow=True) raw_predictions = network.predict(x_dict, batch_size=batch_size) predictions = numpy.array(raw_predictions, dtype = "float64")[:,0] - return to_ic50(predictions) + result = to_ic50(predictions) + if use_cache: + self.prediction_cache[peptides] = result + return result @staticmethod def make_network( -- GitLab