Skip to content

Commit

Permalink
Revamp PT2 Compile/chromium event logging [1/?]
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#138093

This diff is the starting steps of https://docs.google.com/document/u/2/d/1kAEBt4AyW7HTAhXHbjoz8FBFHNyyEA2Qo2mPn7v3WUQ/edit?usp=drive_web&ouid=113555078003219714709

It implements the following changes:

- Only log spans to scuba, so no start events are ever logged
- Log events as the full event name, without "START" or "END"
- Only log to scuba major phases from chromium events. These are:
  - entire_frame_compile (dynamo)
  - backend_compile (aotdispatch)
  - inductor_compile (inductor)
  - codegen (inductor codegen)

Tlparse chromium events stay basically the same. But I implemented a few changes to clean that up as well:
- When there's a phase name available, log the phase name instead of the function name as the event name. This simplifies the trace to not have two identical rows. The fn_name is avaliable as metadata on the chromium event, if interested
- Log new events for pre and post grad passes. These do *not* log to scuba.

By making the phases much simpler in Scuba, with only categories for major phases of PT2 Compilation, we pave the way to add **much** more metadata and information to each individual event type. Diffs for that will come later.

**IMPLEMENTATION NOTES:**
- The logic for `log_chromium_event_internal` (which is the function that logs to Scuba) lives in chromium_events for now, but in the future as we add more metadata, it may belong independently in dynamo_timed or even outside of dynamo_timed. I haven't explored in detail what the refactor will look like. Once we start logging metadata for dynamo, aotdispatch, inductor, I suspect we will call log_pt2_compile_event directly, instead of making chromium event logger handle the pt2_compile_event logic. But that refactor is left for another PR on top of this one.

- There's an interesting space after pre grad passes within AOT autograd logic, that's between create_aot_dispatcher_function and pre grad passes. I'm not sure what we're spending time doing in that time, but I'll find out with a profile later.
ghstack-source-id: 248790387

Reviewed By: oulgen

Differential Revision: D64479033

fbshipit-source-id: 1f30e734160bfed2f664063b5b2f4df1b661dfa4
  • Loading branch information
jamesjwu authored and facebook-github-bot committed Oct 18, 2024
1 parent e89c1b3 commit 8358f92
Showing 1 changed file with 20 additions and 37 deletions.
57 changes: 20 additions & 37 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,6 @@ def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) -
_add_time_spent(key, "remote_cache_time_saved", time_saved)


def get_cache_stats() -> Dict[str, Any]:
"""Get a bunch of metadata about cache hits and misses to use in chromium events"""
cache_stats = {
"fxgraph_cache_hit": counters["inductor"]["fxgraph_cache_hit"],
"fxgraph_cache_miss": counters["inductor"]["fxgraph_cache_miss"],
"fxgraph_cache_bypass": counters["inductor"]["fxgraph_cache_bypass"],
}
return cache_stats


# dynamo_timed is a context manager
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
# where the key is the functions name.
Expand Down Expand Up @@ -290,9 +280,10 @@ def dynamo_timed(
try:
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
t0 = time.time()
chromium_log.log_event_start(key, start, None)
if phase_name:
chromium_log.log_event_start(phase_name, start)
chromium_log.log_event_start(phase_name, start, {"fn_name": key})
else:
chromium_log.log_event_start(key, start, {})
yield
time_spent = time.time() - t0
compilation_time_metrics[key].append(time_spent)
Expand All @@ -306,16 +297,15 @@ def dynamo_timed(
chromium_log.log_event_end(
phase_name,
time.time_ns(),
{"cache_stats": get_cache_stats()},
{},
start,
)
chromium_log.log_event_end(
key, time.time_ns(), {"cache_stats": get_cache_stats()}, start
)
else:
chromium_log.log_event_end(key, time.time_ns(), {}, start)
# Only record backward compilation metrics if phase_name is not None!
if phase_name:
frame_key = str(curr_frame)
# fwd only compilation stages: entire_frame_compile, backend_compile.
# fwd only compilation stages: entire_frame_compile, backend_compile, aotdispatch.
# use frame_key as time aggregation key.
if fwd_only and fail_type is None:
_add_time_spent(frame_key, phase_name, time_spent)
Expand Down Expand Up @@ -902,7 +892,7 @@ def log_event_start(
self,
event_name: str,
time_ns: int,
metadata: Optional[Dict[str, Any]] = None,
metadata: Dict[str, Any],
) -> None:
"""
Logs the start of a single event.
Expand All @@ -911,19 +901,14 @@ def log_event_start(
:param metadata: Any extra metadata associated with this event
"""

# Add compile id to metadata
if metadata is None:
metadata = {}
compile_id = str(torch._guards.CompileContext.current_compile_id())
metadata["compile_id"] = compile_id

event = self._log_timed_event(
self._log_timed_event(
event_name,
time_ns,
"B",
metadata,
)
log_chromium_event_internal(event, self.get_stack(), compile_id, self.id_)
self.get_stack().append(event_name)

def reset(self) -> None:
Expand All @@ -937,8 +922,8 @@ def log_event_end(
self,
event_name: str,
time_ns: int,
metadata: Optional[Dict[str, Any]] = None,
start_time_ns: Optional[int] = None,
metadata: Dict[str, Any],
start_time_ns: int,
) -> None:
"""
Logs the end of a single event. This function should only be
Expand All @@ -947,11 +932,14 @@ def log_event_end(
:param time_ns: Timestamp in nanoseconds
:param metadata: Any extra metadata associated with this event
"""
# Add compile id to metadata
if metadata is None:
metadata = {}
compile_id = str(torch._guards.CompileContext.current_compile_id())
metadata["compile_id"] = compile_id
event = self._log_timed_event(
event_name,
time_ns,
"E",
metadata,
)

# These stack health checks currently never happen,
# but they're written this way to future proof any weird event
Expand All @@ -963,13 +951,6 @@ def log_event_end(
log.warning("ChromiumEventLogger: Start event not in stack, ignoring")
return

event = self._log_timed_event(
event_name,
time_ns,
"E",
metadata,
)

while event_name != stack[-1]:
# If the event isn't the most recent one to end, pop
# off the stack until it is.
Expand Down Expand Up @@ -1046,7 +1027,9 @@ def log_instant_event(
expect_trace_id=True,
)
# Log an instant event with the same start and end time
log_chromium_event_internal(event, self.get_stack(), compile_id, self.id_)
log_chromium_event_internal(
event, self.get_stack(), compile_id, self.id_, time_ns
)


CHROMIUM_EVENT_LOG: Optional[ChromiumEventLogger] = None
Expand Down

0 comments on commit 8358f92

Please sign in to comment.