From 38adfb4d8f310f7f3399e4149271b1a8bc7aef82 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Tue, 28 Nov 2017 10:28:21 -0500
Subject: [PATCH] Change prediction batch_size from 32 to 4096 for better
 performance

---
 .../class1_neural_network.py                          | 11 ++++++++---
 test/test_class1_affinity_predictor.py                |  1 -
 test/test_train_allele_specific_models_command.py     |  4 ++--
 3 files changed, 10 insertions(+), 6 deletions(-)

diff --git a/mhcflurry/class1_affinity_prediction/class1_neural_network.py b/mhcflurry/class1_affinity_prediction/class1_neural_network.py
index f37827f3..73a9670a 100644
--- a/mhcflurry/class1_affinity_prediction/class1_neural_network.py
+++ b/mhcflurry/class1_affinity_prediction/class1_neural_network.py
@@ -479,7 +479,7 @@ class Class1NeuralNetwork(object):
                         break
         self.fit_seconds = time.time() - start
 
-    def predict(self, peptides, allele_pseudosequences=None):
+    def predict(self, peptides, allele_pseudosequences=None, batch_size=4096):
         """
         Predict affinities
         
@@ -490,6 +490,9 @@ class Class1NeuralNetwork(object):
         allele_pseudosequences : EncodableSequences or list of string, optional
             Only required when this model is a pan-allele model
 
+        batch_size : int
+            batch_size passed to Keras
+
         Returns
         -------
         numpy.array of nM affinity predictions 
@@ -501,8 +504,10 @@ class Class1NeuralNetwork(object):
             pseudosequences_input = self.pseudosequence_to_network_input(
                 allele_pseudosequences)
             x_dict['pseudosequence'] = pseudosequences_input
-        (predictions,) = numpy.array(
-            self.network(borrow=True).predict(x_dict), dtype="float64").T
+
+        network = self.network(borrow=True)
+        raw_predictions = network.predict(x_dict, batch_size=batch_size)
+        predictions = numpy.array(raw_predictions, dtype = "float64")[:,0]
         return to_ic50(predictions)
 
     def compile(self):
diff --git a/test/test_class1_affinity_predictor.py b/test/test_class1_affinity_predictor.py
index 3ab6277c..65f1f587 100644
--- a/test/test_class1_affinity_predictor.py
+++ b/test/test_class1_affinity_predictor.py
@@ -132,7 +132,6 @@ def test_class1_affinity_predictor_a0205_memorize_training_data():
         dense_layer_l1_regularization=0.0,
         dropout_probability=0.0)
 
-    # First test a Class1NeuralNetwork, then a Class1AffinityPredictor.
     allele = "HLA-A*02:05"
 
     df = pandas.read_csv(
diff --git a/test/test_train_allele_specific_models_command.py b/test/test_train_allele_specific_models_command.py
index f08963a1..19cde4f2 100644
--- a/test/test_train_allele_specific_models_command.py
+++ b/test/test_train_allele_specific_models_command.py
@@ -60,9 +60,9 @@ def test_run():
     args = [
         "--data", get_path("data_curated", "curated_training_data.csv.bz2"),
         "--hyperparameters", hyperparameters_filename,
-        "--min-measurements-per-allele", "9000",
+        "--allele", "HLA-A*02:01", "HLA-A*01:01", "HLA-A*03:01",
         "--out-models-dir", models_dir,
-        "--percent-rank-calibration-num-peptides-per-length", "1000",
+        "--percent-rank-calibration-num-peptides-per-length", "10000",
     ]
     print("Running with args: %s" % args)
     train_allele_specific_models_command.run(args)
-- 
GitLab