From bdab816b0b7942741d41c1a279f9d0d614c592ee Mon Sep 17 00:00:00 2001 From: AnjaSei Date: Wed, 31 Oct 2018 10:35:27 +0100 Subject: [PATCH 1/3] defined conversion function for global average pooling --- deeplift/conversion/kerasapi_conversion.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/deeplift/conversion/kerasapi_conversion.py b/deeplift/conversion/kerasapi_conversion.py index 2b4b9f2..a3b7aaf 100644 --- a/deeplift/conversion/kerasapi_conversion.py +++ b/deeplift/conversion/kerasapi_conversion.py @@ -262,6 +262,12 @@ def maxpool1d_conversion(config, name, verbose, **pool1d_kwargs)] +def globalavgpooling1d_conversion(config, name, verbose, **kwargs): + return [layers.GlobalAvgPool1D( + name=name, + verbose=verbose)] + + def avgpool1d_conversion(config, name, verbose, **kwargs): pool1d_kwargs = prep_pool1d_kwargs( config=config, @@ -305,6 +311,7 @@ def layer_name_to_conversion_function(layer_name): 'maxpooling1d': maxpool1d_conversion, 'globalmaxpooling1d': globalmaxpooling1d_conversion, 'averagepooling1d': avgpool1d_conversion, + 'globalaveragepooling1d': globalavgpooling1d_conversion, 'conv2d': conv2d_conversion, 'maxpooling2d': maxpool2d_conversion, From 290ff7880ab49089d17155ff71d71b369bb8fe99 Mon Sep 17 00:00:00 2001 From: AnjaSei Date: Wed, 31 Oct 2018 10:39:17 +0100 Subject: [PATCH 2/3] added global average pooling layer --- deeplift/layers/pooling.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/deeplift/layers/pooling.py b/deeplift/layers/pooling.py index e4a354f..c6b0515 100644 --- a/deeplift/layers/pooling.py +++ b/deeplift/layers/pooling.py @@ -171,6 +171,37 @@ def _get_mxts_increments_for_inputs(self): return pos_mxts_increments, neg_mxts_increments +class GlobalAvgPool1D(SingleInputMixin, Node): + + def __init__(self, **kwargs): + super(GlobalAvgPool1D, self).__init__(**kwargs) + + def _compute_shape(self, input_shape): + assert len(input_shape)==3 + shape_to_return = [None, input_shape[-1]] + return shape_to_return + + def _build_activation_vars(self, input_act_vars): + return tf.reduce_mean(input_act_vars, axis=1) + + def _build_pos_and_neg_contribs(self): + inp_pos_contribs, inp_neg_contribs =\ + self._get_input_pos_and_neg_contribs() + pos_contribs = self._build_activation_vars(inp_pos_contribs) + neg_contribs = self._build_activation_vars(inp_neg_contribs) + return pos_contribs, neg_contribs + + def _grad_op(self, out_grad): + width = self._get_input_activation_vars().get_shape().as_list()[1] + mask = tf.ones_like(self._get_input_activation_vars()) / float(width) + return tf.multiply(tf.expand_dims(out_grad, axis=1), mask) + + def _get_mxts_increments_for_inputs(self): + pos_mxts_increments = self._grad_op(self.get_pos_mxts()) + neg_mxts_increments = self._grad_op(self.get_neg_mxts()) + return pos_mxts_increments, neg_mxts_increments + + class Pool2D(SingleInputMixin, Node): def __init__(self, pool_size, strides, padding, data_format, **kwargs): From 8b9b7d5255fd1b2eb7b23433eb4bc3110f87adbd Mon Sep 17 00:00:00 2001 From: AnjaSei Date: Wed, 31 Oct 2018 15:03:33 +0100 Subject: [PATCH 3/3] updated version number --- deeplift.egg-info/PKG-INFO | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deeplift.egg-info/PKG-INFO b/deeplift.egg-info/PKG-INFO index 4320e06..934e690 100644 --- a/deeplift.egg-info/PKG-INFO +++ b/deeplift.egg-info/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: deeplift -Version: 0.6.7.1 +Version: 0.6.8.0 Summary: DeepLIFT (Deep Learning Important FeaTures) Home-page: https://github.com/kundajelab/deeplift License: UNKNOWN diff --git a/setup.py b/setup.py index 2d06a3a..a973599 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ Implements the methods in "Learning Important Features Through Propagating Activation Differences" by Shrikumar, Greenside & Kundaje, as well as other commonly-used methods such as gradients, guided backprop and integrated gradients. See https://github.com/kundajelab/deeplift for documentation and FAQ. """, url='https://github.com/kundajelab/deeplift', - version='0.6.7.1', + version='0.6.8.0', packages=['deeplift', 'deeplift.layers', 'deeplift.visualization', 'deeplift.conversion'],