diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py index decd4b980679d2dc6623629464c522a41237b1d3..d075167ddd7d82709aece6cdabde09c63215a78c 100644 --- a/mhcflurry/train_allele_specific_models_command.py +++ b/mhcflurry/train_allele_specific_models_command.py @@ -200,7 +200,6 @@ def run(argv=sys.argv[1:]): print("Training data: %s" % (str(df.shape))) GLOBAL_DATA["train_data"] = df - GLOBAL_DATA["train_data_size_by_allele"] = df.allele.value_counts() GLOBAL_DATA["args"] = args if not os.path.exists(args.out_models_dir): @@ -477,7 +476,18 @@ def train_model( pretrain_min_points = hyperparameters['train_data']['pretrain_min_points'] full_data = GLOBAL_DATA["train_data"] - data_size_by_allele = GLOBAL_DATA["train_data_size_by_allele"] + + subset = hyperparameters.get("train_data", {}).get("subset", "all") + if subset == "quantitative": + data = full_data.loc[ + full_data.measurement_type == "quantitative" + ] + elif subset == "all": + pass + else: + raise ValueError("Unsupported subset: %s" % subset) + + data_size_by_allele = data.allele.value_counts() if pretrain_min_points: similar_alleles = alleles_by_similarity(allele) @@ -485,22 +495,12 @@ def train_model( while not alleles or data_size_by_allele.loc[alleles].sum() < pretrain_min_points: alleles.append(similar_alleles.pop(0)) print(alleles) - data = full_data.loc[full_data.allele.isin(alleles)] + data = data.loc[data.allele.isin(alleles)] assert len(data) >= pretrain_min_points, (len(data), pretrain_min_points) - train_rounds = (data.allele == allele).astype(int).values + data = (data.allele == allele).astype(int).values else: train_rounds = None - data = full_data.loc[full_data.allele == allele] - - subset = hyperparameters.get("train_data", {}).get("subset", "all") - if subset == "quantitative": - data = data.loc[ - data.measurement_type == "quantitative" - ] - elif subset == "all": - pass - else: - raise ValueError("Unsupported subset: %s" % subset) + data = data.loc[data.allele == allele] progress_preamble = ( "[%2d / %2d hyperparameters] "