Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions demo/predictor.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,38 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import numpy as np
from maskrcnn_benchmark.structures.keypoint import PersonKeypoints
import matplotlib.pyplot as plt
import cv2
import torch
from torchvision import transforms as T

from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.structures.image_list import to_image_list
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker
from maskrcnn_benchmark import layers as L
from maskrcnn_benchmark.utils import cv2_util


class FeatureExtractorFromBoxes(torch.nn.Module):
"""
Uses a GeneralizedRCNN model to re-compute the
image features and extracts out the roi-aligned/pooled
from the ground truth boxes
"""

def __init__(self, grcnn_model):
super().__init__()
self.mdl = grcnn_model

def forward(self, images, gtboxes):
features = self.mdl.backbone(images.tensors)
x, result, detector_losses = self.mdl.roi_heads(
features, gtboxes, None)
return x, result


class COCODemo(object):
# COCO categories for pretty print
CATEGORIES = [
Expand Down Expand Up @@ -110,10 +132,14 @@ def __init__(
self.model.eval()
self.device = torch.device(cfg.MODEL.DEVICE)
self.model.to(self.device)
# Adding feature extractor
self.feat_extractor = FeatureExtractorFromBoxes(self.model)

self.min_image_size = min_image_size

save_dir = cfg.OUTPUT_DIR
checkpointer = DetectronCheckpointer(cfg, self.model, save_dir=save_dir)
checkpointer = DetectronCheckpointer(
cfg, self.model, save_dir=save_dir)
_ = checkpointer.load(cfg.MODEL.WEIGHT)

self.transforms = self.build_transform()
Expand Down Expand Up @@ -198,7 +224,8 @@ def compute_prediction(self, original_image):
image = self.transforms(original_image)
# convert to an ImageList, padded so that it is divisible by
# cfg.DATALOADER.SIZE_DIVISIBILITY
image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY)
image_list = to_image_list(
image, self.cfg.DATALOADER.SIZE_DIVISIBILITY)
image_list = image_list.to(self.device)
# compute predictions
with torch.no_grad():
Expand All @@ -221,6 +248,37 @@ def compute_prediction(self, original_image):
prediction.add_field("mask", masks)
return prediction

def compute_features_from_bbox(self, original_image, gt_boxes):
"""
Extracts features given the ground-truth boxes
assume ground-truth boxes are list of boxes in xyxy format

Arguments:
original_image (np.ndarray): an image as returned by OpenCV

Returns:
features (BoxList): the ground truth boxes with features
accessible using features.get_field()
"""
# Convert gt boxes to BoxList
gt_box_list = BoxList(
gt_boxes, (original_image.shape[1], original_image.shape[0]), mode='xyxy').to(self.device)
# Convert image as in `run_on_opencv_image`
image = self.transforms(original_image)
# Convert gt boxes for a single image to a list
gt_box_list = [gt_box_list.resize((image.size(2), image.size(1)))]
image_list = to_image_list(
image, self.cfg.DATALOADER.SIZE_DIVISIBILITY)
image_list = image_list.to(self.device)
with torch.no_grad():
features = self.feat_extractor(image_list, gt_box_list)

# sanity check
assert len(features) == len(gt_box_list[0].bbox)
feats = gt_box_list[0]
feats.add_field('features', features)
return feats

def select_top_predictions(self, predictions):
"""
Select only predictions which have a `score` > self.confidence_threshold,
Expand Down Expand Up @@ -328,7 +386,8 @@ def create_mask_montage(self, image, predictions):
masks = masks[:max_masks]
# handle case where we have less detections than max_masks
if len(masks) < max_masks:
masks_padded = torch.zeros(max_masks, 1, height, width, dtype=torch.uint8)
masks_padded = torch.zeros(
max_masks, 1, height, width, dtype=torch.uint8)
masks_padded[: len(masks)] = masks
masks = masks_padded
masks = masks.reshape(masks_per_dim, masks_per_dim, height, width)
Expand Down Expand Up @@ -364,14 +423,12 @@ def overlay_class_names(self, image, predictions):
x, y = box[:2]
s = template.format(label, score)
cv2.putText(
image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255,
255, 255), 1
)

return image

import numpy as np
import matplotlib.pyplot as plt
from maskrcnn_benchmark.structures.keypoint import PersonKeypoints

def vis_keypoints(img, kps, kp_thresh=2, alpha=0.7):
"""Visualizes keypoints (adapted from vis_one_image).
Expand Down