diff --git a/.travis.yml b/.travis.yml
index bbd40a5af803f3b0afd5331dba69063788ae2292..870c5e3b316f9fd2705704038823a5d733ec7382 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -50,6 +50,7 @@ script:
       --embedding-size 10
       --hidden-layer-size 10
       --training-epochs 100
+      --imputation-method mice
   # run tests
   - nosetests test --with-coverage --cover-package=mhcflurry  && ./lint.sh
 after_success:
diff --git a/mhcflurry/class1_binding_predictor.py b/mhcflurry/class1_binding_predictor.py
index 28378df3b3c3a58d328546d1a1cb09cce5cd9323..e00f1c2095bf5778883364209e960d99d12ec8ce 100644
--- a/mhcflurry/class1_binding_predictor.py
+++ b/mhcflurry/class1_binding_predictor.py
@@ -131,7 +131,7 @@ class Class1BindingPredictor(PredictorBase):
             sample_weights,
             X_pretrain,
             Y_pretrain,
-            pretrain_sample_weights,
+            sample_weights_pretrain,
             verbose=False):
         """
         Make sure the shapes of given training and pre-training data
@@ -198,18 +198,18 @@ class Class1BindingPredictor(PredictorBase):
             raise ValueError("Maximum value of Y_pretrain can't be greater than 1, got %f" % (
                 Y.max()))
 
-        if pretrain_sample_weights is None:
-            pretrain_sample_weights = np.ones_like(Y_pretrain)
+        if sample_weights_pretrain is None:
+            sample_weights_pretrain = np.ones_like(Y_pretrain)
         else:
-            pretrain_sample_weights = np.asarray(pretrain_sample_weights)
+            sample_weights_pretrain = np.asarray(sample_weights_pretrain)
         if verbose:
             print("sample weights mean = %f, pretrain weights mean = %f" % (
                 sample_weights.mean(),
-                pretrain_sample_weights.mean()))
+                sample_weights_pretrain.mean()))
         X_combined = np.vstack([X_pretrain, X])
         Y_combined = np.concatenate([Y_pretrain, Y])
         combined_weights = np.concatenate([
-            pretrain_sample_weights,
+            sample_weights_pretrain,
             sample_weights,
         ])
         return X_combined, Y_combined, combined_weights, n_pretrain_samples
@@ -221,7 +221,7 @@ class Class1BindingPredictor(PredictorBase):
             sample_weights=None,
             X_pretrain=None,
             Y_pretrain=None,
-            pretrain_sample_weights=None,
+            sample_weights_pretrain=None,
             n_training_epochs=200,
             verbose=False,
             batch_size=128):
@@ -247,7 +247,7 @@ class Class1BindingPredictor(PredictorBase):
         Y_pretrain : array
             Labels for extra samples, shape
 
-        pretrain_sample_weights : array
+        sample_weights_pretrain : array
             Initial weights for the rows of X_pretrain. If not specified then
             initialized to ones.
 
@@ -259,7 +259,8 @@ class Class1BindingPredictor(PredictorBase):
         """
         X_combined, Y_combined, combined_weights, n_pretrain = \
             self._combine_training_data(
-                X, Y, sample_weights, X_pretrain, Y_pretrain, pretrain_sample_weights,
+                X, Y, sample_weights,
+                X_pretrain, Y_pretrain, sample_weights_pretrain,
                 verbose=verbose)
 
         total_pretrain_sample_weight = combined_weights[:n_pretrain].sum()
diff --git a/mhcflurry/imputation.py b/mhcflurry/imputation.py
index 8e9bb6136daa6e879faf4d49c56eff6fa2ebca92..115d66a20d4c76750fd58f27c19c46fcc782faa6 100644
--- a/mhcflurry/imputation.py
+++ b/mhcflurry/imputation.py
@@ -21,16 +21,12 @@ from collections import defaultdict
 import logging
 
 import numpy as np
