Skip to content

Commit

Permalink
merged with dev. Reorganization applied
Browse files Browse the repository at this point in the history
Former-commit-id: e82e2cd [formerly 6f80488]
Former-commit-id: d50a4f9
  • Loading branch information
AvantiShri committed Oct 31, 2016
2 parents dad41d5 + 0a2ba1b commit 9614233
Show file tree
Hide file tree
Showing 25 changed files with 818 additions and 429 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,11 @@ Here are the steps necessary to implement the backward pass, which is where the
4. Compile the importance score computation function with

```python
deeplift.backend.function([input_layer.get_activation_vars()...],
deeplift.backend.function([input_layer.get_activation_vars()...,
input_layer.get_reference_vars()...],
blob_to_find_scores_for.get_target_contrib_vars())
```
- The first argument represents the inputs to the function and should be a list of one symbolic tensor for each input layer (this was explained under the instructions for compiling the forward pass).
- The first argument represents the inputs to the function and should be a list of one symbolic tensor for the activations of each input layer (as for the forward pass), followed by a list of one symbolic tensor for the references of each input layer
- The second argument represents the output of the function. In the example above, it is a single tensor containing the importance scores of a single blob, but it can also be a list of tensors if you wish to compute the scores for multiple blobs at once.
- Instead of `get_target_contrib_vars()` which returns the importance scores (in the case of `MxtsMode.DeepLIFT`, these are called "contribution scores"), you can use `get_mxts()` to get the multipliers.
5. Now you are ready to call the function to find the importance scores.
Expand Down
15 changes: 10 additions & 5 deletions deeplift/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def sum(x, axis):


def ones_like(x):
return T.ones_like(x)
return T.ones_like(x, dtype=theano.config.floatX)


def zeros_like(x):
Expand All @@ -79,7 +79,9 @@ def set_subtensor(subtensor, amnt):

def function(inputs, outputs, **kwargs):
return theano.function(inputs, outputs,
allow_input_downcast=True, **kwargs)
allow_input_downcast=True,
on_unused_input='ignore',
**kwargs)


def tensor_with_dims(num_dims, name):
Expand Down Expand Up @@ -132,27 +134,30 @@ def abs(inp):


def conv2d(inp, filters, border_mode, subsample):

inp = T.cast(inp, dtype=theano.config.floatX)
if (border_mode==BorderMode.same):
#'half' kernel width padding results in outputs of the same
#dimensions as input
border_mode=BorderMode.half
assert filters.shape[2]%2 == 1 and filter_shape[3]%2 == 1,\
"haven't handled even filter shapes for border mode 'half'"
return T.nnet.conv2d(input=inp,
filters=theano.shared(value=filters),
filters=T.cast(theano.shared(value=filters),
dtype=theano.config.floatX),
border_mode=border_mode,
subsample=subsample,
filter_shape=filters.shape)


def conv2d_grad(out_grad, conv_in, filters, border_mode, subsample):
out_grad=T.cast(out_grad, dtype=theano.config.floatX)
conv_op = T.nnet.conv.ConvOp(output_mode=border_mode,
dx=subsample[0],
dy=subsample[1])
inverse_conv2d = conv_op.grad(
(conv_in,
T.as_tensor_variable(filters)),
T.cast(T.as_tensor_variable(filters),
dtype=theano.config.floatX)),
(out_grad,))
#grad returns d_input and d_filters; we just care about
#the first
Expand Down
34 changes: 17 additions & 17 deletions deeplift/blobs/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@ class Activation(SingleInputMixin, OneDimOutputMixin, Node):
#if you tried to call its functions for a layer that was
#not actually one dimensional)

def __init__(self, mxts_mode,
**kwargs):
self.mxts_mode = mxts_mode
def __init__(self, nonlinear_mxts_mode, **kwargs):
self.nonlinear_mxts_mode = nonlinear_mxts_mode
super(Activation, self).__init__(**kwargs)

def get_yaml_compatible_object_kwargs(self):
kwargs_dict = super(Activation, self).\
get_yaml_compatible_object_kwargs()
kwargs_dict['mxts_mode'] = self.mxts_mode
kwargs_dict['nonlinear_mxts_mode'] = self.nonlinear_mxts_mode
return kwargs_dict

