From 024618ec7cd5e728a4a0ff89408adf1d6a435c73 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Tue, 3 Dec 2019 19:09:00 -0500 Subject: [PATCH] presentation model saving and loading --- mhcflurry/class1_affinity_predictor.py | 48 +------- .../class1_presentation_neural_network.py | 45 ++++++++ mhcflurry/class1_presentation_predictor.py | 107 ++++-------------- mhcflurry/common.py | 32 ++++++ mhcflurry/downloads.py | 34 ++++++ 5 files changed, 139 insertions(+), 127 deletions(-) diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py index 740f7697..9a8aaecc 100644 --- a/mhcflurry/class1_affinity_predictor.py +++ b/mhcflurry/class1_affinity_predictor.py @@ -26,6 +26,7 @@ from .regression_target import to_ic50 from .version import __version__ from .ensemble_centrality import CENTRALITY_MEASURES from .allele_encoding import AlleleEncoding +from .common import save_weights, load_weights # Default function for combining predictions across models in an ensemble. @@ -370,8 +371,7 @@ class Class1AffinityPredictor(object): updated_network_config_jsons.append( json.dumps(row.model.get_config())) weights_path = self.weights_path(models_dir, row.model_name) - Class1AffinityPredictor.save_weights( - row.model.get_weights(), weights_path) + save_weights(row.model.get_weights(), weights_path) logging.info("Wrote: %s", weights_path) sub_manifest_df["config_json"] = updated_network_config_jsons self.manifest_df.loc[ @@ -469,9 +469,7 @@ class Class1AffinityPredictor(object): # We will lazy-load weights when the network is used. model = Class1NeuralNetwork.from_config( config, - weights_loader=partial( - Class1AffinityPredictor.load_weights, - abspath(weights_filename))) + weights_loader=partial(load_weights, abspath(weights_filename))) if row.allele == "pan-class1": class1_pan_allele_models.append(model) else: @@ -1235,46 +1233,6 @@ class Class1AffinityPredictor(object): del df["normalized_allele"] return df - @staticmethod - def save_weights(weights_list, filename): - """ - Save the model weights to the given filename using numpy's ".npz" - format. - - Parameters - ---------- - weights_list : list of array - - filename : string - Should end in ".npz". - - """ - numpy.savez( - filename, - **dict((("array_%d" % i), w) for (i, w) in enumerate(weights_list))) - - @staticmethod - def load_weights(filename): - """ - Restore model weights from the given filename, which should have been - created with `save_weights`. - - Parameters - ---------- - filename : string - Should end in ".npz". - - Returns - ---------- - list of array - """ - with numpy.load(filename) as loaded: - weights = [ - loaded["array_%d" % i] - for i in range(len(loaded.keys())) - ] - return weights - def calibrate_percentile_ranks( self, peptides=None, diff --git a/mhcflurry/class1_presentation_neural_network.py b/mhcflurry/class1_presentation_neural_network.py index 897fafd8..b37adcff 100644 --- a/mhcflurry/class1_presentation_neural_network.py +++ b/mhcflurry/class1_presentation_neural_network.py @@ -699,4 +699,49 @@ class Class1PresentationNeuralNetwork(object): if network_weights is not None: self.network.set_weights(network_weights) + def get_config(self): + """ + serialize to a dict all attributes except model weights + + Returns + ------- + dict + """ + result = dict(self.__dict__) + result['network'] = None + result['network_weights'] = None + result['network_json'] = None + if self.network: + result['network_weights'] = self.network.get_weights() + result['network_json'] = self.network.to_json() + return result + + @classmethod + def from_config(cls, config, weights=None): + """ + deserialize from a dict returned by get_config(). + + Parameters + ---------- + config : dict + weights : list of array, optional + Network weights to restore + weights_loader : callable, optional + Function to call (no arguments) to load weights when needed + Returns + ------- + Class1NeuralNetwork + """ + config = dict(config) + instance = cls(**config.pop('hyperparameters')) + network_json = config.pop('network_json') + network_weights = config.pop('network_weights') + instance.__dict__.update(config) + assert instance.network is None + if network_json is not None: + import keras.models + instance.network = keras.models.model_from_json(network_json) + if network_weights is not None: + instance.network.set_weights(network_weights) + return instance \ No newline at end of file diff --git a/mhcflurry/class1_presentation_predictor.py b/mhcflurry/class1_presentation_predictor.py index 625b32f4..a4cdd3b7 100644 --- a/mhcflurry/class1_presentation_predictor.py +++ b/mhcflurry/class1_presentation_predictor.py @@ -31,16 +31,19 @@ from .custom_loss import ( MSEWithInequalities, MultiallelicMassSpecLoss, ZeroLoss) +from .downloads import get_default_class1_presentation_models_dir +from .class1_presentation_neural_network import Class1PresentationNeuralNetwork +from .common import save_weights, load_weights class Class1PresentationPredictor(object): def __init__( self, - class1_presentation_neural_networks, + models, allele_to_sequence, manifest_df=None, metadata_dataframes=None): - self.networks = class1_presentation_neural_networks + self.models = models self.allele_to_sequence = allele_to_sequence self._manifest_df = manifest_df self.metadata_dataframes = ( @@ -57,7 +60,7 @@ class Class1PresentationPredictor(object): """ if self._manifest_df is None: rows = [] - for (i, model) in enumerate(self.networks): + for (i, model) in enumerate(self.models): rows.append(( self.model_name(i), json.dumps(model.get_config()), @@ -70,10 +73,10 @@ class Class1PresentationPredictor(object): @property def max_alleles(self): - max_alleles = self.networks[0].hyperparameters['max_alleles'] + max_alleles = self.models[0].hyperparameters['max_alleles'] assert all( n.hyperparameters['max_alleles'] == self.max_alleles - for n in self.networks) + for n in self.models) return max_alleles @staticmethod @@ -153,7 +156,7 @@ class Class1PresentationPredictor(object): score_array = [] affinity_array = [] - for (i, network) in enumerate(self.networks): + for (i, network) in enumerate(self.models): predictions = network.predict( peptides=peptides, allele_encoding=alleles, @@ -191,24 +194,6 @@ class Class1PresentationPredictor(object): numpy.percentile(affinity_array[:, :, i], 5.0, axis=0)) return result_df - @staticmethod - def save_weights(weights_list, filename): - """ - Save the model weights to the given filename using numpy's ".npz" - format. - - Parameters - ---------- - weights_list : list of array - - filename : string - Should end in ".npz". - - """ - numpy.savez( - filename, - **dict((("array_%d" % i), w) for (i, w) in enumerate(weights_list))) - def check_consistency(self): """ Verify that self.manifest_df is consistent with instance variables. @@ -217,10 +202,10 @@ class Class1PresentationPredictor(object): Throws AssertionError if inconsistent. """ - assert len(self.manifest_df) == len(self.networks), ( + assert len(self.manifest_df) == len(self.models), ( "Manifest seems out of sync with models: %d vs %d entries: \n%s"% ( len(self.manifest_df), - len(self.networks), + len(self.models), str(self.manifest_df))) def save(self, models_dir, model_names_to_write=None, write_metadata=True): @@ -301,8 +286,8 @@ class Class1PresentationPredictor(object): join(models_dir, "allele_sequences.csv"), index=False) logging.info("Wrote: %s", join(models_dir, "allele_sequences.csv")) - @staticmethod - def load(models_dir=None, max_models=None): + @classmethod + def load(cls, models_dir=None, max_models=None): """ Deserialize a predictor from a directory on disk. @@ -317,35 +302,24 @@ class Class1PresentationPredictor(object): Returns ------- - `Class1AffinityPredictor` instance + `Class1PresentationPredictor` instance """ if models_dir is None: - models_dir = get_default_class1_models_dir() + models_dir = get_default_class1_presentation_models_dir() manifest_path = join(models_dir, "manifest.csv") manifest_df = pandas.read_csv(manifest_path, nrows=max_models) - allele_to_allele_specific_models = collections.defaultdict(list) - class1_pan_allele_models = [] - all_models = [] + models = [] for (_, row) in manifest_df.iterrows(): - weights_filename = Class1AffinityPredictor.weights_path( - models_dir, row.model_name) + weights_filename = cls.weights_path(models_dir, row.model_name) config = json.loads(row.config_json) - - # We will lazy-load weights when the network is used. - model = Class1NeuralNetwork.from_config( + model = Class1PresentationNeuralNetwork.from_config( config, - weights_loader=partial( - Class1AffinityPredictor.load_weights, - abspath(weights_filename))) - if row.allele == "pan-class1": - class1_pan_allele_models.append(model) - else: - allele_to_allele_specific_models[row.allele].append(model) - all_models.append(model) + weights=load_weights(abspath(weights_filename))) + models.append(model) - manifest_df["model"] = all_models + manifest_df["model"] = models # Load allele sequences allele_to_sequence = None @@ -354,40 +328,9 @@ class Class1PresentationPredictor(object): join(models_dir, "allele_sequences.csv"), index_col=0).iloc[:, 0].to_dict() - allele_to_percent_rank_transform = {} - percent_ranks_path = join(models_dir, "percent_ranks.csv") - if exists(percent_ranks_path): - percent_ranks_df = pandas.read_csv(percent_ranks_path, index_col=0) - for allele in percent_ranks_df.columns: - allele_to_percent_rank_transform[allele] = ( - PercentRankTransform.from_series(percent_ranks_df[allele])) - - logging.info( - "Loaded %d class1 pan allele predictors, %d allele sequences, " - "%d percent rank distributions, and %d allele specific models: %s", - len(class1_pan_allele_models), - len(allele_to_sequence) if allele_to_sequence else 0, - len(allele_to_percent_rank_transform), - sum(len(v) for v in allele_to_allele_specific_models.values()), - ", ".join( - "%s (%d)" % (allele, len(v)) - for (allele, v) - in sorted(allele_to_allele_specific_models.items()))) - - result = Class1AffinityPredictor( - allele_to_allele_specific_models=allele_to_allele_specific_models, - class1_pan_allele_models=class1_pan_allele_models, + logging.info("Loaded %d class1 presentation models", len(models)) + result = cls( + models=models, allele_to_sequence=allele_to_sequence, - manifest_df=manifest_df, - allele_to_percent_rank_transform=allele_to_percent_rank_transform, - ) - if optimization_level >= 1: - optimized = result.optimize() - logging.info( - "Model optimization %s", - "succeeded" if optimized else "not supported for these models") + manifest_df=manifest_df) return result - - - - # TODO: implement saving and loading \ No newline at end of file diff --git a/mhcflurry/common.py b/mhcflurry/common.py index 8885637b..f51696ea 100644 --- a/mhcflurry/common.py +++ b/mhcflurry/common.py @@ -174,3 +174,35 @@ def positional_frequency_matrix(peptides): result = (counts / len(peptides)).fillna(0.0).T result.index.name = 'position' return result + + +def save_weights(weights_list, filename): + """ + Save model weights to the given filename using numpy's ".npz" format. + + Parameters + ---------- + weights_list : list of numpy array + + filename : string + """ + numpy.savez(filename, + **dict((("array_%d" % i), w) for (i, w) in enumerate(weights_list))) + + +def load_weights(filename): + """ + Restore model weights from the given filename, which should have been + created with `save_weights`. + + Parameters + ---------- + filename : string + + Returns + ---------- + list of array + """ + with numpy.load(filename) as loaded: + weights = [loaded["array_%d" % i] for i in range(len(loaded.keys()))] + return weights diff --git a/mhcflurry/downloads.py b/mhcflurry/downloads.py index 7f29ea52..c3eca7cd 100644 --- a/mhcflurry/downloads.py +++ b/mhcflurry/downloads.py @@ -28,6 +28,8 @@ _CURRENT_RELEASE = None _METADATA = None _MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR = environ.get( "MHCFLURRY_DEFAULT_CLASS1_MODELS") +_MHCFLURRY_DEFAULT_CLASS1_PRESENTATION_MODELS_DIR = environ.get( + "MHCFLURRY_DEFAULT_CLASS1_PRESENTATION_MODELS_DIR") def get_downloads_dir(): @@ -84,6 +86,38 @@ def get_default_class1_models_dir(test_exists=True): return get_path("models_class1", "models", test_exists=test_exists) +def get_default_class1_presentation_models_dir(test_exists=True): + """ + Return the absolute path to the default class1 presentation models dir. + + See `get_default_class1_models_dir`. + + If environment variable MHCFLURRY_DEFAULT_CLASS1_PRESENTATION_MODELS is set + to an absolute path, return that path. If it's set to a relative path (does + not start with /) then return that path taken to be relative to the mhcflurry + downloads dir. + + Parameters + ---------- + + test_exists : boolean, optional + Whether to raise an exception of the path does not exist + + Returns + ------- + string : absolute path + """ + if _MHCFLURRY_DEFAULT_CLASS1_PRESENTATION_MODELS_DIR: + result = join( + get_downloads_dir(), + _MHCFLURRY_DEFAULT_CLASS1_PRESENTATION_MODELS_DIR) + if test_exists and not exists(result): + raise IOError("No such directory: %s" % result) + return result + return get_path( + "models_class1_pan_refined", "presentation", test_exists=test_exists) + + def get_current_release_downloads(): """ Return a dict of all available downloads in the current release. -- GitLab