Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Imrpove SFD detect and batch_detect #369

Open
yassineAlouini opened this issue Feb 14, 2025 · 0 comments
Open

Imrpove SFD detect and batch_detect #369

yassineAlouini opened this issue Feb 14, 2025 · 0 comments

Comments

@yassineAlouini
Copy link

yassineAlouini commented Feb 14, 2025

In progress...

import torch
import torch.nn.functional as F
import cv2
import numpy as np

from .bbox import decode  # assume decode supports vectorized inputs

def detect(net, img, device):
    # Transpose from (H, W, C) to (C, H, W)
    img = img.transpose(2, 0, 1)
    # Create a batch of 1. Use np.ascontiguousarray to avoid extra copies.
    img = np.expand_dims(np.ascontiguousarray(img), 0)
    img = torch.from_numpy(img).to(device, dtype=torch.float32)
    return batch_detect(net, img, device)


def batch_detect(net, img_batch, device):
    """
    Inputs:
        - img_batch: a torch.Tensor of shape (Batch size, Channels, Height, Width)
    """
    # It is better to set cudnn.benchmark globally (outside the function)
    # rather than on every call (if using CUDA).
    if 'cuda' in device:
        torch.backends.cudnn.benchmark = True

    # Make sure img_batch is on the correct device and in float32.
    img_batch = img_batch.to(device, dtype=torch.float32)

    # Convert RGB (assumed input) to BGR by flipping the channel dimension.
    # (Could also use explicit channel indexing like img_batch = img_batch[:, [2,1,0],:,:])
    img_batch = img_batch.flip(-3)
    
    # Subtract the mean
    mean = torch.tensor([104.0, 117.0, 123.0], device=device).view(1, 3, 1, 1)
    img_batch = img_batch - mean

    with torch.no_grad():
        olist = net(img_batch)

    # Apply softmax on all classification outputs. Assuming that every even-index output 
    # is a classification output:
    olist = [F.softmax(o, dim=1) if idx % 2 == 0 else o for idx, o in enumerate(olist)]
    # Transfer outputs to the CPU and convert to numpy.
    olist = [o.cpu().numpy() for o in olist]
    
    bboxlists = get_predictions(olist, img_batch.size(0))
    return bboxlists


def get_predictions(olist, batch_size):
    """
    Vectorized version that obtains candidate detections from the network outputs.
    It groups detections per batch sample.

    Returns a list of arrays, one per image in the batch, where each array is
    of shape (N, 5) representing the 4 bounding box coordinates and the final score.
    """
    # Create a list to hold detections for every image
    detections_by_image = [[] for _ in range(batch_size)]
    # Variances used in decoding
    variances = [0.1, 0.2]

    num_scales = len(olist) // 2
    for i in range(num_scales):
        # Get classification and regression results for this scale.
        ocls = olist[i * 2]      # shape: (batch, num_classes, H, W)
        oreg = olist[i * 2 + 1]  # shape: (batch, 4, H, W)
        # Define the stride (note that 2**(i+2) gives 4,8,16,32,...)
        stride = 2 ** (i + 2)

        # Use vectorized thresholding: obtain all positions (across the batch) with score > 0.05
        # Note: np.where returns a tuple (batch_inds, h_inds, w_inds)
        batch_inds, h_inds, w_inds = np.where(ocls[:, 1, :, :] > 0.05)
        if batch_inds.size == 0:
            continue

        # Compute the center coordinates based on stride.
        axc = stride / 2 + w_inds * stride
        ayc = stride / 2 + h_inds * stride
        # Each candidate uses the same prior box dimensions at this scale.
        priors = np.vstack((
            axc,
            ayc,
            np.full_like(axc, stride * 4),
            np.full_like(ayc, stride * 4)
        )).T  # shape: (N, 4)

        # Gather the scores (expand dims for concatenation later)
        scores = ocls[batch_inds, 1, h_inds, w_inds][:, None]  # shape: (N, 1)
        # Gather regression outputs for the same positions.
        # Here, indexing is done on every detection: from oreg (batch, 4, H, W)
        locs = oreg[batch_inds, :, h_inds, w_inds]  # shape: (N, 4)

        # Decode the location predictions using the priors and provided variances.
        # (Assuming that decode is implemented to work with vectorized inputs.)
        boxes = decode(locs, priors, variances)  # expected shape: (N, 4)

        # Concatenate the boxes with their scores.
        detections = np.concatenate((boxes, scores), axis=1)  # shape: (N, 5)

        # Group detections by the image index
        for b, det in zip(batch_inds, detections):
            detections_by_image[b].append(det)

    # For every image in the batch, convert list of detections into a numpy array.
    for i in range(batch_size):
        if detections_by_image[i]:
            detections_by_image[i] = np.stack(detections_by_image[i], axis=0)
        else:
            # If no candidates, return an empty array with shape (0, 5)
            detections_by_image[i] = np.empty((0, 5))
    return detections_by_image


def flip_detect(net, img, device):
    # Flips the image horizontally.
    img = cv2.flip(img, 1)
    b = detect(net, img, device)

    bboxlist = np.zeros(b[0].shape) if b[0].size > 0 else np.empty((0, 5))
    if bboxlist.size > 0:
        # Adjust the bounding boxes to the original (flipped) image coordinates.
        bboxlist[:, 0] = img.shape[1] - b[0][:, 2]  # x_min
        bboxlist[:, 1] = b[0][:, 1]                  # y_min remains the same
        bboxlist[:, 2] = img.shape[1] - b[0][:, 0]  # x_max
        bboxlist[:, 3] = b[0][:, 3]                  # y_max remains the same
        bboxlist[:, 4] = b[0][:, 4]                  # score
    return bboxlist


def pts_to_bb(pts):
    # Converts a set of points to a bounding box
    min_xy = np.min(pts, axis=0)
    max_xy = np.max(pts, axis=0)
    return np.array([min_xy[0], min_xy[1], max_xy[0], max_xy[1]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant