Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 50 additions & 24 deletions src/seamless_communication/cli/m4t/finetune/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from seamless_communication.models.unity import (
UnitYModel,
UnitYT2UModel,
UnitYNART2UModel,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -100,7 +101,8 @@ def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device):
def forward(
self, batch: dataloader.MultimodalSeqsBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
dummy_context = contextmanager(lambda: iter([None]))()
# dummy_context = contextmanager(lambda: iter([None]))()
dummy_context = nullcontext()
with torch.no_grad() if self.freeze_s2t else dummy_context: # type:ignore
assert batch.speech_to_text.src_tokens is not None
seqs = batch.speech_to_text.src_tokens.to(self.device)
Expand All @@ -125,30 +127,54 @@ def forward(
return (text_logits, None)
assert self.model.t2u_model is not None
assert batch.text_to_units.prev_output_tokens is not None
dummy_context = contextmanager(lambda: iter([None]))()
with torch.no_grad() if self.freeze_t2u else dummy_context: # type:ignore
if not isinstance(self.model.t2u_model, UnitYT2UModel):
dummy_context = nullcontext()
with torch.no_grad() if self.freeze_t2u else nullcontext():
if isinstance(self.model.t2u_model, UnitYNART2UModel):
assert batch.speech_to_text.target_tokens is not None
text_seqs = batch.speech_to_text.target_tokens.to(self.device)
unit_output, unit_decoder_padding_mask, durations = self.model.t2u_model.forward(
text_decoder_output=text_decoder_out,
text_decoder_padding_mask=text_decoder_padding_mask,
text_seqs=text_seqs,
duration_factor=1.0,
)
unit_logits = unit_output.logits
target_length = batch.text_to_units.target_tokens.size(1)
generated_length = unit_logits.size(1)
if generated_length != target_length:
if generated_length < target_length:
pad_size = target_length - generated_length
unit_logits = torch.nn.functional.pad(
unit_logits,
(0, 0, 0, pad_size),
value=0
)
else:
unit_logits = unit_logits[:, :target_length, :]

elif isinstance(self.model.t2u_model, UnitYT2UModel):
(
unit_encoder_out,
unit_encoder_padding_mask,
) = self.model.t2u_model.encode(
seqs=text_decoder_out,
padding_mask=text_decoder_padding_mask,
)
seqs = batch.text_to_units.prev_output_tokens.to(self.device)
assert batch.text_to_units.target_lengths is not None
seq_lens = batch.text_to_units.target_lengths.to(self.device)
unit_decoder_out, _ = self.model.t2u_model.decode(
seqs=seqs,
padding_mask=PaddingMask(seq_lens, seqs.size(1)),
encoder_output=unit_encoder_out,
encoder_padding_mask=unit_encoder_padding_mask,
)
unit_logits = self.model.t2u_model.final_proj(unit_decoder_out)
else:
raise NotImplementedError(
"T2U finetuning implemented only for UnitYT2UModel"
f"T2U finetuning not implemented for {type(self.model.t2u_model).__name__}. "
f"Supported models: UnitYT2UModel (v1), UnitYNART2UModel (v2)"
)
(
unit_encoder_out,
unit_encoder_padding_mask,
) = self.model.t2u_model.encode(
seqs=text_decoder_out,
padding_mask=text_decoder_padding_mask,
)
seqs = batch.text_to_units.prev_output_tokens.to(self.device)
assert batch.text_to_units.target_lengths is not None
seq_lens = batch.text_to_units.target_lengths.to(self.device)
unit_decoder_out, _ = self.model.t2u_model.decode(
seqs=seqs,
padding_mask=PaddingMask(seq_lens, seqs.size(1)),
encoder_output=unit_encoder_out,
encoder_padding_mask=unit_encoder_padding_mask,
)
unit_logits = self.model.t2u_model.final_proj(unit_decoder_out)

return (text_logits, unit_logits)


Expand Down Expand Up @@ -436,4 +462,4 @@ def run(self) -> None:
)
break

self.epoch_idx += 1
self.epoch_idx += 1