diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py index 1c0e50219d5d663268842a1dc2377699738ccca0..b21360cf45e3b5fb105f5d0d816a4490df1b169d 100644 --- a/mhcflurry/train_allele_specific_models_command.py +++ b/mhcflurry/train_allele_specific_models_command.py @@ -497,7 +497,7 @@ def train_model( print(alleles) data = data.loc[data.allele.isin(alleles)] assert len(data) >= pretrain_min_points, (len(data), pretrain_min_points) - data = (data.allele == allele).astype(int).values + train_rounds = (data.allele == allele).astype(int).values else: train_rounds = None data = data.loc[data.allele == allele]