From 024618ec7cd5e728a4a0ff89408adf1d6a435c73 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 3 Dec 2019 19:09:00 -0500
Subject: [PATCH] presentation model saving and loading

---
 mhcflurry/class1_affinity_predictor.py        |  48 +-------
 .../class1_presentation_neural_network.py     |  45 ++++++++
 mhcflurry/class1_presentation_predictor.py    | 107 ++++--------------
 mhcflurry/common.py                           |  32 ++++++
 mhcflurry/downloads.py                        |  34 ++++++
 5 files changed, 139 insertions(+), 127 deletions(-)

diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py
index 740f7697..9a8aaecc 100644
--- a/mhcflurry/class1_affinity_predictor.py
+++ b/mhcflurry/class1_affinity_predictor.py
@@ -26,6 +26,7 @@ from .regression_target import to_ic50
 from .version import __version__
 from .ensemble_centrality import CENTRALITY_MEASURES
 from .allele_encoding import AlleleEncoding
+from .common import save_weights, load_weights
 
 
 # Default function for combining predictions across models in an ensemble.
@@ -370,8 +371,7 @@ class Class1AffinityPredictor(object):
             updated_network_config_jsons.append(
                 json.dumps(row.model.get_config()))
             weights_path = self.weights_path(models_dir, row.model_name)
-            Class1AffinityPredictor.save_weights(
-                row.model.get_weights(), weights_path)
+            save_weights(row.model.get_weights(), weights_path)
             logging.info("Wrote: %s", weights_path)
         sub_manifest_df["config_json"] = updated_network_config_jsons
         self.manifest_df.loc[
@@ -469,9 +469,7 @@ class Class1AffinityPredictor(object):
             # We will lazy-load weights when the network is used.
             model = Class1NeuralNetwork.from_config(
                 config,
-                weights_loader=partial(
-                    Class1AffinityPredictor.load_weights,
-                    abspath(weights_filename)))
+                weights_loader=partial(load_weights, abspath(weights_filename)))
             if row.allele == "pan-class1":
                 class1_pan_allele_models.append(model)
             else:
@@ -1235,46 +1233,6 @@ class Class1AffinityPredictor(object):
         del df["normalized_allele"]
         return df
 
-    @staticmethod
-    def save_weights(weights_list, filename):
-        """
-        Save the model weights to the given filename using numpy's ".npz"
-        format.
-    
-        Parameters
-        ----------
-        weights_list : list of array
-        
-        filename : string
-            Should end in ".npz".
-    
-        """
-        numpy.savez(
-            filename,
-            **dict((("array_%d" % i), w) for (i, w) in enumerate(weights_list)))
-
-    @staticmethod
-    def load_weights(filename):
-        """
-        Restore model weights from the given filename, which should have been
-        created with `save_weights`.
-    
-        Parameters
-        ----------
-        filename : string
-            Should end in ".npz".
-
-        Returns
-        ----------
-        list of array
-        """
-        with numpy.load(filename) as loaded:
-            weights = [
-                loaded["array_%d" % i]
-                for i in range(len(loaded.keys()))
-            ]
-        return weights
-
     def calibrate_percentile_ranks(
             self,
             peptides=None,
diff --git a/mhcflurry/class1_presentation_neural_network.py b/mhcflurry/class1_presentation_neural_network.py
index 897fafd8..b37adcff 100644
--- a/mhcflurry/class1_presentation_neural_network.py
+++ b/mhcflurry/class1_presentation_neural_network.py
@@ -699,4 +699,49 @@ class Class1PresentationNeuralNetwork(object):
             if network_weights is not None:
                 self.network.set_weights(network_weights)
 
+    def get_config(self):
+        """
+        serialize to a dict all attributes except model weights
+
+        Returns
+        -------
+        dict
+        """
+        result = dict(self.__dict__)
+        result['network'] = None
+        result['network_weights'] = None
+        result['network_json'] = None
+        if self.network:
+            result['network_weights'] = self.network.get_weights()
+            result['network_json'] = self.network.to_json()
+        return result
+
+    @classmethod
+    def from_config(cls, config, weights=None):
+        """
+        deserialize from a dict returned by get_config().
+
+        Parameters
+        ----------
+        config : dict
+        weights : list of array, optional
+            Network weights to restore
+        weights_loader : callable, optional
+            Function to call (no arguments) to load weights when needed
 
+        Returns
+        -------
+        Class1NeuralNetwork
+        """
+        config = dict(config)
+        instance = cls(**config.pop('hyperparameters'))
+        network_json = config.pop('network_json')
+        network_weights = config.pop('network_weights')
+        instance.__dict__.update(config)
+        assert instance.network is None
+        if network_json is not None:
+            import keras.models
+            instance.network = keras.models.model_from_json(network_json)
+            if network_weights is not None:
+                instance.network.set_weights(network_weights)
+        return instance
\ No newline at end of file
diff --git a/mhcflurry/class1_presentation_predictor.py b/mhcflurry/class1_presentation_predictor.py
index 625b32f4..a4cdd3b7 100644
--- a/mhcflurry/class1_presentation_predictor.py
+++ b/mhcflurry/class1_presentation_predictor.py
@@ -31,16 +31,19 @@ from .custom_loss import (
     MSEWithInequalities,
     MultiallelicMassSpecLoss,
     ZeroLoss)
+from .downloads import get_default_class1_presentation_models_dir
+from .class1_presentation_neural_network import Class1PresentationNeuralNetwork
+from .common import save_weights, load_weights
 
 
 class Class1PresentationPredictor(object):
     def __init__(
             self,
-            class1_presentation_neural_networks,
+            models,
             allele_to_sequence,
             manifest_df=None,
             metadata_dataframes=None):
-        self.networks = class1_presentation_neural_networks
+        self.models = models
         self.allele_to_sequence = allele_to_sequence
         self._manifest_df = manifest_df
         self.metadata_dataframes = (
@@ -57,7 +60,7 @@ class Class1PresentationPredictor(object):
         """
         if self._manifest_df is None:
             rows = []
-            for (i, model) in enumerate(self.networks):
+            for (i, model) in enumerate(self.models):
                 rows.append((
                     self.model_name(i),
                     json.dumps(model.get_config()),
@@ -70,10 +73,10 @@ class Class1PresentationPredictor(object):
 
     @property
     def max_alleles(self):
-        max_alleles = self.networks[0].hyperparameters['max_alleles']
+        max_alleles = self.models[0].hyperparameters['max_alleles']
         assert all(
             n.hyperparameters['max_alleles'] == self.max_alleles
-            for n in self.networks)
+            for n in self.models)
         return max_alleles
 
     @staticmethod
@@ -153,7 +156,7 @@ class Class1PresentationPredictor(object):
         score_array = []
         affinity_array = []
 
-        for (i, network) in enumerate(self.networks):
+        for (i, network) in enumerate(self.models):
             predictions = network.predict(
                 peptides=peptides,
                 allele_encoding=alleles,
@@ -191,24 +194,6 @@ class Class1PresentationPredictor(object):
                     numpy.percentile(affinity_array[:, :, i], 5.0, axis=0))
         return result_df
 
-    @staticmethod
-    def save_weights(weights_list, filename):
-        """
-        Save the model weights to the given filename using numpy's ".npz"
-        format.
-
-        Parameters
-        ----------
-        weights_list : list of array
-
-        filename : string
-            Should end in ".npz".
-
-        """
-        numpy.savez(
-            filename,
-            **dict((("array_%d" % i), w) for (i, w) in enumerate(weights_list)))
-
     def check_consistency(self):
         """
         Verify that self.manifest_df is consistent with instance variables.
@@ -217,10 +202,10 @@ class Class1PresentationPredictor(object):
 
         Throws AssertionError if inconsistent.
         """
-        assert len(self.manifest_df) == len(self.networks), (
+        assert len(self.manifest_df) == len(self.models), (
             "Manifest seems out of sync with models: %d vs %d entries: \n%s"% (
                 len(self.manifest_df),
-                len(self.networks),
+                len(self.models),
                 str(self.manifest_df)))
 
     def save(self, models_dir, model_names_to_write=None, write_metadata=True):
@@ -301,8 +286,8 @@ class Class1PresentationPredictor(object):
                 join(models_dir, "allele_sequences.csv"), index=False)
             logging.info("Wrote: %s", join(models_dir, "allele_sequences.csv"))
 
-    @staticmethod
-    def load(models_dir=None, max_models=None):
+    @classmethod
+    def load(cls, models_dir=None, max_models=None):
         """
         Deserialize a predictor from a directory on disk.
 
@@ -317,35 +302,24 @@ class Class1PresentationPredictor(object):
 
         Returns
         -------
-        `Class1AffinityPredictor` instance
+        `Class1PresentationPredictor` instance
         """
         if models_dir is None:
-            models_dir = get_default_class1_models_dir()
+            models_dir = get_default_class1_presentation_models_dir()
 
         manifest_path = join(models_dir, "manifest.csv")
         manifest_df = pandas.read_csv(manifest_path, nrows=max_models)
 
-        allele_to_allele_specific_models = collections.defaultdict(list)
-        class1_pan_allele_models = []
-        all_models = []
+        models = []
         for (_, row) in manifest_df.iterrows():
-            weights_filename = Class1AffinityPredictor.weights_path(
-                models_dir, row.model_name)
+            weights_filename = cls.weights_path(models_dir, row.model_name)
             config = json.loads(row.config_json)
-
-            # We will lazy-load weights when the network is used.
-            model = Class1NeuralNetwork.from_config(
+            model = Class1PresentationNeuralNetwork.from_config(
                 config,
-                weights_loader=partial(
-                    Class1AffinityPredictor.load_weights,
-                    abspath(weights_filename)))
-            if row.allele == "pan-class1":
-                class1_pan_allele_models.append(model)
-            else:
-                allele_to_allele_specific_models[row.allele].append(model)
-            all_models.append(model)
+                weights=load_weights(abspath(weights_filename)))
+            models.append(model)
 
-        manifest_df["model"] = all_models
+        manifest_df["model"] = models
 
         # Load allele sequences
         allele_to_sequence = None
@@ -354,40 +328,9 @@ class Class1PresentationPredictor(object):
                 join(models_dir, "allele_sequences.csv"),
                 index_col=0).iloc[:, 0].to_dict()
 
-        allele_to_percent_rank_transform = {}
-        percent_ranks_path = join(models_dir, "percent_ranks.csv")
-        if exists(percent_ranks_path):
-            percent_ranks_df = pandas.read_csv(percent_ranks_path, index_col=0)
-            for allele in percent_ranks_df.columns:
-                allele_to_percent_rank_transform[allele] = (
-                    PercentRankTransform.from_series(percent_ranks_df[allele]))
-
-        logging.info(
-            "Loaded %d class1 pan allele predictors, %d allele sequences, "
-            "%d percent rank distributions, and %d allele specific models: %s",
-            len(class1_pan_allele_models),
-            len(allele_to_sequence) if allele_to_sequence else 0,
-            len(allele_to_percent_rank_transform),
-            sum(len(v) for v in allele_to_allele_specific_models.values()),
-            ", ".join(
-                "%s (%d)" % (allele, len(v))
-                for (allele, v)
-                in sorted(allele_to_allele_specific_models.items())))
-
-        result = Class1AffinityPredictor(
-            allele_to_allele_specific_models=allele_to_allele_specific_models,
-            class1_pan_allele_models=class1_pan_allele_models,
+        logging.info("Loaded %d class1 presentation models", len(models))
+        result = cls(
+            models=models,
             allele_to_sequence=allele_to_sequence,
-            manifest_df=manifest_df,
-            allele_to_percent_rank_transform=allele_to_percent_rank_transform,
-        )
-        if optimization_level >= 1:
-            optimized = result.optimize()
-            logging.info(
-                "Model optimization %s",
-                "succeeded" if optimized else "not supported for these models")
+            manifest_df=manifest_df)
         return result
-
-
-
-    # TODO: implement saving and loading
\ No newline at end of file
diff --git a/mhcflurry/common.py b/mhcflurry/common.py
index 8885637b..f51696ea 100644
--- a/mhcflurry/common.py
+++ b/mhcflurry/common.py
@@ -174,3 +174,35 @@ def positional_frequency_matrix(peptides):
     result = (counts / len(peptides)).fillna(0.0).T
     result.index.name = 'position'
     return result
+
+
+def save_weights(weights_list, filename):
+    """
+    Save model weights to the given filename using numpy's ".npz" format.
+
+    Parameters
+    ----------
+    weights_list : list of numpy array
+
+    filename : string
+    """
+    numpy.savez(filename,
+        **dict((("array_%d" % i), w) for (i, w) in enumerate(weights_list)))
+
+
+def load_weights(filename):
+    """
+    Restore model weights from the given filename, which should have been
+    created with `save_weights`.
+
+    Parameters
+    ----------
+    filename : string
+
+    Returns
+    ----------
+    list of array
+    """
+    with numpy.load(filename) as loaded:
+        weights = [loaded["array_%d" % i] for i in range(len(loaded.keys()))]
+    return weights
diff --git a/mhcflurry/downloads.py b/mhcflurry/downloads.py
index 7f29ea52..c3eca7cd 100644
--- a/mhcflurry/downloads.py
+++ b/mhcflurry/downloads.py
@@ -28,6 +28,8 @@ _CURRENT_RELEASE = None
 _METADATA = None
 _MHCFLURRY_DEFAULT_CLASS1_MODELS_DIR = environ.get(
     "MHCFLURRY_DEFAULT_CLASS1_MODELS")
+_MHCFLURRY_DEFAULT_CLASS1_PRESENTATION_MODELS_DIR = environ.get(
+    "MHCFLURRY_DEFAULT_CLASS1_PRESENTATION_MODELS_DIR")
 
 
 def get_downloads_dir():
@@ -84,6 +86,38 @@ def get_default_class1_models_dir(test_exists=True):
     return get_path("models_class1", "models", test_exists=test_exists)
 
 
+def get_default_class1_presentation_models_dir(test_exists=True):
+    """
+    Return the absolute path to the default class1 presentation models dir.
+
+    See `get_default_class1_models_dir`.
+
+    If environment variable MHCFLURRY_DEFAULT_CLASS1_PRESENTATION_MODELS is set
+    to an absolute path, return that path. If it's set to a relative path (does
+    not start with /) then return that path taken to be relative to the mhcflurry
+    downloads dir.
+
+    Parameters
+    ----------
+
+    test_exists : boolean, optional
+        Whether to raise an exception of the path does not exist
+
+    Returns
+    -------
+    string : absolute path
+    """
+    if _MHCFLURRY_DEFAULT_CLASS1_PRESENTATION_MODELS_DIR:
+        result = join(
+            get_downloads_dir(),
+            _MHCFLURRY_DEFAULT_CLASS1_PRESENTATION_MODELS_DIR)
+        if test_exists and not exists(result):
+            raise IOError("No such directory: %s" % result)
+        return result
+    return get_path(
+        "models_class1_pan_refined", "presentation", test_exists=test_exists)
+
+
 def get_current_release_downloads():
     """
     Return a dict of all available downloads in the current release.
-- 
GitLab