Skip to content
This repository has been archived by the owner on Jun 9, 2023. It is now read-only.

[WIP] - Allows different shapes for multi inputs & outputs #5

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions nvtx_plugins/python/nvtx/plugins/tf/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,29 @@

nvtx_tf_ops = load_library('lib/nvtx_ops' + get_ext_suffix())

def _maybe_process_inputs(inputs):

raw_inputs = None

if isinstance(inputs, (list, tuple)):
raw_inputs = inputs
inputs = tf.zeros(shape=(), dtype=tf.float32)

assert isinstance(inputs, tf.Tensor)

return inputs, raw_inputs

def _maybe_process_outputs(nvtx_out, raw_inputs):

if raw_inputs is not None:
null_op = tf.debugging.assert_type(tensor=nvtx_out, tf_type=tf.float32)
with tf.control_dependencies([null_op]):
inputs = [tf.identity(x) for x in raw_inputs]
else:
inputs = nvtx_out

return inputs

def _maybe_convert_list_to_tensor(inputs):

inputs_were_processed = False
Expand Down Expand Up @@ -121,14 +144,13 @@ def start(inputs, message, domain_name=None,
initializer=tf.zeros_initializer,
trainable=True)

inputs, should_unstack = _maybe_convert_list_to_tensor(inputs)
inputs, raw_inputs = _maybe_process_inputs(inputs)

inputs, marker_id, domain_handle = nvtx_tf_ops.nvtx_start(
inputs=inputs, null_input=null_input,
message=message, domain_name=domain_name, name=name)

if should_unstack:
inputs = tf.unstack(inputs, axis=0)
inputs = _maybe_process_outputs(inputs, raw_inputs)

return inputs, (marker_id, domain_handle, grad_message, grad_domain_name)

Expand Down Expand Up @@ -166,15 +188,14 @@ def end(inputs, nvtx_context, name=None):

marker_id, domain_handle, grad_message, grad_domain_name = nvtx_context

inputs, should_unstack = _maybe_convert_list_to_tensor(inputs)
inputs, raw_inputs = _maybe_process_inputs(inputs)

output, null_output = nvtx_tf_ops.nvtx_end(inputs=inputs,
marker_id=marker_id, domain_handle=domain_handle,
grad_message=grad_message, grad_domain_name=grad_domain_name, name=name
)

if should_unstack:
output = tf.unstack(output, axis=0)
output = _maybe_process_outputs(output, raw_inputs)

return output

Expand Down Expand Up @@ -213,26 +234,22 @@ def func_wrapper(wrapped, instance, args, kwargs):
raise ValueError("The input tensor must be the first argument"
" or named `inputs`")

inputs, should_unstack = _maybe_convert_list_to_tensor(inputs)

start_name = '{}_start'.format(name) if name else None
end_name = '{}_end'.format(name) if name else None

inputs, nvtx_context = start(inputs=inputs,
nvtx_out, nvtx_context = start(inputs=inputs,
message=message, domain_name=domain_name,
grad_message=grad_message, grad_domain_name=grad_domain_name,
enabled=enabled, trainable=trainable, name=start_name
)

if should_unstack:
inputs = tf.unstack(inputs, axis=0)

if "inputs" in kwargs:
kwargs["inputs"] = inputs
else:
args = [inputs] + list(args[1:])

output = wrapped(*args, **kwargs)

output = end(inputs=output, nvtx_context=nvtx_context, name=end_name)

return output
Expand Down