Skip to content
Snippets Groups Projects
Commit eee968c9 authored by Alex Rubinsteyn's avatar Alex Rubinsteyn
Browse files

training script seems to work

parent 79b04065
No related merge requests found
...@@ -80,12 +80,18 @@ parser.add_argument( ...@@ -80,12 +80,18 @@ parser.add_argument(
help="Don't train predictors for alleles with fewer samples than this", help="Don't train predictors for alleles with fewer samples than this",
type=int) type=int)
parser.add_argument(
"--alleles",
default=[],
nargs="+",
type=normalize_allele_name)
# add options for neural network hyperparameters # add options for neural network hyperparameters
parser = add_hyperparameter_arguments_to_parser(parser) parser = add_hyperparameter_arguments_to_parser(parser)
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
print(args)
if not exists(args.output_dir): if not exists(args.output_dir):
makedirs(args.output_dir) makedirs(args.output_dir)
...@@ -103,7 +109,29 @@ if __name__ == "__main__": ...@@ -103,7 +109,29 @@ if __name__ == "__main__":
Y_all = np.concatenate([group.Y for group in allele_groups.values()]) Y_all = np.concatenate([group.Y for group in allele_groups.values()])
print("Total Dataset size = %d" % len(Y_all)) print("Total Dataset size = %d" % len(Y_all))
for allele_name, allele_data in allele_groups.items(): # if user didn't specify alleles then train models for all available alleles
alleles = args.alleles
if not alleles:
alleles = sorted(allele_groups.keys())
for allele_name in alleles:
allele_name = normalize_allele_name(allele_name)
if allele_name.isdigit():
print("Skipping allele %s" % (allele_name,))
continue
allele_data = allele_groups[allele_name]
X = allele_data.X_index
Y = allele_data.Y
n_allele = len(allele_data.Y)
assert len(X) == n_allele
print("\n=== Training predictor for %s: %d samples" % (
allele_name,
n_allele))
model = Class1BindingPredictor.from_hyperparameters( model = Class1BindingPredictor.from_hyperparameters(
name=allele_name, name=allele_name,
peptide_length=9, peptide_length=9,
...@@ -115,12 +143,6 @@ if __name__ == "__main__": ...@@ -115,12 +143,6 @@ if __name__ == "__main__":
init=args.initialization, init=args.initialization,
dropout_probability=args.dropout, dropout_probability=args.dropout,
learning_rate=args.learning_rate) learning_rate=args.learning_rate)
allele_name = normalize_allele_name(allele_name)
if allele_name.isdigit():
print("Skipping allele %s" % (allele_name,))
continue
n_allele = len(allele_data.Y)
print("%s: total count = %d" % (allele_name, n_allele))
json_filename = allele_name + ".json" json_filename = allele_name + ".json"
json_path = join(args.output_dir, json_filename) json_path = join(args.output_dir, json_filename)
...@@ -139,19 +161,18 @@ if __name__ == "__main__": ...@@ -139,19 +161,18 @@ if __name__ == "__main__":
if exists(json_path): if exists(json_path):
print("-- removing old model description %s" % json_path) print("-- removing old model description %s" % json_path)
remove(json_path) remove(json_path)
if exists(hdf_path): if exists(hdf_path):
print("-- removing old weights file %s" % hdf_path) print("-- removing old weights file %s" % hdf_path)
remove(hdf_path) remove(hdf_path)
model.fit( model.fit(
allele_data.X, allele_data.X_index,
allele_data.Y, allele_data.Y,
nb_epoch=args.training_epochs, n_training_epochs=args.training_epochs,
show_accuracy=True) verbose=True)
print("Saving model description for %s to %s" % (
allele_name, json_path)) model.to_disk(
with open(json_path, "w") as f: model_json_path=json_path,
f.write(model.to_json()) weights_hdf_path=hdf_path,
print("Saving model weights for %s to %s" % ( overwrite=args.overwrite)
allele_name, hdf_path))
model.save_weights(hdf_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment