forked from SongDark/DeepFM_keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MyMeanPooling.py
31 lines (27 loc) · 1.01 KB
/
MyMeanPooling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
class MyMeanPool(Layer):
def __init__(self, axis, **kwargs):
self.supports_masking = True
self.axis = axis
super(MyMeanPool, self).__init__(**kwargs)
def compute_mask(self, input, input_mask=None):
# need not to pass the mask to next layers
return None
def call(self, x, mask=None):
if mask is not None:
if K.ndim(x)!=K.ndim(mask):
mask = K.repeat(mask, x.shape[-1])
mask = tf.transpose(mask, [0,2,1])
mask = K.cast(mask, K.floatx())
x = x * mask
return K.sum(x, axis=self.axis) / K.sum(mask, axis=self.axis)
else:
return K.mean(x, axis=self.axis)
def compute_output_shape(self, input_shape):
output_shape = []
for i in range(len(input_shape)):
if i!=self.axis:
output_shape.append(input_shape[i])
return tuple(output_shape)