Skip to content

Commit

Permalink
add option to pass trace_file list to Trace() object (#179)
Browse files Browse the repository at this point in the history
Summary:
## What does this PR do?
1. Allows users to enter a list of trace_files in Trace class. HTA will then determine rank to trace file mapping.
1. Makes it so that you do not need "distributedInfo", HTa will now default to rank = 0 and emit a warning.

Code example for (1)
```
        inference_trace_files = ["tests/data/inference_single_rank/inference_rank_0.json.gz"]
        trace: Trace = Trace(
            trace_files=inference_trace_files, trace_dir=os.getcwd()
        )
```

## Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  - [x] N/A
- [x] Did you write any new necessary tests?
  - [ ] N/A
- [ ] Did you make sure to update the docs?
  - [x] N/A
- [x] Did you update the [changelog](https://github.com/facebookresearch/HolisticTraceAnalysis/blob/main/CHANGELOG.md)?
  - [] N/A

Pull Request resolved: #179

Reviewed By: fullthom

Differential Revision: D61665964

Pulled By: briancoutinho

fbshipit-source-id: 7b1cb636d263e9f96241439a468e83f11bb4458b
  • Loading branch information
briancoutinho authored and facebook-github-bot committed Aug 23, 2024
1 parent 471d468 commit 5f883b1
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 32 deletions.
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ Versioning](https://semver.org/spec/v2.0.0.html).
- (Experimental) Added lightweight critical path analysis feature.
- (Experimental) Critical path analysis features: event attribution and `summary()`
- (Experimental) Critical path analysis fixes: fixing async memcpy and adding GPU to CPU event based synchronization.
- Add a workaround for overlapping events when using ns resolution traces (https://github.com/pytorch/pytorch/pull/122425)
- Better handling of CUDA sync events with steam = -1
- (Experimental) Added save and restore feature for critical path graph.
- Add nccl collective fields to parser config
- Fix ijson metadata parser for some corner cases
- Add an option for ns rounding and cover ijson loading with it.

#### Changed
- Change test data path in unittests from relative path to real path to support running test within IDEs.
- Add a workaround for overlapping events when using ns resolution traces (https://github.com/pytorch/pytorch/pull/122425)
- Better handling of CUDA sync events with steam = -1
- Fix ijson metadata parser for some corner cases
- Add an option for ns rounding and cover ijson loading with it.
- Updated Trace() api to specify a list of files and auto figure out ranks.

#### Deprecated
- Deprecated 'call_stack'; use 'trace_call_stack' and 'trace_call_graph' instead.
Expand Down
30 changes: 22 additions & 8 deletions hta/common/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
import sys
import time
import tracemalloc
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np

import pandas as pd

from hta.common.trace_file import get_trace_files
from hta.common.trace_file import create_rank_to_trace_dict, get_trace_files
from hta.common.trace_filter import CPUOperatorFilter, GPUKernelFilter
from hta.common.trace_parser import parse_trace_dataframe, parse_trace_dict
from hta.common.trace_symbol_table import (
Expand Down Expand Up @@ -311,24 +311,38 @@ class Trace:

def __init__(
self,
trace_files: Optional[Dict[int, str]] = None,
trace_files: Union[List[str], Optional[Dict[int, str]]] = None,
trace_dir: str = DEFAULT_TRACE_DIR,
) -> None:
"""
The constructor of a Trace object.
Args:
trace_files: Dict[int, str] : a map from rank to trace file names. The trace file names can be either
relative to the path `trace_path` or absolute file paths.
trace_files: Union[List[str], Dict[int, str] : either a list of trace file names or a map from rank to trace file names.
When a list is provided, HTA will infer the ranks by reading the trace file metadata.
The trace file names can be either relative to the path `trace_path` or absolute file paths.
trace_dir (str) : a path used to derive `trace_path = normalize_path(trace_dir)`.
Raises:
AssertionError
"""
self.trace_path: str = normalize_path(trace_dir)
logger.info(f"{self.trace_path}")
self.trace_files: Dict[int, str] = (
trace_files if trace_files is not None else get_trace_files(self.trace_path)
)

self.trace_files: Dict[int, str]
if trace_files is None:
self.trace_files = get_trace_files(self.trace_path)
elif isinstance(trace_files, dict):
self.trace_files = trace_files
elif isinstance(trace_files, list):
ok, self.trace_files = create_rank_to_trace_dict(trace_files)
if not ok:
logger.warning("failed to create rank to trace map")
else:
logger.error(
f"Unsupported type for trace_files = {trace_files}, should be list or dict"
)
return

logger.debug(self.trace_files)
self.traces: Dict[int, pd.DataFrame] = {}
self.symbol_table = TraceSymbolTable()
Expand Down
25 changes: 16 additions & 9 deletions hta/common/trace_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import json
import os
import re
from typing import Any, Dict, Tuple
from typing import Any, Dict, List, Tuple

from hta.configs.config import logger


def create_rank_to_trace_dict(trace_dir: str) -> Tuple[bool, Dict]:
def create_rank_to_trace_dict_from_dir(trace_dir: str) -> Tuple[bool, Dict]:
"""
Create a rank -> trace_filename map for traces located within the directory <trace_path>
Expand All @@ -35,14 +35,19 @@ def create_rank_to_trace_dict(trace_dir: str) -> Tuple[bool, Dict]:
if len(file_list) == 0:
logger.warning(f"No trace file is found in {trace_dir}")
return False, {}
return create_rank_to_trace_dict(
[os.path.join(trace_dir, file) for file in file_list]
)


def create_rank_to_trace_dict(file_list: List[str]) -> Tuple[bool, Dict]:
rank_to_trace_dict: Dict[int, str] = {}
rank_re = re.compile(r'"rank":\s+(\d+)')
for file in file_list:
file_path = os.path.join(trace_dir, file)

for file_path in file_list:
with (
gzip.open(file_path, "rb") if file.endswith("gz") else open(file_path, "r")
gzip.open(file_path, "rb")
if file_path.endswith("gz")
else open(file_path, "r")
) as f:
for line in f:
data = line.decode() if isinstance(line, bytes) else line
Expand All @@ -54,16 +59,18 @@ def create_rank_to_trace_dict(trace_dir: str) -> Tuple[bool, Dict]:
if match:
rank = int(match.group(1))
if rank in rank_to_trace_dict:
logger.error(
logger.warning(
f"File {rank_to_trace_dict[rank]} and file {file_path} has the same rank. Will use {file_path} as the path to rank: {rank}."
)
rank_to_trace_dict[int(rank)] = file_path
else:
logger.error(
logger.warning(
"If the trace file does not have the rank specified in it, then add the following snippet "
'key to the json files to use HTA; "distributedInfo": {"rank": 0}. If there are multiple '
"traces files, then each file should have a unique rank value."
"For now we will default to rank = 0."
)
rank_to_trace_dict[0] = file_path

return True, rank_to_trace_dict

Expand All @@ -84,7 +91,7 @@ def get_trace_files(trace_path: str) -> Dict[int, str]:
if not os.path.exists(trace_path):
logger.warning(f"{trace_path} is not a valid path")
else:
ok, rank_to_trace_dict = create_rank_to_trace_dict(trace_path)
ok, rank_to_trace_dict = create_rank_to_trace_dict_from_dir(trace_path)
if not ok:
logger.warning("failed to create rank to trace map")
return {}
Expand Down
Binary file not shown.
33 changes: 23 additions & 10 deletions tests/test_trace_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from hta.common.trace_file import (
create_rank_to_trace_dict,
create_rank_to_trace_dict_from_dir,
read_trace,
update_trace_rank,
write_trace,
Expand Down Expand Up @@ -56,31 +57,43 @@ def setUp(self) -> None:
self.trace_without_distributed_info = "tests/data/distributed_info_unavailable"
self.trace_without_rank = "tests/data/rank_unavailable"
self.trace_mixed_files = "tests/data/mixed_files"
self.trace_file_list = ["tests/data/trace_file_list/inference_rank_1.json.gz"]
self.logger = logger

def test_create_rank_to_trace_dict_without_distributed_info(self):
with self.assertLogs(logger, level="ERROR") as cm:
with self.assertLogs(logger, level="WARNING") as cm:
self.assertEqual(
create_rank_to_trace_dict(self.trace_without_distributed_info),
(True, {}),
create_rank_to_trace_dict_from_dir(self.trace_without_distributed_info),
(
True,
{
0: "tests/data/distributed_info_unavailable/distributed_info_not_found.json.gz"
},
),
)
self.assertIn("trace file does not have the rank", cm.output[0])

def test_create_rank_to_trace_dict_without_rank(self) -> None:
with self.assertLogs(logger, level="ERROR") as cm:
with self.assertLogs(logger, level="WARNING") as cm:
self.assertEqual(
create_rank_to_trace_dict(self.trace_without_rank), (True, {})
create_rank_to_trace_dict_from_dir(self.trace_without_rank),
(True, {0: "tests/data/rank_unavailable/rank_not_found.json.gz"}),
)
self.assertIn("trace file does not have the rank", cm.output[0])

def test_create_rank_to_trace_dict_with_mixed_dir(self) -> None:
with self.assertLogs(logger, level="ERROR") as cm:
self.assertEqual(
create_rank_to_trace_dict(self.trace_mixed_files),
(True, {0: "tests/data/mixed_files/rank_non_gpu.json.gz"}),
)
with self.assertLogs(logger, level="WARNING") as cm:
ok, res_dict = create_rank_to_trace_dict_from_dir(self.trace_mixed_files)
self.assertTrue(ok)
self.assertEqual(list(res_dict.keys()), [0])
self.assertIn("has the same rank", cm.output[0])

def test_create_rank_to_trace_dict_with_file_list(self) -> None:
self.assertEqual(
create_rank_to_trace_dict(self.trace_file_list),
(True, {1: "tests/data/trace_file_list/inference_rank_1.json.gz"}),
)

def test_read_write_trace(self) -> None:
with tempfile.TemporaryDirectory() as tmpdirname:
test_trace_file = os.path.join(tmpdirname, "test.json.gz")
Expand Down
7 changes: 6 additions & 1 deletion tests/test_trace_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def setUpClass(cls):
inference_trace_dir: str = "tests/data/inference_single_rank"
vision_transformer_rank_0_file: str = "rank-0.json.gz"
inference_rank_0_file: str = "inference_rank_0.json.gz"
inference_trace_files = [
os.path.join(inference_trace_dir, inference_rank_0_file)
]
max_ranks = 8

# Trace parser for vision transformer
Expand All @@ -109,7 +112,9 @@ def setUpClass(cls):
vision_transformer_trace_dir, vision_transformer_rank_0_file
)
# Trace parser for inference
cls.inference_t: Trace = Trace(trace_dir=inference_trace_dir)
cls.inference_t: Trace = Trace(
trace_files=inference_trace_files, trace_dir=os.getcwd()
)
cls.inference_t.parse_traces(max_ranks=max_ranks, use_multiprocessing=True)
cls.inference_raw_df = prepare_ground_truth_df(
inference_trace_dir, inference_rank_0_file
Expand Down

0 comments on commit 5f883b1

Please sign in to comment.