Skip to content
Snippets Groups Projects
Commit b8e4f2ef authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

support more lengths

parent 428ab2a3
No related merge requests found
......@@ -98,7 +98,13 @@ class Mhc1BindingPredictor(object):
Difference from the paper: instead of taking the geometric mean,
we're taking the median of log-transformed IC50 values
"""
if length == 8:
assert len(peptides) > 0
if length < 8 or length > 15:
raise ValueError("Invalid peptide length: %d (%s)" % (
length, peptides[0]))
elif length == 9:
return peptides
elif length == 8:
# extend each peptide by inserting every possible amino acid
# between base-1 positions 4-8
return [
......@@ -107,27 +113,14 @@ class Mhc1BindingPredictor(object):
for i in xrange(3, 8)
for extra_amino_acid in amino_acid_letters
]
if length == 9:
return peptides
elif length == 10:
else:
# drop interior residues between base-1 positions 4-9
n_skip = length - 9
return [
peptide[:i] + peptide[i + 1:]
peptide[:i] + peptide[i + n_skip:]
for peptide in peptides
for i in range(3, 9)
]
elif length == 11:
# drop pairs of amino acids from interior residues
return [
peptide[:i] + peptide[i + 2:]
for peptide in peptides
for i in range(3, 9)
]
else:
raise ValueError(
"Only lengths 8-11 supported, can't predict %s (len=%d)" % (
peptides[0],
length))
def predict_peptides(self, peptides):
column_names = [
......
......@@ -24,6 +24,7 @@ import argparse
import numpy as np
from mhcflurry.common import normalize_allele_name
from mhcflurry.feedforward import make_network
from mhcflurry.data_helpers import load_data
from mhcflurry.class1_allele_specific_hyperparameters import (
......@@ -88,10 +89,14 @@ if __name__ == "__main__":
activation=ACTIVATION,
init=INITIALIZATION_METHOD,
dropout_probability=DROPOUT_PROBABILITY)
print("Model config: %s" % (model.get_config(),))
model.fit(X_all, Y_all, nb_epoch=N_PRETRAIN_EPOCHS)
old_weights = model.get_weights()
for allele_name, allele_data in allele_groups.items():
allele_name = allele_name.replace("/", "_").replace("*", "").replace(":", "")
allele_name = normalize_allele_name(allele_name)
if allele_name.isdigit():
print("Skipping allele %s" % (allele_name,))
continue
n_allele = len(allele_data.Y)
print("%s: total count = %d" % (allele_name, n_allele))
filename = allele_name + ".hdf"
......@@ -101,10 +106,12 @@ if __name__ == "__main__":
continue
if n_allele < 10:
print("-- too few data points, skipping")
continue
model.set_weights(old_weights)
model.fit(
allele_data.X,
allele_data.Y,
nb_epoch=N_EPOCHS,
show_accuracy=True)
print("Saving model for %s to %s" % (allele_name, path))
model.save_weights(path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment