diff --git a/src/nnsight/contexts/Runner.py b/src/nnsight/contexts/Runner.py index 31366502..59428888 100644 --- a/src/nnsight/contexts/Runner.py +++ b/src/nnsight/contexts/Runner.py @@ -1,9 +1,10 @@ from __future__ import annotations import io -import torch + import requests import socketio +import torch from tqdm import tqdm from .. import CONFIG, pydantics @@ -42,6 +43,7 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: if self.remote: self.run_server() + self._graph.tracing = False self._graph = None else: super().__exit__(exc_type, exc_val, exc_tb) @@ -132,7 +134,7 @@ def blocking_request(self, request: pydantics.RequestModel): response = requests.post( f"https://{CONFIG.API.HOST}/request", json=request.model_dump(exclude=["id", "received"]), - headers={'ndif-api-key' : CONFIG.API.APIKEY} + headers={"ndif-api-key": CONFIG.API.APIKEY}, ) if response.status_code == 200: diff --git a/src/nnsight/contexts/Tracer.py b/src/nnsight/contexts/Tracer.py index 2110774b..0cc4405e 100644 --- a/src/nnsight/contexts/Tracer.py +++ b/src/nnsight/contexts/Tracer.py @@ -35,7 +35,7 @@ def __init__( self._kwargs = kwargs - self._graph = Graph( + self._graph: Graph = Graph( self._model._envoy, proxy_class=model.proxy_class, validate=validate ) @@ -63,7 +63,7 @@ def __enter__(self) -> Tracer: def __exit__(self, exc_type, exc_val, exc_tb) -> None: if isinstance(exc_val, BaseException): raise exc_val - + output = self._model.interleave( self._model._execute, self._graph, @@ -71,6 +71,7 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: **self._kwargs, ) + self._graph.tracing = False self._graph = None def invoke(self, *inputs: Tuple[Any], **kwargs) -> Invoker: diff --git a/src/nnsight/intervention.py b/src/nnsight/intervention.py index c2b2af8f..5977e68a 100644 --- a/src/nnsight/intervention.py +++ b/src/nnsight/intervention.py @@ -84,12 +84,14 @@ def grad(self) -> InterventionProxy: if self._grad is None: # We track how many times backward is called via an attribute on the Graph - if not hasattr(self.node.graph, 'n_backward_calls'): + if not hasattr(self.node.graph, "n_backward_calls"): - setattr(self.node.graph, 'n_backward_calls', 0) + setattr(self.node.graph, "n_backward_calls", 0) self.__dict__["_grad"] = self.node.add( - value=self.node.proxy_value, target="grad", args=[self.node,self.node.graph.n_backward_calls] + value=self.node.proxy_value, + target="grad", + args=[self.node, self.node.graph.n_backward_calls], ) return self._grad @@ -106,7 +108,6 @@ def grad(self, value: Union[InterventionProxy, Any]) -> None: def __call__(self, *args, **kwargs) -> Self: - # We don't want to call backward on fake tensors if ( self.node.target is util.fetch_attr @@ -114,18 +115,18 @@ def __call__(self, *args, **kwargs) -> Self: and self.node.args[1] == "backward" ): # We track how many times backward is called via an attribute on the Graph - if not hasattr(self.node.graph, 'n_backward_calls'): + if not hasattr(self.node.graph, "n_backward_calls"): - setattr(self.node.graph, 'n_backward_calls', 0) + setattr(self.node.graph, "n_backward_calls", 0) - # Clear all .grad proxies + # Clear all .grad proxies for node in self.node.graph.nodes.values(): try: if node.proxy._grad is not None: - node.proxy.__dict__['_grad'] = None + node.proxy.__dict__["_grad"] = None except ReferenceError: pass @@ -139,7 +140,6 @@ def __call__(self, *args, **kwargs) -> Self: kwargs=kwargs, ) - return super().__call__(*args, **kwargs) def __setattr__( @@ -159,7 +159,7 @@ def shape(self) -> Collection[torch.Size]: Union[torch.Size,Collection[torch.Size]]: Proxy value shape or collection of shapes. """ - if self.node.is_graph_dereferenced(): + if not self.node.is_tracing(): return util.apply(self.value, lambda x: x.shape, torch.Tensor) @@ -170,15 +170,29 @@ def device(self) -> Collection[torch.device]: """Property to retrieve the device of the traced proxy value or real value. Returns: - Union[torch.Size,Collection[torch.device]]: Proxy value shape or collection of shapes. + Union[torch.Size,Collection[torch.device]]: Proxy value device or collection of devices. """ - if self.node.is_graph_dereferenced(): + if not self.node.is_tracing(): return util.apply(self.value, lambda x: x.device, torch.Tensor) return util.apply(self.node.proxy_value, lambda x: x.device, torch.Tensor) + @property + def dtype(self) -> Collection[torch.device]: + """Property to retrieve the dtype of the traced proxy value or real value. + + Returns: + Union[torch.Size,Collection[torch.dtype]]: Proxy value dtype or collection of dtypes. + """ + + if not self.node.is_tracing(): + + return util.apply(self.value, lambda x: x.dtype, torch.Tensor) + + return util.apply(self.node.proxy_value, lambda x: x.dtype, torch.Tensor) + def concat( activations: Any, diff --git a/src/nnsight/tracing/Graph.py b/src/nnsight/tracing/Graph.py index 3c63e00a..d8647434 100644 --- a/src/nnsight/tracing/Graph.py +++ b/src/nnsight/tracing/Graph.py @@ -26,6 +26,7 @@ class Graph: Attributes: validate (bool): If to execute nodes as they are added with their proxy values in order to check if the executions are possible (i.e shape errors etc). Defaults to True. proxy_class (Type[Proxy]): Proxy class to use. Defaults to Proxy. + tracing (bool): If currently tracing operations nodes (Dict[str, Node]): Mapping of node name to node. name_idx (Dict[str, int]): Mapping of node target_name to number of previous names with the same target_name. Used so names are unique. @@ -45,6 +46,8 @@ def __init__( self.proxy_class = proxy_class self.validate = validate + self.tracing = True + self.nodes: Dict[str, Node] = dict() self.name_idx: Dict[str, int] = dict() diff --git a/src/nnsight/tracing/Node.py b/src/nnsight/tracing/Node.py index 3a789e14..aff0dd77 100644 --- a/src/nnsight/tracing/Node.py +++ b/src/nnsight/tracing/Node.py @@ -137,21 +137,19 @@ def __init__( self.execute() - def is_graph_dereferenced(self) -> bool: - """Checks to see if the weakref to the Graph is deleted. If it is, we're no longer tracing. + def is_tracing(self) -> bool: + """Checks to see if the weakref to the Graph is alive and tracing or dead. Returns: - bool: Is Graph dereferenced. + bool: Is Graph tracing. """ try: - self.graph.add + return self.graph.tracing except: - return True - - return False + return False def add( self, @@ -167,7 +165,7 @@ def add( Proxy: Proxy """ - if self.is_graph_dereferenced(): + if not self.is_tracing(): return Proxy( Node( @@ -285,7 +283,7 @@ def grad(value): self.set_value(value) - if not self.is_graph_dereferenced(): + if self.is_tracing(): value = self.graph.get_swap(value) diff --git a/src/nnsight/tracing/Proxy.py b/src/nnsight/tracing/Proxy.py index 002a79e4..705d51b0 100644 --- a/src/nnsight/tracing/Proxy.py +++ b/src/nnsight/tracing/Proxy.py @@ -53,7 +53,7 @@ def value(self) -> Any: def __str__(self) -> str: - if self.node.is_graph_dereferenced(): + if not self.node.is_tracing(): return str(self.value) @@ -63,7 +63,7 @@ def __str__(self) -> str: def __repr__(self) -> str: - if self.node.is_graph_dereferenced(): + if not self.node.is_tracing(): return repr(self.value) @@ -216,12 +216,12 @@ def __rtruediv__(self, other: Union[Proxy, Any]) -> Self: args=[other, self.node], ) + def __index__(self) -> Self: + return self.node.add(target=operator.index, args=[self.node]) + def __bool__(self) -> bool: return self.node.proxy_value.__bool__() - def __index__(self) -> int: - return self.node.proxy_value.__index__() - def __instancecheck__(self, __instance: Any) -> bool: return self.node.proxy_value.__instancecheck__(__instance)