diff --git a/hackathon/metrics.py b/hackathon/metrics.py new file mode 100644 index 0000000..bc4b6c7 --- /dev/null +++ b/hackathon/metrics.py @@ -0,0 +1,72 @@ +def compute_map(preds, gt_bboxes, gt_labels, iou_threshold=0.5): + """ + Computes a simplified mean Average Precision (mAP) at IoU=0.5 + + Args: + preds: List of dicts with keys: 'bbox', 'score', 'label' + gt_bboxes: List of [cx, cy, w, h] + gt_labels: List of int + + Returns: + float: average precision for one image + """ + matched = 0 + total_pred = len(preds) + total_gt = len(gt_bboxes) + + used_gt = set() + + for pred in preds: + pred_box = pred["bbox"] + pred_label = pred["label"] + + for i, (gt_box, gt_label) in enumerate(zip(gt_bboxes, gt_labels)): + if i in used_gt: + continue + if pred_label != gt_label: + continue + + iou = compute_iou(pred_box, gt_box) + if iou >= iou_threshold: + matched += 1 + used_gt.add(i) + break + + precision = matched / (total_pred + 1e-6) + recall = matched / (total_gt + 1e-6) + ap = precision * recall # Simplified stand-in + + return ap + + +def compute_iou(box1, box2): + """ + Computes IoU between two bounding boxes in [cx, cy, w, h] format + + Returns: + float: IoU value + """ + def to_corners(box): + cx, cy, w, h = box + return [ + cx - w / 2, + cy - h / 2, + cx + w / 2, + cy + h / 2, + ] + + x1_min, y1_min, x1_max, y1_max = to_corners(box1) + x2_min, y2_min, x2_max, y2_max = to_corners(box2) + + inter_xmin = max(x1_min, x2_min) + inter_ymin = max(y1_min, y2_min) + inter_xmax = min(x1_max, x2_max) + inter_ymax = min(y1_max, y2_max) + + inter_area = max(0.0, inter_xmax - inter_xmin) * max(0.0, inter_ymax - inter_ymin) + area1 = (x1_max - x1_min) * (y1_max - y1_min) + area2 = (x2_max - x2_min) * (y2_max - y2_min) + + union = area1 + area2 - inter_area + + return inter_area / (union + 1e-6) diff --git a/hackathon/objectdetection.py b/hackathon/objectdetection.py index 76af579..030c944 100644 --- a/hackathon/objectdetection.py +++ b/hackathon/objectdetection.py @@ -22,6 +22,8 @@ prepare_image, ) +from hackathon.metrics import compute_map + def preprocessing( sample, @@ -148,6 +150,45 @@ def loss_for_step(p): return params, opt_state, loss +def postprocess_predictions(logits, conf_threshold=0.3): + """ + Converts YOLO-style model output into detection predictions. + + Args: + logits: jnp.array of shape [H, W, 4 + 1 + C] + conf_threshold: float, minimum score to keep a prediction + + Returns: + List of dicts with keys: 'bbox', 'score', 'label' + """ + H, W, D = logits.shape + num_classes = D - 5 + + logits = jax.nn.sigmoid(logits) # apply sigmoid to all outputs + + bboxes = logits[..., :4] # center_x, center_y, width, height + objectness = logits[..., 4] # shape [H, W] + class_probs = logits[..., 5:] # shape [H, W, C] + + scores = objectness[..., None] * class_probs # shape [H, W, C] + + results = [] + + for i in range(H): + for j in range(W): + for c in range(num_classes): + score = scores[i, j, c] + if score > conf_threshold: + cx, cy, w, h = bboxes[i, j] + bbox = [float(cx), float(cy), float(w), float(h)] + results.append({ + "bbox": bbox, + "score": float(score), + "label": int(c) + }) + return results + + def evaluate(dataset, params, static, key, config, seed): dataset_identifier = get_identifier(dataset) @@ -232,8 +273,32 @@ def evaluate(dataset, params, static, key, config, seed): ) for batch in progress_bar: key, subkey = jr.split(key) - # Test on a given metric on the validation set (e.g. mAP) - raise NotImplementedError + + images = batch["global_crops"][:, 0] # (B, H, W, C) + images = nhwc_to_nchw(images) # (B, C, H, W) since model expects NCHW + + true_bboxes = batch["objects"]["bboxes"] + true_labels = batch["objects"]["labels"] + + + model = eqx.combine(params, static) + preds = jax.vmap(model)(images) # (B, H, W, A*(C+5)) + + + batch_map = [] + for i in range(images.shape[0]): + pred_logits = preds[i] + gt_boxes = true_bboxes[i] + gt_labels = true_labels[i] + + pred_objs = postprocess_predictions(pred_logits, conf_threshold=0.3) + + + sample_ap = compute_map(pred_objs, gt_boxes, gt_labels) + batch_map.append(sample_ap) + + avg_map = sum(batch_map) / len(batch_map) + progress_bar.set_postfix({"mAP@0.5": f"{avg_map:.4f}"}) logger.info("Evaluation completed!")