Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions dlclive/pose_estimation_pytorch/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import warnings

import numpy as np
import torch
Expand Down Expand Up @@ -131,15 +132,25 @@ def __init__(
path: str | Path,
device: str = "auto",
precision: Literal["FP16", "FP32"] = "FP32",
single_animal: bool = True,
single_animal: bool | None = None,
dynamic: dict | dynamic_cropping.DynamicCropper | None = None,
top_down_config: dict | TopDownConfig | None = None,
) -> None:
super().__init__(path)
self.device = _parse_device(device)
self.precision = precision
if single_animal is not None:
warnings.warn(
"The `single_animal` parameter is deprecated and will be removed "
"in a future version. The number of individuals will be automatically inferred "
"from the model configuration. Remove argument `single_animal` or set "
"`single_animal=None` to accept the inferred value and silence this warning.",
DeprecationWarning,
stacklevel=2,
)
self.single_animal = single_animal

self.n_individuals = None
self.n_bodyparts = None
self.cfg = None
self.detector = None
self.model = None
Expand Down Expand Up @@ -191,9 +202,14 @@ def get_pose(self, frame: np.ndarray) -> np.ndarray:

frame_batch, offsets_and_scales = self._prepare_top_down(tensor, detections)
if len(frame_batch) == 0:
offsets_and_scales = [(0, 0), 1]
else:
tensor = frame_batch # still CHW, batched
zero_pose = (
np.zeros((self.n_bodyparts, 3))
if self.n_individuals < 2 else
np.zeros((self.n_individuals, self.n_bodyparts, 3))
Comment on lines +206 to +208
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The early return uses self.n_individuals < 2 to determine the output shape, while the regular path uses self.single_animal at line 240. These two conditions could differ if a user explicitly sets single_animal to a value that doesn't match the model configuration. For consistency, both code paths should use the same condition (either both use self.single_animal or both use self.n_individuals < 2).

Suggested change
np.zeros((self.n_bodyparts, 3))
if self.n_individuals < 2 else
np.zeros((self.n_individuals, self.n_bodyparts, 3))
np.zeros((self.n_bodyparts, 3))
if self.single_animal
else np.zeros((self.n_individuals, self.n_bodyparts, 3))

Copilot uses AI. Check for mistakes.
)
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When returning early due to no detections, the skip_frames state is not updated. This means the age counter doesn't get incremented and the cached detections don't get cleared. This could lead to stale detections being used in subsequent frames or incorrect frame skipping behavior. Consider updating the skip_frames state before returning, or restructuring the logic to ensure skip_frames.update() is always called when skip_frames is enabled.

Suggested change
)
)
if self.top_down_config.skip_frames is not None:
zero_pose_tensor = torch.from_numpy(zero_pose).to(self.device)
self.top_down_config.skip_frames.update(zero_pose_tensor, w, h)

Copilot uses AI. Check for mistakes.
Comment on lines +205 to +209
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If get_pose is called before load_model is called, and the model has a detector with no detections, this will fail with a TypeError when comparing self.n_individuals (which is None) with 2. While calling get_pose before load_model is already a misuse of the API, the error message could be improved. Consider adding a check that self.n_individuals and self.n_bodyparts are not None, or document that load_model must be called first.

Copilot uses AI. Check for mistakes.
return zero_pose

tensor = frame_batch # still CHW, batched

if self.dynamic is not None:
tensor = self.dynamic.crop(tensor)
Expand Down Expand Up @@ -260,6 +276,15 @@ def load_model(self) -> None:
raw_data = torch.load(self.path, map_location="cpu", weights_only=True)

self.cfg = raw_data["config"]

# Infer single animal mode and n_bodyparts from model configuration
individuals = self.cfg.get("metadata", {}).get("individuals", ['idv1'])
bodyparts = self.cfg.get("metadata", {}).get("bodyparts", [])
self.n_individuals = len(individuals)
self.n_bodyparts = len(bodyparts)
if self.single_animal is None:
self.single_animal = self.n_individuals == 1

self.model = models.PoseModel.build(self.cfg["model"])
self.model.load_state_dict(raw_data["pose"])
self.model = self.model.to(self.device)
Expand Down