Skip to content
Snippets Groups Projects
Commit a940a110 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Cache predictions in Class1NeuralNetwork under certain circumstances

parent 877c60a0
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ import time ...@@ -2,6 +2,7 @@ import time
import collections import collections
import logging import logging
import json import json
import weakref
import numpy import numpy
import pandas import pandas
...@@ -155,6 +156,8 @@ class Class1NeuralNetwork(object): ...@@ -155,6 +156,8 @@ class Class1NeuralNetwork(object):
self.fit_seconds = None self.fit_seconds = None
self.fit_num_points = None self.fit_num_points = None
self.prediction_cache = weakref.WeakKeyDictionary()
KERAS_MODELS_CACHE = {} KERAS_MODELS_CACHE = {}
""" """
Process-wide keras model cache, a map from: architecture JSON string to Process-wide keras model cache, a map from: architecture JSON string to
...@@ -274,6 +277,7 @@ class Class1NeuralNetwork(object): ...@@ -274,6 +277,7 @@ class Class1NeuralNetwork(object):
result['_network'] = None result['_network'] = None
result['network_weights'] = None result['network_weights'] = None
result['network_weights_loader'] = None result['network_weights_loader'] = None
result['prediction_cache'] = None
return result return result
@classmethod @classmethod
...@@ -299,6 +303,7 @@ class Class1NeuralNetwork(object): ...@@ -299,6 +303,7 @@ class Class1NeuralNetwork(object):
instance.__dict__.update(config) instance.__dict__.update(config)
instance.network_weights = weights instance.network_weights = weights
instance.network_weights_loader = weights_loader instance.network_weights_loader = weights_loader
instance.prediction_cache = weakref.WeakKeyDictionary()
return instance return instance
def load_weights(self): def load_weights(self):
...@@ -338,6 +343,7 @@ class Class1NeuralNetwork(object): ...@@ -338,6 +343,7 @@ class Class1NeuralNetwork(object):
self.load_weights() self.load_weights()
result = dict(self.__dict__) result = dict(self.__dict__)
result['_network'] = None result['_network'] = None
result['prediction_cache'] = None
return result return result
def peptides_to_network_input(self, peptides): def peptides_to_network_input(self, peptides):
...@@ -704,13 +710,18 @@ class Class1NeuralNetwork(object): ...@@ -704,13 +710,18 @@ class Class1NeuralNetwork(object):
def predict(self, peptides, allele_encoding=None, batch_size=4096): def predict(self, peptides, allele_encoding=None, batch_size=4096):
""" """
Predict affinities Predict affinities.
If peptides are specified as EncodableSequences, then the predictions
will be cached for this predictor as long as the EncodableSequences object
remains in memory. The cache is keyed in the object identity of the
EncodableSequences, not the sequences themselves.
Parameters Parameters
---------- ----------
peptides : EncodableSequences or list of string peptides : EncodableSequences or list of string
allele_pseudosequences : AlleleEncoding, optional allele_encoding : AlleleEncoding, optional
Only required when this model is a pan-allele model Only required when this model is a pan-allele model
batch_size : int batch_size : int
...@@ -720,6 +731,12 @@ class Class1NeuralNetwork(object): ...@@ -720,6 +731,12 @@ class Class1NeuralNetwork(object):
------- -------
numpy.array of nM affinity predictions numpy.array of nM affinity predictions
""" """
use_cache = (
allele_encoding is None and
isinstance(peptides, EncodableSequences))
if use_cache and peptides in self.prediction_cache:
return self.prediction_cache[peptides].copy()
x_dict = { x_dict = {
'peptide': self.peptides_to_network_input(peptides) 'peptide': self.peptides_to_network_input(peptides)
} }
...@@ -730,7 +747,10 @@ class Class1NeuralNetwork(object): ...@@ -730,7 +747,10 @@ class Class1NeuralNetwork(object):
network = self.network(borrow=True) network = self.network(borrow=True)
raw_predictions = network.predict(x_dict, batch_size=batch_size) raw_predictions = network.predict(x_dict, batch_size=batch_size)
predictions = numpy.array(raw_predictions, dtype = "float64")[:,0] predictions = numpy.array(raw_predictions, dtype = "float64")[:,0]
return to_ic50(predictions) result = to_ic50(predictions)
if use_cache:
self.prediction_cache[peptides] = result
return result
@staticmethod @staticmethod
def make_network( def make_network(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment