From adf2c56ae65ac99632a526baefd4248a1cc792ed Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Mon, 15 Jul 2019 15:45:22 -0400
Subject: [PATCH] Add data dependent weights initialization (LSUV)

---
 mhcflurry/class1_neural_network.py            |  69 ++++++++++-
 .../data_dependent_weights_initialization.py  | 107 ++++++++++++++++++
 test/expensive_test_pretrain_optimizable.py   |   3 +-
 3 files changed, 174 insertions(+), 5 deletions(-)
 create mode 100644 mhcflurry/data_dependent_weights_initialization.py

diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 3773f7d4..ec5feb07 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -3,6 +3,7 @@ import collections
 import logging
 import json
 import weakref
+import itertools
 
 import numpy
 import pandas
@@ -13,6 +14,7 @@ from .encodable_sequences import EncodableSequences, EncodingError
 from .regression_target import to_ic50, from_ic50
 from .common import random_peptides, amino_acid_distribution
 from .custom_loss import get_loss
+from .data_dependent_weights_initialization import lsuv_init
 
 
 class Class1NeuralNetwork(object):
@@ -76,6 +78,7 @@ class Class1NeuralNetwork(object):
         validation_split=0.1,
         early_stopping=True,
         minibatch_size=128,
+        data_dependent_initialization_method=None,
         random_negative_rate=0.0,
         random_negative_constant=25,
         random_negative_affinity_min=20000.0,
@@ -419,6 +422,31 @@ class Class1NeuralNetwork(object):
             allele_encoding.allele_representations(
                 self.hyperparameters['allele_amino_acid_encoding']))
 
+    @staticmethod
+    def data_dependent_weights_initialization(
+            network,
+            x_dict=None,
+            method="lsuv",
+            verbose=1):
+        """
+        Data dependent weights initialization.
+
+        Parameters
+        ----------
+        method
+
+        Returns
+        -------
+
+        """
+        if verbose:
+            print("Performing data-dependent init: ", method)
+        if method == "lsuv":
+            assert x_dict is not None, "Data required for LSUV init"
+            lsuv_init(network, x_dict, verbose=verbose > 0)
+        else:
+            raise RuntimeError("Unsupported init method: ", method)
+
     def fit_generator(
             self,
             generator,
@@ -505,7 +533,9 @@ class Class1NeuralNetwork(object):
             'output': output,
         }
 
-        yielded_values_box = [0]
+        mutable_generator_state = {
+            'yielded_values': 0  # total number of data points yielded
+        }
 
         def wrapped_generator():
             for (alleles, peptides, affinities) in generator:
@@ -519,12 +549,28 @@ class Class1NeuralNetwork(object):
                     'output': from_ic50(affinities)
                 }
                 yield (x_dict, y_dict)
-                yielded_values_box[0] += len(affinities)
+                mutable_generator_state['yielded_values'] += len(affinities)
 
         start = time.time()
 
+        iterator = wrapped_generator()
+
+        # Initialization required if a data_dependent_initialization_method
+        # is set and this is our first time fitting (i.e. fit_info is empty).
+        data_dependent_init = self.hyperparameters[
+            'data_dependent_initialization_method'
+        ]
+        if data_dependent_init and not self.fit_info:
+            first_chunk = next(iterator)
+            self.data_dependent_weights_initialization(
+                network,
+                first_chunk[0],  # x_dict
+                method=data_dependent_init,
+                verbose=verbose)
+            iterator = itertools.chain([first_chunk], iterator)
+
         fit_history = network.fit_generator(
-            wrapped_generator(),
+            iterator,
             steps_per_epoch=steps_per_epoch,
             epochs=epochs,
             use_multiprocessing=False,
@@ -541,7 +587,7 @@ class Class1NeuralNetwork(object):
             fit_info[key].extend(value)
 
         fit_info["time"] = time.time() - start
-        fit_info["num_points"] = yielded_values_box[0]
+        fit_info["num_points"] = mutable_generator_state["yielded_values"]
         self.fit_info.append(dict(fit_info))
 
     def fit(
@@ -777,6 +823,12 @@ class Class1NeuralNetwork(object):
         min_val_loss_iteration = None
         min_val_loss = None
 
+        # Initialization required if a data_dependent_initialization_method
+        # is set and this is our first time fitting (i.e. fit_info is empty).
+        needs_initialization = self.hyperparameters[
+            'data_dependent_initialization_method'
+        ] is not None and not self.fit_info
+
         start = time.time()
         last_progress_print = None
         x_dict_with_random_negatives = {}
@@ -828,6 +880,15 @@ class Class1NeuralNetwork(object):
                             ]
                         )
 
