From 63482a3fa759f20e5774f3ad83ae13be435e4401 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Fri, 29 Nov 2019 23:33:57 -0500
Subject: [PATCH] fix

---
 ..._batch_generator.py => batch_generator.py} |  77 ++++--
 mhcflurry/class1_ligandome_predictor.py       | 220 ++++--------------
 ...h_generator.py => test_batch_generator.py} |   3 +-
 3 files changed, 105 insertions(+), 195 deletions(-)
 rename mhcflurry/{multiallelic_mass_spec_batch_generator.py => batch_generator.py} (76%)
 rename test/{test_multiallelic_mass_spec_batch_generator.py => test_batch_generator.py} (95%)

diff --git a/mhcflurry/multiallelic_mass_spec_batch_generator.py b/mhcflurry/batch_generator.py
similarity index 76%
rename from mhcflurry/multiallelic_mass_spec_batch_generator.py
rename to mhcflurry/batch_generator.py
index 754d0355..25659c51 100644
--- a/mhcflurry/multiallelic_mass_spec_batch_generator.py
+++ b/mhcflurry/batch_generator.py
@@ -8,7 +8,16 @@ from .hyperparameters import HyperparameterDefaults
 
 
 class BatchPlan(object):
-    def __init__(self, equivalence_classes, batch_compositions):
+    def __init__(self, equivalence_classes, batch_compositions, equivalence_class_labels=None):
+        """
+
+        Parameters
+        ----------
+        equivalence_classes
+        batch_compositions
+        equivalence_class_labels : list of string, optional
+            Used only for summary().
+        """
         # batch_compositions is (num batches_generator, batch size)
 
         self.equivalence_classes = equivalence_classes # indices into points
@@ -23,6 +32,9 @@ class BatchPlan(object):
             indices_into_equivalence_classes.append(
                 numpy.array(indices, dtype=int))
         self.indices_into_equivalence_classes = indices_into_equivalence_classes
+        self.equivalence_class_labels = (
+            numpy.array(equivalence_class_labels)
+            if equivalence_class_labels is not None else None)
 
     def batch_indices_generator(self, epochs=1):
         batch_nums = numpy.arange(len(self.batch_compositions))
@@ -54,21 +66,35 @@ class BatchPlan(object):
 
     def summary(self, indent=0):
         lines = []
-        lines.append("Equivalence class sizes: ")
-        lines.append(pandas.Series(
-            [len(c) for c in self.equivalence_classes]))
-        lines.append("Batch compositions: ")
-        lines.append(self.batch_compositions)
+        equivalence_class_labels = self.equivalence_class_labels
+        if equivalence_class_labels is None:
+            equivalence_class_labels = (
+                "class-" + numpy.arange(self.equivalence_classes).astype("str"))
+
+        i = 0
+        while i < len(self.batch_compositions):
+            composition = self.batch_compositions[i]
+            label_counts = pandas.Series(
+                equivalence_class_labels[composition]).value_counts()
+            lines.append(
+                ("Batch %5d: " % i) + ", ".join(
+                    "{key}[{value}]".format(key=key, value=value)
+                    for (key, value) in label_counts.iteritems()))
+            if i == 5:
+                lines.append("...")
+                i = len(self.batch_compositions) - 4
+            i += 1
+
         indent_spaces = "    " * indent
         return "\n".join([indent_spaces + str(line) for line in lines])
 
     @property
     def num_batches(self):
-        return self.batch_compositions.shape[0]
+        return len(self.batch_compositions)
 
     @property
     def batch_size(self):
-        return self.batch_compositions.shape[1]
+        return max(len(b) for b in self.batch_compositions)
 
 
 class MultiallelicMassSpecBatchGenerator(object):
@@ -100,6 +126,15 @@ class MultiallelicMassSpecBatchGenerator(object):
         df["first_allele"] = df.alleles.str.get(0)
         df["unused"] = True
         df["idx"] = df.index
+        equivalence_class_to_label = dict(
+            (idx, (
+                "{first_allele} {binder}" if row.is_affinity else
+                "{experiment_name} {binder}"
+                ).format(
+                    binder="binder" if row.is_binder else "nonbinder",
+                    **row.to_dict()))
+            for (idx, row) in df.drop_duplicates(
+                "equivalence_class").set_index("equivalence_class").iterrows())
         df = df.sample(frac=1.0)
         #df["key"] = df.is_binder ^ (numpy.arange(len(df)) % 2).astype(bool)
         #df = df.sort_values("key")
@@ -171,14 +206,19 @@ class MultiallelicMassSpecBatchGenerator(object):
         ]
         return BatchPlan(
             equivalence_classes=equivalence_classes,
-            batch_compositions=batch_compositions)
+            batch_compositions=batch_compositions,
+            equivalence_class_labels=[
+                equivalence_class_to_label[i] for i in
+                range(len(class_to_indices))
+            ])
 
     def plan(
             self,
             affinities_mask,
             experiment_names,
             alleles_matrix,
-            is_binder):
+            is_binder,
+            potential_validation_mask=None):
         affinities_mask = numpy.array(affinities_mask, copy=False, dtype=bool)
         experiment_names = numpy.array(experiment_names, copy=False)
         alleles_matrix = numpy.array(alleles_matrix, copy=False)
@@ -190,10 +230,13 @@ class MultiallelicMassSpecBatchGenerator(object):
         numpy.testing.assert_equal(len(is_binder), n)
         numpy.testing.assert_equal(
             affinities_mask, pandas.isnull(experiment_names))
+        if potential_validation_mask is not None:
+            numpy.testing.assert_equal(len(potential_validation_mask), n)
 
         validation_items = numpy.random.choice(
-            n, int(
-                self.hyperparameters['batch_generator_validation_split'] * n))
+            n if potential_validation_mask is None
+                else numpy.where(potential_validation_mask)[0],
+            int(self.hyperparameters['batch_generator_validation_split'] * n))
         validation_mask = numpy.zeros(n, dtype=bool)
         validation_mask[validation_items] = True
 
@@ -216,7 +259,7 @@ class MultiallelicMassSpecBatchGenerator(object):
 
     def summary(self):
         return (
-            "Train: " + self.train_batch_plan.summary(indent=1) +
+            "Train:\n" + self.train_batch_plan.summary(indent=1) +
             "\n***\nTest: " + self.test_batch_plan.summary(indent=1))
 
     def get_train_and_test_generators(self, x_dict, y_list, epochs=1):
@@ -225,3 +268,11 @@ class MultiallelicMassSpecBatchGenerator(object):
         test_generator = self.test_batch_plan.batches_generator(
             x_dict, y_list, epochs=epochs)
         return (train_generator, test_generator)
+
+    @property
+    def num_train_batches(self):
+        return self.train_batch_plan.num_batches
+
+    @property
+    def num_test_batches(self):
+        return self.test_batch_plan.num_batches
diff --git a/mhcflurry/class1_ligandome_predictor.py b/mhcflurry/class1_ligandome_predictor.py
index 0c140d11..8ddd4c20 100644
--- a/mhcflurry/class1_ligandome_predictor.py
+++ b/mhcflurry/class1_ligandome_predictor.py
@@ -16,6 +16,7 @@ from .regression_target import from_ic50, to_ic50
 from .random_negative_peptides import RandomNegativePeptides
 from .allele_encoding import MultipleAlleleEncoding, AlleleEncoding
 from .auxiliary_input import AuxiliaryInputEncoder
+from .batch_generator import MultiallelicMassSpecBatchGenerator
 from .custom_loss import (
     MSEWithInequalities,
     MultiallelicMassSpecLoss,
@@ -39,11 +40,10 @@ class Class1LigandomePredictor(object):
 
     fit_hyperparameter_defaults = HyperparameterDefaults(
         max_epochs=500,
-        validation_split=0.1,
         early_stopping=True,
-        minibatch_size=128,
         random_negative_affinity_min=20000.0,).extend(
-        RandomNegativePeptides.hyperparameter_defaults
+        RandomNegativePeptides.hyperparameter_defaults).extend(
+        MultiallelicMassSpecBatchGenerator.hyperparameter_defaults
     )
     """
     Hyperparameters for neural network training.
@@ -366,12 +366,6 @@ class Class1LigandomePredictor(object):
 
         peptide_input = self.peptides_to_network_input(encodable_peptides)
 
-        validation_items = numpy.random.choice(
-            len(labels),
-            int(self.hyperparameters['validation_split'] * len(labels)))
-        validation_mask = numpy.zeros(len(labels), dtype=bool)
-        validation_mask[validation_items] = True
-
         # Optional optimization
         (allele_encoding_input, allele_representations) = (
             self.allele_encoding_to_network_input(allele_encoding))
@@ -403,10 +397,6 @@ class Class1LigandomePredictor(object):
                 allele_encoding.max_alleles_per_experiment),
             borrow_from=allele_encoding.allele_encoding)
         num_random_negatives = random_negatives_planner.get_total_count()
-        validation_mask_with_random_negatives = numpy.concatenate([
-            numpy.tile(False, num_random_negatives),
-            validation_mask
-        ])
 
         # Reverse inequalities because from_ic50() flips the direction
         # (i.e. lower affinity results in higher y values).
@@ -466,6 +456,37 @@ class Class1LigandomePredictor(object):
         if verbose:
             self.network.summary()
 
+        batch_generator = MultiallelicMassSpecBatchGenerator(
+            MultiallelicMassSpecBatchGenerator.hyperparameter_defaults.subselect(
+                self.hyperparameters))
+        start = time.time()
+        batch_generator.plan(
+            affinities_mask=numpy.concatenate([
+                numpy.tile(True, num_random_negatives),
+                affinities_mask
+            ]),
+            experiment_names=numpy.concatenate([
+                numpy.tile(None, num_random_negatives),
+                allele_encoding.experiment_names
+            ]),
+            alleles_matrix=numpy.concatenate([
+                random_negatives_allele_encoding.alleles,
+                allele_encoding.alleles,
+            ]),
+            is_binder=numpy.concatenate([
+                numpy.tile(False, num_random_negatives),
+                numpy.where(affinities_mask, labels, to_ic50(labels)) < 1000.0
+            ]),
+            potential_validation_mask=numpy.concatenate([
+                numpy.tile(False, num_random_negatives),
+                numpy.tile(True, len(labels))
+            ]),
+        )
+        if verbose:
+            print("Generated batch generation plan in %0.2f sec." % (
+                time.time() - start))
+            print(batch_generator.summary())
+
         min_val_loss_iteration = None
         min_val_loss = None
         last_progress_print = 0
@@ -519,27 +540,22 @@ class Class1LigandomePredictor(object):
                         "peptide"
                     ][:num_random_negatives] = random_negative_peptides_encoding
 
-            (train_generator, train_batches, test_generator, test_batches) = (
-                self.train_and_test_generators(
+            (train_generator, test_generator) = (
+                batch_generator.get_train_and_test_generators(
                     x_dict=x_dict_with_random_negatives,
                     y_list=[encoded_y1, encoded_y2, encoded_y2],
-                    batch_size=self.hyperparameters['minibatch_size'],
-                    validation_mask=validation_mask_with_random_negatives,
-                    experiment_names=numpy.concatenate([
-                        numpy.tile(None, num_random_negatives),
-                        allele_encoding.experiment_names
-                    ])))
+                    epochs=1))
             self.assert_allele_representations_hash(allele_representations_hash)
             fit_history = self.network.fit_generator(
                 train_generator,
-                steps_per_epoch=train_batches,
+                steps_per_epoch=batch_generator.num_train_batches,
                 epochs=i + 1,
                 initial_epoch=i,
                 verbose=verbose,
                 use_multiprocessing=False,
                 workers=0,
                 validation_data=test_generator,
-                validation_steps=test_batches)
+                validation_steps=batch_generator.num_test_batches)
 
             """
             fit_history = self.network.fit(
@@ -575,7 +591,7 @@ class Class1LigandomePredictor(object):
                            min_val_loss_iteration)).strip())
                 last_progress_print = time.time()
 
