-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Minor refactoring to simplify the code base and isolate the UI from the logic. Incidentally, this fixes a logging bug due to removing BuilderUI.run_async().
- Loading branch information
1 parent
cb15829
commit 6b195cf
Showing
5 changed files
with
402 additions
and
294 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,258 @@ | ||
import asyncio | ||
import importlib | ||
import inspect | ||
import sys | ||
import traceback | ||
from types import FrameType | ||
from typing import ( | ||
Callable, | ||
Optional, | ||
Tuple, | ||
) | ||
|
||
from rich.text import Text | ||
from textual.app import App | ||
from textual.widgets import ( | ||
RichLog, | ||
) | ||
|
||
from bootstrap.hot_reloading.module import MatchboxModule | ||
|
||
|
||
class HotReloadingEngine: | ||
def __init__(self, ui: App): | ||
self.ui = ui | ||
|
||
@classmethod | ||
def get_class_frame(cls, func: Callable, exc_traceback) -> Optional[FrameType]: | ||
""" | ||
Find the frame of the last callable within the scope of the MatchboxModule in | ||
the traceback. In this instance, the MatchboxModule is a class so we want to find | ||
the frame of the method that either (1) threw the exception, or (2) called a | ||
function that threw (or originated) the exception. | ||
""" | ||
last_frame = None | ||
for frame, _ in traceback.walk_tb(exc_traceback): | ||
print(frame.f_code.co_qualname) | ||
if frame.f_code.co_qualname == func.__name__: | ||
print( | ||
f"Found module.underlying_fn ({func.__name__}) in traceback, continuing..." | ||
) | ||
for name, val in inspect.getmembers(func): | ||
if ( | ||
name == frame.f_code.co_name | ||
and "self" in inspect.getargs(frame.f_code).args | ||
): | ||
print(f"Found method {val} in traceback, continuing...") | ||
last_frame = frame | ||
return last_frame | ||
|
||
@classmethod | ||
def get_lambda_child_frame( | ||
cls, func: Callable, exc_traceback | ||
) -> Tuple[Optional[FrameType], Optional[str]]: | ||
""" | ||
Find the frame of the last callable within the scope of the MatchboxModule in | ||
the traceback. In this instance, the MatchboxModule is a lambda function so we want | ||
to find the frame of the first function called by the lambda. | ||
""" | ||
lambda_args = inspect.getargs(func.__code__) | ||
potential_matches = {} | ||
print(f"Lambda args: {lambda_args}") | ||
for frame, _ in traceback.walk_tb(exc_traceback): | ||
print(frame.f_code.co_qualname) | ||
assert lambda_args is not None | ||
frame_args = inspect.getargvalues(frame) | ||
for name, val in potential_matches.items(): | ||
print( | ||
f"Checking candidate {name}={val} to match against {frame.f_code.co_qualname}" | ||
) | ||
if val == frame.f_code.co_qualname: | ||
print(f"Matched {name}={val} to {frame.f_code.co_qualname}") | ||
return frame, name | ||
elif hasattr(val, frame.f_code.co_name): | ||
print(f"Matched {frame.f_code.co_qualname} to member of {val}") | ||
return frame, name | ||
for name in lambda_args.args: | ||
print(f"Lambda arg '{name}'") | ||
if name in frame_args.args: | ||
print(f"Frame has arg {name} with value {frame_args.locals[name]}") | ||
# TODO: Find the argument which initiated the call that threw! | ||
# Which is somewhere deeper in the stack, which | ||
# frame.f_code.co_qualname must match one of the | ||
# frame_args.args! | ||
# NOTE: We know the next frame in the loop WILL match one of | ||
# this frame's arguments, either in the qual_name directly or in | ||
# the qual_name base (the class) | ||
potential_matches[name] = frame_args.locals[name] | ||
return None, None | ||
|
||
@classmethod | ||
def get_function_frame(cls, func: Callable, exc_traceback) -> Optional[FrameType]: | ||
print("============= get_function_frame() =========") | ||
last_frame = None | ||
for frame, _ in traceback.walk_tb(exc_traceback): | ||
print(frame.f_code.co_qualname) | ||
if frame.f_code.co_qualname == func.__name__: | ||
print( | ||
f"Found module.underlying_fn ({func.__name__}) in traceback, continuing..." | ||
) | ||
for name, val in inspect.getmembers(func.__module__): | ||
if name == frame.f_code.co_name: | ||
print(f"Found function {val} in traceback, continuing...") | ||
last_frame = frame | ||
print("============================================") | ||
return last_frame | ||
|
||
async def catch_and_hang(self, module: MatchboxModule, *args, **kwargs): | ||
try: | ||
self.ui.print_info(f"Calling MatchboxModule({module.underlying_fn}) with") | ||
self.ui.print_pretty( | ||
{ | ||
"args": args, | ||
"kwargs": kwargs, | ||
"partial.args": module.partial.args, | ||
"partial.kwargs": module.partial.keywords, | ||
} | ||
) | ||
output = await asyncio.to_thread(module, *args, **kwargs) | ||
self.ui.print_info("Output:") | ||
self.ui.print_pretty(output) | ||
return output | ||
except Exception as exception: | ||
# If the exception came from the wrapper itself, we should not catch it! | ||
exc_type, exc_value, exc_traceback = sys.exc_info() | ||
if exc_traceback.tb_next is None: | ||
self.ui.print_err( | ||
"[ERROR] Could not find the next frame in the call stack!" | ||
) | ||
elif exc_traceback.tb_next.tb_frame.f_code.co_name == "catch_and_hang": | ||
self.ui.print_err( | ||
f"[ERROR] Caught exception in the Builder: {exception}", | ||
) | ||
else: | ||
self.ui.print_err( | ||
f"Caught exception: {exception}", | ||
) | ||
self.ui.query_one("#traceback", RichLog).write(traceback.format_exc()) | ||
func = module.underlying_fn | ||
# NOTE: This frame is for the given function, which is the root of the | ||
# call tree (our MatchboxModule's underlying function). What we want is | ||
# to go down to the function that threw, and reload that only if it | ||
# wasn't called anywhere in the frozen module's call tree. | ||
frame = None | ||
if inspect.isclass(func): | ||
frame = self.get_class_frame(func, exc_traceback) | ||
elif inspect.isfunction(func) and func.__name__ == "<lambda>": | ||
frame, lambda_argname = self.get_lambda_child_frame( | ||
func, exc_traceback | ||
) | ||
module.throw_lambda_argname = lambda_argname | ||
elif inspect.isfunction(func): | ||
frame = self.get_function_frame(func, exc_traceback) | ||
else: | ||
raise NotImplementedError() | ||
if not frame: | ||
self.ui.print_err( | ||
f"Could not find the frame of the original function {func} in the traceback." | ||
) | ||
module.throw_frame = frame | ||
self.ui.print_info("Exception thrown in:") | ||
self.ui.print_pretty(frame) | ||
module.to_reload = True | ||
self.ui.print_info("Hanged.") | ||
await self.ui.hang(threw=True) | ||
|
||
async def reload_module(self, module: MatchboxModule): | ||
if module.throw_frame is None: | ||
self.ui.exit(1) | ||
raise RuntimeError( | ||
f"Module {module} is set to reload but we don't have the frame that threw!" | ||
) | ||
self.ui.log_tracer( | ||
Text( | ||
f"Reloading code from {module.throw_frame.f_code.co_filename}", | ||
style="purple", | ||
) | ||
) | ||
code_obj = module.throw_frame.f_code | ||
print(code_obj.co_qualname, inspect.getmodule(code_obj)) | ||
code_module = inspect.getmodule(code_obj) | ||
if code_module is None: | ||
self.ui.exit(1) | ||
raise RuntimeError( | ||
f"Could not find the module for the code object {code_obj}." | ||
) | ||
rld_module = importlib.reload(code_module) | ||
if code_obj.co_qualname.endswith("__init__"): | ||
class_name = code_obj.co_qualname.split(".")[0] | ||
self.ui.log_tracer( | ||
Text( | ||
f"-> Reloading class {class_name} from module {code_module}", | ||
style="purple", | ||
) | ||
) | ||
rld_callable = getattr(rld_module, class_name) | ||
if rld_callable is not None: | ||
self.ui.log_tracer( | ||
Text( | ||
f"-> Reloaded class {code_obj.co_qualname} from module {code_module.__name__}", | ||
style="cyan", | ||
) | ||
) | ||
print(inspect.getsource(rld_callable)) | ||
module.reload(rld_callable) | ||
return | ||
|
||
else: | ||
if code_obj.co_qualname.find(".") != -1: | ||
class_name, _ = code_obj.co_qualname.split(".") | ||
self.ui.log_tracer( | ||
Text( | ||
f"-> Reloading class {class_name} from module {code_module}", | ||
style="purple", | ||
) | ||
) | ||
rld_class = getattr(rld_module, class_name) | ||
rld_callable = None | ||
# Now find the method in the reloaded class, and replace the | ||
# with the reloaded one. | ||
for name, val in inspect.getmembers(rld_class): | ||
if inspect.isfunction(val) and val.__name__ == code_obj.co_name: | ||
self.ui.print_info( | ||
f" -> Reloading method '{name}'", | ||
) | ||
rld_callable = val | ||
if rld_callable is not None: | ||
self.ui.log_tracer( | ||
Text( | ||
f"-> Reloaded class-level method {code_obj.co_qualname} from module {code_module.__name__}", | ||
style="cyan", | ||
) | ||
) | ||
if module.underlying_fn.__name__ == "<lambda>": | ||
assert module.throw_lambda_argname is not None | ||
module.reload_surgically_in_lambda( | ||
module.throw_lambda_argname, code_obj.co_name, rld_callable | ||
) | ||
else: | ||
module.reload_surgically(code_obj.co_name, rld_callable) | ||
return | ||
else: | ||
print(code_module, code_obj, code_obj.co_name) | ||
self.ui.log_tracer( | ||
Text( | ||
f"-> Reloading module-level function {code_obj.co_name} from module {code_module.__name__}", | ||
style="purple", | ||
) | ||
) | ||
func = getattr(rld_module, code_obj.co_name) | ||
if func is not None: | ||
self.ui.print_info( | ||
f" -> Reloaded module level function {code_obj.co_name}", | ||
) | ||
print(inspect.getsource(func)) | ||
module.reload(func) | ||
return | ||
while True: | ||
await asyncio.sleep(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import uuid | ||
from functools import partial | ||
from types import FrameType | ||
from typing import Any, Callable, List, Optional | ||
|
||
from hydra_zen.typing import Partial | ||
|
||
|
||
class MatchboxModule: | ||
def __init__(self, name: str, fn: Callable | Partial, *args, **kwargs): | ||
self._str_rep = name | ||
self._uid = uuid.uuid4().hex | ||
self.underlying_fn: Callable = fn.func if isinstance(fn, partial) else fn | ||
self.partial = partial(fn, *args, **kwargs) | ||
self.to_reload = False | ||
self.result = None | ||
self.is_frozen = False | ||
self.throw_frame: Optional[FrameType] = None | ||
self.throw_lambda_argname: Optional[str] = None | ||
|
||
def reload(self, new_func: Callable) -> None: | ||
print(f"Replacing {self.underlying_fn} with {new_func}") | ||
self.underlying_fn = new_func | ||
self.partial = partial( | ||
self.underlying_fn, *self.partial.args, **self.partial.keywords | ||
) | ||
self.to_reload = False | ||
|
||
def reload_surgically(self, method_name: str, method: Callable) -> None: | ||
print(f"Replacing {method_name} which was {self.underlying_fn} with {method}") | ||
setattr(self.underlying_fn, method_name, method) | ||
self.partial = partial( | ||
self.underlying_fn, *self.partial.args, **self.partial.keywords | ||
) | ||
self.to_reload = False | ||
|
||
def reload_surgically_in_lambda( | ||
self, arg_name: str, method_name: str, method: Callable | ||
) -> None: | ||
print( | ||
f"Replacing {method_name} as argument {arg_name} in lambda's {self.partial.args} or {self.partial.keywords} with {method}" | ||
) | ||
if arg_name not in self.partial.keywords.keys(): | ||
raise KeyError( | ||
"Could not find the argument to replace in the partial kwargs!" | ||
) | ||
for k, v in self.partial.keywords.items(): | ||
print(f"Updating lambda arg {v} and") | ||
print(f"re-passing self reference via partial {partial(method, v)}") | ||
setattr(v, method_name, partial(method, v)) | ||
self.partial.keywords[k] = v # Need to update when using dict iterator | ||
self.partial = partial( | ||
self.underlying_fn, *self.partial.args, **self.partial.keywords | ||
) | ||
self.to_reload = False | ||
|
||
def __call__(self, module_chain: List) -> Any: | ||
""" | ||
Args: | ||
module_chain: List[MatchboxModule] | ||
""" | ||
|
||
def _find_module_result(module_chain: List, uid: str) -> Any: | ||
for module in module_chain: | ||
if module.uid == uid: | ||
return module.result | ||
return None | ||
|
||
for i, arg in enumerate(self.partial.args): | ||
if isinstance(arg, MatchboxModule): | ||
self.partial.args[i] = _find_module_result(module_chain, arg.uid) | ||
for key, value in self.partial.keywords.items(): | ||
if isinstance(value, MatchboxModule): | ||
self.partial.keywords[key] = _find_module_result( | ||
module_chain, value.uid | ||
) | ||
return self.partial() | ||
|
||
def __str__(self) -> str: | ||
return self._str_rep | ||
|
||
@property | ||
def uid(self) -> str: | ||
return f"uid-{self._uid}" |
Oops, something went wrong.