diff --git a/src/seamless_communication/models/aligner/model.py b/src/seamless_communication/models/aligner/model.py index 5981da5c..52e5f65f 100644 --- a/src/seamless_communication/models/aligner/model.py +++ b/src/seamless_communication/models/aligner/model.py @@ -297,8 +297,8 @@ def forward(self, input_text: Tensor, input_unit: Tensor) -> Tuple[Tensor, Tenso attn_lprob, attn_hard_dur = self.alignment_encoder( embs_text, embs_unit, - torch.tensor([embs_text.size(1)]).to(embs_text).int(), - torch.tensor([embs_unit.size(1)]).to(embs_unit).int(), + torch.tensor([embs_text.size(1)]).to(device=embs_text.device).int(), + torch.tensor([embs_unit.size(1)]).to(device=embs_unit.device).int(), ) return attn_lprob, attn_hard_dur