Skip to content
Snippets Groups Projects
Commit b8e4f2ef authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

support more lengths

parent 428ab2a3
No related branches found
No related tags found
No related merge requests found
......@@ -98,7 +98,13 @@ class Mhc1BindingPredictor(object):
Difference from the paper: instead of taking the geometric mean,
we're taking the median of log-transformed IC50 values
"""
if length == 8:
assert len(peptides) > 0
if length < 8 or length > 15:
raise ValueError("Invalid peptide length: %d (%s)" % (
length, peptides[0]))
elif length == 9:
return peptides
elif length == 8:
# extend each peptide by inserting every possible amino acid
# between base-1 positions 4-8
return [
......@@ -107,27 +113,14 @@ class Mhc1BindingPredictor(object):
for i in xrange(3, 8)
for extra_amino_acid in amino_acid_letters
]
if length == 9:
return peptides
elif length == 10:
else:
# drop interior residues between base-1 positions 4-9
n_skip = length - 9
return [
peptide[:i] + peptide[i + 1:]
peptide[:i] + peptide[i + n_skip:]
for peptide in peptides
for i in range(3, 9)
]
elif length == 11:
# drop pairs of amino acids from interior residues
return [
peptide[:i] + peptide[i + 2:]
for peptide in peptides
for i in range(3, 9)
]
else:
raise ValueError(
"Only lengths 8-11 supported, can't predict %s (len=%d)" % (
peptides[0],
length))
def predict_peptides(self, peptides):
column_names = [
......
......@@ -24,6 +24,7 @@ import argparse
import numpy as np
from mhcflurry.common import normalize_allele_name
from mhcflurry.feedforward import make_network
from mhcflurry.data_helpers import load_data
from mhcflurry.class1_allele_specific_hyperparameters import (
......@@ -88,10 +89,14 @@ if __name__ == "__main__":
activation=ACTIVATION,
init=INITIALIZATION_METHOD,
dropout_probability=DROPOUT_PROBABILITY)
print("Model config: %s" % (model.get_config(),))
model.fit(X_all, Y_all, nb_epoch=N_PRETRAIN_EPOCHS)
old_weights = model.get_weights()
for allele_name, allele_data in allele_groups.items():
allele_name = allele_name.replace("/", "_").replace("*", "").replace(":", "")
allele_name = normalize_allele_name(allele_name)
if allele_name.isdigit():
print("Skipping allele %s" % (allele_name,))
continue
n_allele = len(allele_data.Y)
print("%s: total count = %d" % (allele_name, n_allele))
filename = allele_name + ".hdf"
......@@ -101,10 +106,12 @@ if __name__ == "__main__":
continue
if n_allele < 10:
print("-- too few data points, skipping")
continue
model.set_weights(old_weights)
model.fit(
allele_data.X,
allele_data.Y,
nb_epoch=N_EPOCHS,
show_accuracy=True)
print("Saving model for %s to %s" % (allele_name, path))
model.save_weights(path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment