Skip to content

Commit

Permalink
Refactor hot reloading
Browse files Browse the repository at this point in the history
Minor refactoring to simplify the code base and isolate the UI from the
logic. Incidentally, this fixes a logging bug due to removing
BuilderUI.run_async().
  • Loading branch information
DubiousCactus committed Nov 3, 2024
1 parent cb15829 commit 6b195cf
Show file tree
Hide file tree
Showing 5 changed files with 402 additions and 294 deletions.
258 changes: 258 additions & 0 deletions bootstrap/hot_reloading/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import asyncio
import importlib
import inspect
import sys
import traceback
from types import FrameType
from typing import (
Callable,
Optional,
Tuple,
)

from rich.text import Text
from textual.app import App
from textual.widgets import (
RichLog,
)

from bootstrap.hot_reloading.module import MatchboxModule


class HotReloadingEngine:
def __init__(self, ui: App):
self.ui = ui

@classmethod
def get_class_frame(cls, func: Callable, exc_traceback) -> Optional[FrameType]:
"""
Find the frame of the last callable within the scope of the MatchboxModule in
the traceback. In this instance, the MatchboxModule is a class so we want to find
the frame of the method that either (1) threw the exception, or (2) called a
function that threw (or originated) the exception.
"""
last_frame = None
for frame, _ in traceback.walk_tb(exc_traceback):
print(frame.f_code.co_qualname)
if frame.f_code.co_qualname == func.__name__:
print(
f"Found module.underlying_fn ({func.__name__}) in traceback, continuing..."
)
for name, val in inspect.getmembers(func):
if (
name == frame.f_code.co_name
and "self" in inspect.getargs(frame.f_code).args
):
print(f"Found method {val} in traceback, continuing...")
last_frame = frame
return last_frame

@classmethod
def get_lambda_child_frame(
cls, func: Callable, exc_traceback
) -> Tuple[Optional[FrameType], Optional[str]]:
"""
Find the frame of the last callable within the scope of the MatchboxModule in
the traceback. In this instance, the MatchboxModule is a lambda function so we want
to find the frame of the first function called by the lambda.
"""
lambda_args = inspect.getargs(func.__code__)
potential_matches = {}
print(f"Lambda args: {lambda_args}")
for frame, _ in traceback.walk_tb(exc_traceback):
print(frame.f_code.co_qualname)
assert lambda_args is not None
frame_args = inspect.getargvalues(frame)
for name, val in potential_matches.items():
print(
f"Checking candidate {name}={val} to match against {frame.f_code.co_qualname}"
)
if val == frame.f_code.co_qualname:
print(f"Matched {name}={val} to {frame.f_code.co_qualname}")
return frame, name
elif hasattr(val, frame.f_code.co_name):
print(f"Matched {frame.f_code.co_qualname} to member of {val}")
return frame, name
for name in lambda_args.args:
print(f"Lambda arg '{name}'")
if name in frame_args.args:
print(f"Frame has arg {name} with value {frame_args.locals[name]}")
# TODO: Find the argument which initiated the call that threw!
# Which is somewhere deeper in the stack, which
# frame.f_code.co_qualname must match one of the
# frame_args.args!
# NOTE: We know the next frame in the loop WILL match one of
# this frame's arguments, either in the qual_name directly or in
# the qual_name base (the class)
potential_matches[name] = frame_args.locals[name]
return None, None

@classmethod
def get_function_frame(cls, func: Callable, exc_traceback) -> Optional[FrameType]:
print("============= get_function_frame() =========")
last_frame = None
for frame, _ in traceback.walk_tb(exc_traceback):
print(frame.f_code.co_qualname)
if frame.f_code.co_qualname == func.__name__:
print(
f"Found module.underlying_fn ({func.__name__}) in traceback, continuing..."
)
for name, val in inspect.getmembers(func.__module__):
if name == frame.f_code.co_name:
print(f"Found function {val} in traceback, continuing...")
last_frame = frame
print("============================================")
return last_frame

