Skip to content
Snippets Groups Projects
Commit 38adfb4d authored by Tim O'Donnell's avatar Tim O'Donnell
Browse files

Change prediction batch_size from 32 to 4096 for better performance

parent cbb649d6
No related merge requests found
......@@ -479,7 +479,7 @@ class Class1NeuralNetwork(object):
break
self.fit_seconds = time.time() - start
def predict(self, peptides, allele_pseudosequences=None):
def predict(self, peptides, allele_pseudosequences=None, batch_size=4096):
"""
Predict affinities
......@@ -490,6 +490,9 @@ class Class1NeuralNetwork(object):
allele_pseudosequences : EncodableSequences or list of string, optional
Only required when this model is a pan-allele model
batch_size : int
batch_size passed to Keras
Returns
-------
numpy.array of nM affinity predictions
......@@ -501,8 +504,10 @@ class Class1NeuralNetwork(object):
pseudosequences_input = self.pseudosequence_to_network_input(
allele_pseudosequences)
x_dict['pseudosequence'] = pseudosequences_input
(predictions,) = numpy.array(
self.network(borrow=True).predict(x_dict), dtype="float64").T
network = self.network(borrow=True)
raw_predictions = network.predict(x_dict, batch_size=batch_size)
predictions = numpy.array(raw_predictions, dtype = "float64")[:,0]
return to_ic50(predictions)
def compile(self):
......
......@@ -132,7 +132,6 @@ def test_class1_affinity_predictor_a0205_memorize_training_data():
dense_layer_l1_regularization=0.0,
dropout_probability=0.0)
# First test a Class1NeuralNetwork, then a Class1AffinityPredictor.
allele = "HLA-A*02:05"
df = pandas.read_csv(
......
......@@ -60,9 +60,9 @@ def test_run():
args = [
"--data", get_path("data_curated", "curated_training_data.csv.bz2"),
"--hyperparameters", hyperparameters_filename,
"--min-measurements-per-allele", "9000",
"--allele", "HLA-A*02:01", "HLA-A*01:01", "HLA-A*03:01",
"--out-models-dir", models_dir,
"--percent-rank-calibration-num-peptides-per-length", "1000",
"--percent-rank-calibration-num-peptides-per-length", "10000",
]
print("Running with args: %s" % args)
train_allele_specific_models_command.run(args)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment