Skip to content

Commit

Permalink
resolve merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Bizordec committed Oct 19, 2024
2 parents ed8c39c + a9f5869 commit 1ea3ea3
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 123 deletions.
2 changes: 1 addition & 1 deletion src/slipcover/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def merge_files(args):
with args.out.open("w", encoding='utf-8') as jf:
json.dump(merged, jf)
except Exception as e:
warnings.warn(e)
warnings.warn(str(e))
return 1

return 0
Expand Down
49 changes: 29 additions & 20 deletions src/slipcover/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ def encode_branch(from_line, to_line):
def decode_branch(line):
return ((line>>15)&0x7FFF, line&0x7FFF)

EXIT = 0

def preinstrument(tree: ast.AST) -> ast.AST:
def preinstrument(tree: ast.Module) -> ast.Module:
"""Prepares an AST for Slipcover instrumentation, inserting assignments indicating where branches happen."""

class SlipcoverTransformer(ast.NodeTransformer):
Expand All @@ -30,10 +31,10 @@ def _mark_branch(self, from_line: int, to_line: int) -> List[ast.stmt]:
# Using a constant Expr allows the compiler to optimize this to a NOP
mark = ast.Expr(ast.Constant(None))
for node in ast.walk(mark):
node.lineno = node.end_lineno = encode_branch(from_line, to_line)
node.lineno = node.end_lineno = encode_branch(from_line, to_line) # type: ignore[attr-defined]
# Leaving the columns unitialized can lead to invalid positions despite
# our use of ast.fix_missing_locations
node.col_offset = node.end_col_offset = -1
node.col_offset = node.end_col_offset = -1 # type: ignore[attr-defined]
else:
mark = ast.Assign([ast.Name(BRANCH_NAME, ast.Store())],
ast.Tuple([ast.Constant(from_line), ast.Constant(to_line)], ast.Load()))
Expand All @@ -42,7 +43,7 @@ def _mark_branch(self, from_line: int, to_line: int) -> List[ast.stmt]:
node.lineno = 0 # we ignore line 0, so this avoids generating extra line probes
else:
for node in ast.walk(mark):
node.lineno = from_line
node.lineno = from_line # type: ignore[attr-defined]

return [mark]

Expand All @@ -58,13 +59,13 @@ def visit_FunctionDef(self, node: Union[ast.AsyncFunctionDef, ast.FunctionDef])
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST:
return self.visit_FunctionDef(node)

def _mark_branches(self, node: ast.AST) -> ast.AST:
def _mark_branches(self, node: Union[ast.If, ast.For, ast.AsyncFor, ast.While]) -> ast.AST:
node.body = self._mark_branch(node.lineno, node.body[0].lineno) + node.body

if node.orelse:
node.orelse = self._mark_branch(node.lineno, node.orelse[0].lineno) + node.orelse
else:
to_line = node.next_node.lineno if node.next_node else 0 # exit
to_line = node.next_node.lineno if node.next_node else EXIT # type: ignore[union-attr]
node.orelse = self._mark_branch(node.lineno, to_line)

super().generic_visit(node)
Expand Down Expand Up @@ -93,59 +94,67 @@ def visit_Match(self, node: ast.Match) -> ast.Match:

has_wildcard = case.guard is None and isinstance(last_pattern, ast.MatchAs) and last_pattern.pattern is None
if not has_wildcard:
to_line = node.next_node.lineno if node.next_node else 0 # exit
to_line = node.next_node.lineno if node.next_node else EXIT # type: ignore[attr-defined]
node.cases.append(ast.match_case(ast.MatchAs(),
body=self._mark_branch(node.lineno, to_line)))

super().generic_visit(node)
return node


match_type = ast.Match if sys.version_info >= (3,10) else tuple() # empty tuple matches nothing
try_type = (ast.Try, ast.TryStar) if sys.version_info >= (3,11) else ast.Try
if sys.version_info >= (3,10):
match_type = ast.Match
else:
match_type = tuple() # matches nothing

if sys.version_info >= (3,11):
try_type = (ast.Try, ast.TryStar)
else:
try_type = ast.Try