async def catch_and_hang(self, module: MatchboxModule, *args, **kwargs):
try:
self.ui.print_info(f"Calling MatchboxModule({module.underlying_fn}) with")
self.ui.print_pretty(
{
"args": args,
"kwargs": kwargs,
"partial.args": module.partial.args,
"partial.kwargs": module.partial.keywords,
}
)
output = await asyncio.to_thread(module, *args, **kwargs)
self.ui.print_info("Output:")
self.ui.print_pretty(output)
return output
except Exception as exception:
# If the exception came from the wrapper itself, we should not catch it!
exc_type, exc_value, exc_traceback = sys.exc_info()
if exc_traceback.tb_next is None:
self.ui.print_err(
"[ERROR] Could not find the next frame in the call stack!"
)
elif exc_traceback.tb_next.tb_frame.f_code.co_name == "catch_and_hang":
self.ui.print_err(
f"[ERROR] Caught exception in the Builder: {exception}",
)
else:
self.ui.print_err(
f"Caught exception: {exception}",
)
self.ui.query_one("#traceback", RichLog).write(traceback.format_exc())
func = module.underlying_fn
# NOTE: This frame is for the given function, which is the root of the
# call tree (our MatchboxModule's underlying function). What we want is
# to go down to the function that threw, and reload that only if it
# wasn't called anywhere in the frozen module's call tree.
frame = None
if inspect.isclass(func):
frame = self.get_class_frame(func, exc_traceback)
elif inspect.isfunction(func) and func.__name__ == "<lambda>":
frame, lambda_argname = self.get_lambda_child_frame(
func, exc_traceback
)
module.throw_lambda_argname = lambda_argname
elif inspect.isfunction(func):
frame = self.get_function_frame(func, exc_traceback)
else:
raise NotImplementedError()
if not frame:
self.ui.print_err(
f"Could not find the frame of the original function {func} in the traceback."
)
module.throw_frame = frame
self.ui.print_info("Exception thrown in:")
self.ui.print_pretty(frame)
module.to_reload = True
self.ui.print_info("Hanged.")
await self.ui.hang(threw=True)

async def reload_module(self, module: MatchboxModule):
if module.throw_frame is None:
self.ui.exit(1)
raise RuntimeError(
f"Module {module} is set to reload but we don't have the frame that threw!"
)
self.ui.log_tracer(
Text(
f"Reloading code from {module.throw_frame.f_code.co_filename}",
style="purple",
)
)
code_obj = module.throw_frame.f_code
print(code_obj.co_qualname, inspect.getmodule(code_obj))
code_module = inspect.getmodule(code_obj)
if code_module is None:
self.ui.exit(1)
raise RuntimeError(
f"Could not find the module for the code object {code_obj}."
)
rld_module = importlib.reload(code_module)
if code_obj.co_qualname.endswith("__init__"):
class_name = code_obj.co_qualname.split(".")[0]
self.ui.log_tracer(
Text(
f"-> Reloading class {class_name} from module {code_module}",
style="purple",
)
)
rld_callable = getattr(rld_module, class_name)
if rld_callable is not None:
self.ui.log_tracer(
Text(
f"-> Reloaded class {code_obj.co_qualname} from module {code_module.__name__}",
style="cyan",
)
)
print(inspect.getsource(rld_callable))
module.reload(rld_callable)
return