def _compute_shape(self, input_shape):
Expand All @@ -29,7 +28,7 @@ def _compute_shape(self, input_shape):
def _build_fwd_pass_vars(self):
super(Activation, self)._build_fwd_pass_vars()
self._gradient_at_default_activation =\
self._get_gradient_at_activation(self._get_default_activation_vars())
self._get_gradient_at_activation(self.get_reference_vars())

def _get_gradient_at_default_activation_var(self):
return self._gradient_at_default_activation
Expand All @@ -38,13 +37,13 @@ def _build_activation_vars(self, input_act_vars):
raise NotImplementedError()

def _deeplift_get_scale_factor(self):
input_diff_from_default = self._get_input_diff_from_default_vars()
near_zero_contrib_mask = (B.abs(input_diff_from_default)\
input_diff_from_reference = self._get_input_diff_from_reference_vars()
near_zero_contrib_mask = (B.abs(input_diff_from_reference)\
< NEAR_ZERO_THRESHOLD)
far_from_zero_contrib_mask = 1-(1*near_zero_contrib_mask)
#the pseudocount is to avoid division-by-zero for the ones that
#we won't use anyway
pc_diff_from_default = input_diff_from_default +\
pc_diff_from_reference = input_diff_from_reference +\
(1*near_zero_contrib_mask)
#when total contrib is near zero,
#the scale factor is 1 (gradient; piecewise linear). Otherwise,
Expand All @@ -54,34 +53,35 @@ def _deeplift_get_scale_factor(self):
scale_factor = near_zero_contrib_mask*\
self._get_gradient_at_default_activation_var() +\
(far_from_zero_contrib_mask*\
(self._get_diff_from_default_vars()/
pc_diff_from_default))
(self._get_diff_from_reference_vars()/
pc_diff_from_reference))
return scale_factor

def _gradients_get_scale_factor(self):
return self._get_gradient_at_activation(
self._get_input_activation_vars())

