Skip to content
Snippets Groups Projects
Commit a940a110 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Cache predictions in Class1NeuralNetwork under certain circumstances

parent 877c60a0
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ import time
import collections
import logging
import json
import weakref
import numpy
import pandas
......@@ -155,6 +156,8 @@ class Class1NeuralNetwork(object):
self.fit_seconds = None
self.fit_num_points = None
self.prediction_cache = weakref.WeakKeyDictionary()
KERAS_MODELS_CACHE = {}
"""
Process-wide keras model cache, a map from: architecture JSON string to
......@@ -274,6 +277,7 @@ class Class1NeuralNetwork(object):
result['_network'] = None
result['network_weights'] = None
result['network_weights_loader'] = None
result['prediction_cache'] = None
return result
@classmethod
......@@ -299,6 +303,7 @@ class Class1NeuralNetwork(object):
instance.__dict__.update(config)
instance.network_weights = weights
instance.network_weights_loader = weights_loader
instance.prediction_cache = weakref.WeakKeyDictionary()
return instance
def load_weights(self):
......@@ -338,6 +343,7 @@ class Class1NeuralNetwork(object):
self.load_weights()
result = dict(self.__dict__)
result['_network'] = None
result['prediction_cache'] = None
return result
def peptides_to_network_input(self, peptides):
......@@ -704,13 +710,18 @@ class Class1NeuralNetwork(object):
def predict(self, peptides, allele_encoding=None, batch_size=4096):
"""
Predict affinities
Predict affinities.
If peptides are specified as EncodableSequences, then the predictions
will be cached for this predictor as long as the EncodableSequences object
remains in memory. The cache is keyed in the object identity of the
EncodableSequences, not the sequences themselves.
Parameters
----------
peptides : EncodableSequences or list of string
allele_pseudosequences : AlleleEncoding, optional
allele_encoding : AlleleEncoding, optional
Only required when this model is a pan-allele model
batch_size : int
......@@ -720,6 +731,12 @@ class Class1NeuralNetwork(object):
-------
numpy.array of nM affinity predictions
"""
use_cache = (
allele_encoding is None and
isinstance(peptides, EncodableSequences))
if use_cache and peptides in self.prediction_cache:
return self.prediction_cache[peptides].copy()
x_dict = {
'peptide': self.peptides_to_network_input(peptides)
}
......@@ -730,7 +747,10 @@ class Class1NeuralNetwork(object):
network = self.network(borrow=True)
raw_predictions = network.predict(x_dict, batch_size=batch_size)
predictions = numpy.array(raw_predictions, dtype = "float64")[:,0]
return to_ic50(predictions)
result = to_ic50(predictions)
if use_cache:
self.prediction_cache[peptides] = result
return result
@staticmethod
def make_network(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment