From 32c5fa3e4cb30a5a10cd49d6095f97dc653b7e0d Mon Sep 17 00:00:00 2001
From: Alex Rubinsteyn <alex.rubinsteyn@gmail.com>
Date: Thu, 18 Feb 2016 14:16:13 -0500
Subject: [PATCH] added missing serialization_helpers module

---
 mhcflurry/__init__.py                         |  4 +-
 .../class1_allele_specific_hyperparameters.py |  4 +-
 mhcflurry/class1_binding_predictor.py         | 90 ++++++-------------
 mhcflurry/ensemble.py                         | 81 +++++++++++++++--
 mhcflurry/predictor_base.py                   |  8 +-
 mhcflurry/serialization_helpers.py            | 89 ++++++++++++++++++
 test/dummy_predictors.py                      | 34 +++++++
 test/test_class1_binding_predictor.py         | 36 +-------
 test/test_ensemble.py                         | 15 ++++
 9 files changed, 251 insertions(+), 110 deletions(-)
 create mode 100644 mhcflurry/serialization_helpers.py
 create mode 100644 test/dummy_predictors.py
 create mode 100644 test/test_ensemble.py

diff --git a/mhcflurry/__init__.py b/mhcflurry/__init__.py
index 4ab1d343..7c32e11b 100644
--- a/mhcflurry/__init__.py
+++ b/mhcflurry/__init__.py
@@ -19,6 +19,7 @@ from . import common
 from . import peptide_encoding
 from . import amino_acid
 from .class1_binding_predictor import Class1BindingPredictor
+from .ensemble import Ensemble
 
 __all__ = [
     "paths",
@@ -27,5 +28,6 @@ __all__ = [
     "peptide_encoding",
     "amino_acid",
     "common",
-    "Class1BindingPredictor"
+    "Class1BindingPredictor",
+    "Ensemble",
 ]
diff --git a/mhcflurry/class1_allele_specific_hyperparameters.py b/mhcflurry/class1_allele_specific_hyperparameters.py
index 420da0d2..6d2a4f2d 100644
--- a/mhcflurry/class1_allele_specific_hyperparameters.py
+++ b/mhcflurry/class1_allele_specific_hyperparameters.py
@@ -13,10 +13,10 @@
 # limitations under the License.
 
 N_PRETRAIN_EPOCHS = 5
-N_EPOCHS = 150
+N_EPOCHS = 250
 ACTIVATION = "tanh"
 INITIALIZATION_METHOD = "lecun_uniform"
 EMBEDDING_DIM = 32
 HIDDEN_LAYER_SIZE = 200
 DROPOUT_PROBABILITY = 0.25
-MAX_IC50 = 20000.0
+MAX_IC50 = 50000.0
diff --git a/mhcflurry/class1_binding_predictor.py b/mhcflurry/class1_binding_predictor.py
index 7ad212f9..1104a071 100644
--- a/mhcflurry/class1_binding_predictor.py
+++ b/mhcflurry/class1_binding_predictor.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2015. Mount Sinai School of Medicine
+# Copyright (c) 2016. Mount Sinai School of Medicine
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,32 +20,37 @@ from __future__ import (
     division,
     absolute_import,
 )
-import logging
-from os import listdir, remove
-from os.path import exists, join
 
-import json
+from os import listdir
+from os.path import exists, join
 
 import numpy as np
-from keras.models import model_from_config
-
 
 from .common import normalize_allele_name
 from .paths import CLASS1_MODEL_DIRECTORY
 from .feedforward import make_embedding_network
 from .predictor_base import PredictorBase
+from .serialization_helpers import (
+    load_keras_model_from_disk,
+    save_keras_model_to_disk
+)
 
 from .class1_allele_specific_hyperparameters import MAX_IC50
 
 _allele_predictor_cache = {}
 
 class Class1BindingPredictor(PredictorBase):
+    """
+    Allele-specific Class I MHC binding predictor which uses
+    fixed-length (9mer) index encoding for inputs and outputs
+    a value between 0 and 1 (where 1 is the strongest binder).
+    """
     def __init__(
             self,
             model,
             name=None,
             max_ic50=MAX_IC50,
-            allow_unknown_amino_acids=False,
+            allow_unknown_amino_acids=True,
             verbose=False):
         PredictorBase.__init__(
             self,
@@ -61,33 +66,24 @@ class Class1BindingPredictor(PredictorBase):
             cls,
             model_json_path,
             weights_hdf_path=None,
-            name=None,
-            max_ic50=MAX_IC50):
+            **kwargs):
         """
         Load model from stored JSON representation of network and
         (optionally) load weights from HDF5 file.
         """
-        if not exists(model_json_path):
-            raise ValueError("Model file %s (name = %s) not found" % (
-                model_json_path, name,))
-
-        with open(model_json_path, "r") as f:
-            config_dict = json.load(f)
-
-        model = model_from_config(config_dict)
-
-        if weights_hdf_path:
-            if not exists(weights_hdf_path):
-                raise ValueError(
-                    "Missing model weights file %s (name = %s)" % (
-                        weights_hdf_path, name))
-
-            model.load_weights(weights_hdf_path)
+        model = load_keras_model_from_disk(
+            model_json_path,
+            weights_hdf_path,
+            name=None)
+        return cls(model=model, **kwargs)
 
-        return cls.__init__(
-            model=model,
-            max_ic50=max_ic50,
-            name=name)
+    def to_disk(self, model_json_path, weights_hdf_path, overwrite=False):
+        save_keras_model_to_disk(
+            self.model,
+            model_json_path,
+            weights_hdf_path,
+            overwrite=overwrite,
+            name=self.name)
 
     @classmethod
     def from_hyperparameters(
@@ -330,38 +326,6 @@ class Class1BindingPredictor(PredictorBase):
                     verbose=0,
                     batch_size=batch_size)
 
-    def to_disk(self, model_json_path, weights_hdf_path, overwrite=False):
-        if exists(model_json_path) and overwrite:
-            logging.info(
-                "Removing existing model JSON file '%s'" % (
-                    model_json_path,))
-            remove(model_json_path)
-
-        if exists(model_json_path):
-            logging.warn(
-                "Model JSON file '%s' already exists" % (model_json_path,))
-        else:
-            logging.info(
-                "Saving model file %s (name=%s)" % (model_json_path, self.name))
-            with open(model_json_path, "w") as f:
-                f.write(self.model.to_json())
-
-        if exists(weights_hdf_path) and overwrite:
-            logging.info(
-                "Removing existing model weights HDF5 file '%s'" % (
-                    weights_hdf_path,))
-            remove(weights_hdf_path)
-
-        if exists(weights_hdf_path):
-            logging.warn(
-                "Model weights HDF5 file '%s' already exists" % (
-                    weights_hdf_path,))
-        else:
-            logging.info(
-                "Saving model weights HDF5 file %s (name=%s)" % (
-                    weights_hdf_path, self.name))
-            self.model.save_weights(weights_hdf_path)
-
     @classmethod
     def from_allele_name(
             cls,
@@ -417,7 +381,7 @@ class Class1BindingPredictor(PredictorBase):
     def __str__(self):
         return repr(self)
 
-    def predict_encoded(self, X):
+    def predict(self, X):
         max_expected_index = 20 if self.allow_unknown_amino_acids else 19
         assert X.max() <= max_expected_index, \
             "Got index %d in peptide encoding" % (X.max(),)
diff --git a/mhcflurry/ensemble.py b/mhcflurry/ensemble.py
index a0373607..10880ce3 100644
--- a/mhcflurry/ensemble.py
+++ b/mhcflurry/ensemble.py
@@ -12,16 +12,81 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os
+from os import listdir
+from os.path import splitext, join
 
-class Ensemble(object):
-    def __init__(self, models, name=None):
-        self.name = name
-        self.models = models
+import numpy as np
 
-    @classmethod
-    def from_directory(cls, directory_path):
-        files = os.listdir(directory_path)
+from .class1_allele_specific_hyperparameters import MAX_IC50
+from .predictor_base import PredictorBase
+
+class Ensemble(PredictorBase):
+    def __init__(
+            self,
+            predictors,
+            name=None,
+            max_ic50=MAX_IC50,
+            allow_unknown_amino_acids=True,
+            verbose=False):
+        PredictorBase.__init__(
+            self,
+            name=name,
+            max_ic50=max_ic50,
+            allow_unknown_amino_acids=allow_unknown_amino_acids,
+            verbose=verbose)
+        self.predictors = predictors
 
+    @classmethod
+    def from_directory(
+            cls,
+            predictor_class,
+            directory_path,
+            name=None,
+            allow_unknown_amino_acids=True,
+            max_ic50=MAX_IC50,
+            verbose=False):
+        filenames = listdir(directory_path)
+        filename_set = set(filenames)
+        predictors = []
+        for filename in filenames:
+            prefix, ext = splitext(filename)
+            if ext == ".json":
+                weights_filename = prefix + ".hdf5"
+                if weights_filename in filename_set:
+                    json_path = join(directory_path, filename)
+                    weights_path = join(directory_path, weights_filename)
+                    predictor = predictor_class.from_disk(
+                        json_path,
+                        weights_path,
+                        name=name + ("_%d" % (len(predictors))),
+                        max_ic50=max_ic50,
+                        allow_unknown_amino_acids=allow_unknown_amino_acids,
+                        verbose=verbose)
+                    predictors.append(predictor)
+        return cls(
+            predictors,
+            name=name,
+            max_ic50=max_ic50,
+            allow_unknown_amino_acids=allow_unknown_amino_acids,
+            verbose=verbose)
 
+    def to_directory(self, directory_path, base_name=None):
+        if not base_name:
+            base_name = self.name
+        if not base_name:
+            raise ValueError("Base name for serialized models required")
+        raise ValueError("Not yet implemented")
 
+    def predict(self, X):
+        X = np.asarray(X)
+        if len(X.shape) != 2:
+            raise ValueError("Expected encoded peptides to be 2d, got %s array" % (
+                X.shape,))
+        n = len(X)
+        y_combined = np.zeros(n)
+        for predictor in self.predictors:
+            y = predictor.predict(X)
+            assert len(y) == len(y_combined)
+            y_combined += y
+        y_combined /= len(self.predictors)
+        return y_combined
diff --git a/mhcflurry/predictor_base.py b/mhcflurry/predictor_base.py
index 9939f553..4b104049 100644
--- a/mhcflurry/predictor_base.py
+++ b/mhcflurry/predictor_base.py
@@ -86,7 +86,7 @@ class PredictorBase(object):
         if any(len(peptide) != 9 for peptide in peptides):
             raise ValueError("Can only predict 9mer peptides")
         X, _ = self.encode_peptides(peptides)
-        return self.predict_encoded(X)
+        return self.predict(X)
 
     def predict_9mer_peptides_ic50(self, peptides):
         return self.log_to_ic50(self.predict_9mer_peptides(peptides))
@@ -98,8 +98,8 @@ class PredictorBase(object):
         return self.log_to_ic50(
             self.predict_peptides(peptides))
 
-    def predict_encoded(self, X):
-        raise ValueError("Not yet implemented for %s!" % (
+    def predict(self, X):
+        raise ValueError("Method 'predict' not yet implemented for %s!" % (
             self.__class__.__name__,))
 
     def predict_peptides(
@@ -118,7 +118,7 @@ class PredictorBase(object):
         # non-9mer peptides get multiple predictions, which are then combined
         # with the combine_fn argument
         multiple_predictions_dict = defaultdict(list)
-        fixed_length_predictions = self.predict_encoded(input_matrix)
+        fixed_length_predictions = self.predict(input_matrix)
         for i, yi in enumerate(fixed_length_predictions):
             original_peptide_index = original_peptide_indices[i]
             original_peptide = peptides[original_peptide_index]
diff --git a/mhcflurry/serialization_helpers.py b/mhcflurry/serialization_helpers.py
new file mode 100644
index 00000000..52914166
--- /dev/null
+++ b/mhcflurry/serialization_helpers.py
@@ -0,0 +1,89 @@
+# Copyright (c) 2015. Mount Sinai School of Medicine
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Helper functions for serialization/deserialization of Keras models
+"""
+
+from __future__ import (
+    print_function,
+    division,
+    absolute_import,
+)
+import logging
+from os.path import exists
+from os import remove
+import json
+
+
+from keras.models import model_from_config
+
+
+def load_keras_model_from_disk(model_json_path, weights_hdf_path, name=None):
+
+    if not exists(model_json_path):
+        raise ValueError("Model file %s (name = %s) not found" % (
+            model_json_path, name,))
+
+    with open(model_json_path, "r") as f:
+        config_dict = json.load(f)
+
+    model = model_from_config(config_dict)
+
+    if weights_hdf_path:
+        if not exists(weights_hdf_path):
+            raise ValueError(
+                "Missing model weights file %s (name = %s)" % (
+                    weights_hdf_path, name))
+
+        model.load_weights(weights_hdf_path)
+    return model
+
+
+def save_keras_model_to_disk(
+        model,
+        model_json_path,
+        weights_hdf_path,
+        overwrite=False,
+        name=None):
+    if exists(model_json_path) and overwrite:
+        logging.info(
+            "Removing existing model JSON file '%s'" % (
+                model_json_path,))
+        remove(model_json_path)
+
+    if exists(model_json_path):
+        logging.warn(
+            "Model JSON file '%s' already exists" % (model_json_path,))
+    else:
+        logging.info(
+            "Saving model file %s (name=%s)" % (model_json_path, name))
+        with open(model_json_path, "w") as f:
+            f.write(model.to_json())
+
+    if exists(weights_hdf_path) and overwrite:
+        logging.info(
+            "Removing existing model weights HDF5 file '%s'" % (
+                weights_hdf_path,))
+        remove(weights_hdf_path)
+
+    if exists(weights_hdf_path):
+        logging.warn(
+            "Model weights HDF5 file '%s' already exists" % (
+                weights_hdf_path,))
+    else:
+        logging.info(
+            "Saving model weights HDF5 file %s (name=%s)" % (
+                weights_hdf_path, name))
+        model.save_weights(weights_hdf_path)
diff --git a/test/dummy_predictors.py b/test/dummy_predictors.py
new file mode 100644
index 00000000..45f7b21c
--- /dev/null
+++ b/test/dummy_predictors.py
@@ -0,0 +1,34 @@
+import numpy as np
+from mhcflurry import Class1BindingPredictor
+
+class Dummy9merIndexEncodingModel(object):
+    """
+    Dummy molde used for testing the pMHC binding predictor.
+    """
+    def __init__(self, constant_output_value=0):
+        self.constant_output_value = constant_output_value
+
+    def predict(self, X, verbose=False):
+        assert isinstance(X, np.ndarray)
+        assert len(X.shape) == 2
+        n_rows, n_cols = X.shape
+        n_cols == 9, "Expected 9mer index input input, got %d columns" % (
+            n_cols,)
+        return np.ones(n_rows, dtype=float) * self.constant_output_value
+
+always_zero_predictor_with_unknown_AAs = Class1BindingPredictor(
+    model=Dummy9merIndexEncodingModel(0),
+    allow_unknown_amino_acids=True)
+
+always_zero_predictor_without_unknown_AAs = Class1BindingPredictor(
+    model=Dummy9merIndexEncodingModel(0),
+    allow_unknown_amino_acids=False)
+
+
+always_one_predictor_with_unknown_AAs = Class1BindingPredictor(
+    model=Dummy9merIndexEncodingModel(1),
+    allow_unknown_amino_acids=True)
+
+always_one_predictor_without_unknown_AAs = Class1BindingPredictor(
+    model=Dummy9merIndexEncodingModel(1),
+    allow_unknown_amino_acids=False)
diff --git a/test/test_class1_binding_predictor.py b/test/test_class1_binding_predictor.py
index c824ce1a..950a13cd 100644
--- a/test/test_class1_binding_predictor.py
+++ b/test/test_class1_binding_predictor.py
@@ -1,25 +1,11 @@
 import numpy as np
 
-from mhcflurry import Class1BindingPredictor
-
-
-class Dummy9merIndexEncodingModel(object):
-    """
-    Dummy molde used for testing the pMHC binding predictor.
-    """
-    def predict(self, X, verbose=False):
-        assert isinstance(X, np.ndarray)
-        assert len(X.shape) == 2
-        n_rows, n_cols = X.shape
-        n_cols == 9, "Expected 9mer index input input, got %d columns" % (
-            n_cols,)
-        return np.zeros(n_rows, dtype=float)
+import dummy_predictors
+import dummy_predictors.always_zero_predictor_with_unknown_AAs as predictor
 
 
 def test_always_zero_9mer_inputs():
-    predictor = Class1BindingPredictor(
-        model=Dummy9merIndexEncodingModel(),
-        allow_unknown_amino_acids=True)
+
     test_9mer_peptides = [
         "SIISIISII",
         "AAAAAAAAA",
@@ -41,9 +27,6 @@ def test_always_zero_9mer_inputs():
 
 
 def test_always_zero_8mer_inputs():
-    predictor = Class1BindingPredictor(
-        model=Dummy9merIndexEncodingModel(),
-        allow_unknown_amino_acids=True)
     test_8mer_peptides = [
         "SIISIISI",
         "AAAAAAAA",
@@ -60,9 +43,7 @@ def test_always_zero_8mer_inputs():
 
 
 def test_always_zero_10mer_inputs():
-    predictor = Class1BindingPredictor(
-        model=Dummy9merIndexEncodingModel(),
-        allow_unknown_amino_acids=True)
+
     test_10mer_peptides = [
         "SIISIISIYY",
         "AAAAAAAAYY",
@@ -79,9 +60,6 @@ def test_always_zero_10mer_inputs():
 
 
 def test_encode_peptides_9mer():
-    predictor = Class1BindingPredictor(
-        model=Dummy9merIndexEncodingModel(),
-        allow_unknown_amino_acids=True)
     X = predictor.encode_9mer_peptides(["AAASSSYYY"])
     assert X.shape[0] == 1, X.shape
     assert X.shape[1] == 9, X.shape
@@ -94,9 +72,6 @@ def test_encode_peptides_9mer():
 
 
 def test_encode_peptides_8mer():
-    predictor = Class1BindingPredictor(
-        model=Dummy9merIndexEncodingModel(),
-        allow_unknown_amino_acids=True)
     X, indices = predictor.encode_peptides(["AAASSSYY"])
     assert len(indices) == 9
     assert (indices == 0).all()
@@ -105,9 +80,6 @@ def test_encode_peptides_8mer():
 
 
 def test_encode_peptides_10mer():
-    predictor = Class1BindingPredictor(
-        model=Dummy9merIndexEncodingModel(),
-        allow_unknown_amino_acids=True)
     X, indices = predictor.encode_peptides(["AAASSSYYFF"])
     assert len(indices) == 10
     assert (indices == 0).all()
diff --git a/test/test_ensemble.py b/test/test_ensemble.py
new file mode 100644
index 00000000..cc376da0
--- /dev/null
+++ b/test/test_ensemble.py
@@ -0,0 +1,15 @@
+
+from dummy_predictors import (
+    always_zero_predictor_with_unknown_AAs,
+    always_one_predictor_with_unknown_AAs,
+)
+from mhcflurry import Ensemble
+
+def test_ensemble_of_dummy_predictors():
+    ensemble = Ensemble([
+        always_one_predictor_with_unknown_AAs,
+        always_zero_predictor_with_unknown_AAs])
+    peptides = ["SYYFFYLLY"]
+    y = ensemble.predict_peptides(peptides)
+    assert len(y) == len(peptides)
+    assert all(yi == 0.5 for yi in y)
-- 
GitLab