Skip to content

Commit c2e58d9

Browse files
authored
Merge pull request #31 from habbasian/pr
Updating TensorRT-introduction to work with TRT7 and Dynamic Shape
2 parents bfe1d42 + c4da909 commit c2e58d9

File tree

6 files changed

+139
-107
lines changed

6 files changed

+139
-107
lines changed

Diff for: posts/TensorRT-introduction/Makefile

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
CUDA_INSTALL_DIR=/usr/local/cuda
2727

28-
CXXFLAGS=-std=c++11 -Wall -I$(CUDA_INSTALL_DIR)/include
29-
LDFLAGS=-L$(CUDA_INSTALL_DIR)/lib64 -L$(CUDA_INSTALL_DIR)/lib64/stubs
30-
LDLIBS=-Wl,--start-group -lnvinfer -lnvonnxparser -lcudart_static -lrt -ldl -lpthread -lonnx -lonnx_proto -lprotobuf -lstdc++ -lm -Wl,--end-group
28+
CXXFLAGS=-std=c++11 -DONNX_ML=1 -Wall -I$(CUDA_INSTALL_DIR)/include
29+
LDFLAGS=-L$(CUDA_INSTALL_DIR)/lib64 -L$(CUDA_INSTALL_DIR)/lib64/stubs -L/usr/local/lib
30+
LDLIBS=-Wl,--start-group -lnvonnxparser -lnvinfer -lcudart_static -lonnx -lonnx_proto -lprotobuf -lstdc++ -lm -lrt -ldl -lpthread -Wl,--end-group
3131

3232
HEADERS=${wildcard *.h}
3333
TARGET_SRCS=$(wildcard simpleOnnx*.cpp)
3434
TARGET_OBJS=${TARGET_SRCS:.cpp=.o}
3535
TARGETS=${TARGET_OBJS:.o=}
3636

37-
37+
3838
all: $(TARGETS)
3939

4040
$(TARGETS): %: %.o ioHelper.o

Diff for: posts/TensorRT-introduction/ioHelper.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,13 @@
2424
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
*/
27-
#include "ioHelper.h"
2827
#include <algorithm>
2928
#include <fstream>
3029
#include <google/protobuf/io/coded_stream.h>
3130
#include <google/protobuf/io/zero_copy_stream_impl.h>
3231
#include <iterator>
3332
#include <onnx/onnx_pb.h>
34-
33+
#include "ioHelper.h"
3534
using namespace std;
3635

3736
namespace nvinfer1
@@ -83,8 +82,7 @@ size_t readTensor(vector<string> const& tensorProtoPaths, vector<float>& buffer)
8382

