diff --git a/src/seamless_communication/cli/m4t/finetune/trainer.py b/src/seamless_communication/cli/m4t/finetune/trainer.py index d7c2dd95..5ff09128 100644 --- a/src/seamless_communication/cli/m4t/finetune/trainer.py +++ b/src/seamless_communication/cli/m4t/finetune/trainer.py @@ -28,6 +28,7 @@ from seamless_communication.models.unity import ( UnitYModel, UnitYT2UModel, + UnitYNART2UModel, ) logger = logging.getLogger(__name__) @@ -100,7 +101,8 @@ def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device): def forward( self, batch: dataloader.MultimodalSeqsBatch ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - dummy_context = contextmanager(lambda: iter([None]))() + # dummy_context = contextmanager(lambda: iter([None]))() + dummy_context = nullcontext() with torch.no_grad() if self.freeze_s2t else dummy_context: # type:ignore assert batch.speech_to_text.src_tokens is not None seqs = batch.speech_to_text.src_tokens.to(self.device) @@ -125,30 +127,54 @@ def forward( return (text_logits, None) assert self.model.t2u_model is not None assert batch.text_to_units.prev_output_tokens is not None - dummy_context = contextmanager(lambda: iter([None]))() - with torch.no_grad() if self.freeze_t2u else dummy_context: # type:ignore - if not isinstance(self.model.t2u_model, UnitYT2UModel): + dummy_context = nullcontext() + with torch.no_grad() if self.freeze_t2u else nullcontext(): + if isinstance(self.model.t2u_model, UnitYNART2UModel): + assert batch.speech_to_text.target_tokens is not None + text_seqs = batch.speech_to_text.target_tokens.to(self.device) + unit_output, unit_decoder_padding_mask, durations = self.model.t2u_model.forward( + text_decoder_output=text_decoder_out, + text_decoder_padding_mask=text_decoder_padding_mask, + text_seqs=text_seqs, + duration_factor=1.0, + ) + unit_logits = unit_output.logits + target_length = batch.text_to_units.target_tokens.size(1) + generated_length = unit_logits.size(1) + if generated_length != target_length: + if generated_length < target_length: + pad_size = target_length - generated_length + unit_logits = torch.nn.functional.pad( + unit_logits, + (0, 0, 0, pad_size), + value=0 + ) + else: + unit_logits = unit_logits[:, :target_length, :] + + elif isinstance(self.model.t2u_model, UnitYT2UModel): + ( + unit_encoder_out, + unit_encoder_padding_mask, + ) = self.model.t2u_model.encode( + seqs=text_decoder_out, + padding_mask=text_decoder_padding_mask, + ) + seqs = batch.text_to_units.prev_output_tokens.to(self.device) + assert batch.text_to_units.target_lengths is not None + seq_lens = batch.text_to_units.target_lengths.to(self.device) + unit_decoder_out, _ = self.model.t2u_model.decode( + seqs=seqs, + padding_mask=PaddingMask(seq_lens, seqs.size(1)), + encoder_output=unit_encoder_out, + encoder_padding_mask=unit_encoder_padding_mask, + ) + unit_logits = self.model.t2u_model.final_proj(unit_decoder_out) + else: raise NotImplementedError( - "T2U finetuning implemented only for UnitYT2UModel" + f"T2U finetuning not implemented for {type(self.model.t2u_model).__name__}. " + f"Supported models: UnitYT2UModel (v1), UnitYNART2UModel (v2)" ) - ( - unit_encoder_out, - unit_encoder_padding_mask, - ) = self.model.t2u_model.encode( - seqs=text_decoder_out, - padding_mask=text_decoder_padding_mask, - ) - seqs = batch.text_to_units.prev_output_tokens.to(self.device) - assert batch.text_to_units.target_lengths is not None - seq_lens = batch.text_to_units.target_lengths.to(self.device) - unit_decoder_out, _ = self.model.t2u_model.decode( - seqs=seqs, - padding_mask=PaddingMask(seq_lens, seqs.size(1)), - encoder_output=unit_encoder_out, - encoder_padding_mask=unit_encoder_padding_mask, - ) - unit_logits = self.model.t2u_model.final_proj(unit_decoder_out) - return (text_logits, unit_logits) @@ -436,4 +462,4 @@ def run(self) -> None: ) break - self.epoch_idx += 1 \ No newline at end of file + self.epoch_idx += 1