Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions groundingdino/util/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ def load_model(model_config_path: str, model_checkpoint_path: str, device: str =


def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
"""
Load an image and apply transformations.
This function takes the path to an image file, loads the image, and applies a series of transformations to it.
The transformations include resizing the image, converting it to a tensor, and normalizing its pixel values.

Parameters:
image_path (str): The path to the image file.

Returns:
Tuple[np.array, torch.Tensor]: A tuple containing the original image as a NumPy array and the transformed image as a PyTorch tensor.
"""
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
Expand All @@ -49,6 +60,29 @@ def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
image_transformed, _ = transform(image_source, None)
return image, image_transformed

def transform_image(PIL_image: PIL.Image.Image) -> Tuple[np.array, torch.Tensor]:
"""
Transform an RGB image and convert it to a tensor.

This function takes a PIL Image, applies a series of transformations to it, and returns the original and transformed images.
The transformations include resizing the image, converting it to a tensor, and normalizing its pixel values.

Parameters:
PIL_image (PIL.Image.Image): The input image.

Returns:
Tuple[np.array, torch.Tensor]: A tuple containing the original image as a NumPy array and the transformed image as a PyTorch tensor.
"""
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image = np.asarray(PIL_image)
image_transformed, _ = transform(PIL_image, None)
return image, image_transformed

def predict(
model,
Expand Down