Skip to content

Commit

Permalink
Add record_history flag to thunder.jit. Off by default. Fixes mem…
Browse files Browse the repository at this point in the history
…ory issues. (#239)
  • Loading branch information
apaz-cli authored Apr 21, 2024
1 parent 593c2f1 commit 23bb666
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 51 deletions.
19 changes: 16 additions & 3 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,16 @@ def __version__():


# Translates the Python function to a thunder program using the thunder interpreter
def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> TraceResults:
return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
def _general_frontend(
fn: Callable,
args: tuple[Any, ...],
kwargs: dict[str, Any],
/,
*,
record_history: bool,
sharp_edges: SHARP_EDGES_OPTIONS,
) -> TraceResults:
return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)


class ThunderModule(pytorch.nn.Module):
Expand Down Expand Up @@ -327,6 +335,7 @@ def jit(
cache: None | CACHE_OPTIONS | str = None,
disable_torch_autograd: bool = False, # TODO Revisit this UX for RC1
additional_transforms: list | None = None,
record_history: bool = False,
**compile_options, # TODO RC1 Make this explicit -- dict of options
) -> Callable:
"""Just-in-time compile a callable (function or model).
Expand Down Expand Up @@ -377,6 +386,8 @@ def jit(
# TODO: sharp edge if lookasides are shadowed?
executor_lookasides.update(ex._lookasides)

assert type(record_history) is bool

# TODO RC1 Refine the compile data option to remove unused options
cd = CompileData(
fn=fn,
Expand Down Expand Up @@ -524,7 +535,9 @@ def get_computation_and_inputs(*args, **kwargs):
with langctxs.langctx(cd.langctx):
prologue_trc: TraceCtx
computation_trc: TraceCtx
jit_results: TraceResults = interpreter(fn, args, kwargs, sharp_edges=cd.sharp_edges)
jit_results: TraceResults = interpreter(
fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
)
prologue_trc = jit_results.prologue_trace
computation_trc = jit_results.computation_trace
epilogue_trc = jit_results.epilogue_trace
Expand Down
89 changes: 57 additions & 32 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,30 +453,30 @@ def interpretercompilectx(_interpretercompilectx: InterpreterCompileCtx):

class LineLogItem(TypedDict):
kind: Literal["Line"]
fn: Callable | CodeType
fn: str
filename: str
position: Positions | None


class OpaqueLogItem(TypedDict):
kind: Literal["Opaque"]
fn: Callable
fn: str


class LookasideLogItem(TypedDict):
kind: Literal["Lookaside"]
fn: Callable
fn: str


class CallLogItem(TypedDict):
kind: Literal["InterpreterCall"]
fn: Callable
fn: str
prev_frame: str


class ReturnLogItem(TypedDict):
kind: Literal["InterpreterReturn"]
fn: Callable
fn: str
is_signal: bool
rval: type | INTERPRETER_SIGNALS

Expand Down Expand Up @@ -512,7 +512,7 @@ class ReturnLogItem(TypedDict):

# The interpreter's runtime context, which tracks stack changes in Python mode
class InterpreterRuntimeCtx:
def __init__(self, *, debug_log: None | StringIO = None):
def __init__(self, record_history: bool, debug_log: None | StringIO):
self.frame_stack: list[InterpreterFrame] = []
self._globals_dict: dict[str, Any] | None = None
self._interpreter_log: list[InterpreterLogItem] = []
Expand All @@ -533,6 +533,7 @@ def __init__(self, *, debug_log: None | StringIO = None):
self._prev_position: Positions | None = None
self._known_wrappers = {}
self._proxied_values = set()
self._record_history = record_history

def register_proxied_value(self, v):
self._proxied_values.add(v)
Expand Down Expand Up @@ -572,6 +573,9 @@ def interp_log(self) -> list[InterpreterLogItem]:
return self._interpreter_log

def record(self, val: InterpreterLogItem, /) -> None:
if not self._record_history:
return

self._interpreter_log.append(val)

if self.debug_log is not None:
Expand Down Expand Up @@ -605,50 +609,72 @@ def push_frame_stack(self, frame: InterpreterFrame):
# consider just appending to the log and then filtering to only instructions when
# interpreted_instructions is accessed
def record_interpreted_instruction(self, inst: dis.Instruction, /) -> InterpreterRuntimeCtx:
if not self._record_history:
return self

self._interpreted_instructions.append(inst)
self.record(inst)
return self

def record_interpreter_call(self, fn: Callable) -> InterpreterRuntimeCtx:
if not self._record_history:
return self

frame: InterpreterFrame | None = self.peek_frame_stack()

# If frame is None, that means that this is the first call to _interpret_call, in _run_frame.
# In that case we should also print out what line we're starting on, since
# no line number changes have happened yet.
if frame is not None:
self.record({"kind": "InterpreterCall", "fn": fn, "prev_frame": frame.qualname})
self.record(
{"kind": "InterpreterCall", "fn": extract_callable_name(unwrap(fn)), "prev_frame": frame.qualname}
)
else:
if hasattr(self._original_callsite, "positions"):
pos = self._original_callsite.positions
else:
pos = Positions(self._original_callsite.lineno, self._original_callsite.lineno, 0, 999)
# self.record_position(fn, self._original_callsite.filename, pos)
self.record_position(fn, self._original_callsite.filename, pos)
self.record(
{
"kind": "InterpreterCall",
"fn": fn,
"fn": extract_callable_name(unwrap(fn)),
"prev_frame": self._original_callsite.function,
}
)
return self

def record_interpreter_return(self, fn: Callable, rval: Any | INTERPRETER_SIGNALS, /) -> InterpreterRuntimeCtx:
if not self._record_history:
return self

is_signal: bool = isinstance(rval, INTERPRETER_SIGNALS)
rv: type | INTERPRETER_SIGNALS = rval if is_signal else type(unwrap(rval))
self.record(ReturnLogItem(kind="InterpreterReturn", fn=fn, is_signal=is_signal, rval=rv))
self.record(
ReturnLogItem(kind="InterpreterReturn", fn=extract_callable_name(unwrap(fn)), is_signal=is_signal, rval=rv)
)
return self

def record_opaque_call(self, fn: Callable) -> InterpreterRuntimeCtx:
self.record(OpaqueLogItem(kind="Opaque", fn=fn))
if not self._record_history:
return self

self.record(OpaqueLogItem(kind="Opaque", fn=extract_callable_name(unwrap(fn))))
return self

def record_lookaside(self, fn: Callable) -> InterpreterRuntimeCtx:
self.record(LookasideLogItem(kind="Lookaside", fn=fn))
if not self._record_history:
return self

self.record(LookasideLogItem(kind="Lookaside", fn=extract_callable_name(unwrap(fn))))
return self

def record_position(
self, fn: Callable | CodeType, filename: str, position: Positions | None, /
) -> InterpreterRuntimeCtx:
if not self._record_history:
return self

# Only record a change in the Python line
if filename == self._prev_filename and _positions_equal(position, self._prev_position):
return self
Expand All @@ -658,7 +684,7 @@ def record_position(

self._prev_position = position
self._prev_filename = filename
line = LineLogItem(kind="Line", fn=fn, filename=filename, position=position)
line = LineLogItem(kind="Line", fn=extract_callable_name(unwrap(fn)), filename=filename, position=position)
self.record(line)
return self

Expand Down Expand Up @@ -5974,12 +6000,12 @@ def _interpret_call_with_unwrapping(fn: Callable, /, *args, **kwargs) -> Any | I
return unwrap(res)


def _interpret_call(fn: Callable | WrappedValue, /, *args, **kwargs) -> Any | INTERPRETER_SIGNALS:
def _interpret_call(fn: Callable, /, *args, **kwargs) -> Any | INTERPRETER_SIGNALS:
compilectx: InterpreterCompileCtx = get_interpretercompilectx()
runtimectx: InterpreterRuntimeCtx = get_interpreterruntimectx()

# TODO: Implement generics and fix WrappedValue[T] everywhere.
runtimectx.record_interpreter_call(fn) # type: ignore
runtimectx.record_interpreter_call(fn)
rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs) # type: ignore
if compilectx._with_provenance_tracking:
assert isinstance(rval, (INTERPRETER_SIGNALS, WrappedValue)), f"return {rval} unexpected calling {unwrap(fn)}"
Expand Down Expand Up @@ -6160,10 +6186,11 @@ def partial_call_impl(partial_function, /, *args, **kwargs):

# (4) Handles opaque functions
if is_opaque(fn):
runtimectx.record_opaque_call(fn)
args_ = [unwrap(a) for a in args]
# TODO: Deeper unwrapping?
args_ = tuple(unwrap(a) for a in args)
kwargs_ = {unwrap(k): unwrap(v) for k, v in kwargs.items()}
try:
runtimectx.record_opaque_call(fn)
opaque_result: Any = fn(*args_, **kwargs_)
except Exception as e:
runtimectx.curexc = e
Expand Down Expand Up @@ -6601,6 +6628,7 @@ def interpret(
debug_log: None | StringIO = None,
with_provenance_tracking: bool = False,
uncacheable_classes: list[type] | None = None,
record_history: bool = False,
) -> Callable:
compilectx: InterpreterCompileCtx = InterpreterCompileCtx(
opcode_interpreter=opcode_interpreter,
Expand All @@ -6614,7 +6642,7 @@ def interpret(

@functools.wraps(fn)
def fn_(*args, **kwargs) -> Any:
runtimectx: InterpreterRuntimeCtx = InterpreterRuntimeCtx(debug_log=debug_log)
runtimectx: InterpreterRuntimeCtx = InterpreterRuntimeCtx(debug_log=debug_log, record_history=record_history)

with interpreter_ctx(compilectx, runtimectx):
try:
Expand Down Expand Up @@ -6720,18 +6748,16 @@ def print_interpreter_log(
linecolor = colors["RESET"]
log_line = item

case {"kind": "Line", "fn": _fn, "filename": filename, "position": position}:
case {"kind": "Line", "fn": fn, "filename": filename, "position": position}:
# LineLogItem
_fn = unwrap(_fn)
inside_inner_interpreter = interpreter_path in filename
if color_internals or not inside_inner_interpreter:
linecolor = colors["YELLOW"]
nl = os.linesep
fnname = extract_callable_name(_fn)
if position:
log_line = f"# Line {filename}:{position.lineno} in {fnname}()"
log_line = f"# Line {filename}:{position.lineno} in {fn}()"
else:
log_line = f"# {filename} in {fnname}()"
log_line = f"# {filename} in {fn}()"

if not print_source_code or not position:
continue
Expand All @@ -6745,36 +6771,35 @@ def print_interpreter_log(

case {"kind": "InterpreterCall", "fn": fn, "prev_frame": prev_frame}:
# CallLogItem
fn = unwrap(fn)
if color_internals or not inside_inner_interpreter:
linecolor = colors["GREEN"]
c_indent += 1
log_line = f"Interpreting call to {extract_callable_name(fn)}() from {prev_frame}{'()' if not prev_frame.endswith('>') else ''}"
log_line = (
f"Interpreting call to {fn}() from {prev_frame}{'()' if not prev_frame.endswith('>') else ''}"
)

case {"kind": "InterpreterReturn", "fn": fn, "is_signal": is_signal, "rval": rval}:
case {"kind": "InterpreterReturn", "fn": fn, "rval": rval}:
# ReturnLogItem
fn = unwrap(fn)
rval = unwrap(rval)
if color_internals or not inside_inner_interpreter:
linecolor = colors["RED"]
deindent = True
is_signal = isinstance(rval, INTERPRETER_SIGNALS)
meaning = "signal" if is_signal else "value of type"
val = rval if is_signal else rval.__qualname__
log_line = f"Returning from call to {extract_callable_name(fn)}() with {meaning} {val}"
log_line = f"Returning from call to {fn}() with {meaning} {val}"

case {"kind": "Lookaside", "fn": fn}:
# LookasideLogItem
fn = unwrap(fn)
if color_internals or not inside_inner_interpreter:
linecolor = colors["BLUE"]
log_line = f"Lookaside to {extract_callable_name(fn)}()"
log_line = f"Lookaside to {fn}()"

case {"kind": "Opaque", "fn": fn}:
# OpaqueLogItem
fn = unwrap(fn)
if color_internals or not inside_inner_interpreter:
linecolor = colors["CYAN"]
log_line = f"Opaque call to {fn} with name {extract_callable_name(fn)}"
log_line = f"Opaque call to {fn}()"

case _:
raise NotImplementedError(f"Unexpected log item {item}")
Expand Down
15 changes: 12 additions & 3 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,9 @@ def _minimal_global_callback(orig_value: Any, name: str) -> Any:


# TODO RC1 Add debug_log
def minimal_thunder_jit(fn: Callable, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> Callable:
def minimal_thunder_jit(fn: Callable, /, *, record_history: bool = False, sharp_edges: SHARP_EDGES_OPTIONS) -> Callable:
ctx: MinimalCtx = MinimalCtx(sharp_edges=sharp_edges)
jfn = interpret(fn, fn_lookaside=_minimal_lookaside, callbacks=_minimal_callbacks)
jfn = interpret(fn, fn_lookaside=_minimal_lookaside, callbacks=_minimal_callbacks, record_history=record_history)

def fn_(*args, **kwargs):
try:
Expand Down Expand Up @@ -1395,7 +1395,15 @@ def _get_process_group_from(*fn_and_args) -> Optional["ProcessGroup"]:
return found_pg


def thunder_general_jit(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> TraceResults:
def thunder_general_jit(
fn: Callable,
args: tuple[Any, ...],
kwargs: dict[str, Any],
/,
*,
record_history: bool = False,
sharp_edges: SHARP_EDGES_OPTIONS,
) -> TraceResults:
# TODO: move into wrap_callback or so
if isinstance(fn, torch.nn.parallel.DistributedDataParallel):
raise NotImplementedError(
Expand Down Expand Up @@ -1432,6 +1440,7 @@ def thunder_general_jit(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDG
callbacks=general_jit_callbacks,
with_provenance_tracking=True,
uncacheable_classes=(torch.Tensor, int, float, str, NoneType),
record_history=record_history,
)

with general_jit_ctx(ctx):
Expand Down
10 changes: 6 additions & 4 deletions thunder/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,12 @@ def leaf_test(x: Any) -> bool:


# Translates the Python function a thunder program using the Python interpreter
def _python_interpreter(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> TraceResults:
def _python_interpreter(
fn: Callable, args, kwargs, /, *, record_history: bool = False, sharp_edges: SHARP_EDGES_OPTIONS
) -> TraceResults:
if sharp_edges is not SHARP_EDGES_OPTIONS.ALLOW:
raise ValueError(
f"Detecting sharp edges is not supported when using the Python interpreter. To detect sharp edges use another interpretation option."
"Detecting sharp edges is not supported when using the Python interpreter. To detect sharp edges use another interpretation option."
)

def _interpreter(fn_):
Expand All @@ -433,11 +435,11 @@ def _interpreter(fn_):

# Translates the Python function to a thunder program using the thunder interpreter
def _translate_functions_interpreter(
fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS
fn: Callable, args, kwargs, /, *, record_history: bool = False, sharp_edges: SHARP_EDGES_OPTIONS
) -> TraceResults:
from thunder.core.jit_ext import minimal_thunder_jit

pjit = partial(minimal_thunder_jit, sharp_edges=sharp_edges)
pjit = partial(minimal_thunder_jit, sharp_edges=sharp_edges, record_history=record_history)
return _eager_unpacking_interpreter(pjit, fn, args, kwargs, interpreter_name="translate functions")


Expand Down
Loading

0 comments on commit 23bb666

Please sign in to comment.