From b3c7aaf12ad56234a3b066e63e05349d3ec4a96f Mon Sep 17 00:00:00 2001 From: JadenFiottoKaufman Date: Fri, 30 Aug 2024 17:31:24 -0400 Subject: [PATCH 1/5] Update walkthrough link --- docs/source/notebooks/tutorials/walkthrough.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/notebooks/tutorials/walkthrough.ipynb b/docs/source/notebooks/tutorials/walkthrough.ipynb index c90806bc..978e4c13 100644 --- a/docs/source/notebooks/tutorials/walkthrough.ipynb +++ b/docs/source/notebooks/tutorials/walkthrough.ipynb @@ -14,7 +14,7 @@ }, "source": [ "An interactive version of this walkthrough can be found\n", - "[here](https://colab.research.google.com/github/ndif-team/nnsight/blob/main/NNsight_Walkthough.ipynb)\n", + "[here](https://colab.research.google.com/github/ndif-team/nnsight/blob/main/NNsight_Walkthrough.ipynb)\n", "\n", "In this era of large-scale deep learning, the most interesting AI models are\n", "massive black boxes that are hard to run. Ordinary commercial inference service\n", From 7a5bbff8963f16e091614d3be25cd3a3825a7b96 Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Sun, 1 Sep 2024 17:37:44 -0400 Subject: [PATCH 2/5] When doing global patching on a class vs a fn need to 1.) Patch the __init__ method not the class itself to preserve type checking like isinstance 2.) Still make sure the nnsight tracing fn applied is the class not the __init__ method --- src/nnsight/contexts/GraphBasedContext.py | 51 +++++++++++++---------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/src/nnsight/contexts/GraphBasedContext.py b/src/nnsight/contexts/GraphBasedContext.py index 3d005a69..29e1709e 100755 --- a/src/nnsight/contexts/GraphBasedContext.py +++ b/src/nnsight/contexts/GraphBasedContext.py @@ -53,14 +53,13 @@ def apply( Returns: InterventionProxy: Proxy of applying that function. """ - - + proxy_value = inspect._empty - + if validate is False: - + proxy_value = None - + return self.graph.create( target=target, proxy_value=proxy_value, @@ -182,12 +181,12 @@ def list(self, *args, **kwargs) -> InterventionProxy: """NNsight helper method to create a traceable list.""" return self.apply(list, *args, **kwargs) - + def set(self, *args, **kwargs) -> InterventionProxy: """NNsight helper method to create a traceable set.""" return self.apply(set, *args, **kwargs) - + def dict(self, *args, **kwargs) -> InterventionProxy: """NNsight helper method to create a traceable dictionary.""" @@ -247,15 +246,17 @@ def bridge_backend_handle(self, bridge: Bridge) -> None: from torch.utils import data -def global_patch(fn): +def global_patch(fn, applied_fn=None): + + if applied_fn is None: + + applied_fn = fn @wraps(fn) def inner(*args, **kwargs): return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply( - fn, - *args, - **kwargs + applied_fn, *args, **kwargs ) return inner @@ -272,8 +273,18 @@ class GlobalTracingContext(GraphBasedContext): TORCH_HANDLER: GlobalTracingContext.GlobalTracingTorchHandler PATCHER: Patcher = Patcher( [ - Patch(torch.nn, global_patch(torch.nn.Parameter), "Parameter"), - Patch(data, global_patch(data.DataLoader), "DataLoader"), + Patch( + torch.nn.Parameter, + global_patch( + torch.nn.Parameter.__init__, applied_fn=torch.nn.Parameter + ), + "__init__", + ), + Patch( + data.DataLoader, + global_patch(data.DataLoader.__init__, applied_fn=data.DataLoader), + "__init__", + ), Patch(torch, global_patch(torch.arange), "arange"), Patch(torch, global_patch(torch.empty), "empty"), Patch(torch, global_patch(torch.eye), "eye"), @@ -288,7 +299,7 @@ class GlobalTracingContext(GraphBasedContext): Patch(torch, global_patch(torch.zeros), "zeros"), ] + [ - Patch(torch.optim, global_patch(value), key) + Patch(value, global_patch(value.__init__, applied_fn=value), "__init__") for key, value in getmembers(torch.optim, isclass) if issubclass(value, torch.optim.Optimizer) ] @@ -304,9 +315,7 @@ def __torch_function__(self, func, types, args, kwargs=None): if "_VariableFunctionsClass" in func.__qualname__: return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply( - func, - *args, - **kwargs + func, *args, **kwargs ) return func(*args, **kwargs) @@ -391,9 +400,7 @@ def register(graph_based_context: GraphBasedContext) -> None: assert GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph is None - GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph = ( - graph_based_context.graph - ) + GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph = graph_based_context.graph GlobalTracingContext.TORCH_HANDLER.__enter__() GlobalTracingContext.PATCHER.__enter__() @@ -440,6 +447,4 @@ def __getattribute__(self, name: str) -> Any: GlobalTracingContext.GLOBAL_TRACING_CONTEXT = GlobalTracingContext() -GlobalTracingContext.TORCH_HANDLER = ( - GlobalTracingContext.GlobalTracingTorchHandler() -) +GlobalTracingContext.TORCH_HANDLER = GlobalTracingContext.GlobalTracingTorchHandler() From cd8fa105374ebd22c129606134df9bad559d5701 Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Sun, 1 Sep 2024 17:39:31 -0400 Subject: [PATCH 3/5] Define __new__ in NNsight so type hint that methods on the NNsight objects should also reflect the underlying Envoy --- src/nnsight/models/NNsightModel.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/nnsight/models/NNsightModel.py b/src/nnsight/models/NNsightModel.py index d7b1c376..f29ebaf6 100755 --- a/src/nnsight/models/NNsightModel.py +++ b/src/nnsight/models/NNsightModel.py @@ -71,6 +71,10 @@ class NNsight: proxy_class: Type[InterventionProxy] = InterventionProxy + # For type hinting Envoy + def __new__(cls, *args, **kwargs) -> Self | Envoy: + return super().__new__(cls) + def __init__( self, model_key: Union[str, torch.nn.Module], @@ -79,7 +83,6 @@ def __init__( meta_buffers: bool = True, **kwargs, ) -> None: - super().__init__() self._model_key = model_key @@ -457,9 +460,7 @@ def interleave( intervention_graph, batch_groups, batch_size ) - module_paths = InterventionProtocol.get_interventions( - intervention_graph - ).keys() + module_paths = InterventionProtocol.get_interventions(intervention_graph).keys() with HookHandler( self._model, @@ -522,15 +523,13 @@ def __repr__(self) -> str: def __setattr__(self, key: Any, value: Any) -> None: """Overload setattr to create and set an Envoy when trying to set a torch Module.""" - if key not in ("_model", "_model_key") and isinstance( - value, torch.nn.Module - ): + if key not in ("_model", "_model_key") and isinstance(value, torch.nn.Module): setattr(self._envoy, key, value) else: - super().__setattr__(key, value) + object.__setattr__(self, key, value) def __getattr__(self, key: Any) -> Union[Envoy, InterventionProxy, Any]: """Wrapper of ._envoy's attributes to access module's inputs and outputs. @@ -560,9 +559,7 @@ def _load(self, repo_id: str, *args, **kwargs) -> torch.nn.Module: return AutoModel.from_config(config, trust_remote_code=True) - return accelerate.load_checkpoint_and_dispatch( - self._model, repo_id, **kwargs - ) + return accelerate.load_checkpoint_and_dispatch(self._model, repo_id, **kwargs) def _execute(self, *prepared_inputs: Any, **kwargs) -> Any: """Virtual method to run the underlying ._model with some inputs. From 4e1a5e51e2c6c2ce3f9f164547500c6b8e47e2ea Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Sun, 1 Sep 2024 18:38:08 -0400 Subject: [PATCH 4/5] Update DiffusionModel with 0.3 and some other nice to haves. --- src/nnsight/models/DiffusionModel.py | 197 +++++---------------------- 1 file changed, 32 insertions(+), 165 deletions(-) diff --git a/src/nnsight/models/DiffusionModel.py b/src/nnsight/models/DiffusionModel.py index b92f5b15..7b4fd939 100755 --- a/src/nnsight/models/DiffusionModel.py +++ b/src/nnsight/models/DiffusionModel.py @@ -3,14 +3,13 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch -from diffusers import DiffusionPipeline, SchedulerMixin -from PIL import Image -from transformers import BatchEncoding, CLIPTextModel, CLIPTokenizer +from diffusers import DiffusionPipeline +from transformers import BatchEncoding +from .. import util from .mixins import GenerationMixin from .NNsightModel import NNsight -from torch._guards import detect_fake_mode -from .. import util + class Diffuser(util.WrapperModule): def __init__(self, *args, **kwargs) -> None: @@ -23,135 +22,11 @@ def __init__(self, *args, **kwargs) -> None: setattr(self, key, value) self.tokenizer = self.pipeline.tokenizer - - @torch.no_grad() - def scan( - self, - prompt: Union[str, List[str]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guidance_rescale: float = 0.0, - ): - - # 0. Default height and width to unet - height = ( - height - or self.pipeline.unet.config.sample_size * self.pipeline.vae_scale_factor - ) - width = ( - width - or self.pipeline.unet.config.sample_size * self.pipeline.vae_scale_factor - ) - - # 1. Check inputs. Raise error if not correct - self.pipeline.check_inputs( - prompt, - height, - width, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) - if cross_attention_kwargs is not None - else None - ) - prompt_embeds, negative_prompt_embeds = self.pipeline.encode_prompt( - prompt, - "meta", - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - # 4. Prepare timesteps - timesteps = self.pipeline.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.pipeline.unet.config.in_channels - latents = self.pipeline.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - "meta", - generator, - latents, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.pipeline.prepare_extra_step_kwargs(generator, eta) - - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents - ) - - # predict the noise residual - noise_pred = self.pipeline.unet( - latent_model_input, - 0, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=False, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - if not output_type == "latent": - image = self.pipeline.vae.decode( - latents / self.pipeline.vae.config.scaling_factor, return_dict=False - )[0] - else: - image = latents - has_nsfw_concept = None class DiffusionModel(GenerationMixin, NNsight): def __init__(self, *args, **kwargs) -> None: - + self._model: Diffuser = None super().__init__(*args, **kwargs) @@ -162,7 +37,6 @@ def _load(self, repo_id: str, device_map=None, **kwargs) -> Diffuser: model = Diffuser( repo_id, - trust_remote_code=True, device_map=None, low_cpu_mem_usage=False, **kwargs, @@ -178,57 +52,50 @@ def _prepare_inputs( self, inputs: Union[str, List[str]], ) -> Any: - + if isinstance(inputs, str): inputs = [inputs] - - return (inputs,), len(inputs) - - # def _forward(self, inputs, *args, n_imgs=1, img_size=512, **kwargs) -> None: - # text_tokens, latents = inputs - - # text_embeddings = self.meta_model.get_text_embeddings(text_tokens, n_imgs) - # latents = torch.cat([latents] * 2).to("meta") + return (inputs,), len(inputs) - # return self.meta_model.unet( - # latents, - # torch.zeros((1,), device="meta"), - # encoder_hidden_states=text_embeddings, - # ).sample + def _batch_inputs( + self, + batched_inputs: Optional[Dict[str, Any]], + prepared_inputs: BatchEncoding, + ) -> torch.Tensor: - def _batch_inputs(self, batched_inputs: Optional[Dict[str, Any]], - prepared_inputs: BatchEncoding,) -> torch.Tensor: - if batched_inputs is None: - + return prepared_inputs - + return batched_inputs + prepared_inputs - - def _execute_forward(self, prepared_inputs: Any, *args, **kwargs): - device = next(self._model.parameters()).device + def _execute_forward(self, prepared_inputs: Any, *args, **kwargs): return self._model.unet( prepared_inputs, - *args + *args, **kwargs, ) def _execute_generate( - self, prepared_inputs: Any, *args, **kwargs + self, prepared_inputs: Any, *args, seed: int = None, **kwargs ): - device = next(self._model.parameters()).device - - if detect_fake_mode(prepared_inputs): - - output = self._model.scan(*prepared_inputs) - - else: - - output = self._model.pipeline(prepared_inputs, *args, **kwargs) - + + if self._scanning(): + + kwargs["num_inference_steps"] = 1 + + generator = torch.Generator() + + if seed is not None: + + generator = generator.manual_seed(seed) + + output = self._model.pipeline( + prepared_inputs, *args, generator=generator, **kwargs + ) + output = self._model(output) return output From 0b4714f7c8afd116b96728a53d5e8b223cd35b9a Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Sun, 1 Sep 2024 18:52:46 -0400 Subject: [PATCH 5/5] Need to add the hinting to subclasses as well --- src/nnsight/models/DiffusionModel.py | 51 ++++++++++++++++------------ src/nnsight/models/LanguageModel.py | 27 ++++++--------- 2 files changed, 40 insertions(+), 38 deletions(-) diff --git a/src/nnsight/models/DiffusionModel.py b/src/nnsight/models/DiffusionModel.py index b92f5b15..4638ba2c 100755 --- a/src/nnsight/models/DiffusionModel.py +++ b/src/nnsight/models/DiffusionModel.py @@ -5,12 +5,15 @@ import torch from diffusers import DiffusionPipeline, SchedulerMixin from PIL import Image +from torch._guards import detect_fake_mode from transformers import BatchEncoding, CLIPTextModel, CLIPTokenizer +from typing_extensions import Self +from .. import util +from ..envoy import Envoy from .mixins import GenerationMixin from .NNsightModel import NNsight -from torch._guards import detect_fake_mode -from .. import util + class Diffuser(util.WrapperModule): def __init__(self, *args, **kwargs) -> None: @@ -23,7 +26,7 @@ def __init__(self, *args, **kwargs) -> None: setattr(self, key, value) self.tokenizer = self.pipeline.tokenizer - + @torch.no_grad() def scan( self, @@ -123,7 +126,7 @@ def scan( latent_model_input = ( torch.cat([latents] * 2) if do_classifier_free_guidance else latents ) - + # predict the noise residual noise_pred = self.pipeline.unet( latent_model_input, @@ -150,8 +153,12 @@ def scan( class DiffusionModel(GenerationMixin, NNsight): + + def __new__(cls, *args, **kwargs) -> Self | Envoy: + return object.__new__(cls) + def __init__(self, *args, **kwargs) -> None: - + self._model: Diffuser = None super().__init__(*args, **kwargs) @@ -178,10 +185,10 @@ def _prepare_inputs( self, inputs: Union[str, List[str]], ) -> Any: - + if isinstance(inputs, str): inputs = [inputs] - + return (inputs,), len(inputs) # def _forward(self, inputs, *args, n_imgs=1, img_size=512, **kwargs) -> None: @@ -197,38 +204,38 @@ def _prepare_inputs( # encoder_hidden_states=text_embeddings, # ).sample - def _batch_inputs(self, batched_inputs: Optional[Dict[str, Any]], - prepared_inputs: BatchEncoding,) -> torch.Tensor: - + def _batch_inputs( + self, + batched_inputs: Optional[Dict[str, Any]], + prepared_inputs: BatchEncoding, + ) -> torch.Tensor: + if batched_inputs is None: - + return prepared_inputs - + return batched_inputs + prepared_inputs - + def _execute_forward(self, prepared_inputs: Any, *args, **kwargs): device = next(self._model.parameters()).device return self._model.unet( prepared_inputs, - *args - **kwargs, + *args**kwargs, ) - def _execute_generate( - self, prepared_inputs: Any, *args, **kwargs - ): + def _execute_generate(self, prepared_inputs: Any, *args, **kwargs): device = next(self._model.parameters()).device - + if detect_fake_mode(prepared_inputs): - + output = self._model.scan(*prepared_inputs) - + else: output = self._model.pipeline(prepared_inputs, *args, **kwargs) - + output = self._model(output) return output diff --git a/src/nnsight/models/LanguageModel.py b/src/nnsight/models/LanguageModel.py index 829eae0a..164630fe 100755 --- a/src/nnsight/models/LanguageModel.py +++ b/src/nnsight/models/LanguageModel.py @@ -18,6 +18,8 @@ from transformers.models.llama.configuration_llama import LlamaConfig from typing_extensions import Self +from nnsight.envoy import Envoy + from ..intervention import InterventionProxy from ..util import WrapperModule from . import NNsight @@ -48,9 +50,7 @@ def __getitem__(self, key: int) -> LanguageModelProxy: return self.proxy[:, key] - def __setitem__( - self, key: int, value: Union[LanguageModelProxy, Any] - ) -> None: + def __setitem__(self, key: int, value: Union[LanguageModelProxy, Any]) -> None: key = self.convert_idx(key) self.proxy[:, key] = value @@ -134,6 +134,9 @@ class LanguageModel(GenerationMixin, RemoteableMixin, NNsight): proxy_class = LanguageModelProxy + def __new__(cls, *args, **kwargs) -> Self | Envoy: + return object.__new__(cls) + def __init__( self, model_key: Union[str, torch.nn.Module], @@ -176,7 +179,7 @@ def _load( repo_id, config=config, **tokenizer_kwargs ) - if not hasattr(self.tokenizer.pad_token, 'pad_token'): + if not hasattr(self.tokenizer.pad_token, "pad_token"): self.tokenizer.pad_token = self.tokenizer.eos_token if ( @@ -235,9 +238,7 @@ def _tokenize( inputs = [{"input_ids": ids} for ids in inputs] return self.tokenizer.pad(inputs, return_tensors="pt", **kwargs) - return self.tokenizer( - inputs, return_tensors="pt", padding=True, **kwargs - ) + return self.tokenizer(inputs, return_tensors="pt", padding=True, **kwargs) def _prepare_inputs( self, @@ -268,9 +269,7 @@ def _prepare_inputs( ai, -len(attn_mask) : ] = attn_mask - new_inputs["attention_mask"] = tokenized_inputs[ - "attention_mask" - ] + new_inputs["attention_mask"] = tokenized_inputs["attention_mask"] if "labels" in inputs: labels = self._tokenize(inputs["labels"], **kwargs) @@ -312,9 +311,7 @@ def _batch_inputs( if "labels" in prepared_inputs: batched_inputs["labels"].extend(prepared_inputs["labels"]) if "attention_mask" in prepared_inputs: - batched_inputs["attention_mask"].extend( - prepared_inputs["attention_mask"] - ) + batched_inputs["attention_mask"].extend(prepared_inputs["attention_mask"]) return (batched_inputs,) @@ -347,9 +344,7 @@ def _execute_generate( def _remoteable_model_key(self) -> str: return json.dumps( - { - "repo_id": self._model_key - } # , "torch_dtype": str(self._model.dtype)} + {"repo_id": self._model_key} # , "torch_dtype": str(self._model.dtype)} ) @classmethod