Skip to content

Commit

Permalink
Add metadata to events in progress, new dynamo event
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#138477

This diff does a few things:

## Add metadata to events in progress
Adds the ability to add extra metadata to Chromium Events via `add_event_data`.
Metadata can only be added to chromium events that have started, but not ended (so, in progress events)
- When you add the data, the metadata is appended to the metadata when you call log_event_end().
- The metadata appears in chromium events in tlparse. It also gets logged to scuba.

## New `dynamo` chromium event
We add a new `dynamo` chromium event to the top of the stack, where we collect various metadata found in dynamo_compile. So the new order of events goes:

```
__start__
-> dynamo (dynamo compile metrics)
-> entire_frame_compile (compile.inner)
-> backend_compile (i.e. aotdispatch)
-> create_aot_dispatch_function
-> inductor_compile
-> ...
```

BackwardCompilationMetrics doesn't have any dynamo specific information (as it's mostly inductor timings). So we don't include that here.

*FAQ: Why can't we use `entire_frame_compile` as the event?*
This is mostly due to backward compatibility with `dynamo_compile`. `dynamo_compile` collects CompilationMetrics outside of `compile.compile_inner`, and uses `dynamo_timed` to grab timings from phases of the compiler, including `entire_frame_compile`. So we don't have a CompilationMetric object until after an `entire_frame_compile` event ends! Separately, `dynamo` as a name for all of dynamo compile is more descriptive than `entire_frame_compile`, imo.

## Log metadata as separate columns
(Meta only): Separately, this also changes the `metadata` column in PT2 Compile Events. Instead of logging a single metadata column in JSON, it separates the JSON into separate columns. This is much better for data analysis. Now that this table is more mature, I think logging keys to separate columns is a better system.
ghstack-source-id: 249373269

Reviewed By: aorenste

Differential Revision: D64696287

fbshipit-source-id: 441f57e2d1c0210e81c06eb86d4482e95bed4971
  • Loading branch information
jamesjwu authored and facebook-github-bot committed Oct 22, 2024
1 parent 1154318 commit 8fce9c1
Showing 1 changed file with 72 additions and 6 deletions.
78 changes: 72 additions & 6 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,13 +821,45 @@ class BwdCompilationMetrics:
] = collections.deque(maxlen=DEFAULT_COMPILATION_METRICS_LIMIT)


def add_compilation_metrics_to_chromium(c: CompilationMetrics):
event_logger = get_chromium_event_logger()
# The following compilation metrics are related to
# dynamo, so go with the "entire frame compile" event
event_logger.add_event_data(
event_name="dynamo",
frame_key=c.frame_key,
co_name=c.co_name,
co_filename=c.co_filename,
co_firstlineno=c.co_firstlineno,
cache_size=c.cache_size,
accumulated_cache_size=c.accumulated_cache_size,
guard_count=c.guard_count,
shape_env_guard_count=c.shape_env_guard_count,
graph_op_count=c.graph_op_count,
graph_node_count=c.graph_node_count,
graph_input_count=c.graph_input_count,
fail_type=c.fail_type,
fail_reason=c.fail_reason,
fail_user_frame_filename=c.fail_user_frame_filename,
fail_user_frame_lineno=c.fail_user_frame_lineno,
# Sets aren't JSON serializable
non_compliant_ops=list(c.non_compliant_ops),
compliant_custom_ops=list(c.compliant_custom_ops),
restart_reasons=list(c.restart_reasons),
dynamo_time_before_restart_s=c.dynamo_time_before_restart_s,
has_guarded_code=c.has_guarded_code,
dynamo_config=c.dynamo_config,
)


def record_compilation_metrics(
compilation_metrics: Union[CompilationMetrics, BwdCompilationMetrics]
):
global _compilation_metrics
_compilation_metrics.append(compilation_metrics)
if isinstance(compilation_metrics, CompilationMetrics):
name = "compilation_metrics"
add_compilation_metrics_to_chromium(compilation_metrics)
else:
name = "bwd_compilation_metrics"
torch._logging.trace_structured(
Expand Down Expand Up @@ -877,6 +909,11 @@ def get_stack(self):
self.tls.stack = ["__start__"]
return self.tls.stack

def get_event_data(self) -> Dict[str, Any]:
if not hasattr(self.tls, "event_data"):
self.tls.event_data = {}
return self.tls.event_data

def __init__(self):
self.tls = threading.local()
# Generate a unique id for this logger, which we can use in scuba to filter down
Expand All @@ -886,6 +923,25 @@ def __init__(self):
# TODO: log to init/id tlparse after I add support for it
log.info("ChromiumEventLogger initialized with id %s", self.id_)

def add_event_data(
self,
event_name: str,
**kwargs,
) -> None:
"""
Adds additional metadata info to an in-progress event
This metadata is recorded in the END event
"""
if event_name not in self.get_stack():
raise RuntimeError(
"Cannot add metadata to events that aren't in progress."
"Please make sure the event has started and hasn't ended."
)
event_data = self.get_event_data()
if event_name not in event_data:
event_data[event_name] = {}
event_data[event_name].update(kwargs)

def log_event_start(
self,
event_name: str,
Expand All @@ -898,7 +954,6 @@ def log_event_start(
:param time_ns Timestamp in nanoseconds
:param metadata: Any extra metadata associated with this event
"""

compile_id = str(torch._guards.CompileContext.current_compile_id())
metadata["compile_id"] = compile_id
self._log_timed_event(
Expand All @@ -915,6 +970,8 @@ def reset(self) -> None:
stack = self.get_stack()
stack.clear()
stack.append("__start__")
event_data = self.get_event_data()
event_data.clear()

def log_event_end(
self,
Expand All @@ -932,11 +989,22 @@ def log_event_end(
"""
compile_id = str(torch._guards.CompileContext.current_compile_id())
metadata["compile_id"] = compile_id

# Grab metadata collected during event span
all_event_data = self.get_event_data()
if event_name in all_event_data:
event_metadata = all_event_data[event_name]
del all_event_data[event_name]
else:
event_metadata = {}
# Add the passed in metadata
event_metadata.update(metadata)

event = self._log_timed_event(
event_name,
time_ns,
"E",
metadata,
event_metadata,
)

# These stack health checks currently never happen,
Expand All @@ -958,7 +1026,7 @@ def log_event_end(
)
stack.pop()

log_chromium_event_internal(event, stack, compile_id, self.id_, start_time_ns)
log_chromium_event_internal(event, stack, self.id_, start_time_ns)
# Finally pop the actual event off the stack
stack.pop()

Expand Down Expand Up @@ -1025,9 +1093,7 @@ def log_instant_event(
expect_trace_id=True,
)
# Log an instant event with the same start and end time
log_chromium_event_internal(
event, self.get_stack(), compile_id, self.id_, time_ns
)
log_chromium_event_internal(event, self.get_stack(), self.id_, time_ns)


CHROMIUM_EVENT_LOG: Optional[ChromiumEventLogger] = None
Expand Down

0 comments on commit 8fce9c1

Please sign in to comment.