From 32c5fa3e4cb30a5a10cd49d6095f97dc653b7e0d Mon Sep 17 00:00:00 2001 From: Alex Rubinsteyn <alex.rubinsteyn@gmail.com> Date: Thu, 18 Feb 2016 14:16:13 -0500 Subject: [PATCH] added missing serialization_helpers module --- mhcflurry/__init__.py | 4 +- .../class1_allele_specific_hyperparameters.py | 4 +- mhcflurry/class1_binding_predictor.py | 90 ++++++------------- mhcflurry/ensemble.py | 81 +++++++++++++++-- mhcflurry/predictor_base.py | 8 +- mhcflurry/serialization_helpers.py | 89 ++++++++++++++++++ test/dummy_predictors.py | 34 +++++++ test/test_class1_binding_predictor.py | 36 +------- test/test_ensemble.py | 15 ++++ 9 files changed, 251 insertions(+), 110 deletions(-) create mode 100644 mhcflurry/serialization_helpers.py create mode 100644 test/dummy_predictors.py create mode 100644 test/test_ensemble.py diff --git a/mhcflurry/__init__.py b/mhcflurry/__init__.py index 4ab1d343..7c32e11b 100644 --- a/mhcflurry/__init__.py +++ b/mhcflurry/__init__.py @@ -19,6 +19,7 @@ from . import common from . import peptide_encoding from . import amino_acid from .class1_binding_predictor import Class1BindingPredictor +from .ensemble import Ensemble __all__ = [ "paths", @@ -27,5 +28,6 @@ __all__ = [ "peptide_encoding", "amino_acid", "common", - "Class1BindingPredictor" + "Class1BindingPredictor", + "Ensemble", ] diff --git a/mhcflurry/class1_allele_specific_hyperparameters.py b/mhcflurry/class1_allele_specific_hyperparameters.py index 420da0d2..6d2a4f2d 100644 --- a/mhcflurry/class1_allele_specific_hyperparameters.py +++ b/mhcflurry/class1_allele_specific_hyperparameters.py @@ -13,10 +13,10 @@ # limitations under the License. N_PRETRAIN_EPOCHS = 5 -N_EPOCHS = 150 +N_EPOCHS = 250 ACTIVATION = "tanh" INITIALIZATION_METHOD = "lecun_uniform" EMBEDDING_DIM = 32 HIDDEN_LAYER_SIZE = 200 DROPOUT_PROBABILITY = 0.25 -MAX_IC50 = 20000.0 +MAX_IC50 = 50000.0 diff --git a/mhcflurry/class1_binding_predictor.py b/mhcflurry/class1_binding_predictor.py index 7ad212f9..1104a071 100644 --- a/mhcflurry/class1_binding_predictor.py +++ b/mhcflurry/class1_binding_predictor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015. Mount Sinai School of Medicine +# Copyright (c) 2016. Mount Sinai School of Medicine # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,32 +20,37 @@ from __future__ import ( division, absolute_import, ) -import logging -from os import listdir, remove -from os.path import exists, join -import json +from os import listdir +from os.path import exists, join import numpy as np -from keras.models import model_from_config - from .common import normalize_allele_name from .paths import CLASS1_MODEL_DIRECTORY from .feedforward import make_embedding_network from .predictor_base import PredictorBase +from .serialization_helpers import ( + load_keras_model_from_disk, + save_keras_model_to_disk +) from .class1_allele_specific_hyperparameters import MAX_IC50 _allele_predictor_cache = {} class Class1BindingPredictor(PredictorBase): + """ + Allele-specific Class I MHC binding predictor which uses + fixed-length (9mer) index encoding for inputs and outputs + a value between 0 and 1 (where 1 is the strongest binder). + """ def __init__( self, model, name=None, max_ic50=MAX_IC50, - allow_unknown_amino_acids=False, + allow_unknown_amino_acids=True, verbose=False): PredictorBase.__init__( self, @@ -61,33 +66,24 @@ class Class1BindingPredictor(PredictorBase): cls, model_json_path, weights_hdf_path=None, - name=None, - max_ic50=MAX_IC50): + **kwargs): """ Load model from stored JSON representation of network and (optionally) load weights from HDF5 file. """ - if not exists(model_json_path): - raise ValueError("Model file %s (name = %s) not found" % ( - model_json_path, name,)) - - with open(model_json_path, "r") as f: - config_dict = json.load(f) - - model = model_from_config(config_dict) - - if weights_hdf_path: - if not exists(weights_hdf_path): - raise ValueError( - "Missing model weights file %s (name = %s)" % ( - weights_hdf_path, name)) - - model.load_weights(weights_hdf_path) + model = load_keras_model_from_disk( + model_json_path, + weights_hdf_path, + name=None) + return cls(model=model, **kwargs) - return cls.__init__( - model=model, - max_ic50=max_ic50, - name=name) + def to_disk(self, model_json_path, weights_hdf_path, overwrite=False): + save_keras_model_to_disk( + self.model, + model_json_path, + weights_hdf_path, + overwrite=overwrite, + name=self.name) @classmethod def from_hyperparameters( @@ -330,38 +326,6 @@ class Class1BindingPredictor(PredictorBase): verbose=0, batch_size=batch_size) - def to_disk(self, model_json_path, weights_hdf_path, overwrite=False): - if exists(model_json_path) and overwrite: - logging.info( - "Removing existing model JSON file '%s'" % ( - model_json_path,)) - remove(model_json_path) - - if exists(model_json_path): - logging.warn( - "Model JSON file '%s' already exists" % (model_json_path,)) - else: - logging.info( - "Saving model file %s (name=%s)" % (model_json_path, self.name)) - with open(model_json_path, "w") as f: - f.write(self.model.to_json()) - - if exists(weights_hdf_path) and overwrite: - logging.info( - "Removing existing model weights HDF5 file '%s'" % ( - weights_hdf_path,)) - remove(weights_hdf_path) - - if exists(weights_hdf_path): - logging.warn( - "Model weights HDF5 file '%s' already exists" % ( - weights_hdf_path,)) - else: - logging.info( - "Saving model weights HDF5 file %s (name=%s)" % ( - weights_hdf_path, self.name)) - self.model.save_weights(weights_hdf_path) - @classmethod def from_allele_name( cls, @@ -417,7 +381,7 @@ class Class1BindingPredictor(PredictorBase): def __str__(self): return repr(self) - def predict_encoded(self, X): + def predict(self, X): max_expected_index = 20 if self.allow_unknown_amino_acids else 19 assert X.max() <= max_expected_index, \ "Got index %d in peptide encoding" % (X.max(),) diff --git a/mhcflurry/ensemble.py b/mhcflurry/ensemble.py index a0373607..10880ce3 100644 --- a/mhcflurry/ensemble.py +++ b/mhcflurry/ensemble.py @@ -12,16 +12,81 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +from os import listdir +from os.path import splitext, join -class Ensemble(object): - def __init__(self, models, name=None): - self.name = name - self.models = models +import numpy as np - @classmethod - def from_directory(cls, directory_path): - files = os.listdir(directory_path) +from .class1_allele_specific_hyperparameters import MAX_IC50 +from .predictor_base import PredictorBase + +class Ensemble(PredictorBase): + def __init__( + self, + predictors, + name=None, + max_ic50=MAX_IC50, + allow_unknown_amino_acids=True, + verbose=False): + PredictorBase.__init__( + self, + name=name, + max_ic50=max_ic50, + allow_unknown_amino_acids=allow_unknown_amino_acids, + verbose=verbose) + self.predictors = predictors + @classmethod + def from_directory( + cls, + predictor_class, + directory_path, + name=None, + allow_unknown_amino_acids=True, + max_ic50=MAX_IC50, + verbose=False): + filenames = listdir(directory_path) + filename_set = set(filenames) + predictors = [] + for filename in filenames: + prefix, ext = splitext(filename) + if ext == ".json": + weights_filename = prefix + ".hdf5" + if weights_filename in filename_set: + json_path = join(directory_path, filename) + weights_path = join(directory_path, weights_filename) + predictor = predictor_class.from_disk( + json_path, + weights_path, + name=name + ("_%d" % (len(predictors))), + max_ic50=max_ic50, + allow_unknown_amino_acids=allow_unknown_amino_acids, + verbose=verbose) + predictors.append(predictor) + return cls( + predictors, + name=name, + max_ic50=max_ic50, + allow_unknown_amino_acids=allow_unknown_amino_acids, + verbose=verbose) + def to_directory(self, directory_path, base_name=None): + if not base_name: + base_name = self.name + if not base_name: + raise ValueError("Base name for serialized models required") + raise ValueError("Not yet implemented") + def predict(self, X): + X = np.asarray(X) + if len(X.shape) != 2: + raise ValueError("Expected encoded peptides to be 2d, got %s array" % ( + X.shape,)) + n = len(X) + y_combined = np.zeros(n) + for predictor in self.predictors: + y = predictor.predict(X) + assert len(y) == len(y_combined) + y_combined += y + y_combined /= len(self.predictors) + return y_combined diff --git a/mhcflurry/predictor_base.py b/mhcflurry/predictor_base.py index 9939f553..4b104049 100644 --- a/mhcflurry/predictor_base.py +++ b/mhcflurry/predictor_base.py @@ -86,7 +86,7 @@ class PredictorBase(object): if any(len(peptide) != 9 for peptide in peptides): raise ValueError("Can only predict 9mer peptides") X, _ = self.encode_peptides(peptides) - return self.predict_encoded(X) + return self.predict(X) def predict_9mer_peptides_ic50(self, peptides): return self.log_to_ic50(self.predict_9mer_peptides(peptides)) @@ -98,8 +98,8 @@ class PredictorBase(object): return self.log_to_ic50( self.predict_peptides(peptides)) - def predict_encoded(self, X): - raise ValueError("Not yet implemented for %s!" % ( + def predict(self, X): + raise ValueError("Method 'predict' not yet implemented for %s!" % ( self.__class__.__name__,)) def predict_peptides( @@ -118,7 +118,7 @@ class PredictorBase(object): # non-9mer peptides get multiple predictions, which are then combined # with the combine_fn argument multiple_predictions_dict = defaultdict(list) - fixed_length_predictions = self.predict_encoded(input_matrix) + fixed_length_predictions = self.predict(input_matrix) for i, yi in enumerate(fixed_length_predictions): original_peptide_index = original_peptide_indices[i] original_peptide = peptides[original_peptide_index] diff --git a/mhcflurry/serialization_helpers.py b/mhcflurry/serialization_helpers.py new file mode 100644 index 00000000..52914166 --- /dev/null +++ b/mhcflurry/serialization_helpers.py @@ -0,0 +1,89 @@ +# Copyright (c) 2015. Mount Sinai School of Medicine +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper functions for serialization/deserialization of Keras models +""" + +from __future__ import ( + print_function, + division, + absolute_import, +) +import logging +from os.path import exists +from os import remove +import json + + +from keras.models import model_from_config + + +def load_keras_model_from_disk(model_json_path, weights_hdf_path, name=None): + + if not exists(model_json_path): + raise ValueError("Model file %s (name = %s) not found" % ( + model_json_path, name,)) + + with open(model_json_path, "r") as f: + config_dict = json.load(f) + + model = model_from_config(config_dict) + + if weights_hdf_path: + if not exists(weights_hdf_path): + raise ValueError( + "Missing model weights file %s (name = %s)" % ( + weights_hdf_path, name)) + + model.load_weights(weights_hdf_path) + return model + + +def save_keras_model_to_disk( + model, + model_json_path, + weights_hdf_path, + overwrite=False, + name=None): + if exists(model_json_path) and overwrite: + logging.info( + "Removing existing model JSON file '%s'" % ( + model_json_path,)) + remove(model_json_path) + + if exists(model_json_path): + logging.warn( + "Model JSON file '%s' already exists" % (model_json_path,)) + else: + logging.info( + "Saving model file %s (name=%s)" % (model_json_path, name)) + with open(model_json_path, "w") as f: + f.write(model.to_json()) + + if exists(weights_hdf_path) and overwrite: + logging.info( + "Removing existing model weights HDF5 file '%s'" % ( + weights_hdf_path,)) + remove(weights_hdf_path) + + if exists(weights_hdf_path): + logging.warn( + "Model weights HDF5 file '%s' already exists" % ( + weights_hdf_path,)) + else: + logging.info( + "Saving model weights HDF5 file %s (name=%s)" % ( + weights_hdf_path, name)) + model.save_weights(weights_hdf_path) diff --git a/test/dummy_predictors.py b/test/dummy_predictors.py new file mode 100644 index 00000000..45f7b21c --- /dev/null +++ b/test/dummy_predictors.py @@ -0,0 +1,34 @@ +import numpy as np +from mhcflurry import Class1BindingPredictor + +class Dummy9merIndexEncodingModel(object): + """ + Dummy molde used for testing the pMHC binding predictor. + """ + def __init__(self, constant_output_value=0): + self.constant_output_value = constant_output_value + + def predict(self, X, verbose=False): + assert isinstance(X, np.ndarray) + assert len(X.shape) == 2 + n_rows, n_cols = X.shape + n_cols == 9, "Expected 9mer index input input, got %d columns" % ( + n_cols,) + return np.ones(n_rows, dtype=float) * self.constant_output_value + +always_zero_predictor_with_unknown_AAs = Class1BindingPredictor( + model=Dummy9merIndexEncodingModel(0), + allow_unknown_amino_acids=True) + +always_zero_predictor_without_unknown_AAs = Class1BindingPredictor( + model=Dummy9merIndexEncodingModel(0), + allow_unknown_amino_acids=False) + + +always_one_predictor_with_unknown_AAs = Class1BindingPredictor( + model=Dummy9merIndexEncodingModel(1), + allow_unknown_amino_acids=True) + +always_one_predictor_without_unknown_AAs = Class1BindingPredictor( + model=Dummy9merIndexEncodingModel(1), + allow_unknown_amino_acids=False) diff --git a/test/test_class1_binding_predictor.py b/test/test_class1_binding_predictor.py index c824ce1a..950a13cd 100644 --- a/test/test_class1_binding_predictor.py +++ b/test/test_class1_binding_predictor.py @@ -1,25 +1,11 @@ import numpy as np -from mhcflurry import Class1BindingPredictor - - -class Dummy9merIndexEncodingModel(object): - """ - Dummy molde used for testing the pMHC binding predictor. - """ - def predict(self, X, verbose=False): - assert isinstance(X, np.ndarray) - assert len(X.shape) == 2 - n_rows, n_cols = X.shape - n_cols == 9, "Expected 9mer index input input, got %d columns" % ( - n_cols,) - return np.zeros(n_rows, dtype=float) +import dummy_predictors +import dummy_predictors.always_zero_predictor_with_unknown_AAs as predictor def test_always_zero_9mer_inputs(): - predictor = Class1BindingPredictor( - model=Dummy9merIndexEncodingModel(), - allow_unknown_amino_acids=True) + test_9mer_peptides = [ "SIISIISII", "AAAAAAAAA", @@ -41,9 +27,6 @@ def test_always_zero_9mer_inputs(): def test_always_zero_8mer_inputs(): - predictor = Class1BindingPredictor( - model=Dummy9merIndexEncodingModel(), - allow_unknown_amino_acids=True) test_8mer_peptides = [ "SIISIISI", "AAAAAAAA", @@ -60,9 +43,7 @@ def test_always_zero_8mer_inputs(): def test_always_zero_10mer_inputs(): - predictor = Class1BindingPredictor( - model=Dummy9merIndexEncodingModel(), - allow_unknown_amino_acids=True) + test_10mer_peptides = [ "SIISIISIYY", "AAAAAAAAYY", @@ -79,9 +60,6 @@ def test_always_zero_10mer_inputs(): def test_encode_peptides_9mer(): - predictor = Class1BindingPredictor( - model=Dummy9merIndexEncodingModel(), - allow_unknown_amino_acids=True) X = predictor.encode_9mer_peptides(["AAASSSYYY"]) assert X.shape[0] == 1, X.shape assert X.shape[1] == 9, X.shape @@ -94,9 +72,6 @@ def test_encode_peptides_9mer(): def test_encode_peptides_8mer(): - predictor = Class1BindingPredictor( - model=Dummy9merIndexEncodingModel(), - allow_unknown_amino_acids=True) X, indices = predictor.encode_peptides(["AAASSSYY"]) assert len(indices) == 9 assert (indices == 0).all() @@ -105,9 +80,6 @@ def test_encode_peptides_8mer(): def test_encode_peptides_10mer(): - predictor = Class1BindingPredictor( - model=Dummy9merIndexEncodingModel(), - allow_unknown_amino_acids=True) X, indices = predictor.encode_peptides(["AAASSSYYFF"]) assert len(indices) == 10 assert (indices == 0).all() diff --git a/test/test_ensemble.py b/test/test_ensemble.py new file mode 100644 index 00000000..cc376da0 --- /dev/null +++ b/test/test_ensemble.py @@ -0,0 +1,15 @@ + +from dummy_predictors import ( + always_zero_predictor_with_unknown_AAs, + always_one_predictor_with_unknown_AAs, +) +from mhcflurry import Ensemble + +def test_ensemble_of_dummy_predictors(): + ensemble = Ensemble([ + always_one_predictor_with_unknown_AAs, + always_zero_predictor_with_unknown_AAs]) + peptides = ["SYYFFYLLY"] + y = ensemble.predict_peptides(peptides) + assert len(y) == len(peptides) + assert all(yi == 0.5 for yi in y) -- GitLab