Skip to content

Commit

Permalink
Merge pull request #31 from habbasian/pr
Browse files Browse the repository at this point in the history
Updating TensorRT-introduction to work with TRT7 and Dynamic Shape
  • Loading branch information
harrism authored Apr 8, 2020
2 parents bfe1d42 + c4da909 commit c2e58d9
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 107 deletions.
8 changes: 4 additions & 4 deletions posts/TensorRT-introduction/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
CUDA_INSTALL_DIR=/usr/local/cuda

CXXFLAGS=-std=c++11 -Wall -I$(CUDA_INSTALL_DIR)/include
LDFLAGS=-L$(CUDA_INSTALL_DIR)/lib64 -L$(CUDA_INSTALL_DIR)/lib64/stubs
LDLIBS=-Wl,--start-group -lnvinfer -lnvonnxparser -lcudart_static -lrt -ldl -lpthread -lonnx -lonnx_proto -lprotobuf -lstdc++ -lm -Wl,--end-group
CXXFLAGS=-std=c++11 -DONNX_ML=1 -Wall -I$(CUDA_INSTALL_DIR)/include
LDFLAGS=-L$(CUDA_INSTALL_DIR)/lib64 -L$(CUDA_INSTALL_DIR)/lib64/stubs -L/usr/local/lib
LDLIBS=-Wl,--start-group -lnvonnxparser -lnvinfer -lcudart_static -lonnx -lonnx_proto -lprotobuf -lstdc++ -lm -lrt -ldl -lpthread -Wl,--end-group

HEADERS=${wildcard *.h}
TARGET_SRCS=$(wildcard simpleOnnx*.cpp)
TARGET_OBJS=${TARGET_SRCS:.cpp=.o}
TARGETS=${TARGET_OBJS:.o=}


all: $(TARGETS)

$(TARGETS): %: %.o ioHelper.o
Expand Down
6 changes: 2 additions & 4 deletions posts/TensorRT-introduction/ioHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "ioHelper.h"
#include <algorithm>
#include <fstream>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <iterator>
#include <onnx/onnx_pb.h>

#include "ioHelper.h"
using namespace std;

namespace nvinfer1
Expand Down Expand Up @@ -83,8 +82,7 @@ size_t readTensor(vector<string> const& tensorProtoPaths, vector<float>& buffer)

for (size_t i = 0; i < tensorProtoPaths.size(); ++i)
{
size_t elements = readTensorProto(tensorProtoPaths[i], &buffer[totalElements]);
if (!elements)
size_t elements = readTensorProto(tensorProtoPaths[i], &buffer[totalElements]); if (!elements)
{
cout << "ERROR: could not read tensor from file " << tensorProtoPaths[i] << endl;
break;
Expand Down
Binary file added posts/TensorRT-introduction/ioHelper.o
Binary file not shown.
71 changes: 35 additions & 36 deletions posts/TensorRT-introduction/simpleOnnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,20 @@
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <NvInfer.h>
#include "cudaWrapper.h"
#include "ioHelper.h"
#include <NvInfer.h>
#include <NvOnnxParser.h>
#include <algorithm>
#include <functional>
#include <cmath>
#include <cassert>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include <numeric>
#include <math.h>

using namespace nvinfer1;
using namespace std;
Expand All @@ -51,25 +55,30 @@ constexpr double REL_EPSILON = 0.05;
constexpr size_t MAX_WORKSPACE_SIZE = 1ULL << 30; // 1 GB

ICudaEngine* createCudaEngine(string const& onnxModelPath, int batchSize)
{
unique_ptr<IBuilder, Destroy<IBuilder>> builder{createInferBuilder(gLogger)};
unique_ptr<INetworkDefinition, Destroy<INetworkDefinition>> network{builder->createNetwork()};
{
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
unique_ptr<nvinfer1::IBuilder, Destroy<nvinfer1::IBuilder>> builder{nvinfer1::createInferBuilder(gLogger)};
unique_ptr<nvinfer1::INetworkDefinition, Destroy<nvinfer1::INetworkDefinition>> network{builder->createNetworkV2(explicitBatch)};
unique_ptr<nvonnxparser::IParser, Destroy<nvonnxparser::IParser>> parser{nvonnxparser::createParser(*network, gLogger)};
unique_ptr<nvinfer1::IBuilderConfig,Destroy<nvinfer1::IBuilderConfig>> config{builder->createBuilderConfig()};

if (!parser->parseFromFile(onnxModelPath.c_str(), static_cast<int>(ILogger::Severity::kINFO)))
{
cout << "ERROR: could not parse input engine." << endl;
return nullptr;
}

// Build TensorRT engine optimized based on for batch size of input data provided.
builder->setMaxBatchSize(batchSize);
// Allow TensorRT to use fp16 mode kernels internally.
// Note that Input and Output tensors will still use 32 bit float type by default.
config->setMaxWorkspaceSize(MAX_WORKSPACE_SIZE);
builder->setFp16Mode(builder->platformHasFastFp16());
builder->setMaxWorkspaceSize(MAX_WORKSPACE_SIZE);

return builder->buildCudaEngine(*network); // Build and return TensorRT engine.
builder->setMaxBatchSize(batchSize);

auto profile = builder->createOptimizationProfile();
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMIN, Dims4{1, 3, 256 , 256});
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kOPT, Dims4{1, 3, 256 , 256});
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMAX, Dims4{32, 3, 256 , 256});
config->addOptimizationProfile(profile);

return builder->buildEngineWithConfig(*network, *config);
}

