diff --git a/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py b/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py
index 388f1bf9b1424bb5d0cd71348e46833f21217215..de7fbc8172f51bacc603e8da999196baa7b8a1e6 100644
--- a/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py
+++ b/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py
@@ -21,8 +21,6 @@ def cache_dict_for_policy(policy):
 
 
 class PresentationComponentModel(object):
-    cache_fields = ["cached_fits", "cached_predictions"]
-
     '''
     Base class for inputs to a "final model". By "final model" we mean
     something like a logistic regression model that takes as input expression,
@@ -31,9 +29,24 @@ class PresentationComponentModel(object):
     '''
     def __init__(
             self, fit_cache_policy="weak", predictions_cache_policy="weak"):
-        self.cached_fits = cache_dict_for_policy(fit_cache_policy)
+        self.fit_cache_policy = fit_cache_policy
+        self.predictions_cache_policy = predictions_cache_policy
+        self.reset_cache()
+
+    def reset_cache(self):
+        self.cached_fits = cache_dict_for_policy(self.fit_cache_policy)
         self.cached_predictions = cache_dict_for_policy(
-            predictions_cache_policy)
+            self.predictions_cache_policy)
+
+    def __getstate__(self):
+        d = dict(self.__dict__)
+        d["cached_fits"] = None
+        d["cached_predictions"] = None
+        return d
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+        self.reset_cache()
 
     def column_names(self):
         """
@@ -98,19 +111,6 @@ class PresentationComponentModel(object):
             print("Cache hit in clone_and_restore_fit: %s" % str(self))
         return result
 
-    def __getstate__(self):
-        d = dict(self.__dict__)
-
-        # Don't pickle the cache variables, but remember what type they are.
-        for key in PresentationComponentModel.cache_fields:
-            d[key] = type(d[key])
-        return d
-
-    def __setstate__(self, state):
-        self.__dict__.update(state)
-        for key in PresentationComponentModel.cache_fields:
-            setattr(self, key, state[key]())
-
     def fit(self, hits_df):
         """
         Train the final model input.
@@ -229,14 +229,9 @@ class PresentationComponentModel(object):
             self.cached_predictions[cache_key] = return_value
         return return_value
 
-    def fit_ensemble_and_predict(peptides_df):
+    def fit_ensemble_and_predict(self, peptides_df):
         raise NotImplementedError
 
-    def reset_cache(self):
-        for key in PresentationComponentModel.cache_fields:
-            obj_type = type(getattr(self, key))
-            setattr(self, key, obj_type())
-
     def clone(self):
         """
         Copy this object so that the original and copy can be fit