From 716e2251f338c30b4ce44ceab032882976721332 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Fri, 13 Dec 2019 12:40:52 -0500
Subject: [PATCH] test

---
 .../models_class1_pan_refined/GENERATE.sh     |  3 ++-
 .../make_multiallelic_training_data.py        | 23 ++++++++++++++++---
 mhcflurry/downloads.yml                       |  2 +-
 3 files changed, 23 insertions(+), 5 deletions(-)

diff --git a/downloads-generation/models_class1_pan_refined/GENERATE.sh b/downloads-generation/models_class1_pan_refined/GENERATE.sh
index a18b9389..17d63e6f 100755
--- a/downloads-generation/models_class1_pan_refined/GENERATE.sh
+++ b/downloads-generation/models_class1_pan_refined/GENERATE.sh
@@ -97,7 +97,8 @@ else
         --hits "$(mhcflurry-downloads path data_mass_spec_annotated)/annotated_ms.csv.bz2" \
         --expression "$(mhcflurry-downloads path data_curated)/rna_expression.csv.bz2" \
         --decoys-per-hit 1 \
-        --out train.multiallelic.csv
+        --out train.multiallelic.csv \
+        --alleles "HLA-A*02:01" "HLA-B*27:01" "HLA-C*07:01" "HLA-A*03:01" "HLA-B*15:01" "HLA-C*01:02"
 fi
 
 ALLELE_LIST=$(bzcat "$MONOALLELIC_TRAIN" | cut -f 1 -d , | grep -v allele | uniq | sort | uniq)
diff --git a/downloads-generation/models_class1_pan_refined/make_multiallelic_training_data.py b/downloads-generation/models_class1_pan_refined/make_multiallelic_training_data.py
index df14ce6f..998d8dd4 100644
--- a/downloads-generation/models_class1_pan_refined/make_multiallelic_training_data.py
+++ b/downloads-generation/models_class1_pan_refined/make_multiallelic_training_data.py
@@ -39,12 +39,17 @@ parser.add_argument(
     metavar="CSV",
     required=True,
     help="File to write")
-
+parser.add_argument(
+    "--alleles",
+    nargs="+",
+    help="Include only the specified alleles")
 
 def run():
     args = parser.parse_args(sys.argv[1:])
     hit_df = pandas.read_csv(args.hits)
     expression_df = pandas.read_csv(args.expression, index_col=0).fillna(0)
+    hit_df["alleles"] = hit_df.hla.str.split()
+
     hit_df = hit_df.loc[
         (hit_df.mhc_class == "I") &
         (hit_df.peptide.str.len() <= 15) &
@@ -64,8 +69,20 @@ def run():
             "to",
             len(new_hit_df))
         hit_df = new_hit_df.copy()
-
-    hit_df["alleles"] = hit_df.hla.str.split()
+    if args.alleles:
+        filter_alleles = set(args.alleles)
+        new_hit_df = hit_df.loc[
+            hit_df.alleles.isin.map(
+                lambda a: len(set(a).intersection(filter_alleles)) > 0)
+        ]
+        print(
+            "Selecting alleles",
+            args.alleles,
+            "reduced dataset from",
+            len(hit_df),
+            "to",
+            len(new_hit_df))
+        hit_df = new_hit_df.copy()
 
     sample_table = hit_df.drop_duplicates("sample_id").set_index("sample_id")
     grouped = hit_df.groupby("sample_id").nunique()
diff --git a/mhcflurry/downloads.yml b/mhcflurry/downloads.yml
index 841372dc..1200b72e 100644
--- a/mhcflurry/downloads.yml
+++ b/mhcflurry/downloads.yml
@@ -30,7 +30,7 @@ releases:
               default: false
 
             - name: models_class1_pan_refined
-              url: https://github.com/openvax/mhcflurry/releases/download/1.4.0/models_class1_pan_refined.20191212b.tar.bz2
+              url: https://github.com/openvax/mhcflurry/releases/download/1.4.0/models_class1_pan_refined.20191212c.tar.bz2
               default: false
 
             - name: models_class1_pan_variants
-- 
GitLab