Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions hackathon/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
def compute_map(preds, gt_bboxes, gt_labels, iou_threshold=0.5):
"""
Computes a simplified mean Average Precision (mAP) at IoU=0.5

Args:
preds: List of dicts with keys: 'bbox', 'score', 'label'
gt_bboxes: List of [cx, cy, w, h]
gt_labels: List of int

Returns:
float: average precision for one image
"""
matched = 0
total_pred = len(preds)
total_gt = len(gt_bboxes)

used_gt = set()

for pred in preds:
pred_box = pred["bbox"]
pred_label = pred["label"]

for i, (gt_box, gt_label) in enumerate(zip(gt_bboxes, gt_labels)):
if i in used_gt:
continue
if pred_label != gt_label:
continue

iou = compute_iou(pred_box, gt_box)
if iou >= iou_threshold:
matched += 1
used_gt.add(i)
break

precision = matched / (total_pred + 1e-6)
recall = matched / (total_gt + 1e-6)
ap = precision * recall # Simplified stand-in

return ap


def compute_iou(box1, box2):
"""
Computes IoU between two bounding boxes in [cx, cy, w, h] format

Returns:
float: IoU value
"""
def to_corners(box):
cx, cy, w, h = box
return [
cx - w / 2,
cy - h / 2,
cx + w / 2,
cy + h / 2,
]

x1_min, y1_min, x1_max, y1_max = to_corners(box1)
x2_min, y2_min, x2_max, y2_max = to_corners(box2)

inter_xmin = max(x1_min, x2_min)
inter_ymin = max(y1_min, y2_min)
inter_xmax = min(x1_max, x2_max)
inter_ymax = min(y1_max, y2_max)

inter_area = max(0.0, inter_xmax - inter_xmin) * max(0.0, inter_ymax - inter_ymin)
area1 = (x1_max - x1_min) * (y1_max - y1_min)
area2 = (x2_max - x2_min) * (y2_max - y2_min)

union = area1 + area2 - inter_area

return inter_area / (union + 1e-6)
69 changes: 67 additions & 2 deletions hackathon/objectdetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
prepare_image,
)

from hackathon.metrics import compute_map


def preprocessing(
sample,
Expand Down Expand Up @@ -148,6 +150,45 @@ def loss_for_step(p):
return params, opt_state, loss


def postprocess_predictions(logits, conf_threshold=0.3):
"""
Converts YOLO-style model output into detection predictions.

Args:
logits: jnp.array of shape [H, W, 4 + 1 + C]
conf_threshold: float, minimum score to keep a prediction

Returns:
List of dicts with keys: 'bbox', 'score', 'label'
"""
H, W, D = logits.shape
num_classes = D - 5

logits = jax.nn.sigmoid(logits) # apply sigmoid to all outputs

bboxes = logits[..., :4] # center_x, center_y, width, height
objectness = logits[..., 4] # shape [H, W]
class_probs = logits[..., 5:] # shape [H, W, C]

scores = objectness[..., None] * class_probs # shape [H, W, C]

results = []

for i in range(H):
for j in range(W):
for c in range(num_classes):
score = scores[i, j, c]
if score > conf_threshold:
cx, cy, w, h = bboxes[i, j]
bbox = [float(cx), float(cy), float(w), float(h)]
results.append({
"bbox": bbox,
"score": float(score),
"label": int(c)
})
return results


def evaluate(dataset, params, static, key, config, seed):
dataset_identifier = get_identifier(dataset)

Expand Down Expand Up @@ -232,8 +273,32 @@ def evaluate(dataset, params, static, key, config, seed):
)
for batch in progress_bar:
key, subkey = jr.split(key)
# Test on a given metric on the validation set (e.g. mAP)
raise NotImplementedError

images = batch["global_crops"][:, 0] # (B, H, W, C)
images = nhwc_to_nchw(images) # (B, C, H, W) since model expects NCHW

true_bboxes = batch["objects"]["bboxes"]
true_labels = batch["objects"]["labels"]


model = eqx.combine(params, static)
preds = jax.vmap(model)(images) # (B, H, W, A*(C+5))


batch_map = []
for i in range(images.shape[0]):
pred_logits = preds[i]
gt_boxes = true_bboxes[i]
gt_labels = true_labels[i]

pred_objs = postprocess_predictions(pred_logits, conf_threshold=0.3)


sample_ap = compute_map(pred_objs, gt_boxes, gt_labels)
batch_map.append(sample_ap)

avg_map = sum(batch_map) / len(batch_map)
progress_bar.set_postfix({"[email protected]": f"{avg_map:.4f}"})

logger.info("Evaluation completed!")

Expand Down