Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make exceptions from user code more readable #1460

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
44 changes: 43 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
AnyProxy,
)
from thunder.core.interpreter import print_interpreter_log, print_to_log
from thunder.core.jit_ext import thunder_general_jit
from thunder.core.jit_ext import thunder_general_jit, InnerException
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction

# NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this
Expand Down Expand Up @@ -814,7 +814,49 @@ def maybe_call_epilogue(cache_entry, result, pro_to_epi):

return result

def unwrap_inner_exception(c: Callable) -> Callable:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this could live in the interpreter instead.

Copy link
Collaborator Author

@apaz-cli apaz-cli Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function can. The unwrapping actually can't, tried that already. Traceback frames are pushed as the stack unwinds, not when the exception is thrown.

def _thunder_unwrap_inner_exception(*args, **kwargs):
# Run the function, and caputre the exception if there is one.
try:
return c(*args, **kwargs)
except InnerException as e:
exc = e.value

def internal_to_thunder(co):
if co is thunder_general_jit.__code__ or co is _thunder_unwrap_inner_exception.__code__:
return True
return co.co_filename.endswith("thunder" + os.sep + "core" + os.sep + "interpreter.py") and (
co.co_name in ("fn_", "fn_2")
)
Comment on lines +828 to +830
Copy link
Collaborator

@t-vi t-vi Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Why just fn_ and fn_2 and not all of the methods in the interpreter?
  • Please don't use the filename/code object here. Instead check whether that is in f_globals of the frame has __name__ == 'thunder.core.interpreter'.

If we drop all interpreter methods, we likely want to keep internal frames at the top of the stacktrace.

Copy link
Collaborator Author

@apaz-cli apaz-cli Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excluding everything in thunder.core.interpreter sounds good. I was trying to be conservative by default.

But using f_globals to do that instead of co_name doesn't work, because not every function has a f_globals with a __name__ in it. Certainly it would not work with some uses of the exec() lookaside. It also depends on how exactly the module is imported. There are a bunch of reasons why it doesn't work.

As for the functions filtered out by code object, those live outside of interpreter.py, in files that I don't want to exclude because some of their functions could be called by user code. And using the code objects is very convenient because it exactly matches the functions to exclude, and we wouldn't want to exclude all of thunder.__init__.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which functions we want to filter does not have an f_globals with __name__?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thunder_general_jit() does not. When I add an assertion that f_globals has a __name__, it fails.

ap-readable_exceptions ~/lightning-thunder python exctest.py 
Traceback (most recent call last):
  File "/teamspace/studios/this_studio/lightning-thunder/exctest.py", line 14, in <module>
    jfn()
  File "/teamspace/studios/this_studio/lightning-thunder/thunder/__init__.py", line 836, in _thunder_unwrap_inner_exception
    assert hasattr(tb.tb_frame.f_globals, "__name__"), tb.tb_frame.f_code.co_name
AssertionError: thunder_general_jit

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's because f_globals is a mapping, so it's members, not attributes.


# Iterate over the traceback and collect frames that don't correspond to thunder internal functions.
tb = exc.__traceback__
tb_frames = []
while tb != None:
co = tb.tb_frame.f_code
if not internal_to_thunder(co):
tb_frames.append(tb)
tb = tb.tb_next

# Relink the non-internal traceback frames
if tb_frames:
top_tb = tb = tb_frames[0]
for _tb in tb_frames[1:]:
tb.tb_next = _tb
tb = _tb
exc.__traceback__ = top_tb

# Re-raise the exception without retaining it in this stack frame to avoid leaking tensors.
try:
raise exc
except Exception:
del exc
raise # re-raises current exception

return _thunder_unwrap_inner_exception

@wraps(fn)
@unwrap_inner_exception
@update_call_statistics
def fn_(*args, **kwargs) -> Any:
if is_tracing():
Expand Down
11 changes: 10 additions & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@
}


class InnerException(BaseException):
def __init__(self, *, value: BaseException):
self.value = value


class JITSharpEdgeError(RuntimeError):
"""
Thrown when the program cannot be safely translated to a thunder program,
Expand Down Expand Up @@ -1741,7 +1746,11 @@ def thunder_general_jit(

with jit_ctx(ctx):
with tracectx(computation_trace):
result = jfn(*args, **kwargs)
try:
result = jfn(*args, **kwargs)
except BaseException as e:
raise InnerException(value=e)

prims.python_return(result)
computation_trace.set_current_source_location(None, None)
process_recorded_modifications(ctx, epilogue_trace)
Expand Down
30 changes: 30 additions & 0 deletions thunder/tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,36 @@ def main():
assert weak_x() is None


def test_backtrace_filter():
import thunder

def fn1():
fn2()

def fn2():
fn3()

def fn3():
raise ValueError

jfn = thunder.jit(fn1)

expected_frame_names = ["test_backtrace_filter", "_thunder_unwrap_inner_exception", "fn1", "fn2", "fn3"]

try:
jfn()
except ValueError as e:
tb_frames = []
tb = e.__traceback__
while tb != None:
tb_frames.append(tb)
tb = tb.tb_next
frame_names = [tb.tb_frame.f_code.co_name for tb in tb_frames]
assert frame_names == expected_frame_names
except BaseException as e:
assert False, e # Wrong exception type.


def test_walrus_operator(jit):
def foo(a, b):
c = (a := b)
Expand Down
Loading