# Compute the "next" statement in case a branch flows control out of a node.
# We need a parent node's "next" computed before its siblings, so we compute it here, in BFS;
# note that visit() doesn't guarantee any specific order.
tree.next_node = None
tree.next_node = None # type: ignore[attr-defined]
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
# no next node, yields (..., 0), i.e., "->exit" branch
node.next_node = None
node.next_node = None # type: ignore[union-attr]

for name, field in ast.iter_fields(node):
if isinstance(field, ast.AST):
# if a field is just a node, any execution continues after our node
field.next_node = node.next_node
field.next_node = node.next_node # type: ignore[attr-defined]
elif isinstance(node, match_type) and name == 'cases':
# each case continues after the 'match'
for item in field:
item.next_node = node.next_node
item.next_node = node.next_node # type: ignore[attr-defined]
elif isinstance(node, try_type) and name == 'handlers':
# each 'except' continues either in 'finally', or after the 'try'
for h in field:
h.next_node = node.finalbody[0] if node.finalbody else node.next_node
h.next_node = node.finalbody[0] if node.finalbody else node.next_node # type: ignore[attr-defined,union-attr]
elif isinstance(field, list):
# if a field is a list, each item but the last one continues with the next item
prev = None
for item in field:
if isinstance(item, ast.AST):
if prev:
prev.next_node = item
prev.next_node = item # type: ignore[attr-defined]
prev = item

if prev:
if isinstance(node, (ast.For, ast.While)):
prev.next_node = node # loops back
# loops back
prev.next_node = node # type: ignore[attr-defined]
elif isinstance(node, try_type) and (name in ('body', 'orelse')):
if name == 'body' and node.orelse:
prev.next_node = node.orelse[0]
prev.next_node = node.orelse[0] # type: ignore[attr-defined]
elif node.finalbody:
prev.next_node = node.finalbody[0]
prev.next_node = node.finalbody[0] # type: ignore[attr-defined]
else:
prev.next_node = node.next_node
prev.next_node = node.next_node # type: ignore[attr-defined, union-attr]
else:
prev.next_node = node.next_node
prev.next_node = node.next_node # type: ignore[attr-defined]

tree = SlipcoverTransformer().visit(tree)
ast.fix_missing_locations(tree)
Expand Down
14 changes: 7 additions & 7 deletions src/slipcover/bytecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import dis
import types
from typing import List, Tuple
from typing import List, Tuple, Iterator

# FIXME provide __all__

Expand Down Expand Up @@ -34,7 +34,7 @@ def branch2offset(arg: int) -> int:
op_PRECALL = dis.opmap["PRECALL"]
op_CALL = dis.opmap["CALL"]
op_CACHE = dis.opmap["CACHE"]
is_EXTENDED_ARG.append(dis._all_opmap["EXTENDED_ARG_QUICK"])
is_EXTENDED_ARG.append(dis._all_opmap["EXTENDED_ARG_QUICK"]) # type: ignore[attr-defined]
else:
op_RESUME = None
op_PUSH_NULL = None
Expand Down Expand Up @@ -63,11 +63,11 @@ def opcode_arg(opcode: int, arg: int, min_ext : int = 0) -> List[int]:
)
bytecode.extend([opcode, arg & 0xFF])
if sys.version_info >= (3,11):
bytecode.extend([op_CACHE, 0] * dis._inline_cache_entries[opcode])
bytecode.extend([op_CACHE, 0] * dis._inline_cache_entries[opcode]) # type: ignore[attr-defined]
return bytecode


