diff --git a/nvtx_plugins/python/nvtx/plugins/tf/ops.py b/nvtx_plugins/python/nvtx/plugins/tf/ops.py index 6242ec1..6e90812 100644 --- a/nvtx_plugins/python/nvtx/plugins/tf/ops.py +++ b/nvtx_plugins/python/nvtx/plugins/tf/ops.py @@ -28,6 +28,29 @@ nvtx_tf_ops = load_library('lib/nvtx_ops' + get_ext_suffix()) +def _maybe_process_inputs(inputs): + + raw_inputs = None + + if isinstance(inputs, (list, tuple)): + raw_inputs = inputs + inputs = tf.zeros(shape=(), dtype=tf.float32) + + assert isinstance(inputs, tf.Tensor) + + return inputs, raw_inputs + +def _maybe_process_outputs(nvtx_out, raw_inputs): + + if raw_inputs is not None: + null_op = tf.debugging.assert_type(tensor=nvtx_out, tf_type=tf.float32) + with tf.control_dependencies([null_op]): + inputs = [tf.identity(x) for x in raw_inputs] + else: + inputs = nvtx_out + + return inputs + def _maybe_convert_list_to_tensor(inputs): inputs_were_processed = False @@ -121,14 +144,13 @@ def start(inputs, message, domain_name=None, initializer=tf.zeros_initializer, trainable=True) - inputs, should_unstack = _maybe_convert_list_to_tensor(inputs) + inputs, raw_inputs = _maybe_process_inputs(inputs) inputs, marker_id, domain_handle = nvtx_tf_ops.nvtx_start( inputs=inputs, null_input=null_input, message=message, domain_name=domain_name, name=name) - if should_unstack: - inputs = tf.unstack(inputs, axis=0) + inputs = _maybe_process_outputs(inputs, raw_inputs) return inputs, (marker_id, domain_handle, grad_message, grad_domain_name) @@ -166,15 +188,14 @@ def end(inputs, nvtx_context, name=None): marker_id, domain_handle, grad_message, grad_domain_name = nvtx_context - inputs, should_unstack = _maybe_convert_list_to_tensor(inputs) + inputs, raw_inputs = _maybe_process_inputs(inputs) output, null_output = nvtx_tf_ops.nvtx_end(inputs=inputs, marker_id=marker_id, domain_handle=domain_handle, grad_message=grad_message, grad_domain_name=grad_domain_name, name=name ) - if should_unstack: - output = tf.unstack(output, axis=0) + output = _maybe_process_outputs(output, raw_inputs) return output @@ -213,26 +234,22 @@ def func_wrapper(wrapped, instance, args, kwargs): raise ValueError("The input tensor must be the first argument" " or named `inputs`") - inputs, should_unstack = _maybe_convert_list_to_tensor(inputs) - start_name = '{}_start'.format(name) if name else None end_name = '{}_end'.format(name) if name else None - inputs, nvtx_context = start(inputs=inputs, + nvtx_out, nvtx_context = start(inputs=inputs, message=message, domain_name=domain_name, grad_message=grad_message, grad_domain_name=grad_domain_name, enabled=enabled, trainable=trainable, name=start_name ) - if should_unstack: - inputs = tf.unstack(inputs, axis=0) - if "inputs" in kwargs: kwargs["inputs"] = inputs else: args = [inputs] + list(args[1:]) - + output = wrapped(*args, **kwargs) + output = end(inputs=output, nvtx_context=nvtx_context, name=end_name) return output