Skip to content

Commit

Permalink
Changed logging in aotautograd a little (pytorch#97289)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#97289
Approved by: https://github.com/mlazos
  • Loading branch information
Chillee authored and pytorchmergebot committed Mar 22, 2023
1 parent 4ab1588 commit e49b4d3
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 12 deletions.
3 changes: 2 additions & 1 deletion test/dynamo/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def single_record_test(**kwargs):
class LoggingTests(LoggingTestCase):
test_bytecode = multi_record_test(2, bytecode=True)
test_output_code = multi_record_test(1, output_code=True)
test_aot_graphs = multi_record_test(2, aot_graphs=True)

@requires_cuda()
@make_logging_test(schedule=True)
Expand Down Expand Up @@ -142,7 +143,7 @@ def test_open_registration(self, records):


# single record tests
exclusions = {"bytecode", "output_code", "schedule"}
exclusions = {"bytecode", "output_code", "schedule", "aot_graphs"}
for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions:
setattr(LoggingTests, f"test_{name}", single_record_test(**{name: True}))
Expand Down
9 changes: 4 additions & 5 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@
from torch._guards import TracingContext, DuplicateInputs, Source

log = logging.getLogger(__name__)
aot_forward_log = getArtifactLogger(__name__, "aot_forward_graph")
aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
aot_backward_log = getArtifactLogger(__name__, "aot_backward_graph")
aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")

MutationType = Enum(
"MutationType", ("none", "metadata_only", "data", "data_and_metadata")
Expand Down Expand Up @@ -1269,7 +1268,7 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *

assert copy_count == copy_count2

aot_forward_log.info(format_graph_code(f"====== Forward graph {aot_config.aot_id} ======\n", fw_module))
aot_graphs_log.info(format_graph_code(f"====== Forward graph {aot_config.aot_id} ======\n", fw_module))

disable_amp = torch._C._is_any_autocast_enabled()
context = disable_autocast_manager if disable_amp else nullcontext
Expand Down Expand Up @@ -2309,8 +2308,8 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
]
_num_symints_saved_for_bw = len(symint_outs_saved_for_bw)

aot_forward_log.info(format_graph_code(f"====== Forward graph {aot_config.aot_id} ======\n", fw_module))
aot_backward_log.info(format_graph_code(f"====== Backward graph {aot_config.aot_id} ======\n", bw_module))
aot_graphs_log.info(format_graph_code(f"====== Forward graph {aot_config.aot_id} ======\n", fw_module))
aot_graphs_log.info(format_graph_code(f"====== Backward graph {aot_config.aot_id} ======\n", bw_module))

with track_graph_compiling(aot_config, "forward"):
compiled_fw_func = aot_config.fw_compiler(
Expand Down
6 changes: 2 additions & 4 deletions torch/_logging/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ def set_logs(
aot=DEFAULT_LOG_LEVEL,
inductor=DEFAULT_LOG_LEVEL,
bytecode=False,
aot_forward_graph=False,
aot_backward_graph=False,
aot_graphs=False,
aot_joint_graph=False,
graph=False,
graph_code=False,
Expand Down Expand Up @@ -174,8 +173,7 @@ def _set_logs(**kwargs):
aot=aot,
inductor=inductor,
bytecode=bytecode,
aot_forward_graph=aot_forward_graph,
aot_backward_graph=aot_backward_graph,
aot_graphs=aot_graphs,
aot_joint_graph=aot_joint_graph,
graph=graph,
graph_code=graph_code,
Expand Down
3 changes: 1 addition & 2 deletions torch/_logging/_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
register_artifact("bytecode")
register_artifact("graph")
register_artifact("graph_code")
register_artifact("aot_forward_graph")
register_artifact("aot_backward_graph")
register_artifact("aot_graphs")
register_artifact("aot_joint_graph")
register_artifact("output_code", off_by_default=True)
register_artifact("schedule", off_by_default=True)

0 comments on commit e49b4d3

Please sign in to comment.