Skip to content

dynamic batches handling #10

@haiderasad

Description

@haiderasad

Hey, nice work,
the dynamic batching flow is a bit broken I think
it works fine if the engine is built on one shape, but when built with

min_shape = (1, 360, 640)   # Minimum shape with batch size 1
opt_shape = (6, 360, 640)   # Optimal shape with batch size 6
max_shape = (10, 360, 640)  # Maximum shape with batch size 10

when I give it an image of (6, 360, 640) it says
ValueError: could not broadcast input array from the shape (1382400,) into shape (230400,)

upon investigating I see that the shape and size of the inputs is set to (10, 360, 640) so
its expecting (10, 360, 640) , I don't know why , so are you aware of what the best practice in tensorrt to handle dynamic inputs?

below is my whole code

engine building


import tensorrt as trt

# Set up TensorRT logger, builder, and network
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

# Use the ONNX parser to load your ONNX model
parser = trt.OnnxParser(network, TRT_LOGGER)

# Path to the ONNX file generated
onnx_file_path = 'det_model.onnx'

# Parse the ONNX file
with open(onnx_file_path, 'rb') as model:
    if not parser.parse(model.read()):
        print('ERROR: Failed to parse the ONNX file.')
        for error in range(parser.num_errors()):
            print(parser.get_error(error))
        exit()

# Configure the builder and create an optimization profile
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)  # 4GB, adjust this as necessary

# Assuming your model's input shape can vary in batch size, you need to set the optimization profile accordingly
min_shape = (1, 360, 640)   # Minimum shape with batch size 1
opt_shape = (6, 360, 640)   # Optimal shape with batch size 6360, 640
max_shape = (10, 360, 640)  # Maximum shape with batch size 10

profile = builder.create_optimization_profile()
profile.set_shape(network.get_input(0).name, min=min_shape, opt=opt_shape, max=max_shape)
config.add_optimization_profile(profile)

# Build the TensorRT engine
engine = builder.build_serialized_network(network, config)

# Save the engine to a file
engine_file_path = 'det_model_dynamic.trt'
with open(engine_file_path, 'wb') as f:
    f.write(engine)

print("TensorRT model is successfully created and saved to", engine_file_path)

Main.py


import ctypes
import numpy as np
import tensorrt as trt
from cuda import cuda, cudart
import cv2 as cv
try:
    FileNotFoundError
except NameError:
    FileNotFoundError = IOError

EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

def check_cuda_err(err):
    if isinstance(err, cuda.CUresult):
        if err != cuda.CUresult.CUDA_SUCCESS:
            raise RuntimeError("Cuda Error: {}".format(err))
    if isinstance(err, cudart.cudaError_t):
        if err != cudart.cudaError_t.cudaSuccess:
            raise RuntimeError("Cuda Runtime Error: {}".format(err))
    else:
        raise RuntimeError("Unknown error type: {}".format(err))

def cuda_call(call):
    err, res = call[0], call[1:]
    check_cuda_err(err)
    if len(res) == 1:
        res = res[0]
    return res

def GiB(val):
    return val * 1 << 30

class HostDeviceMem:
    def __init__(self, size: int, dtype: np.dtype, name= None, shape = None, format= None):
        nbytes = size * dtype.itemsize
        host_mem = cuda_call(cudart.cudaMallocHost(nbytes))
        pointer_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))

        self._host = np.ctypeslib.as_array(ctypes.cast(host_mem, pointer_type), (size,))
        self._device = cuda_call(cudart.cudaMalloc(nbytes))
        self._nbytes = nbytes
        self._name = name
        self._shape = shape
        self._format = format
        self._dtype = dtype

    @property
    def host(self) -> np.ndarray:
        return self._host

    @host.setter
    def host(self, arr: np.ndarray):
        if arr.size > self.host.size:
            raise ValueError(f"Tried to fit an array of size {arr.size} into host memory of size {self.host.size}")
        np.copyto(self.host[:arr.size], arr.flat, casting='safe')

    @property
    def device(self) -> int:
        return self._device

    @property
    def nbytes(self) -> int:
        return self._nbytes

    @property
    def name(self):
        return self._name

    @property
    def shape(self):
        return self._shape

    @property
    def format(self):
        return self._format

    @property
    def dtype(self) -> np.dtype:
        return self._dtype

    def __str__(self):
        return f"Host:\n{self.host}\nDevice:\n{self.device}\nSize:\n{self.nbytes}\n"

    def __repr__(self):
        return self.__str__()

    def free(self):
        cuda_call(cudart.cudaFree(self.device))
        cuda_call(cudart.cudaFreeHost(self.host.ctypes.data))