-            if self.hyperparameters['validation_split']:
+            if batch_generator.num_test_batches:
                 #import ipdb ; ipdb.set_trace()
                 val_loss = fit_info['val_loss'][-1]
                 if min_val_loss is None or (
@@ -609,162 +625,6 @@ class Class1LigandomePredictor(object):
         fit_info["num_points"] = len(labels)
         self.fit_info.append(dict(fit_info))
 
-    @classmethod
-    def train_and_test_generators(
-            cls,
-            x_dict,
-            y_list,
-            batch_size,
-            validation_mask,
-            experiment_names):
-
-        points = len(y_list[0])
-        train_x_dict = {}
-        test_x_dict = {}
-        for (key, value) in x_dict.items():
-            train_x_dict[key] = value[~validation_mask]
-            test_x_dict[key] = value[validation_mask]
-
-        train_y_list = []
-        test_y_list = []
-        for value in y_list:
-            train_y_list.append(value[~validation_mask])
-            test_y_list.append(value[validation_mask])
-
-        train_generator = cls.batch_generator(
-            x_dict=train_x_dict,
-            y_list=train_y_list,
-            batch_size=batch_size,
-            experiment_names=experiment_names[~validation_mask])
-        test_generator = cls.batch_generator(
-            x_dict=test_x_dict,
-            y_list=test_y_list,
-            batch_size=batch_size,
-            experiment_names=experiment_names[validation_mask])
-
-        train_batches = next(train_generator)
-        test_batches = next(test_generator)
-
-        return (train_generator, train_batches, test_generator, test_batches)
-
-    @staticmethod
-    def batch_generator(x_dict, y_list, batch_size, experiment_names, affinity_fraction_for_mass_spec_batches=0.5):
-        # Each batch should have a mix of:
-        #   - random negative peptides
-        #   - affinity measurements (binder + non-binder)
-        #   - multiallelic mass spec
-        start = time.time()
-        df = pandas.DataFrame({"experiment": experiment_names})
-        df["unused"] = True
-        df["mass_spec_label"] = y_list[1]
-        assert set(
-            df.loc[~df.experiment.isnull()].mass_spec_label.unique()) == {
-                0.0, 1.0
-            }, df.loc[~df.experiment.isnull()].mass_spec_label.unique()
-        hit_rate = df.loc[~df.experiment.isnull()].mass_spec_label.mean()
-        affinities_per_batch = int(affinity_fraction_for_mass_spec_batches * batch_size)
-        mass_spec_per_batch = batch_size - affinities_per_batch
-
-        hits_per_mass_spec_batch = int(hit_rate * mass_spec_per_batch)
-        decoys_per_mass_spec_batch = (
-            mass_spec_per_batch - hits_per_mass_spec_batch)
-
-        print("affinity count", affinities_per_batch)
-        print("mass_spec count", mass_spec_per_batch,hits_per_mass_spec_batch, decoys_per_mass_spec_batch )
-
-        # Mixed mass spec / affinity batches_generator
-        experiments = df.experiment.unique()
-        batch_indices = []
-        batch_descriptions = []
-        for experiment in experiments:
-            if experiment is None:
-                continue
-            while True:
-                experiment_df = df.loc[
-                    df.unused & (df.experiment == experiment)]
-                if len(experiment_df) == 0:
-                    break
-
-                affinities_df = df.loc[df.unused & df.experiment.isnull()]
-                affinities_for_this_batch = min(
-                    affinities_per_batch, len(affinities_df))
-                mass_spec_for_this_batch = (
-                    batch_size - affinities_for_this_batch)
-                if len(experiment_df) < mass_spec_for_this_batch:
-                    mass_spec_for_this_batch = len(experiment_df)
-                    affinities_for_this_batch = (
-                            batch_size - mass_spec_for_this_batch)
-                    if affinities_for_this_batch < len(affinities_df):
-                        # For mass spec, we only do whole batches_generator, since it's
-                        # unclear how our pairwise loss would interact with
-                        # a smaller batch.
-                        break
-
-                mass_spec_labels = y_list[1][experiment_df.index.values]
-                assert ((mass_spec_labels == 0) | (mass_spec_labels == 1)).all(), mass_spec_labels
-
-                to_use_list = []
-
-                # sample hits
-                to_use = experiment_df.sample(
-                    n=hits_per_mass_spec_batch,
-                    weights=experiment_df.mass_spec_label + 1e-10,
-                    replace=False)
-                to_use_list.append(to_use.index.values)
-
-                # sample decoys
-                to_use = experiment_df.loc[
-                    ~experiment_df.index.isin(to_use.index)
-                ].sample(
-                    n=decoys_per_mass_spec_batch,
-                    weights=(1 - experiment_df.mass_spec_label) + 1e-10,
-                    replace=False)
-                to_use_list.append(to_use.index.values)
-
-                # sample affinities
-                to_use = affinities_df.sample(
-                    n=affinities_for_this_batch,
-                    replace=False)
-                to_use_list.append(to_use.index.values)
-
-                to_use_indices = numpy.concatenate(to_use_list)
-                df.loc[to_use_indices, "unused"] = False
-                batch_indices.append(to_use_indices)
-                batch_descriptions.append("multiallelic-mass-spec")
-
-        # Affinities-only batches_generator
-        affinities_df = df.loc[df.unused & df.experiment.isnull()]
-        while len(affinities_df) > 0:
-            if len(affinities_df) <= batch_size:
-                to_use = affinities_df
-            else:
-                to_use = affinities_df.sample(n=batch_size, replace=False)
-            df.loc[to_use.index, "unused"] = False
-            batch_indices.append(to_use.index)
-            affinities_df = df.loc[df.unused & df.experiment.isnull()]
-            batch_descriptions.append("affinities-only")
-
-        numpy.random.shuffle(batch_indices)
-        print("Planning %d batches_generator took" % len(batch_indices), time.time() - start, "sec")
-        print("remaining unused: ")
-        print(df.loc[df.unused].experiment.fillna("[affinity]").value_counts())
-        print("batch descriptions")
-        print(pandas.Series(batch_descriptions).value_counts())
-        #import ipdb ; ipdb.set_trace()
-        yield len(batch_indices)
-        for indices in batch_indices:
-            x_dict_batch = {}
-            for (key, value) in x_dict.items():
-                x_dict_batch[key] = value[indices]
-            y_list_batch = []
-            for value in y_list:
-                y_list_batch.append(value[indices])
-
-            yield (x_dict_batch, y_list_batch)
-        #import ipdb ; ipdb.set_trace()
-        #yield None
-
-
     def predict(
             self,
             peptides,
diff --git a/test/test_multiallelic_mass_spec_batch_generator.py b/test/test_batch_generator.py
similarity index 95%
rename from test/test_multiallelic_mass_spec_batch_generator.py
rename to test/test_batch_generator.py
index 778add9e..ab3dcf52 100644
--- a/test/test_multiallelic_mass_spec_batch_generator.py
+++ b/test/test_batch_generator.py
@@ -1,7 +1,7 @@
 import pandas
 import numpy
 
-from mhcflurry.multiallelic_mass_spec_batch_generator import (
+from mhcflurry.batch_generator import (
     MultiallelicMassSpecBatchGenerator)
 
 from numpy.testing import assert_equal
@@ -56,7 +56,6 @@ def test_basic():
 
     for ((kind, batch_num), batch_df) in df.groupby(["kind", "batch"]):
         if not batch_df.affinities_mask.all():
-            print(batch_df)
             # Test each batch has at most one multiallelic ms experiment.
             assert_equal(
                 batch_df.loc[
-- 
GitLab