From 9f1bb326b1af2320d13da04adfa5977afd781589 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Wed, 14 Feb 2018 16:25:23 -0500
Subject: [PATCH] fix

---
 .../train_allele_specific_models_command.py   | 30 +++++++++----------
 1 file changed, 15 insertions(+), 15 deletions(-)

diff --git a/mhcflurry/train_allele_specific_models_command.py b/mhcflurry/train_allele_specific_models_command.py
index decd4b98..d075167d 100644
--- a/mhcflurry/train_allele_specific_models_command.py
+++ b/mhcflurry/train_allele_specific_models_command.py
@@ -200,7 +200,6 @@ def run(argv=sys.argv[1:]):
     print("Training data: %s" % (str(df.shape)))
 
     GLOBAL_DATA["train_data"] = df
-    GLOBAL_DATA["train_data_size_by_allele"] = df.allele.value_counts()
     GLOBAL_DATA["args"] = args
 
     if not os.path.exists(args.out_models_dir):
@@ -477,7 +476,18 @@ def train_model(
     pretrain_min_points = hyperparameters['train_data']['pretrain_min_points']
 
     full_data = GLOBAL_DATA["train_data"]
-    data_size_by_allele = GLOBAL_DATA["train_data_size_by_allele"]
+
+    subset = hyperparameters.get("train_data", {}).get("subset", "all")
+    if subset == "quantitative":
+        data = full_data.loc[
+            full_data.measurement_type == "quantitative"
+        ]
+    elif subset == "all":
+        pass
+    else:
+        raise ValueError("Unsupported subset: %s" % subset)
+
+    data_size_by_allele = data.allele.value_counts()
 
     if pretrain_min_points:
         similar_alleles = alleles_by_similarity(allele)
@@ -485,22 +495,12 @@ def train_model(
         while not alleles or data_size_by_allele.loc[alleles].sum() < pretrain_min_points:
             alleles.append(similar_alleles.pop(0))
         print(alleles)
-        data = full_data.loc[full_data.allele.isin(alleles)]
+        data = data.loc[data.allele.isin(alleles)]
         assert len(data) >= pretrain_min_points, (len(data), pretrain_min_points)
-        train_rounds = (data.allele == allele).astype(int).values
+        data = (data.allele == allele).astype(int).values
     else:
         train_rounds = None
-        data = full_data.loc[full_data.allele == allele]
-
-    subset = hyperparameters.get("train_data", {}).get("subset", "all")
-    if subset == "quantitative":
-        data = data.loc[
-            data.measurement_type == "quantitative"
-        ]
-    elif subset == "all":
-        pass
-    else:
-        raise ValueError("Unsupported subset: %s" % subset)
+        data = data.loc[data.allele == allele]
 
     progress_preamble = (
         "[%2d / %2d hyperparameters] "
-- 
GitLab