8483
for (size_t i = 0; i < tensorProtoPaths.size(); ++i)
8584
{
86-
size_t elements = readTensorProto(tensorProtoPaths[i], &buffer[totalElements]);
87-
if (!elements)
85+
size_t elements = readTensorProto(tensorProtoPaths[i], &buffer[totalElements]); if (!elements)
8886
{
8987
cout << "ERROR: could not read tensor from file " << tensorProtoPaths[i] << endl;
9088
break;

Diff for: posts/TensorRT-introduction/ioHelper.o

30.6 KB
Binary file not shown.

Diff for: posts/TensorRT-introduction/simpleOnnx.cpp

+35-36
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,20 @@
2424
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
*/
27+
#include <NvInfer.h>
2728
#include "cudaWrapper.h"
2829
#include "ioHelper.h"
29-
#include <NvInfer.h>
3030
#include <NvOnnxParser.h>
3131
#include <algorithm>
32+
#include <functional>
33+
#include <cmath>
3234
#include <cassert>
3335
#include <iostream>
3436
#include <memory>
3537
#include <string>
3638
#include <vector>
39+
#include <numeric>
40+
#include <math.h>
3741

3842
using namespace nvinfer1;
3943
using namespace std;
@@ -51,25 +55,30 @@ constexpr double REL_EPSILON = 0.05;
5155
constexpr size_t MAX_WORKSPACE_SIZE = 1ULL << 30; // 1 GB
5256

5357
ICudaEngine* createCudaEngine(string const& onnxModelPath, int batchSize)
54-
{
55-
unique_ptr<IBuilder, Destroy<IBuilder>> builder{createInferBuilder(gLogger)};
56-
unique_ptr<INetworkDefinition, Destroy<INetworkDefinition>> network{builder->createNetwork()};
58+
{
59+
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
60+
unique_ptr<nvinfer1::IBuilder, Destroy<nvinfer1::IBuilder>> builder{nvinfer1::createInferBuilder(gLogger)};
61+
unique_ptr<nvinfer1::INetworkDefinition, Destroy<nvinfer1::INetworkDefinition>> network{builder->createNetworkV2(explicitBatch)};
5762
unique_ptr<nvonnxparser::IParser, Destroy<nvonnxparser::IParser>> parser{nvonnxparser::createParser(*network, gLogger)};
63+
unique_ptr<nvinfer1::IBuilderConfig,Destroy<nvinfer1::IBuilderConfig>> config{builder->createBuilderConfig()};
5864

5965
if (!parser->parseFromFile(onnxModelPath.c_str(), static_cast<int>(ILogger::Severity::kINFO)))
6066
{
6167
cout << "ERROR: could not parse input engine." << endl;
6268
return nullptr;
6369
}
6470

65-
// Build TensorRT engine optimized based on for batch size of input data provided.
66-
builder->setMaxBatchSize(batchSize);
67-
// Allow TensorRT to use fp16 mode kernels internally.
68-
// Note that Input and Output tensors will still use 32 bit float type by default.
71+
config->setMaxWorkspaceSize(MAX_WORKSPACE_SIZE);
6972
builder->setFp16Mode(builder->platformHasFastFp16());
70-
builder->setMaxWorkspaceSize(MAX_WORKSPACE_SIZE);
71-
72-
return builder->buildCudaEngine(*network); // Build and return TensorRT engine.
73+
builder->setMaxBatchSize(batchSize);
74+
75+
auto profile = builder->createOptimizationProfile();
76+
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMIN, Dims4{1, 3, 256 , 256});
77+
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kOPT, Dims4{1, 3, 256 , 256});
78+
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMAX, Dims4{32, 3, 256 , 256});
79+
config->addOptimizationProfile(profile);
80+
81+
return builder->buildEngineWithConfig(*network, *config);
7382
}
7483

7584
ICudaEngine* getCudaEngine(string const& onnxModelPath, int batchSize)
@@ -78,11 +87,13 @@ ICudaEngine* getCudaEngine(string const& onnxModelPath, int batchSize)
7887
ICudaEngine* engine{nullptr};
7988

8089
string buffer = readBuffer(enginePath);
90+
8191
if (buffer.size())
8292
{
8393
// Try to deserialize engine.
8494
unique_ptr<IRuntime, Destroy<IRuntime>> runtime{createInferRuntime(gLogger)};
8595
engine = runtime->deserializeCudaEngine(buffer.data(), buffer.size(), nullptr);
96+
8697
}
8798

8899
if (!engine)
@@ -110,7 +121,7 @@ void launchInference(IExecutionContext* context, cudaStream_t stream, vector<flo
110121
int inputId = getBindingInputIndex(context);
111122

112123
cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, stream);
113-
context->enqueue(batchSize, bindings, stream, nullptr);
124+
context->enqueueV2(bindings, stream, nullptr);
114125
cudaMemcpyAsync(outputTensor.data(), bindings[1 - inputId], outputTensor.size() * sizeof(float), cudaMemcpyDeviceToHost, stream);
115126
}
116127

@@ -139,23 +150,9 @@ void doInference(IExecutionContext* context, cudaStream_t stream, vector<float>
139150
cout << "Inference batch size " << batchSize << " average over " << ITERATIONS << " runs is " << totalTime / ITERATIONS << "ms" << endl;
140151
}
141152