+            if needs_initialization:
+                self.data_dependent_weights_initialization(
+                    self.network(),
+                    x_dict_with_random_negatives,
+                    method=self.hyperparameters[
+                        'data_dependent_initialization_method'],
+                    verbose=verbose)
+                needs_initialization = False
+
             fit_history = self.network().fit(
                 x_dict_with_random_negatives,
                 y_dict_with_random_negatives,
diff --git a/mhcflurry/data_dependent_weights_initialization.py b/mhcflurry/data_dependent_weights_initialization.py
new file mode 100644
index 00000000..165c1e6d
--- /dev/null
+++ b/mhcflurry/data_dependent_weights_initialization.py
@@ -0,0 +1,107 @@
+# LSUV initialization code in this file is adapted from:
+#   https://github.com/ducha-aiki/LSUV-keras/blob/master/lsuv_init.py
+# by Dmytro Mishkin
+#
+# Here is the license for the original code:
+#
+#
+# Copyright (C) 2017, Dmytro Mishkin
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+# 1. Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the
+#    distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from __future__ import print_function
+import numpy
+
+
+def svd_orthonormal(shape):
+    # Orthonorm init code is from Lasagne
+    # https://github.com/Lasagne/Lasagne/blob/master/lasagne/init.py
+    if len(shape) < 2:
+        raise RuntimeError("Only shapes of length 2 or more are supported.")
+    flat_shape = (shape[0], numpy.prod(shape[1:]))
+    a = numpy.random.standard_normal(flat_shape).astype("float32")
+    u, _, v = numpy.linalg.svd(a, full_matrices=False)
+    q = u if u.shape == flat_shape else v
+    q = q.reshape(shape)
+    return q
+
+
+def get_activations(model, layer, X_batch):
+    from keras.models import Model
+    intermediate_layer_model = Model(
+        inputs=model.get_input_at(0),
+        outputs=layer.get_output_at(0)
+    )
+    activations = intermediate_layer_model.predict(X_batch)
+    return activations
+
+
+def lsuv_init(model, batch, verbose=True, margin=0.1, max_iter=100):
+    from keras.layers import Dense, Convolution2D
+    needed_variance = 1.0
+    layers_inintialized = 0
+    for layer in model.layers:
+        if not isinstance(layer, (Dense, Convolution2D)):
+            continue
+        # avoid small layers where activation variance close to zero, esp.
+        # for small batches
+        if numpy.prod(layer.get_output_shape_at(0)[1:]) < 32:
+            if verbose:
+                print('LSUV initialization skipping', layer.name)
+            continue
+        layers_inintialized += 1
+        weights_and_biases = layer.get_weights()
+        weights_and_biases[0] = svd_orthonormal(weights_and_biases[0].shape)
+        layer.set_weights(weights_and_biases)
+        activations = get_activations(model, layer, batch)
+        variance = numpy.var(activations)
+        iteration = 0
+        if verbose:
+            print(layer.name, variance)
+        while abs(needed_variance - variance) > margin:
+            if verbose:
+                print(
+                    'LSUV initialization',
+                    layer.name,
+                    iteration,
+                    needed_variance,
+                    margin,
+                    variance)
+
+            if numpy.abs(numpy.sqrt(variance)) < 1e-7:
+                break  # avoid zero division
+
+            weights_and_biases = layer.get_weights()
+            weights_and_biases[0] /= numpy.sqrt(variance) / numpy.sqrt(
+                needed_variance)
+            layer.set_weights(weights_and_biases)
+            activations = get_activations(model, layer, batch)
+            variance = numpy.var(activations)
+
+            iteration += 1
+            if iteration >= max_iter:
+                break
+    if verbose:
+        print('Done with LSUV: total layers initialized', layers_inintialized)
+    return model
\ No newline at end of file
diff --git a/test/expensive_test_pretrain_optimizable.py b/test/expensive_test_pretrain_optimizable.py
index 3a261c65..c2b0e461 100644
--- a/test/expensive_test_pretrain_optimizable.py
+++ b/test/expensive_test_pretrain_optimizable.py
@@ -50,10 +50,11 @@ HYPERPARAMTERS = {
     'random_negative_distribution_smoothing': 0.0,
     'random_negative_match_distribution': True, 'random_negative_rate': 0.2,
     'train_data': {'pretrain': True,
-                   'pretrain_max_epochs': 1,
+                   'pretrain_max_epochs': 3,
                    'pretrain_peptides_per_epoch': 1024,
                    'pretrain_steps_per_epoch': 16},
     'validation_split': 0.1,
+    'data_dependent_initialization_method': "lsuv",
 }
 
 
-- 
GitLab