From c09c9ffaeee7f0ec34528dd531432ad86d57d3c0 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 3 Sep 2019 17:13:31 -0400
Subject: [PATCH] implement merging optimization

---
 mhcflurry/class1_affinity_predictor.py |  95 ++++++++++++++++-----
 mhcflurry/class1_neural_network.py     | 114 +++++++++++++++++++++++++
 test/test_network_merging.py           |  42 +++++++++
 test/test_speed.py                     |   1 +
 4 files changed, 231 insertions(+), 21 deletions(-)
 create mode 100644 test/test_network_merging.py

diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py
index f85f4cfe..c9aafb96 100644
--- a/mhcflurry/class1_affinity_predictor.py
+++ b/mhcflurry/class1_affinity_predictor.py
@@ -5,7 +5,7 @@ import logging
 import time
 import warnings
 from os.path import join, exists, abspath
-from os import mkdir
+from os import mkdir, environ
 from socket import gethostname
 from getpass import getuser
 from functools import partial
@@ -31,6 +31,9 @@ from .allele_encoding import AlleleEncoding
 # See ensemble_centrality.py for other options.
 DEFAULT_CENTRALITY_MEASURE = "mean"
 
+# Any value > 0 will result in attempting to optimize models after loading.
+OPTIMIZATION_LEVEL = int(environ.get("MHCFLURRY_OPTIMIZATION_LEVEL", 0))
+
 
 class Class1AffinityPredictor(object):
     """
@@ -98,6 +101,7 @@ class Class1AffinityPredictor(object):
         self.metadata_dataframes = (
             dict(metadata_dataframes) if metadata_dataframes else {})
         self._cache = {}
+        self.optimization_info = {}
 
         assert isinstance( self.allele_to_allele_specific_models, dict)
         assert isinstance(self.class1_pan_allele_models, list)
@@ -483,8 +487,48 @@ class Class1AffinityPredictor(object):
             manifest_df=manifest_df,
             allele_to_percent_rank_transform=allele_to_percent_rank_transform,
         )
+        if OPTIMIZATION_LEVEL >= 1:
+            logging.info("Optimizing models")
+            optimized = result.optimize()
+            logging.info(
+                "Optimization " + ("succeeded" if optimized else "failed"))
         return result
 
+    def optimize(self):
+        """
+        EXPERIMENTAL: Optimize the predictor for faster predictions.
+
+        Currently the only optimization implemented is to merge multiple pan-
+        allele predictors at the tensorflow level.
+
+        The optimization is performed in-place, mutating the instance.
+
+        Returns
+        ----------
+        bool
+            Whether optimization was performed
+
+        """
+        num_class1_pan_allele_models = len(self.class1_pan_allele_models)
+        if num_class1_pan_allele_models > 1:
+            try:
+                self.class1_pan_allele_models = [
+                    Class1NeuralNetwork.merge(
+                        self.class1_pan_allele_models,
+                        merge_method="concatenate")
+                ]
+            except NotImplementedError as e:
+                logging.warning("Optimization failed: %s" % str(e))
+                return False
+            self._manifest_df = None
+            self.clear_cache()
+            self.optimization_info["pan_models_merged"] = True
+            self.optimization_info["num_pan_models_merged"] = (
+                num_class1_pan_allele_models)
+        else:
+            return False
+        return True
+
     @staticmethod
     def model_name(allele, num):
         """
@@ -987,7 +1031,10 @@ class Class1AffinityPredictor(object):
             df["supported_peptide_length"] = True
             all_peptide_lengths_supported = True
 