def allocate_buffers(engine: trt.ICudaEngine, profile_idx= None):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda_call(cudart.cudaStreamCreate())
    tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
    for binding in tensor_names:
        format = engine.get_tensor_format(binding)
        
        
        shape = engine.get_tensor_shape(binding) if profile_idx is None else engine.get_tensor_profile_shape(binding, profile_idx)[0]
        shape_valid = np.all([s >= 0 for s in shape])
        if not shape_valid and profile_idx is None:
            raise ValueError(f"Binding {binding} has dynamic shape, but no profile was specified.")
        size = trt.volume(shape)
        dtype = np.dtype(trt.nptype(engine.get_tensor_dtype(binding)))

        print(shape)
        binding_memory = HostDeviceMem(size, dtype, name=binding, shape=shape, format=format)

        bindings.append(int(binding_memory.device))

        if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
            inputs.append(binding_memory)
        else:
            outputs.append(binding_memory)
    return inputs, outputs, bindings, stream


def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray):
    nbytes = host_arr.size * host_arr.itemsize
    cuda_call(cudart.cudaMemcpy(device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice))

def memcpy_device_to_host(host_arr: np.ndarray, device_ptr: int):
    nbytes = host_arr.size * host_arr.itemsize
    cuda_call(cudart.cudaMemcpy(host_arr, device_ptr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost))

def _do_inference_base(inputs, outputs, stream, execute_async_func):
    kind = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
    [cuda_call(cudart.cudaMemcpyAsync(inp.device, inp.host, inp.nbytes, kind, stream)) for inp in inputs]
    execute_async_func()
    kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
    [cuda_call(cudart.cudaMemcpyAsync(out.host, out.device, out.nbytes, kind, stream)) for out in outputs]
    cuda_call(cudart.cudaStreamSynchronize(stream))
    return [out.host for out in outputs]

def do_inference(context, engine, bindings, inputs, outputs, stream):
    def execute_async_func():
        context.execute_async_v3(stream_handle=stream)

    num_io = engine.num_io_tensors
    context.set_input_shape('input', (6, 360, 640))
    for i in range(num_io):
        context.set_tensor_address(engine.get_tensor_name(i), bindings[i])
        # if engine.get_tensor_name(i)=='input':
        #     context.set_input_shape('input', (6, 360, 640))
        
    #print(context.all_binding_shapes_specified)
    return _do_inference_base(inputs, outputs, stream, execute_async_func)

def load_engine(engine_file_path):
    with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(f.read())

def preprocess_images(images, width=1280 // 2, height=720 // 2):
    shapes = [img.shape for img in images]
    images = [cv.resize(img, (width, height)) for img in images]
    images = np.stack(images)
    images = images / 128.0 - 1
    return images

engine_file_path = 'det_model_dynamic.trt'
engine = load_engine(engine_file_path)
inputs, outputs, bindings, stream = allocate_buffers(engine=engine,profile_idx=0)
images = [np.random.rand(360, 640, 6).astype(np.float32) for _ in range(1)]  # Adjust the batch size as needed
preprocessed_images = preprocess_images(images)

#print(inputs[0].shape)
for host_device_buffer in inputs:
    np.copyto(host_device_buffer.host, preprocessed_images.flatten())
    
context = engine.create_execution_context()
masks = do_inference(context=context, engine=engine, inputs=inputs, outputs=outputs, bindings=bindings, stream=stream)
#print(len(masks))
for mask in masks:
    print(mask.shape)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions