Skip to content
Snippets Groups Projects
Commit 32c5fa3e authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

added missing serialization_helpers module

parent bb4701a0
No related merge requests found
......@@ -19,6 +19,7 @@ from . import common
from . import peptide_encoding
from . import amino_acid
from .class1_binding_predictor import Class1BindingPredictor
from .ensemble import Ensemble
__all__ = [
"paths",
......@@ -27,5 +28,6 @@ __all__ = [
"peptide_encoding",
"amino_acid",
"common",
"Class1BindingPredictor"
"Class1BindingPredictor",
"Ensemble",
]
......@@ -13,10 +13,10 @@
# limitations under the License.
N_PRETRAIN_EPOCHS = 5
N_EPOCHS = 150
N_EPOCHS = 250
ACTIVATION = "tanh"
INITIALIZATION_METHOD = "lecun_uniform"
EMBEDDING_DIM = 32
HIDDEN_LAYER_SIZE = 200
DROPOUT_PROBABILITY = 0.25
MAX_IC50 = 20000.0
MAX_IC50 = 50000.0
# Copyright (c) 2015. Mount Sinai School of Medicine
# Copyright (c) 2016. Mount Sinai School of Medicine
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,32 +20,37 @@ from __future__ import (
division,
absolute_import,
)
import logging
from os import listdir, remove
from os.path import exists, join
import json
from os import listdir
from os.path import exists, join
import numpy as np
from keras.models import model_from_config
from .common import normalize_allele_name
from .paths import CLASS1_MODEL_DIRECTORY
from .feedforward import make_embedding_network
from .predictor_base import PredictorBase
from .serialization_helpers import (
load_keras_model_from_disk,
save_keras_model_to_disk
)
from .class1_allele_specific_hyperparameters import MAX_IC50
_allele_predictor_cache = {}
class Class1BindingPredictor(PredictorBase):
"""
Allele-specific Class I MHC binding predictor which uses
fixed-length (9mer) index encoding for inputs and outputs
a value between 0 and 1 (where 1 is the strongest binder).
"""
def __init__(
self,
model,
name=None,
max_ic50=MAX_IC50,
allow_unknown_amino_acids=False,
allow_unknown_amino_acids=True,
verbose=False):
PredictorBase.__init__(
self,
......@@ -61,33 +66,24 @@ class Class1BindingPredictor(PredictorBase):
cls,
model_json_path,
weights_hdf_path=None,
name=None,
max_ic50=MAX_IC50):
**kwargs):
"""
Load model from stored JSON representation of network and
(optionally) load weights from HDF5 file.
"""
if not exists(model_json_path):
raise ValueError("Model file %s (name = %s) not found" % (
model_json_path, name,))
with open(model_json_path, "r") as f:
config_dict = json.load(f)
model = model_from_config(config_dict)
if weights_hdf_path:
if not exists(weights_hdf_path):
raise ValueError(
"Missing model weights file %s (name = %s)" % (
weights_hdf_path, name))
model.load_weights(weights_hdf_path)
model = load_keras_model_from_disk(
model_json_path,
weights_hdf_path,
name=None)
return cls(model=model, **kwargs)
return cls.__init__(
model=model,
max_ic50=max_ic50,
name=name)
def to_disk(self, model_json_path, weights_hdf_path, overwrite=False):
save_keras_model_to_disk(
self.model,
model_json_path,
weights_hdf_path,
overwrite=overwrite,
name=self.name)
@classmethod
def from_hyperparameters(
......@@ -330,38 +326,6 @@ class Class1BindingPredictor(PredictorBase):
verbose=0,
batch_size=batch_size)
def to_disk(self, model_json_path, weights_hdf_path, overwrite=False):
if exists(model_json_path) and overwrite:
logging.info(
"Removing existing model JSON file '%s'" % (
model_json_path,))
remove(model_json_path)
if exists(model_json_path):
logging.warn(
"Model JSON file '%s' already exists" % (model_json_path,))
else:
logging.info(
"Saving model file %s (name=%s)" % (model_json_path, self.name))
with open(model_json_path, "w") as f:
f.write(self.model.to_json())
if exists(weights_hdf_path) and overwrite:
logging.info(
"Removing existing model weights HDF5 file '%s'" % (
weights_hdf_path,))
remove(weights_hdf_path)
if exists(weights_hdf_path):
logging.warn(
"Model weights HDF5 file '%s' already exists" % (
weights_hdf_path,))
else:
logging.info(
"Saving model weights HDF5 file %s (name=%s)" % (
weights_hdf_path, self.name))
self.model.save_weights(weights_hdf_path)
@classmethod
def from_allele_name(
cls,
......@@ -417,7 +381,7 @@ class Class1BindingPredictor(PredictorBase):
def __str__(self):
return repr(self)
def predict_encoded(self, X):
def predict(self, X):
max_expected_index = 20 if self.allow_unknown_amino_acids else 19
assert X.max() <= max_expected_index, \
"Got index %d in peptide encoding" % (X.max(),)
......
......@@ -12,16 +12,81 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from os import listdir
from os.path import splitext, join
class Ensemble(object):
def __init__(self, models, name=None):
self.name = name
self.models = models
import numpy as np
@classmethod
def from_directory(cls, directory_path):
files = os.listdir(directory_path)
from .class1_allele_specific_hyperparameters import MAX_IC50
from .predictor_base import PredictorBase
class Ensemble(PredictorBase):
def __init__(
self,
predictors,
name=None,
max_ic50=MAX_IC50,
allow_unknown_amino_acids=True,
verbose=False):
PredictorBase.__init__(
self,
name=name,
max_ic50=max_ic50,
allow_unknown_amino_acids=allow_unknown_amino_acids,
verbose=verbose)
self.predictors = predictors
@classmethod
def from_directory(
cls,
predictor_class,
directory_path,
name=None,
allow_unknown_amino_acids=True,
max_ic50=MAX_IC50,
verbose=False):
filenames = listdir(directory_path)
filename_set = set(filenames)
predictors = []
for filename in filenames:
prefix, ext = splitext(filename)
if ext == ".json":
weights_filename = prefix + ".hdf5"
if weights_filename in filename_set:
json_path = join(directory_path, filename)
weights_path = join(directory_path, weights_filename)
predictor = predictor_class.from_disk(
json_path,
weights_path,
name=name + ("_%d" % (len(predictors))),
max_ic50=max_ic50,
allow_unknown_amino_acids=allow_unknown_amino_acids,
verbose=verbose)
predictors.append(predictor)
return cls(
predictors,
name=name,
max_ic50=max_ic50,
allow_unknown_amino_acids=allow_unknown_amino_acids,
verbose=verbose)
def to_directory(self, directory_path, base_name=None):
if not base_name:
base_name = self.name
if not base_name:
raise ValueError("Base name for serialized models required")
raise ValueError("Not yet implemented")
def predict(self, X):
X = np.asarray(X)
if len(X.shape) != 2:
raise ValueError("Expected encoded peptides to be 2d, got %s array" % (
X.shape,))
n = len(X)
y_combined = np.zeros(n)
for predictor in self.predictors:
y = predictor.predict(X)
assert len(y) == len(y_combined)
y_combined += y
y_combined /= len(self.predictors)
return y_combined
......@@ -86,7 +86,7 @@ class PredictorBase(object):
if any(len(peptide) != 9 for peptide in peptides):
raise ValueError("Can only predict 9mer peptides")
X, _ = self.encode_peptides(peptides)
return self.predict_encoded(X)
return self.predict(X)
def predict_9mer_peptides_ic50(self, peptides):
return self.log_to_ic50(self.predict_9mer_peptides(peptides))
......@@ -98,8 +98,8 @@ class PredictorBase(object):
return self.log_to_ic50(
self.predict_peptides(peptides))
def predict_encoded(self, X):
raise ValueError("Not yet implemented for %s!" % (
def predict(self, X):
raise ValueError("Method 'predict' not yet implemented for %s!" % (
self.__class__.__name__,))
def predict_peptides(
......@@ -118,7 +118,7 @@ class PredictorBase(object):
# non-9mer peptides get multiple predictions, which are then combined
# with the combine_fn argument
multiple_predictions_dict = defaultdict(list)
fixed_length_predictions = self.predict_encoded(input_matrix)
fixed_length_predictions = self.predict(input_matrix)
for i, yi in enumerate(fixed_length_predictions):
original_peptide_index = original_peptide_indices[i]
original_peptide = peptides[original_peptide_index]
......
# Copyright (c) 2015. Mount Sinai School of Medicine
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Helper functions for serialization/deserialization of Keras models
"""
from __future__ import (
print_function,
division,
absolute_import,
)
import logging
from os.path import exists
from os import remove
import json
from keras.models import model_from_config
def load_keras_model_from_disk(model_json_path, weights_hdf_path, name=None):
if not exists(model_json_path):
raise ValueError("Model file %s (name = %s) not found" % (
model_json_path, name,))
with open(model_json_path, "r") as f:
config_dict = json.load(f)
model = model_from_config(config_dict)
if weights_hdf_path:
if not exists(weights_hdf_path):
raise ValueError(
"Missing model weights file %s (name = %s)" % (
weights_hdf_path, name))
model.load_weights(weights_hdf_path)
return model
def save_keras_model_to_disk(
model,
model_json_path,
weights_hdf_path,
overwrite=False,
name=None):
if exists(model_json_path) and overwrite:
logging.info(
"Removing existing model JSON file '%s'" % (
model_json_path,))
remove(model_json_path)
if exists(model_json_path):
logging.warn(
"Model JSON file '%s' already exists" % (model_json_path,))
else:
logging.info(
"Saving model file %s (name=%s)" % (model_json_path, name))
with open(model_json_path, "w") as f:
f.write(model.to_json())
if exists(weights_hdf_path) and overwrite:
logging.info(
"Removing existing model weights HDF5 file '%s'" % (
weights_hdf_path,))
remove(weights_hdf_path)
if exists(weights_hdf_path):
logging.warn(
"Model weights HDF5 file '%s' already exists" % (
weights_hdf_path,))
else:
logging.info(
"Saving model weights HDF5 file %s (name=%s)" % (
weights_hdf_path, name))
model.save_weights(weights_hdf_path)
import numpy as np
from mhcflurry import Class1BindingPredictor
class Dummy9merIndexEncodingModel(object):
"""
Dummy molde used for testing the pMHC binding predictor.
"""
def __init__(self, constant_output_value=0):
self.constant_output_value = constant_output_value
def predict(self, X, verbose=False):
assert isinstance(X, np.ndarray)
assert len(X.shape) == 2
n_rows, n_cols = X.shape
n_cols == 9, "Expected 9mer index input input, got %d columns" % (
n_cols,)
return np.ones(n_rows, dtype=float) * self.constant_output_value
always_zero_predictor_with_unknown_AAs = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(0),
allow_unknown_amino_acids=True)
always_zero_predictor_without_unknown_AAs = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(0),
allow_unknown_amino_acids=False)
always_one_predictor_with_unknown_AAs = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(1),
allow_unknown_amino_acids=True)
always_one_predictor_without_unknown_AAs = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(1),
allow_unknown_amino_acids=False)
import numpy as np
from mhcflurry import Class1BindingPredictor
class Dummy9merIndexEncodingModel(object):
"""
Dummy molde used for testing the pMHC binding predictor.
"""
def predict(self, X, verbose=False):
assert isinstance(X, np.ndarray)
assert len(X.shape) == 2
n_rows, n_cols = X.shape
n_cols == 9, "Expected 9mer index input input, got %d columns" % (
n_cols,)
return np.zeros(n_rows, dtype=float)
import dummy_predictors
import dummy_predictors.always_zero_predictor_with_unknown_AAs as predictor
def test_always_zero_9mer_inputs():
predictor = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(),
allow_unknown_amino_acids=True)
test_9mer_peptides = [
"SIISIISII",
"AAAAAAAAA",
......@@ -41,9 +27,6 @@ def test_always_zero_9mer_inputs():
def test_always_zero_8mer_inputs():
predictor = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(),
allow_unknown_amino_acids=True)
test_8mer_peptides = [
"SIISIISI",
"AAAAAAAA",
......@@ -60,9 +43,7 @@ def test_always_zero_8mer_inputs():
def test_always_zero_10mer_inputs():
predictor = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(),
allow_unknown_amino_acids=True)
test_10mer_peptides = [
"SIISIISIYY",
"AAAAAAAAYY",
......@@ -79,9 +60,6 @@ def test_always_zero_10mer_inputs():
def test_encode_peptides_9mer():
predictor = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(),
allow_unknown_amino_acids=True)
X = predictor.encode_9mer_peptides(["AAASSSYYY"])
assert X.shape[0] == 1, X.shape
assert X.shape[1] == 9, X.shape
......@@ -94,9 +72,6 @@ def test_encode_peptides_9mer():
def test_encode_peptides_8mer():
predictor = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(),
allow_unknown_amino_acids=True)
X, indices = predictor.encode_peptides(["AAASSSYY"])
assert len(indices) == 9
assert (indices == 0).all()
......@@ -105,9 +80,6 @@ def test_encode_peptides_8mer():
def test_encode_peptides_10mer():
predictor = Class1BindingPredictor(
model=Dummy9merIndexEncodingModel(),
allow_unknown_amino_acids=True)
X, indices = predictor.encode_peptides(["AAASSSYYFF"])
assert len(indices) == 10
assert (indices == 0).all()
......
from dummy_predictors import (
always_zero_predictor_with_unknown_AAs,
always_one_predictor_with_unknown_AAs,
)
from mhcflurry import Ensemble
def test_ensemble_of_dummy_predictors():
ensemble = Ensemble([
always_one_predictor_with_unknown_AAs,
always_zero_predictor_with_unknown_AAs])
peptides = ["SYYFFYLLY"]
y = ensemble.predict_peptides(peptides)
assert len(y) == len(peptides)
assert all(yi == 0.5 for yi in y)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment