From 9631e6f2e4b7f304ad1a4b7d3ee4ebe9980ce926 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Mon, 23 Sep 2019 14:03:30 -0400
Subject: [PATCH] update

---
 docs/generate_class1_pan.py                   | 75 ++++++++++---------
 ...test_released_predictors_on_hpv_dataset.py |  3 +
 2 files changed, 42 insertions(+), 36 deletions(-)

diff --git a/docs/generate_class1_pan.py b/docs/generate_class1_pan.py
index 7e70cada..af81ed36 100644
--- a/docs/generate_class1_pan.py
+++ b/docs/generate_class1_pan.py
@@ -36,13 +36,13 @@ parser.add_argument(
 )
 parser.add_argument(
     "--logo-cutoff",
-    default=0.001,
+    default=0.01,
     type=float,
     help="Fraction of top to use for motifs",
 )
 parser.add_argument(
     "--length-cutoff",
-    default=0.0001,
+    default=0.01,
     type=float,
     help="Fraction of top to use for length distribution",
 )
@@ -104,34 +104,39 @@ def model_info(models_dir):
 def write_logo(
         normalized_frequency_matrices,
         allele,
-        length,
+        lengths,
         cutoff,
         models_label,
         out_dir):
-    matrix = normalized_frequency_matrices.loc[
-        (normalized_frequency_matrices.allele == allele) &
-        (normalized_frequency_matrices.length == length) &
-        (normalized_frequency_matrices.cutoff_fraction == cutoff)
-    ].set_index("position")[AMINO_ACIDS]
-    if matrix.shape[0] == 0:
-        return None
 
-    matrix = (matrix.T / matrix.sum(1)).T  # row normalize
-
-    fig = pyplot.figure(figsize=(8,4))
-    ss_logo = logomaker.Logo(
-        matrix,
-        width=.8,
-        vpad=.05,
-        fade_probabilities=True,
-        stack_order='small_on_top',
-        # color_scheme='dodgerblue',
-        # ax=ax,
-    )
-    pyplot.title("%s %d-mer (%s)" % (allele, length, models_label))
-    pyplot.xticks(matrix.index.values)
-    name = "%s-%dmer.%s.png" % (
-        allele.replace("*", "-").replace(":", "-"), length, models_label)
+    fig = pyplot.figure(figsize=(8,10))
+
+    for (i, length) in enumerate(lengths):
+        ax = pyplot.subplot(len(lengths), 1, i + 1)
+        matrix = normalized_frequency_matrices.loc[
+            (normalized_frequency_matrices.allele == allele) &
+            (normalized_frequency_matrices.length == length) &
+            (normalized_frequency_matrices.cutoff_fraction == cutoff)
+        ].set_index("position")[AMINO_ACIDS]
+        if matrix.shape[0] == 0:
+            return None
+
+        matrix = (matrix.T / matrix.sum(1)).T  # row normalize
+
+        ss_logo = logomaker.Logo(
+            matrix,
+            width=.8,
+            vpad=.05,
+            fade_probabilities=True,
+            stack_order='small_on_top',
+            ax=ax,
+        )
+        pyplot.title(
+            "%s %d-mer (%s)" % (allele, length, models_label), y=0.85)
+        pyplot.xticks(matrix.index.values)
+    pyplot.tight_layout()
+    name = "%s.motifs.%s.png" % (
+        allele.replace("*", "-").replace(":", "-"), models_label)
     filename = os.path.abspath(join(out_dir, name))
     pyplot.savefig(filename)
     print("Wrote: ", filename)
@@ -244,16 +249,14 @@ def go(argv):
                 info['observations_per_allele'].get(allele, 0)))
             w("\n")
             w(image(length_distribution_image_path))
-
-            for length in args.motif_lengths:
-                w(image(write_logo(
-                    normalized_frequency_matrices=normalized_frequency_matrices,
-                    allele=allele,
-                    length=length,
-                    cutoff=args.logo_cutoff,
-                    out_dir=args.out_dir,
-                    models_label=label,
-                )))
+            w(image(write_logo(
+                normalized_frequency_matrices=normalized_frequency_matrices,
+                allele=allele,
+                lengths=args.motif_lengths,
+                cutoff=args.logo_cutoff,
+                out_dir=args.out_dir,
+                models_label=label,
+            )))
         w("")
 
     document_path = join(args.out_dir, "allele_motifs.rst")
diff --git a/test/test_released_predictors_on_hpv_dataset.py b/test/test_released_predictors_on_hpv_dataset.py
index 26bf0227..25be5823 100644
--- a/test/test_released_predictors_on_hpv_dataset.py
+++ b/test/test_released_predictors_on_hpv_dataset.py
@@ -61,6 +61,9 @@ def test_on_hpv(df=DF):
     scores_df = scores_df.pivot(
         index="metric", columns="predictor", values="score")
 
+    print("Predictions")
+    print(df)
+
     print("Scores")
     print(scores_df)
 
-- 
GitLab