ICudaEngine* getCudaEngine(string const& onnxModelPath, int batchSize)
Expand All @@ -78,11 +87,13 @@ ICudaEngine* getCudaEngine(string const& onnxModelPath, int batchSize)
ICudaEngine* engine{nullptr};

string buffer = readBuffer(enginePath);

if (buffer.size())
{
// Try to deserialize engine.
unique_ptr<IRuntime, Destroy<IRuntime>> runtime{createInferRuntime(gLogger)};
engine = runtime->deserializeCudaEngine(buffer.data(), buffer.size(), nullptr);

}

if (!engine)
Expand Down Expand Up @@ -110,7 +121,7 @@ void launchInference(IExecutionContext* context, cudaStream_t stream, vector<flo
int inputId = getBindingInputIndex(context);

cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, stream);
context->enqueue(batchSize, bindings, stream, nullptr);
context->enqueueV2(bindings, stream, nullptr);
cudaMemcpyAsync(outputTensor.data(), bindings[1 - inputId], outputTensor.size() * sizeof(float), cudaMemcpyDeviceToHost, stream);
}

Expand Down Expand Up @@ -139,23 +150,9 @@ void doInference(IExecutionContext* context, cudaStream_t stream, vector<float>
cout << "Inference batch size " << batchSize << " average over " << ITERATIONS << " runs is " << totalTime / ITERATIONS << "ms" << endl;
}

void softmax(vector<float>& tensor, int batchSize)
void verifyOutput(vector<float> const& outputTensor, vector<float> const& referenceTensor, int size)
{
size_t batchElements = tensor.size() / batchSize;

for (int i = 0; i < batchSize; ++i)
{
float* batchVector = &tensor[i * batchElements];
double maxValue = *max_element(batchVector, batchVector + batchElements);
double expSum = accumulate(batchVector, batchVector + batchElements, 0.0, [=](double acc, float value) { return acc + exp(value - maxValue); });

transform(batchVector, batchVector + batchElements, batchVector, [=](float input) { return static_cast<float>(std::exp(input - maxValue) / expSum); });
}
}

