Skip to content

Commit

Permalink
Sanity test and code cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jun 25, 2024
1 parent 4641f8d commit 4c2b568
Show file tree
Hide file tree
Showing 13 changed files with 519 additions and 1,538 deletions.
117 changes: 55 additions & 62 deletions examples/post_training_quantization/openvino/yolov8/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import time
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Tuple
from typing import Dict, Tuple

import numpy as np
import openvino as ov
Expand All @@ -29,13 +29,10 @@
from torch.fx.passes.graph_drawer import FxGraphDrawer
from tqdm import tqdm
from ultralytics.cfg import get_cfg
from ultralytics.data.converter import coco80_to_coco91_class
from ultralytics.data.utils import check_det_dataset
from ultralytics.engine.validator import BaseValidator as Validator
from ultralytics.models.yolo import YOLO
from ultralytics.utils import DATASETS_DIR
from ultralytics.utils import DEFAULT_CFG
from ultralytics.utils.metrics import ConfusionMatrix
from ultralytics.utils.torch_utils import de_parallel

import nncf
Expand All @@ -55,15 +52,18 @@ def measure_time(model, example_inputs, num_iters=500):
return average_time


def validate_fx_ult_method(model: ov.Model) -> Tuple[Dict, int, int]:
"""
Uses .val ultralitics method instead of a dataloader loop.
For some reason this shows better metrics on torch.compiled models
"""
yolo = YOLO(f"{ROOT}/{MODEL_NAME}.pt")
yolo.model = model
result = yolo.val(data="coco128.yaml", batch=1, rect=False)
return result.results_dict
def measure_time_ov(model, example_inputs, num_iters=1000):
ie = ov.Core()
compiled_model = ie.compile_model(model, "CPU")
infer_request = compiled_model.create_infer_request()
infer_request.infer(example_inputs)
total_time = 0
for i in range(0, num_iters):
start_time = time.time()
infer_request.infer(example_inputs)
total_time += time.time() - start_time
average_time = (total_time / num_iters) * 1000
return average_time


def validate_fx(
Expand Down Expand Up @@ -100,10 +100,10 @@ def print_statistics_short(stats: np.ndarray) -> None:
def validate_ov(
model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None
) -> Tuple[Dict, int, int]:
validator.seen = 0
validator.jdict = []
validator.stats = []
validator.confusion_matrix = ConfusionMatrix(nc=validator.nc)
# validator.seen = 0
# validator.jdict = []
# validator.stats = []
# validator.confusion_matrix = ConfusionMatrix(nc=validator.nc)
model.reshape({0: [1, 3, -1, -1]})
compiled_model = ov.compile_model(model)
output_layer = compiled_model.output(0)
Expand Down Expand Up @@ -131,7 +131,7 @@ def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) -
print(pf % ("all", total_images, total_objects, mp, mr, map50, mean_ap))


def prepare_validation_new(model: YOLO, data: str) -> Tuple[Validator, torch.utils.data.DataLoader]:
def prepare_validation(model: YOLO, data: str) -> Tuple[Validator, torch.utils.data.DataLoader]:
# custom = {"rect": True, "batch": 1} # method defaults
# rect: false forces to resize all input pictures to one size
custom = {"rect": False, "batch": 1} # method defaults
Expand All @@ -148,25 +148,6 @@ def prepare_validation_new(model: YOLO, data: str) -> Tuple[Validator, torch.uti
return validator, data_loader


def prepare_validation(model: YOLO, args: Any) -> Tuple[Validator, torch.utils.data.DataLoader]:
validator = model.smart_load("validator")(args)
validator.data = check_det_dataset(args.data)
dataset = validator.data["val"]
print(f"{dataset}")

data_loader = validator.get_dataloader(f"{DATASETS_DIR}/coco128", 1)

validator = model.smart_load("validator")(args)

validator.is_coco = True
validator.class_map = coco80_to_coco91_class()
validator.names = model.model.names
validator.metrics.names = validator.names
validator.nc = model.model.model[-1].nc

return validator, data_loader


def benchmark_performance(model_path, config) -> float:
command = f"benchmark_app -m {model_path} -d CPU -api async -t 30"
command += f' -shape "[1,3,{config.imgsz},{config.imgsz}]"'
Expand Down Expand Up @@ -221,7 +202,7 @@ def transform_fn(data_item: Dict):
return quantized_model


NNCF_QUANTIZATION = True
NNCF_QUANTIZATION = False


def quantize_impl(exported_model, val_loader, validator):
Expand Down Expand Up @@ -290,26 +271,25 @@ def main():
# args.data = "coco128.yaml"
# Prepare validation dataset and helper

validator, data_loader = prepare_validation_new(model, "coco128.yaml")
validator, data_loader = prepare_validation(model, "coco128.yaml")

# Convert to OpenVINO model
if TORCH_FX:
batch = next(iter(data_loader))
batch = validator.preprocess(batch)
batch = next(iter(data_loader))
batch = validator.preprocess(batch)

if TORCH_FX:
fp_stats, total_images, total_objects = validate_fx(model.model, tqdm(data_loader), validator)
print("Floating-point Torch model validation results:")
print_statistics(fp_stats, total_images, total_objects)

fp32_compiled_model = torch.compile(model.model, backend="openvino")
if NNCF_QUANTIZATION:
fp32_compiled_model = torch.compile(model.model, backend="openvino")
else:
fp32_compiled_model = torch.compile(model.model)
fp32_stats, total_images, total_objects = validate_fx(fp32_compiled_model, tqdm(data_loader), validator)
print("FP32 FX model validation results:")
print_statistics(fp32_stats, total_images, total_objects)

# result = validate_fx_ult_method(fp32_compiled_model)
# print("FX FP32 model .val validation")
# print_statistics_short(result)

print("Start quantization...")
# Rebuild model to reset ultralitics cache
model = YOLO(f"{ROOT}/{MODEL_NAME}.pt")
Expand All @@ -323,10 +303,6 @@ def main():
)
quantized_model = quantize_impl(deepcopy(exported_model), data_loader, validator)

