Skip to content

Commit 0aa2368

Browse files
Fix some cosmos fp8 issues.
1 parent cca96a8 commit 0aa2368

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

comfy/ldm/cosmos/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def prepare_embedded_sequence(
293293
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
294294

295295
if self.extra_per_block_abs_pos_emb:
296-
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device)
296+
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
297297
else:
298298
extra_pos_emb = None
299299

comfy/ldm/cosmos/position_embedding.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0)
4141

4242

4343
class VideoPositionEmb(nn.Module):
44-
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
44+
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
4545
"""
4646
It delegates the embedding generation to generate_embeddings function.
4747
"""
4848
B_T_H_W_C = x_B_T_H_W_C.shape
49-
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device)
49+
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
5050

5151
return embeddings
5252

@@ -104,6 +104,7 @@ def generate_embeddings(
104104
w_ntk_factor: Optional[float] = None,
105105
t_ntk_factor: Optional[float] = None,
106106
device=None,
107+
dtype=None,
107108
):
108109
"""
109110
Generate embeddings for the given input size.
@@ -189,13 +190,12 @@ def __init__(
189190
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
190191
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
191192

192-
193-
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
193+
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
194194
B, T, H, W, _ = B_T_H_W_C
195195
if self.interpolation == "crop":
196-
emb_h_H = self.pos_emb_h[:H].to(device=device)
197-
emb_w_W = self.pos_emb_w[:W].to(device=device)
198-
emb_t_T = self.pos_emb_t[:T].to(device=device)
196+
emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
197+
emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
198+
emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
199199
emb = (
200200
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
201201
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)

0 commit comments

Comments
 (0)