142-
void softmax(vector<float>& tensor, int batchSize)
153+
void verifyOutput(vector<float> const& outputTensor, vector<float> const& referenceTensor, int size)
143154
{
144-
size_t batchElements = tensor.size() / batchSize;
145-
146-
for (int i = 0; i < batchSize; ++i)
147-
{
148-
float* batchVector = &tensor[i * batchElements];
149-
double maxValue = *max_element(batchVector, batchVector + batchElements);
150-
double expSum = accumulate(batchVector, batchVector + batchElements, 0.0, [=](double acc, float value) { return acc + exp(value - maxValue); });
151-
152-
transform(batchVector, batchVector + batchElements, batchVector, [=](float input) { return static_cast<float>(std::exp(input - maxValue) / expSum); });
153-
}
154-
}
155-
156-
void verifyOutput(vector<float> const& outputTensor, vector<float> const& referenceTensor)
157-
{
158-
for (size_t i = 0; i < referenceTensor.size(); ++i)
155+
for (size_t i = 0; i < size; ++i)
159156
{
160157
double reference = static_cast<double>(referenceTensor[i]);
161158
// Check absolute and relative tolerance.
@@ -207,9 +204,9 @@ int main(int argc, char* argv[])
207204
for (int i = 0; i < engine->getNbBindings(); ++i)
208205
{
209206
Dims dims{engine->getBindingDimensions(i)};
210-
size_t size = accumulate(dims.d, dims.d + dims.nbDims, batchSize, multiplies<size_t>());
207+
size_t size = std::accumulate(dims.d+1, dims.d + dims.nbDims, batchSize, multiplies<size_t>());
211208
// Create CUDA buffer for Tensor.
212-
cudaMalloc(&bindings[i], size * sizeof(float));
209+
cudaMalloc(&bindings[i], batchSize * size * sizeof(float));
213210

214211
// Resize CPU buffers to fit Tensor.
215212
if (engine->bindingIsInput(i))
@@ -228,6 +225,10 @@ int main(int argc, char* argv[])
228225
// Create Execution Context.
229226
context.reset(engine->createExecutionContext());
230227

228+
Dims dims_i{engine->getBindingDimensions(0)};
229+
Dims4 inputDims{batchSize, dims_i.d[1], dims_i.d[2], dims_i.d[3]};
230+
context->setBindingDimensions(0, inputDims);
231+
231232
doInference(context.get(), stream, inputTensor, outputTensor, bindings, batchSize);
232233

233234
vector<string> referenceFiles;
@@ -240,12 +241,10 @@ int main(int argc, char* argv[])
240241
cout << "Couldn't read reference Tensor" << endl;
241242
return 1;
242243
}
243-
244-
// Apply a softmax on the CPU to create a normalized distribution suitable for measuring relative error in probabilities.
245-
softmax(outputTensor, batchSize);
246-
softmax(referenceTensor, batchSize);
247-
248-
verifyOutput(outputTensor, referenceTensor);
244+
245+
Dims dims_o{engine->getBindingDimensions(1)};
246+
int size = batchSize * dims_o.d[2] * dims_o.d[3];
247+
verifyOutput(outputTensor, referenceTensor, size);
249248

250249
for (void* ptr : bindings)
251250
cudaFree(ptr);

Diff for: posts/TensorRT-introduction/simpleOnnx_1.cpp

+65-33
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,19 @@
2424
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
*/
27+
#include <NvInfer.h>
2728
#include "cudaWrapper.h"
2829
#include "ioHelper.h"
29-
#include <NvInfer.h>
3030
#include <NvOnnxParser.h>
3131
#include <algorithm>
3232
#include <cassert>
3333
#include <iostream>
3434
#include <memory>
3535
#include <string>
3636
#include <vector>
37+
#include <numeric>
38+
#include <math.h>
39+
#include <cmath>
3740

3841
using namespace nvinfer1;
3942
using namespace std;
@@ -46,52 +49,49 @@ constexpr double ABS_EPSILON = 0.005;
4649
// Maxmimum relative tolerance for output tensor comparison against reference.
4750
constexpr double REL_EPSILON = 0.05;
4851

49-
ICudaEngine* createCudaEngine(string const& onnxModelPath, int batchSize)
52+
nvinfer1::ICudaEngine* createCudaEngine(string const& onnxModelPath, int batchSize)
5053
{
51-
unique_ptr<IBuilder, Destroy<IBuilder>> builder{createInferBuilder(gLogger)};
52-
unique_ptr<INetworkDefinition, Destroy<INetworkDefinition>> network{builder->createNetwork()};
54+
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
55+
unique_ptr<nvinfer1::IBuilder, Destroy<nvinfer1::IBuilder>> builder{nvinfer1::createInferBuilder(gLogger)};
56+
unique_ptr<nvinfer1::INetworkDefinition, Destroy<nvinfer1::INetworkDefinition>> network{builder->createNetworkV2(explicitBatch)};
5357
unique_ptr<nvonnxparser::IParser, Destroy<nvonnxparser::IParser>> parser{nvonnxparser::createParser(*network, gLogger)};
58+
unique_ptr<nvinfer1::IBuilderConfig,Destroy<nvinfer1::IBuilderConfig>> config{builder->createBuilderConfig()};
5459

5560
if (!parser->parseFromFile(onnxModelPath.c_str(), static_cast<int>(ILogger::Severity::kINFO)))
5661
{
5762
cout << "ERROR: could not parse input engine." << endl;
5863
return nullptr;
5964
}
6065

61-
return builder->buildCudaEngine(*network); // Build and return TensorRT engine.
66+
builder->setMaxBatchSize(batchSize);
67+
config->setMaxWorkspaceSize((1 << 30));
68+
69+
auto profile = builder->createOptimizationProfile();
70+
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMIN, Dims4{1, 3, 256 , 256});
71+
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kOPT, Dims4{1, 3, 256 , 256});
72+
profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMAX, Dims4{32, 3, 256 , 256});
73+
config->addOptimizationProfile(profile);
74+
75+
return builder->buildEngineWithConfig(*network, *config);
6276
}
6377