def _get_mxts_increments_for_inputs(self):
if (self.mxts_mode == MxtsMode.DeconvNet):
if (self.nonlinear_mxts_mode==NonlinearMxtsMode.DeconvNet):
#apply the given nonlinearity in reverse
mxts = self._build_activation_vars(self.get_mxts())
else:
#all the other ones here are of the form:
# scale_factor*self.get_mxts()
if (self.mxts_mode == MxtsMode.DeepLIFT):
if (self.nonlinear_mxts_mode==NonlinearMxtsMode.DeepLIFT):
scale_factor = self._deeplift_get_scale_factor()
elif (self.mxts_mode == MxtsMode.GuidedBackpropDeepLIFT):
elif (self.nonlinear_mxts_mode==
NonlinearMxtsMode.GuidedBackpropDeepLIFT):
deeplift_scale_factor = self._deeplift_get_scale_factor()
scale_factor = deeplift_scale_factor*(self.get_mxts() > 0)
elif (self.mxts_mode == MxtsMode.Gradient):
elif (self.nonlinear_mxts_mode==NonlinearMxtsMode.Gradient):
scale_factor = self._gradients_get_scale_factor()
elif (self.mxts_mode == MxtsMode.GuidedBackprop):
elif (self.nonlinear_mxts_mode==NonlinearMxtsMode.GuidedBackprop):
scale_factor = self._gradients_get_scale_factor()\
*(self.get_mxts() > 0)
else:
raise RuntimeError("Unsupported mxts_mode: "
+str(self.mxts_mode))
raise RuntimeError("Unsupported nonlinear_mxts_mode: "
+str(self.nonlinear_mxts_mode))
orig_mxts = scale_factor*self.get_mxts()
return orig_mxts
return mxts
Expand Down
4 changes: 2 additions & 2 deletions deeplift/blobs/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def _get_mxts_increments_for_inputs(self):
elif (self.maxpool_deeplift_mode==
MaxPoolDeepLiftMode.scaled_gradient):
grad_times_diff_def = self._get_input_grad_given_outgrad(
out_grad=self.get_mxts()*self._get_diff_from_default_vars())
out_grad=self.get_mxts()*self._get_diff_from_reference_vars())
pcd_input_diff_default = (pseudocount_near_zero(
self._get_input_diff_from_default_vars()))
self._get_input_diff_from_reference_vars()))
return grad_times_diff_def/pcd_input_diff_default
else:
raise RuntimeError("Unsupported maxpool_deeplift_mode: "+
Expand Down
125 changes: 75 additions & 50 deletions deeplift/blobs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@

ScoringMode = deeplift.util.enum(OneAndZeros="OneAndZeros",
SoftmaxPreActivation="SoftmaxPreActivation")
MxtsMode = deeplift.util.enum(Gradient="Gradient", DeepLIFT="DeepLIFT",
DeconvNet="DeconvNet",
GuidedBackprop="GuidedBackprop",
GuidedBackpropDeepLIFT=\
"GuidedBackpropDeepLIFT")
NonlinearMxtsMode = deeplift.util.enum(
Gradient="Gradient",
DeepLIFT="DeepLIFT",
DeconvNet="DeconvNet",
GuidedBackprop="GuidedBackprop",
GuidedBackpropDeepLIFT="GuidedBackpropDeepLIFT")
DenseMxtsMode = deeplift.util.enum(
Linear="Linear",
PosOnly="PosOnly",
Counterbalance="Counterbalance")
ActivationNames = deeplift.util.enum(sigmoid="sigmoid",
hard_sigmoid="hard_sigmoid",
tanh="tanh",
Expand Down Expand Up @@ -75,36 +80,36 @@ def get_activation_vars(self):
self._layer_needs_to_be_built_message()
return self._activation_vars

def _build_default_activation_vars(self):
def _build_reference_vars(self):
raise NotImplementedError()

def _build_diff_from_default_vars(self):
def _build_diff_from_reference_vars(self):
"""
instantiate theano vars whose value is the difference between
the activation and the default activaiton
the activation and the reference activation
"""
return self.get_activation_vars() - self._get_default_activation_vars()
return self.get_activation_vars() - self.get_reference_vars()

def _build_target_contrib_vars(self):
"""
the contrib to the target is mxts*(Ax - Ax0)
"""
return self.get_mxts()*self._get_diff_from_default_vars()
return self.get_mxts()*self._get_diff_from_reference_vars()

def _get_diff_from_default_vars(self):
def _get_diff_from_reference_vars(self):
"""
return the theano vars representing the difference between
the activation and the default activation
the activation and the reference activation
"""
return self._diff_from_default_vars
return self._diff_from_reference_vars

def _get_default_activation_vars(self):
def get_reference_vars(self):
"""
get the activation that corresponds to zero contrib
"""
if (hasattr(self, '_default_activation_vars')==False):
raise RuntimeError("_default_activation_vars is unset")
return self._default_activation_vars
if (hasattr(self, '_reference_vars')==False):
raise RuntimeError("_reference_vars is unset")
return self._reference_vars

def _increment_mxts(self, increment_var):
"""
Expand Down Expand Up @@ -195,8 +200,9 @@ def __init__(self, num_dims, shape, **kwargs):
def get_activation_vars(self):
return self._activation_vars

def _build_default_activation_vars(self):
raise NotImplementedError()
def _build_reference_vars(self):
return B.tensor_with_dims(self._num_dims,
name="ref_"+str(self.get_name()))

def get_yaml_compatible_object_kwargs(self):
kwargs_dict = super(Input,self).get_yaml_compatible_object_kwargs()
Expand All @@ -205,27 +211,11 @@ def get_yaml_compatible_object_kwargs(self):
return kwargs_dict

def _build_fwd_pass_vars(self):
self._default_activation_vars = self._build_default_activation_vars()
self._diff_from_default_vars = self._build_diff_from_default_vars()
self._reference_vars = self._build_reference_vars()
self._diff_from_reference_vars = self._build_diff_from_reference_vars()
self._mxts = B.zeros_like(self.get_activation_vars())


class Input_FixedDefault(Input):

def __init__(self, default=0.0, **kwargs):
super(Input_FixedDefault, self).__init__(**kwargs)
self.default = default

def get_yaml_compatible_object_kwargs(self):
kwargs_dict = super(Input_FixedDefault, self).\
get_yaml_compatible_object_kwargs()
kwargs_dict['default'] = self.default
return kwargs_dict

def _build_default_activation_vars(self):
return B.ones_like(self._activation_vars)*self.default


class Node(Blob):

def __init__(self, **kwargs):
Expand Down Expand Up @@ -257,13 +247,13 @@ def _get_input_activation_vars(self):
return self._call_function_on_blobs_within_inputs(
'get_activation_vars')

def _get_input_default_activation_vars(self):
def _get_input_reference_vars(self):
return self._call_function_on_blobs_within_inputs(
'_get_default_activation_vars')
'get_reference_vars')

def _get_input_diff_from_default_vars(self):
def _get_input_diff_from_reference_vars(self):
return self._call_function_on_blobs_within_inputs(
'_get_diff_from_default_vars')
'_get_diff_from_reference_vars')

def _get_input_shape(self):
return self._call_function_on_blobs_within_inputs('get_shape')
Expand Down Expand Up @@ -292,11 +282,11 @@ def _build_fwd_pass_vars(self):
self._activation_vars =\
self._build_activation_vars(
self._get_input_activation_vars())
self._default_activation_vars =\
self._build_default_activation_vars()
self._diff_from_default_vars =\
self._build_diff_from_default_vars()
self._mxts = B.zeros_like(self._get_default_activation_vars())
self._reference_vars =\
self._build_reference_vars()
self._diff_from_reference_vars =\
self._build_diff_from_reference_vars()
self._mxts = B.zeros_like(self.get_reference_vars())

def _compute_shape(self, input_shape):
"""
Expand All @@ -311,9 +301,9 @@ def _build_activation_vars(self, input_act_vars):
"""
raise NotImplementedError()

def _build_default_activation_vars(self):
def _build_reference_vars(self):
return self._build_activation_vars(
self._get_input_default_activation_vars())
self._get_input_reference_vars())

def _update_mxts_for_inputs(self):
"""
Expand Down Expand Up @@ -461,10 +451,11 @@ def _get_mxts_increments_for_inputs(self):

class Dense(SingleInputMixin, OneDimOutputMixin, Node):

def __init__(self, W, b, **kwargs):
def __init__(self, W, b, dense_mxts_mode, **kwargs):
super(Dense, self).__init__(**kwargs)
self.W = W
self.b = b
self.dense_mxts_mode = dense_mxts_mode

def get_yaml_compatible_object_kwargs(self):
kwargs_dict = super(Dense, self).\
Expand All @@ -480,7 +471,41 @@ def _build_activation_vars(self, input_act_vars):
return B.dot(input_act_vars, self.W) + self.b

def _get_mxts_increments_for_inputs(self):
return B.dot(self.get_mxts(),self.W.T)
if (self.dense_mxts_mode == DenseMxtsMode.PosOnly):
return B.dot(self.get_mxts()*(self.get_mxts()>0.0),self.W.T)
elif (self.dense_mxts_mode == DenseMxtsMode.Counterbalance):
#self.W has dims input x output
#fwd_contribs has dims batch x output x input
fwd_contribs = self._get_input_activation_vars()[:,None,:]\
*self.W.T[None,:,:]
#total_pos_contribs and total_neg_contribs have dim batch x output
total_pos_contribs = B.sum(fwd_contribs*(fwd_contribs>0), axis=-1)
total_neg_contribs = B.abs(B.sum(fwd_contribs*(fwd_contribs<0),
axis=-1))
#if output diff-from-def is positive but there are some neg
#contribs, temper positive by some portion of the neg
#to_distribute has dims batch x output
to_distribute =\
B.maximum(
(total_neg_contribs*(total_neg_contribs < total_pos_contribs)
- B.maximum(self.get_reference_vars(),0)),0.0)\
*(1.0-((total_neg_contribs)/
pseudocount_near_zero(total_pos_contribs)))
#total_pos_contribs_new has dims batch x output
total_pos_contribs_new = total_pos_contribs - to_distribute
total_neg_contribs_new = total_neg_contribs - to_distribute
#positive_rescale has dims batch x output
positive_rescale = total_pos_contribs_new/pseudocount_near_zero(total_pos_contribs)
negative_rescale = total_neg_contribs_new/pseudocount_near_zero(total_neg_contribs)
#new_Wt has dims batch x output x input
new_Wt = self.W.T[None,:,:]*(fwd_contribs>0)*positive_rescale[:,:,None]
new_Wt += self.W.T[None,:,:]*(fwd_contribs<0)*negative_rescale[:,:,None]
return B.sum(self.get_mxts()[:,:,None]*new_Wt[:,:,:],axis=1)
elif (self.dense_mxts_mode == DenseMxtsMode.Linear):
return B.dot(self.get_mxts(),self.W.T)
else:
raise RuntimeError("Unsupported mxts mode: "
+str(self.dense_mxts_mode))


class BatchNormalization(SingleInputMixin, Node):
Expand Down
Loading

0 comments on commit 9614233

Please sign in to comment.