diff --git a/scripts/train-class1-allele-specific-models.py b/scripts/train-class1-allele-specific-models.py
index 260674d2f43c01d79dbcac94d395311529484f23..7a6bd17e9770799aa0426aca39aeb04c72fb20f2 100755
--- a/scripts/train-class1-allele-specific-models.py
+++ b/scripts/train-class1-allele-specific-models.py
@@ -80,12 +80,18 @@ parser.add_argument(
     help="Don't train predictors for alleles with fewer samples than this",
     type=int)
 
+parser.add_argument(
+    "--alleles",
+    default=[],
+    nargs="+",
+    type=normalize_allele_name)
+
 # add options for neural network hyperparameters
 parser = add_hyperparameter_arguments_to_parser(parser)
 
 if __name__ == "__main__":
     args = parser.parse_args()
-
+    print(args)
     if not exists(args.output_dir):
         makedirs(args.output_dir)
 
@@ -103,7 +109,29 @@ if __name__ == "__main__":
     Y_all = np.concatenate([group.Y for group in allele_groups.values()])
     print("Total Dataset size = %d" % len(Y_all))
 
-    for allele_name, allele_data in allele_groups.items():
+    # if user didn't specify alleles then train models for all available alleles
+    alleles = args.alleles
+
+    if not alleles:
+        alleles = sorted(allele_groups.keys())
+
+    for allele_name in alleles:
+        allele_name = normalize_allele_name(allele_name)
+        if allele_name.isdigit():
+            print("Skipping allele %s" % (allele_name,))
+            continue
+
+        allele_data = allele_groups[allele_name]
+        X = allele_data.X_index
+        Y = allele_data.Y
+
+        n_allele = len(allele_data.Y)
+        assert len(X) == n_allele
+
+        print("\n=== Training predictor for %s: %d samples" % (
+            allele_name,
+            n_allele))
+
         model = Class1BindingPredictor.from_hyperparameters(
             name=allele_name,
             peptide_length=9,
@@ -115,12 +143,6 @@ if __name__ == "__main__":
             init=args.initialization,
             dropout_probability=args.dropout,
             learning_rate=args.learning_rate)
-        allele_name = normalize_allele_name(allele_name)
-        if allele_name.isdigit():
-            print("Skipping allele %s" % (allele_name,))
-            continue
-        n_allele = len(allele_data.Y)
-        print("%s: total count = %d" % (allele_name, n_allele))
 
         json_filename = allele_name + ".json"
         json_path = join(args.output_dir, json_filename)
@@ -139,19 +161,18 @@ if __name__ == "__main__":
         if exists(json_path):
             print("-- removing old model description %s" % json_path)
             remove(json_path)
+
         if exists(hdf_path):
             print("-- removing old weights file %s" % hdf_path)
             remove(hdf_path)
 
         model.fit(
-            allele_data.X,
+            allele_data.X_index,
             allele_data.Y,
-            nb_epoch=args.training_epochs,
-            show_accuracy=True)
-        print("Saving model description for %s to %s" % (
-            allele_name, json_path))
-        with open(json_path, "w") as f:
-            f.write(model.to_json())
-        print("Saving model weights for %s to %s" % (
-            allele_name, hdf_path))
-        model.save_weights(hdf_path)
+            n_training_epochs=args.training_epochs,
+            verbose=True)
+
+        model.to_disk(
+            model_json_path=json_path,
+            weights_hdf_path=hdf_path,
+            overwrite=args.overwrite)