Skip to content
Snippets Groups Projects
allele_encoding_transforms.py 2.13 KiB
Newer Older
Tim O'Donnell's avatar
Tim O'Donnell committed
import time

import pandas
import sklearn.decomposition


class AlleleEncodingTransform(object):
Tim O'Donnell's avatar
Tim O'Donnell committed
    def transform(self, allele_encoding, argument=None):
Tim O'Donnell's avatar
Tim O'Donnell committed
        raise NotImplementedError()

    def get_fit(self):
        """
        Get the fit for serialization, which must be in the form of one or more
        dataframes.

        Returns
        -------
        dict : string to DataFrame
        """
        raise NotImplementedError()

    def restore_fit(self, fit):
        """
        Restore a serialized fit.

        Parameters
        ----------
Tim O'Donnell's avatar
Tim O'Donnell committed
        fit : string to array
Tim O'Donnell's avatar
Tim O'Donnell committed
        """
Tim O'Donnell's avatar
Tim O'Donnell committed
        pass
Tim O'Donnell's avatar
Tim O'Donnell committed


class PCATransform(AlleleEncodingTransform):
    name = 'pca'
Tim O'Donnell's avatar
Tim O'Donnell committed
    serialization_keys = ['mean', 'components']
Tim O'Donnell's avatar
Tim O'Donnell committed

    def __init__(self):
        self.model = None

    def is_fit(self):
        return self.model is not None

    def fit(self, allele_representations):
        self.model = sklearn.decomposition.PCA()
        shape = allele_representations.shape
        flattened = allele_representations.reshape(
            (shape[0], shape[1] * shape[2]))
        print("Fitting PCA allele encoding transform on data of shape: %s" % (
            str(flattened.shape)))
        start = time.time()
        self.model.fit(flattened)
        print("Fit complete in %0.3f sec." % (time.time() - start))

    def get_fit(self):
        return {
Tim O'Donnell's avatar
Tim O'Donnell committed
            'mean': self.model.mean_,
            'components': self.model.components_,
Tim O'Donnell's avatar
Tim O'Donnell committed
        }

    def restore_fit(self, fit):
        self.model = sklearn.decomposition.PCA()
Tim O'Donnell's avatar
Tim O'Donnell committed
        self.model.mean_ = fit["mean"]
        self.model.components_ = fit["components"]
Tim O'Donnell's avatar
Tim O'Donnell committed

Tim O'Donnell's avatar
Tim O'Donnell committed
    def transform(self, allele_encoding, underlying_representation):
        allele_representations = allele_encoding.allele_representations(
            underlying_representation)
Tim O'Donnell's avatar
Tim O'Donnell committed
        if not self.is_fit():
            self.fit(allele_representations)
        flattened = allele_representations.reshape(
            (allele_representations.shape[0],
             allele_representations.shape[1] * allele_representations.shape[2]))
        return self.model.transform(flattened)


TRANSFORMS = dict((klass.name, klass) for klass in [PCATransform])