From 1e884d0e0e936a1d90a689a117d6e67d6ed826d1 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Mon, 9 Dec 2019 16:16:41 -0500
Subject: [PATCH] fixes

---
 docs/generate_class1_pan.py                   | 37 ++++++++++++-------
 .../models_class1_pan_refined/GENERATE.sh     | 37 +++++++++++++++++--
 .../hyperparameters.yaml                      |  5 +++
 .../make_multiallelic_training_data.py        | 18 +++++++++
 mhcflurry/downloads.yml                       |  4 ++
 mhcflurry/multiallelic_refinement_command.py  |  6 ++-
 6 files changed, 90 insertions(+), 17 deletions(-)

diff --git a/docs/generate_class1_pan.py b/docs/generate_class1_pan.py
index af81ed36..dff04069 100644
--- a/docs/generate_class1_pan.py
+++ b/docs/generate_class1_pan.py
@@ -34,6 +34,13 @@ parser.add_argument(
         "models_class1_pan", "models.no_mass_spec", test_exists=False),
     help="Class1 models. Default: %(default)s",
 )
+parser.add_argument(
+    "--class1-models-dir-refined",
+    metavar="DIR",
+    default=get_path(
+        "models_class1_pan_refined", "models.affinity", test_exists=False),
+    help="Class1 refined models. Default: %(default)s",
+)
 parser.add_argument(
     "--logo-cutoff",
     default=0.01,
@@ -80,8 +87,13 @@ def model_info(models_dir):
         join(models_dir, "length_distributions.csv.bz2"))
     frequency_matrices_df = pandas.read_csv(
         join(models_dir, "frequency_matrices.csv.bz2"))