-from fancyimpute.dictionary_helpers import (
-    dense_matrix_from_nested_dictionary
-)
-from fancyimpute import (
-    KNN,
-    IterativeSVD,
-    SimpleFill,
-    SoftImpute,
-    MICE
-)
+from fancyimpute.knn import KNN
+from fancyimpute.iterative_svd import IterativeSVD
+from fancyimpute.simple_fill import SimpleFill
+from fancyimpute.soft_impute import SoftImpute
+from fancyimpute.mice import MICE
+from fancyimpute.dictionary_helpers import dense_matrix_from_nested_dictionary
 
 from .data import (
     create_allele_data_from_peptide_to_ic50_dict,
@@ -143,15 +139,6 @@ def create_incomplete_dense_pMHC_matrix(
             if allele_name not in peptide_to_allele_to_affinity_dict[peptide]:
                 peptide_to_allele_to_affinity_dict[peptide][allele_name] = affinity
 
-    n_binding_values = sum(
-        len(allele_dict)
-        for allele_dict in
-        allele_data_dict.values()
-    )
-    print("Collected %d binding values for %d alleles" % (
-        n_binding_values,
-        len(peptide_to_allele_to_affinity_dict)))
-
     X, peptide_list, allele_list = \
         dense_matrix_from_nested_dictionary(peptide_to_allele_to_affinity_dict)
     _check_dense_pMHC_array(X, peptide_list, allele_list)
diff --git a/mhcflurry/serialization_helpers.py b/mhcflurry/serialization_helpers.py
index a3a732d28abdd92a96a46e4facd8be6717fac3cb..7089b65bb6d9ce8b3df9407e908dbdbfeff6f76e 100644
--- a/mhcflurry/serialization_helpers.py
+++ b/mhcflurry/serialization_helpers.py
@@ -28,7 +28,10 @@ import json
 
 from keras.models import model_from_config
 
-def load_keras_model_from_disk(model_json_path, weights_hdf_path, name=None):
+def load_keras_model_from_disk(
+        model_json_path,
+        weights_hdf_path,
+        name=None):
 
     if not exists(model_json_path):
         raise ValueError("Model file %s (name = %s) not found" % (
diff --git a/script/mhcflurry-train-class1-allele-specific-models.py b/script/mhcflurry-train-class1-allele-specific-models.py
index fddf924124f6bc3ecdf04210489b03f3f34d916f..d2755ce286165619aaa20037750fe4b9c4946fa9 100755
--- a/script/mhcflurry-train-class1-allele-specific-models.py
+++ b/script/mhcflurry-train-class1-allele-specific-models.py
@@ -41,8 +41,6 @@ from os import makedirs, remove
 from os.path import exists, join
 import argparse
 
-import numpy as np
-
 from mhcflurry.common import normalize_allele_name
 from mhcflurry.data import load_allele_datasets
 from mhcflurry.class1_binding_predictor import Class1BindingPredictor
@@ -53,7 +51,7 @@ from mhcflurry.paths import (
     CLASS1_MODEL_DIRECTORY,
     CLASS1_DATA_DIRECTORY
 )
-from mhcflurry.imputation import imputer_from_name, create_imputed_datasets
+from mhcflurry.imputation import create_imputed_datasets, imputer_from_name
 
 CSV_FILENAME = "combined_human_class1_dataset.csv"
 CSV_PATH = join(CLASS1_DATA_DIRECTORY, CSV_FILENAME)
@@ -98,7 +96,7 @@ parser.add_argument(
     "--imputation-method",
     default=None,
     choices=("mice", "knn", "softimpute", "svd", "mean"),
-    type=imputer_from_name,
+    type=lambda s: s.strip().lower(),
     help="Use the given imputation method to generate data for pre-training models")
 
 # add options for neural network hyperparameters
@@ -119,36 +117,57 @@ if __name__ == "__main__":
         sep=",",
         peptide_column_name="peptide")
 
-    # concatenate datasets from all alleles to use for pre-training of
-    # allele-specific predictors
-    X_all = np.vstack([group.X_index for group in allele_data_dict.values()])
-    Y_all = np.concatenate([group.Y for group in allele_data_dict.values()])
-    print("Total Dataset size = %d" % len(Y_all))
-
-    if args.imputation_method is not None:
-        # TODO: use imputed data for training
-        imputed_data_dict = create_imputed_datasets(
-            allele_data_dict,
-            args.imputation_method)
-
     # if user didn't specify alleles then train models for all available alleles
     alleles = args.alleles
 
     if not alleles:
         alleles = sorted(allele_data_dict.keys())
 
-    for allele_name in alleles:
-        allele_name = normalize_allele_name(allele_name)
-        if allele_name.isdigit():
-            print("Skipping allele %s" % (allele_name,))
-            continue
+    # restrict the data dictionary to only the specified alleles
+    # this also propagates to the imputation logic below, so we don't
+    # impute from other alleles
+    allele_data_dict = {
+        allele: allele_data_dict[allele]
+        for allele in alleles
+    }
+
+    if args.imputation_method is None:
+        imputer = None
+    else:
+        imputer = imputer_from_name(args.imputation_method)
+
+    if imputer is None:
+        imputed_data_dict = {}
+    else:
+        imputed_data_dict = create_imputed_datasets(
+            allele_data_dict,
+            imputer)
 
+    for allele_name in alleles:
         allele_data = allele_data_dict[allele_name]
         X = allele_data.X_index
         Y = allele_data.Y
+        weights = allele_data.weights
 
         n_allele = len(allele_data.Y)
         assert len(X) == n_allele
+        assert len(weights) == n_allele
+
+        if allele_name in imputed_data_dict:
+            imputed_data = imputed_data_dict[allele_name]
+            X_pretrain = imputed_data.X_index
+            Y_pretrain = imputed_data.Y
+            weights_pretrain = imputed_data.weights
+        else:
+            X_pretrain = None
+            Y_pretrain = None
+            weights_pretrain = None
+
+        # normalize allele name to check if it's just
+        allele_name = normalize_allele_name(allele_name)
+        if allele_name.isdigit():
+            print("Skipping allele %s" % (allele_name,))
+            continue
 
         print("\n=== Training predictor for %s: %d samples, %d unique" % (
             allele_name,
@@ -189,8 +208,12 @@ if __name__ == "__main__":
             remove(hdf_path)
 
         model.fit(
-            allele_data.X_index,
-            allele_data.Y,
+            X=allele_data.X_index,
+            Y=allele_data.Y,
+            sample_weights=weights,
+            X_pretrain=X_pretrain,
+            Y_pretrain=Y_pretrain,
+            sample_weights_pretrain=weights_pretrain,
             n_training_epochs=args.training_epochs,
             verbose=True)
 
diff --git a/test/test_imputation.py b/test/test_imputation.py
index 4302620309f1d04e42403ac68b52b0830729facf..70e136811ed2832b106834e8efd44410cd4c0769 100644
--- a/test/test_imputation.py
+++ b/test/test_imputation.py
@@ -84,13 +84,14 @@ def test_performance_improves_for_A0205_with_pretraining():
 
     predictor_with_imputation = \
         Class1BindingPredictor.from_hyperparameters(name="A0205-impute")
+
     predictor_with_imputation.fit(
         X=X_index,
         Y=Y_true,
         sample_weights=sample_weights,
         X_pretrain=X_index_imputed,
         Y_pretrain=Y_imputed,
-        pretrain_sample_weights=sample_weights_imputed,
+        sample_weights_pretrain=sample_weights_imputed,
         n_training_epochs=10)
 
     Y_pred_with_imputation = predictor_with_imputation.predict(X_index)