From 1c1e75c6f8b32e4a9cda0fe72c37c74819034da1 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Sun, 12 Feb 2017 11:41:19 -0500
Subject: [PATCH] Simpler cache handling

---
 .../presentation_component_model.py           | 41 ++++++++-----------
 1 file changed, 18 insertions(+), 23 deletions(-)

diff --git a/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py b/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py
index 388f1bf9..de7fbc81 100644
--- a/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py
+++ b/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py
@@ -21,8 +21,6 @@ def cache_dict_for_policy(policy):
 
 
 class PresentationComponentModel(object):
-    cache_fields = ["cached_fits", "cached_predictions"]
-
     '''
     Base class for inputs to a "final model". By "final model" we mean
     something like a logistic regression model that takes as input expression,
@@ -31,9 +29,24 @@ class PresentationComponentModel(object):
     '''
     def __init__(
             self, fit_cache_policy="weak", predictions_cache_policy="weak"):
-        self.cached_fits = cache_dict_for_policy(fit_cache_policy)
+        self.fit_cache_policy = fit_cache_policy
+        self.predictions_cache_policy = predictions_cache_policy
+        self.reset_cache()
+
+    def reset_cache(self):
+        self.cached_fits = cache_dict_for_policy(self.fit_cache_policy)
         self.cached_predictions = cache_dict_for_policy(
-            predictions_cache_policy)
+            self.predictions_cache_policy)
+
+    def __getstate__(self):
+        d = dict(self.__dict__)
+        d["cached_fits"] = None
+        d["cached_predictions"] = None
+        return d
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+        self.reset_cache()
 
     def column_names(self):
         """
@@ -98,19 +111,6 @@ class PresentationComponentModel(object):
             print("Cache hit in clone_and_restore_fit: %s" % str(self))
         return result
 
-    def __getstate__(self):
-        d = dict(self.__dict__)
-
-        # Don't pickle the cache variables, but remember what type they are.
-        for key in PresentationComponentModel.cache_fields:
-            d[key] = type(d[key])
-        return d
-
-    def __setstate__(self, state):
-        self.__dict__.update(state)
-        for key in PresentationComponentModel.cache_fields:
-            setattr(self, key, state[key]())
-
     def fit(self, hits_df):
         """
         Train the final model input.
@@ -229,14 +229,9 @@ class PresentationComponentModel(object):
             self.cached_predictions[cache_key] = return_value
         return return_value
 
-    def fit_ensemble_and_predict(peptides_df):
+    def fit_ensemble_and_predict(self, peptides_df):
         raise NotImplementedError
 
-    def reset_cache(self):
-        for key in PresentationComponentModel.cache_fields:
-            obj_type = type(getattr(self, key))
-            setattr(self, key, obj_type())
-
     def clone(self):
         """
         Copy this object so that the original and copy can be fit
-- 
GitLab