From b8e4f2ef3082bf7fe73de340ba5a7e8147e71f83 Mon Sep 17 00:00:00 2001
From: Alex Rubinsteyn <alex.rubinsteyn@gmail.com>
Date: Wed, 1 Jul 2015 10:47:47 -0400
Subject: [PATCH] support more lengths

---
 mhcflurry/mhc1_binding_predictor.py           | 27 +++++++------------
 .../train-class1-allele-specific-models.py    |  9 ++++++-
 2 files changed, 18 insertions(+), 18 deletions(-)

diff --git a/mhcflurry/mhc1_binding_predictor.py b/mhcflurry/mhc1_binding_predictor.py
index 58f54e92..d3e437e5 100644
--- a/mhcflurry/mhc1_binding_predictor.py
+++ b/mhcflurry/mhc1_binding_predictor.py
@@ -98,7 +98,13 @@ class Mhc1BindingPredictor(object):
         Difference from the paper: instead of taking the geometric mean,
         we're taking the median of log-transformed IC50 values
         """
-        if length == 8:
+        assert len(peptides) > 0
+        if length < 8 or length > 15:
+            raise ValueError("Invalid peptide length: %d (%s)" % (
+                length, peptides[0]))
+        elif length == 9:
+            return peptides
+        elif length == 8:
             # extend each peptide by inserting every possible amino acid
             # between base-1 positions 4-8
             return [
@@ -107,27 +113,14 @@ class Mhc1BindingPredictor(object):
                 for i in xrange(3, 8)
                 for extra_amino_acid in amino_acid_letters
             ]
-        if length == 9:
-            return peptides
-        elif length == 10:
+        else:
             # drop interior residues between base-1 positions 4-9
+            n_skip = length - 9
             return [
-                peptide[:i] + peptide[i + 1:]
+                peptide[:i] + peptide[i + n_skip:]
                 for peptide in peptides
                 for i in range(3, 9)
             ]
-        elif length == 11:
-            # drop pairs of amino acids from interior residues
-            return [
-                peptide[:i] + peptide[i + 2:]
-                for peptide in peptides
-                for i in range(3, 9)
-            ]
-        else:
-            raise ValueError(
-                "Only lengths 8-11 supported, can't predict %s (len=%d)" % (
-                    peptides[0],
-                    length))
 
     def predict_peptides(self, peptides):
         column_names = [
diff --git a/scripts/train-class1-allele-specific-models.py b/scripts/train-class1-allele-specific-models.py
index e42e8b2d..e422588c 100755
--- a/scripts/train-class1-allele-specific-models.py
+++ b/scripts/train-class1-allele-specific-models.py
@@ -24,6 +24,7 @@ import argparse
 
 import numpy as np
 
+from mhcflurry.common import normalize_allele_name
 from mhcflurry.feedforward import make_network
 from mhcflurry.data_helpers import load_data
 from mhcflurry.class1_allele_specific_hyperparameters import (
@@ -88,10 +89,14 @@ if __name__ == "__main__":
         activation=ACTIVATION,
         init=INITIALIZATION_METHOD,
         dropout_probability=DROPOUT_PROBABILITY)
+    print("Model config: %s" % (model.get_config(),))
     model.fit(X_all, Y_all, nb_epoch=N_PRETRAIN_EPOCHS)
     old_weights = model.get_weights()
     for allele_name, allele_data in allele_groups.items():
-        allele_name = allele_name.replace("/", "_").replace("*", "").replace(":", "")
+        allele_name = normalize_allele_name(allele_name)
+        if allele_name.isdigit():
+            print("Skipping allele %s" % (allele_name,))
+            continue
         n_allele = len(allele_data.Y)
         print("%s: total count = %d" % (allele_name, n_allele))
         filename = allele_name + ".hdf"
@@ -101,10 +106,12 @@ if __name__ == "__main__":
             continue
         if n_allele < 10:
             print("-- too few data points, skipping")
+            continue
         model.set_weights(old_weights)
         model.fit(
             allele_data.X,
             allele_data.Y,
             nb_epoch=N_EPOCHS,
             show_accuracy=True)
+        print("Saving model for %s to %s" % (allele_name, path))
         model.save_weights(path)
-- 
GitLab