From a39afecb2b5da3808376b3eccd52a54f28fd28d0 Mon Sep 17 00:00:00 2001 From: apaz Date: Tue, 19 Nov 2024 22:59:09 +0000 Subject: [PATCH 1/8] Made exceptions more readable. --- thunder/__init__.py | 33 ++++++++++++++++++++++++++++++++- thunder/core/jit_ext.py | 11 ++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 54c94855d..0c5a26ba0 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -71,7 +71,7 @@ AnyProxy, ) from thunder.core.interpreter import print_interpreter_log, print_to_log -from thunder.core.jit_ext import thunder_general_jit +from thunder.core.jit_ext import thunder_general_jit, InnerException from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction # NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this @@ -814,7 +814,38 @@ def maybe_call_epilogue(cache_entry, result, pro_to_epi): return result + def unwrap_inner_exception(c: Callable) -> Callable: + def _thunder_unwrap_inner_exception(*args, **kwargs): + try: + return c(*args, **kwargs) + except InnerException as e: + exc = e.value + + tb = exc.__traceback__ + tb_frames = [] + while tb != None: + co = tb.tb_frame.f_code + co_fname = co.co_filename + co_name = co.co_name + if ((co is _thunder_unwrap_inner_exception.__code__ or co is thunder_general_jit.__code__) or + (co_fname.endswith("thunder" + os.sep + "core" + os.sep + "interpreter.py") and (co_name in ("fn_", "fn_2")))): + pass + else: + tb_frames.append(tb) + tb = tb.tb_next + + if tb_frames: + top_tb = tb = tb_frames[0] + for _tb in tb_frames[1:]: + tb.tb_next = _tb + tb = _tb + exc.with_traceback(top_tb) + + raise exc + return _thunder_unwrap_inner_exception + @wraps(fn) + @unwrap_inner_exception @update_call_statistics def fn_(*args, **kwargs) -> Any: if is_tracing(): diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index cf186b1ea..e822cb5b6 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -132,6 +132,11 @@ } +class InnerException(BaseException): + def __init__(self, *, value: BaseException): + self.value = value + + class JITSharpEdgeError(RuntimeError): """ Thrown when the program cannot be safely translated to a thunder program, @@ -1741,7 +1746,11 @@ def thunder_general_jit( with jit_ctx(ctx): with tracectx(computation_trace): - result = jfn(*args, **kwargs) + try: + result = jfn(*args, **kwargs) + except BaseException as e: + raise InnerException(value=e) + prims.python_return(result) computation_trace.set_current_source_location(None, None) process_recorded_modifications(ctx, epilogue_trace) From 2a9d738bb11db6bff03bca5e63e72921a2553a38 Mon Sep 17 00:00:00 2001 From: apaz Date: Wed, 20 Nov 2024 14:49:44 +0000 Subject: [PATCH 2/8] Fix memory leak. --- thunder/__init__.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 0c5a26ba0..490ccfd45 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -816,11 +816,13 @@ def maybe_call_epilogue(cache_entry, result, pro_to_epi): def unwrap_inner_exception(c: Callable) -> Callable: def _thunder_unwrap_inner_exception(*args, **kwargs): + # Run the function, and caputre the exception if there is one. try: return c(*args, **kwargs) except InnerException as e: exc = e.value + # Iterate over the traceback and exclude thunder internal functions. tb = exc.__traceback__ tb_frames = [] while tb != None: @@ -834,14 +836,22 @@ def _thunder_unwrap_inner_exception(*args, **kwargs): tb_frames.append(tb) tb = tb.tb_next + # Relink the non-internal traceback frames if tb_frames: top_tb = tb = tb_frames[0] for _tb in tb_frames[1:]: tb.tb_next = _tb tb = _tb - exc.with_traceback(top_tb) + exc.__traceback__ = top_tb + + # Re-raise the exception without retaining it in this stack frame to avoid leaking tensors. + try: + raise exc + except Exception as e: + del exc + del e + raise # re-raises current exception - raise exc return _thunder_unwrap_inner_exception @wraps(fn) From 0b618c151a48cb65b5de70ecc8866b94e16971f4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Nov 2024 14:56:57 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 490ccfd45..e59059be6 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -829,8 +829,10 @@ def _thunder_unwrap_inner_exception(*args, **kwargs): co = tb.tb_frame.f_code co_fname = co.co_filename co_name = co.co_name - if ((co is _thunder_unwrap_inner_exception.__code__ or co is thunder_general_jit.__code__) or - (co_fname.endswith("thunder" + os.sep + "core" + os.sep + "interpreter.py") and (co_name in ("fn_", "fn_2")))): + if (co is _thunder_unwrap_inner_exception.__code__ or co is thunder_general_jit.__code__) or ( + co_fname.endswith("thunder" + os.sep + "core" + os.sep + "interpreter.py") + and (co_name in ("fn_", "fn_2")) + ): pass else: tb_frames.append(tb) @@ -850,7 +852,7 @@ def _thunder_unwrap_inner_exception(*args, **kwargs): except Exception as e: del exc del e - raise # re-raises current exception + raise # re-raises current exception return _thunder_unwrap_inner_exception From f270ad4be17f6d3622d778f9dbb4705741de6bbc Mon Sep 17 00:00:00 2001 From: apaz Date: Wed, 20 Nov 2024 18:56:23 +0000 Subject: [PATCH 4/8] Separate out check logic --- thunder/__init__.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index e59059be6..fa41ca5b5 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -822,19 +822,17 @@ def _thunder_unwrap_inner_exception(*args, **kwargs): except InnerException as e: exc = e.value - # Iterate over the traceback and exclude thunder internal functions. + def internal_to_thunder(co): + if co is thunder_general_jit.__code__ or co is _thunder_unwrap_inner_exception.__code__: + return True + return co.co_filename.endswith("thunder" + os.sep + "core" + os.sep + "interpreter.py") and (co.co_name in ("fn_", "fn_2")) + + # Iterate over the traceback and collect frames that don't correspond to thunder internal functions. tb = exc.__traceback__ tb_frames = [] while tb != None: co = tb.tb_frame.f_code - co_fname = co.co_filename - co_name = co.co_name - if (co is _thunder_unwrap_inner_exception.__code__ or co is thunder_general_jit.__code__) or ( - co_fname.endswith("thunder" + os.sep + "core" + os.sep + "interpreter.py") - and (co_name in ("fn_", "fn_2")) - ): - pass - else: + if not internal_to_thunder(co): tb_frames.append(tb) tb = tb.tb_next From 4c7d7be8ddb80443dbe2bd8d4c8aec94e72f14f4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Nov 2024 18:59:27 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index fa41ca5b5..a46986379 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -825,7 +825,9 @@ def _thunder_unwrap_inner_exception(*args, **kwargs): def internal_to_thunder(co): if co is thunder_general_jit.__code__ or co is _thunder_unwrap_inner_exception.__code__: return True - return co.co_filename.endswith("thunder" + os.sep + "core" + os.sep + "interpreter.py") and (co.co_name in ("fn_", "fn_2")) + return co.co_filename.endswith("thunder" + os.sep + "core" + os.sep + "interpreter.py") and ( + co.co_name in ("fn_", "fn_2") + ) # Iterate over the traceback and collect frames that don't correspond to thunder internal functions. tb = exc.__traceback__ From 69221b7c0b450202f0940dfb8459b2d65169fd5f Mon Sep 17 00:00:00 2001 From: apaz Date: Thu, 21 Nov 2024 17:01:14 +0000 Subject: [PATCH 6/8] Remove potentially unused exception name. --- thunder/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index a46986379..c5eb46819 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -849,9 +849,8 @@ def internal_to_thunder(co): # Re-raise the exception without retaining it in this stack frame to avoid leaking tensors. try: raise exc - except Exception as e: + except Exception: del exc - del e raise # re-raises current exception return _thunder_unwrap_inner_exception From 1e02e0ace184f9f076c7d520fd77932541e564ae Mon Sep 17 00:00:00 2001 From: apaz Date: Thu, 21 Nov 2024 19:48:57 +0000 Subject: [PATCH 7/8] Add test with expected frame names. --- thunder/tests/test_interpreter.py | 36 +++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index 580ac914d..7c26d4595 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -850,6 +850,42 @@ def main(): assert weak_x() is None +def test_backtrace_filter(): + import thunder + + def fn1(): + fn2() + + def fn2(): + fn3() + + def fn3(): + raise ValueError + + jfn = thunder.jit(fn1) + + expected_frame_names = [ + "test_backtrace_filter", + "_thunder_unwrap_inner_exception", + "fn1", + "fn2", + "fn3" + ] + + try: + jfn() + except ValueError as e: + tb_frames = [] + tb = e.__traceback__ + while tb != None: + tb_frames.append(tb) + tb = tb.tb_next + frame_names = [tb.tb_frame.f_code.co_name for tb in tb_frames] + assert frame_names == expected_frame_names + except BaseException as e: + assert False, e # Wrong exception type. + + def test_walrus_operator(jit): def foo(a, b): c = (a := b) From 1ac4664b07918721430ed5cb8651abaa0c7fb4ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Nov 2024 19:50:23 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/test_interpreter.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index 7c26d4595..08f53f0b7 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -864,13 +864,7 @@ def fn3(): jfn = thunder.jit(fn1) - expected_frame_names = [ - "test_backtrace_filter", - "_thunder_unwrap_inner_exception", - "fn1", - "fn2", - "fn3" - ] + expected_frame_names = ["test_backtrace_filter", "_thunder_unwrap_inner_exception", "fn1", "fn2", "fn3"] try: jfn() @@ -883,7 +877,7 @@ def fn3(): frame_names = [tb.tb_frame.f_code.co_name for tb in tb_frames] assert frame_names == expected_frame_names except BaseException as e: - assert False, e # Wrong exception type. + assert False, e # Wrong exception type. def test_walrus_operator(jit):