Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
julvo committed Jun 3, 2021
1 parent 67658a9 commit d5fbe65
Showing 1 changed file with 62 additions and 72 deletions.
134 changes: 62 additions & 72 deletions reloading/reloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,68 +53,62 @@ def reloading(fn_or_seq=None, every=1, forever=None):


def unique_name(used):
# get the longest element of the used names and append a "0"
return max(used, key=len) + "0"


def tuple_ast_as_name(tup):
if isinstance(
tup, ast.Name
): # handle the case that there only is a single loop var
return tup.id
def format_itervars(ast_node):
"""Formats a an `ast_node` of loop iteration variables as string, e.g. 'a, b'"""

# handle the case that there only is a single loop var
if isinstance(ast_node, ast.Name):
return ast_node.id

names = []
for child in tup.elts:
for child in ast_node.elts:
if isinstance(child, ast.Name):
names.append(child.id)
elif isinstance(child, ast.Tuple):
names.append(
"({})".format(tuple_ast_as_name(child))
) # if its another tuple, like "a, (b, c)", recurse.
# if its another tuple, like "a, (b, c)", recurse
names.append("({})".format(format_itervars(child)))

return ", ".join(names)


def load_file(path):
src = ""
while (
src == ""
): # while loop here since while saving, the file may sometimes be empty.
# while loop here since while saving, the file may sometimes be empty.
while (src == ""):
with open(path, "r") as f:
src = f.read()
return src + "\n"


def load_ast_parse(path):
def parse_file_until_successful(path):
source = load_file(path)
while True:
try:
tree = ast.parse(source)
break
return tree
except SyntaxError:
handle_exception(path)
source = load_file(path)
return tree


def isolate_loop_ast(tree, lineno=None):
"""Strip ast from anything but the loop body, also returning the loop vars."""
for child in ast.walk(tree):
# i hope this is enough checks
if (
getattr(child, "lineno", None) == lineno
and child.iter.func.id == "reloading"
):
itervars = tuple_ast_as_name(child.target)
# replace the original body with the loop body
tree.body = child.body
return itervars
def isolate_loop_body_and_get_itervars(tree, lineno):
"""Modifies tree inplace as unclear how to create ast.Module.
Returns itervars"""
for node in ast.walk(tree):
if (getattr(node, "lineno", None) == lineno and node.iter.func.id == "reloading"):
tree.body = node.body
return node.target


def get_loop_code(loop_frame_info):
fpath = loop_frame_info[1]
# find the loop body in the caller module's source
tree = load_ast_parse(fpath)
# same working principle as the functio nversion, strip the ast of everything but the loop body.
itervars = isolate_loop_ast(tree, lineno=loop_frame_info[2])
return compile(tree, filename="", mode="exec"), itervars
tree = parse_file_until_successful(fpath)
itervars = isolate_loop_body_and_get_itervars(tree, lineno=loop_frame_info[2])
return compile(tree, filename="", mode="exec"), format_itervars(itervars)


def handle_exception(fpath):
Expand Down Expand Up @@ -151,56 +145,55 @@ def _reloading_loop(seq, every=1):
return []


def ast_get_decorator_name(dec):
if hasattr(dec, "id"):
return dec.id
return dec.func.id
def get_decorator_name(dec_node):
if hasattr(dec_node, "id"):
return dec_node.id
return dec_node.func.id


def ast_filter_decorator(func):
"""Filter out the reloading decorator, inplace."""
def strip_reloading_decorator(func):
"""Remove the reloading decorator in-place"""
func.decorator_list = [
dec for dec in func.decorator_list if ast_get_decorator_name(dec) != "reloading"
dec for dec in func.decorator_list if get_decorator_name(dec) != "reloading"
]


def isolate_func_ast(funcname, tree):
"""Remove everything but the function definition from the ast."""
for child in ast.walk(tree):
def isolate_function_def(funcname, tree):
"""Strip everything but the function definition from the ast in-place.
Also strips the reloading decorator from the function definition"""
for node in ast.walk(tree):
if (
isinstance(child, ast.FunctionDef)
and child.name == funcname
and len(
[
dec
for dec in child.decorator_list
if ast_get_decorator_name(dec) == "reloading"
]
)
== 1
isinstance(node, ast.FunctionDef)
and node.name == funcname
and "reloading" in [
get_decorator_name(dec)
for dec in node.decorator_list
]
):
ast_filter_decorator(child)
tree.body = [
child
] # reassign body, i would create a new ast if i knew how to create ast.Module objects
strip_reloading_decorator(node)
tree.body = [ node ]
return True
return False


def get_function_def_code(fpath, fn):
tree = load_ast_parse(fpath)
# these both work inplace and modify the ast
isolate_func_ast(fn.__name__, tree)
tree = parse_file_until_successful(fpath)
found = isolate_function_def(fn.__name__, tree)
if not found:
return None
compiled = compile(tree, filename="", mode="exec")
return compiled


def get_reloaded_function(caller_globals, caller_locals, fpath, fn):
code = get_function_def_code(fpath, fn)
if code is None:
return None
# need to copy locals, otherwise the exec will overwrite the decorated with the undecorated new version
# this became a need after removing the reloading decorator from the newly defined version
caller_locals = caller_locals.copy()
exec(code, caller_globals, caller_locals)
func = caller_locals[fn.__name__]
# get the newly defined function from the caller_locals copy
caller_locals_copy = caller_locals.copy()
exec(code, caller_globals, caller_locals_copy)
func = caller_locals_copy[fn.__name__]
return func


Expand All @@ -212,24 +205,21 @@ def _reloading_function(fn, every=1):

# crutch to use dict as python2 doesn't support nonlocal
state = {
"func": get_reloaded_function(caller_globals, caller_locals, fpath, fn),
"reloads": 1,
"func": None,
"reloads": 0,
}

def wrapped(*args, **kwargs):
if state["reloads"] % every == 0:
state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn)
state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn) or state["func"]
state["reloads"] += 1
while True:
try:
result = state["func"](*args, **kwargs)
break
return result
except Exception:
handle_exception(fpath)
state["func"] = get_reloaded_function(
caller_globals, caller_locals, fpath, fn
)
return result
state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn) or state["func"]

caller_locals[fn.__name__] = wrapped
return wrapped
return wrapped

0 comments on commit d5fbe65

Please sign in to comment.