From eee968c980d2412c8b8fd3560e5c231b14c06aaf Mon Sep 17 00:00:00 2001 From: Alex Rubinsteyn <alex.rubinsteyn@gmail.com> Date: Tue, 19 Apr 2016 13:26:59 -0400 Subject: [PATCH] training script seems to work --- .../train-class1-allele-specific-models.py | 57 +++++++++++++------ 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/scripts/train-class1-allele-specific-models.py b/scripts/train-class1-allele-specific-models.py index 260674d2..7a6bd17e 100755 --- a/scripts/train-class1-allele-specific-models.py +++ b/scripts/train-class1-allele-specific-models.py @@ -80,12 +80,18 @@ parser.add_argument( help="Don't train predictors for alleles with fewer samples than this", type=int) +parser.add_argument( + "--alleles", + default=[], + nargs="+", + type=normalize_allele_name) + # add options for neural network hyperparameters parser = add_hyperparameter_arguments_to_parser(parser) if __name__ == "__main__": args = parser.parse_args() - + print(args) if not exists(args.output_dir): makedirs(args.output_dir) @@ -103,7 +109,29 @@ if __name__ == "__main__": Y_all = np.concatenate([group.Y for group in allele_groups.values()]) print("Total Dataset size = %d" % len(Y_all)) - for allele_name, allele_data in allele_groups.items(): + # if user didn't specify alleles then train models for all available alleles + alleles = args.alleles + + if not alleles: + alleles = sorted(allele_groups.keys()) + + for allele_name in alleles: + allele_name = normalize_allele_name(allele_name) + if allele_name.isdigit(): + print("Skipping allele %s" % (allele_name,)) + continue + + allele_data = allele_groups[allele_name] + X = allele_data.X_index + Y = allele_data.Y + + n_allele = len(allele_data.Y) + assert len(X) == n_allele + + print("\n=== Training predictor for %s: %d samples" % ( + allele_name, + n_allele)) + model = Class1BindingPredictor.from_hyperparameters( name=allele_name, peptide_length=9, @@ -115,12 +143,6 @@ if __name__ == "__main__": init=args.initialization, dropout_probability=args.dropout, learning_rate=args.learning_rate) - allele_name = normalize_allele_name(allele_name) - if allele_name.isdigit(): - print("Skipping allele %s" % (allele_name,)) - continue - n_allele = len(allele_data.Y) - print("%s: total count = %d" % (allele_name, n_allele)) json_filename = allele_name + ".json" json_path = join(args.output_dir, json_filename) @@ -139,19 +161,18 @@ if __name__ == "__main__": if exists(json_path): print("-- removing old model description %s" % json_path) remove(json_path) + if exists(hdf_path): print("-- removing old weights file %s" % hdf_path) remove(hdf_path) model.fit( - allele_data.X, + allele_data.X_index, allele_data.Y, - nb_epoch=args.training_epochs, - show_accuracy=True) - print("Saving model description for %s to %s" % ( - allele_name, json_path)) - with open(json_path, "w") as f: - f.write(model.to_json()) - print("Saving model weights for %s to %s" % ( - allele_name, hdf_path)) - model.save_weights(hdf_path) + n_training_epochs=args.training_epochs, + verbose=True) + + model.to_disk( + model_json_path=json_path, + weights_hdf_path=hdf_path, + overwrite=args.overwrite) -- GitLab