From 5cf8563fa8178632a1c9f5a67997e46fb7a160b5 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Thu, 13 Jun 2019 12:17:31 -0400
Subject: [PATCH] drop allele encoding transforms

---
 mhcflurry/allele_encoding.py                 | 46 ++----------
 mhcflurry/allele_encoding_transforms.py      | 76 --------------------
 mhcflurry/class1_affinity_predictor.py       | 36 +---------
 mhcflurry/train_pan_allele_models_command.py |  7 +-
 test/test_allele_encoding.py                 | 19 -----
 5 files changed, 12 insertions(+), 172 deletions(-)
 delete mode 100644 mhcflurry/allele_encoding_transforms.py

diff --git a/mhcflurry/allele_encoding.py b/mhcflurry/allele_encoding.py
index 20350da0..bfd61ec5 100644
--- a/mhcflurry/allele_encoding.py
+++ b/mhcflurry/allele_encoding.py
@@ -1,9 +1,6 @@
-from six import callable
-
 import pandas
 
 from . import amino_acid
-from .allele_encoding_transforms import TRANSFORMS
 
 
 class AlleleEncoding(object):
@@ -11,7 +8,6 @@ class AlleleEncoding(object):
             self,
             alleles=None,
             allele_to_sequence=None,
-            transforms=None,
             borrow_from=None):
         """
         A place to cache encodings for a (potentially large) sequence of alleles.
@@ -30,11 +26,6 @@ class AlleleEncoding(object):
         self.borrow_from = borrow_from
         self.allele_to_sequence = allele_to_sequence
 
-        if transforms is None:
-            transforms = dict(
-                (name, klass()) for (name, klass) in TRANSFORMS.items())
-        self.transforms = transforms
-
         if self.borrow_from is None:
             assert allele_to_sequence is not None
             all_alleles = (
@@ -54,7 +45,6 @@ class AlleleEncoding(object):
             self.allele_to_index = borrow_from.allele_to_index
             self.sequences = borrow_from.sequences
             self.allele_to_sequence = borrow_from.allele_to_sequence
-            self.transforms = borrow_from.transforms
 
         if alleles is not None:
             assert all(
@@ -76,34 +66,12 @@ class AlleleEncoding(object):
             "allele_representations",
             encoding_name)
         if cache_key not in self.encoding_cache:
-            if ":" in encoding_name:
-                # Transform
-                pieces = encoding_name.split(":", 3)
-                if pieces[0] != "transform":
-                    raise RuntimeError(
-                        "Expected 'transform' but saw: %s" % pieces[0])
-                if len(pieces) == 1:
-                    raise RuntimeError("Expected: 'transform:<name>[:argument]")
-                transform_name = pieces[1]
-                argument = None if len(pieces) == 2 else pieces[2]
-                try:
-                    transform = self.transforms[transform_name]
-                except KeyError:
-                    raise KeyError(
-                        "Unsupported transform: %s. Supported transforms: %s" % (
-                            transform_name,
-                            " ".join(self.transforms) if self.transforms else "(none)"))
-                vector_encoded = (
-                    transform.transform(self) if argument is None
-                    else transform.transform(self, argument))
-            else:
-                # No transform.
-                index_encoded_matrix = amino_acid.index_encoding(
-                    self.sequences.values,
-                    amino_acid.AMINO_ACID_INDEX)
-                vector_encoded = amino_acid.fixed_vectors_encoding(
-                    index_encoded_matrix,
-                    amino_acid.ENCODING_DATA_FRAMES[encoding_name])
+            index_encoded_matrix = amino_acid.index_encoding(
+                self.sequences.values,
+                amino_acid.AMINO_ACID_INDEX)
+            vector_encoded = amino_acid.fixed_vectors_encoding(
+                index_encoded_matrix,
+                amino_acid.ENCODING_DATA_FRAMES[encoding_name])
             self.encoding_cache[cache_key] = vector_encoded
         return self.encoding_cache[cache_key]
 
@@ -117,8 +85,6 @@ class AlleleEncoding(object):
             One of "BLOSUM62", "one-hot", etc. Full list of supported vector
             encodings is given by available_vector_encodings() in amino_acid.
 
-            Also supported are names like pca:BLOSUM62, which would run the
-            "pca" transform on BLOSUM62-encoded sequences.
         Returns
         -------
         numpy.array with shape (num sequences, sequence length, m) where m is
diff --git a/mhcflurry/allele_encoding_transforms.py b/mhcflurry/allele_encoding_transforms.py
deleted file mode 100644
index 236066eb..00000000
--- a/mhcflurry/allele_encoding_transforms.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import time
-
-import pandas
-import sklearn.decomposition
-
-
-class AlleleEncodingTransform(object):
-    def transform(self, allele_encoding, argument=None):
-        raise NotImplementedError()
-
-    def get_fit(self):
-        """
-        Get the fit for serialization, which must be in the form of one or more
-        dataframes.
-
-        Returns
-        -------
-        dict : string to DataFrame
-        """
-        raise NotImplementedError()
-
-    def restore_fit(self, fit):
-        """
-        Restore a serialized fit.
-
-        Parameters
-        ----------
-        fit : string to array
-        """
-        pass
-
-
-class PCATransform(AlleleEncodingTransform):
-    name = 'pca'
-    serialization_keys = ['mean', 'components']
-
-    def __init__(self):
-        self.model = None
-
-    def is_fit(self):
-        return self.model is not None
-
-    def fit(self, allele_representations):
-        self.model = sklearn.decomposition.PCA()
-        shape = allele_representations.shape
-        flattened = allele_representations.reshape(
-            (shape[0], shape[1] * shape[2]))
-        print("Fitting PCA allele encoding transform on data of shape: %s" % (
-            str(flattened.shape)))
-        start = time.time()
-        self.model.fit(flattened)
-        print("Fit complete in %0.3f sec." % (time.time() - start))
-
-    def get_fit(self):
-        return {
-            'mean': self.model.mean_,
-            'components': self.model.components_,
-        }
-
-    def restore_fit(self, fit):
-        self.model = sklearn.decomposition.PCA()
-        self.model.mean_ = fit["mean"]
-        self.model.components_ = fit["components"]
-
-    def transform(self, allele_encoding, underlying_representation):
-        allele_representations = allele_encoding.allele_representations(
-            underlying_representation)
-        if not self.is_fit():
-            self.fit(allele_representations)
-        flattened = allele_representations.reshape(
-            (allele_representations.shape[0],
-             allele_representations.shape[1] * allele_representations.shape[2]))
-        return self.model.transform(flattened)
-
-
-TRANSFORMS = dict((klass.name, klass) for klass in [PCATransform])
diff --git a/mhcflurry/class1_affinity_predictor.py b/mhcflurry/class1_affinity_predictor.py
index d6410b49..3497e25d 100644
--- a/mhcflurry/class1_affinity_predictor.py
+++ b/mhcflurry/class1_affinity_predictor.py
@@ -25,7 +25,6 @@ from .regression_target import to_ic50
 from .version import __version__
 from .ensemble_centrality import CENTRALITY_MEASURES
 from .allele_encoding import AlleleEncoding
-from .allele_encoding_transforms import TRANSFORMS as ALLELE_ENCODING_TRANSFORMS
 
 
 # Default function for combining predictions across models in an ensemble.
@@ -47,7 +46,6 @@ class Class1AffinityPredictor(object):
             allele_to_allele_specific_models=None,
             class1_pan_allele_models=None,
             allele_to_sequence=None,
-            allele_encoding_transforms=None,
             manifest_df=None,
             allele_to_percent_rank_transform=None,
             metadata_dataframes=None):
@@ -83,8 +81,6 @@ class Class1AffinityPredictor(object):
             class1_pan_allele_models = []
 
         self.allele_to_sequence = allele_to_sequence
-        self.allele_encoding_transforms = (
-            allele_encoding_transforms if allele_encoding_transforms else {})
         self.master_allele_encoding = None
         if class1_pan_allele_models:
             assert self.allele_to_sequence
@@ -378,15 +374,6 @@ class Class1AffinityPredictor(object):
                 join(models_dir, "allele_sequences.csv"), index=False)
             logging.info("Wrote: %s" % join(models_dir, "allele_sequences.csv"))
 
-        # Save allele encoding transforms
-        for transform in self.allele_encoding_transforms.values():
-            if transform.is_fit():
-                fit_data = transform.get_fit()
-                assert set(fit_data) == set(transform.serialization_keys)
-                target_path = join(models_dir, "%s.npz" % transform.name)
-                numpy.savez(target_path, **fit_data)
-                logging.info("Wrote: %s" % target_path)
-
         if self.allele_to_percent_rank_transform:
             percent_ranks_df = None
             for (allele, transform) in self.allele_to_percent_rank_transform.items():
@@ -454,26 +441,6 @@ class Class1AffinityPredictor(object):
                 join(models_dir, "allele_sequences.csv"),
                 index_col="allele").sequence.to_dict()
 
-        # Load allele encoding transforms
-        allele_encoding_transforms = {}
-        for transform_name in ALLELE_ENCODING_TRANSFORMS:
-            klass = ALLELE_ENCODING_TRANSFORMS[transform_name]
-            transform = klass()
-            target_path = join(models_dir, "%s.npz" % transform_name)
-            if exists(target_path):
-                with numpy.load(target_path) as loaded:
-                    restored_fit = dict(
-                        (key, loaded[key]) for key in loaded.keys())
-                if set(restored_fit) != set(klass.serialization_keys):
-                    logging.warning(
-                        "Missing some allele encoding transform serialization "
-                        "data from %s. Found: %s. Expected: %s." % (
-                            models_dir,
-                            str(set(restored_fit)),
-                            str(set(klass.serialization_keys))))
-                transform.restore_fit(restored_fit)
-            allele_encoding_transforms[transform_name] = transform
-
         allele_to_percent_rank_transform = {}
         percent_ranks_path = join(models_dir, "percent_ranks.csv")
         if exists(percent_ranks_path):
@@ -543,8 +510,7 @@ class Class1AffinityPredictor(object):
                     self.master_allele_encoding.allele_to_sequence !=
                     self.allele_to_sequence):
             self.master_allele_encoding = AlleleEncoding(
-                allele_to_sequence=self.allele_to_sequence,
-                transforms=self.allele_encoding_transforms)
+                allele_to_sequence=self.allele_to_sequence)
         return self.master_allele_encoding
 
     def fit_allele_specific_predictors(
diff --git a/mhcflurry/train_pan_allele_models_command.py b/mhcflurry/train_pan_allele_models_command.py
index 88ca3e2a..1d67f003 100644
--- a/mhcflurry/train_pan_allele_models_command.py
+++ b/mhcflurry/train_pan_allele_models_command.py
@@ -415,9 +415,7 @@ def train_model(
         progress_print_interval,
         predictor,
         save_to):
-
     import keras.backend as K
-    K.clear_session()  # release memory
 
     df = GLOBAL_DATA["train_data"]
     folds_df = GLOBAL_DATA["folds_df"]
@@ -531,6 +529,11 @@ def train_model(
         predictor.manifest_df.shape[0], len(predictor.class1_pan_allele_models))
     predictor.clear_cache()
 
+    # Delete the network and release memory
+    model.update_network_description()  # save weights and config
+    model._network = None  # release tensorflow network
+    K.clear_session()  # release graph
+
     return predictor
 
 
diff --git a/test/test_allele_encoding.py b/test/test_allele_encoding.py
index a6c1a972..caba3bd3 100644
--- a/test/test_allele_encoding.py
+++ b/test/test_allele_encoding.py
@@ -2,10 +2,7 @@ import time
 
 from mhcflurry.allele_encoding import AlleleEncoding
 from mhcflurry.amino_acid import BLOSUM62_MATRIX
-from nose.tools import eq_
 from numpy.testing import assert_equal
-import numpy
-import pandas
 
 
 def test_allele_encoding_speed():
@@ -37,19 +34,3 @@ def test_allele_encoding_speed():
     start = time.time()
     encoding1 = encoding.fixed_length_vector_encoded_sequences("BLOSUM62")
     print("Long encoding in %0.2f sec." % (time.time() - start))
-
-
-def test_pca():
-    encoding = AlleleEncoding(
-        ["A*02:01", "A*02:03", "A*02:01"],
-        {
-            "A*02:01": "AC",
-            "A*02:03": "AE",
-        }
-    )
-    encoded1 = encoding.fixed_length_vector_encoded_sequences(
-        "transform:pca:BLOSUM62")
-
-    numpy.testing.assert_array_equal(encoded1[0], encoded1[2])
-    assert not numpy.array_equal(encoded1[0], encoded1[1])
-    print(encoded1)
\ No newline at end of file
-- 
GitLab