else:
if code_obj.co_qualname.find(".") != -1:
class_name, _ = code_obj.co_qualname.split(".")
self.ui.log_tracer(
Text(
f"-> Reloading class {class_name} from module {code_module}",
style="purple",
)
)
rld_class = getattr(rld_module, class_name)
rld_callable = None
# Now find the method in the reloaded class, and replace the
# with the reloaded one.
for name, val in inspect.getmembers(rld_class):
if inspect.isfunction(val) and val.__name__ == code_obj.co_name:
self.ui.print_info(
f" -> Reloading method '{name}'",
)
rld_callable = val
if rld_callable is not None:
self.ui.log_tracer(
Text(
f"-> Reloaded class-level method {code_obj.co_qualname} from module {code_module.__name__}",
style="cyan",
)
)
if module.underlying_fn.__name__ == "<lambda>":
assert module.throw_lambda_argname is not None
module.reload_surgically_in_lambda(
module.throw_lambda_argname, code_obj.co_name, rld_callable
)
else:
module.reload_surgically(code_obj.co_name, rld_callable)
return
else:
print(code_module, code_obj, code_obj.co_name)
self.ui.log_tracer(
Text(
f"-> Reloading module-level function {code_obj.co_name} from module {code_module.__name__}",
style="purple",
)
)
func = getattr(rld_module, code_obj.co_name)
if func is not None:
self.ui.print_info(
f" -> Reloaded module level function {code_obj.co_name}",
)
print(inspect.getsource(func))
module.reload(func)
return
while True:
await asyncio.sleep(1)
84 changes: 84 additions & 0 deletions bootstrap/hot_reloading/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import uuid
from functools import partial
from types import FrameType
from typing import Any, Callable, List, Optional

from hydra_zen.typing import Partial


class MatchboxModule:
def __init__(self, name: str, fn: Callable | Partial, *args, **kwargs):
self._str_rep = name
self._uid = uuid.uuid4().hex
self.underlying_fn: Callable = fn.func if isinstance(fn, partial) else fn
self.partial = partial(fn, *args, **kwargs)
self.to_reload = False
self.result = None
self.is_frozen = False
self.throw_frame: Optional[FrameType] = None
self.throw_lambda_argname: Optional[str] = None

def reload(self, new_func: Callable) -> None:
print(f"Replacing {self.underlying_fn} with {new_func}")
self.underlying_fn = new_func
self.partial = partial(
self.underlying_fn, *self.partial.args, **self.partial.keywords
)
self.to_reload = False

def reload_surgically(self, method_name: str, method: Callable) -> None:
print(f"Replacing {method_name} which was {self.underlying_fn} with {method}")
setattr(self.underlying_fn, method_name, method)
self.partial = partial(
self.underlying_fn, *self.partial.args, **self.partial.keywords
)
self.to_reload = False

def reload_surgically_in_lambda(
self, arg_name: str, method_name: str, method: Callable
) -> None:
print(
f"Replacing {method_name} as argument {arg_name} in lambda's {self.partial.args} or {self.partial.keywords} with {method}"
)
if arg_name not in self.partial.keywords.keys():
raise KeyError(
"Could not find the argument to replace in the partial kwargs!"
)
for k, v in self.partial.keywords.items():
print(f"Updating lambda arg {v} and")
print(f"re-passing self reference via partial {partial(method, v)}")
setattr(v, method_name, partial(method, v))
self.partial.keywords[k] = v # Need to update when using dict iterator
self.partial = partial(
self.underlying_fn, *self.partial.args, **self.partial.keywords
)
self.to_reload = False

def __call__(self, module_chain: List) -> Any:
"""
Args:
module_chain: List[MatchboxModule]
"""

def _find_module_result(module_chain: List, uid: str) -> Any:
for module in module_chain:
if module.uid == uid:
return module.result
return None

for i, arg in enumerate(self.partial.args):
if isinstance(arg, MatchboxModule):
self.partial.args[i] = _find_module_result(module_chain, arg.uid)
for key, value in self.partial.keywords.items():
if isinstance(value, MatchboxModule):
self.partial.keywords[key] = _find_module_result(
module_chain, value.uid
)
return self.partial()

def __str__(self) -> str:
return self._str_rep

@property
def uid(self) -> str:
return f"uid-{self._uid}"
Loading

0 comments on commit 6b195cf

Please sign in to comment.