@@ -41,12 +41,12 @@ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0)
41
41
42
42
43
43
class VideoPositionEmb (nn .Module ):
44
- def forward (self , x_B_T_H_W_C : torch .Tensor , fps = Optional [torch .Tensor ], device = None ) -> torch .Tensor :
44
+ def forward (self , x_B_T_H_W_C : torch .Tensor , fps = Optional [torch .Tensor ], device = None , dtype = None ) -> torch .Tensor :
45
45
"""
46
46
It delegates the embedding generation to generate_embeddings function.
47
47
"""
48
48
B_T_H_W_C = x_B_T_H_W_C .shape
49
- embeddings = self .generate_embeddings (B_T_H_W_C , fps = fps , device = device )
49
+ embeddings = self .generate_embeddings (B_T_H_W_C , fps = fps , device = device , dtype = dtype )
50
50
51
51
return embeddings
52
52
@@ -104,6 +104,7 @@ def generate_embeddings(
104
104
w_ntk_factor : Optional [float ] = None ,
105
105
t_ntk_factor : Optional [float ] = None ,
106
106
device = None ,
107
+ dtype = None ,
107
108
):
108
109
"""
109
110
Generate embeddings for the given input size.
@@ -189,13 +190,12 @@ def __init__(
189
190
self .pos_emb_w = nn .Parameter (torch .empty (len_w , model_channels , device = device , dtype = dtype ))
190
191
self .pos_emb_t = nn .Parameter (torch .empty (len_t , model_channels , device = device , dtype = dtype ))
191
192
192
-
193
- def generate_embeddings (self , B_T_H_W_C : torch .Size , fps = Optional [torch .Tensor ], device = None ) -> torch .Tensor :
193
+ def generate_embeddings (self , B_T_H_W_C : torch .Size , fps = Optional [torch .Tensor ], device = None , dtype = None ) -> torch .Tensor :
194
194
B , T , H , W , _ = B_T_H_W_C
195
195
if self .interpolation == "crop" :
196
- emb_h_H = self .pos_emb_h [:H ].to (device = device )
197
- emb_w_W = self .pos_emb_w [:W ].to (device = device )
198
- emb_t_T = self .pos_emb_t [:T ].to (device = device )
196
+ emb_h_H = self .pos_emb_h [:H ].to (device = device , dtype = dtype )
197
+ emb_w_W = self .pos_emb_w [:W ].to (device = device , dtype = dtype )
198
+ emb_t_T = self .pos_emb_t [:T ].to (device = device , dtype = dtype )
199
199
emb = (
200
200
repeat (emb_t_T , "t d-> b t h w d" , b = B , h = H , w = W )
201
201
+ repeat (emb_h_H , "h d-> b t h w d" , b = B , t = T , w = W )
0 commit comments