Skip to content

Commit 8d311ca

Browse files
authored
Merge pull request #52 from RainbowLinLin/main
Jetpack 5.1
2 parents 4ebb525 + 1159558 commit 8d311ca

14 files changed

+84
-689
lines changed

.vscode/settings.json

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"files.associations": {
3+
"cmath": "cpp",
4+
"chrono": "cpp"
5+
}
6+
}

CMakeLists.txt

+11-27
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,23 @@
11
cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
2-
set(PROJECT_NAME mtcnn_facenet_cpp_tensorRT)
2+
set(PROJECT_NAME face_recogition_tensorRT)
33
project(${PROJECT_NAME})# LANGUAGES CXX CUDA)
44

55
set (CMAKE_CXX_STANDARD 11)
66

77
# OpenCV
88
find_package(OpenCV REQUIRED)
99

10+
# setup CUDA
1011
find_package(CUDA)
1112
message("-- CUDA version: ${CUDA_VERSION}")
1213

1314
set(
1415
CUDA_NVCC_FLAGS
1516
${CUDA_NVCC_FLAGS};
1617
-O3
17-
-gencode arch=compute_53,code=sm_53
18-
-gencode arch=compute_62,code=sm_62
18+
-gencode arch=compute_87,code=sm_87
1919
)
2020

21-
if(CUDA_VERSION_MAJOR GREATER 9)
22-
message("-- CUDA ${CUDA_VERSION_MAJOR} detected, enabling SM_72")
23-
24-
set(
25-
CUDA_NVCC_FLAGS
26-
${CUDA_NVCC_FLAGS};
27-
-gencode arch=compute_72,code=sm_72
28-
)
29-
30-
endif()
31-
3221
# tensorRT
3322
message("CUDA_TOOLKIT_ROOT_DIR = ${CUDA_TOOLKIT_ROOT_DIR}")
3423

@@ -38,14 +27,20 @@ find_path(TENSORRT_INCLUDE_DIR NvInfer.h
3827
find_path(TENSORRT_INCLUDE_DIR NvInferPlugin.h
3928
HINTS ${TENSORRT_ROOT} ${CUDA_TOOLKIT_ROOT_DIR}
4029
PATH_SUFFIXES include)
30+
find_path(TENSORRT_INCLUDE_DIR NvCaffeParser.h
31+
HINTS ${TENSORRT_ROOT} ${CUDA_TOOLKIT_ROOT_DIR}
32+
PATH_SUFFIXES include)
4133
MESSAGE(STATUS "Found TensorRT headers at ${TENSORRT_INCLUDE_DIR}")
4234
find_library(TENSORRT_LIBRARY_INFER nvinfer
4335
HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR}
4436
PATH_SUFFIXES lib lib64 lib/x64 lib/aarch64-linux-gnu)
4537
find_library(TENSORRT_LIBRARY_INFER_PLUGIN nvinfer_plugin
4638
HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR}
4739
PATH_SUFFIXES lib lib64 lib/x64 lib/aarch64-linux-gnu)
48-
find_library(TENSORRT_LIBRARY_PARSER nvparsers
40+
find_library(TENSORRT_LIBRARY_CAFFE_PARSER nvcaffe_parser
41+
HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR}
42+
PATH_SUFFIXES lib lib64 lib/x64 lib/aarch64-linux-gnu)
43+
find_library(TENSORRT_LIBRARY_PARSER nvparsers
4944
HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR}
5045
PATH_SUFFIXES lib lib64 lib/x64 lib/aarch64-linux-gnu)
5146
set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_PARSER})
@@ -58,22 +53,11 @@ if(NOT TENSORRT_FOUND)
5853
"Cannot find TensorRT library.")
5954
endif()
6055

61-
# l2norm_helper plugin
62-
add_subdirectory(trt_l2norm_helper)
63-
include_directories(
64-
trt_l2norm_helper
65-
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
66-
${TENSORRT_INCLUDE_DIR}
67-
)
68-
6956
message("TENSORRT_LIBRARY = ${TENSORRT_LIBRARY}")
7057

7158
AUX_SOURCE_DIRECTORY(./src DIR_SRCS)
7259
message("DIR_SRCS = ${DIR_SRCS}")
7360
cuda_add_executable(${PROJECT_NAME} ${DIR_SRCS})
7461

75-
target_link_libraries(${PROJECT_NAME}
76-
trt_l2norm_helper
77-
${TENSORRT_LIBRARY}
78-
)
62+
target_link_libraries(${PROJECT_NAME} ${TENSORRT_LIBRARY})
7963
target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS})

README.md

+26-51
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,29 @@
1-
# Face Recognition for NVIDIA Jetson (Nano) using TensorRT
2-
Face recognition with [Google FaceNet](https://arxiv.org/abs/1503.03832)
3-
architecture and retrained model by David Sandberg
4-
([github.com/davidsandberg/facenet](https://github.com/davidsandberg/facenet))
5-
using TensorRT and OpenCV. <br> This project is based on the
6-
implementation of l2norm helper functions which are needed in the output
7-
layer of the FaceNet model. Link to the repo:
8-
[github.com/r7vme/tensorrt_l2norm_helper](https://github.com/r7vme/tensorrt_l2norm_helper). <br>
9-
Moreover, this project uses an adapted version of [PKUZHOU's implementation](https://github.com/PKUZHOU/MTCNN_FaceDetection_TensorRT)
1+
# Face Recognition for NVIDIA Jetson AGX Orin using TensorRT
2+
- This project is based on the implementation of this repo:
3+
[Face Recognition for NVIDIA Jetson (Nano) using TensorRT](https://github.com/nwesem/mtcnn_facenet_cpp_tensorRT). Since the original author is no longer updating his content, and many of the original content cannot be applied to the new Jetpack version and the new Jetson device. Therefore, I have modified the original author's content slightly to make it work for face recognition on the Jetson AGX Orin.
4+
- Face recognition with [Google FaceNet](https://arxiv.org/abs/1503.03832) architecture and retrained model by David Sandberg ([github.com/davidsandberg/facenet](https://github.com/davidsandberg/facenet)) using TensorRT and OpenCV.
5+
- Moreover, this project uses an adapted version of [PKUZHOU's implementation](https://github.com/PKUZHOU/MTCNN_FaceDetection_TensorRT)
106
of the mtCNN for face detection. More info below.
117

128
## Hardware
13-
* NVIDIA Jetson Nano
14-
* Raspberry Pi v2 camera
9+
- Nvidia Jetson AGX Orin DVK
10+
- Logitech C922 Pro HD Stream Webcam
1511

16-
If you want to use a USB camera instead of Raspi Camera set the boolean _isCSICam_ to false in [main.cpp](./src/main.cpp).
12+
If you want to use a CSI camera instead of USB Camera, set the boolean _isCSICam_ to true in [main.cpp](./src/main.cpp).
1713

1814

1915
## Dependencies
20-
cuda 10.2 + cudnn 8.0 <br> TensorRT 7.x <br> OpenCV 4.1.1 <br>
21-
TensorFlow r1.14 (for Python to convert model from .pb to .uff)
16+
- JetPack 5.1
17+
- CUDA 11.4.19 + cuDNN 8.6.0
18+
- TensorRT 8.5.2
19+
- OpenCV 4.5.4
20+
- Tensorflow 2.11
2221

23-
## Update
24-
This master branch now uses Jetpack 4.4, so dependencies have slightly changed and tensorflow is not preinstalled anymore. So there is an extra step that takes a few minutes more than before. <br>
25-
In case you would like to use older versions of Jetpack there is a tag jp4.2.2, that can links to the older implementation.
2622

2723
## Installation
28-
#### 1. Install Cuda, CudNN, TensorRT, and TensorFlow for Python
29-
You can check [NVIDIA website](https://developer.nvidia.com/) for help.
30-
Installation procedures are very well documented.<br><br>**If you are
31-
using NVIDIA Jetson (Nano, TX1/2, Xavier) with Jetpack 4.4**, most needed packages
32-
should be installed if the Jetson was correctly flashed using SDK
33-
Manager or the SD card image, you will only need to install cmake, openblas and tensorflow:
34-
```bash
35-
sudo apt install cmake libopenblas-dev
36-
```
37-
#### 2. Install Tensorflow
38-
The following shows the steps to install Tensorflow for Jetpack 4.4. This was copied from the official [NVIDIA documentation](https://docs.nvidia.com/deeplearning/frameworks/install-tf-jetson-platform/index.html). I'm assuming you don't need to install it in a virtual environment. If yes, please refer to the documentation linked above. If you are not installing this on a jetson, please refer to the official tensorflow documentation.
24+
25+
#### 1. Install Tensorflow
26+
The following shows the steps to install Tensorflow for Jetpack 5.1. This was copied from the official [NVIDIA documentation](https://docs.nvidia.com/deeplearning/frameworks/install-tf-jetson-platform/index.html). I'm assuming you don't need to install it in a virtual environment. If yes, please refer to the documentation linked above. If you are not installing this on a jetson, please refer to the official tensorflow documentation.
3927

4028
```bash
4129
# Install system packages required by TensorFlow:
@@ -44,13 +32,14 @@ sudo apt install libhdf5-serial-dev hdf5-tools libhdf5-dev zlib1g-dev zip libjpe
4432

4533
# Install and upgrade pip3
4634
sudo apt install python3-pip
47-
sudo pip3 install -U pip testresources setuptools
35+
sudo python3 -m pip install --upgrade pip
36+
sudo pip3 install -U testresources setuptools==65.5.0
4837

4938
# Install the Python package dependencies
50-
sudo pip3 install -U numpy==1.16.1 future==0.18.2 mock==3.0.5 h5py==2.10.0 keras_preprocessing==1.1.1 keras_applications==1.0.8 gast==0.2.2 futures protobuf pybind11
39+
sudo pip3 install -U numpy==1.22 future==0.18.2 mock==3.0.5 keras_preprocessing==1.1.2 keras_applications==1.0.8 gast==0.4.0 protobuf pybind11 cython pkgconfig packaging h5py==3.6.0
5140

52-
# Install TensorFlow using the pip3 command. This command will install the latest version of TensorFlow compatible with JetPack 4.4.
53-
sudo pip3 install --pre --extra-index-url https://developer.download.nvidia.com/compute/redist/jp/v44 'tensorflow<2'
41+
# Install TensorFlow using the pip3 command. This command will install the latest version of TensorFlow compatible with JetPack 5.1.
42+
sudo pip3 install --extra-index-url https://developer.download.nvidia.com/compute/redist/jp/v51 tensorflow==2.11.0+nv23.01
5443
```
5544

5645

@@ -127,7 +116,7 @@ Put images of people in the imgs folder. Please only use images that contain one
127116
the OpenCV GUI, press "**N**" on your keyboard to add a new face. The camera input will stop until
128117
you have opened your terminal and put in the name of the person you want to add.
129118
```bash
130-
./mtcnn_facenet_cpp_tensorRT
119+
./face_recogition_tensorRT
131120
```
132121
Press "**Q**" to quit and to show the stats (fps).
133122

@@ -136,26 +125,12 @@ now parses and serializes the model from .uff to a runtime engine
136125
(.engine file).
137126

138127
## Performance
139-
Performance on **NVIDIA Jetson Nano**
140-
* ~60ms +/- 20ms for face detection using mtCNN
141-
* ~22ms +/- 2ms per face for facenet inference
142-
* **Total:** ~15fps
143-
144-
Performance on **NVIDIA Jetson AGX Xavier**:
145-
* ~40ms +/- 20ms for mtCNN
146-
* ~9ms +/- 1ms per face for inference of facenet
147-
* **Total:** ~22fps
128+
Performance on **NVIDIA Jetson AGX Orin**
129+
* ~24ms for face detection using mtCNN
130+
* ~4ms per face for facenet inference
131+
* **Total:** ~30fps
148132

149133
## License
150134
Please respect all licenses of OpenCV and the data the machine learning models (mtCNN and Google FaceNet)
151135
were trained on.
152136

153-
## FAQ
154-
Sometimes the camera driver doesn't close properly that means you will have to restart the __nvargus-daemon__:
155-
```bash
156-
sudo systemctl restart nvargus-daemon
157-
```
158-
159-
## Info
160-
Niclas Wesemann <br>
161-

src/baseEngine.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ void baseEngine::caffeToGIEModel(const std::string &deployFile, /
6666
else {
6767
// create the builder
6868
IBuilder *builder = createInferBuilder(gLogger);
69+
IBuilderConfig* config = builder->createBuilderConfig();
6970

7071
// parse the caffe model to populate the network, then set the outputs
71-
INetworkDefinition *network = builder->createNetwork();
72+
INetworkDefinition *network = builder->createNetworkV2(0U);
7273
ICaffeParser *parser = createCaffeParser();
7374

7475
const IBlobNameToTensor *blobNameToTensor = parser->parse(deployFile.c_str(),
@@ -81,8 +82,8 @@ void baseEngine::caffeToGIEModel(const std::string &deployFile, /
8182

8283
// Build the engine
8384
builder->setMaxBatchSize(maxBatchSize);
84-
builder->setMaxWorkspaceSize(1 << 25);
85-
ICudaEngine *engine = builder->buildCudaEngine(*network);
85+
config->setMaxWorkspaceSize(1 << 25);
86+
ICudaEngine *engine = builder->buildEngineWithConfig(*network, *config);
8687
assert(engine);
8788

8889
context = engine->createExecutionContext();

src/common.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
class Logger : public nvinfer1::ILogger
3232
{
3333
public:
34-
void log(nvinfer1::ILogger::Severity severity, const char* msg) override
34+
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override
3535
{
3636
// suppress info-level messages
3737
//if (severity == Severity::kINFO) return;

src/faceNet.cpp

+26-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include "faceNet.h"
2+
#include <vector>
3+
#include <cmath>
24

35
int FaceNetClassifier::m_classCount = 0;
46

@@ -51,10 +53,11 @@ void FaceNetClassifier::createOrLoadEngine() {
5153
}
5254
else {
5355
IBuilder *builder = createInferBuilder(m_gLogger);
54-
INetworkDefinition *network = builder->createNetwork();
56+
IBuilderConfig* config = builder->createBuilderConfig();
57+
INetworkDefinition *network = builder->createNetworkV2(0U);
5558
IUffParser *parser = createUffParser();
56-
parser->registerInput("input", DimsCHW(160, 160, 3), UffInputOrder::kNHWC);
57-
parser->registerOutput("embeddings");
59+
parser->registerInput("input", Dims3(160, 160, 3), UffInputOrder::kNHWC);
60+
parser->registerOutput("Bottleneck/BatchNorm/batchnorm/add_1");
5861

5962
if (!parser->parse(m_uffFile.c_str(), *network, m_dtype))
6063
{
@@ -68,23 +71,23 @@ void FaceNetClassifier::createOrLoadEngine() {
6871
/* build engine */
6972
if (m_dtype == DataType::kHALF)
7073
{
71-
builder->setFp16Mode(true);
74+
config->setFlag(BuilderFlag::kFP16);
7275
}
7376
else if (m_dtype == DataType::kINT8) {
74-
builder->setInt8Mode(true);
77+
config->setFlag(BuilderFlag::kINT8);
7578
// ToDo
7679
//builder->setInt8Calibrator()
7780
}
7881
builder->setMaxBatchSize(m_batchSize);
79-
builder->setMaxWorkspaceSize(1<<30);
82+
config->setMaxWorkspaceSize(1<<30);
8083
// strict will force selected datatype, even when another was faster
8184
//builder->setStrictTypeConstraints(true);
8285
// Disable DLA, because many layers are still not supported
8386
// and this causes additional latency.
8487
//builder->allowGPUFallback(true);
8588
//builder->setDefaultDeviceType(DeviceType::kDLA);
8689
//builder->setDLACore(1);
87-
m_engine = builder->buildCudaEngine(*network);
90+
m_engine = builder->buildEngineWithConfig(*network, *config);
8891

8992
/* serialize engine and write to file */
9093
if(m_serializeEngine) {
@@ -155,7 +158,7 @@ void FaceNetClassifier::doInference(float* inputData, float* output) {
155158
int size_of_single_input = 3 * 160 * 160 * sizeof(float);
156159
int size_of_single_output = 128 * sizeof(float);
157160
int inputIndex = m_engine->getBindingIndex("input");
158-
int outputIndex = m_engine->getBindingIndex("embeddings");
161+
int outputIndex = m_engine->getBindingIndex("Bottleneck/BatchNorm/batchnorm/add_1");
159162

160163
void* buffers[2];
161164

@@ -262,10 +265,24 @@ FaceNetClassifier::~FaceNetClassifier() {
262265
// std::cout << "FaceNet was destructed" << std::endl;
263266
}
264267

268+
std::vector<float> l2Normalize(const std::vector<float>& vec) {
269+
float norm = 0.0;
270+
for (const auto& element : vec) {
271+
norm += element * element;
272+
}
273+
norm = std::sqrt(norm);
274+
std::vector<float> normalizedVec(vec.size());
275+
for (std::size_t i = 0; i < vec.size(); ++i) {
276+
normalizedVec[i] = vec[i] / norm;
277+
}
278+
return normalizedVec;
279+
}
265280

266281
// HELPER FUNCTIONS
267282
// Computes the distance between two std::vectors
268-
float vectors_distance(const std::vector<float>& a, const std::vector<float>& b) {
283+
float vectors_distance(const std::vector<float>& aa, const std::vector<float>& bb) {
284+
std::vector<float> a = l2Normalize(aa);
285+
std::vector<float> b = l2Normalize(bb);
269286
std::vector<double> auxiliary;
270287
std::transform (a.begin(), a.end(), b.begin(), std::back_inserter(auxiliary),//
271288
[](float element1, float element2) {return pow((element1-element2),2);});

src/faceNet.h

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include <NvInfer.h>
1616
#include <NvUffParser.h>
1717
#include <NvInferPlugin.h>
18-
#include <l2norm_helper.h>
1918
#include "common.h"
2019
#include "pBox.h"
2120

src/main.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <chrono>
44
#include <NvInfer.h>
55
#include <NvInferPlugin.h>
6-
#include <l2norm_helper.h>
76
#include <opencv2/highgui.hpp>
87
#include "faceNet.h"
98
#include "videoStreamer.h"
@@ -35,7 +34,7 @@ int main()
3534
int videoFrameHeight = 480;
3635
int maxFacesPerScene = 5;
3736
float knownPersonThreshold = 1.;
38-
bool isCSICam = true;
37+
bool isCSICam = false;
3938

4039
// init facenet
4140
FaceNetClassifier faceNet = FaceNetClassifier(gLogger, dtype, uffFile, engineFile, batchSize, serializeEngine,
@@ -69,6 +68,7 @@ int main()
6968
// loop over frames with inference
7069
auto globalTimeStart = chrono::steady_clock::now();
7170
while (true) {
71+
auto fps_start = chrono::steady_clock::now();
7272
videoStreamer.getFrame(frame);
7373
if (frame.empty()) {
7474
std::cout << "Empty frame! Exiting...\n Try restarting nvargus-daemon by "
@@ -86,6 +86,12 @@ int main()
8686
auto endFeatM = chrono::steady_clock::now();
8787
faceNet.resetVariables();
8888

89+
auto fps_end = chrono::steady_clock::now();
90+
auto milliseconds = chrono::duration_cast<chrono::milliseconds>(fps_end-fps_start).count();
91+
float fps = (1000/milliseconds);
92+
std::string label = cv::format("FPS: %.2f ", fps);
93+
cv::putText(frame, label, cv::Point(15, 30), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(0, 0, 0), 2);
94+
8995
cv::imshow("VideoSource", frame);
9096
nbFrames++;
9197
outputBbox.clear();

0 commit comments

Comments
 (0)