From b8e4f2ef3082bf7fe73de340ba5a7e8147e71f83 Mon Sep 17 00:00:00 2001 From: Alex Rubinsteyn <alex.rubinsteyn@gmail.com> Date: Wed, 1 Jul 2015 10:47:47 -0400 Subject: [PATCH] support more lengths --- mhcflurry/mhc1_binding_predictor.py | 27 +++++++------------ .../train-class1-allele-specific-models.py | 9 ++++++- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/mhcflurry/mhc1_binding_predictor.py b/mhcflurry/mhc1_binding_predictor.py index 58f54e92..d3e437e5 100644 --- a/mhcflurry/mhc1_binding_predictor.py +++ b/mhcflurry/mhc1_binding_predictor.py @@ -98,7 +98,13 @@ class Mhc1BindingPredictor(object): Difference from the paper: instead of taking the geometric mean, we're taking the median of log-transformed IC50 values """ - if length == 8: + assert len(peptides) > 0 + if length < 8 or length > 15: + raise ValueError("Invalid peptide length: %d (%s)" % ( + length, peptides[0])) + elif length == 9: + return peptides + elif length == 8: # extend each peptide by inserting every possible amino acid # between base-1 positions 4-8 return [ @@ -107,27 +113,14 @@ class Mhc1BindingPredictor(object): for i in xrange(3, 8) for extra_amino_acid in amino_acid_letters ] - if length == 9: - return peptides - elif length == 10: + else: # drop interior residues between base-1 positions 4-9 + n_skip = length - 9 return [ - peptide[:i] + peptide[i + 1:] + peptide[:i] + peptide[i + n_skip:] for peptide in peptides for i in range(3, 9) ] - elif length == 11: - # drop pairs of amino acids from interior residues - return [ - peptide[:i] + peptide[i + 2:] - for peptide in peptides - for i in range(3, 9) - ] - else: - raise ValueError( - "Only lengths 8-11 supported, can't predict %s (len=%d)" % ( - peptides[0], - length)) def predict_peptides(self, peptides): column_names = [ diff --git a/scripts/train-class1-allele-specific-models.py b/scripts/train-class1-allele-specific-models.py index e42e8b2d..e422588c 100755 --- a/scripts/train-class1-allele-specific-models.py +++ b/scripts/train-class1-allele-specific-models.py @@ -24,6 +24,7 @@ import argparse import numpy as np +from mhcflurry.common import normalize_allele_name from mhcflurry.feedforward import make_network from mhcflurry.data_helpers import load_data from mhcflurry.class1_allele_specific_hyperparameters import ( @@ -88,10 +89,14 @@ if __name__ == "__main__": activation=ACTIVATION, init=INITIALIZATION_METHOD, dropout_probability=DROPOUT_PROBABILITY) + print("Model config: %s" % (model.get_config(),)) model.fit(X_all, Y_all, nb_epoch=N_PRETRAIN_EPOCHS) old_weights = model.get_weights() for allele_name, allele_data in allele_groups.items(): - allele_name = allele_name.replace("/", "_").replace("*", "").replace(":", "") + allele_name = normalize_allele_name(allele_name) + if allele_name.isdigit(): + print("Skipping allele %s" % (allele_name,)) + continue n_allele = len(allele_data.Y) print("%s: total count = %d" % (allele_name, n_allele)) filename = allele_name + ".hdf" @@ -101,10 +106,12 @@ if __name__ == "__main__": continue if n_allele < 10: print("-- too few data points, skipping") + continue model.set_weights(old_weights) model.fit( allele_data.X, allele_data.Y, nb_epoch=N_EPOCHS, show_accuracy=True) + print("Saving model for %s to %s" % (allele_name, path)) model.save_weights(path) -- GitLab