def unpack_opargs(code: bytes) -> Tuple[int, int, int, int]:
def unpack_opargs(code: bytes) -> Iterator[Tuple[int, int, int, int]]:
"""Unpacks opcodes and their arguments, returning:
- the beginning offset, including that of the first EXTENDED_ARG, if any
Expand Down Expand Up @@ -444,8 +444,8 @@ def __init__(self, code):
self.patch = None

self.branches = None
self.ex_table = None
self.lines = None
self.ex_table = []
self.lines = []
self.inserts = []

self.max_addtl_stack = 0
Expand Down Expand Up @@ -641,7 +641,7 @@ def finish(self):
if not self.patch and not self.consts:
return self.orig_code

replace = {}
replace : dict = {}
if self.consts is not None:
replace["co_consts"] = tuple(self.consts)

Expand Down
25 changes: 14 additions & 11 deletions src/slipcover/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
import sysconfig

from importlib.abc import MetaPathFinder, Loader
from importlib.abc import MetaPathFinder, Loader, ResourceReader
from importlib import machinery


Expand All @@ -29,14 +29,17 @@ def __init__(self, sci: Slipcover, orig_loader: Loader, origin: str):
delattr(self, "get_resource_reader")

# for compability with loaders supporting resources, used e.g. by sklearn
def get_resource_reader(self, fullname: str):
return self.orig_loader.get_resource_reader(fullname)
def get_resource_reader(self, fullname: str) -> Optional[ResourceReader]:
# FIXME deprecated in Python 3.12
if hasattr(self.orig_loader, 'get_resource_reader'):
return self.orig_loader.get_resource_reader(fullname)
return None

def create_module(self, spec):
return self.orig_loader.create_module(spec)

def get_code(self, name): # expected by pyrun
return self.orig_loader.get_code(name)
return self.orig_loader.get_code(name) # type: ignore[attr-defined]

def exec_module(self, module):
import ast
Expand All @@ -45,7 +48,7 @@ def exec_module(self, module):
t = br.preinstrument(ast.parse(self.origin.read_bytes()))
code = compile(t, str(self.origin), "exec")
else:
code = self.orig_loader.get_code(module.__name__)
code = self.orig_loader.get_code(module.__name__) # type: ignore[attr-defined]

self.sci.register_module(module)
code = self.sci.instrument(code)
Expand Down Expand Up @@ -133,7 +136,7 @@ def find_spec(self, fullname, path, target=None):
if isinstance(spec.loader, machinery.ExtensionFileLoader):
return None

if self.file_matcher.matches(spec.origin):
if spec.origin and self.file_matcher.matches(spec.origin):
if self.debug:
print(f"instrumenting {fullname} from {spec.origin}")
spec.loader = SlipcoverLoader(self.sci, spec.loader, spec.origin)
Expand All @@ -146,7 +149,7 @@ def find_spec(self, fullname, path, target=None):
class ImportManager:
"""A context manager that enables instrumentation while active."""

def __init__(self, sci: Slipcover, file_matcher: FileMatcher = None, debug: bool = False):
def __init__(self, sci: Slipcover, file_matcher: Optional[FileMatcher] = None, debug: bool = False):
self.mpf = SlipcoverMetaPathFinder(sci, file_matcher if file_matcher else MatchEverything(), debug)

def __enter__(self) -> "ImportManager":
Expand Down Expand Up @@ -196,7 +199,7 @@ def find_replacements(co):

find_replacements(code)

visited = set()
visited : set = set()
for f in Slipcover.find_functions(module.__dict__.values(), visited):
if (repl := replacement.get(f.__code__.co_name, None)):
assert f.__code__.co_firstlineno == repl.co_firstlineno # sanity check
Expand All @@ -215,7 +218,7 @@ def exec_wrapper(obj, g):
obj = sci.instrument(obj)
exec(obj, g)

pyrewrite._Slipcover_exec_wrapper = exec_wrapper
pyrewrite._Slipcover_exec_wrapper = exec_wrapper # type: ignore[attr-defined]

if sci.branch:
import inspect
Expand Down Expand Up @@ -249,11 +252,11 @@ def adjust_name(fn : Path) -> Path:

orig_read_pyc = pyrewrite._read_pyc
def read_pyc(*args, **kwargs):
return orig_read_pyc(*args[:1], adjust_name(args[1]), *args[2:], **kwargs)
return orig_read_pyc(*args[:1], adjust_name(args[1]), *args[2:], **kwargs) # type: ignore[call-arg]

orig_write_pyc = pyrewrite._write_pyc
def write_pyc(*args, **kwargs):
return orig_write_pyc(*args[:3], adjust_name(args[3]), *args[4:], **kwargs)
return orig_write_pyc(*args[:3], adjust_name(args[3]), *args[4:], **kwargs) # type: ignore[call-arg]

pyrewrite._read_pyc = read_pyc
pyrewrite._write_pyc = write_pyc
Expand Down
Loading

0 comments on commit 1ea3ea3

Please sign in to comment.