From e49b4d38278ad995f38fa0b5845bb68d29ceff8c Mon Sep 17 00:00:00 2001 From: Horace He Date: Wed, 22 Mar 2023 03:04:41 +0000 Subject: [PATCH] Changed logging in aotautograd a little (#97289) Pull Request resolved: https://github.com/pytorch/pytorch/pull/97289 Approved by: https://github.com/mlazos --- test/dynamo/test_logging.py | 3 ++- torch/_functorch/aot_autograd.py | 9 ++++----- torch/_logging/_internal.py | 6 ++---- torch/_logging/_registrations.py | 3 +-- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index fc93a41184a042..bddb2db1b16dd2 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -71,6 +71,7 @@ def single_record_test(**kwargs): class LoggingTests(LoggingTestCase): test_bytecode = multi_record_test(2, bytecode=True) test_output_code = multi_record_test(1, output_code=True) + test_aot_graphs = multi_record_test(2, aot_graphs=True) @requires_cuda() @make_logging_test(schedule=True) @@ -142,7 +143,7 @@ def test_open_registration(self, records): # single record tests -exclusions = {"bytecode", "output_code", "schedule"} +exclusions = {"bytecode", "output_code", "schedule", "aot_graphs"} for name in torch._logging._internal.log_registry.artifact_names: if name not in exclusions: setattr(LoggingTests, f"test_{name}", single_record_test(**{name: True})) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 4fd7cf3db39ebc..8670f3ef364d68 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -31,9 +31,8 @@ from torch._guards import TracingContext, DuplicateInputs, Source log = logging.getLogger(__name__) -aot_forward_log = getArtifactLogger(__name__, "aot_forward_graph") aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph") -aot_backward_log = getArtifactLogger(__name__, "aot_backward_graph") +aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") MutationType = Enum( "MutationType", ("none", "metadata_only", "data", "data_and_metadata") @@ -1269,7 +1268,7 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, * assert copy_count == copy_count2 - aot_forward_log.info(format_graph_code(f"====== Forward graph {aot_config.aot_id} ======\n", fw_module)) + aot_graphs_log.info(format_graph_code(f"====== Forward graph {aot_config.aot_id} ======\n", fw_module)) disable_amp = torch._C._is_any_autocast_enabled() context = disable_autocast_manager if disable_amp else nullcontext @@ -2309,8 +2308,8 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig, ] _num_symints_saved_for_bw = len(symint_outs_saved_for_bw) - aot_forward_log.info(format_graph_code(f"====== Forward graph {aot_config.aot_id} ======\n", fw_module)) - aot_backward_log.info(format_graph_code(f"====== Backward graph {aot_config.aot_id} ======\n", bw_module)) + aot_graphs_log.info(format_graph_code(f"====== Forward graph {aot_config.aot_id} ======\n", fw_module)) + aot_graphs_log.info(format_graph_code(f"====== Backward graph {aot_config.aot_id} ======\n", bw_module)) with track_graph_compiling(aot_config, "forward"): compiled_fw_func = aot_config.fw_compiler( diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index a1402e7d52570e..1c2b4c1af0390b 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -124,8 +124,7 @@ def set_logs( aot=DEFAULT_LOG_LEVEL, inductor=DEFAULT_LOG_LEVEL, bytecode=False, - aot_forward_graph=False, - aot_backward_graph=False, + aot_graphs=False, aot_joint_graph=False, graph=False, graph_code=False, @@ -174,8 +173,7 @@ def _set_logs(**kwargs): aot=aot, inductor=inductor, bytecode=bytecode, - aot_forward_graph=aot_forward_graph, - aot_backward_graph=aot_backward_graph, + aot_graphs=aot_graphs, aot_joint_graph=aot_joint_graph, graph=graph, graph_code=graph_code, diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index 6ddd521876de06..08be97d7a2840e 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -9,8 +9,7 @@ register_artifact("bytecode") register_artifact("graph") register_artifact("graph_code") -register_artifact("aot_forward_graph") -register_artifact("aot_backward_graph") +register_artifact("aot_graphs") register_artifact("aot_joint_graph") register_artifact("output_code", off_by_default=True) register_artifact("schedule", off_by_default=True)