Skip to content

Commit

Permalink
Merge pull request #222 from ndif-team/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JadenFiotto-Kaufman authored Sep 1, 2024
2 parents 70b804d + 9325a8f commit 5850fe9
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 215 deletions.
2 changes: 1 addition & 1 deletion docs/source/notebooks/tutorials/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
},
"source": [
"An interactive version of this walkthrough can be found\n",
"[here](https://colab.research.google.com/github/ndif-team/nnsight/blob/main/NNsight_Walkthough.ipynb)\n",
"[here](https://colab.research.google.com/github/ndif-team/nnsight/blob/main/NNsight_Walkthrough.ipynb)\n",
"\n",
"In this era of large-scale deep learning, the most interesting AI models are\n",
"massive black boxes that are hard to run. Ordinary commercial inference service\n",
Expand Down
51 changes: 28 additions & 23 deletions src/nnsight/contexts/GraphBasedContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,13 @@ def apply(
Returns:
InterventionProxy: Proxy of applying that function.
"""



proxy_value = inspect._empty

if validate is False:

proxy_value = None

return self.graph.create(
target=target,
proxy_value=proxy_value,
Expand Down Expand Up @@ -182,12 +181,12 @@ def list(self, *args, **kwargs) -> InterventionProxy:
"""NNsight helper method to create a traceable list."""

return self.apply(list, *args, **kwargs)

def set(self, *args, **kwargs) -> InterventionProxy:
"""NNsight helper method to create a traceable set."""

return self.apply(set, *args, **kwargs)

def dict(self, *args, **kwargs) -> InterventionProxy:
"""NNsight helper method to create a traceable dictionary."""

Expand Down Expand Up @@ -247,15 +246,17 @@ def bridge_backend_handle(self, bridge: Bridge) -> None:
from torch.utils import data


def global_patch(fn):
def global_patch(fn, applied_fn=None):

if applied_fn is None:

applied_fn = fn

@wraps(fn)
def inner(*args, **kwargs):

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(
fn,
*args,
**kwargs
applied_fn, *args, **kwargs
)

return inner
Expand All @@ -272,8 +273,18 @@ class GlobalTracingContext(GraphBasedContext):
TORCH_HANDLER: GlobalTracingContext.GlobalTracingTorchHandler
PATCHER: Patcher = Patcher(
[
Patch(torch.nn, global_patch(torch.nn.Parameter), "Parameter"),
Patch(data, global_patch(data.DataLoader), "DataLoader"),
Patch(
torch.nn.Parameter,
global_patch(
torch.nn.Parameter.__init__, applied_fn=torch.nn.Parameter
),
"__init__",
),
Patch(
data.DataLoader,
global_patch(data.DataLoader.__init__, applied_fn=data.DataLoader),
"__init__",
),
Patch(torch, global_patch(torch.arange), "arange"),
Patch(torch, global_patch(torch.empty), "empty"),
Patch(torch, global_patch(torch.eye), "eye"),
Expand All @@ -288,7 +299,7 @@ class GlobalTracingContext(GraphBasedContext):
Patch(torch, global_patch(torch.zeros), "zeros"),
]
+ [
Patch(torch.optim, global_patch(value), key)
Patch(value, global_patch(value.__init__, applied_fn=value), "__init__")
for key, value in getmembers(torch.optim, isclass)
if issubclass(value, torch.optim.Optimizer)
]
Expand All @@ -304,9 +315,7 @@ def __torch_function__(self, func, types, args, kwargs=None):

if "_VariableFunctionsClass" in func.__qualname__:
return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(
func,
*args,
**kwargs
func, *args, **kwargs
)

return func(*args, **kwargs)
Expand Down Expand Up @@ -391,9 +400,7 @@ def register(graph_based_context: GraphBasedContext) -> None:

assert GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph is None

GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph = (
graph_based_context.graph
)
GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph = graph_based_context.graph

GlobalTracingContext.TORCH_HANDLER.__enter__()
GlobalTracingContext.PATCHER.__enter__()
Expand Down Expand Up @@ -440,6 +447,4 @@ def __getattribute__(self, name: str) -> Any:


GlobalTracingContext.GLOBAL_TRACING_CONTEXT = GlobalTracingContext()
GlobalTracingContext.TORCH_HANDLER = (
GlobalTracingContext.GlobalTracingTorchHandler()
)
GlobalTracingContext.TORCH_HANDLER = GlobalTracingContext.GlobalTracingTorchHandler()
201 changes: 37 additions & 164 deletions src/nnsight/models/DiffusionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from diffusers import DiffusionPipeline, SchedulerMixin
from PIL import Image
from transformers import BatchEncoding, CLIPTextModel, CLIPTokenizer
from diffusers import DiffusionPipeline
from transformers import BatchEncoding
from typing_extensions import Self

from .. import util
from ..envoy import Envoy
from .mixins import GenerationMixin
from .NNsightModel import NNsight
from torch._guards import detect_fake_mode
from .. import util


class Diffuser(util.WrapperModule):
def __init__(self, *args, **kwargs) -> None:
Expand All @@ -23,135 +24,15 @@ def __init__(self, *args, **kwargs) -> None:
setattr(self, key, value)

self.tokenizer = self.pipeline.tokenizer

@torch.no_grad()
def scan(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
):

# 0. Default height and width to unet
height = (
height
or self.pipeline.unet.config.sample_size * self.pipeline.vae_scale_factor
)
width = (
width
or self.pipeline.unet.config.sample_size * self.pipeline.vae_scale_factor
)

# 1. Check inputs. Raise error if not correct
self.pipeline.check_inputs(
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]

do_classifier_free_guidance = guidance_scale > 1.0

# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None)
if cross_attention_kwargs is not None
else None
)
prompt_embeds, negative_prompt_embeds = self.pipeline.encode_prompt(
prompt,
"meta",
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

# 4. Prepare timesteps
timesteps = self.pipeline.scheduler.timesteps

# 5. Prepare latent variables
num_channels_latents = self.pipeline.unet.config.in_channels
latents = self.pipeline.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
"meta",
generator,
latents,
)

# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.pipeline.prepare_extra_step_kwargs(generator, eta)

# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)

# predict the noise residual
noise_pred = self.pipeline.unet(
latent_model_input,
0,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)

if not output_type == "latent":
image = self.pipeline.vae.decode(
latents / self.pipeline.vae.config.scaling_factor, return_dict=False
)[0]
else:
image = latents
has_nsfw_concept = None
class DiffusionModel(GenerationMixin, NNsight):

def __new__(cls, *args, **kwargs) -> Self | Envoy:
return object.__new__(cls)

class DiffusionModel(GenerationMixin, NNsight):
def __init__(self, *args, **kwargs) -> None:

self._model: Diffuser = None

super().__init__(*args, **kwargs)
Expand All @@ -162,7 +43,6 @@ def _load(self, repo_id: str, device_map=None, **kwargs) -> Diffuser:

model = Diffuser(
repo_id,
trust_remote_code=True,
device_map=None,
low_cpu_mem_usage=False,
**kwargs,
Expand All @@ -178,57 +58,50 @@ def _prepare_inputs(
self,
inputs: Union[str, List[str]],
) -> Any:

if isinstance(inputs, str):
inputs = [inputs]

return (inputs,), len(inputs)

# def _forward(self, inputs, *args, n_imgs=1, img_size=512, **kwargs) -> None:
# text_tokens, latents = inputs

# text_embeddings = self.meta_model.get_text_embeddings(text_tokens, n_imgs)

# latents = torch.cat([latents] * 2).to("meta")
return (inputs,), len(inputs)

# return self.meta_model.unet(
# latents,
# torch.zeros((1,), device="meta"),
# encoder_hidden_states=text_embeddings,
# ).sample
def _batch_inputs(
self,
batched_inputs: Optional[Dict[str, Any]],
prepared_inputs: BatchEncoding,
) -> torch.Tensor:

def _batch_inputs(self, batched_inputs: Optional[Dict[str, Any]],
prepared_inputs: BatchEncoding,) -> torch.Tensor:

if batched_inputs is None:

return prepared_inputs

return batched_inputs + prepared_inputs

def _execute_forward(self, prepared_inputs: Any, *args, **kwargs):

device = next(self._model.parameters()).device
def _execute_forward(self, prepared_inputs: Any, *args, **kwargs):

return self._model.unet(
prepared_inputs,
*args
*args,
**kwargs,
)

def _execute_generate(
self, prepared_inputs: Any, *args, **kwargs
self, prepared_inputs: Any, *args, seed: int = None, **kwargs
):
device = next(self._model.parameters()).device

if detect_fake_mode(prepared_inputs):

output = self._model.scan(*prepared_inputs)

else:

output = self._model.pipeline(prepared_inputs, *args, **kwargs)


if self._scanning():

kwargs["num_inference_steps"] = 1

generator = torch.Generator()

if seed is not None:

generator = generator.manual_seed(seed)

output = self._model.pipeline(
prepared_inputs, *args, generator=generator, **kwargs
)

output = self._model(output)

return output
Loading

0 comments on commit 5850fe9

Please sign in to comment.