Skip to content
Snippets Groups Projects
Commit 6d52392f authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

adding options to training script and updating to work with new predictor class

parent 7cb4b493
No related branches found
No related tags found
No related merge requests found
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
N_PRETRAIN_EPOCHS = 5
N_EPOCHS = 250
ACTIVATION = "tanh"
INITIALIZATION_METHOD = "lecun_uniform"
......@@ -20,3 +19,51 @@ EMBEDDING_DIM = 32
HIDDEN_LAYER_SIZE = 200
DROPOUT_PROBABILITY = 0.25
MAX_IC50 = 50000.0
def add_hyperparameter_arguments_to_parser(parser):
"""
Extend an argument parser with the following options:
--training-epochs
--activation
--initialization
--embedding-size
--hidden-layer-size
--dropout
--max-ic50
"""
parser.add_argument(
"--training-epochs",
default=N_EPOCHS,
help="Number of training epochs")
parser.add_argument(
"--initialization",
default=INITIALIZATION_METHOD,
help="Initialization for neural network weights")
parser.add_argument(
"--activation",
default=ACTIVATION,
help="Activation function for neural network layers")
parser.add_argument(
"--embedding-size",
default=EMBEDDING_DIM,
help="Size of vector representations for embedding amino acids")
parser.add_argument(
"--hidden-layer-size",
default=HIDDEN_LAYER_SIZE,
help="Size of hidden neural network layer")
parser.add_argument(
"--dropout",
default=DROPOUT_PROBABILITY,
help="Dropout probability after neural network layers")
parser.add_argument(
"--max-ic50",
default=MAX_IC50,
help="Largest IC50 represented by neural network output")
return parser
......@@ -47,14 +47,7 @@ from mhcflurry.common import normalize_allele_name
from mhcflurry.feedforward import make_network
from mhcflurry.data_helpers import load_allele_datasets
from mhcflurry.class1_allele_specific_hyperparameters import (
N_PRETRAIN_EPOCHS,
N_EPOCHS,
ACTIVATION,
INITIALIZATION_METHOD,
EMBEDDING_DIM,
HIDDEN_LAYER_SIZE,
DROPOUT_PROBABILITY,
MAX_IC50
add_hyperparamer_arguments_to_parser
)
from mhcflurry.paths import (
CLASS1_MODEL_DIRECTORY,
......@@ -70,6 +63,7 @@ parser.add_argument(
default=CLASS1_MODEL_DIRECTORY,
help="Output directory for allele-specific predictor HDF weights files")
parser.add_argument(
"--overwrite",
default=False,
......@@ -87,6 +81,9 @@ parser.add_argument(
help="Don't train predictors for alleles with fewer samples than this",
type=int)
# add options for neural network hyperparameters
parser = add_hyperparamer_arguments_to_parser(parser)
if __name__ == "__main__":
args = parser.parse_args()
......@@ -97,7 +94,7 @@ if __name__ == "__main__":
args.binding_data_csv_path,
peptide_length=9,
binary_encoding=False,
max_ic50=MAX_IC50,
max_ic50=args.max_ic50,
sep=",",
peptide_column_name="peptide")
......@@ -107,18 +104,15 @@ if __name__ == "__main__":
Y_all = np.concatenate([group.Y for group in allele_groups.values()])
print("Total Dataset size = %d" % len(Y_all))
model = make_network(
input_size=9,
embedding_input_dim=20,
embedding_output_dim=EMBEDDING_DIM,
layer_sizes=(HIDDEN_LAYER_SIZE,),
activation=ACTIVATION,
init=INITIALIZATION_METHOD,
dropout_probability=DROPOUT_PROBABILITY)
print("Model config: %s" % (model.get_config(),))
model.fit(X_all, Y_all, nb_epoch=N_PRETRAIN_EPOCHS)
old_weights = model.get_weights()
for allele_name, allele_data in allele_groups.items():
model = make_network(
input_size=9,
embedding_input_dim=20,
embedding_output_dim=args.embedding_size,
layer_sizes=(args.hidden_layer_size,),
activation=args.activation,
init=args.initialization,
dropout_probability=args.dropout)
allele_name = normalize_allele_name(allele_name)
if allele_name.isdigit():
print("Skipping allele %s" % (allele_name,))
......@@ -147,11 +141,10 @@ if __name__ == "__main__":
print("-- removing old weights file %s" % hdf_path)
remove(hdf_path)
model.set_weights(old_weights)
model.fit(
allele_data.X,
allele_data.Y,
nb_epoch=N_EPOCHS,
nb_epoch=args.training_epochs,
show_accuracy=True)
print("Saving model description for %s to %s" % (
allele_name, json_path))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment