diff --git a/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py b/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py index 388f1bf9b1424bb5d0cd71348e46833f21217215..de7fbc8172f51bacc603e8da999196baa7b8a1e6 100644 --- a/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py +++ b/mhcflurry/antigen_presentation/presentation_component_models/presentation_component_model.py @@ -21,8 +21,6 @@ def cache_dict_for_policy(policy): class PresentationComponentModel(object): - cache_fields = ["cached_fits", "cached_predictions"] - ''' Base class for inputs to a "final model". By "final model" we mean something like a logistic regression model that takes as input expression, @@ -31,9 +29,24 @@ class PresentationComponentModel(object): ''' def __init__( self, fit_cache_policy="weak", predictions_cache_policy="weak"): - self.cached_fits = cache_dict_for_policy(fit_cache_policy) + self.fit_cache_policy = fit_cache_policy + self.predictions_cache_policy = predictions_cache_policy + self.reset_cache() + + def reset_cache(self): + self.cached_fits = cache_dict_for_policy(self.fit_cache_policy) self.cached_predictions = cache_dict_for_policy( - predictions_cache_policy) + self.predictions_cache_policy) + + def __getstate__(self): + d = dict(self.__dict__) + d["cached_fits"] = None + d["cached_predictions"] = None + return d + + def __setstate__(self, state): + self.__dict__.update(state) + self.reset_cache() def column_names(self): """ @@ -98,19 +111,6 @@ class PresentationComponentModel(object): print("Cache hit in clone_and_restore_fit: %s" % str(self)) return result - def __getstate__(self): - d = dict(self.__dict__) - - # Don't pickle the cache variables, but remember what type they are. - for key in PresentationComponentModel.cache_fields: - d[key] = type(d[key]) - return d - - def __setstate__(self, state): - self.__dict__.update(state) - for key in PresentationComponentModel.cache_fields: - setattr(self, key, state[key]()) - def fit(self, hits_df): """ Train the final model input. @@ -229,14 +229,9 @@ class PresentationComponentModel(object): self.cached_predictions[cache_key] = return_value return return_value - def fit_ensemble_and_predict(peptides_df): + def fit_ensemble_and_predict(self, peptides_df): raise NotImplementedError - def reset_cache(self): - for key in PresentationComponentModel.cache_fields: - obj_type = type(getattr(self, key)) - setattr(self, key, obj_type()) - def clone(self): """ Copy this object so that the original and copy can be fit