Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync up and add new parser config fields #185

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions hta/configs/parser_config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# pyre-strict

import copy
from enum import Enum
from typing import Dict, List, NamedTuple, Optional, Set, Union


class ParserBackend(str, Enum):
"""Tracer parser and laoder backend
"""Tracer parser and loader backend
See https://github.com/facebookresearch/HolisticTraceAnalysis/pull/125
for details on performance and memory usage.
"""
Expand All @@ -24,6 +26,14 @@ class ValueType(Enum):
Object = 4


class TraceType(str, Enum):
"""TraceType enumerates the possible trace types"""

Training = "training"
TrainingWoProfilerstepAnnot = "training_wo_profilerstep_annot"
Inference = "inference"


class AttributeSpec(NamedTuple):
"""AttributeSpec specifies what an attribute looks like and how to parse it.

Expand Down Expand Up @@ -57,6 +67,12 @@ class AttributeSpec(NamedTuple):
"cpu_op::input_type": AttributeSpec(
"input_type", "Input type", ValueType.Object, "-1"
),
"cpu_op::input_strides": AttributeSpec(
"input_strides",
"Input Strides",
ValueType.Object,
"-1",
),
"cpu_op::sequence_number": AttributeSpec(
"sequence", "Sequence number", ValueType.Int, -1
),
Expand Down Expand Up @@ -112,7 +128,7 @@ class AttributeSpec(NamedTuple):
"out_msg_nelems", "Out msg nelems", ValueType.Int, 0
),
"nccl::group_size": AttributeSpec("group_size", "Group size", ValueType.Int, 0),
"nccl::dtype": AttributeSpec("dtype", "dtype", ValueType.String, ""),
"nccl::dtype": AttributeSpec("msg_dtype", "dtype", ValueType.String, ""),
"nccl::in_split_size": AttributeSpec(
"in_split_size", "In split size", ValueType.Object, "[]"
),
Expand All @@ -128,6 +144,7 @@ class AttributeSpec(NamedTuple):
"nccl::process_group_ranks": AttributeSpec(
"process_group_ranks", "Process Group Ranks", ValueType.Object, "[]"
),
"nccl::rank": AttributeSpec("process_rank", "Rank", ValueType.Int, -1),
}


Expand All @@ -149,7 +166,8 @@ class ParserConfig:
"""

ARGS_INPUT_SHAPE: List[AttributeSpec] = [
AVAILABLE_ARGS[k] for k in ["cpu_op::input_dims", "cpu_op::input_type"]
AVAILABLE_ARGS[k]
for k in ["cpu_op::input_dims", "cpu_op::input_type", "cpu_op::input_strides"]
]
ARGS_BANDWIDTH: List[AttributeSpec] = [
AVAILABLE_ARGS[k] for k in ["data::bytes", "data::bandwidth"]
Expand All @@ -163,17 +181,39 @@ class ParserConfig:
ARGS_COMPLETE: List[AttributeSpec] = [
AVAILABLE_ARGS[k] for k in AVAILABLE_ARGS if not k.startswith("info")
]
ARGS_INFO: List[AttributeSpec] = [
AVAILABLE_ARGS[k] for k in ["info::labels", "info::name", "info::sort_index"]
]
ARGS_COMMUNICATION: List[AttributeSpec] = [
AVAILABLE_ARGS[k]
for k in [
"nccl::collective_name",
"nccl::in_msg_nelems",
"nccl::out_msg_nelems",
"nccl::dtype",
"nccl::group_size",
"nccl::rank",
"nccl::in_split_size",
"nccl::out_split_size",
]
]
ARGS_DEFAULT: List[AttributeSpec] = (
ARGS_MINIMUM
+ ARGS_BANDWIDTH
+ ARGS_SYNC
+ ARGS_INPUT_SHAPE
+ [AVAILABLE_ARGS["index::external_id"]]
)

def __init__(self, args: Optional[List[AttributeSpec]] = None):
def __init__(
self,
args: Optional[List[AttributeSpec]] = None,
user_provide_trace_type: Optional[TraceType] = None,
) -> None:
self.args: List[AttributeSpec] = args if args else self.get_default_args()
self.parser_backend: Optional[ParserBackend] = None
self.trace_memory: bool = False
self.user_provide_trace_type: Optional[TraceType] = user_provide_trace_type

@classmethod
def get_default_cfg(cls) -> "ParserConfig":
Expand All @@ -191,6 +231,10 @@ def get_minimum_args(cls) -> List[AttributeSpec]:
def get_default_args(cls) -> List[AttributeSpec]:
return cls.ARGS_DEFAULT.copy()

@classmethod
def get_info_args(cls) -> List[AttributeSpec]:
return cls.ARGS_INFO.copy()

def set_args(self, args: List[AttributeSpec]) -> None:
if args != self.args:
self.args.clear()
Expand All @@ -209,6 +253,10 @@ def add_args(self, args: List[AttributeSpec]) -> None:
def set_parser_backend(self, parser_backend: ParserBackend) -> None:
self.parser_backend = parser_backend

@staticmethod
def enable_communication_args() -> None:
_DEFAULT_PARSER_CONFIG.add_args(ParserConfig.ARGS_COMMUNICATION)


# Define a global ParserConfig variable for internal use. To access this variable,
# Clients should use ParserConfig.get_default_cfg and ParserConfig.set_default_cfg.
Expand Down
Loading