Skip to content

Commit

Permalink
Merge pull request #227 from ndif-team/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JadenFiotto-Kaufman authored Sep 3, 2024
2 parents 0f3340d + 77e71d7 commit 60bca6b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
24 changes: 20 additions & 4 deletions src/nnsight/contexts/GraphBasedContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,19 +253,31 @@ def global_patch(root, name: str) -> Patch:
@wraps(fn)
def inner(*args, **kwargs):

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(fn, *args, **kwargs)
return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(
fn, *args, **kwargs
)

return Patch(root, inner, name)


def global_patch_class(cls: type) -> Patch:

if cls.__new__ is object.__new__:

def super_new(cls, *args, **kwargs):

return object.__new__(cls)

cls.__new__ = super_new

fn = cls.__new__

@wraps(fn)
def inner(cls, *args, **kwargs):

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(cls, *args, **kwargs)
return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(
cls, *args, **kwargs
)

return Patch(cls, inner, "__new__")

Expand Down Expand Up @@ -398,7 +410,9 @@ def register(graph_based_context: GraphBasedContext) -> None:

assert GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph is None

GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph = graph_based_context.graph
GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph = (
graph_based_context.graph
)

GlobalTracingContext.TORCH_HANDLER.__enter__()
GlobalTracingContext.PATCHER.__enter__()
Expand Down Expand Up @@ -445,4 +459,6 @@ def __getattribute__(self, name: str) -> Any:


GlobalTracingContext.GLOBAL_TRACING_CONTEXT = GlobalTracingContext()
GlobalTracingContext.TORCH_HANDLER = GlobalTracingContext.GlobalTracingTorchHandler()
GlobalTracingContext.TORCH_HANDLER = (
GlobalTracingContext.GlobalTracingTorchHandler()
)
2 changes: 1 addition & 1 deletion src/nnsight/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
remote_logger = logging.getLogger("nnsight_remote")
remote_handler = logging.StreamHandler()
remote_handler.setFormatter(
logging.Formatter("%(asctime)s %(processName)-10s %(name)s %(levelname)-8s %(message)s")
logging.Formatter("%(asctime)s %(message)s")
)
remote_handler.setLevel(logging.INFO)
remote_logger.addHandler(remote_handler)
Expand Down

0 comments on commit 60bca6b

Please sign in to comment.