void verifyOutput(vector<float> const& outputTensor, vector<float> const& referenceTensor)
{
for (size_t i = 0; i < referenceTensor.size(); ++i)
for (size_t i = 0; i < size; ++i)
{
double reference = static_cast<double>(referenceTensor[i]);
// Check absolute and relative tolerance.
Expand Down Expand Up @@ -207,9 +204,9 @@ int main(int argc, char* argv[])
for (int i = 0; i < engine->getNbBindings(); ++i)
{
Dims dims{engine->getBindingDimensions(i)};
size_t size = accumulate(dims.d, dims.d + dims.nbDims, batchSize, multiplies<size_t>());
size_t size = std::accumulate(dims.d+1, dims.d + dims.nbDims, batchSize, multiplies<size_t>());
// Create CUDA buffer for Tensor.
cudaMalloc(&bindings[i], size * sizeof(float));
cudaMalloc(&bindings[i], batchSize * size * sizeof(float));

// Resize CPU buffers to fit Tensor.
if (engine->bindingIsInput(i))
Expand All @@ -228,6 +225,10 @@ int main(int argc, char* argv[])
// Create Execution Context.
context.reset(engine->createExecutionContext());

Dims dims_i{engine->getBindingDimensions(0)};
Dims4 inputDims{batchSize, dims_i.d[1], dims_i.d[2], dims_i.d[3]};
context->setBindingDimensions(0, inputDims);

doInference(context.get(), stream, inputTensor, outputTensor, bindings, batchSize);

vector<string> referenceFiles;
Expand All @@ -240,12 +241,10 @@ int main(int argc, char* argv[])
cout << "Couldn't read reference Tensor" << endl;
return 1;
}

// Apply a softmax on the CPU to create a normalized distribution suitable for measuring relative error in probabilities.
softmax(outputTensor, batchSize);
softmax(referenceTensor, batchSize);

verifyOutput(outputTensor, referenceTensor);

Dims dims_o{engine->getBindingDimensions(1)};
int size = batchSize * dims_o.d[2] * dims_o.d[3];
verifyOutput(outputTensor, referenceTensor, size);

for (void* ptr : bindings)
cudaFree(ptr);
Expand Down
98 changes: 65 additions & 33 deletions posts/TensorRT-introduction/simpleOnnx_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,19 @@
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <NvInfer.h>
#include "cudaWrapper.h"
#include "ioHelper.h"
#include <NvInfer.h>
#include <NvOnnxParser.h>
#include <algorithm>
#include <cassert>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include <numeric>
#include <math.h>
#include <cmath>

using namespace nvinfer1;
using namespace std;
Expand All @@ -46,52 +49,49 @@ constexpr double ABS_EPSILON = 0.005;
// Maxmimum relative tolerance for output tensor comparison against reference.
constexpr double REL_EPSILON = 0.05;

ICudaEngine* createCudaEngine(string const& onnxModelPath, int batchSize)
nvinfer1::ICudaEngine* createCudaEngine(string const& onnxModelPath, int batchSize)
{
unique_ptr<IBuilder, Destroy<IBuilder>> builder{createInferBuilder(gLogger)};
unique_ptr<INetworkDefinition, Destroy<INetworkDefinition>> network{builder->createNetwork()};
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
unique_ptr<nvinfer1::IBuilder, Destroy<nvinfer1::IBuilder>> builder{nvinfer1::createInferBuilder(gLogger)};
unique_ptr<nvinfer1::INetworkDefinition, Destroy<nvinfer1::INetworkDefinition>> network{builder->createNetworkV2(explicitBatch)};
unique_ptr<nvonnxparser::IParser, Destroy<nvonnxparser::IParser>> parser{nvonnxparser::createParser(*network, gLogger)};
unique_ptr<nvinfer1::IBuilderConfig,Destroy<nvinfer1::IBuilderConfig>> config{builder->createBuilderConfig()};

if (!parser->parseFromFile(onnxModelPath.c_str(), static_cast<int>(ILogger::Severity::kINFO)))
{
cout << "ERROR: could not parse input engine." << endl;
return nullptr;
}

return builder->buildCudaEngine(*network); // Build and return TensorRT engine.
builder->setMaxBatchSize(batchSize);
config->setMaxWorkspaceSize((1 << 30));

auto profile = builder->createOptimizationProfile();
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMIN, Dims4{1, 3, 256 , 256});
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kOPT, Dims4{1, 3, 256 , 256});
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMAX, Dims4{32, 3, 256 , 256});
config->addOptimizationProfile(profile);

return builder->buildEngineWithConfig(*network, *config);
}

static int getBindingInputIndex(IExecutionContext* context)
static int getBindingInputIndex(nvinfer1::IExecutionContext* context)
{
return !context->getEngine().bindingIsInput(0); // 0 (false) if bindingIsInput(0), 1 (true) otherwise
}

void launchInference(IExecutionContext* context, cudaStream_t stream, vector<float> const& inputTensor, vector<float>& outputTensor, void** bindings, int batchSize)
{
int inputId = getBindingInputIndex(context);

cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, stream);
context->enqueue(batchSize, bindings, stream, nullptr);
context->enqueueV2(bindings, stream, nullptr);
cudaMemcpyAsync(outputTensor.data(), bindings[1 - inputId], outputTensor.size() * sizeof(float), cudaMemcpyDeviceToHost, stream);
}

