From d69999965f53cd8a4fc0fa217c96bfb51e1f92a4 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Sat, 30 Nov 2019 23:13:21 -0500
Subject: [PATCH] working

---
 mhcflurry/batch_generator.py            |  13 ++-
 mhcflurry/class1_ligandome_predictor.py |   3 +-
 test/test_batch_generator.py            | 114 ++++++++++++++++++++++++
 3 files changed, 122 insertions(+), 8 deletions(-)

diff --git a/mhcflurry/batch_generator.py b/mhcflurry/batch_generator.py
index 25659c51..4b000ac4 100644
--- a/mhcflurry/batch_generator.py
+++ b/mhcflurry/batch_generator.py
@@ -116,13 +116,12 @@ class MultiallelicMassSpecBatchGenerator(object):
     def plan_from_dataframe(df, hyperparameters):
         affinity_fraction = hyperparameters["batch_generator_affinity_fraction"]
         batch_size = hyperparameters["batch_generator_batch_size"]
-        classes = {}
-        df["equivalence_class"] = [
-            classes.setdefault(
-                tuple(row[["is_affinity", "is_binder", "experiment_name"]]),
-                len(classes))
-            for _, row in df.iterrows()
-        ]
+        equivalence_columns = ["is_affinity", "is_binder", "experiment_name"]
+        df["equivalence_key"] = df[equivalence_columns].astype(str).sum(1)
+        equivalence_map = dict(
+            (v, i)
+            for (i, v) in zip(*df.equivalence_key.factorize()))
+        df["equivalence_class"] = df.equivalence_key.map(equivalence_map)
         df["first_allele"] = df.alleles.str.get(0)
         df["unused"] = True
         df["idx"] = df.index
diff --git a/mhcflurry/class1_ligandome_predictor.py b/mhcflurry/class1_ligandome_predictor.py
index 8ddd4c20..cc851e69 100644
--- a/mhcflurry/class1_ligandome_predictor.py
+++ b/mhcflurry/class1_ligandome_predictor.py
@@ -725,7 +725,8 @@ class Class1LigandomePredictor(object):
                 axis=0)
 
         if existing_weights_shape != reshaped.shape:
-            print("Performing network surgery", existing_weights_shape, reshaped.shape)
+            print(
+                "Performing network surgery", existing_weights_shape, reshaped.shape)
             # Network surgery required. Make a new network with this layer's
             # dimensions changed. Kind of a hack.
             layer.input_dim = reshaped.shape[0]
diff --git a/test/test_batch_generator.py b/test/test_batch_generator.py
index ab3dcf52..f5aac2aa 100644
--- a/test/test_batch_generator.py
+++ b/test/test_batch_generator.py
@@ -1,12 +1,34 @@
+import logging
+logging.getLogger('matplotlib').disabled = True
+logging.getLogger('tensorflow').disabled = True
+
+import os
+import collections
+import time
+import cProfile
+import pstats
+
 import pandas
 import numpy
 
+from mhcflurry.allele_encoding import MultipleAlleleEncoding
+from mhcflurry.downloads import get_path
 from mhcflurry.batch_generator import (
     MultiallelicMassSpecBatchGenerator)
+from mhcflurry.regression_target import to_ic50
+from mhcflurry import Class1AffinityPredictor
 
 from numpy.testing import assert_equal
 
 
+def data_path(name):
+    '''
+    Return the absolute path to a file in the test/data directory.
+    The name specified should be relative to test/data.
+    '''
+    return os.path.join(os.path.dirname(__file__), "data", name)
+
+
 def test_basic():
     planner = MultiallelicMassSpecBatchGenerator(
         hyperparameters=dict(
@@ -64,3 +86,95 @@ def test_basic():
 
     #import ipdb;ipdb.set_trace()
 
+
+def test_large(sample_rate=0.01):
+    multi_train_df = pandas.read_csv(
+        data_path("multiallelic_ms.benchmark1.csv.bz2"))
+    multi_train_df["label"] = multi_train_df.hit
+    multi_train_df["is_affinity"] = False
+
+    sample_table = multi_train_df.loc[
+        multi_train_df.label == True
+    ].drop_duplicates("sample_id").set_index("sample_id").loc[
+        multi_train_df.sample_id.unique()
+    ]
+    grouped = multi_train_df.groupby("sample_id").nunique()
+    for col in sample_table.columns:
+        if (grouped[col] > 1).any():
+            del sample_table[col]
+    sample_table["alleles"] = sample_table.hla.str.split()
+
+    pan_train_df = pandas.read_csv(
+        get_path(
+            "models_class1_pan", "models.with_mass_spec/train_data.csv.bz2"))
+    pan_sub_train_df = pan_train_df
+    pan_sub_train_df["label"] = pan_sub_train_df["measurement_value"]
+    del pan_sub_train_df["measurement_value"]
+    pan_sub_train_df["is_affinity"] = True
+
+    pan_sub_train_df = pan_sub_train_df.sample(frac=sample_rate)
+    multi_train_df = multi_train_df.sample(frac=sample_rate)
+
+    pan_predictor = Class1AffinityPredictor.load(
+        get_path("models_class1_pan", "models.with_mass_spec"),
+        optimization_level=0,
+        max_models=1)
+
+    allele_encoding = MultipleAlleleEncoding(
+        experiment_names=multi_train_df.sample_id.values,
+        experiment_to_allele_list=sample_table.alleles.to_dict(),
+        max_alleles_per_experiment=sample_table.alleles.str.len().max(),
+        allele_to_sequence=pan_predictor.allele_to_sequence,
+    )
+    allele_encoding.append_alleles(pan_sub_train_df.allele.values)
+    allele_encoding = allele_encoding.compact()
+
+    combined_train_df = pandas.concat(
+        [multi_train_df, pan_sub_train_df], ignore_index=True, sort=True)
+
+    print("Total size", combined_train_df)
+
+    planner = MultiallelicMassSpecBatchGenerator(
+        hyperparameters=dict(
+            batch_generator_validation_split=0.2,
+            batch_generator_batch_size=1024,
+            batch_generator_affinity_fraction=0.5))
+
+    s = time.time()
+    profiler = cProfile.Profile()
+    profiler.enable()
+    planner.plan(
+        affinities_mask=combined_train_df.is_affinity.values,
+        experiment_names=combined_train_df.sample_id.values,
+        alleles_matrix=allele_encoding.alleles,
+        is_binder=numpy.where(
+            combined_train_df.is_affinity.values,
+            combined_train_df.label.values,
+            to_ic50(combined_train_df.label.values)) < 1000.0)
+    stats = pstats.Stats(profiler)
+    stats.sort_stats("cumtime").reverse_order().print_stats()
+    print(planner.summary())
+    print("Planning took [sec]: ", time.time() - s)
+
+    (train_iter, test_iter) = planner.get_train_and_test_generators(
+        x_dict={
+            "idx": numpy.arange(len(combined_train_df)),
+        },
+        y_list=[])
+
+    for (kind, it) in [("train", train_iter), ("test", test_iter)]:
+        for (i, (x_item, y_item)) in enumerate(it):
+            idx = x_item["idx"]
+            combined_train_df.loc[idx, "kind"] = kind
+            combined_train_df.loc[idx, "idx"] = idx
+            combined_train_df.loc[idx, "batch"] = i
+    combined_train_df["idx"] = combined_train_df.idx.astype(int)
+    combined_train_df["batch"] = combined_train_df.batch.astype(int)
+
+    for ((kind, batch_num), batch_df) in combined_train_df.groupby(["kind", "batch"]):
+        if not batch_df.is_affinity.all():
+            # Test each batch has at most one multiallelic ms experiment.
+            assert_equal(
+                batch_df.loc[
+                    ~batch_df.is_affinity
+                ].sample_id.nunique(), 1)
\ No newline at end of file
-- 
GitLab