From e55afc59840f805999e20fd6342c5efe32bf2901 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Fri, 24 Jan 2020 21:12:40 -0500
Subject: [PATCH] fix

---
 .../models_class1_presentation/make_benchmark.py         | 9 ++++-----
 mhcflurry/train_presentation_models_command.py           | 2 +-
 2 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/downloads-generation/models_class1_presentation/make_benchmark.py b/downloads-generation/models_class1_presentation/make_benchmark.py
index 6a14d7b6..6d094eb4 100644
--- a/downloads-generation/models_class1_presentation/make_benchmark.py
+++ b/downloads-generation/models_class1_presentation/make_benchmark.py
@@ -53,7 +53,8 @@ parser.add_argument(
 def run():
     args = parser.parse_args(sys.argv[1:])
     hit_df = pandas.read_csv(args.hits)
-    original_sample_ids = hit_df.sample_id.unique()
+    hit_df["pmid"] = hit_df["pmid"].astype(str)
+    original_samples_pmids = hit_df.pmid.unique()
     numpy.testing.assert_equal(hit_df.hit_id.nunique(), len(hit_df))
     hit_df = hit_df.loc[
         (hit_df.mhc_class == "I") &
@@ -69,14 +70,12 @@ def run():
         print("Subselected to %d %s samples" % (
             hit_df.sample_id.nunique(), args.only_format))
 
-    hit_df["pmid"] = hit_df["pmid"].astype(str)
-
     if args.only_pmid or args.exclude_pmid:
         assert not (args.only_pmid and args.exclude_pmid)
 
         pmids = list(args.only_pmid) + list(args.exclude_pmid)
-        missing = [pmid for pmid in pmids if pmid not in original_sample_ids]
-        assert not missing, missing
+        missing = [pmid for pmid in pmids if pmid not in original_samples_pmids]
+        assert not missing, (missing, original_samples_pmids)
 
         mask = hit_df.pmid.isin(pmids)
         if args.exclude_pmid:
diff --git a/mhcflurry/train_presentation_models_command.py b/mhcflurry/train_presentation_models_command.py
index d23d0106..ca1ac349 100644
--- a/mhcflurry/train_presentation_models_command.py
+++ b/mhcflurry/train_presentation_models_command.py
@@ -96,7 +96,7 @@ def main(args):
     ]
     print("Subselected to 8-15mers: %s" % (str(df.shape)))
 
-    df["experiment_id"] = df[args.hla_columns]
+    df["experiment_id"] = df[args.hla_column]
     experiment_to_alleles = dict((
         key, key.split()) for key in df.experiment_id.unique())
 
-- 
GitLab