Skip to content

Commit 7fbebaa

Browse files
committed
Move make_coords in dinov3 RoPE to a free fn so it can be wrapped for fx
1 parent 508df64 commit 7fbebaa

File tree

1 file changed

+54
-44
lines changed

1 file changed

+54
-44
lines changed

timm/layers/pos_embed_sincos.py

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,54 @@ def no_weight_decay(self):
885885
return {'freqs'}
886886

887887

888+
@torch.fx.wrap
889+
@register_notrace_function
890+
def make_coords_dinov3(
891+
height: int,
892+
width: int,
893+
normalize_coords: str = 'separate',
894+
grid_indexing: str = 'ij',
895+
grid_offset: float = 0.,
896+
device: torch.device = 'cpu',
897+
dtype: torch.dtype = torch.float32,
898+
) -> torch.Tensor:
899+
"""Make coordinate grid matching offset and normalization of original.
900+
Returns: coords with shape (HW, 2) in [-1, 1].
901+
"""
902+
# 0.5-centered indices with optional offset
903+
coords_h = torch.arange(0.5, height, device=device, dtype=dtype) + grid_offset
904+
coords_w = torch.arange(0.5, width, device=device, dtype=dtype) + grid_offset
905+
906+
# Normalization denominators
907+
if normalize_coords == "max":
908+
denom = float(max(height, width))
909+
h_denom = denom
910+
w_denom = denom
911+
elif normalize_coords == "min":
912+
denom = float(min(height, width))
913+
h_denom = denom
914+
w_denom = denom
915+
elif normalize_coords == "separate":
916+
h_denom = float(height)
917+
w_denom = float(width)
918+
else:
919+
raise ValueError(f"Unknown normalize_coords: {normalize_coords}")
920+
921+
# Normalize to [0, 1]
922+
coords_h = coords_h / h_denom
923+
coords_w = coords_w / w_denom
924+
925+
# Create grid then map to [-1, 1]
926+
if grid_indexing == "xy":
927+
grid_w, grid_h = torch.meshgrid(coords_w, coords_h, indexing="xy")
928+
coords = torch.stack([grid_h, grid_w], dim=-1) # (H, W, 2) -> (h, w order)
929+
else:
930+
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # (H, W, 2)
931+
coords = coords.flatten(0, 1) # (HW, 2)
932+
coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1]
933+
return coords
934+
935+
888936
class RotaryEmbeddingDinoV3(nn.Module):
889937
"""RoPE for timm DinoV3 port, numerically matching original.
890938
@@ -960,49 +1008,6 @@ def _compute_periods(self, device: torch.device = 'cpu', dtype: torch.dtype = to
9601008

9611009
return periods
9621010

963-
def _make_coords(
964-
self,
965-
height: int,
966-
width: int,
967-
device: torch.device = 'cpu',
968-
dtype: torch.dtype = torch.float32,
969-
) -> torch.Tensor:
970-
"""Make coordinate grid matching offset and normalization of original.
971-
Returns: coords with shape (HW, 2) in [-1, 1].
972-
"""
973-
# 0.5-centered indices with optional offset
974-
coords_h = torch.arange(0.5, height, device=device, dtype=dtype) + self.grid_offset
975-
coords_w = torch.arange(0.5, width, device=device, dtype=dtype) + self.grid_offset
976-
977-
# Normalization denominators
978-
if self.normalize_coords == "max":
979-
denom = float(max(height, width))
980-
h_denom = denom
981-
w_denom = denom
982-
elif self.normalize_coords == "min":
983-
denom = float(min(height, width))
984-
h_denom = denom
985-
w_denom = denom
986-
elif self.normalize_coords == "separate":
987-
h_denom = float(height)
988-
w_denom = float(width)
989-
else:
990-
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
991-
992-
# Normalize to [0, 1]
993-
coords_h = coords_h / h_denom
994-
coords_w = coords_w / w_denom
995-
996-
# Create grid then map to [-1, 1]
997-
if self.grid_indexing == "xy":
998-
grid_w, grid_h = torch.meshgrid(coords_w, coords_h, indexing="xy")
999-
coords = torch.stack([grid_h, grid_w], dim=-1) # (H, W, 2) -> (h, w order)
1000-
else:
1001-
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # (H, W, 2)
1002-
coords = coords.flatten(0, 1) # (HW, 2)
1003-
coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1]
1004-
return coords
1005-
10061011
def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor:
10071012
"""Apply shift/jitter/rescale train time augmentations."""
10081013
if not self.training or not self.aug_active:
@@ -1063,7 +1068,12 @@ def _get_pos_embed_from_coords(self, coords: torch.Tensor) -> Tuple[torch.Tenso
10631068

10641069
def _create_embed(self, feat_shape: List[int], no_aug: bool = False) -> torch.Tensor:
10651070
H, W = feat_shape
1066-
coords = self._make_coords(H, W) # (HW, 2)
1071+
coords = make_coords_dinov3(
1072+
H, W,
1073+
normalize_coords=self.normalize_coords,
1074+
grid_indexing=self.grid_indexing,
1075+
grid_offset=self.grid_offset
1076+
) # (HW, 2)
10671077
if not no_aug:
10681078
coords = self._apply_coord_augs(coords)
10691079
sin, cos = self._get_pos_embed_from_coords(coords) # 2 * (HW, dim)

0 commit comments

Comments
 (0)