-
Notifications
You must be signed in to change notification settings - Fork 35
Closed
Description
besides, I found It's not working with interpolate
, Test code as follow:
import torch.nn.functional as F
from torch import nn
import tensorrt as trt
class TestModel(torch.nn.Module):
def forward(self, x):
xh, xw = int(x.shape[-2] * 2), int(x.shape[-1] * 2)
x = F.interpolate(x, size=(xh, xw), mode="nearest")
return x
test_model = TestModel().cuda()
input_shape = (1, 3, 300, 400)
dummy_tensor = torch.randn(input_shape, dtype=torch.float32).cuda()
# output is (1, 3, 600, 800)
print(test_model(dummy_tensor).shape)
# convert test model to trt
import tensorrt as trt
opt_shape_param = [
[
[1, 3, 160, 240], # min
[1, 3, 800, 1200], # opt
[1, 3, 1600, 2400] # max
]
]
with torch.no_grad():
trt_model = torch2trt(
test_model,
[dummy_tensor],
fp16_mode=False,
opt_shape_param=opt_shape_param,
)
# test trt model
dummy_tensor = torch.randn((1, 3, 400, 400), dtype=torch.float32).cuda()
# except output is (1, 3, 800, 800), but actually the output shape is still (1, 3, 600, 800)
print(trt_model(dummy_tensor).shape)
Metadata
Metadata
Assignees
Labels
No labels