64-
static int getBindingInputIndex(IExecutionContext* context)
78+
static int getBindingInputIndex(nvinfer1::IExecutionContext* context)
6579
{
6680
return !context->getEngine().bindingIsInput(0); // 0 (false) if bindingIsInput(0), 1 (true) otherwise
6781
}
6882

6983
void launchInference(IExecutionContext* context, cudaStream_t stream, vector<float> const& inputTensor, vector<float>& outputTensor, void** bindings, int batchSize)
7084
{
7185
int inputId = getBindingInputIndex(context);
72-
7386
cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, stream);
74-
context->enqueue(batchSize, bindings, stream, nullptr);
87+
context->enqueueV2(bindings, stream, nullptr);
7588
cudaMemcpyAsync(outputTensor.data(), bindings[1 - inputId], outputTensor.size() * sizeof(float), cudaMemcpyDeviceToHost, stream);
76-
}
7789

78-
void softmax(vector<float>& tensor, int batchSize)
79-
{
80-
size_t batchElements = tensor.size() / batchSize;
81-
82-
for (int i = 0; i < batchSize; ++i)
83-
{
84-
float* batchVector = &tensor[i * batchElements];
85-
double maxValue = *max_element(batchVector, batchVector + batchElements);
86-
double expSum = accumulate(batchVector, batchVector + batchElements, 0.0, [=](double acc, float value) { return acc + exp(value - maxValue); });
87-
88-
transform(batchVector, batchVector + batchElements, batchVector, [=](float input) { return static_cast<float>(std::exp(input - maxValue) / expSum); });
89-
}
9090
}
9191

92-
void verifyOutput(vector<float> const& outputTensor, vector<float> const& referenceTensor)
92+
void verifyOutput(vector<float> const& outputTensor, vector<float> const& referenceTensor, int size)
9393
{
94-
for (size_t i = 0; i < referenceTensor.size(); ++i)
94+
for (size_t i = 0; i < size; ++i)
9595
{
9696
double reference = static_cast<double>(referenceTensor[i]);
9797
// Check absolute and relative tolerance.
@@ -102,8 +102,31 @@ void verifyOutput(vector<float> const& outputTensor, vector<float> const& refere
102102
return;
103103
}
104104
}
105+
cout << "OK" << endl;
106+
}
105107

