|
| 1 | +import tensorflow as tf |
| 2 | + |
| 3 | +from .metrics import bbox_iou |
| 4 | +from .utils import swap_xy, convert_to_xywh |
| 5 | + |
| 6 | + |
| 7 | +def random_flip_horizontal(image, boxes): |
| 8 | + if tf.random.uniform(()) > 0.5: |
| 9 | + image = tf.image.flip_left_right(image) |
| 10 | + boxes = tf.stack([1 - boxes[:, 2], boxes[:, 1], 1 - boxes[:, 0], boxes[:, 3]], axis=-1) |
| 11 | + return image, boxes |
| 12 | + |
| 13 | + |
| 14 | +def resize_and_pad_image(image, min_side=800.0, max_side=1333.0, jitter=(640, 1024), stride=128.0): |
| 15 | + image_shape = tf.cast(tf.shape(image)[:2], dtype=tf.float32) |
| 16 | + if jitter is not None: |
| 17 | + min_side = tf.random.uniform((), jitter[0], jitter[1], dtype=tf.float32) |
| 18 | + ratio = min_side / tf.reduce_min(image_shape) |
| 19 | + if ratio * tf.reduce_max(image_shape) > max_side: |
| 20 | + ratio = max_side / tf.reduce_max(image_shape) |
| 21 | + image_shape = ratio * image_shape |
| 22 | + image = tf.image.resize(image, tf.cast(image_shape, dtype=tf.int32)) |
| 23 | + padded_image_shape = tf.cast( |
| 24 | + tf.math.ceil(image_shape / stride) * stride, dtype=tf.int32 |
| 25 | + ) |
| 26 | + image = tf.image.pad_to_bounding_box( |
| 27 | + image, 0, 0, padded_image_shape[0], padded_image_shape[1] |
| 28 | + ) |
| 29 | + return image, image_shape, ratio |
| 30 | + |
| 31 | + |
| 32 | +def preprocess_data(sample): |
| 33 | + """Applies preprocessing step to a single sample |
| 34 | + """ |
| 35 | + image = sample["image"] |
| 36 | + bbox = swap_xy(sample["objects"]["bbox"]) |
| 37 | + class_id = tf.cast(sample["objects"]["label"], dtype=tf.int32) |
| 38 | + |
| 39 | + image, bbox = random_flip_horizontal(image, bbox) |
| 40 | + image, image_shape, _ = resize_and_pad_image(image) |
| 41 | + |
| 42 | + bbox = tf.stack( |
| 43 | + [ |
| 44 | + bbox[:, 0] * image_shape[1], |
| 45 | + bbox[:, 1] * image_shape[0], |
| 46 | + bbox[:, 2] * image_shape[1], |
| 47 | + bbox[:, 3] * image_shape[0], |
| 48 | + ], |
| 49 | + axis=-1, |
| 50 | + ) |
| 51 | + bbox = convert_to_xywh(bbox) |
| 52 | + return image, bbox, class_id |
| 53 | + |
| 54 | + |
| 55 | +class AnchorBox: |
| 56 | + """Generates anchor boxes. |
| 57 | +
|
| 58 | + This class has operations to generate anchor boxes for feature maps at |
| 59 | + strides `[8, 16, 32, 64, 128]`. Where each anchor each box is of the |
| 60 | + format `[x, y, width, height]`. |
| 61 | +
|
| 62 | + Attributes: |
| 63 | + aspect_ratios: A list of float values representing the aspect ratios of |
| 64 | + the anchor boxes at each location on the feature map |
| 65 | + scales: A list of float values representing the scale of the anchor boxes |
| 66 | + at each location on the feature map. |
| 67 | + num_anchors: The number of anchor boxes at each location on feature map |
| 68 | + areas: A list of float values representing the areas of the anchor |
| 69 | + boxes for each feature map in the feature pyramid. |
| 70 | + strides: A list of float value representing the strides for each feature |
| 71 | + map in the feature pyramid. |
| 72 | + """ |
| 73 | + |
| 74 | + def __init__(self): |
| 75 | + self.aspect_ratios = [0.5, 1.0, 2.0] |
| 76 | + self.scales = [2 ** x for x in [0, 1 / 3, 2 / 3]] |
| 77 | + |
| 78 | + self._num_anchors = len(self.aspect_ratios) * len(self.scales) |
| 79 | + self._strides = [2 ** i for i in range(3, 8)] |
| 80 | + self._areas = [x ** 2 for x in [32.0, 64.0, 128.0, 256.0, 512.0]] |
| 81 | + self._anchor_dims = self._compute_dims() |
| 82 | + |
| 83 | + def _compute_dims(self): |
| 84 | + anchor_dims_all = [] |
| 85 | + for area in self._areas: |
| 86 | + anchor_dims = [] |
| 87 | + for ratio in self.aspect_ratios: |
| 88 | + anchor_height = tf.math.sqrt(area / ratio) |
| 89 | + anchor_width = area / anchor_height |
| 90 | + dims = tf.convert_to_tensor([[[anchor_width, anchor_height]]]) |
| 91 | + |
| 92 | + for scale in self.scales: |
| 93 | + anchor_dims.append(scale * dims) |
| 94 | + anchor_dims = tf.stack(anchor_dims, axis=-2) |
| 95 | + anchor_dims_all.append(anchor_dims) |
| 96 | + return anchor_dims_all |
| 97 | + |
| 98 | + def _get_anchors(self, feature_height, feature_width, level): |
| 99 | + """Generates anchor boxes for a given feature map size and level |
| 100 | +
|
| 101 | + Arguments: |
| 102 | + feature_height: An integer representing the height of the feature map. |
| 103 | + feature_width: An integer representing the width of the feature map. |
| 104 | + level: An integer representing the level of the feature map in the |
| 105 | + feature pyramid. |
| 106 | +
|
| 107 | + Returns: |
| 108 | + anchor boxes with the shape |
| 109 | + `(feature_height * feature_width * num_anchors, 4)` |
| 110 | + """ |
| 111 | + rx = tf.range(feature_width, dtype=tf.float32) + 0.5 |
| 112 | + ry = tf.range(feature_height, dtype=tf.float32) + 0.5 |
| 113 | + centers = tf.stack(tf.meshgrid(rx, ry), axis=-1) * self._strides[level - 3] |
| 114 | + centers = tf.expand_dims(centers, axis=-2) |
| 115 | + centers = tf.tile(centers, [1, 1, self._num_anchors, 1]) |
| 116 | + dims = tf.tile( |
| 117 | + self._anchor_dims[level - 3], [feature_height, feature_width, 1, 1] |
| 118 | + ) |
| 119 | + anchors = tf.concat([centers, dims], axis=-1) |
| 120 | + return tf.reshape( |
| 121 | + anchors, [feature_height * feature_width * self._num_anchors, 4] |
| 122 | + ) |
| 123 | + |
| 124 | + def get_anchors(self, image_height, image_width): |
| 125 | + anchors = [ |
| 126 | + self._get_anchors( |
| 127 | + tf.math.ceil(image_height / 2 ** i), |
| 128 | + tf.math.ceil(image_width / 2 ** i), |
| 129 | + i, |
| 130 | + ) |
| 131 | + for i in range(3, 8) |
| 132 | + ] |
| 133 | + return tf.concat(anchors, axis=0) |
| 134 | + |
| 135 | + |
| 136 | +class LabelEncoder: |
| 137 | + def __init__(self, preprocessing_fn): |
| 138 | + self.preprocessing_fn = preprocessing_fn |
| 139 | + |
| 140 | + self._anchor_box = AnchorBox() |
| 141 | + |
| 142 | + # The scaling factors used to scale the bounding box targets. |
| 143 | + self._box_variance = tf.convert_to_tensor( |
| 144 | + [0.1, 0.1, 0.2, 0.2], dtype=tf.float32 |
| 145 | + ) |
| 146 | + |
| 147 | + def _match_anchor_boxes( |
| 148 | + self, anchor_boxes, gt_boxes, match_iou=0.5, ignore_iou=0.4 |
| 149 | + ): |
| 150 | + """Matches ground truth boxes to anchor boxes based on IOU. |
| 151 | +
|
| 152 | + 1. Calculates the pairwise IOU for the M `anchor_boxes` and N `gt_boxes` |
| 153 | + to get a `(M, N)` shaped matrix. |
| 154 | + 2. The ground truth box with the maximum IOU in each row is assigned to |
| 155 | + the anchor box provided the IOU is greater than `match_iou`. |
| 156 | + 3. If the maximum IOU in a row is less than `ignore_iou`, the anchor |
| 157 | + box is assigned with the background class. |
| 158 | + 4. The remaining anchor boxes that do not have any class assigned are |
| 159 | + ignored during training. |
| 160 | +
|
| 161 | + Arguments: |
| 162 | + anchor_boxes: `(total_anchors, 4), [x, y, width, height])` |
| 163 | + gt_boxes: `(num_objects, 4)`, `[x, y, width, height]` |
| 164 | +
|
| 165 | + match_iou: A float value representing the minimum IOU threshold for |
| 166 | + determining if a ground truth box can be assigned to an anchor box. |
| 167 | + ignore_iou: A float value representing the IOU threshold under which |
| 168 | + an anchor box is assigned to the background class. |
| 169 | +
|
| 170 | + Returns: |
| 171 | + matched_gt_idx: Index of the matched object |
| 172 | + positive_mask: A mask for anchor boxes that have been assigned ground |
| 173 | + truth boxes. |
| 174 | + ignore_mask: A mask for anchor boxes that need to by ignored during |
| 175 | + training |
| 176 | + """ |
| 177 | + iou_matrix = bbox_iou(anchor_boxes, gt_boxes) |
| 178 | + max_iou = tf.reduce_max(iou_matrix, axis=1) |
| 179 | + matched_gt_idx = tf.argmax(iou_matrix, axis=1) |
| 180 | + positive_mask = tf.greater_equal(max_iou, match_iou) |
| 181 | + negative_mask = tf.less(max_iou, ignore_iou) |
| 182 | + ignore_mask = tf.logical_not(tf.logical_or(positive_mask, negative_mask)) |
| 183 | + return ( |
| 184 | + matched_gt_idx, |
| 185 | + tf.cast(positive_mask, dtype=tf.float32), |
| 186 | + tf.cast(ignore_mask, dtype=tf.float32), |
| 187 | + ) |
| 188 | + |
| 189 | + def _compute_box_target(self, anchor_boxes, matched_gt_boxes): |
| 190 | + """Transforms the ground truth boxes into targets for training""" |
| 191 | + box_target = tf.concat( |
| 192 | + [ |
| 193 | + (matched_gt_boxes[:, :2] - anchor_boxes[:, :2]) / anchor_boxes[:, 2:], |
| 194 | + tf.math.log(matched_gt_boxes[:, 2:] / anchor_boxes[:, 2:]), |
| 195 | + ], |
| 196 | + axis=-1, |
| 197 | + ) |
| 198 | + box_target = box_target / self._box_variance |
| 199 | + return box_target |
| 200 | + |
| 201 | + def _encode_sample(self, image_shape, gt_boxes, cls_ids): |
| 202 | + anchor_boxes = self._anchor_box.get_anchors(image_shape[1], image_shape[2]) |
| 203 | + cls_ids = tf.cast(cls_ids, dtype=tf.float32) |
| 204 | + matched_gt_idx, positive_mask, ignore_mask = self._match_anchor_boxes( |
| 205 | + anchor_boxes, gt_boxes |
| 206 | + ) |
| 207 | + matched_gt_boxes = tf.gather(gt_boxes, matched_gt_idx) |
| 208 | + box_target = self._compute_box_target(anchor_boxes, matched_gt_boxes) |
| 209 | + matched_gt_cls_ids = tf.gather(cls_ids, matched_gt_idx) |
| 210 | + cls_target = tf.where( |
| 211 | + tf.not_equal(positive_mask, 1.0), -1.0, matched_gt_cls_ids |
| 212 | + ) |
| 213 | + cls_target = tf.where(tf.equal(ignore_mask, 1.0), -2.0, cls_target) |
| 214 | + cls_target = tf.expand_dims(cls_target, axis=-1) |
| 215 | + label = tf.concat([box_target, cls_target], axis=-1) |
| 216 | + return label |
| 217 | + |
| 218 | + def encode_batch(self, batch_images, gt_boxes, cls_ids): |
| 219 | + """Creates box and classification targets for a batch""" |
| 220 | + images_shape = tf.shape(batch_images) |
| 221 | + batch_size = images_shape[0] |
| 222 | + |
| 223 | + labels = tf.TensorArray(dtype=tf.float32, size=batch_size, dynamic_size=True) |
| 224 | + for i in range(batch_size): |
| 225 | + label = self._encode_sample(images_shape, gt_boxes[i], cls_ids[i]) |
| 226 | + labels = labels.write(i, label) |
| 227 | + batch_images = self.preprocessing_fn(batch_images) |
| 228 | + return batch_images, labels.stack() |
0 commit comments