forked from wang-xinyu/tensorrtx
-
Notifications
You must be signed in to change notification settings - Fork 2
/
yololayer.h
158 lines (120 loc) · 4.9 KB
/
yololayer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#ifndef _YOLO_LAYER_H
#define _YOLO_LAYER_H
#include <assert.h>
#include <cmath>
#include <string.h>
#include <cublas_v2.h>
#include "NvInfer.h"
#include "Utils.h"
#include <iostream>
namespace Yolo
{
static constexpr int CHECK_COUNT = 3;
static constexpr float IGNORE_THRESH = 0.1f;
static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000;
static constexpr int CLASS_NUM = 80;
struct YoloKernel
{
int width;
int height;
int stride;
float anchors[CHECK_COUNT*2];
};
static constexpr YoloKernel yolo1 = {
-1, // dynamic width and height
-1,
32,
{116,90, 156,198, 373,326}
};
static constexpr YoloKernel yolo2 = {
-1,
-1,
16,
{30,61, 62,45, 59,119}
};
static constexpr YoloKernel yolo3 = {
-1,
-1,
8,
{10,13, 16,30, 33,23}
};
static constexpr int LOCATIONS = 4;
struct alignas(float) Detection{
//x y w h
float bbox[LOCATIONS];
float det_confidence;
float class_id;
float class_confidence;
};
}
namespace nvinfer1
{
class YoloLayerPlugin: public IPluginV2DynamicExt
{
public:
explicit YoloLayerPlugin();
YoloLayerPlugin(const void* data, size_t length);
~YoloLayerPlugin();
int getNbOutputs() const override
{
return 1;
}
//virtual Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) final;
virtual DimsExprs getOutputDimensions(int outputIndex, const DimsExprs* inputs, int nbInputs, IExprBuilder& exprBuilder) override;
int initialize() override;
virtual void terminate() override {};
//virtual size_t getWorkspaceSize(int maxBatchSize) const override { return 0;}
size_t getWorkspaceSize(const PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, int nbOutputs) const override { return 0; }
//virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override;
int enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) override;
virtual size_t getSerializationSize() const override;
virtual void serialize(void* buffer) const override;
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override {
return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
}
const char* getPluginType() const override;
const char* getPluginVersion() const override;
void destroy() override;
IPluginV2DynamicExt* clone() const override;
void setPluginNamespace(const char* pluginNamespace) override;
const char* getPluginNamespace() const override;
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
void attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override;
void configurePlugin(const DynamicPluginTensorDesc* in, int nbInputs, const DynamicPluginTensorDesc* out, int nbOutputs) override;
void detachFromContext() override;
private:
void forwardGpu(const float *const * inputs,float * output, cudaStream_t stream,int batchSize = 1);
int mClassCount;
int mKernelCount;
std::vector<Yolo::YoloKernel> mYoloKernel;
int mThreadCount = 256;
void** mAnchor;
const char* mPluginNamespace;
};
class YoloPluginCreator : public IPluginCreator
{
public:
YoloPluginCreator();
~YoloPluginCreator() override = default;
const char* getPluginName() const override;
const char* getPluginVersion() const override;
const PluginFieldCollection* getFieldNames() override;
IPluginV2DynamicExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
IPluginV2DynamicExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
void setPluginNamespace(const char* libNamespace) override
{
mNamespace = libNamespace;
}
const char* getPluginNamespace() const override
{
return mNamespace.c_str();
}
private:
std::string mNamespace;
static PluginFieldCollection mFC;
static std::vector<PluginField> mPluginAttributes;
};
REGISTER_TENSORRT_PLUGIN(YoloPluginCreator);
};
#endif