Skip to content
Snippets Groups Projects
Commit 79b04065 authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

added LR to hypeparameters

parent 6d52392f
No related merge requests found
......@@ -19,6 +19,7 @@ EMBEDDING_DIM = 32
HIDDEN_LAYER_SIZE = 200
DROPOUT_PROBABILITY = 0.25
MAX_IC50 = 50000.0
LEARNING_RATE = 0.001
def add_hyperparameter_arguments_to_parser(parser):
"""
......@@ -33,6 +34,7 @@ def add_hyperparameter_arguments_to_parser(parser):
"""
parser.add_argument(
"--training-epochs",
type=int,
default=N_EPOCHS,
help="Number of training epochs")
......@@ -48,22 +50,32 @@ def add_hyperparameter_arguments_to_parser(parser):
parser.add_argument(
"--embedding-size",
type=int,
default=EMBEDDING_DIM,
help="Size of vector representations for embedding amino acids")
parser.add_argument(
"--hidden-layer-size",
type=int,
default=HIDDEN_LAYER_SIZE,
help="Size of hidden neural network layer")
parser.add_argument(
"--dropout",
type=float,
default=DROPOUT_PROBABILITY,
help="Dropout probability after neural network layers")
parser.add_argument(
"--max-ic50",
type=float,
default=MAX_IC50,
help="Largest IC50 represented by neural network output")
parser.add_argument(
"--learning-rate",
type=float,
default=0.001,
help="Learning rate for training neural network")
return parser
......@@ -44,10 +44,10 @@ import argparse
import numpy as np
from mhcflurry.common import normalize_allele_name
from mhcflurry.feedforward import make_network
from mhcflurry.data_helpers import load_allele_datasets
from mhcflurry.data import load_allele_datasets
from mhcflurry.class1_binding_predictor import Class1BindingPredictor
from mhcflurry.class1_allele_specific_hyperparameters import (
add_hyperparamer_arguments_to_parser
add_hyperparameter_arguments_to_parser
)
from mhcflurry.paths import (
CLASS1_MODEL_DIRECTORY,
......@@ -58,23 +58,22 @@ CSV_PATH = join(CLASS1_DATA_DIRECTORY, CSV_FILENAME)
parser = argparse.ArgumentParser()
parser.add_argument(
"--binding-data-csv",
default=CSV_PATH,
help="CSV file with 'mhc', 'peptide', 'peptide_length', 'meas' columns")
parser.add_argument(
"--output-dir",
default=CLASS1_MODEL_DIRECTORY,
help="Output directory for allele-specific predictor HDF weights files")
parser.add_argument(
"--overwrite",
default=False,
action="store_true",
help="Overwrite existing output directory")
parser.add_argument(
"--binding-data-csv-path",
default=CSV_PATH,
help="CSV file with 'mhc', 'peptide', 'peptide_length', 'meas' columns")
parser.add_argument(
"--min-samples-per-allele",
default=5,
......@@ -82,7 +81,7 @@ parser.add_argument(
type=int)
# add options for neural network hyperparameters
parser = add_hyperparamer_arguments_to_parser(parser)
parser = add_hyperparameter_arguments_to_parser(parser)
if __name__ == "__main__":
args = parser.parse_args()
......@@ -91,28 +90,31 @@ if __name__ == "__main__":
makedirs(args.output_dir)
allele_groups = load_allele_datasets(
args.binding_data_csv_path,
filename=args.binding_data_csv,
peptide_length=9,
binary_encoding=False,
use_multiple_peptide_lengths=True,
max_ic50=args.max_ic50,
sep=",",
peptide_column_name="peptide")
# concatenate datasets from all alleles to use for pre-training of
# allele-specific predictors
X_all = np.vstack([group.X for group in allele_groups.values()])
X_all = np.vstack([group.X_index for group in allele_groups.values()])
Y_all = np.concatenate([group.Y for group in allele_groups.values()])
print("Total Dataset size = %d" % len(Y_all))
for allele_name, allele_data in allele_groups.items():
model = make_network(
input_size=9,
model = Class1BindingPredictor.from_hyperparameters(
name=allele_name,
peptide_length=9,
max_ic50=args.max_ic50,
embedding_input_dim=20,
embedding_output_dim=args.embedding_size,
layer_sizes=(args.hidden_layer_size,),
activation=args.activation,
init=args.initialization,
dropout_probability=args.dropout)
dropout_probability=args.dropout,
learning_rate=args.learning_rate)
allele_name = normalize_allele_name(allele_name)
if allele_name.isdigit():
print("Skipping allele %s" % (allele_name,))
......
def test_known_class1_epitopes():
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment