diff --git a/src/seamless_communication/cli/m4t/finetune/dataloader.py b/src/seamless_communication/cli/m4t/finetune/dataloader.py index 0e58c2d1..5db5fc8e 100644 --- a/src/seamless_communication/cli/m4t/finetune/dataloader.py +++ b/src/seamless_communication/cli/m4t/finetune/dataloader.py @@ -8,7 +8,7 @@ import json import logging from dataclasses import dataclass -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple import numpy as np import torch @@ -100,6 +100,7 @@ def __init__( unit_tokenizer: UnitTokenizer, dataset_manifest_path: str, batching_config: BatchingConfig, + max_src_tokens_per_batch: int = 100000 ): self.text_tokenizer = text_tokenizer self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {} @@ -115,6 +116,7 @@ def __init__( "dtype": self.batching_config.float_dtype, } self.dataset = self._load_manifest(dataset_manifest_path) + self.max_src_tokens_per_batch = max_src_tokens_per_batch def get_dataloader(self) -> DataLoader[SeqsBatch]: subset = split_dataset_by_node( @@ -156,9 +158,9 @@ def _get_tokenized_target_text(self, sample: LangPairSample) -> Tensor: """Expected sequence is [, , ..text tokens.., ]""" target_lang = sample.target.lang if target_lang not in self.text_encoders_per_lang: - self.text_encoders_per_lang[ - target_lang - ] = self.text_tokenizer.create_encoder(lang=target_lang, mode="target") + self.text_encoders_per_lang[target_lang] = ( + self.text_tokenizer.create_encoder(lang=target_lang, mode="target") + ) tokens = self.text_encoders_per_lang[target_lang](sample.target.text) eos_idx = self.text_tokenizer.vocab_info.eos_idx tokens = torch.concat([tokens, torch.LongTensor([eos_idx])]) @@ -170,9 +172,9 @@ def _get_tokenized_units(self, sample: LangPairSample) -> Optional[Tensor]: return None target_lang = sample.target.lang if target_lang not in self.unit_encoders_per_lang: - self.unit_encoders_per_lang[ - target_lang - ] = self.unit_tokenizer.create_encoder(lang=target_lang) + self.unit_encoders_per_lang[target_lang] = ( + self.unit_tokenizer.create_encoder(lang=target_lang) + ) tokens = self.unit_encoders_per_lang[target_lang]( torch.LongTensor(sample.target.units).unsqueeze(0) ) @@ -195,20 +197,41 @@ def _is_long_src_audio(self, sample: LangPairSample) -> bool: length_s: float = max(wav.shape) / sample_rate return length_s > self.batching_config.max_audio_length_sec + def _drop_overflow_samples( + self, samples_with_fbanks: List[Tuple[LangPairSample, torch.Tensor]] + ) -> List[Tuple[LangPairSample, torch.Tensor]]: + # filter by src_tokens length (reverse) + samples_with_fbanks = sorted( + samples_with_fbanks, key=lambda sb: -sb[1].shape[0] + ) + bwd = samples_with_fbanks[0][1].shape[0] + max_samples_for_batch = min(1, self.max_src_tokens_per_batch // bwd) + if max_samples_for_batch < len(samples_with_fbanks): + samples_with_fbanks = samples_with_fbanks[:max_samples_for_batch] + return samples_with_fbanks + def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch: samples = [LangPairSample.from_json(sample) for sample in raw_samples] # input speech # - filter long audio samples - filtered_samples = [sample for sample in samples if not self._is_long_src_audio(sample)] - samples = filtered_samples if filtered_samples else [samples[0]] # keep at least one sample - src_tokens_list = [self._get_source_fbank(sample) for sample in samples] + filtered_samples = [ + sample for sample in samples if not self._is_long_src_audio(sample) + ] + samples = ( + filtered_samples if filtered_samples else [samples[0]] + ) # keep at least one sample + with_fbanks = [(sample, self._get_source_fbank(sample)) for sample in samples] # - filter NaNs in fbanks - with_nans = [fbank.isnan().any().item() for fbank in src_tokens_list] - samples = [sample for sample, skip in zip(samples, with_nans) if not skip] - assert len(samples) > 0 - src_tokens_list = [ - src_toks for src_toks, skip in zip(src_tokens_list, with_nans) if not skip + filtered = [ + (sample, fbank) + for sample, fbank in with_fbanks + if not fbank.isnan().any().item() ] + filtered = self._drop_overflow_samples(filtered) + + samples = [sample for sample, _ in filtered] + src_tokens_list = [src_tokens for _, src_tokens in filtered] + assert len(samples) > 0 src_tokens = self._batch_tensors( src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx ).to(self.batching_config.float_dtype) diff --git a/src/seamless_communication/cli/m4t/finetune/dataset.py b/src/seamless_communication/cli/m4t/finetune/dataset.py index 392160f7..787c44b6 100644 --- a/src/seamless_communication/cli/m4t/finetune/dataset.py +++ b/src/seamless_communication/cli/m4t/finetune/dataset.py @@ -16,6 +16,7 @@ from seamless_communication.datasets.huggingface import ( Speech2SpeechFleursDatasetBuilder, + Speech2SpeechGigaSpeechDatasetBuilder, SpeechTokenizer, ) from seamless_communication.models.unit_extractor import UnitExtractor @@ -123,6 +124,7 @@ def download_fleurs_dataset( target_lang: str, split: str, save_directory: str, + max_samples: int = 100_000, ) -> str: _check_lang_code_mapping(source_lang) _check_lang_code_mapping(target_lang) @@ -130,18 +132,23 @@ def download_fleurs_dataset( torch.device("cuda:0") if torch.cuda.device_count() > 0 else torch.device("cpu") ) tokenizer = UnitSpeechTokenizer(device=device) - dataset_iterator = Speech2SpeechFleursDatasetBuilder( - source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang], - target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang], - dataset_cache_dir=save_directory, - speech_tokenizer=tokenizer, - skip_source_audio=True, # don't extract units from source audio - skip_target_audio=False, - split=split, - ) + if 1: + dataset_iterator = Speech2SpeechGigaSpeechDatasetBuilder(split=split, dataset_cache_dir=save_directory) + else: + dataset_iterator = Speech2SpeechFleursDatasetBuilder( + source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang], + target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang], + dataset_cache_dir=save_directory, + speech_tokenizer=tokenizer, + skip_source_audio=True, # don't extract units from source audio + skip_target_audio=False, + split=split, + ) manifest_path: str = os.path.join(save_directory, f"{split}_manifest.json") with open(manifest_path, "w") as fp_out: for idx, sample in enumerate(dataset_iterator.__iter__(), start=1): + if idx >= max_samples: + break # correction as FleursDatasetBuilder return fleurs lang codes sample.source.lang = source_lang sample.target.lang = target_lang @@ -183,6 +190,12 @@ def init_parser() -> argparse.ArgumentParser: required=True, help="Directory where the datastets will be stored with HuggingFace datasets cache files", ) + parser.add_argument( + "--max_samples", + type=int, + default=100_000, + help="Max samples to use", + ) return parser @@ -193,6 +206,7 @@ def main() -> None: target_lang=args.target_lang, split=args.split, save_directory=args.save_dir, + max_samples=args.max_samples, ) logger.info(f"Manifest saved to: {manifest_path}") diff --git a/src/seamless_communication/cli/m4t/finetune/finetune.py b/src/seamless_communication/cli/m4t/finetune/finetune.py index 6bb1e40a..8c17ce06 100644 --- a/src/seamless_communication/cli/m4t/finetune/finetune.py +++ b/src/seamless_communication/cli/m4t/finetune/finetune.py @@ -133,11 +133,13 @@ def main() -> None: dist_utils.init_distributed([logger, trainer.logger]) text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name) unit_tokenizer: UnitTokenizer = load_unity_unit_tokenizer(args.model_name) + float_dtype = torch.float16 if torch.device(args.device).type != "cpu" else torch.bfloat16 + logger.info(f"Training precision: {float_dtype}") finetune_params = trainer.FinetuneParams( finetune_mode=args.mode, save_model_path=args.save_model_to, device=torch.device(args.device), - float_dtype=torch.float16 if torch.device(args.device).type != "cpu" else torch.bfloat16, + float_dtype=float_dtype, train_batch_size=args.batch_size, eval_batch_size=args.batch_size, patience=args.patience, @@ -174,6 +176,7 @@ def main() -> None: float_dtype=finetune_params.float_dtype, ), dataset_manifest_path=args.train_dataset, + max_src_tokens_per_batch=7000, ) eval_dataloader = dataloader.UnitYDataLoader( text_tokenizer=text_tokenizer, diff --git a/src/seamless_communication/cli/m4t/finetune/mini_eval.py b/src/seamless_communication/cli/m4t/finetune/mini_eval.py new file mode 100644 index 00000000..63fd0e37 --- /dev/null +++ b/src/seamless_communication/cli/m4t/finetune/mini_eval.py @@ -0,0 +1,106 @@ +import torch +from datasets import load_dataset +from jiwer import wer +import os +from typing import Tuple, Iterable, Dict, Any +import logging +from whisper.normalizers import EnglishTextNormalizer + +logging.basicConfig(level=logging.INFO) + +from seamless_communication.models.unity import UnitYModel +from seamless_communication.inference import Translator + +log = logging.getLogger("l") + +TOKEN = "dummy" +MAX_SAMPLES = 100 +CHCK_PATH = os.path.expanduser("~/tune_chck/chck.pt") + +norm = EnglishTextNormalizer() + + +DATASET = [] # type:ignore + + +def __iterate_test_ds() -> Iterable[Tuple[torch.Tensor, str]]: + ds = load_dataset( + "speechcolab/gigaspeech", + "s", + token=os.environ.get("HF_TOKEN", TOKEN), + split="test", + streaming=True, + trust_remote_code=True, + ) + for idx, item in enumerate(ds): + if idx >= MAX_SAMPLES: + break + assert item["audio"]["sampling_rate"] == 16000 + yield (torch.from_numpy(item["audio"]["array"]), item["text"]) + + +def _iterate_test_ds() -> Iterable[Tuple[torch.Tensor, str]]: + global DATASET + if not DATASET: + DATASET = list(__iterate_test_ds()) + yield from DATASET + + +def _eval(translator: Translator) -> float: + references = [] + predictions = [] + for idx, (wav, text) in enumerate(_iterate_test_ds()): + reference = norm(text) + if not reference: + reference = "." + references.append(reference) + prediction = str( + translator.predict( + input=wav, + task_str="s2tt", + tgt_lang="eng", + src_lang="eng", + )[0][0] + ) + prediction = norm(prediction) + if not prediction: + prediction = "." + log.info(idx) + log.info(f"REF: {reference}") + log.info(f"PRE: {prediction}") + log.info("----") + predictions.append(prediction) + return wer(reference=references, hypothesis=predictions) # type:ignore + + +def _select_keys(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]: + return {key.replace(prefix, ""): value for key, value in state_dict.items() if key.startswith(prefix)} + + +def load_checkpoint(model: UnitYModel, chck_path: str) -> None: + state_dict = torch.load(chck_path, map_location="cpu") + model.speech_encoder_frontend.load_state_dict(_select_keys(state_dict, "model.speech_encoder_frontend.")) + model.speech_encoder.load_state_dict(_select_keys(state_dict, "model.speech_encoder.")) + assert model.text_decoder_frontend is not None + model.text_decoder_frontend.load_state_dict(_select_keys(state_dict, "model.text_decoder_frontend.")) + assert model.text_decoder is not None + model.text_decoder.load_state_dict(_select_keys(state_dict, "model.text_decoder.")) + + +def main() -> None: + translator = Translator( + model_name_or_card="seamlessM4T_medium", + vocoder_name_or_card=None, + device=torch.device("cuda"), + ) + non_tuned_wer = _eval(translator) + + load_checkpoint(translator.model, CHCK_PATH) + tuned_wer = _eval(translator) + + log.info(f"WER non-tuned: {non_tuned_wer:.3f}") + log.info(f"WER tuned: {tuned_wer:.3f}") + + +if __name__ == "__main__": + main() diff --git a/src/seamless_communication/cli/m4t/finetune/trainer.py b/src/seamless_communication/cli/m4t/finetune/trainer.py index dfeda956..c73df4ca 100644 --- a/src/seamless_communication/cli/m4t/finetune/trainer.py +++ b/src/seamless_communication/cli/m4t/finetune/trainer.py @@ -11,7 +11,7 @@ from enum import Enum from tqdm import tqdm from pathlib import Path -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch import torch.distributed as dist @@ -21,7 +21,7 @@ from fairseq2.nn.padding import PaddingMask from fairseq2.optim.lr_scheduler import MyleLR from fairseq2.typing import Device -from torch.optim import AdamW +from torch.optim import AdamW, Adam from seamless_communication.cli.m4t.finetune import dataloader, dist_utils from seamless_communication.models.unity import ( @@ -88,11 +88,17 @@ class UnitYFinetuneWrapper(nn.Module): def __init__(self, model: UnitYModel, mode: FinetuneMode, device: Device): super().__init__() self.model: UnitYModel = model + #self._freeze_module(self.model.speech_encoder_frontend) + #self._freeze_module(self.model.speech_encoder) self.freeze_s2t: bool = mode == FinetuneMode.TEXT_TO_SPEECH self.freeze_t2u: bool = mode == FinetuneMode.SPEECH_TO_TEXT logger.info(f"Freeze s2t: {self.freeze_s2t}, freeze t2u: {self.freeze_t2u}") self.device = device + def _freeze_module(self, module: torch.nn.Module) -> None: + for param in module.parameters(): + param.requires_grad = False + def forward( self, batch: dataloader.MultimodalSeqsBatch ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -329,12 +335,11 @@ def _eval_model(self) -> None: assert batch.speech_to_text.src_tokens is not None with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype): loss = self.calc_loss(batch, *self.model(batch)) - if loss.isnan(): - logger.warning("Eval loss value is NaN, setting to inf") - loss_val = float("Inf") - else: - loss_val = loss.item() del batch # force memory release + if loss.isnan(): + logger.warning(".. batch loss value is NaN, skipping") + continue + loss_val = loss.item() loss_hist.update(1, loss_val) eval_loss = loss_hist.reduce() self._update_eval_stats(eval_loss) @@ -351,13 +356,18 @@ def _train_step_log(self) -> None: f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E}" ) - def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None: + def _train_step(self, batches: List[dataloader.MultimodalSeqsBatch]) -> None: """Run one train step""" self.model.train() self.optimizer.zero_grad() - with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype): - tokens, units = self.model(batch) - loss = self.calc_loss(batch, tokens, units) + # logger.info(f"forward start {torch.cuda.memory_allocated(0) >> 30}g") + losses = [] + for batch in batches: + with torch.autocast(device_type=self.params.device.type, dtype=self.params.float_dtype): + tokens, units = self.model(batch) + # logger.info(f"forward done {torch.cuda.memory_allocated(0) >> 30}g") + losses.append(self.calc_loss(batch, tokens, units)) + loss = sum(losses) / len(losses) if loss.isnan().any().item(): logger.error(batch.speech_to_text) raise RuntimeError("Loss is Nan. Terminating.") @@ -365,6 +375,7 @@ def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None: self.grad_scaler.step(self.optimizer) self.grad_scaler.update() self.lr_scheduler.step() + # logger.info(f"backward done {torch.cuda.memory_allocated(0) >> 30}g") assert batch.speech_to_text.src_tokens is not None self.train_loss_hist.update(1, loss.item()) self._train_step_log() @@ -385,19 +396,24 @@ def run(self) -> None: self._reset_stats() self._eval_model() batch_itr = self.train_data_loader.get_dataloader() + batches_per_iter = 1 while self.epoch_idx < self.params.max_epochs and self.patience_left: + train_batches = [] for train_batch in batch_itr: - self._train_step(batch=train_batch) - if self.update_idx and self.update_idx % self.params.eval_steps == 0: - self._eval_model() - if self.is_best_state: - self._save_model() - elif not self.patience_left: - no_improve_steps = self.params.eval_steps * self.params.patience - logger.info( - "Early termination, as eval loss did not improve " - f"over last {no_improve_steps} updates" - ) - break - self.update_idx += 1 + train_batches.append(train_batch) + if len(train_batches) > batches_per_iter: + self._train_step(batches=train_batches) + train_batches = [] + if self.update_idx and self.update_idx % self.params.eval_steps == 0: + self._eval_model() + if self.is_best_state: + self._save_model() + elif not self.patience_left: + no_improve_steps = self.params.eval_steps * self.params.patience + logger.info( + "Early termination, as eval loss did not improve " + f"over last {no_improve_steps} updates" + ) + break + self.update_idx += 1 self.epoch_idx += 1