diff --git a/mhcflurry/class1_allele_specific_hyperparameters.py b/mhcflurry/class1_allele_specific_hyperparameters.py index aaf1ae3a525ce7354af2deb66305876f7b3fea00..b4017d06f6c4db6162b32378c736b7ba13b008b0 100644 --- a/mhcflurry/class1_allele_specific_hyperparameters.py +++ b/mhcflurry/class1_allele_specific_hyperparameters.py @@ -19,3 +19,4 @@ INITIALIZATION_METHOD = "lecun_uniform" EMBEDDING_DIM = 64 HIDDEN_LAYER_SIZE = 400 DROPOUT_PROBABILITY = 0.25 +MAX_IC50 = 5000.0 diff --git a/mhcflurry/common.py b/mhcflurry/common.py index 8623aa97b9ef1f75298237c15b41e3b7d72c7344..eb91f0cfeb8072c39e76aed4ad342f5b5e5094f7 100644 --- a/mhcflurry/common.py +++ b/mhcflurry/common.py @@ -17,6 +17,7 @@ from __future__ import ( division, absolute_import, ) +from .amino_acid import amino_acid_letters def parse_int_list(s): @@ -63,3 +64,37 @@ def split_allele_names(s): for part in s.split(",") ] + + +def expand_9mer_peptides(peptides, length): + """ + Expand non-9mer peptides using methods from + Accurate approximation method for prediction of class I MHC + affinities for peptides of length 8, 10 and 11 using prediction + tools trained on 9mers. + by Lundegaard et. al. + http://bioinformatics.oxfordjournals.org/content/24/11/1397 + """ + assert len(peptides) > 0 + if length < 8: + raise ValueError("Invalid peptide length: %d (%s)" % ( + length, peptides[0])) + elif length == 9: + return peptides + elif length == 8: + # extend each peptide by inserting every possible amino acid + # between base-1 positions 4-8 + return [ + peptide[:i] + extra_amino_acid + peptide[i:] + for peptide in peptides + for i in range(3, 8) + for extra_amino_acid in amino_acid_letters + ] + else: + # drop interior residues between base-1 positions 4 to last + n_skip = length - 9 + return [ + peptide[:i] + peptide[i + n_skip:] + for peptide in peptides + for i in range(3, length - 1) + ] diff --git a/mhcflurry/mhc1_binding_predictor.py b/mhcflurry/mhc1_binding_predictor.py index 983a8293ab96e85def14294b629f4e81772a1326..bd41aa458c7d7862eb03f76be59e30f3c33d7a97 100644 --- a/mhcflurry/mhc1_binding_predictor.py +++ b/mhcflurry/mhc1_binding_predictor.py @@ -26,18 +26,12 @@ from itertools import groupby import numpy as np import pandas as pd +from keras.models import model_from_json -from .amino_acid import amino_acid_letters -from .feedforward import make_network -from .class1_allele_specific_hyperparameters import ( - EMBEDDING_DIM, - HIDDEN_LAYER_SIZE, - ACTIVATION, - INITIALIZATION_METHOD, - DROPOUT_PROBABILITY, -) +from .class1_allele_specific_hyperparameters import MAX_IC50 from .data_helpers import index_encoding, normalize_allele_name from .paths import CLASS1_MODEL_DIRECTORY +from .common import expand_9mer_peptides _allele_model_cache = {} @@ -47,35 +41,34 @@ class Mhc1BindingPredictor(object): self, allele, model_directory=CLASS1_MODEL_DIRECTORY, - max_ic50=5000.0): + max_ic50=MAX_IC50): self.max_ic50 = max_ic50 if not exists(model_directory) or len(listdir(model_directory)) == 0: raise ValueError( "No MHC prediction models found in %s" % (model_directory,)) original_allele_name = allele self.allele = normalize_allele_name(allele) - if self.allele in _allele_model_cache: - self.model = _allele_model_cache[self.allele] - else: - filename = self.allele + ".hdf" - path = join(model_directory, filename) - print("HDF path: %s" % path) - if not exists(path): + if self.allele not in _allele_model_cache: + json_filename = self.allele + ".json" + json_path = join(model_directory, json_filename) + if not exists(json_path): raise ValueError("Unsupported allele: %s" % ( original_allele_name,)) - self.model = make_network( - input_size=9, - embedding_input_dim=20, - embedding_output_dim=EMBEDDING_DIM, - layer_sizes=(HIDDEN_LAYER_SIZE,), - activation=ACTIVATION, - init=INITIALIZATION_METHOD, - dropout_probability=DROPOUT_PROBABILITY, - compile_for_training=True) - print("before", len(self.model.get_weights()), self.model.get_weights()[0][0]) - self.model.load_weights(path) - print("after", len(self.model.get_weights()), self.model.get_weights()[0][0]) + + hdf_filename = self.allele + ".hdf" + hdf_path = join(model_directory, hdf_filename) + + if not exists(hdf_path): + raise ValueError("Missing model weights for allele %s" % ( + original_allele_name,)) + + with open(hdf_path, "r") as f: + self.model = model_from_json(f.read()) + + self.model.load_weights(hdf_path) _allele_model_cache[self.allele] = self.model + else: + self.model = _allele_model_cache[self.allele] def __repr__(self): return "Mhc1BindingPredictor(allele=%s, model_directory=%s)" % ( @@ -105,42 +98,6 @@ class Mhc1BindingPredictor(object): log_y = self._predict_9mer_peptides(peptides) return self._log_to_ic50(log_y) - def _expand_peptides(self, peptides, length): - """ - Expand non-9mer peptides using methods from - Accurate approximation method for prediction of class I MHC - affinities for peptides of length 8, 10 and 11 using prediction - tools trained on 9mers. - by Lundegaard et. al. - http://bioinformatics.oxfordjournals.org/content/24/11/1397 - - Difference from the paper: instead of taking the geometric mean, - we're taking the median of log-transformed IC50 values - """ - assert len(peptides) > 0 - if length < 8 or length > 15: - raise ValueError("Invalid peptide length: %d (%s)" % ( - length, peptides[0])) - elif length == 9: - return peptides - elif length == 8: - # extend each peptide by inserting every possible amino acid - # between base-1 positions 4-8 - return [ - peptide[:i] + extra_amino_acid + peptide[i:] - for peptide in peptides - for i in range(3, 8) - for extra_amino_acid in amino_acid_letters - ] - else: - # drop interior residues between base-1 positions 4-9 - n_skip = length - 9 - return [ - peptide[:i] + peptide[i + n_skip:] - for peptide in peptides - for i in range(3, 9) - ] - def predict_peptides(self, peptides): column_names = [ "Allele", @@ -153,7 +110,7 @@ class Mhc1BindingPredictor(object): for length, group_peptides in groupby(peptides, lambda x: len(x)): group_peptides = list(group_peptides) - expanded_peptides = self._expand_peptides(group_peptides, length) + expanded_peptides = expand_9mer_peptides(group_peptides, length) n_group = len(group_peptides) n_expanded = len(expanded_peptides) expansion_factor = int(n_expanded / n_group)