Skip to content
Snippets Groups Projects
Commit 1c1e75c6 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Simpler cache handling

parent 337b1039
No related branches found
No related tags found
No related merge requests found
...@@ -21,8 +21,6 @@ def cache_dict_for_policy(policy): ...@@ -21,8 +21,6 @@ def cache_dict_for_policy(policy):
class PresentationComponentModel(object): class PresentationComponentModel(object):
cache_fields = ["cached_fits", "cached_predictions"]
''' '''
Base class for inputs to a "final model". By "final model" we mean Base class for inputs to a "final model". By "final model" we mean
something like a logistic regression model that takes as input expression, something like a logistic regression model that takes as input expression,
...@@ -31,9 +29,24 @@ class PresentationComponentModel(object): ...@@ -31,9 +29,24 @@ class PresentationComponentModel(object):
''' '''
def __init__( def __init__(
self, fit_cache_policy="weak", predictions_cache_policy="weak"): self, fit_cache_policy="weak", predictions_cache_policy="weak"):
self.cached_fits = cache_dict_for_policy(fit_cache_policy) self.fit_cache_policy = fit_cache_policy
self.predictions_cache_policy = predictions_cache_policy
self.reset_cache()
def reset_cache(self):
self.cached_fits = cache_dict_for_policy(self.fit_cache_policy)
self.cached_predictions = cache_dict_for_policy( self.cached_predictions = cache_dict_for_policy(
predictions_cache_policy) self.predictions_cache_policy)
def __getstate__(self):
d = dict(self.__dict__)
d["cached_fits"] = None
d["cached_predictions"] = None
return d
def __setstate__(self, state):
self.__dict__.update(state)
self.reset_cache()
def column_names(self): def column_names(self):
""" """
...@@ -98,19 +111,6 @@ class PresentationComponentModel(object): ...@@ -98,19 +111,6 @@ class PresentationComponentModel(object):
print("Cache hit in clone_and_restore_fit: %s" % str(self)) print("Cache hit in clone_and_restore_fit: %s" % str(self))
return result return result
def __getstate__(self):
d = dict(self.__dict__)
# Don't pickle the cache variables, but remember what type they are.
for key in PresentationComponentModel.cache_fields:
d[key] = type(d[key])
return d
def __setstate__(self, state):
self.__dict__.update(state)
for key in PresentationComponentModel.cache_fields:
setattr(self, key, state[key]())
def fit(self, hits_df): def fit(self, hits_df):
""" """
Train the final model input. Train the final model input.
...@@ -229,14 +229,9 @@ class PresentationComponentModel(object): ...@@ -229,14 +229,9 @@ class PresentationComponentModel(object):
self.cached_predictions[cache_key] = return_value self.cached_predictions[cache_key] = return_value
return return_value return return_value
def fit_ensemble_and_predict(peptides_df): def fit_ensemble_and_predict(self, peptides_df):
raise NotImplementedError raise NotImplementedError
def reset_cache(self):
for key in PresentationComponentModel.cache_fields:
obj_type = type(getattr(self, key))
setattr(self, key, obj_type())
def clone(self): def clone(self):
""" """
Copy this object so that the original and copy can be fit Copy this object so that the original and copy can be fit
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment