You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
*`--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `dynamo` or `torch_compile`, the input should be a Pytorch module (instead of a torchscript module).
49
47
*`--model_torch` : Name of the PyTorch model file (optional, only necessary if `dynamo` or `torch_compile` is a chosen backend)
48
+
*`--onnx` : ONNX model file which helps bypass the step of exporting ONNX from `model_torch`. If this argument is provided, the ONNX will be directly converted to TRT engine
50
49
*`--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT
51
50
*`--batch_size` : Batch size
52
51
*`--precision` : Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16
53
52
*`--device` : Device ID
54
53
*`--truncate` : Truncate long and double weights in the network in Torch-TensorRT
55
54
*`--is_trt_engine` : Boolean flag to be enabled if the model file provided is a TensorRT engine.
56
55
*`--report` : Path of the output file where performance summary is written.
56
+
*`--optimization_level` : Builder optimization level for TensorRT (from 1 to 5, 5 is the highest optimization).
help="ONNX model file which helps bypass the step of exporting ONNX from torchscript model. If this argument is provided, the ONNX will be directly converted to TRT engine",
654
+
)
607
655
arg_parser.add_argument(
608
656
"--inputs",
609
657
type=str,
@@ -643,6 +691,12 @@ def run(
643
691
action="store_true",
644
692
help="Truncate long and double weights in the network in Torch-TensorRT",
645
693
)
694
+
arg_parser.add_argument(
695
+
"--optimization_level",
696
+
type=int,
697
+
default=3,
698
+
help="Builder optimization level for TensorRT",
699
+
)
646
700
arg_parser.add_argument(
647
701
"--is_trt_engine",
648
702
action="store_true",
@@ -702,8 +756,13 @@ def run(
702
756
703
757
# Load TorchScript model, if provided
704
758
ifos.path.exists(model_name):
705
-
print("Loading user provided torchscript model: ", model_name)
706
-
model=torch.jit.load(model_name).cuda().eval()
759
+
ifparams["is_trt_engine"]:
760
+
withopen(model_name, "rb") asf:
761
+
model=f.read()
762
+
print("Loading user provided trt engine: ", model_name)
763
+
else:
764
+
print("Loading user provided torchscript model: ", model_name)
0 commit comments