Skip to content

Commit

Permalink
Merge pull request #78 from JadenFiotto-Kaufman/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JadenFiotto-Kaufman authored Feb 26, 2024
2 parents 7ee98ab + 63fe491 commit 24253c4
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 30 deletions.
6 changes: 4 additions & 2 deletions src/nnsight/contexts/Runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import io
import torch

import requests
import socketio
import torch
from tqdm import tqdm

from .. import CONFIG, pydantics
Expand Down Expand Up @@ -42,6 +43,7 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
if self.remote:
self.run_server()

self._graph.tracing = False
self._graph = None
else:
super().__exit__(exc_type, exc_val, exc_tb)
Expand Down Expand Up @@ -132,7 +134,7 @@ def blocking_request(self, request: pydantics.RequestModel):
response = requests.post(
f"https://{CONFIG.API.HOST}/request",
json=request.model_dump(exclude=["id", "received"]),
headers={'ndif-api-key' : CONFIG.API.APIKEY}
headers={"ndif-api-key": CONFIG.API.APIKEY},
)

if response.status_code == 200:
Expand Down
5 changes: 3 additions & 2 deletions src/nnsight/contexts/Tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(

self._kwargs = kwargs

self._graph = Graph(
self._graph: Graph = Graph(
self._model._envoy, proxy_class=model.proxy_class, validate=validate
)

Expand Down Expand Up @@ -63,14 +63,15 @@ def __enter__(self) -> Tracer:
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
if isinstance(exc_val, BaseException):
raise exc_val

output = self._model.interleave(
self._model._execute,
self._graph,
*self._batched_input,
**self._kwargs,
)

self._graph.tracing = False
self._graph = None

def invoke(self, *inputs: Tuple[Any], **kwargs) -> Invoker:
Expand Down
38 changes: 26 additions & 12 deletions src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,14 @@ def grad(self) -> InterventionProxy:
if self._grad is None:

# We track how many times backward is called via an attribute on the Graph
if not hasattr(self.node.graph, 'n_backward_calls'):
if not hasattr(self.node.graph, "n_backward_calls"):

setattr(self.node.graph, 'n_backward_calls', 0)
setattr(self.node.graph, "n_backward_calls", 0)

self.__dict__["_grad"] = self.node.add(
value=self.node.proxy_value, target="grad", args=[self.node,self.node.graph.n_backward_calls]
value=self.node.proxy_value,
target="grad",
args=[self.node, self.node.graph.n_backward_calls],
)

return self._grad
Expand All @@ -106,26 +108,25 @@ def grad(self, value: Union[InterventionProxy, Any]) -> None:

def __call__(self, *args, **kwargs) -> Self:


# We don't want to call backward on fake tensors
if (
self.node.target is util.fetch_attr
and isinstance(self.node.args[1], str)
and self.node.args[1] == "backward"
):
# We track how many times backward is called via an attribute on the Graph
if not hasattr(self.node.graph, 'n_backward_calls'):
if not hasattr(self.node.graph, "n_backward_calls"):

setattr(self.node.graph, 'n_backward_calls', 0)
setattr(self.node.graph, "n_backward_calls", 0)

# Clear all .grad proxies
# Clear all .grad proxies
for node in self.node.graph.nodes.values():

try:

if node.proxy._grad is not None:

node.proxy.__dict__['_grad'] = None
node.proxy.__dict__["_grad"] = None

except ReferenceError:
pass
Expand All @@ -139,7 +140,6 @@ def __call__(self, *args, **kwargs) -> Self:
kwargs=kwargs,
)


return super().__call__(*args, **kwargs)

def __setattr__(
Expand All @@ -159,7 +159,7 @@ def shape(self) -> Collection[torch.Size]:
Union[torch.Size,Collection[torch.Size]]: Proxy value shape or collection of shapes.
"""

if self.node.is_graph_dereferenced():
if not self.node.is_tracing():

return util.apply(self.value, lambda x: x.shape, torch.Tensor)

Expand All @@ -170,15 +170,29 @@ def device(self) -> Collection[torch.device]:
"""Property to retrieve the device of the traced proxy value or real value.
Returns:
Union[torch.Size,Collection[torch.device]]: Proxy value shape or collection of shapes.
Union[torch.Size,Collection[torch.device]]: Proxy value device or collection of devices.
"""

if self.node.is_graph_dereferenced():
if not self.node.is_tracing():

return util.apply(self.value, lambda x: x.device, torch.Tensor)

return util.apply(self.node.proxy_value, lambda x: x.device, torch.Tensor)

@property
def dtype(self) -> Collection[torch.device]:
"""Property to retrieve the dtype of the traced proxy value or real value.
Returns:
Union[torch.Size,Collection[torch.dtype]]: Proxy value dtype or collection of dtypes.
"""

if not self.node.is_tracing():

return util.apply(self.value, lambda x: x.dtype, torch.Tensor)

return util.apply(self.node.proxy_value, lambda x: x.dtype, torch.Tensor)


def concat(
activations: Any,
Expand Down
3 changes: 3 additions & 0 deletions src/nnsight/tracing/Graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Graph:
Attributes:
validate (bool): If to execute nodes as they are added with their proxy values in order to check if the executions are possible (i.e shape errors etc). Defaults to True.
proxy_class (Type[Proxy]): Proxy class to use. Defaults to Proxy.
tracing (bool): If currently tracing operations
nodes (Dict[str, Node]): Mapping of node name to node.
name_idx (Dict[str, int]): Mapping of node target_name to number of previous names with the same target_name.
Used so names are unique.
Expand All @@ -45,6 +46,8 @@ def __init__(
self.proxy_class = proxy_class
self.validate = validate

self.tracing = True

self.nodes: Dict[str, Node] = dict()
self.name_idx: Dict[str, int] = dict()

Expand Down
16 changes: 7 additions & 9 deletions src/nnsight/tracing/Node.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,21 +137,19 @@ def __init__(

self.execute()

def is_graph_dereferenced(self) -> bool:
"""Checks to see if the weakref to the Graph is deleted. If it is, we're no longer tracing.
def is_tracing(self) -> bool:
"""Checks to see if the weakref to the Graph is alive and tracing or dead.
Returns:
bool: Is Graph dereferenced.
bool: Is Graph tracing.
"""

try:

self.graph.add
return self.graph.tracing

except:
return True

return False
return False

def add(
self,
Expand All @@ -167,7 +165,7 @@ def add(
Proxy: Proxy
"""

if self.is_graph_dereferenced():
if not self.is_tracing():

return Proxy(
Node(
Expand Down Expand Up @@ -285,7 +283,7 @@ def grad(value):

self.set_value(value)

if not self.is_graph_dereferenced():
if self.is_tracing():

value = self.graph.get_swap(value)

Expand Down
10 changes: 5 additions & 5 deletions src/nnsight/tracing/Proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def value(self) -> Any:

def __str__(self) -> str:

if self.node.is_graph_dereferenced():
if not self.node.is_tracing():

return str(self.value)

Expand All @@ -63,7 +63,7 @@ def __str__(self) -> str:

def __repr__(self) -> str:

if self.node.is_graph_dereferenced():
if not self.node.is_tracing():

return repr(self.value)

Expand Down Expand Up @@ -216,12 +216,12 @@ def __rtruediv__(self, other: Union[Proxy, Any]) -> Self:
args=[other, self.node],
)

def __index__(self) -> Self:
return self.node.add(target=operator.index, args=[self.node])

def __bool__(self) -> bool:
return self.node.proxy_value.__bool__()

def __index__(self) -> int:
return self.node.proxy_value.__index__()

def __instancecheck__(self, __instance: Any) -> bool:
return self.node.proxy_value.__instancecheck__(__instance)

Expand Down

0 comments on commit 24253c4

Please sign in to comment.