diff --git a/reloading/reloading.py b/reloading/reloading.py index ef79636..2a8aa41 100644 --- a/reloading/reloading.py +++ b/reloading/reloading.py @@ -53,68 +53,62 @@ def reloading(fn_or_seq=None, every=1, forever=None): def unique_name(used): + # get the longest element of the used names and append a "0" return max(used, key=len) + "0" -def tuple_ast_as_name(tup): - if isinstance( - tup, ast.Name - ): # handle the case that there only is a single loop var - return tup.id +def format_itervars(ast_node): + """Formats a an `ast_node` of loop iteration variables as string, e.g. 'a, b'""" + + # handle the case that there only is a single loop var + if isinstance(ast_node, ast.Name): + return ast_node.id + names = [] - for child in tup.elts: + for child in ast_node.elts: if isinstance(child, ast.Name): names.append(child.id) elif isinstance(child, ast.Tuple): - names.append( - "({})".format(tuple_ast_as_name(child)) - ) # if its another tuple, like "a, (b, c)", recurse. + # if its another tuple, like "a, (b, c)", recurse + names.append("({})".format(format_itervars(child))) + return ", ".join(names) def load_file(path): src = "" - while ( - src == "" - ): # while loop here since while saving, the file may sometimes be empty. + # while loop here since while saving, the file may sometimes be empty. + while (src == ""): with open(path, "r") as f: src = f.read() return src + "\n" -def load_ast_parse(path): +def parse_file_until_successful(path): source = load_file(path) while True: try: tree = ast.parse(source) - break + return tree except SyntaxError: handle_exception(path) source = load_file(path) - return tree -def isolate_loop_ast(tree, lineno=None): - """Strip ast from anything but the loop body, also returning the loop vars.""" - for child in ast.walk(tree): - # i hope this is enough checks - if ( - getattr(child, "lineno", None) == lineno - and child.iter.func.id == "reloading" - ): - itervars = tuple_ast_as_name(child.target) - # replace the original body with the loop body - tree.body = child.body - return itervars +def isolate_loop_body_and_get_itervars(tree, lineno): + """Modifies tree inplace as unclear how to create ast.Module. + Returns itervars""" + for node in ast.walk(tree): + if (getattr(node, "lineno", None) == lineno and node.iter.func.id == "reloading"): + tree.body = node.body + return node.target def get_loop_code(loop_frame_info): fpath = loop_frame_info[1] - # find the loop body in the caller module's source - tree = load_ast_parse(fpath) - # same working principle as the functio nversion, strip the ast of everything but the loop body. - itervars = isolate_loop_ast(tree, lineno=loop_frame_info[2]) - return compile(tree, filename="", mode="exec"), itervars + tree = parse_file_until_successful(fpath) + itervars = isolate_loop_body_and_get_itervars(tree, lineno=loop_frame_info[2]) + return compile(tree, filename="", mode="exec"), format_itervars(itervars) def handle_exception(fpath): @@ -151,56 +145,55 @@ def _reloading_loop(seq, every=1): return [] -def ast_get_decorator_name(dec): - if hasattr(dec, "id"): - return dec.id - return dec.func.id +def get_decorator_name(dec_node): + if hasattr(dec_node, "id"): + return dec_node.id + return dec_node.func.id -def ast_filter_decorator(func): - """Filter out the reloading decorator, inplace.""" +def strip_reloading_decorator(func): + """Remove the reloading decorator in-place""" func.decorator_list = [ - dec for dec in func.decorator_list if ast_get_decorator_name(dec) != "reloading" + dec for dec in func.decorator_list if get_decorator_name(dec) != "reloading" ] -def isolate_func_ast(funcname, tree): - """Remove everything but the function definition from the ast.""" - for child in ast.walk(tree): +def isolate_function_def(funcname, tree): + """Strip everything but the function definition from the ast in-place. + Also strips the reloading decorator from the function definition""" + for node in ast.walk(tree): if ( - isinstance(child, ast.FunctionDef) - and child.name == funcname - and len( - [ - dec - for dec in child.decorator_list - if ast_get_decorator_name(dec) == "reloading" - ] - ) - == 1 + isinstance(node, ast.FunctionDef) + and node.name == funcname + and "reloading" in [ + get_decorator_name(dec) + for dec in node.decorator_list + ] ): - ast_filter_decorator(child) - tree.body = [ - child - ] # reassign body, i would create a new ast if i knew how to create ast.Module objects + strip_reloading_decorator(node) + tree.body = [ node ] + return True + return False def get_function_def_code(fpath, fn): - tree = load_ast_parse(fpath) - # these both work inplace and modify the ast - isolate_func_ast(fn.__name__, tree) + tree = parse_file_until_successful(fpath) + found = isolate_function_def(fn.__name__, tree) + if not found: + return None compiled = compile(tree, filename="", mode="exec") return compiled def get_reloaded_function(caller_globals, caller_locals, fpath, fn): code = get_function_def_code(fpath, fn) + if code is None: + return None # need to copy locals, otherwise the exec will overwrite the decorated with the undecorated new version # this became a need after removing the reloading decorator from the newly defined version - caller_locals = caller_locals.copy() - exec(code, caller_globals, caller_locals) - func = caller_locals[fn.__name__] - # get the newly defined function from the caller_locals copy + caller_locals_copy = caller_locals.copy() + exec(code, caller_globals, caller_locals_copy) + func = caller_locals_copy[fn.__name__] return func @@ -212,24 +205,21 @@ def _reloading_function(fn, every=1): # crutch to use dict as python2 doesn't support nonlocal state = { - "func": get_reloaded_function(caller_globals, caller_locals, fpath, fn), - "reloads": 1, + "func": None, + "reloads": 0, } def wrapped(*args, **kwargs): if state["reloads"] % every == 0: - state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn) + state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn) or state["func"] state["reloads"] += 1 while True: try: result = state["func"](*args, **kwargs) - break + return result except Exception: handle_exception(fpath) - state["func"] = get_reloaded_function( - caller_globals, caller_locals, fpath, fn - ) - return result + state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn) or state["func"] caller_locals[fn.__name__] = wrapped - return wrapped \ No newline at end of file + return wrapped