# result = validate_fx_ult_method(quantized_model)
# print("FX INT8 model .val validation")
# print_statistics_short(result)

int8_stats, total_images, total_objects = validate_fx(quantized_model, tqdm(data_loader), validator)
print("INT8 FX model validation results:")
print_statistics(int8_stats, total_images, total_objects)
Expand Down Expand Up @@ -360,35 +336,52 @@ def main():
print("Quantized model validation results:")
print_statistics(q_stats, total_images, total_objects)

# Benchmark performance of FP32 model
fp_model_perf = benchmark_performance(ov_model_path, args)
print(f"Floating-point model performance: {fp_model_perf} FPS")

# Benchmark performance of quantized model
quantized_model_perf = benchmark_performance(quantized_model_path, args)
print(f"Quantized model performance: {quantized_model_perf} FPS")
fps = True
latency = True
fp_model_perf = -1
quantized_model_perf = -1
if fps:
# Benchmark performance of FP32 model
fp_model_perf = benchmark_performance(ov_model_path, args)
print(f"Floating-point model performance: {fp_model_perf} FPS")

# Benchmark performance of quantized model
quantized_model_perf = benchmark_performance(quantized_model_path, args)
print(f"Quantized model performance: {quantized_model_perf} FPS")
if latency:
fp_model_latency = measure_time_ov(ov_model, batch["img"])
print(f"FP32 OV model latency: {fp_model_latency}")
int8_model_latency = measure_time_ov(quantized_model, batch["img"])
print(f"INT8 OV model latency: {int8_model_latency}")

return fp_stats["metrics/mAP50-95(B)"], q_stats["metrics/mAP50-95(B)"], fp_model_perf, quantized_model_perf


def check_export_not_strict():
def main_export_not_strict():
model = YOLO(f"{ROOT}/{MODEL_NAME}.pt")

# Prepare validation dataset and helper
validator, data_loader = prepare_validation_new(model, "coco128.yaml")
validator, data_loader = prepare_validation(model, "coco128.yaml")

batch = next(iter(data_loader))
batch = validator.preprocess(batch)

model.model(batch["img"])
ex_model = torch.export.export(model.model, args=(batch["img"],), strict=False)
ex_model = capture_pre_autograd_graph(ex_model.module(), args=(batch["img"],))
ex_model = torch.compile(ex_model)

fp_stats, total_images, total_objects = validate_fx(ex_model, tqdm(data_loader), validator)
print("Floating-point ex strict=False")
print_statistics(fp_stats, total_images, total_objects)

quantized_model = quantize_impl(deepcopy(ex_model), data_loader, validator)
int8_stats, total_images, total_objects = validate_fx(quantized_model, tqdm(data_loader), validator)
print("Int8 ex strict=False")
print_statistics(int8_stats, total_images, total_objects)
# No quantized were inserted, metrics are OK


if __name__ == "__main__":
check_export_not_strict()
# main()
# main_export_not_strict()
main()
153 changes: 1 addition & 152 deletions nncf/experimental/torch_fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,48 +10,23 @@
# limitations under the License.

from collections import defaultdict
from dataclasses import dataclass

# from functools import partial
from typing import Callable, List, Optional, Union
from typing import Callable, List, Union

import torch
import torch.fx
from torch.ao.quantization.fx.utils import create_getattr_from_value
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_manager import PassManager
from torch.fx.passes.split_utils import split_by_tags

