From 5545f31f1c3198cb9f6cdb30ad1a7b629bca3617 Mon Sep 17 00:00:00 2001
From: Tim O'Donnell <timodonnell@gmail.com>
Date: Wed, 19 Sep 2018 13:40:40 -0400
Subject: [PATCH] add missing file

---
 mhcflurry/allele_encoding_transforms.py | 77 +++++++++++++++++++++++++
 1 file changed, 77 insertions(+)
 create mode 100644 mhcflurry/allele_encoding_transforms.py

diff --git a/mhcflurry/allele_encoding_transforms.py b/mhcflurry/allele_encoding_transforms.py
new file mode 100644
index 00000000..30092586
--- /dev/null
+++ b/mhcflurry/allele_encoding_transforms.py
@@ -0,0 +1,77 @@
+import time
+
+import pandas
+import sklearn.decomposition
+
+
+class AlleleEncodingTransform(object):
+    def transform(self, data):
+        raise NotImplementedError()
+
+    def get_fit(self):
+        """
+        Get the fit for serialization, which must be in the form of one or more
+        dataframes.
+
+        Returns
+        -------
+        dict : string to DataFrame
+        """
+        raise NotImplementedError()
+
+    def restore_fit(self, fit):
+        """
+        Restore a serialized fit.
+
+        Parameters
+        ----------
+        fit : string to DataFrame
+        """
+
+
+class PCATransform(AlleleEncodingTransform):
+    name = 'pca'
+    serialization_keys = ['data']
+
+    def __init__(self):
+        self.model = None
+
+    def is_fit(self):
+        return self.model is not None
+
+    def fit(self, allele_representations):
+        self.model = sklearn.decomposition.PCA()
+        shape = allele_representations.shape
+        flattened = allele_representations.reshape(
+            (shape[0], shape[1] * shape[2]))
+        print("Fitting PCA allele encoding transform on data of shape: %s" % (
+            str(flattened.shape)))
+        start = time.time()
+        self.model.fit(flattened)
+        print("Fit complete in %0.3f sec." % (time.time() - start))
+
+    def get_fit(self):
+        df = pandas.DataFrame(self.model.components_)
+        df.columns = ["pca_%s" % c for c in df.columns]
+        df["mean"] = self.model.mean_
+        return {
+            'data': df
+        }
+
+    def restore_fit(self, fit):
+        assert list(fit) == ['data']
+        data = fit["data"]
+        self.model = sklearn.decomposition.PCA()
+        self.model.mean_ = data["mean"].values
+        self.model.components_ = data.drop(columns="mean").values
+
+    def transform(self, allele_representations):
+        if not self.is_fit():
+            self.fit(allele_representations)
+        flattened = allele_representations.reshape(
+            (allele_representations.shape[0],
+             allele_representations.shape[1] * allele_representations.shape[2]))
+        return self.model.transform(flattened)
+
+
+TRANSFORMS = dict((klass.name, klass) for klass in [PCATransform])
-- 
GitLab