forked from PaddlePaddle/FastDeploy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer.py
142 lines (122 loc) · 4.93 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--det_model", required=True, help="Path of Detection model of PPOCR.")
parser.add_argument(
"--cls_model",
required=True,
help="Path of Classification model of PPOCR.")
parser.add_argument(
"--rec_model",
required=True,
help="Path of Recognization model of PPOCR.")
parser.add_argument(
"--rec_label_file",
required=True,
help="Path of Recognization model of PPOCR.")
parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--backend",
type=str,
default="default",
help="Type of inference backend, support ort/trt/paddle/openvino, default 'openvino' for cpu, 'tensorrt' for gpu"
)
parser.add_argument(
"--device_id",
type=int,
default=0,
help="Define which GPU card used to run model.")
parser.add_argument(
"--cpu_thread_num",
type=int,
default=9,
help="Number of threads while inference on CPU.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu(0)
option.set_cpu_thread_num(args.cpu_thread_num)
if args.backend.lower() == "trt":
assert args.device.lower(
) == "gpu", "TensorRT backend require inference on device GPU."
option.use_trt_backend()
elif args.backend.lower() == "ort":
option.use_ort_backend()
elif args.backend.lower() == "paddle":
option.use_paddle_backend()
elif args.backend.lower() == "openvino":
assert args.device.lower(
) == "cpu", "OpenVINO backend require inference on device CPU."
option.use_openvino_backend()
return option
args = parse_arguments()
# Detection模型, 检测文字框
det_model_file = os.path.join(args.det_model, "inference.pdmodel")
det_params_file = os.path.join(args.det_model, "inference.pdiparams")
# Classification模型,方向分类,可选
cls_model_file = os.path.join(args.cls_model, "inference.pdmodel")
cls_params_file = os.path.join(args.cls_model, "inference.pdiparams")
# Recognition模型,文字识别模型
rec_model_file = os.path.join(args.rec_model, "inference.pdmodel")
rec_params_file = os.path.join(args.rec_model, "inference.pdiparams")
rec_label_file = args.rec_label_file
# 对于三个模型,均采用同样的部署配置
# 用户也可根据自行需求分别配置
runtime_option = build_option(args)
det_option = runtime_option
cls_option = runtime_option
rec_option = runtime_option
# 当使用TRT时,分别给三个Runtime设置动态shape
det_option.set_trt_input_shape("x", [1, 3, 50, 50], [1, 3, 640, 640],
[1, 3, 1536, 1536])
cls_option.set_trt_input_shape("x", [1, 3, 48, 10], [1, 3, 48, 320],
[1, 3, 48, 1024])
rec_option.set_trt_input_shape("x", [1, 3, 48, 10], [1, 3, 48, 320],
[1, 3, 48, 2304])
# 用户可以把TRT引擎文件保存至本地
# det_option.set_trt_cache_file(args.det_model + "/det_trt_cache.trt")
# cls_option.set_trt_cache_file(args.cls_model + "/cls_trt_cache.trt")
# rec_option.set_trt_cache_file(args.rec_model + "/rec_trt_cache.trt")
det_model = fd.vision.ocr.DBDetector(
det_model_file, det_params_file, runtime_option=det_option)
cls_model = fd.vision.ocr.Classifier(
cls_model_file, cls_params_file, runtime_option=cls_option)
rec_model = fd.vision.ocr.Recognizer(
rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option)
# 创建PP-OCR,串联3个模型,其中cls_model可选,如无需求,可设置为None
ppocr_v3 = fd.vision.ocr.PPOCRv3(
det_model=det_model, cls_model=cls_model, rec_model=rec_model)
# 预测图片准备
im = cv2.imread(args.image)
#预测并打印结果
result = ppocr_v3.predict(im)
print(result)
# 可视化结果
vis_im = fd.vision.vis_ppocr(im, result)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")