@@ -885,6 +885,54 @@ def no_weight_decay(self):
885
885
return {'freqs' }
886
886
887
887
888
+ @torch .fx .wrap
889
+ @register_notrace_function
890
+ def make_coords_dinov3 (
891
+ height : int ,
892
+ width : int ,
893
+ normalize_coords : str = 'separate' ,
894
+ grid_indexing : str = 'ij' ,
895
+ grid_offset : float = 0. ,
896
+ device : torch .device = 'cpu' ,
897
+ dtype : torch .dtype = torch .float32 ,
898
+ ) -> torch .Tensor :
899
+ """Make coordinate grid matching offset and normalization of original.
900
+ Returns: coords with shape (HW, 2) in [-1, 1].
901
+ """
902
+ # 0.5-centered indices with optional offset
903
+ coords_h = torch .arange (0.5 , height , device = device , dtype = dtype ) + grid_offset
904
+ coords_w = torch .arange (0.5 , width , device = device , dtype = dtype ) + grid_offset
905
+
906
+ # Normalization denominators
907
+ if normalize_coords == "max" :
908
+ denom = float (max (height , width ))
909
+ h_denom = denom
910
+ w_denom = denom
911
+ elif normalize_coords == "min" :
912
+ denom = float (min (height , width ))
913
+ h_denom = denom
914
+ w_denom = denom
915
+ elif normalize_coords == "separate" :
916
+ h_denom = float (height )
917
+ w_denom = float (width )
918
+ else :
919
+ raise ValueError (f"Unknown normalize_coords: { normalize_coords } " )
920
+
921
+ # Normalize to [0, 1]
922
+ coords_h = coords_h / h_denom
923
+ coords_w = coords_w / w_denom
924
+
925
+ # Create grid then map to [-1, 1]
926
+ if grid_indexing == "xy" :
927
+ grid_w , grid_h = torch .meshgrid (coords_w , coords_h , indexing = "xy" )
928
+ coords = torch .stack ([grid_h , grid_w ], dim = - 1 ) # (H, W, 2) -> (h, w order)
929
+ else :
930
+ coords = torch .stack (torch .meshgrid (coords_h , coords_w , indexing = "ij" ), dim = - 1 ) # (H, W, 2)
931
+ coords = coords .flatten (0 , 1 ) # (HW, 2)
932
+ coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1]
933
+ return coords
934
+
935
+
888
936
class RotaryEmbeddingDinoV3 (nn .Module ):
889
937
"""RoPE for timm DinoV3 port, numerically matching original.
890
938
@@ -960,49 +1008,6 @@ def _compute_periods(self, device: torch.device = 'cpu', dtype: torch.dtype = to
960
1008
961
1009
return periods
962
1010
963
- def _make_coords (
964
- self ,
965
- height : int ,
966
- width : int ,
967
- device : torch .device = 'cpu' ,
968
- dtype : torch .dtype = torch .float32 ,
969
- ) -> torch .Tensor :
970
- """Make coordinate grid matching offset and normalization of original.
971
- Returns: coords with shape (HW, 2) in [-1, 1].
972
- """
973
- # 0.5-centered indices with optional offset
974
- coords_h = torch .arange (0.5 , height , device = device , dtype = dtype ) + self .grid_offset
975
- coords_w = torch .arange (0.5 , width , device = device , dtype = dtype ) + self .grid_offset
976
-
977
- # Normalization denominators
978
- if self .normalize_coords == "max" :
979
- denom = float (max (height , width ))
980
- h_denom = denom
981
- w_denom = denom
982
- elif self .normalize_coords == "min" :
983
- denom = float (min (height , width ))
984
- h_denom = denom
985
- w_denom = denom
986
- elif self .normalize_coords == "separate" :
987
- h_denom = float (height )
988
- w_denom = float (width )
989
- else :
990
- raise ValueError (f"Unknown normalize_coords: { self .normalize_coords } " )
991
-
992
- # Normalize to [0, 1]
993
- coords_h = coords_h / h_denom
994
- coords_w = coords_w / w_denom
995
-
996
- # Create grid then map to [-1, 1]
997
- if self .grid_indexing == "xy" :
998
- grid_w , grid_h = torch .meshgrid (coords_w , coords_h , indexing = "xy" )
999
- coords = torch .stack ([grid_h , grid_w ], dim = - 1 ) # (H, W, 2) -> (h, w order)
1000
- else :
1001
- coords = torch .stack (torch .meshgrid (coords_h , coords_w , indexing = "ij" ), dim = - 1 ) # (H, W, 2)
1002
- coords = coords .flatten (0 , 1 ) # (HW, 2)
1003
- coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1]
1004
- return coords
1005
-
1006
1011
def _apply_coord_augs (self , coords : torch .Tensor ) -> torch .Tensor :
1007
1012
"""Apply shift/jitter/rescale train time augmentations."""
1008
1013
if not self .training or not self .aug_active :
@@ -1063,7 +1068,12 @@ def _get_pos_embed_from_coords(self, coords: torch.Tensor) -> Tuple[torch.Tenso
1063
1068
1064
1069
def _create_embed (self , feat_shape : List [int ], no_aug : bool = False ) -> torch .Tensor :
1065
1070
H , W = feat_shape
1066
- coords = self ._make_coords (H , W ) # (HW, 2)
1071
+ coords = make_coords_dinov3 (
1072
+ H , W ,
1073
+ normalize_coords = self .normalize_coords ,
1074
+ grid_indexing = self .grid_indexing ,
1075
+ grid_offset = self .grid_offset
1076
+ ) # (HW, 2)
1067
1077
if not no_aug :
1068
1078
coords = self ._apply_coord_augs (coords )
1069
1079
sin , cos = self ._get_pos_embed_from_coords (coords ) # 2 * (HW, dim)
0 commit comments