Newer
Older
Alex Rubinsteyn
committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import keras.layers
class MaskedSlice(keras.layers.Lambda):
"""
Takes an embedded representation of a sentence with dims
(n_samples, max_length, n_dims)
where each sample is masked to allow for variable-length inputs.
Returns a tensor of shape (n_samples, n_dims) which are the first
and last vectors in each sentence.
"""
supports_masking = True
def __init__(
self,
time_start,
time_end,
*args,
**kwargs):
assert time_start >= 0
assert time_end >= 0
self.time_start = time_start
self.time_end = time_end
super(MaskedSlice, self).__init__(*args, **kwargs)
def call(self, x, mask):
return x[:, self.time_start:self.time_end, :]
def compute_mask(self, x, mask):
return mask[:, self.time_start:self.time_end, :]
def get_output_shape_for(self, input_shape):
assert len(input_shape) == 3
output_shape = (
input_shape[0],
self.time_end - self.time_start + 1,
input_shape[2])
return output_shape