Skip to content

Commit

Permalink
expose config options for globals tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
amakelov committed Aug 21, 2024
1 parent d43e302 commit 000ea27
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
17 changes: 12 additions & 5 deletions mandala/deps/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,10 @@ def __init__(
self._representation = representation

@staticmethod
def from_obj(obj: Any, dep_key: DepKey) -> "GlobalVarNode":
representation = GlobalVarNode.represent(obj=obj)
def from_obj(obj: Any, dep_key: DepKey,
skip_unhashable: bool = False,
skip_silently: bool = False,) -> "GlobalVarNode":
representation = GlobalVarNode.represent(obj=obj, skip_unhashable=skip_unhashable, skip_silently=skip_silently)
return GlobalVarNode(
module_name=dep_key[0],
obj_name=dep_key[1],
Expand All @@ -199,7 +201,9 @@ def representation(self) -> Tuple[str, str]:
return self._representation

@staticmethod
def represent(obj: Any, allow_fallback: bool = True) -> Tuple[str, str]:
def represent(obj: Any, skip_unhashable: bool = True,
skip_silently: bool = False,
) -> Tuple[str, str]:
"""
Return a hash of this global variable's value + a truncated
representation useful for debugging/printing.
Expand All @@ -217,9 +221,12 @@ def represent(obj: Any, allow_fallback: bool = True) -> Tuple[str, str]:
except Exception as e:
shortened_exception = textwrap.shorten(text=str(e), width=80)
msg = f"Failed to hash global variable {truncated_repr} of type {type(obj)}, because {shortened_exception}"
if allow_fallback:
if skip_unhashable:
content_hash = UNKNOWN_GLOBAL_VAR
logger.warning(msg)
if skip_silently:
logger.debug(msg)
else:
logger.warning(msg)
else:
raise RuntimeError(msg)
return content_hash, truncated_repr
Expand Down
8 changes: 7 additions & 1 deletion mandala/deps/tracers/dec_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,15 @@ def __init__(
graph: Optional[DependencyGraph] = None,
strict: bool = True,
allow_methods: bool = False,
skip_unhashable_globals: bool = True,
skip_globals_silently: bool = False,
):
self.call_stack: List[CallableNode] = []
self.graph = DependencyGraph() if graph is None else graph
self.paths = paths
self.strict = strict
self.skip_unhashable_globals = skip_unhashable_globals
self.skip_globals_silently = skip_globals_silently
self.allow_methods = allow_methods

self._traced = {}
Expand Down Expand Up @@ -248,7 +252,9 @@ def register_global_access(self, key: str, value: Any):
assert len(self.call_stack) > 0
calling_node = self.call_stack[-1]
node = GlobalVarNode.from_obj(
obj=value, dep_key=(calling_node.module_name, key)
obj=value, dep_key=(calling_node.module_name, key),
skip_unhashable=self.skip_unhashable_globals,
skip_silently=self.skip_globals_silently
)
self.graph.add_edge(calling_node, node)

Expand Down
10 changes: 8 additions & 2 deletions mandala/deps/versioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,18 @@ def __init__(
self,
paths: List[Path],
TracerCls: type,
strict,
track_methods,
strict: bool,
skip_unhashable_globals: bool,
skip_globals_silently: bool,
track_methods: bool,
package_name: Optional[str] = None,
):
assert len(paths) in [0, 1]
self.paths = paths
self.TracerCls = TracerCls
self.strict = strict
self.skip_unhashable_globals = skip_unhashable_globals
self.skip_globals_silently = skip_globals_silently
self.allow_methods = track_methods
self.package_name = package_name
self.global_topology: DependencyGraph = DependencyGraph()
Expand Down Expand Up @@ -116,6 +120,8 @@ def make_tracer(self) -> TracerABC:
paths=[Config.mandala_path] + self.paths,
strict=self.strict,
allow_methods=self.allow_methods,
skip_unhashable_globals=self.skip_unhashable_globals,
skip_globals_silently=self.skip_globals_silently,
)

def guess_code_state(self) -> CodeState:
Expand Down
4 changes: 4 additions & 0 deletions mandala/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def __init__(self, db_path: str = ":memory:",
deps_path: Optional[Union[str, Path]] = None,
tracer_impl: Optional[type] = None,
strict_tracing: bool = False,
skip_unhashable_globals: bool = True,
skip_globals_silently: bool = False,
deps_package: Optional[str] = None,
):
self.db = DBAdapter(db_path=db_path)
Expand Down Expand Up @@ -74,6 +76,8 @@ def __init__(self, db_path: str = ":memory:",
strict=strict_tracing,
track_methods=True,
package_name=deps_package,
skip_unhashable_globals=skip_unhashable_globals,
skip_globals_silently=skip_globals_silently
)
self.sources["versioner"] = versioner
else:
Expand Down

0 comments on commit 000ea27

Please sign in to comment.