106-
cout << "OK" << endl;
108+
void saveImageAsPGM(vector<float>& outputTensor,int H, int W)
109+
{
110+
FILE* pgmimg;
111+
pgmimg = fopen("output.pgm", "wb");
112+
113+
fprintf(pgmimg, "P2\n");
114+
// Writing Width and Height
115+
fprintf(pgmimg, "%d %d\n", H, W);
116+
// Writing the maximum gray value
117+
fprintf(pgmimg, "255\n");
118+
119+
for (int i=0; i< H; ++i)
120+
{
121+
for(int j=0; j<W; ++j)
122+
{
123+
int temp = round(255* outputTensor[i*H + j]);
124+
fprintf(pgmimg, "%d ", temp);
125+
}
126+
fprintf(pgmimg, "\n");
127+
}
128+
129+
fclose(pgmimg);
107130
}
108131

109132
int main(int argc, char* argv[])
@@ -141,13 +164,14 @@ int main(int argc, char* argv[])
141164
for (int i = 0; i < engine->getNbBindings(); ++i)
142165
{
143166
Dims dims{engine->getBindingDimensions(i)};
144-
size_t size = accumulate(dims.d, dims.d + dims.nbDims, batchSize, multiplies<size_t>());
167+
size_t size = accumulate(dims.d+1, dims.d + dims.nbDims, batchSize, multiplies<size_t>());
145168
// Create CUDA buffer for Tensor.
146-
cudaMalloc(&bindings[i], size * sizeof(float));
169+
cudaMalloc(&bindings[i], batchSize * size * sizeof(float));
147170

148171
// Resize CPU buffers to fit Tensor.
149-
if (engine->bindingIsInput(i))
172+
if (engine->bindingIsInput(i)){
150173
inputTensor.resize(size);
174+
}
151175
else
152176
outputTensor.resize(size);
153177
}
@@ -158,31 +182,39 @@ int main(int argc, char* argv[])
158182
cout << "Couldn't read input Tensor" << endl;
159183
return 1;
160184
}
185+
161186

162187
// Create Execution Context.
163188
context.reset(engine->createExecutionContext());
189+
190+
Dims dims_i{engine->getBindingDimensions(0)};
191+
Dims4 inputDims{batchSize, dims_i.d[1], dims_i.d[2], dims_i.d[3]};
192+
context->setBindingDimensions(0, inputDims);
164193

165194
launchInference(context.get(), stream, inputTensor, outputTensor, bindings, batchSize);
195+
196+
Dims dims{engine->getBindingDimensions(1)};
197+
saveImageAsPGM(outputTensor, dims.d[2], dims.d[3]);
166198
// Wait until the work is finished.
167199
cudaStreamSynchronize(stream);
168200

169201
vector<string> referenceFiles;
170202
for (string path : inputFiles)
171203
referenceFiles.push_back(path.replace(path.rfind("input"), 5, "output"));
172204
// Try to read and compare against reference tensor from protobuf file.
205+
206+
173207
referenceTensor.resize(outputTensor.size());
174208
if (readTensor(referenceFiles, referenceTensor) != referenceTensor.size())
175209
{
176210
cout << "Couldn't read reference Tensor" << endl;
177211
return 1;
178212
}
179213

180-
// Apply a softmax on the CPU to create a normalized distribution suitable for measuring relative error in probabilities.
181-
softmax(outputTensor, batchSize);
182-
softmax(referenceTensor, batchSize);
183-
184-
verifyOutput(outputTensor, referenceTensor);
185-
214+
Dims dims_o{engine->getBindingDimensions(1)};
215+
int size = batchSize * dims_o.d[2] * dims_o.d[3];
216+
verifyOutput(outputTensor, referenceTensor, size);
217+
186218
for (void* ptr : bindings)
187219
cudaFree(ptr);
188220

0 commit comments

Comments
 (0)