Skip to content

Commit

Permalink
Fix leak related to stop iteration (#1170)
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored Sep 23, 2024
1 parent a639260 commit 63887b3
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 26 deletions.
20 changes: 17 additions & 3 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ jobs:
os: ["ubuntu-22.04", "macOS-14", "windows-2022"]
python-version: ["3.10"]
requires: ["latest", "nightly"] # , 'oldest'
suite: ["core", "ops"]
include:
- { os: "ubuntu-22.04", python-version: "3.11", requires: "latest" }
- { os: "ubuntu-22.04", python-version: "3.12", requires: "latest" }
exclude:
- { os: "windows-2022", suite: "ops" }

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
Expand Down Expand Up @@ -86,25 +89,36 @@ jobs:
shell: bash

- name: Testing Local
if: matrix.python-version == '3.10'
if: matrix.python-version == '3.10' && matrix.suite == 'core'
run: |
coverage run --source thunder -m \
pytest thunder/tests/ \
--ignore=thunder/tests/distributed \
--ignore=thunder/tests/test_ops.py \
--ignore=thunder/tests/test_grad.py \
-v --datefmt="%Y%m%d-%H:%M:%S.%f" \
--random-order-seed=$GITHUB_RUN_ID \
-n 4 --durations=250
- name: Testing Distributed
# run all found tests in given past as standalone
if: matrix.python-version == '3.10' && runner.os == 'Linux'
if: matrix.python-version == '3.10' && runner.os == 'Linux' && matrix.suite == 'core'
run: |
pytest thunder/tests/distributed/ \
-v --datefmt="%Y%m%d-%H:%M:%S.%f" \
--random-order-seed=$GITHUB_RUN_ID \
--durations=250
- name: Testing just a few
- name: Testing Ops
if: matrix.python-version == '3.10' && matrix.suite == 'ops' && runner.os != 'Windows'
run: |
coverage run --source thunder -m \
pytest thunder/tests/test_ops.py thunder/tests/test_grad.py \
-v --datefmt="%Y%m%d-%H:%M:%S.%f" \
--random-order-seed=$GITHUB_RUN_ID \
-n 4 --durations=250
- name: Testing interpreter
if: matrix.python-version == '3.11' || matrix.python-version == '3.12'
#continue-on-error: true
run: |
Expand Down
63 changes: 40 additions & 23 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,40 +1175,55 @@ def get_localsplus_name(self, idx: int) -> str:
return self.code._varname_from_oparg(idx) # type: ignore

def get_or_make_python_frame(self) -> FrameType:
def fn():
pass

assert self.positions is not None
lineno = self.positions.lineno
if lineno is None:
lineno = self.code.co_firstlineno

rel_lineno = lineno - self.code.co_firstlineno + 1
filename = self.code.co_filename
firstlineno = self.code.co_firstlineno
name = self.code.co_name
qualname = self.qualname

# we prefer this code object over fn.__code__ to get the first lineno and the current lineno right,
# which the following does by inserting so many empty lines that relative to the start line
# the exception is raised at the right line
code = compile((rel_lineno - 1) * "\n" + "raise ValueError()", self.code.co_filename, "exec")
def get_frame(l, rel_lineno, filename, firstlineno, name, qualname):
def fn():
pass

replacements = dict(
co_filename=self.code.co_filename, co_firstlineno=self.code.co_firstlineno, co_name=self.code.co_name
)
# we prefer this code object over fn.__code__ to get the first lineno and the current lineno right,
# which the following does by inserting so many empty lines that relative to the start line
# the exception is raised at the right line
code = compile((rel_lineno - 1) * "\n" + "raise ValueError()", filename, "exec")

if hasattr(fn.__code__, "co_qualname"):
replacements["co_qualname"] = self.qualname
replacements = dict(co_filename=filename, co_firstlineno=firstlineno, co_name=name)

fn.__code__ = code.replace(**replacements) # type: ignore (The replaced fields are the correct types)
if hasattr(fn.__code__, "co_qualname"):
replacements["co_qualname"] = qualname

try:
fn()
assert False, "Unreachable."
except ValueError as e:
tb = e.__traceback__
fn.__code__ = code.replace(**replacements) # type: ignore (The replaced fields are the correct types)

assert tb is not None
while tb.tb_next is not None:
tb = tb.tb_next
return tb.tb_frame
try:
fn()
assert False, "Unreachable."
except ValueError as e:
tb = e.__traceback__

assert tb is not None
while tb.tb_next is not None:
tb = tb.tb_next
l.append(tb.tb_frame)

# we run the getting of the frame in a separate thread because
# we want to avoid having f_back pointing to the function
# handling the error
import _thread

result_container = []
_thread.start_new_thread(get_frame, (result_container, rel_lineno, filename, firstlineno, name, qualname))
while not result_container:
pass
return result_container[0]


#
Expand Down Expand Up @@ -4246,6 +4261,7 @@ def _next_impl(tos):
if r is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
ctx = get_interpreterruntimectx()
if isinstance(ctx.curexc, StopIteration):
ctx.curexc = None
if sys.version_info >= (3, 12):
# 3.12 uses jumps relative to the next instruction offset and does not pop here
# instead it pushes a fake value?!
Expand Down Expand Up @@ -6779,8 +6795,9 @@ def _setup_frame_and_run_python_function(
except Exception as e:
# We need to cheat a bit to get a Python frame here...
python_frame = frame.get_or_make_python_frame()
tb = TracebackType(e.__traceback__, python_frame, python_frame.f_lasti, python_frame.f_lineno)
raise e.with_traceback(tb)
e.__traceback__ = TracebackType(e.__traceback__, python_frame, python_frame.f_lasti, python_frame.f_lineno)
del e # avoid memory leak
raise
return res


Expand Down
48 changes: 48 additions & 0 deletions thunder/tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,54 @@ def cross_function_exceptions():
assert jit(cross_function_exceptions)() == True


def test_stop_exception_no_leak(jit):

class Identity(torch.nn.Module):
def forward(self, x):
for p in self.parameters():
pass
return x

def foo():
model = thunder.jit(Identity())
x = torch.randn(16, 16)

model(x)

return weakref.ref(x)

weak_x = foo()

assert weak_x() is None


def test_exception_no_leak(jit):

class Identity(torch.nn.Module):
@staticmethod
def raises():
raise RuntimeError("Exc")

def forward(self, x):
try:
self.raises()
except RuntimeError:
pass
return x

def foo():
model = thunder.jit(Identity())
x = torch.randn(16, 16)

model(x)

return weakref.ref(x)

weak_x = foo()

assert weak_x() is None


def test_walrus_operator(jit):
def foo(a, b):
c = (a := b)
Expand Down

0 comments on commit 63887b3

Please sign in to comment.