diff --git a/README.md b/README.md index 29afabd1..a801234c 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,8 @@ # Last changes +* YOLOv8-obb detector worked with TensorRT! Export pretrained Pytorch models [here (ultralytics/ultralytics)](https://github.com/ultralytics/ultralytics) to onnx format and run Multitarget-tracker with -e=6 example + * YOLOv10 detector worked with TensorRT! Export pretrained Pytorch models [here (THU-MIG/yolov10)](https://github.com/THU-MIG/yolov10) to onnx format and run Multitarget-tracker with -e=6 example * YOLOv9 detector worked with TensorRT! Export pretrained Pytorch models [here (WongKinYiu/yolov9)](https://github.com/WongKinYiu/yolov9) to onnx format and run Multitarget-tracker with -e=6 example @@ -13,6 +15,11 @@ # New videos! +* YOLOv8-obb detection with rotated boxes (DOTA v1.0 trained) + +[![YOLOv8-obb detection:](https://img.youtube.com/vi/1e6ur57Fhzs/0.jpg)](https://youtu.be/1e6ur57Fhzs) + + * YOLOv7 instance segmentation [![YOLOv7 instance segmentation:](https://img.youtube.com/vi/gZxuYyFz1dU/0.jpg)](https://youtu.be/gZxuYyFz1dU) diff --git a/combined/combined.cpp b/combined/combined.cpp index af521b49..bfcfdbf2 100644 --- a/combined/combined.cpp +++ b/combined/combined.cpp @@ -328,6 +328,7 @@ bool CombinedDetector::InitDetector(cv::UMat frame) YOLOv7Mask, YOLOv8, YOLOv8Mask, + YOLOV8_OBB, YOLOv9 }; YOLOModels usedModel = YOLOModels::YOLOv8; diff --git a/data/DOTA.names b/data/DOTA.names new file mode 100644 index 00000000..adea7619 --- /dev/null +++ b/data/DOTA.names @@ -0,0 +1,15 @@ +plane +ship +storage_tank +baseball_diamond +tennis_court +basketball_court +ground_track_field +harbor +bridge +large_vehicle +small_vehicle +helicopter +roundabout +soccer_ball_field +swimming_pool \ No newline at end of file diff --git a/data/coco/full.names b/data/coco/full.names new file mode 100644 index 00000000..ca76c80b --- /dev/null +++ b/data/coco/full.names @@ -0,0 +1,80 @@ +person +bicycle +car +motorbike +aeroplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +sofa +pottedplant +bed +diningtable +toilet +tvmonitor +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/data/settings_yolov10.ini b/data/settings_yolov10.ini new file mode 100644 index 00000000..950982df --- /dev/null +++ b/data/settings_yolov10.ini @@ -0,0 +1,142 @@ +[detection] + +#----------------------------- +# opencv_dnn = 12 +# darknet_cudnn = 10 +# tensorrt = 11 +detector_backend = 12 + +#----------------------------- +# Target and backend for opencv_dnn detector +# DNN_TARGET_CPU +# DNN_TARGET_OPENCL +# DNN_TARGET_OPENCL_FP16 +# DNN_TARGET_MYRIAD +# DNN_TARGET_CUDA +# DNN_TARGET_CUDA_FP16 +ocv_dnn_target = DNN_TARGET_CPU + +# DNN_BACKEND_DEFAULT +# DNN_BACKEND_HALIDE +# DNN_BACKEND_INFERENCE_ENGINE +# DNN_BACKEND_OPENCV +# DNN_BACKEND_VKCOM +# DNN_BACKEND_CUDA +# DNN_BACKEND_INFERENCE_ENGINE_NGRAPH +# DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 +ocv_dnn_backend = DNN_BACKEND_OPENCV + +#----------------------------- +nn_weights = C:/work/home/mtracker/Multitarget-tracker/data/coco/yolov10s.onnx +nn_config = C:/work/home/mtracker/Multitarget-tracker/data/coco/yolov10s.onnx +class_names = C:/work/home/mtracker/Multitarget-tracker/data/coco/coco.names + +#----------------------------- +confidence_threshold = 0.3 + +max_crop_ratio = 0 +max_batch = 1 +gpu_id = 0 + +#----------------------------- +# YOLOV3 +# YOLOV4 +# YOLOV5 +net_type = YOLOV10 + +#----------------------------- +# INT8 +# FP16 +# FP32 +inference_precision = FP16 + + +[tracking] + +#----------------------------- +# DistCenters = 0 // Euclidean distance between centers, pixels +# DistRects = 1 // Euclidean distance between bounding rectangles, pixels +# DistJaccard = 2 // Intersection over Union, IoU, [0, 1] +# DistHist = 3 // Bhatacharia distance between histograms, [0, 1] + +distance_type = 0 + +#----------------------------- +# KalmanLinear = 0 +# KalmanUnscented = 1 + +kalman_type = 0 + +#----------------------------- +# FilterCenter = 0 +# FilterRect = 1 +# FilterRRect = 2 + +filter_goal = 0 + +#----------------------------- +# TrackNone = 0 +# TrackKCF = 1 +# TrackMIL = 2 +# TrackMedianFlow = 3 +# TrackGOTURN = 4 +# TrackMOSSE = 5 +# TrackCSRT = 6 +# TrackDAT = 7 +# TrackSTAPLE = 8 +# TrackLDES = 9 +# TrackDaSiamRPN = 10 +# Used if filter_goal == FilterRect + +lost_track_type = 0 + +#----------------------------- +# MatchHungrian = 0 +# MatchBipart = 1 + +match_type = 0 + +#----------------------------- +# Use constant acceleration motion model: +# 0 - unused (stable) +# 1 - use acceleration in Kalman filter (experimental) +use_aceleration = 0 + +#----------------------------- +# Delta time for Kalman filter +delta_time = 0.4 + +#----------------------------- +# Accel noise magnitude for Kalman filter +accel_noise = 0.2 + +#----------------------------- +# Distance threshold between region and object on two frames +dist_thresh = 0.8 + +#----------------------------- +# If this value > 0 than will be used circle with this radius +# If this value <= 0 than will be used ellipse with size (3*vx, 3*vy), vx and vy - horizontal and vertical speed in pixelsa +min_area_radius_pix = -1 + +#----------------------------- +# Minimal area radius in ration for object size. Used if min_area_radius_pix < 0 +min_area_radius_k = 0.8 + +#----------------------------- +# If the object do not assignment more than this frames then it will be removed +max_skip_frames = 50 + +#----------------------------- +# The maximum trajectory length +max_trace_len = 50 + +#----------------------------- +# Detection abandoned objects +detect_abandoned = 0 +# After this time (in seconds) the object is considered abandoned +min_static_time = 5 +# After this time (in seconds) the abandoned object will be removed +max_static_time = 25 +# Speed in pixels. If speed of object is more that this value than object is non static +max_speed_for_static = 10 diff --git a/data/settings_yolov8_obb.ini b/data/settings_yolov8_obb.ini new file mode 100644 index 00000000..ea95b634 --- /dev/null +++ b/data/settings_yolov8_obb.ini @@ -0,0 +1,142 @@ +[detection] + +#----------------------------- +# opencv_dnn = 12 +# darknet_cudnn = 10 +# tensorrt = 11 +detector_backend = 11 + +#----------------------------- +# Target and backend for opencv_dnn detector +# DNN_TARGET_CPU +# DNN_TARGET_OPENCL +# DNN_TARGET_OPENCL_FP16 +# DNN_TARGET_MYRIAD +# DNN_TARGET_CUDA +# DNN_TARGET_CUDA_FP16 +ocv_dnn_target = DNN_TARGET_CPU + +# DNN_BACKEND_DEFAULT +# DNN_BACKEND_HALIDE +# DNN_BACKEND_INFERENCE_ENGINE +# DNN_BACKEND_OPENCV +# DNN_BACKEND_VKCOM +# DNN_BACKEND_CUDA +# DNN_BACKEND_INFERENCE_ENGINE_NGRAPH +# DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 +ocv_dnn_backend = DNN_BACKEND_OPENCV + +#----------------------------- +nn_weights = C:/work/home/mtracker/Multitarget-tracker/data/yolov8x-obb.onnx +nn_config = C:/work/home/mtracker/Multitarget-tracker/data/yolov8x-obb.onnx +class_names = C:/work/home/mtracker/Multitarget-tracker/data/DOTA.names + +#----------------------------- +confidence_threshold = 0.6 + +max_crop_ratio = 1 +max_batch = 1 +gpu_id = 0 + +#----------------------------- +# YOLOV3 +# YOLOV4 +# YOLOV5 +net_type = YOLOV8_OBB + +#----------------------------- +# INT8 +# FP16 +# FP32 +inference_precision = FP16 + + +[tracking] + +#----------------------------- +# DistCenters = 0 // Euclidean distance between centers, pixels +# DistRects = 1 // Euclidean distance between bounding rectangles, pixels +# DistJaccard = 2 // Intersection over Union, IoU, [0, 1] +# DistHist = 3 // Bhatacharia distance between histograms, [0, 1] + +distance_type = 0 + +#----------------------------- +# KalmanLinear = 0 +# KalmanUnscented = 1 + +kalman_type = 0 + +#----------------------------- +# FilterCenter = 0 +# FilterRect = 1 +# FilterRRect = 2 + +filter_goal = 0 + +#----------------------------- +# TrackNone = 0 +# TrackKCF = 1 +# TrackMIL = 2 +# TrackMedianFlow = 3 +# TrackGOTURN = 4 +# TrackMOSSE = 5 +# TrackCSRT = 6 +# TrackDAT = 7 +# TrackSTAPLE = 8 +# TrackLDES = 9 +# TrackDaSiamRPN = 10 +# Used if filter_goal == FilterRect + +lost_track_type = 0 + +#----------------------------- +# MatchHungrian = 0 +# MatchBipart = 1 + +match_type = 0 + +#----------------------------- +# Use constant acceleration motion model: +# 0 - unused (stable) +# 1 - use acceleration in Kalman filter (experimental) +use_aceleration = 0 + +#----------------------------- +# Delta time for Kalman filter +delta_time = 0.4 + +#----------------------------- +# Accel noise magnitude for Kalman filter +accel_noise = 0.2 + +#----------------------------- +# Distance threshold between region and object on two frames +dist_thresh = 0.8 + +#----------------------------- +# If this value > 0 than will be used circle with this radius +# If this value <= 0 than will be used ellipse with size (3*vx, 3*vy), vx and vy - horizontal and vertical speed in pixelsa +min_area_radius_pix = -1 + +#----------------------------- +# Minimal area radius in ration for object size. Used if min_area_radius_pix < 0 +min_area_radius_k = 0.8 + +#----------------------------- +# If the object do not assignment more than this frames then it will be removed +max_skip_frames = 50 + +#----------------------------- +# The maximum trajectory length +max_trace_len = 50 + +#----------------------------- +# Detection abandoned objects +detect_abandoned = 0 +# After this time (in seconds) the object is considered abandoned +min_static_time = 5 +# After this time (in seconds) the abandoned object will be removed +max_static_time = 25 +# Speed in pixels. If speed of object is more that this value than object is non static +max_speed_for_static = 10 diff --git a/data/settings_yolov9.ini b/data/settings_yolov9.ini new file mode 100644 index 00000000..d34a44ea --- /dev/null +++ b/data/settings_yolov9.ini @@ -0,0 +1,142 @@ +[detection] + +#----------------------------- +# opencv_dnn = 12 +# darknet_cudnn = 10 +# tensorrt = 11 +detector_backend = 12 + +#----------------------------- +# Target and backend for opencv_dnn detector +# DNN_TARGET_CPU +# DNN_TARGET_OPENCL +# DNN_TARGET_OPENCL_FP16 +# DNN_TARGET_MYRIAD +# DNN_TARGET_CUDA +# DNN_TARGET_CUDA_FP16 +ocv_dnn_target = DNN_TARGET_CUDA_FP16 + +# DNN_BACKEND_DEFAULT +# DNN_BACKEND_HALIDE +# DNN_BACKEND_INFERENCE_ENGINE +# DNN_BACKEND_OPENCV +# DNN_BACKEND_VKCOM +# DNN_BACKEND_CUDA +# DNN_BACKEND_INFERENCE_ENGINE_NGRAPH +# DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 +ocv_dnn_backend = DNN_BACKEND_CUDA + +#----------------------------- +nn_weights = C:/work/home/mtracker/Multitarget-tracker/data/coco/yolov9-e.onnx +nn_config = C:/work/home/mtracker/Multitarget-tracker/data/coco/yolov9-e.onnx +class_names = C:/work/home/mtracker/Multitarget-tracker/data/coco/coco.names + +#----------------------------- +confidence_threshold = 0.3 + +max_crop_ratio = 0 +max_batch = 1 +gpu_id = 0 + +#----------------------------- +# YOLOV3 +# YOLOV4 +# YOLOV5 +net_type = YOLOV9 + +#----------------------------- +# INT8 +# FP16 +# FP32 +inference_precision = FP16 + + +[tracking] + +#----------------------------- +# DistCenters = 0 // Euclidean distance between centers, pixels +# DistRects = 1 // Euclidean distance between bounding rectangles, pixels +# DistJaccard = 2 // Intersection over Union, IoU, [0, 1] +# DistHist = 3 // Bhatacharia distance between histograms, [0, 1] + +distance_type = 0 + +#----------------------------- +# KalmanLinear = 0 +# KalmanUnscented = 1 + +kalman_type = 0 + +#----------------------------- +# FilterCenter = 0 +# FilterRect = 1 +# FilterRRect = 2 + +filter_goal = 0 + +#----------------------------- +# TrackNone = 0 +# TrackKCF = 1 +# TrackMIL = 2 +# TrackMedianFlow = 3 +# TrackGOTURN = 4 +# TrackMOSSE = 5 +# TrackCSRT = 6 +# TrackDAT = 7 +# TrackSTAPLE = 8 +# TrackLDES = 9 +# TrackDaSiamRPN = 10 +# Used if filter_goal == FilterRect + +lost_track_type = 0 + +#----------------------------- +# MatchHungrian = 0 +# MatchBipart = 1 + +match_type = 0 + +#----------------------------- +# Use constant acceleration motion model: +# 0 - unused (stable) +# 1 - use acceleration in Kalman filter (experimental) +use_aceleration = 0 + +#----------------------------- +# Delta time for Kalman filter +delta_time = 0.4 + +#----------------------------- +# Accel noise magnitude for Kalman filter +accel_noise = 0.2 + +#----------------------------- +# Distance threshold between region and object on two frames +dist_thresh = 0.8 + +#----------------------------- +# If this value > 0 than will be used circle with this radius +# If this value <= 0 than will be used ellipse with size (3*vx, 3*vy), vx and vy - horizontal and vertical speed in pixelsa +min_area_radius_pix = -1 + +#----------------------------- +# Minimal area radius in ration for object size. Used if min_area_radius_pix < 0 +min_area_radius_k = 0.8 + +#----------------------------- +# If the object do not assignment more than this frames then it will be removed +max_skip_frames = 50 + +#----------------------------- +# The maximum trajectory length +max_trace_len = 50 + +#----------------------------- +# Detection abandoned objects +detect_abandoned = 0 +# After this time (in seconds) the object is considered abandoned +min_static_time = 5 +# After this time (in seconds) the abandoned object will be removed +max_static_time = 25 +# Speed in pixels. If speed of object is more that this value than object is non static +max_speed_for_static = 10 diff --git a/example/examples.h b/example/examples.h index 62861622..1be76399 100644 --- a/example/examples.h +++ b/example/examples.h @@ -649,6 +649,7 @@ class YoloTensorRTExample final : public VideoExample YOLOv7, YOLOv7Mask, YOLOv8, + YOLOV8_OBB, YOLOv8Mask, YOLOv9, YOLOv10 @@ -746,6 +747,16 @@ class YoloTensorRTExample final : public VideoExample config.emplace("maxCropRatio", "-1"); break; + case YOLOModels::YOLOV8_OBB: + config.emplace("modelConfiguration", pathToModel + "yolov8s-obb.onnx"); + config.emplace("modelBinary", pathToModel + "yolov8s-obb.onnx"); + config.emplace("confidenceThreshold", "0.2"); + config.emplace("inference_precision", "FP16"); + config.emplace("net_type", "YOLOV8_OBB"); + maxBatch = 1; + config.emplace("maxCropRatio", "-1"); + break; + case YOLOModels::YOLOv8Mask: config.emplace("modelConfiguration", pathToModel + "yolov8s-seg.onnx"); config.emplace("modelBinary", pathToModel + "yolov8s-seg.onnx"); @@ -912,7 +923,7 @@ class YoloTensorRTExample final : public VideoExample } } - m_detector->CalcMotionMap(frame); + //m_detector->CalcMotionMap(frame); } }; diff --git a/src/Detector/BaseDetector.h b/src/Detector/BaseDetector.h index fc60b6ba..c585426e 100644 --- a/src/Detector/BaseDetector.h +++ b/src/Detector/BaseDetector.h @@ -82,7 +82,7 @@ class BaseDetector for (size_t i = 0; i < frames.size(); ++i) { Detect(frames[i]); - auto res = GetDetects(); + const auto& res = GetDetects(); regions[i].assign(std::begin(res), std::end(res)); } } diff --git a/src/Detector/OCVDNNDetector.cpp b/src/Detector/OCVDNNDetector.cpp index 82902be0..01d1102f 100644 --- a/src/Detector/OCVDNNDetector.cpp +++ b/src/Detector/OCVDNNDetector.cpp @@ -138,6 +138,7 @@ bool OCVDNNDetector::Init(const config_t& config) dictNetType["YOLOV7"] = ModelType::YOLOV7; dictNetType["YOLOV7Mask"] = ModelType::YOLOV7Mask; dictNetType["YOLOV8"] = ModelType::YOLOV8; + dictNetType["YOLOV8_OBB"] = ModelType::YOLOV8_OBB; dictNetType["YOLOV8Mask"] = ModelType::YOLOV8Mask; dictNetType["YOLOV9"] = ModelType::YOLOV9; dictNetType["YOLOV10"] = ModelType::YOLOV10; diff --git a/src/Detector/OCVDNNDetector.h b/src/Detector/OCVDNNDetector.h index ee9331ee..79842ba2 100644 --- a/src/Detector/OCVDNNDetector.h +++ b/src/Detector/OCVDNNDetector.h @@ -39,6 +39,7 @@ class OCVDNNDetector final : public BaseDetector YOLOV7, YOLOV7Mask, YOLOV8, + YOLOV8_OBB, YOLOV8Mask, YOLOV9, YOLOV10 diff --git a/src/Detector/YoloTensorRTDetector.cpp b/src/Detector/YoloTensorRTDetector.cpp index 43aab47e..a0ebeb44 100644 --- a/src/Detector/YoloTensorRTDetector.cpp +++ b/src/Detector/YoloTensorRTDetector.cpp @@ -50,6 +50,8 @@ YoloTensorRTDetector::YoloTensorRTDetector(const cv::Mat& colorFrame) /// bool YoloTensorRTDetector::Init(const config_t& config) { + //std::cout << "YoloTensorRTDetector::Init" << std::endl; + m_detector.reset(); auto modelConfiguration = config.find("modelConfiguration"); @@ -101,6 +103,7 @@ bool YoloTensorRTDetector::Init(const config_t& config) dictNetType["YOLOV7"] = tensor_rt::YOLOV7; dictNetType["YOLOV7Mask"] = tensor_rt::YOLOV7Mask; dictNetType["YOLOV8"] = tensor_rt::YOLOV8; + dictNetType["YOLOV8_OBB"] = tensor_rt::YOLOV8_OBB; dictNetType["YOLOV8Mask"] = tensor_rt::YOLOV8Mask; dictNetType["YOLOV9"] = tensor_rt::YOLOV9; dictNetType["YOLOV10"] = tensor_rt::YOLOV10; @@ -232,7 +235,12 @@ void YoloTensorRTDetector::Detect(const cv::UMat& colorFrame) for (const tensor_rt::Result& bbox : detects[j]) { if (m_classesWhiteList.empty() || m_classesWhiteList.find(T2T(bbox.m_id)) != std::end(m_classesWhiteList)) - tmpRegions.emplace_back(cv::Rect(bbox.m_brect.x + crop.x, bbox.m_brect.y + crop.y, bbox.m_brect.width, bbox.m_brect.height), T2T(bbox.m_id), bbox.m_prob); + { + cv::RotatedRect newRRect(bbox.m_rrect); + newRRect.center.x += crop.x; + newRRect.center.y += crop.y; + tmpRegions.emplace_back(newRRect, T2T(bbox.m_id), bbox.m_prob); + } } } } @@ -276,8 +284,8 @@ void YoloTensorRTDetector::Detect(const std::vector& frames, std::vect const tensor_rt::BatchResult& dets = detects[i]; for (const tensor_rt::Result& bbox : dets) { - if (m_classesWhiteList.empty() || m_classesWhiteList.find(T2T(bbox.m_id)) != std::end(m_classesWhiteList)) - regions[i].emplace_back(bbox.m_brect, T2T(bbox.m_id), bbox.m_prob); + if (m_classesWhiteList.empty() || m_classesWhiteList.find(T2T(bbox.m_id)) != std::end(m_classesWhiteList)) + regions[i].emplace_back(bbox.m_rrect, T2T(bbox.m_id), bbox.m_prob); } } m_regions.assign(std::begin(regions.back()), std::end(regions.back())); diff --git a/src/Detector/tensorrt_yolo/YoloONNXv8_obb.hpp b/src/Detector/tensorrt_yolo/YoloONNXv8_obb.hpp new file mode 100644 index 00000000..4c39c5a4 --- /dev/null +++ b/src/Detector/tensorrt_yolo/YoloONNXv8_obb.hpp @@ -0,0 +1,124 @@ +#pragma once + +#include "YoloONNX.hpp" + +/// +/// \brief The YOLOv8_obb_onnx class +/// +class YOLOv8_obb_onnx : public YoloONNX +{ +protected: + /// + /// \brief GetResult + /// \param output + /// \return + /// + std::vector GetResult(size_t imgIdx, int /*keep_topk*/, const std::vector& outputs, cv::Size frameSize) + { + std::vector resBoxes; + + //0: name: images, size: 1x3x1024x1024 + //1: name: output0, size: 1x20x21504 + //20: 15 DOTA classes + x + y + w + h + a + constexpr int shapeDataSize = 5; + + const float fw = static_cast(frameSize.width) / static_cast(m_inputDims.d[3]); + const float fh = static_cast(frameSize.height) / static_cast(m_inputDims.d[2]); + + auto output = outputs[0]; + + size_t ncInd = 1; + size_t lenInd = 2; + int nc = m_outpuDims[0].d[ncInd] - shapeDataSize; + int dimensions = nc + shapeDataSize; + size_t len = static_cast(m_outpuDims[0].d[lenInd]) / m_params.explicitBatchSize; + //auto Volume = [](const nvinfer1::Dims& d) + //{ + // return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies()); + //}; + auto volume = len * m_outpuDims[0].d[ncInd]; // Volume(m_outpuDims[0]); + output += volume * imgIdx; + //std::cout << "len = " << len << ", nc = " << nc << ", m_params.confThreshold = " << m_params.confThreshold << ", volume = " << volume << std::endl; + + cv::Mat rawMemory(1, dimensions * static_cast(len), CV_32FC1, output); + rawMemory = rawMemory.reshape(1, dimensions); + cv::transpose(rawMemory, rawMemory); + output = (float*)rawMemory.data; + + //std::cout << "output[0] mem:\n"; + //for (size_t ii = 0; ii < 100; ++ii) + //{ + // std::cout << ii << ": "; + // for (size_t jj = 0; jj < 20; ++jj) + // { + // std::cout << output[ii * 20 + jj] << " "; + // } + // std::cout << ";" << std::endl; + //} + //std::cout << ";" << std::endl; + + std::vector classIds; + std::vector confidences; + std::vector rectBoxes; + classIds.reserve(len); + confidences.reserve(len); + rectBoxes.reserve(len); + + for (size_t i = 0; i < len; ++i) + { + // Box + size_t k = i * (nc + shapeDataSize); + + int classId = -1; + float objectConf = 0.f; + for (int j = 0; j < nc; ++j) + { + const float classConf = output[k + 4 + j]; + if (classConf > objectConf) + { + classId = j; + objectConf = classConf; + } + } + + //if (i == 0) + //{ + // for (int jj = 0; jj < 20; ++jj) + // { + // std::cout << output[jj] << " "; + // } + // std::cout << std::endl; + //} + + if (objectConf >= m_params.confThreshold) + { + classIds.push_back(classId); + confidences.push_back(objectConf); + + // (center x, center y, width, height) + float cx = fw * output[k]; + float cy = fh * output[k + 1]; + float width = fw * output[k + 2]; + float height = fh * output[k + 3]; + float angle = 180.f * output[k + nc + shapeDataSize - 1] / M_PI; + rectBoxes.emplace_back(cv::Point2f(cx, cy), cv::Size2f(width, height), angle); + + //if (rectBoxes.size() == 1) + // std::cout << i << ": object_conf = " << objectConf << ", classId = " << classId << ", rect = " << rectBoxes.back().boundingRect() << ", angle = " << angle << std::endl; + } + } + + // Non-maximum suppression to eliminate redudant overlapping boxes + //std::vector indices; + //cv::dnn::NMSBoxes(rectBoxes, confidences, m_params.confThreshold, m_params.nmsThreshold, indices); + //resBoxes.reserve(indices.size()); + + resBoxes.reserve(rectBoxes.size()); + for (size_t bi = 0; bi < rectBoxes.size(); ++bi) + { + resBoxes.emplace_back(classIds[bi], confidences[bi], rectBoxes[bi]); + } + + return resBoxes; + } +}; diff --git a/src/Detector/tensorrt_yolo/YoloONNXv9_bb.hpp b/src/Detector/tensorrt_yolo/YoloONNXv9_bb.hpp index 6c821351..f4c99ebd 100644 --- a/src/Detector/tensorrt_yolo/YoloONNXv9_bb.hpp +++ b/src/Detector/tensorrt_yolo/YoloONNXv9_bb.hpp @@ -19,6 +19,8 @@ class YOLOv9_bb_onnx : public YoloONNX //0: name: images, size: 1x3x640x640 //1: name: output0, size: 1x84x8400 + //84: 80 COCO classes + x + y + w + h + constexpr int shapeDataSize = 4; const float fw = static_cast(frameSize.width) / static_cast(m_inputDims.d[3]); const float fh = static_cast(frameSize.height) / static_cast(m_inputDims.d[2]); @@ -27,8 +29,8 @@ class YOLOv9_bb_onnx : public YoloONNX size_t ncInd = 1; size_t lenInd = 2; - int nc = m_outpuDims[0].d[ncInd] - 4; - int dimensions = nc + 4; + int nc = m_outpuDims[0].d[ncInd] - shapeDataSize; + int dimensions = nc + shapeDataSize; size_t len = static_cast(m_outpuDims[0].d[lenInd]) / m_params.explicitBatchSize; //auto Volume = [](const nvinfer1::Dims& d) //{ @@ -65,13 +67,13 @@ class YOLOv9_bb_onnx : public YoloONNX for (size_t i = 0; i < len; ++i) { // Box - size_t k = i * (nc + 4); + size_t k = i * (nc + shapeDataSize); int classId = -1; float objectConf = 0.f; for (int j = 0; j < nc; ++j) { - const float classConf = output[k + 4 + j]; + const float classConf = output[k + shapeDataSize + j]; if (classConf > objectConf) { classId = j; diff --git a/src/Detector/tensorrt_yolo/class_detector.cpp b/src/Detector/tensorrt_yolo/class_detector.cpp index d8d58b18..f7a18e23 100644 --- a/src/Detector/tensorrt_yolo/class_detector.cpp +++ b/src/Detector/tensorrt_yolo/class_detector.cpp @@ -6,6 +6,7 @@ #include "YoloONNXv7_bb.hpp" #include "YoloONNXv7_instance.hpp" #include "YoloONNXv8_bb.hpp" +#include "YoloONNXv8_obb.hpp" #include "YoloONNXv8_instance.hpp" #include "YoloONNXv9_bb.hpp" #include "YoloONNXv10_bb.hpp" @@ -88,6 +89,11 @@ namespace tensor_rt m_params.outputTensorNames.push_back("output0"); m_detector = std::make_unique(); break; + case ModelType::YOLOV8_OBB: + m_params.inputTensorNames.push_back("images"); + m_params.outputTensorNames.push_back("output0"); + m_detector = std::make_unique(); + break; case ModelType::YOLOV8Mask: m_params.inputTensorNames.push_back("images"); m_params.outputTensorNames.push_back("output0"); @@ -186,7 +192,7 @@ namespace tensor_rt if (config.net_type == ModelType::YOLOV6 || config.net_type == ModelType::YOLOV7 || config.net_type == ModelType::YOLOV7Mask || - config.net_type == ModelType::YOLOV8 || config.net_type == ModelType::YOLOV8Mask || + config.net_type == ModelType::YOLOV8 || config.net_type == ModelType::YOLOV8_OBB || config.net_type == ModelType::YOLOV8Mask || config.net_type == ModelType::YOLOV9 || config.net_type == ModelType::YOLOV10) m_impl = new YoloONNXImpl(); else diff --git a/src/Detector/tensorrt_yolo/class_detector.h b/src/Detector/tensorrt_yolo/class_detector.h index 4c8e2911..1dd85d70 100644 --- a/src/Detector/tensorrt_yolo/class_detector.h +++ b/src/Detector/tensorrt_yolo/class_detector.h @@ -51,6 +51,7 @@ namespace tensor_rt YOLOV7, YOLOV7Mask, YOLOV8, + YOLOV8_OBB, YOLOV8Mask, YOLOV9, YOLOV10 diff --git a/src/Detector/tensorrt_yolo/ds_image.cpp b/src/Detector/tensorrt_yolo/ds_image.cpp index 82c69b2a..b801b874 100644 --- a/src/Detector/tensorrt_yolo/ds_image.cpp +++ b/src/Detector/tensorrt_yolo/ds_image.cpp @@ -49,7 +49,7 @@ DsImage::DsImage(const cv::Mat& mat_image_, tensor_rt::ModelType net_type, const } if (tensor_rt::ModelType::YOLOV5 == net_type || tensor_rt::ModelType::YOLOV6 == net_type || tensor_rt::ModelType::YOLOV7 == net_type || tensor_rt::ModelType::YOLOV7Mask == net_type || - tensor_rt::ModelType::YOLOV8 == net_type || tensor_rt::ModelType::YOLOV8Mask == net_type || + tensor_rt::ModelType::YOLOV8 == net_type || tensor_rt::ModelType::YOLOV8_OBB == net_type || tensor_rt::ModelType::YOLOV8Mask == net_type || tensor_rt::ModelType::YOLOV9 == net_type || tensor_rt::ModelType::YOLOV10 == net_type) { // resize the DsImage with scale @@ -100,7 +100,7 @@ DsImage::DsImage(const std::string& path, tensor_rt::ModelType net_type, const i if (tensor_rt::ModelType::YOLOV5 == net_type || tensor_rt::ModelType::YOLOV6 == net_type || tensor_rt::ModelType::YOLOV7 == net_type || tensor_rt::ModelType::YOLOV7Mask == net_type || - tensor_rt::ModelType::YOLOV8 == net_type || tensor_rt::ModelType::YOLOV8Mask == net_type || + tensor_rt::ModelType::YOLOV8 == net_type || tensor_rt::ModelType::YOLOV8_OBB == net_type || tensor_rt::ModelType::YOLOV8Mask == net_type || tensor_rt::ModelType::YOLOV9 == net_type || tensor_rt::ModelType::YOLOV10 == net_type) { // resize the DsImage with scale