void softmax(vector<float>& tensor, int batchSize)
{
size_t batchElements = tensor.size() / batchSize;

for (int i = 0; i < batchSize; ++i)
{
float* batchVector = &tensor[i * batchElements];
double maxValue = *max_element(batchVector, batchVector + batchElements);
double expSum = accumulate(batchVector, batchVector + batchElements, 0.0, [=](double acc, float value) { return acc + exp(value - maxValue); });

transform(batchVector, batchVector + batchElements, batchVector, [=](float input) { return static_cast<float>(std::exp(input - maxValue) / expSum); });
}
}

void verifyOutput(vector<float> const& outputTensor, vector<float> const& referenceTensor)
void verifyOutput(vector<float> const& outputTensor, vector<float> const& referenceTensor, int size)
{
for (size_t i = 0; i < referenceTensor.size(); ++i)
for (size_t i = 0; i < size; ++i)
{
double reference = static_cast<double>(referenceTensor[i]);
// Check absolute and relative tolerance.
Expand All @@ -102,8 +102,31 @@ void verifyOutput(vector<float> const& outputTensor, vector<float> const& refere
return;
}
}
cout << "OK" << endl;
}

cout << "OK" << endl;
void saveImageAsPGM(vector<float>& outputTensor,int H, int W)
{
FILE* pgmimg;
pgmimg = fopen("output.pgm", "wb");

fprintf(pgmimg, "P2\n");
// Writing Width and Height
fprintf(pgmimg, "%d %d\n", H, W);
// Writing the maximum gray value
fprintf(pgmimg, "255\n");

for (int i=0; i< H; ++i)
{
for(int j=0; j<W; ++j)
{
int temp = round(255* outputTensor[i*H + j]);
fprintf(pgmimg, "%d ", temp);
}
fprintf(pgmimg, "\n");
}

fclose(pgmimg);
}

int main(int argc, char* argv[])
Expand Down Expand Up @@ -141,13 +164,14 @@ int main(int argc, char* argv[])
for (int i = 0; i < engine->getNbBindings(); ++i)
{
Dims dims{engine->getBindingDimensions(i)};
size_t size = accumulate(dims.d, dims.d + dims.nbDims, batchSize, multiplies<size_t>());
size_t size = accumulate(dims.d+1, dims.d + dims.nbDims, batchSize, multiplies<size_t>());
// Create CUDA buffer for Tensor.
cudaMalloc(&bindings[i], size * sizeof(float));
cudaMalloc(&bindings[i], batchSize * size * sizeof(float));

// Resize CPU buffers to fit Tensor.
if (engine->bindingIsInput(i))
if (engine->bindingIsInput(i)){
inputTensor.resize(size);
}
else
outputTensor.resize(size);
}
Expand All @@ -158,31 +182,39 @@ int main(int argc, char* argv[])
cout << "Couldn't read input Tensor" << endl;
return 1;
}


// Create Execution Context.
context.reset(engine->createExecutionContext());

Dims dims_i{engine->getBindingDimensions(0)};
Dims4 inputDims{batchSize, dims_i.d[1], dims_i.d[2], dims_i.d[3]};
context->setBindingDimensions(0, inputDims);

launchInference(context.get(), stream, inputTensor, outputTensor, bindings, batchSize);

Dims dims{engine->getBindingDimensions(1)};
saveImageAsPGM(outputTensor, dims.d[2], dims.d[3]);
// Wait until the work is finished.
cudaStreamSynchronize(stream);

vector<string> referenceFiles;
for (string path : inputFiles)
referenceFiles.push_back(path.replace(path.rfind("input"), 5, "output"));
// Try to read and compare against reference tensor from protobuf file.


referenceTensor.resize(outputTensor.size());
if (readTensor(referenceFiles, referenceTensor) != referenceTensor.size())
{
cout << "Couldn't read reference Tensor" << endl;
return 1;
}

// Apply a softmax on the CPU to create a normalized distribution suitable for measuring relative error in probabilities.
softmax(outputTensor, batchSize);
softmax(referenceTensor, batchSize);

verifyOutput(outputTensor, referenceTensor);

Dims dims_o{engine->getBindingDimensions(1)};
int size = batchSize * dims_o.d[2] * dims_o.d[3];
verifyOutput(outputTensor, referenceTensor, size);

for (void* ptr : bindings)
cudaFree(ptr);

Expand Down
Loading

0 comments on commit c2e58d9

Please sign in to comment.