-        num_pan_models = len(self.class1_pan_allele_models)
+        num_pan_models = (
+            len(self.class1_pan_allele_models)
+            if not self.optimization_info.get("pan_models_merged", False)
+            else self.optimization_info["num_pan_models_merged"])
         max_single_allele_models = max(
             len(self.allele_to_allele_specific_models.get(allele, []))
             for allele in unique_alleles
@@ -1015,40 +1062,46 @@ class Class1AffinityPredictor(object):
                     raise ValueError(msg)
             mask = df.supported_peptide_length & (
                 ~df.normalized_allele.isin(unsupported_alleles))
+
+            row_slice = None
             if mask is None or mask.all():
-                # Common case optimization
-                allele_encoding = AlleleEncoding(
+                row_slice = slice(None, None, None)  # all rows
+                masked_allele_encoding = AlleleEncoding(
                     df.normalized_allele,
                     borrow_from=master_allele_encoding)
+                masked_peptides = peptides
+            elif mask.sum() > 0:
+                row_slice = mask
+                masked_allele_encoding = AlleleEncoding(
+                    df.loc[mask].normalized_allele,
+                    borrow_from=master_allele_encoding)
+                masked_peptides = peptides.sequences[mask]
 
+            if row_slice is not None:
                 # The following line is a performance optimization that may be
                 # revisited. It causes the neural network to set to include
                 # only the alleles actually being predicted for. This makes
                 # the network much smaller. However, subsequent calls to
                 # predict will need to reset these weights, so there is a
                 # tradeoff.
-                allele_encoding = allele_encoding.compact()
-
-                for (i, model) in enumerate(self.class1_pan_allele_models):
-                    predictions_array[:, i] = (
-                        model.predict(
-                            peptides,
-                            allele_encoding=allele_encoding,
-                            **model_kwargs))
-            elif mask.sum() > 0:
-                masked_allele_encoding = AlleleEncoding(
-                    df.loc[mask].normalized_allele,
-                    borrow_from=master_allele_encoding)
-
-                # See above performance note.
                 masked_allele_encoding = masked_allele_encoding.compact()
 
-                masked_peptides = peptides.sequences[mask]
-                for (i, model) in enumerate(self.class1_pan_allele_models):
-                    predictions_array[mask, i] = model.predict(
+                if self.optimization_info.get("pan_models_merged"):
+                    # Multiple pan-allele models have been merged into one
+                    # at the tensorflow level.
+                    assert len(self.class1_pan_allele_models) == 1
+                    predictions = self.class1_pan_allele_models[0].predict(
                         masked_peptides,
                         allele_encoding=masked_allele_encoding,
+                        output_index=None,
                         **model_kwargs)
+                    predictions_array[row_slice, :num_pan_models] = predictions
+                else:
+                    for (i, model) in enumerate(self.class1_pan_allele_models):
+                        predictions_array[row_slice, i] = model.predict(
+                            masked_peptides,
+                            allele_encoding=masked_allele_encoding,
+                            **model_kwargs)
 
         if self.allele_to_allele_specific_models:
             unsupported_alleles = [
diff --git a/mhcflurry/class1_neural_network.py b/mhcflurry/class1_neural_network.py
index 30e3cc62..7bbdf4b5 100644
--- a/mhcflurry/class1_neural_network.py
+++ b/mhcflurry/class1_neural_network.py
@@ -5,6 +5,7 @@ import weakref
 import itertools
 import os
 import logging
+import pickle
 
 import numpy
 import pandas
@@ -1063,6 +1064,119 @@ class Class1NeuralNetwork(object):
             self.prediction_cache[peptides] = result
         return result
 
+    @classmethod
+    def merge(cls, models, merge_method="average"):
+        """
+        Merge multiple models at the tensorflow (or other backend) level.
+
+        Only certain neural network architectures support merging. Others will
+        throw NotImplementedError.
+
+        Parameters
+        ----------
+        models : list of Class1NeuralNetwork
+            instances to merge
+        merge_method : string, one of "average", "sum", or "concatenate"
+            How to merge the predictions of the different models
+
+        Returns
+        -------
+        Class1NeuralNetwork
+            The merged neural network
+
+        """
+        import keras
+        from keras.layers import Input
+        from keras.models import Model
+
+        if len(models) == 1:
+            return models[0]
+
+        # Copy models since we are going to mutate their underlying networks
+        models = [
+            pickle.loads(pickle.dumps(model, protocol=pickle.HIGHEST_PROTOCOL))
+            for model in models
+        ]
+        assert len(models) > 1
+
+        result = Class1NeuralNetwork(**dict(models[0].hyperparameters))
+
+        # Remove hyperparameters that are not shared by all models.
+        for model in models:
+            for (key, value) in model.hyperparameters.items():
+                if result.hyperparameters.get(key, value) != value:
+                    del result.hyperparameters[key]
+
+        assert result._network is None
+
+        networks = [
+            model.network() for model in models
+        ]
+
+        layer_names = [
+            [layer.name for layer in network.layers]
+            for network in networks
+        ]
+
+        pan_allele_layer_names1 = [
+            'allele', 'peptide', 'allele_representation', 'flattened_0',
+            'allele_flat', 'allele_peptide_merged', 'dense_0', 'dropout_0',
+            'dense_1', 'dropout_1', 'output',
+        ]
+
+        if all(names == pan_allele_layer_names1 for names in layer_names):
+            # Merging an ensemble of pan-allele architectures
+            network = networks[0]
+            peptide_input = Input(
+                shape=tuple(int(x) for x in network.inputs[0].shape[1:]),
+                dtype='float32',
+                name='peptide')
+            allele_input = Input(
+                shape=(1,),
+                dtype='float32',
+                name='allele')
+
+            allele_embedding = network.get_layer(
+                "allele_representation")(allele_input)
+            peptide_flat = network.get_layer("flattened_0")(peptide_input)
+            allele_flat = network.get_layer("allele_flat")(allele_embedding)
+            allele_peptide_merged = network.get_layer("allele_peptide_merged")(
+                [peptide_flat, allele_flat])
+
+            sub_networks = []
+            for (i, network) in enumerate(networks):
+                layers = network.layers[
+                    pan_allele_layer_names1.index("allele_peptide_merged") + 1:
+                ]
+                node = allele_peptide_merged
+                for layer in layers:
+                    layer.name += "_%d" % i
+                    node = layer(node)
+                sub_networks.append(node)
+
+            if merge_method == 'average':
+                output = keras.layers.average(sub_networks)
+            elif merge_method == 'sum':
+                output = keras.layers.add(sub_networks)
+            elif merge_method == 'concatenate':
+                output = keras.layers.concatenate(sub_networks)
+            else:
+                raise NotImplementedError(
+                    "Unsupported merge method", merge_method)
+
+            result._network = Model(
+                inputs=[peptide_input, allele_input],
+                outputs=[output],
+                name="merged_predictor"
+            )
+            result.update_network_description()
+        else:
+            raise NotImplementedError(
+                "Don't know merge_method to merge networks with layer names: ",
+                layer_names)
+        return result
+
+
     def make_network(
             self,
             peptide_encoding,
diff --git a/test/test_network_merging.py b/test/test_network_merging.py
new file mode 100644
index 00000000..1f083e8a
--- /dev/null
+++ b/test/test_network_merging.py
@@ -0,0 +1,42 @@
+from nose.tools import eq_, assert_less, assert_greater, assert_almost_equal, assert_equal
+
+import numpy
+import pandas
+from numpy import testing
+
+numpy.random.seed(0)
+
+import logging
+logging.getLogger('tensorflow').disabled = True
+
+from mhcflurry import Class1AffinityPredictor, Class1NeuralNetwork
+from mhcflurry.common import random_peptides
+from mhcflurry.downloads import get_path
+
+ALLELE_SPECIFIC_PREDICTOR = Class1AffinityPredictor.load(
+    get_path("models_class1", "models"))
+
+PAN_ALLELE_PREDICTOR = Class1AffinityPredictor.load(
+    get_path("models_class1_pan", "models.with_mass_spec"))
+
+
+def test_merge():
+    peptides = random_peptides(100, length=9)
+    peptides.extend(random_peptides(100, length=10))
+    peptides = pandas.Series(peptides).sample(frac=1.0)
+
+    alleles = pandas.Series(
+        ["HLA-A*03:01", "HLA-B*57:01", "HLA-C*02:01"]
+    ).sample(n=len(peptides), replace=True)
+
+    predictions1 = PAN_ALLELE_PREDICTOR.predict(
+        peptides=peptides, alleles=alleles)
+
+    merged = Class1NeuralNetwork.merge(
+        PAN_ALLELE_PREDICTOR.class1_pan_allele_models)
+    merged_predictor = Class1AffinityPredictor(
+        allele_to_sequence=PAN_ALLELE_PREDICTOR.allele_to_sequence,
+        class1_pan_allele_models=[merged],
+    )
+    predictions2 = merged_predictor.predict(peptides=peptides, alleles=alleles)
+    numpy.testing.assert_allclose(predictions1, predictions2, atol=0.1)
diff --git a/test/test_speed.py b/test/test_speed.py
index e9c342a6..f51e47bd 100644
--- a/test/test_speed.py
+++ b/test/test_speed.py
@@ -145,6 +145,7 @@ if __name__ == '__main__':
 
     if "pan-allele" in args.predictor:
         print("Running pan-allele test")
+        PAN_ALLELE_PREDICTOR.optimize()
         result = test_speed_pan_allele(
             profile=True, num=args.num_predictions)
         result[
-- 
GitLab