From aa89d495796746ef09fc5ece7d1abbc48479dd3a Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Fri, 13 Sep 2019 11:18:27 -0400 Subject: [PATCH] Add Class1NeuralNetwork.clear_allele_representations() --- mhcflurry/class1_neural_network.py | 15 +++++++++++++++ mhcflurry/select_pan_allele_models_command.py | 1 + mhcflurry/train_pan_allele_models_command.py | 1 + 3 files changed, 17 insertions(+) diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 277c683b..600b9877 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -1346,6 +1346,21 @@ class Class1NeuralNetwork(object): return model + def clear_allele_representations(self): + """ + Set allele representations to NaN. + + This reduces the size of saved models since the NaNs will compress + easily. It doesn't actually shrink the size of the model in memory, + though. + """ + original_model = self.network() + layer = original_model.get_layer("allele_representation") + existing_weights_shape = (layer.input_dim, layer.output_dim) + self.set_allele_representations( + numpy.zeros(shape=(0,) + existing_weights_shape.shape[1:])) + + def set_allele_representations(self, allele_representations): """ Set the allele representations in use by this model. This means mutating diff --git a/mhcflurry/select_pan_allele_models_command.py b/mhcflurry/select_pan_allele_models_command.py index 6e6f45fa..fbc11edf 100644 --- a/mhcflurry/select_pan_allele_models_command.py +++ b/mhcflurry/select_pan_allele_models_command.py @@ -288,6 +288,7 @@ def run(argv=sys.argv[1:]): len(models), fold_num, result['selected_indices'])) models_by_fold[fold_num] = models for model in models: + model.clear_allele_representations() result_predictor.add_pan_allele_model(model) summary_df = pandas.concat(summary_dfs, ignore_index=False) diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py index c1b015e4..f20c9e3e 100644 --- a/mhcflurry/train_pan_allele_models_command.py +++ b/mhcflurry/train_pan_allele_models_command.py @@ -693,6 +693,7 @@ def train_model( predictor.clear_cache() # Delete the network to release memory + model.clear_allele_representations() model.update_network_description() # save weights and config model._network = None # release tensorflow network return predictor -- GitLab