Skip to content
Snippets Groups Projects
Commit 1c1e75c6 authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Simpler cache handling

parent 337b1039
No related branches found
No related tags found
No related merge requests found
......@@ -21,8 +21,6 @@ def cache_dict_for_policy(policy):
class PresentationComponentModel(object):
cache_fields = ["cached_fits", "cached_predictions"]
'''
Base class for inputs to a "final model". By "final model" we mean
something like a logistic regression model that takes as input expression,
......@@ -31,9 +29,24 @@ class PresentationComponentModel(object):
'''
def __init__(
self, fit_cache_policy="weak", predictions_cache_policy="weak"):
self.cached_fits = cache_dict_for_policy(fit_cache_policy)
self.fit_cache_policy = fit_cache_policy
self.predictions_cache_policy = predictions_cache_policy
self.reset_cache()
def reset_cache(self):
self.cached_fits = cache_dict_for_policy(self.fit_cache_policy)
self.cached_predictions = cache_dict_for_policy(
predictions_cache_policy)
self.predictions_cache_policy)
def __getstate__(self):
d = dict(self.__dict__)
d["cached_fits"] = None
d["cached_predictions"] = None
return d
def __setstate__(self, state):
self.__dict__.update(state)
self.reset_cache()
def column_names(self):
"""
......@@ -98,19 +111,6 @@ class PresentationComponentModel(object):
print("Cache hit in clone_and_restore_fit: %s" % str(self))
return result
def __getstate__(self):
d = dict(self.__dict__)
# Don't pickle the cache variables, but remember what type they are.
for key in PresentationComponentModel.cache_fields:
d[key] = type(d[key])
return d
def __setstate__(self, state):
self.__dict__.update(state)
for key in PresentationComponentModel.cache_fields:
setattr(self, key, state[key]())
def fit(self, hits_df):
"""
Train the final model input.
......@@ -229,14 +229,9 @@ class PresentationComponentModel(object):
self.cached_predictions[cache_key] = return_value
return return_value
def fit_ensemble_and_predict(peptides_df):
def fit_ensemble_and_predict(self, peptides_df):
raise NotImplementedError
def reset_cache(self):
for key in PresentationComponentModel.cache_fields:
obj_type = type(getattr(self, key))
setattr(self, key, obj_type())
def clone(self):
"""
Copy this object so that the original and copy can be fit
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment