Skip to content
Snippets Groups Projects
masked_slice.py 1.09 KiB
Newer Older
import keras.layers

class MaskedSlice(keras.layers.Lambda):
    """
    Takes an embedded representation of a sentence with dims
    (n_samples, max_length, n_dims)
    where each sample is masked to allow for variable-length inputs.
    Returns a tensor of shape (n_samples, n_dims) which are the first
    and last vectors in each sentence.
    """
    supports_masking = True

    def __init__(
            self,
            time_start,
            time_end,
            *args,
            **kwargs):
        assert time_start >= 0
        assert time_end >= 0
        self.time_start = time_start
        self.time_end = time_end
        super(MaskedSlice, self).__init__(*args, **kwargs)

    def call(self, x, mask):
        return x[:, self.time_start:self.time_end, :]

    def compute_mask(self, x, mask):
        return mask[:, self.time_start:self.time_end, :]

    def get_output_shape_for(self, input_shape):
        assert len(input_shape) == 3
        output_shape = (
            input_shape[0],
            self.time_end - self.time_start + 1,
            input_shape[2])
        return output_shape