diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py index 277c683b8b23c3750daa82e75c2be09a21f75059..600b9877ce5acae9b612dcb92766f0758245c425 100644 --- a/mhcflurry/class1_neural_network.py +++ b/mhcflurry/class1_neural_network.py @@ -1346,6 +1346,21 @@ class Class1NeuralNetwork(object): return model + def clear_allele_representations(self): + """ + Set allele representations to NaN. + + This reduces the size of saved models since the NaNs will compress + easily. It doesn't actually shrink the size of the model in memory, + though. + """ + original_model = self.network() + layer = original_model.get_layer("allele_representation") + existing_weights_shape = (layer.input_dim, layer.output_dim) + self.set_allele_representations( + numpy.zeros(shape=(0,) + existing_weights_shape.shape[1:])) + + def set_allele_representations(self, allele_representations): """ Set the allele representations in use by this model. This means mutating diff --git a/mhcflurry/select_pan_allele_models_command.py b/mhcflurry/select_pan_allele_models_command.py index 6e6f45fa8c29fd5ead83567ce251e9cb88141e08..fbc11edf57ddf9142fbd4f784676cc5f516de9bf 100644 --- a/mhcflurry/select_pan_allele_models_command.py +++ b/mhcflurry/select_pan_allele_models_command.py @@ -288,6 +288,7 @@ def run(argv=sys.argv[1:]): len(models), fold_num, result['selected_indices'])) models_by_fold[fold_num] = models for model in models: + model.clear_allele_representations() result_predictor.add_pan_allele_model(model) summary_df = pandas.concat(summary_dfs, ignore_index=False) diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py index c1b015e4573341cae239a306ed52db159c587aac..f20c9e3e3586b42a0ec8a9a5bc3c7a44875ca76c 100644 --- a/mhcflurry/train_pan_allele_models_command.py +++ b/mhcflurry/train_pan_allele_models_command.py @@ -693,6 +693,7 @@ def train_model( predictor.clear_cache() # Delete the network to release memory + model.clear_allele_representations() model.update_network_description() # save weights and config model._network = None # release tensorflow network return predictor