From 5545f31f1c3198cb9f6cdb30ad1a7b629bca3617 Mon Sep 17 00:00:00 2001 From: Tim O'Donnell <timodonnell@gmail.com> Date: Wed, 19 Sep 2018 13:40:40 -0400 Subject: [PATCH] add missing file --- mhcflurry/allele_encoding_transforms.py | 77 +++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 mhcflurry/allele_encoding_transforms.py diff --git a/mhcflurry/allele_encoding_transforms.py b/mhcflurry/allele_encoding_transforms.py new file mode 100644 index 00000000..30092586 --- /dev/null +++ b/mhcflurry/allele_encoding_transforms.py @@ -0,0 +1,77 @@ +import time + +import pandas +import sklearn.decomposition + + +class AlleleEncodingTransform(object): + def transform(self, data): + raise NotImplementedError() + + def get_fit(self): + """ + Get the fit for serialization, which must be in the form of one or more + dataframes. + + Returns + ------- + dict : string to DataFrame + """ + raise NotImplementedError() + + def restore_fit(self, fit): + """ + Restore a serialized fit. + + Parameters + ---------- + fit : string to DataFrame + """ + + +class PCATransform(AlleleEncodingTransform): + name = 'pca' + serialization_keys = ['data'] + + def __init__(self): + self.model = None + + def is_fit(self): + return self.model is not None + + def fit(self, allele_representations): + self.model = sklearn.decomposition.PCA() + shape = allele_representations.shape + flattened = allele_representations.reshape( + (shape[0], shape[1] * shape[2])) + print("Fitting PCA allele encoding transform on data of shape: %s" % ( + str(flattened.shape))) + start = time.time() + self.model.fit(flattened) + print("Fit complete in %0.3f sec." % (time.time() - start)) + + def get_fit(self): + df = pandas.DataFrame(self.model.components_) + df.columns = ["pca_%s" % c for c in df.columns] + df["mean"] = self.model.mean_ + return { + 'data': df + } + + def restore_fit(self, fit): + assert list(fit) == ['data'] + data = fit["data"] + self.model = sklearn.decomposition.PCA() + self.model.mean_ = data["mean"].values + self.model.components_ = data.drop(columns="mean").values + + def transform(self, allele_representations): + if not self.is_fit(): + self.fit(allele_representations) + flattened = allele_representations.reshape( + (allele_representations.shape[0], + allele_representations.shape[1] * allele_representations.shape[2])) + return self.model.transform(flattened) + + +TRANSFORMS = dict((klass.name, klass) for klass in [PCATransform]) -- GitLab