Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal build changes. #138

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions morph_net/tools/configurable_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@

import tensorflow as tf

from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib import layers as contrib_layers

gfile = tf.gfile # Aliase needed for mock.

VANISHED = 0.0
Expand Down Expand Up @@ -80,14 +83,14 @@ class FallbackRule(Enum):


DEFAULT_FUNCTION_DICT = {
'fully_connected': tf.contrib.layers.fully_connected,
'conv2d': tf.contrib.layers.conv2d,
'separable_conv2d': tf.contrib.layers.separable_conv2d,
'fully_connected': contrib_layers.fully_connected,
'conv2d': contrib_layers.conv2d,
'separable_conv2d': contrib_layers.separable_conv2d,
'concat': tf.concat,
'add_n': tf.add_n,
'avg_pool2d': tf.contrib.layers.avg_pool2d,
'max_pool2d': tf.contrib.layers.max_pool2d,
'batch_norm': tf.contrib.layers.batch_norm,
'avg_pool2d': contrib_layers.avg_pool2d,
'max_pool2d': contrib_layers.max_pool2d,
'batch_norm': contrib_layers.batch_norm,
}

# Maps function names to the suffix of the name of the regularized ops.
Expand Down Expand Up @@ -164,13 +167,13 @@ def parameterization(self):
"""Returns the parameterization dict mapping op names to num_outputs."""
return self._parameterization

@tf.contrib.framework.add_arg_scope
@contrib_framework.add_arg_scope
def conv2d(self, *args, **kwargs):
"""Masks num_outputs from the function pointed to by 'conv2d'.

The object's parameterization has precedence over the given NUM_OUTPUTS
argument. The resolution of the op names uses
tf.contrib.framework.get_name_scope() and kwargs['scope'].
contrib_framework.get_name_scope() and kwargs['scope'].

Args:
*args: Arguments for the operation.
Expand All @@ -187,13 +190,13 @@ def conv2d(self, *args, **kwargs):
fn, suffix = self._get_function_and_suffix('conv2d')
return self._mask(fn, suffix, *args, **kwargs)

@tf.contrib.framework.add_arg_scope
@contrib_framework.add_arg_scope
def fully_connected(self, *args, **kwargs):
"""Masks NUM_OUTPUTS from the function pointed to by 'fully_connected'.

The object's parameterization has precedence over the given NUM_OUTPUTS
argument. The resolution of the op names uses
tf.contrib.framework.get_name_scope() and kwargs['scope'].
contrib_framework.get_name_scope() and kwargs['scope'].

Args:
*args: Arguments for the operation.
Expand All @@ -214,13 +217,13 @@ def fully_connected(self, *args, **kwargs):
fn, suffix = self._get_function_and_suffix('fully_connected')
return self._mask(fn, suffix, *args, **kwargs)

@tf.contrib.framework.add_arg_scope
@contrib_framework.add_arg_scope
def separable_conv2d(self, *args, **kwargs):
"""Masks NUM_OUTPUTS from the function pointed to by 'separable_conv2d'.

The object's parameterization has precedence over the given NUM_OUTPUTS
argument. The resolution of the op names uses
tf.contrib.framework.get_name_scope() and kwargs['scope'].
contrib_framework.get_name_scope() and kwargs['scope'].

Args:
*args: Arguments for the operation.
Expand Down Expand Up @@ -251,7 +254,7 @@ def _mask(self, function, suffix, *args, **kwargs):

The object's parameterization has precedence over the given NUM_OUTPUTS
argument. The resolution of the op names uses
`tf.contrib.framework.get_name_scope()` and `kwargs['scope']`.
`contrib_framework.get_name_scope()` and `kwargs['scope']`.

The NUM_OUTPUTS argument is assumed to be either in **kwargs or held in
*args[1].
Expand Down Expand Up @@ -284,7 +287,7 @@ def _mask(self, function, suffix, *args, **kwargs):

# Support for tf.contrib.layers and tf.layers API.
op_scope = kwargs.get('scope') or kwargs.get('name')
current_scope = tf.contrib.framework.get_name_scope() or ''
current_scope = contrib_framework.get_name_scope() or ''
if current_scope and not current_scope.endswith('/'):
current_scope += '/'
op_name = ''.join([current_scope, op_scope, '/', suffix])
Expand Down Expand Up @@ -320,17 +323,17 @@ def concat(self, *args, **kwargs):
def add_n(self, *args, **kwargs):
return self._pass_through_mask_list('add_n', 'inputs', *args, **kwargs)

@tf.contrib.framework.add_arg_scope
@contrib_framework.add_arg_scope
def avg_pool2d(self, *args, **kwargs):
return self._pass_through_mask(
self._function_dict['avg_pool2d'], *args, **kwargs)

@tf.contrib.framework.add_arg_scope
@contrib_framework.add_arg_scope
def max_pool2d(self, *args, **kwargs):
return self._pass_through_mask(
self._function_dict['max_pool2d'], *args, **kwargs)

@tf.contrib.framework.add_arg_scope
@contrib_framework.add_arg_scope
def batch_norm(self, *args, **kwargs):
return self._pass_through_mask(
self._function_dict['batch_norm'], *args, **kwargs)
Expand Down Expand Up @@ -432,8 +435,8 @@ def hijack_module_functions(configurable_ops, module):

example_module.py
```
conv2d = tr.contrib.layers.conv2d
fully_connected = tr.contrib.layers.fully_connected
conv2d = tr.contrib_layers.conv2d
fully_connected = tr.contrib_layers.fully_connected

def build_layer(inputs):
return conv2d(inputs, 64, 3, scope='demo')
Expand All @@ -444,15 +447,15 @@ def build_layer(inputs):

So after a call to `hijack_module_functions(configurable_ops, example_module)`
the call `example_module.build_layer(net)` will under the hood use
`configurable_ops.conv2d` rather than `tf.contrib.layers.conv2d`.
`configurable_ops.conv2d` rather than `contrib_layers.conv2d`.

Note: This function could be unsafe as it depends on aliases defined in a
possibly external module. In addition, a function in that module that calls
directly, will not be affected by the hijacking, for instance:

```
def build_layer_not_affected(inputs):
return tf.contrib.layers.conv2d(inputs, 64, 3, scope='bad')
return contrib_layers.conv2d(inputs, 64, 3, scope='bad')
```

Args:
Expand Down