-    train_data_df = pandas.read_csv(
-        join(models_dir, "train_data.csv.bz2"))
+    try:
+        train_data_df = pandas.read_csv(
+            join(models_dir, "train_data.csv.bz2"))
+        observations_per_allele = (
+            train_data_df.groupby("allele").peptide.nunique().to_dict())
+    except IOError:
+        observations_per_allele = None
 
     distribution = frequency_matrices_df.loc[
         (frequency_matrices_df.cutoff_fraction == 1.0), AMINO_ACIDS
@@ -91,9 +103,6 @@ def model_info(models_dir):
     normalized_frequency_matrices.loc[:, AMINO_ACIDS] = (
             normalized_frequency_matrices[AMINO_ACIDS] / distribution)
 
-    observations_per_allele = (
-        train_data_df.groupby("allele").peptide.nunique().to_dict())
-
     return {
         'length_distributions': length_distributions_df,
         'normalized_frequency_matrices': normalized_frequency_matrices,
@@ -182,6 +191,7 @@ def go(argv):
 
     predictors = [
         ("with_mass_spec", args.class1_models_dir_with_ms),
+        ("refined", args.class1_models_dir_refined),
         ("no_mass_spec", args.class1_models_dir_no_ms),
     ]
     info_per_predictor = OrderedDict()
@@ -240,14 +250,15 @@ def go(argv):
                 models_label=label)
             if not length_distribution_image_path:
                 continue
-
-            w(
-                "*" + (
-                    "With mass-spec" if label == "with_mass_spec" else "Affinities only")
-                + "*\n")
-            w("Training observations (unique peptides): %d" % (
-                info['observations_per_allele'].get(allele, 0)))
-            w("\n")
+            w("*%s*\n" % {
+                "with_mass_spec": "With mass-spec",
+                "no_mass_spec": "Affinities only",
+                "refined": "With mass-spec after multiallelic refinement",
+            }[label])
+            if info['observations_per_allele'] is not None:
+                w("Training observations (unique peptides): %d" % (
+                    info['observations_per_allele'].get(allele, 0)))
+                w("\n")
             w(image(length_distribution_image_path))
             w(image(write_logo(
                 normalized_frequency_matrices=normalized_frequency_matrices,
diff --git a/downloads-generation/models_class1_pan_refined/GENERATE.sh b/downloads-generation/models_class1_pan_refined/GENERATE.sh
index 7b1e0db5..fbb9f502 100755
--- a/downloads-generation/models_class1_pan_refined/GENERATE.sh
+++ b/downloads-generation/models_class1_pan_refined/GENERATE.sh
@@ -54,12 +54,43 @@ export PYTHONUNBUFFERED=1
 cp $SCRIPT_DIR/make_multiallelic_training_data.py .
 cp $SCRIPT_DIR/hyperparameters.yaml .
 
+MONOALLELIC_TRAIN="$(mhcflurry-downloads path models_class1_pan)/models.with_mass_spec/train_data.csv.bz2"
+
+# ********************************************************
+# First we refine a single model excluding chromosome 1.
+echo "Beginning testing run."
 time python make_multiallelic_training_data.py \
     --hits "$(mhcflurry-downloads path data_mass_spec_annotated)/annotated_ms.csv.bz2" \
     --expression "$(mhcflurry-downloads path data_curated)/rna_expression.csv.bz2" \
-    --out train.multiallelic.csv
+    --exclude-contig "1" \
+    --out train.multiallelic.no_chr1.csv
 
-MONOALLELIC_TRAIN="$(mhcflurry-downloads path models_class1_pan)/models.with_mass_spec/train_data.csv.bz2"
+time mhcflurry-multiallelic-refinement \
+    --monoallelic-data "$MONOALLELIC_TRAIN" \
+    --multiallelic-data train.multiallelic.no_chr1.csv \
+    --models-dir "$(mhcflurry-downloads path models_class1_pan)/models.with_mass_spec" \
+    --max-models 1 \
+    --hyperparameters hyperparameters.yaml \
+    --out-affinity-predictor-dir $(pwd)/test_models.no_chr1.affinity \
+    --out-presentation-predictor-dir $(pwd)/test_models.no_chr1.presentation \
+    --worker-log-dir "$SCRATCH_DIR/$DOWNLOAD_NAME" \
+    $PARALLELISM_ARGS
+
+time mhcflurry-calibrate-percentile-ranks \
+    --models-dir $(pwd)/test_models.no_chr1.affinity   \
+    --match-amino-acid-distribution-data "$MONOALLELIC_TRAIN" \
+    --motif-summary \
+    --num-peptides-per-length 100000 \
+    --allele "HLA-A*02:01" "HLA-A*02:20" "HLA-C*02:10" \
+    --verbosity 1 \
+    $PARALLELISM_ARGS
+
+# ********************************************************
+echo "Beginning production run"
+time python make_multiallelic_training_data.py \
+    --hits "$(mhcflurry-downloads path data_mass_spec_annotated)/annotated_ms.csv.bz2" \
+    --expression "$(mhcflurry-downloads path data_curated)/rna_expression.csv.bz2" \
+    --out train.multiallelic.csv
 
 ALLELE_LIST=$(bzcat "$MONOALLELIC_TRAIN" | cut -f 1 -d , | grep -v allele | uniq | sort | uniq)
 ALLELE_LIST+=$(cat train.multiallelic.csv | cut -f 7 -d , | gerp -v hla | uniq | tr ' ' '\n' | sort | uniq)
@@ -86,7 +117,7 @@ time mhcflurry-calibrate-percentile-ranks \
 
 echo "Done training."
 
-rm train.multiallelic.csv
+rm train.multiallelic.*
 
 cp $SCRIPT_ABSOLUTE_PATH .
 bzip2 -f "$LOG"
diff --git a/downloads-generation/models_class1_pan_refined/hyperparameters.yaml b/downloads-generation/models_class1_pan_refined/hyperparameters.yaml
index e8b53748..8e40e769 100644
--- a/downloads-generation/models_class1_pan_refined/hyperparameters.yaml
+++ b/downloads-generation/models_class1_pan_refined/hyperparameters.yaml
@@ -5,3 +5,8 @@ batch_generator_validation_split: 0.1
 batch_generator_batch_size: 1024
 batch_generator_affinity_fraction: 0.5
 max_epochs: 500
+random_negative_rate: 1.0
+random_negative_constant: 25
+learning_rate: 0.0001
+patience: 5
+min_delta: 0.0
\ No newline at end of file
diff --git a/downloads-generation/models_class1_pan_refined/make_multiallelic_training_data.py b/downloads-generation/models_class1_pan_refined/make_multiallelic_training_data.py
index 635fbbae..df14ce6f 100644
--- a/downloads-generation/models_class1_pan_refined/make_multiallelic_training_data.py
+++ b/downloads-generation/models_class1_pan_refined/make_multiallelic_training_data.py
@@ -31,12 +31,16 @@ parser.add_argument(
     type=int,
     default=None,
     help="If not specified will use all possible decoys")
+parser.add_argument(
+    "--exclude-contig",
+    help="Exclude entries annotated to the given contig")
 parser.add_argument(
     "--out",
     metavar="CSV",
     required=True,
     help="File to write")
 
+
 def run():
     args = parser.parse_args(sys.argv[1:])
     hit_df = pandas.read_csv(args.hits)
@@ -47,6 +51,20 @@ def run():
         (hit_df.peptide.str.len() >= 7) &
         (~hit_df.protein_ensembl.isnull())
     ]
+    if args.exclude_contig:
+        new_hit_df = hit_df.loc[
+            hit_df.protein_primary_ensembl_contig.astype(str) !=
+            args.exclude_contig
+        ]
+        print(
+            "Excluding contig",
+            args.exclude_contig,
+            "reduced dataset from",
+            len(hit_df),
+            "to",
+            len(new_hit_df))
+        hit_df = new_hit_df.copy()
+
     hit_df["alleles"] = hit_df.hla.str.split()
 
     sample_table = hit_df.drop_duplicates("sample_id").set_index("sample_id")
diff --git a/mhcflurry/downloads.yml b/mhcflurry/downloads.yml
index 24b26730..421c7e5e 100644
--- a/mhcflurry/downloads.yml
+++ b/mhcflurry/downloads.yml
@@ -29,6 +29,10 @@ releases:
                 - https://github.com/openvax/mhcflurry/releases/download/pre-1.4.0/models_class1_pan_unselected.20190924.tar.bz2.part.aa
               default: false
 
+            - name: models_class1_pan_refined
+              url: https://github.com/openvax/mhcflurry/releases/download/1.4.0/models_class1_pan_refined.20191209.tar.bz2
+              default: false
+
             - name: models_class1_pan_variants
               part_urls:
                 - https://github.com/openvax/mhcflurry/releases/download/1.4.0/models_class1_pan_variants.20191101.tar.bz2.part.aa
diff --git a/mhcflurry/multiallelic_refinement_command.py b/mhcflurry/multiallelic_refinement_command.py
index e059e112..ac302861 100644
--- a/mhcflurry/multiallelic_refinement_command.py
+++ b/mhcflurry/multiallelic_refinement_command.py
@@ -68,6 +68,10 @@ parser.add_argument(
     metavar="DIR",
     required=True,
     help="Directory to write preentation predictor")
+parser.add_argument(
+    "--max-models",
+    type=int,
+    default=None)
 parser.add_argument(
     "--verbosity",
     type=int,
@@ -103,7 +107,7 @@ def run(argv=sys.argv[1:]):
     print("Loaded monoallelic data: %s" % (str(monoallelic_df.shape)))
 
     input_predictor = Class1AffinityPredictor.load(
-        args.models_dir, optimization_level=0)
+        args.models_dir, optimization_level=0, max_models=args.max_models)
     print("Loaded: %s" % input_predictor)
 
     sample_table = multiallelic_df.drop_duplicates(
-- 
GitLab