From cf1c077e0bebd79c082d623ae7fdf4fafc7f39ec Mon Sep 17 00:00:00 2001
From: Alex Rubinsteyn <alex.rubinsteyn@gmail.com>
Date: Tue, 28 Jul 2015 15:17:16 -0400
Subject: [PATCH] print dataset size for each allele

---
 mhcflurry/paths.py                        |  5 +++-
 scripts/create-combined-class1-dataset.py |  7 ++---
 scripts/print-class1-alleles.py           | 36 +++++++++++++++++++----
 3 files changed, 36 insertions(+), 12 deletions(-)

diff --git a/mhcflurry/paths.py b/mhcflurry/paths.py
index 2a587612..3e58e57c 100644
--- a/mhcflurry/paths.py
+++ b/mhcflurry/paths.py
@@ -17,4 +17,7 @@ from appdirs import user_data_dir
 
 BASE_DIRECTORY = user_data_dir("mhcflurry", version="0.1")
 CLASS1_DATA_DIRECTORY = join(BASE_DIRECTORY, "class1_data")
-CLASS1_MODEL_DIRECTORY = join(BASE_DIRECTORY, "class1_models")
\ No newline at end of file
+CLASS1_MODEL_DIRECTORY = join(BASE_DIRECTORY, "class1_models")
+
+CLASS1_DATA_CSV_FILENAME = "combined_human_class1_dataset.csv"
+CLASS1_DATA_CSV_PATH = join(CLASS1_DATA_DIRECTORY, CLASS1_DATA_CSV_FILENAME)
diff --git a/scripts/create-combined-class1-dataset.py b/scripts/create-combined-class1-dataset.py
index 0395b20c..ab60e598 100755
--- a/scripts/create-combined-class1-dataset.py
+++ b/scripts/create-combined-class1-dataset.py
@@ -19,7 +19,7 @@ import argparse
 
 import pandas as pd
 
-from mhcflurry.paths import CLASS1_DATA_DIRECTORY
+from mhcflurry.paths import CLASS1_DATA_DIRECTORY, CLASS1_DATA_CSV_PATH
 
 IEDB_PICKLE_FILENAME = "iedb_human_class1_assay_datasets.pickle"
 IEDB_PICKLE_PATH = join(CLASS1_DATA_DIRECTORY, IEDB_PICKLE_FILENAME)
@@ -27,9 +27,6 @@ IEDB_PICKLE_PATH = join(CLASS1_DATA_DIRECTORY, IEDB_PICKLE_FILENAME)
 PETERS_CSV_FILENAME = "bdata.20130222.mhci.public.1.txt"
 PETERS_CSV_PATH = join(CLASS1_DATA_DIRECTORY, PETERS_CSV_FILENAME)
 
-OUTPUT_CSV_FILENAME = "combined_human_class1_dataset.csv"
-OUTPUT_CSV_PATH = join(CLASS1_DATA_DIRECTORY, OUTPUT_CSV_FILENAME)
-
 parser = argparse.ArgumentParser()
 
 parser.add_argument("--ic50-fraction-tolerance",
@@ -59,7 +56,7 @@ parser.add_argument("--netmhcpan-csv-path",
     help="Path to CSV with NetMHCpan dataset from 2013 Peters paper")
 
 parser.add_argument("--output-csv-path",
-    default=OUTPUT_CSV_PATH,
+    default=CLASS1_DATA_CSV_PATH,
     help="Path to CSV of combined assay results")
 
 parser.add_argument("--extra-dataset-csv-path",
diff --git a/scripts/print-class1-alleles.py b/scripts/print-class1-alleles.py
index 1d1b984e..bc130533 100755
--- a/scripts/print-class1-alleles.py
+++ b/scripts/print-class1-alleles.py
@@ -22,7 +22,9 @@ trained models are available
 import argparse
 import os
 
-from mhcflurry.paths import CLASS1_MODEL_DIRECTORY
+import pandas as pd
+
+from mhcflurry.paths import CLASS1_MODEL_DIRECTORY, CLASS1_DATA_CSV_PATH
 
 parser = argparse.ArgumentParser()
 parser.add_argument(
@@ -30,17 +32,39 @@ parser.add_argument(
     default=False,
     action="store_true")
 
+parser.add_argument("--with-dataset-size",
+    default=False,
+    action="store_true")
+
+parser.add_argument("--all",
+    default=False,
+    action="store_true",
+    help="Include serotypes (like 'A2') which include multiple 4-digit types")
 
 if __name__ == "__main__":
     args = parser.parse_args()
+    if args.with_dataset_size:
+        df = pd.read_csv(CLASS1_DATA_CSV_PATH)
+        allele_sizes = {
+            allele: len(group) for (allele, group) in df.groupby("mhc")
+        }
+    else:
+        allele_sizes = None
 
     for filename in os.listdir(CLASS1_MODEL_DIRECTORY):
         allele = filename.replace(".hdf", "")
-        if len(allele) < 5:
+        if len(allele) >= 5:
+            allele = "HLA-%s*%s:%s" % (allele[0], allele[1:3], allele[3:])
+        elif args.all:
+            allele = "HLA-%s" % allele
+        else:
             # skipping serotype names like A2 or B7
             continue
-        allele = "HLA-%s*%s:%s" % (allele[0], allele[1:3], allele[3:])
+
+        line = allele
+
         if args.with_peptide_lengths:
-            print("%s\t8,9,10,11,12" % allele)
-        else:
-            print(allele)
\ No newline at end of file
+            line += "\t8,9,10,11,12"
+        if args.with_dataset_size:
+            line += "\t%d" % allele_sizes[allele]
+        print(line)
\ No newline at end of file
-- 
GitLab