From a940a110e51b66faba57d2fe1edb24e62b9c2f99 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Thu, 8 Feb 2018 18:45:34 -0500
Subject: [PATCH] Cache predictions in Class1NeuralNetwork under certain
 circumstances

---
 mhcflurry/class1_neural_network.py | 28 ++++++++++++++++++++++++----
 1 file changed, 24 insertions(+), 4 deletions(-)

diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 06d363dc..6ce50ffc 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -2,6 +2,7 @@ import time
 import collections
 import logging
 import json
+import weakref
 
 import numpy
 import pandas
@@ -155,6 +156,8 @@ class Class1NeuralNetwork(object):
         self.fit_seconds = None
         self.fit_num_points = None
 
+        self.prediction_cache = weakref.WeakKeyDictionary()
+
     KERAS_MODELS_CACHE = {}
     """
     Process-wide keras model cache, a map from: architecture JSON string to
@@ -274,6 +277,7 @@ class Class1NeuralNetwork(object):
         result['_network'] = None
         result['network_weights'] = None
         result['network_weights_loader'] = None
+        result['prediction_cache'] = None
         return result
 
     @classmethod
@@ -299,6 +303,7 @@ class Class1NeuralNetwork(object):
         instance.__dict__.update(config)
         instance.network_weights = weights
         instance.network_weights_loader = weights_loader
+        instance.prediction_cache = weakref.WeakKeyDictionary()
         return instance
 
     def load_weights(self):
@@ -338,6 +343,7 @@ class Class1NeuralNetwork(object):
         self.load_weights()
         result = dict(self.__dict__)
         result['_network'] = None
+        result['prediction_cache'] = None
         return result
 
     def peptides_to_network_input(self, peptides):
@@ -704,13 +710,18 @@ class Class1NeuralNetwork(object):
 
     def predict(self, peptides, allele_encoding=None, batch_size=4096):
         """
-        Predict affinities
-        
+        Predict affinities.
+
+        If peptides are specified as EncodableSequences, then the predictions
+        will be cached for this predictor as long as the EncodableSequences object
+        remains in memory. The cache is keyed in the object identity of the
+        EncodableSequences, not the sequences themselves.
+
         Parameters
         ----------
         peptides : EncodableSequences or list of string
         
-        allele_pseudosequences : AlleleEncoding, optional
+        allele_encoding : AlleleEncoding, optional
             Only required when this model is a pan-allele model
 
         batch_size : int
@@ -720,6 +731,12 @@ class Class1NeuralNetwork(object):
         -------
         numpy.array of nM affinity predictions 
         """
+        use_cache = (
+            allele_encoding is None and
+            isinstance(peptides, EncodableSequences))
+        if use_cache and peptides in self.prediction_cache:
+            return self.prediction_cache[peptides].copy()
+
         x_dict = {
             'peptide': self.peptides_to_network_input(peptides)
         }
@@ -730,7 +747,10 @@ class Class1NeuralNetwork(object):
         network = self.network(borrow=True)
         raw_predictions = network.predict(x_dict, batch_size=batch_size)
         predictions = numpy.array(raw_predictions, dtype = "float64")[:,0]
-        return to_ic50(predictions)
+        result = to_ic50(predictions)
+        if use_cache:
+            self.prediction_cache[peptides] = result
+        return result
 
     @staticmethod
     def make_network(
-- 
GitLab