From baecd89279a1e524d94b1229e067ef4fd6437d9d Mon Sep 17 00:00:00 2001 From: Nuzhny007 Date: Sun, 25 Feb 2024 22:52:01 +0300 Subject: [PATCH] YOLOv9 works with TensorRT backend --- README.md | 6 ++---- combined/combined.cpp | 17 +++++++++++++++-- example/examples.h | 15 +++++++++++++-- src/Detector/OCVDNNDetector.cpp | 5 +++-- src/Detector/OCVDNNDetector.h | 3 ++- src/Detector/YoloTensorRTDetector.cpp | 1 + src/Detector/tensorrt_yolo/YoloONNX.cpp | 2 +- src/Detector/tensorrt_yolo/class_detector.cpp | 13 ++++++++++--- src/Detector/tensorrt_yolo/class_detector.h | 3 ++- src/Detector/tensorrt_yolo/ds_image.cpp | 6 ++++-- 10 files changed, 53 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 825dfbab..bfec439c 100644 --- a/README.md +++ b/README.md @@ -3,14 +3,12 @@ # Last changes +* YOLOv9 detector worked with TensorRT! Export pretrained Pytorch models [here (WongKinYiu/yolov9)](https://github.com/WongKinYiu/yolov9) to onnx format and run Multitarget-tracker with -e=6 example + * YOLOv8 instance segmentation models worked with TensorRT! Export pretrained Pytorch models [here (ultralytics/ultralytics)](https://github.com/ultralytics/ultralytics) to onnx format and run Multitarget-tracker with -e=6 example * Re-identification model osnet_x0_25_msmt17 from [mikel-brostrom/yolo_tracking](https://github.com/mikel-brostrom/yolo_tracking) -* YOLOv8 detector worked with TensorRT! Export pretrained Pytorch models [here (ultralytics/ultralytics)](https://github.com/ultralytics/ultralytics) to onnx format and run Multitarget-tracker with -e=6 example - -* Some experiments with YOLOv7_mask and results with rotated rectangles: detector works tracker in progress - # New videos! * YOLOv7 instance segmentation diff --git a/combined/combined.cpp b/combined/combined.cpp index edd9f18b..af521b49 100644 --- a/combined/combined.cpp +++ b/combined/combined.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -327,7 +327,8 @@ bool CombinedDetector::InitDetector(cv::UMat frame) YOLOv7, YOLOv7Mask, YOLOv8, - YOLOv8Mask + YOLOv8Mask, + YOLOv9 }; YOLOModels usedModel = YOLOModels::YOLOv8; switch (usedModel) @@ -439,6 +440,18 @@ bool CombinedDetector::InitDetector(cv::UMat frame) maxBatch = 1; configDNN.emplace("maxCropRatio", "-1"); break; + + case YOLOModels::YOLOv9: + configDNN.emplace("modelConfiguration", pathToModel + "yolov9-c.onnx"); + configDNN.emplace("modelBinary", pathToModel + "yolov9-c.onnx"); + configDNN.emplace("confidenceThreshold", "0.2"); + configDNN.emplace("inference_precision", "FP16"); + configDNN.emplace("net_type", "YOLOV9"); + configDNN.emplace("inWidth", "640"); + configDNN.emplace("inHeight", "640"); + maxBatch = 1; + configDNN.emplace("maxCropRatio", "-1"); + break; } configDNN.emplace("maxBatch", std::to_string(maxBatch)); configDNN.emplace("classNames", pathToModel + "coco.names"); diff --git a/example/examples.h b/example/examples.h index 8e96b57e..56f03bbd 100644 --- a/example/examples.h +++ b/example/examples.h @@ -834,9 +834,10 @@ class YoloTensorRTExample final : public VideoExample YOLOv7, YOLOv7Mask, YOLOv8, - YOLOv8Mask + YOLOv8Mask, + YOLOv9 }; - YOLOModels usedModel = YOLOModels::YOLOv4; + YOLOModels usedModel = YOLOModels::YOLOv9; switch (usedModel) { case YOLOModels::TinyYOLOv3: @@ -938,6 +939,16 @@ class YoloTensorRTExample final : public VideoExample maxBatch = 1; config.emplace("maxCropRatio", "-1"); break; + + case YOLOModels::YOLOv9: + config.emplace("modelConfiguration", pathToModel + "yolov9-c.onnx"); + config.emplace("modelBinary", pathToModel + "yolov9-c.onnx"); + config.emplace("confidenceThreshold", "0.2"); + config.emplace("inference_precision", "FP32"); + config.emplace("net_type", "YOLOV9"); + maxBatch = 1; + config.emplace("maxCropRatio", "-1"); + break; } if (maxBatch < m_batchSize) maxBatch = m_batchSize; diff --git a/src/Detector/OCVDNNDetector.cpp b/src/Detector/OCVDNNDetector.cpp index 945bb0c5..737b1227 100644 --- a/src/Detector/OCVDNNDetector.cpp +++ b/src/Detector/OCVDNNDetector.cpp @@ -139,6 +139,7 @@ bool OCVDNNDetector::Init(const config_t& config) dictNetType["YOLOV7Mask"] = ModelType::YOLOV7Mask; dictNetType["YOLOV8"] = ModelType::YOLOV8; dictNetType["YOLOV8Mask"] = ModelType::YOLOV8Mask; + dictNetType["YOLOV9"] = ModelType::YOLOV9; auto netType = dictNetType.find(net_type->second); if (netType != dictNetType.end()) @@ -345,7 +346,7 @@ void OCVDNNDetector::DetectInCrop(const cv::UMat& colorFrame, const cv::Rect& cr } else { - if (m_netType == ModelType::YOLOV8 || m_netType == ModelType::YOLOV5) + if (m_netType == ModelType::YOLOV8 || m_netType == ModelType::YOLOV5 || m_netType == ModelType::YOLOV9) { int rows = detections[0].size[1]; int dimensions = detections[0].size[2]; @@ -367,7 +368,7 @@ void OCVDNNDetector::DetectInCrop(const cv::UMat& colorFrame, const cv::Rect& cr for (int i = 0; i < rows; ++i) { - if (m_netType == ModelType::YOLOV8) + if (m_netType == ModelType::YOLOV8 || m_netType == ModelType::YOLOV9) { float* classes_scores = data + 4; diff --git a/src/Detector/OCVDNNDetector.h b/src/Detector/OCVDNNDetector.h index 7677f7e1..6a014379 100644 --- a/src/Detector/OCVDNNDetector.h +++ b/src/Detector/OCVDNNDetector.h @@ -39,7 +39,8 @@ class OCVDNNDetector final : public BaseDetector YOLOV7, YOLOV7Mask, YOLOV8, - YOLOV8Mask + YOLOV8Mask, + YOLOV9 }; cv::dnn::Net m_net; diff --git a/src/Detector/YoloTensorRTDetector.cpp b/src/Detector/YoloTensorRTDetector.cpp index 76c91ca4..399772d3 100644 --- a/src/Detector/YoloTensorRTDetector.cpp +++ b/src/Detector/YoloTensorRTDetector.cpp @@ -102,6 +102,7 @@ bool YoloTensorRTDetector::Init(const config_t& config) dictNetType["YOLOV7Mask"] = tensor_rt::YOLOV7Mask; dictNetType["YOLOV8"] = tensor_rt::YOLOV8; dictNetType["YOLOV8Mask"] = tensor_rt::YOLOV8Mask; + dictNetType["YOLOV9"] = tensor_rt::YOLOV9; auto netType = dictNetType.find(net_type->second); if (netType != dictNetType.end()) diff --git a/src/Detector/tensorrt_yolo/YoloONNX.cpp b/src/Detector/tensorrt_yolo/YoloONNX.cpp index 5d99bb67..b4883cf3 100644 --- a/src/Detector/tensorrt_yolo/YoloONNX.cpp +++ b/src/Detector/tensorrt_yolo/YoloONNX.cpp @@ -1087,7 +1087,7 @@ void YoloONNX::ProcessBBoxesOutput(size_t imgIdx, const std::vector& out } else if (outputs.size() == 1) { - if (m_params.m_netType == tensor_rt::ModelType::YOLOV8) + if (m_params.m_netType == tensor_rt::ModelType::YOLOV8 || m_params.m_netType == tensor_rt::ModelType::YOLOV9) { //0: name: images, size: 1x3x640x640 //1: name: output0, size: 1x84x8400 diff --git a/src/Detector/tensorrt_yolo/class_detector.cpp b/src/Detector/tensorrt_yolo/class_detector.cpp index ee5b9bda..e18837c7 100644 --- a/src/Detector/tensorrt_yolo/class_detector.cpp +++ b/src/Detector/tensorrt_yolo/class_detector.cpp @@ -56,7 +56,9 @@ namespace tensor_rt // Input tensor name of ONNX file & engine file if (config.net_type == ModelType::YOLOV6) m_params.inputTensorNames.push_back("image_arrays"); - else if (config.net_type == ModelType::YOLOV7 || config.net_type == ModelType::YOLOV7Mask || config.net_type == ModelType::YOLOV8 || config.net_type == ModelType::YOLOV8Mask) + else if (config.net_type == ModelType::YOLOV7 || config.net_type == ModelType::YOLOV7Mask || + config.net_type == ModelType::YOLOV8 || config.net_type == ModelType::YOLOV8Mask || + config.net_type == ModelType::YOLOV9) m_params.inputTensorNames.push_back("images"); // Threshold values @@ -76,7 +78,9 @@ namespace tensor_rt { m_params.outputTensorNames.push_back("outputs"); } - else if (config.net_type == ModelType::YOLOV7 || config.net_type == ModelType::YOLOV7Mask || config.net_type == ModelType::YOLOV8 || config.net_type == ModelType::YOLOV8Mask) + else if (config.net_type == ModelType::YOLOV7 || config.net_type == ModelType::YOLOV7Mask || + config.net_type == ModelType::YOLOV8 || config.net_type == ModelType::YOLOV8Mask || + config.net_type == ModelType::YOLOV9) { //if (config.batch_size == 1) //{ @@ -161,7 +165,10 @@ namespace tensor_rt if (m_impl) delete m_impl; - if (config.net_type == ModelType::YOLOV6 || config.net_type == ModelType::YOLOV7 || config.net_type == ModelType::YOLOV7Mask || config.net_type == ModelType::YOLOV8 || config.net_type == ModelType::YOLOV8Mask) + if (config.net_type == ModelType::YOLOV6 || + config.net_type == ModelType::YOLOV7 || config.net_type == ModelType::YOLOV7Mask || + config.net_type == ModelType::YOLOV8 || config.net_type == ModelType::YOLOV8Mask || + config.net_type == ModelType::YOLOV9) m_impl = new YoloONNXImpl(); else m_impl = new YoloDectectorImpl(); diff --git a/src/Detector/tensorrt_yolo/class_detector.h b/src/Detector/tensorrt_yolo/class_detector.h index b0652856..a6e869b6 100644 --- a/src/Detector/tensorrt_yolo/class_detector.h +++ b/src/Detector/tensorrt_yolo/class_detector.h @@ -51,7 +51,8 @@ namespace tensor_rt YOLOV7, YOLOV7Mask, YOLOV8, - YOLOV8Mask + YOLOV8Mask, + YOLOV9 }; /// diff --git a/src/Detector/tensorrt_yolo/ds_image.cpp b/src/Detector/tensorrt_yolo/ds_image.cpp index 6f3843a8..1d1c4d9a 100644 --- a/src/Detector/tensorrt_yolo/ds_image.cpp +++ b/src/Detector/tensorrt_yolo/ds_image.cpp @@ -49,7 +49,8 @@ DsImage::DsImage(const cv::Mat& mat_image_, tensor_rt::ModelType net_type, const } if (tensor_rt::ModelType::YOLOV5 == net_type || tensor_rt::ModelType::YOLOV6 == net_type || tensor_rt::ModelType::YOLOV7 == net_type || tensor_rt::ModelType::YOLOV7Mask == net_type || - tensor_rt::ModelType::YOLOV8 == net_type || tensor_rt::ModelType::YOLOV8Mask == net_type) + tensor_rt::ModelType::YOLOV8 == net_type || tensor_rt::ModelType::YOLOV8Mask == net_type || + tensor_rt::ModelType::YOLOV9 == net_type) { // resize the DsImage with scale float r = std::min(static_cast(inputH) / static_cast(m_Height), static_cast(inputW) / static_cast(m_Width)); @@ -99,7 +100,8 @@ DsImage::DsImage(const std::string& path, tensor_rt::ModelType net_type, const i if (tensor_rt::ModelType::YOLOV5 == net_type || tensor_rt::ModelType::YOLOV6 == net_type || tensor_rt::ModelType::YOLOV7 == net_type || tensor_rt::ModelType::YOLOV7Mask == net_type || - tensor_rt::ModelType::YOLOV8 == net_type || tensor_rt::ModelType::YOLOV8Mask == net_type) + tensor_rt::ModelType::YOLOV8 == net_type || tensor_rt::ModelType::YOLOV8Mask == net_type || + tensor_rt::ModelType::YOLOV9 == net_type) { // resize the DsImage with scale float dim = std::max(m_Height, m_Width);