from nncf.common.graph.model_transformer import ModelTransformer

# from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
# from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.commands import TransformationType

# from torch import Tensor
# from torch import nn
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint

# from nncf.torch.graph.transformations.commands import PTTargetPoint
# from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand
from nncf.torch.graph.transformations.layout import PTTransformationLayout

# from torch.nn.parameter import Parameter
# from nncf.torch.model_graph_manager import update_fused_bias
# from nncf.torch.nncf_network import PTInsertionPoint
# from nncf.torch.nncf_network import compression_module_type_to_attr_name
# from nncf.torch.utils import get_model_device
# from nncf.torch.utils import is_multidevice


class FXModuleInsertionCommand(Command):
def __init__(
Expand Down Expand Up @@ -206,129 +181,3 @@ def _apply_transformation(
for transformation in transformations:
transformation.tranformation_fn(model)
return model


@dataclass
class QPARAMSPerTensor:
scale: float
zero_point: int
quant_min: int
quant_max: int
dtype: torch.dtype


@dataclass
class QPARAMPerChannel:
scales: torch.Tensor
zero_points: Optional[torch.Tensor]
axis: int
quant_min: int
quant_max: int
dtype: torch.dtype


def insert_qdq_to_model(model: torch.fx.GraphModule, qsetup) -> torch.fx.GraphModule:
# from prepare
_fuse_conv_bn_(model)

# from convert
original_graph_meta = model.meta
_insert_qdq_to_model(model, qsetup)

# Magic. Without this call compiled model
# is not preformant
model = GraphModule(model, model.graph)

model = _fold_conv_bn_qat(model)
pm = PassManager([DuplicateDQPass()])

model = pm(model).graph_module
pm = PassManager([PortNodeMetaForQDQ()])
model = pm(model).graph_module

model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model


def _insert_qdq_to_model(model: torch.fx.GraphModule, qsetup) -> torch.fx.GraphModule:
for idx, node in enumerate(list(model.graph.nodes)):
if node.name not in qsetup:
continue
# 1. extract information for inserting q/dq node from activation_post_process
params = qsetup[node.name]
node_type = "call_function"
quantize_op: Optional[Callable] = None
# scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
if isinstance(params, QPARAMPerChannel):
quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default
qparams = {
"_scale_": params.scales,
"_zero_point_": params.zero_points,
"_axis_": params.axis,
"_quant_min_": params.quant_min,
"_quant_max_": params.quant_max,
"_dtype_": params.dtype,
}
elif isinstance(params, QPARAMSPerTensor):
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
qparams = {
"_scale_": params.scale,
"_zero_point_": params.zero_point,
"_quant_min_": params.quant_min,
"_quant_max_": params.quant_max,
"_dtype_": params.dtype,
}

else:
raise RuntimeError(f"params {params} are unknown")
# 2. replace activation_post_process node with quantize and dequantize
graph = model.graph

# TODO: use metatype to get correct input_port_id
# Do not quantize already quantized nodes
# inserting_before handle only order in the graph generated code.
# so, inserting quantize-dequantize and all constant nodes before the usage of the nodes
with graph.inserting_before(node):
quantize_op_inputs = [node]
for key, value_or_node in qparams.items():
# TODO: we can add the information of whether a value needs to
# be registered as an attribute in qparams dict itself
if key in ["_scale_", "_zero_point_"] and (not isinstance(value_or_node, (float, int))):
# For scale and zero_point values we register them as buffers in the root module.
# However, note that when the values are not tensors, as in the case of
# per_tensor quantization, they will be treated as literals.
# However, registering them as a node seems to cause issue with dynamo
# tracing where it may consider tensor overload as opposed to default.
# With extra check of scale and zero_point being scalar, it makes
# sure that the default overload can be used.
# TODO: maybe need more complex attr name here
qparam_node = create_getattr_from_value(model, graph, str(idx) + key, value_or_node)
quantize_op_inputs.append(qparam_node)
else:
# for qparams that are not scale/zero_point (like axis, dtype) we store
# them as literals in the graph.
quantize_op_inputs.append(value_or_node)

with graph.inserting_after(node):
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
# use the same qparams from quantize op
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
user_dq_nodes = []
with graph.inserting_after(quantized_node):
for user in node.users:
if user is quantized_node:
continue
user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {})))

for user, dq_node in user_dq_nodes:
user.replace_input_with(node, dq_node)

# node.replace_all_uses_with(dequantized_node)
# graph.erase_node(node)
from torch.fx.passes.graph_drawer import FxGraphDrawer

g = FxGraphDrawer(model, "model_after_qdq_insertion")
g.get_dot_graph().write_svg("model_after_qdq_insertion.svg")
Loading

0 comments on commit 4c2b568

Please sign in to comment.