diff --git a/examples/.config/model_params_onnxrt.json b/examples/.config/model_params_onnxrt.json index e7e0cbda6..f06a063e2 100644 --- a/examples/.config/model_params_onnxrt.json +++ b/examples/.config/model_params_onnxrt.json @@ -83,6 +83,13 @@ "input_model": "/tf_dataset2/models/onnx/resnet50-v1-12/resnet50-v1-12.onnx", "main_script": "main.py", "batch_size": 1 + }, + "sd-v1-5-sq": { + "model_src_dir": "nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static", + "dataset_location": "", + "input_model": "/tf_dataset2/models/onnx/sd_v1_5", + "main_script": "main.py", + "batch_size": 1 } } } diff --git a/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/README.md b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/README.md new file mode 100644 index 000000000..17c89883d --- /dev/null +++ b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/README.md @@ -0,0 +1,47 @@ +Step-by-Step +============ + +This example shows how to quantize the unet model of [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) with SmoothQuant and generate images with the quantized unet. + +# Prerequisite + +## 1. Environment +```shell +pip install -r requirements.txt +``` +> Note: Validated ONNX Runtime [Version](/docs/installation_guide.md#validated-software-environment). + +## 2. Prepare Model + + +```bash +git clone https://github.com/huggingface/diffusers.git +cd diffusers/scripts +python convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path stable-diffusion +``` + +# Run + +## 1. Quantization + +```bash +bash run_quant.sh --input_model=/path/to/stable-diffusion \ # folder path of stable-diffusion + --output_model=/path/to/save/unet_model \ # model path as *.onnx + --alpha=0.7 # optional +``` + +## 2. Benchmark + +```bash +bash run_benchmark.sh --input_model=/path/to/stable-diffusion \ # folder path of stable-diffusion + --quantized_unet_path=/path/to/quantized/unet.onnx \ # optional, run fp32 model if not provided + --prompt="a photo of an astronaut riding a horse on mars" \ # optional + --image_path=image.png # optional +``` + +Benchmark will print the throughput data and save the generated image. +Our test results with default parameters is (fp32 vs int8): +
+ + +
diff --git a/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/imgs/fp32.png b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/imgs/fp32.png new file mode 100644 index 000000000..4ae187712 Binary files /dev/null and b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/imgs/fp32.png differ diff --git a/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/imgs/int8.png b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/imgs/int8.png new file mode 100644 index 000000000..486b76821 Binary files /dev/null and b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/imgs/int8.png differ diff --git a/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/main.py b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/main.py new file mode 100644 index 000000000..0eeca7ea6 --- /dev/null +++ b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/main.py @@ -0,0 +1,263 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name,logging-format-interpolation +import argparse +import inspect +import logging +import os +import time +from typing import List + +import numpy as np +import onnx +import onnxruntime as ort +import torch +from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline + +from onnx_neural_compressor import data_reader +from onnx_neural_compressor.quantization import QuantType, config, quantize + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.WARN +) + +parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument( + "--model_path", + type=str, + help="Folder path of ONNX Stable-diffusion model, it contains model_index.json and sub-model folders.", +) +parser.add_argument("--quantized_unet_path", type=str, default=None, help="Path of the quantized unet model.") +parser.add_argument("--benchmark", action="store_true", default=False) +parser.add_argument("--tune", action="store_true", default=False, help="whether quantize the model") +parser.add_argument("--output_model", type=str, default=None, help="output model path") +parser.add_argument("--image_path", type=str, default="image.png", help="generated image path") +parser.add_argument( + "--batch_size", + default=1, + type=int, +) +parser.add_argument("--prompt", type=str, default="a photo of an astronaut riding a horse on mars") +parser.add_argument("--alpha", type=float, default=0.7) +parser.add_argument("--seed", type=int, default=1234, help="random seed for generation") +parser.add_argument("--provider", type=str, default="CPUExecutionProvider") +args = parser.parse_args() + +ORT_TO_NP_TYPE = { + "tensor(bool)": np.bool_, + "tensor(int8)": np.int8, + "tensor(uint8)": np.uint8, + "tensor(int16)": np.int16, + "tensor(uint16)": np.uint16, + "tensor(int32)": np.int32, + "tensor(uint32)": np.uint32, + "tensor(int64)": np.int64, + "tensor(uint64)": np.uint64, + "tensor(float16)": np.float16, + "tensor(float)": np.float32, + "tensor(double)": np.float64, +} + +np.random.seed(args.seed) + + +def benchmark(model): + generator = None if args.seed is None else np.random.RandomState(args.seed) + + pipe = OnnxStableDiffusionPipeline.from_pretrained(args.model_path, provider=args.provider) + if args.quantized_unet_path is not None: + unet = OnnxRuntimeModel(model=ort.InferenceSession(args.quantized_unet_path, providers=[args.provider])) + pipe.unet = unet + + image = None + + tic = time.time() + image = pipe(prompt=args.prompt, generator=generator).images[0] + toc = time.time() + + if image is not None: + image.save(args.image_path) + print("Generated image is saved as " + args.image_path) + + print("\n", "-" * 10, "Summary:", "-" * 10) + throughput = 1 / (toc - tic) + print("Throughput: {} samples/s".format(throughput)) + + +class DataReader(data_reader.CalibrationDataReader): + + def __init__(self, model_path, batch_size=1): + self.encoded_list = [] + self.batch_size = batch_size + + model = onnx.load(os.path.join(model_path, "unet/model.onnx"), load_external_data=False) + inputs_names = [input.name for input in model.graph.input] + + generator = np.random + pipe = OnnxStableDiffusionPipeline.from_pretrained(model_path, provider="CPUExecutionProvider") + prompt = "A cat holding a sign that says hello world" + self.batch_size = batch_size + guidance_scale = 7.5 + do_classifier_free_guidance = guidance_scale > 1.0 + num_images_per_prompt = 1 + negative_prompt_embeds = None + negative_prompt = None + callback = None + eta = 0.0 + latents = None + prompt_embeds = None + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = pipe.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = pipe.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = pipe.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + # get the initial random noise unless the user supplied it + latents_dtype = prompt_embeds.dtype + latents_shape = (batch_size * num_images_per_prompt, 4, 512 // 8, 512 // 8) + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # set timesteps + pipe.scheduler.set_timesteps(50) + + latents = latents * np.float64(pipe.scheduler.init_noise_sigma) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(pipe.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + (input.type for input in pipe.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + for i, t in enumerate(pipe.scheduler.timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = pipe.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + ort_input = {} + for name, inp in zip(inputs_names, [latent_model_input, timestep, prompt_embeds]): + ort_input[name] = inp + self.encoded_list.append(ort_input) + noise_pred = pipe.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = pipe.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % 1 == 0: + step_idx = i // getattr(pipe.scheduler, "order", 1) + callback(step_idx, t, latents) + + self.iter_next = iter(self.encoded_list) + + def get_next(self): + return next(self.iter_next, None) + + def rewind(self): + self.iter_next = iter(self.encoded_list) + + +if __name__ == "__main__": + if args.benchmark: + benchmark(args.model_path) + + if args.tune: + data_reader = DataReader(args.model_path) + cfg = config.StaticQuantConfig( + data_reader, + weight_type=QuantType.QInt8, + activation_type=QuantType.QUInt8, + op_types_to_quantize=["MatMul", "Gemm"], + per_channel=True, + extra_options={ + "SmoothQuant": True, + "SmoothQuantAlpha": args.alpha, + "WeightSymmetric": True, + "ActivationSymmetric": False, + "OpTypesToExcludeOutputQuantization": ["MatMul", "Gemm"], + }, + ) + input_path = os.path.join(args.model_path, "unet/model.onnx") + quantize(input_path, args.output_model, cfg, optimization_level=ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED) diff --git a/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/requirements.txt b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/requirements.txt new file mode 100644 index 000000000..d42b1a34c --- /dev/null +++ b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/requirements.txt @@ -0,0 +1,7 @@ +torch +diffusers +onnx +onnxruntime +onnxruntime-extensions +onnx_neural_compressor +transformers==4.42.0 # restricted by model export diff --git a/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/run_benchmark.sh b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/run_benchmark.sh new file mode 100644 index 000000000..64974fb90 --- /dev/null +++ b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/run_benchmark.sh @@ -0,0 +1,67 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_benchmark + +} + +# init params +function init_params { + for var in "$@" + do + case $var in + --input_model=*) + input_model=$(echo "$var" |cut -f2 -d=) + ;; + --quantized_unet_path=*) + quantized_unet_path=$(echo "$var" |cut -f2 -d=) + ;; + --batch_size=*) + batch_size=$(echo "$var" |cut -f2 -d=) + ;; + --prompt=*) + prompt=$(echo "$var" |cut -f2 -d=) + ;; + --image_path=*) + image_path=$(echo "$var" |cut -f2 -d=) + ;; + esac + done + +} + +# run_benchmark +function run_benchmark { + + # Check if the input_model ends with the filename extension ".onnx" + if [[ $input_model =~ \.onnx$ ]]; then + # If the string ends with the filename extension, get the path of the file + input_model=$(dirname "$input_model") + fi + + extra_cmd="" + + if [ "$quantized_unet_path" ]; then + extra_cmd=$extra_cmd"--quantized_unet_path=${quantized_unet_path} " + fi + + if [ "$prompt" ]; then + extra_cmd=$extra_cmd"--prompt=${prompt} " + fi + + if [ "$image_path" ]; then + extra_cmd=$extra_cmd"--image_path=${image_path} " + fi + + if [ "$batch_size" ]; then + extra_cmd=$extra_cmd"--batch_size=${batch_size} " + fi + extra_cmd=$extra_cmd"--benchmark" + eval "python main.py --model_path=${input_model} ${extra_cmd}" +} + +main "$@" + diff --git a/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/run_quant.sh b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/run_quant.sh new file mode 100644 index 000000000..71b04947f --- /dev/null +++ b/examples/nlp/huggingface_model/text_to_image/stable_diffusion_v1_5/quantization/ptq_static/run_quant.sh @@ -0,0 +1,52 @@ +#!/bin/bash +set -x + +function main { + init_params "$@" + run_tuning +} + +# init params +function init_params { + for var in "$@" + do + case $var in + --input_model=*) + input_model=$(echo "$var" |cut -f2 -d=) + ;; + --output_model=*) + output_model=$(echo "$var" |cut -f2 -d=) + ;; + --alpha=*) + alpha=$(echo "$var" |cut -f2 -d=) + ;; + esac + done + +} + +# run_tuning +function run_tuning { + + # Check if the input_model ends with the filename extension ".onnx" + if [[ $input_model =~ \.onnx$ ]]; then + # If the string ends with the filename extension, get the path of the file + input_model=$(dirname "$input_model") + fi + + # Check if the directory exists + if [ ! -d "$(dirname "$output_model")" ]; then + # If the directory doesn't exist, create it + mkdir -p "$(dirname "$output_model")" + echo "Created directory $(dirname "$output_model")" + fi + + python main.py \ + --model_path "${input_model}" \ + --output_model "${output_model}" \ + --alpha "${alpha-0.7}" \ + --tune +} + +main "$@" + diff --git a/onnx_neural_compressor/algorithms/smoother/core.py b/onnx_neural_compressor/algorithms/smoother/core.py index bcf830f1a..0726a2685 100644 --- a/onnx_neural_compressor/algorithms/smoother/core.py +++ b/onnx_neural_compressor/algorithms/smoother/core.py @@ -272,7 +272,14 @@ def mul(node, scale): # pragma: no cover if self.model.model_path is not None else onnx.numpy_helper.to_array(tensor) * scale ) - self.model.set_initializer(inp, new_tensor) + # set_initializer requires the dims of old & new initializers are same + # Mul operator has broadcast mechanism + self.model.remove_initializer(tensor) + self.model.add_initializer( + onnx.helper.make_tensor( + inp, tensor.data_type, list(new_tensor.shape), new_tensor.flatten().tolist() + ) + ) self.tensor_scales_info[key] = ( 1.0 / scale if key not in self.tensor_scales_info