forked from openvinotoolkit/nncf
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ssd_vgg300, ssdlite mobilenetv3, yolov8 repro
- Loading branch information
1 parent
d45a55f
commit 16b3126
Showing
3 changed files
with
297 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
import openvino.torch # noqa | ||
import torch | ||
from torch._export import capture_pre_autograd_graph | ||
from torch.export import Dim # noqa | ||
from torch.fx.passes.graph_drawer import FxGraphDrawer | ||
from tqdm import tqdm | ||
from ultralytics.cfg import get_cfg | ||
|
@@ -35,6 +36,7 @@ | |
from ultralytics.utils import DATASETS_DIR | ||
from ultralytics.utils import DEFAULT_CFG | ||
from ultralytics.utils.metrics import ConfusionMatrix | ||
from ultralytics.utils.torch_utils import de_parallel | ||
|
||
import nncf | ||
|
||
|
@@ -53,13 +55,24 @@ def measure_time(model, example_inputs, num_iters=500): | |
return average_time | ||
|
||
|
||
def validate_fx_ult_method(model: ov.Model) -> Tuple[Dict, int, int]: | ||
""" | ||
Uses .val ultralitics method instead of a dataloader loop. | ||
For some reason this shows better metrics on torch.compiled models | ||
""" | ||
yolo = YOLO(f"{ROOT}/{MODEL_NAME}.pt") | ||
yolo.model = model | ||
result = yolo.val(data="coco128.yaml", batch=1, rect=False) | ||
return result.results_dict | ||
|
||
|
||
def validate_fx( | ||
model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None | ||
) -> Tuple[Dict, int, int]: | ||
validator.seen = 0 | ||
validator.jdict = [] | ||
validator.stats = [] | ||
validator.confusion_matrix = ConfusionMatrix(nc=validator.nc) | ||
# validator.seen = 0 | ||
# validator.jdict = [] | ||
# validator.stats = [] | ||
# validator.confusion_matrix = ConfusionMatrix(nc=validator.nc) | ||
for batch_i, batch in enumerate(data_loader): | ||
if num_samples is not None and batch_i == num_samples: | ||
break | ||
|
@@ -71,7 +84,20 @@ def validate_fx( | |
return stats, validator.seen, validator.nt_per_class.sum() | ||
|
||
|
||
def validate( | ||
def print_statistics_short(stats: np.ndarray) -> None: | ||
mp, mr, map50, mean_ap = ( | ||
stats["metrics/precision(B)"], | ||
stats["metrics/recall(B)"], | ||
stats["metrics/mAP50(B)"], | ||
stats["metrics/mAP50-95(B)"], | ||
) | ||
s = ("%20s" + "%12s" * 4) % ("Class", "Precision", "Recall", "[email protected]", "[email protected]:.95") | ||
print(s) | ||
pf = "%20s" + "%12.3g" * 4 # print format | ||
print(pf % ("all", mp, mr, map50, mean_ap)) | ||
|
||
|
||
def validate_ov( | ||
model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None | ||
) -> Tuple[Dict, int, int]: | ||
validator.seen = 0 | ||
|
@@ -105,6 +131,23 @@ def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) - | |
print(pf % ("all", total_images, total_objects, mp, mr, map50, mean_ap)) | ||
|
||
|
||
def prepare_validation_new(model: YOLO, data: str) -> Tuple[Validator, torch.utils.data.DataLoader]: | ||
# custom = {"rect": True, "batch": 1} # method defaults | ||
# rect: false forces to resize all input pictures to one size | ||
custom = {"rect": False, "batch": 1} # method defaults | ||
args = {**model.overrides, **custom, "mode": "val"} # highest priority args on the right | ||
|
||
validator = model._smart_load("validator")(args=args, _callbacks=model.callbacks) | ||
stride = 32 # default stride | ||
validator.stride = stride # used in get_dataloader() for padding | ||
validator.data = check_det_dataset(data) | ||
validator.init_metrics(de_parallel(model)) | ||
|
||
data_loader = validator.get_dataloader(validator.data.get(validator.args.split), validator.args.batch) | ||
|
||
return validator, data_loader | ||
|
||
|
||
def prepare_validation(model: YOLO, args: Any) -> Tuple[Validator, torch.utils.data.DataLoader]: | ||
validator = model.smart_load("validator")(args) | ||
validator.data = check_det_dataset(args.data) | ||
|
@@ -236,49 +279,65 @@ def transform_fn(x): | |
|
||
|
||
TORCH_FX = True | ||
MODEL_NAME = "yolov8n" | ||
|
||
|
||
def main(): | ||
MODEL_NAME = "yolov8n" | ||
|
||
model = YOLO(f"{ROOT}/{MODEL_NAME}.pt") | ||
args = get_cfg(cfg=DEFAULT_CFG) | ||
args.data = "coco128.yaml" | ||
|
||
# args = get_cfg(cfg=DEFAULT_CFG) | ||
# args.data = "coco128.yaml" | ||
# Prepare validation dataset and helper | ||
validator, data_loader = prepare_validation(model, args) | ||
|
||
validator, data_loader = prepare_validation_new(model, "coco128.yaml") | ||
|
||
# Convert to OpenVINO model | ||
if TORCH_FX: | ||
batch = next(iter(data_loader)) | ||
batch = validator.preprocess(batch) | ||
|
||
fp_stats, total_images, total_objects = validate_fx(model.model, tqdm(data_loader), validator) | ||
print("Floating-point Torch model validation results:") | ||
print_statistics(fp_stats, total_images, total_objects) | ||
|
||
fp32_compiled_model = torch.compile(model.model, backend="openvino") | ||
fp32_stats, total_images, total_objects = validate_fx(fp32_compiled_model, tqdm(data_loader), validator) | ||
print("FP32 FX model validation results:") | ||
print_statistics(fp32_stats, total_images, total_objects) | ||
|
||
# result = validate_fx_ult_method(fp32_compiled_model) | ||
# print("FX FP32 model .val validation") | ||
# print_statistics_short(result) | ||
|
||
print("Start quantization...") | ||
# Rebuild model to reset ultralitics cache | ||
model = YOLO(f"{ROOT}/{MODEL_NAME}.pt") | ||
with torch.no_grad(): | ||
# fp_stats, total_images, total_object = validate(model.model, tqdm(data_loader), validator) | ||
# print("Floating-point model validation results:") | ||
# print_statistics(fp_stats, total_images, total_objects) | ||
model.model.eval() | ||
model.model(batch["img"]) | ||
exported_model = capture_pre_autograd_graph(model.model, args=(batch["img"],)) | ||
# dynamic_shapes = ((None, None, Dim("H", min=1, max=29802), Dim("W", min=1, max=29802)),) | ||
dynamic_shapes = ((None, None, None, None),) | ||
exported_model = capture_pre_autograd_graph( | ||
model.model, args=(batch["img"],), dynamic_shapes=dynamic_shapes | ||
) | ||
quantized_model = quantize_impl(deepcopy(exported_model), data_loader, validator) | ||
|
||
fp32_compiled_model = torch.compile(exported_model, backend="openvino") | ||
fp32_stats, total_images, total_objects = validate_fx(fp32_compiled_model, tqdm(data_loader), validator) | ||
# fp32_stats, total_images, total_objects = validate_fx(model.model, tqdm(data_loader), validator) | ||
print("FP32 model validation results:") | ||
print_statistics(fp32_stats, total_images, total_objects) | ||
# result = validate_fx_ult_method(quantized_model) | ||
# print("FX INT8 model .val validation") | ||
# print_statistics_short(result) | ||
|
||
int8_stats, total_images, total_objects = validate_fx(quantized_model, tqdm(data_loader), validator) | ||
print("INT8 model validation results:") | ||
print("INT8 FX model validation results:") | ||
print_statistics(int8_stats, total_images, total_objects) | ||
|
||
print("Start fp32 model benchmarking...") | ||
print("Start FX fp32 model benchmarking...") | ||
fp32_latency = measure_time(fp32_compiled_model, (batch["img"],)) | ||
print(f"fp32 latency: {fp32_latency}") | ||
print(f"fp32 FX latency: {fp32_latency}") | ||
|
||
print("Start int8 model benchmarking...") | ||
print("Start FX int8 model benchmarking...") | ||
int8_latency = measure_time(quantized_model, (batch["img"],)) | ||
print(f"int8 latency: {int8_latency}") | ||
print(f"FX int8 latency: {int8_latency}") | ||
print(f"Speed up: {fp32_latency / int8_latency}") | ||
return | ||
|
||
|
@@ -289,13 +348,15 @@ def main(): | |
quantized_model_path = Path(f"{ROOT}/{MODEL_NAME}_openvino_model/{MODEL_NAME}_quantized.xml") | ||
ov.save_model(quantized_model, str(quantized_model_path), compress_to_fp16=False) | ||
|
||
args = get_cfg(cfg=DEFAULT_CFG) | ||
args.data = "coco128.yaml" | ||
# Validate FP32 model | ||
fp_stats, total_images, total_objects = validate(ov_model, tqdm(data_loader), validator) | ||
fp_stats, total_images, total_objects = validate_ov(ov_model, tqdm(data_loader), validator) | ||
print("Floating-point model validation results:") | ||
print_statistics(fp_stats, total_images, total_objects) | ||
|
||
# Validate quantized model | ||
q_stats, total_images, total_objects = validate(quantized_model, tqdm(data_loader), validator) | ||
q_stats, total_images, total_objects = validate_ov(quantized_model, tqdm(data_loader), validator) | ||
print("Quantized model validation results:") | ||
print_statistics(q_stats, total_images, total_objects) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.