diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index 1c0e50219d5d663268842a1dc2377699738ccca0..b21360cf45e3b5fb105f5d0d816a4490df1b169d 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -497,7 +497,7 @@ def train_model(
         print(alleles)
         data = data.loc[data.allele.isin(alleles)]
         assert len(data) >= pretrain_min_points, (len(data), pretrain_min_points)
-        data = (data.allele == allele).astype(int).values
+        train_rounds = (data.allele == allele).astype(int).values
     else:
         train_rounds = None
         data = data.loc[data.allele == allele]