From 7f1e6c99ca3a113ee66af575079bcec0fba25036 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Thu, 27 Jul 2017 12:17:02 -0400
Subject: [PATCH] Script and experiment updates

* Add more experiments to models_class1_experiments
* Add `--quantitative-only` flag to training script
* Add `--no-throw` flag to prediction script
---
 .../models_class1_experiments1/GENERATE.sh    | 13 +++++++++-
 .../hyperparameters-0local.json               |  2 +-
 .../hyperparameters-0local_noL1.json          | 26 +++++++++++++++++++
 .../hyperparameters-standard.json             |  1 +
 .../train_allele_specific_models_command.py   | 11 ++++++++
 mhcflurry/predict_command.py                  |  8 +++++-
 6 files changed, 58 insertions(+), 3 deletions(-)
 create mode 100644 downloads-generation/models_class1_experiments1/hyperparameters-0local_noL1.json
 create mode 120000 downloads-generation/models_class1_experiments1/hyperparameters-standard.json

diff --git a/downloads-generation/models_class1_experiments1/GENERATE.sh b/downloads-generation/models_class1_experiments1/GENERATE.sh
index b92b2701..989fb5ce 100755
--- a/downloads-generation/models_class1_experiments1/GENERATE.sh
+++ b/downloads-generation/models_class1_experiments1/GENERATE.sh
@@ -23,7 +23,18 @@ git status
 
 cd $SCRATCH_DIR/$DOWNLOAD_NAME
 
-for mod in 0local 1local dense16 dense64 noL1
+# Standard architecture on quantitative only
+cp $SCRIPT_DIR/hyperparameters-standard.json .
+mkdir models-standard-quantitative
+time mhcflurry-class1-train-allele-specific-models \
+    --data "$(mhcflurry-downloads path data_curated)/curated_training_data.csv.bz2" \
+    --only-quantitative \
+    --hyperparameters hyperparameters-standard.json \
+    --out-models-dir models-standard-quantitative \
+    --min-measurements-per-allele 100
+
+# Model variations on qualitative + quantitative
+for mod in 0local_noL1 0local 1local dense16 dense64 noL1 
 do
     cp $SCRIPT_DIR/hyperparameters-${mod}.json .
     mkdir models-${mod}
diff --git a/downloads-generation/models_class1_experiments1/hyperparameters-0local.json b/downloads-generation/models_class1_experiments1/hyperparameters-0local.json
index 82f5a361..b39d5b60 100644
--- a/downloads-generation/models_class1_experiments1/hyperparameters-0local.json
+++ b/downloads-generation/models_class1_experiments1/hyperparameters-0local.json
@@ -23,4 +23,4 @@
         "dense_layer_l1_regularization": 0.001,
         "dropout_probability": 0.0
     }
-]
\ No newline at end of file
+]
diff --git a/downloads-generation/models_class1_experiments1/hyperparameters-0local_noL1.json b/downloads-generation/models_class1_experiments1/hyperparameters-0local_noL1.json
new file mode 100644
index 00000000..82f5a361
--- /dev/null
+++ b/downloads-generation/models_class1_experiments1/hyperparameters-0local_noL1.json
@@ -0,0 +1,26 @@
+[
+    {
+        "n_models": 8,
+        "max_epochs": 500,
+        "patience": 10,
+        "early_stopping": true,
+        "validation_split": 0.2,
+
+        "random_negative_rate": 0.0,
+        "random_negative_constant": 25,
+
+        "use_embedding": false,
+        "kmer_size": 15,
+        "batch_normalization": false,
+        "locally_connected_layers": [],
+        "activation": "relu",
+        "output_activation": "sigmoid",
+        "layer_sizes": [
+            32
+        ],
+        "random_negative_affinity_min": 20000.0,
+        "random_negative_affinity_max": 50000.0,
+        "dense_layer_l1_regularization": 0.001,
+        "dropout_probability": 0.0
+    }
+]
\ No newline at end of file
diff --git a/downloads-generation/models_class1_experiments1/hyperparameters-standard.json b/downloads-generation/models_class1_experiments1/hyperparameters-standard.json
new file mode 120000
index 00000000..8d78d631
--- /dev/null
+++ b/downloads-generation/models_class1_experiments1/hyperparameters-standard.json
@@ -0,0 +1 @@
+../models_class1/hyperparameters.json
\ No newline at end of file
diff --git a/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py b/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py
index 212e8918..6e03489d 100644
--- a/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py
+++ b/mhcflurry/class1_affinity_prediction/train_allele_specific_models_command.py
@@ -43,6 +43,11 @@ parser.add_argument(
     metavar="N",
     default=50,
     help="Train models for alleles with >=N measurements.")
+parser.add_argument(
+    "--only-quantitative",
+    action="store_true",
+    default=False,
+    help="Use only quantitative training data")
 parser.add_argument(
     "--verbosity",
     type=int,
@@ -67,6 +72,12 @@ def run(argv=sys.argv[1:]):
     ]
     print("Subselected to 8-15mers: %s" % (str(df.shape)))
 
+    if args.only_quantitative:
+        df = df.loc[
+            df.measurement_type == "quantitative"
+        ]
+        print("Subselected to quantitative: %s" % (str(df.shape)))
+
     allele_counts = df.allele.value_counts()
 
     if args.allele:
diff --git a/mhcflurry/predict_command.py b/mhcflurry/predict_command.py
index 40cb7130..f3a4a201 100644
--- a/mhcflurry/predict_command.py
+++ b/mhcflurry/predict_command.py
@@ -105,6 +105,11 @@ input_mod_args.add_argument(
     metavar="NAME",
     default="peptide",
     help="Input column name for peptides. Default: '%(default)s'")
+input_mod_args.add_argument(
+    "--no-throw",
+    action="store_true",
+    default=False,
+    help="Return NaNs for unsupported alleles or peptides instead of raising")
 
 
 output_args = parser.add_argument_group(title="Optional output modifiers")
@@ -200,7 +205,8 @@ def run(argv=sys.argv[1:]):
     predictions = predictor.predict_to_dataframe(
         peptides=df[args.peptide_column].values,
         alleles=df[args.allele_column].values,
-        include_individual_model_predictions=args.include_individual_model_predictions)
+        include_individual_model_predictions=args.include_individual_model_predictions,
+        throw=not args.no_throw)
 
     for col in predictions.columns:
         if col not in ("allele", "peptide"):
-- 
GitLab