Skip to content

Commit

Permalink
Merge pull request #68 from AnjaSei/master
Browse files Browse the repository at this point in the history
Global Average Pooling
  • Loading branch information
AvantiShri authored Oct 31, 2018
2 parents beb8923 + 8b9b7d5 commit 5cc76a8
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 2 deletions.
2 changes: 1 addition & 1 deletion deeplift.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: deeplift
Version: 0.6.7.1
Version: 0.6.8.0
Summary: DeepLIFT (Deep Learning Important FeaTures)
Home-page: https://github.com/kundajelab/deeplift
License: UNKNOWN
Expand Down
7 changes: 7 additions & 0 deletions deeplift/conversion/kerasapi_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,12 @@ def maxpool1d_conversion(config, name, verbose,
**pool1d_kwargs)]


def globalavgpooling1d_conversion(config, name, verbose, **kwargs):
return [layers.GlobalAvgPool1D(
name=name,
verbose=verbose)]


def avgpool1d_conversion(config, name, verbose, **kwargs):
pool1d_kwargs = prep_pool1d_kwargs(
config=config,
Expand Down Expand Up @@ -305,6 +311,7 @@ def layer_name_to_conversion_function(layer_name):
'maxpooling1d': maxpool1d_conversion,
'globalmaxpooling1d': globalmaxpooling1d_conversion,
'averagepooling1d': avgpool1d_conversion,
'globalaveragepooling1d': globalavgpooling1d_conversion,

'conv2d': conv2d_conversion,
'maxpooling2d': maxpool2d_conversion,
Expand Down
31 changes: 31 additions & 0 deletions deeplift/layers/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,37 @@ def _get_mxts_increments_for_inputs(self):
return pos_mxts_increments, neg_mxts_increments


class GlobalAvgPool1D(SingleInputMixin, Node):

def __init__(self, **kwargs):
super(GlobalAvgPool1D, self).__init__(**kwargs)

def _compute_shape(self, input_shape):
assert len(input_shape)==3
shape_to_return = [None, input_shape[-1]]
return shape_to_return

def _build_activation_vars(self, input_act_vars):
return tf.reduce_mean(input_act_vars, axis=1)

def _build_pos_and_neg_contribs(self):
inp_pos_contribs, inp_neg_contribs =\
self._get_input_pos_and_neg_contribs()
pos_contribs = self._build_activation_vars(inp_pos_contribs)
neg_contribs = self._build_activation_vars(inp_neg_contribs)
return pos_contribs, neg_contribs

def _grad_op(self, out_grad):
width = self._get_input_activation_vars().get_shape().as_list()[1]
mask = tf.ones_like(self._get_input_activation_vars()) / float(width)
return tf.multiply(tf.expand_dims(out_grad, axis=1), mask)

def _get_mxts_increments_for_inputs(self):
pos_mxts_increments = self._grad_op(self.get_pos_mxts())
neg_mxts_increments = self._grad_op(self.get_neg_mxts())
return pos_mxts_increments, neg_mxts_increments


class Pool2D(SingleInputMixin, Node):

def __init__(self, pool_size, strides, padding, data_format, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Implements the methods in "Learning Important Features Through Propagating Activation Differences" by Shrikumar, Greenside & Kundaje, as well as other commonly-used methods such as gradients, guided backprop and integrated gradients. See https://github.com/kundajelab/deeplift for documentation and FAQ.
""",
url='https://github.com/kundajelab/deeplift',
version='0.6.7.1',
version='0.6.8.0',
packages=['deeplift',
'deeplift.layers', 'deeplift.visualization',
'deeplift.conversion'],
Expand Down

0 comments on commit 5cc76a8

Please sign in to comment.