Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 38 additions & 15 deletions src/seamless_communication/cli/m4t/finetune/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import logging
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -100,6 +100,7 @@ def __init__(
unit_tokenizer: UnitTokenizer,
dataset_manifest_path: str,
batching_config: BatchingConfig,
max_src_tokens_per_batch: int = 100000
):
self.text_tokenizer = text_tokenizer
self.text_encoders_per_lang: Dict[str, TextTokenEncoder] = {}
Expand All @@ -115,6 +116,7 @@ def __init__(
"dtype": self.batching_config.float_dtype,
}
self.dataset = self._load_manifest(dataset_manifest_path)
self.max_src_tokens_per_batch = max_src_tokens_per_batch

def get_dataloader(self) -> DataLoader[SeqsBatch]:
subset = split_dataset_by_node(
Expand Down Expand Up @@ -156,9 +158,9 @@ def _get_tokenized_target_text(self, sample: LangPairSample) -> Tensor:
"""Expected sequence is [<eos>, <lang_tok> , ..text tokens.., <eos>]"""
target_lang = sample.target.lang
if target_lang not in self.text_encoders_per_lang:
self.text_encoders_per_lang[
target_lang
] = self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
self.text_encoders_per_lang[target_lang] = (
self.text_tokenizer.create_encoder(lang=target_lang, mode="target")
)
tokens = self.text_encoders_per_lang[target_lang](sample.target.text)
eos_idx = self.text_tokenizer.vocab_info.eos_idx
tokens = torch.concat([tokens, torch.LongTensor([eos_idx])])
Expand All @@ -170,9 +172,9 @@ def _get_tokenized_units(self, sample: LangPairSample) -> Optional[Tensor]:
return None
target_lang = sample.target.lang
if target_lang not in self.unit_encoders_per_lang:
self.unit_encoders_per_lang[
target_lang
] = self.unit_tokenizer.create_encoder(lang=target_lang)
self.unit_encoders_per_lang[target_lang] = (
self.unit_tokenizer.create_encoder(lang=target_lang)
)
tokens = self.unit_encoders_per_lang[target_lang](
torch.LongTensor(sample.target.units).unsqueeze(0)
)
Expand All @@ -195,20 +197,41 @@ def _is_long_src_audio(self, sample: LangPairSample) -> bool:
length_s: float = max(wav.shape) / sample_rate
return length_s > self.batching_config.max_audio_length_sec

def _drop_overflow_samples(
self, samples_with_fbanks: List[Tuple[LangPairSample, torch.Tensor]]
) -> List[Tuple[LangPairSample, torch.Tensor]]:
# filter by src_tokens length (reverse)
samples_with_fbanks = sorted(
samples_with_fbanks, key=lambda sb: -sb[1].shape[0]
)
bwd = samples_with_fbanks[0][1].shape[0]
max_samples_for_batch = min(1, self.max_src_tokens_per_batch // bwd)
if max_samples_for_batch < len(samples_with_fbanks):
samples_with_fbanks = samples_with_fbanks[:max_samples_for_batch]
return samples_with_fbanks

def _prepare_batch(self, raw_samples: List[Dict[str, Any]]) -> MultimodalSeqsBatch:
samples = [LangPairSample.from_json(sample) for sample in raw_samples]
# input speech
# - filter long audio samples
filtered_samples = [sample for sample in samples if not self._is_long_src_audio(sample)]
samples = filtered_samples if filtered_samples else [samples[0]] # keep at least one sample
src_tokens_list = [self._get_source_fbank(sample) for sample in samples]
filtered_samples = [
sample for sample in samples if not self._is_long_src_audio(sample)
]
samples = (
filtered_samples if filtered_samples else [samples[0]]
) # keep at least one sample
with_fbanks = [(sample, self._get_source_fbank(sample)) for sample in samples]
# - filter NaNs in fbanks
with_nans = [fbank.isnan().any().item() for fbank in src_tokens_list]
samples = [sample for sample, skip in zip(samples, with_nans) if not skip]
assert len(samples) > 0
src_tokens_list = [
src_toks for src_toks, skip in zip(src_tokens_list, with_nans) if not skip
filtered = [
(sample, fbank)
for sample, fbank in with_fbanks
if not fbank.isnan().any().item()
]
filtered = self._drop_overflow_samples(filtered)

samples = [sample for sample, _ in filtered]
src_tokens_list = [src_tokens for _, src_tokens in filtered]
assert len(samples) > 0
src_tokens = self._batch_tensors(
src_tokens_list, pad_value=self.batching_config.fbank_feats_pad_idx
).to(self.batching_config.float_dtype)
Expand Down
32 changes: 23 additions & 9 deletions src/seamless_communication/cli/m4t/finetune/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from seamless_communication.datasets.huggingface import (
Speech2SpeechFleursDatasetBuilder,
Speech2SpeechGigaSpeechDatasetBuilder,
SpeechTokenizer,
)
from seamless_communication.models.unit_extractor import UnitExtractor
Expand Down Expand Up @@ -123,25 +124,31 @@ def download_fleurs_dataset(
target_lang: str,
split: str,
save_directory: str,
max_samples: int = 100_000,
) -> str:
_check_lang_code_mapping(source_lang)
_check_lang_code_mapping(target_lang)
device = (
torch.device("cuda:0") if torch.cuda.device_count() > 0 else torch.device("cpu")
)
tokenizer = UnitSpeechTokenizer(device=device)
dataset_iterator = Speech2SpeechFleursDatasetBuilder(
source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang],
target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang],
dataset_cache_dir=save_directory,
speech_tokenizer=tokenizer,
skip_source_audio=True, # don't extract units from source audio
skip_target_audio=False,
split=split,
)
if 1:
dataset_iterator = Speech2SpeechGigaSpeechDatasetBuilder(split=split, dataset_cache_dir=save_directory)
else:
dataset_iterator = Speech2SpeechFleursDatasetBuilder(
source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang],
target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang],
dataset_cache_dir=save_directory,
speech_tokenizer=tokenizer,
skip_source_audio=True, # don't extract units from source audio
skip_target_audio=False,
split=split,
)
manifest_path: str = os.path.join(save_directory, f"{split}_manifest.json")
with open(manifest_path, "w") as fp_out:
for idx, sample in enumerate(dataset_iterator.__iter__(), start=1):
if idx >= max_samples:
break
# correction as FleursDatasetBuilder return fleurs lang codes
sample.source.lang = source_lang
sample.target.lang = target_lang
Expand Down Expand Up @@ -183,6 +190,12 @@ def init_parser() -> argparse.ArgumentParser:
required=True,
help="Directory where the datastets will be stored with HuggingFace datasets cache files",
)
parser.add_argument(
"--max_samples",
type=int,
default=100_000,
help="Max samples to use",
)
return parser


Expand All @@ -193,6 +206,7 @@ def main() -> None:
target_lang=args.target_lang,
split=args.split,
save_directory=args.save_dir,
max_samples=args.max_samples,
)
logger.info(f"Manifest saved to: {manifest_path}")

Expand Down
5 changes: 4 additions & 1 deletion src/seamless_communication/cli/m4t/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,13 @@ def main() -> None:
dist_utils.init_distributed([logger, trainer.logger])
text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name)
unit_tokenizer: UnitTokenizer = load_unity_unit_tokenizer(args.model_name)
float_dtype = torch.float16 if torch.device(args.device).type != "cpu" else torch.bfloat16
logger.info(f"Training precision: {float_dtype}")
finetune_params = trainer.FinetuneParams(
finetune_mode=args.mode,
save_model_path=args.save_model_to,
device=torch.device(args.device),
float_dtype=torch.float16 if torch.device(args.device).type != "cpu" else torch.bfloat16,
float_dtype=float_dtype,
train_batch_size=args.batch_size,
eval_batch_size=args.batch_size,
patience=args.patience,
Expand Down Expand Up @@ -174,6 +176,7 @@ def main() -> None:
float_dtype=finetune_params.float_dtype,
),
dataset_manifest_path=args.train_dataset,
max_src_tokens_per_batch=7000,
)
eval_dataloader = dataloader.UnitYDataLoader(
text_tokenizer=text_tokenizer,
Expand Down
106 changes: 106 additions & 0 deletions src/seamless_communication/cli/m4t/finetune/mini_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import torch
from datasets import load_dataset
from jiwer import wer
import os
from typing import Tuple, Iterable, Dict, Any
import logging
from whisper.normalizers import EnglishTextNormalizer

logging.basicConfig(level=logging.INFO)

from seamless_communication.models.unity import UnitYModel
from seamless_communication.inference import Translator

log = logging.getLogger("l")

TOKEN = "dummy"
MAX_SAMPLES = 100
CHCK_PATH = os.path.expanduser("~/tune_chck/chck.pt")

norm = EnglishTextNormalizer()


DATASET = [] # type:ignore


def __iterate_test_ds() -> Iterable[Tuple[torch.Tensor, str]]:
ds = load_dataset(
"speechcolab/gigaspeech",
"s",
token=os.environ.get("HF_TOKEN", TOKEN),
split="test",
streaming=True,
trust_remote_code=True,
)
for idx, item in enumerate(ds):
if idx >= MAX_SAMPLES:
break
assert item["audio"]["sampling_rate"] == 16000
yield (torch.from_numpy(item["audio"]["array"]), item["text"])


def _iterate_test_ds() -> Iterable[Tuple[torch.Tensor, str]]:
global DATASET
if not DATASET:
DATASET = list(__iterate_test_ds())
yield from DATASET


def _eval(translator: Translator) -> float:
references = []
predictions = []
for idx, (wav, text) in enumerate(_iterate_test_ds()):
reference = norm(text)
if not reference:
reference = "."
references.append(reference)
prediction = str(
translator.predict(
input=wav,
task_str="s2tt",
tgt_lang="eng",
src_lang="eng",
)[0][0]
)
prediction = norm(prediction)
if not prediction:
prediction = "."
log.info(idx)
log.info(f"REF: {reference}")
log.info(f"PRE: {prediction}")
log.info("----")
predictions.append(prediction)
return wer(reference=references, hypothesis=predictions) # type:ignore


def _select_keys(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]:
return {key.replace(prefix, ""): value for key, value in state_dict.items() if key.startswith(prefix)}


def load_checkpoint(model: UnitYModel, chck_path: str) -> None:
state_dict = torch.load(chck_path, map_location="cpu")
model.speech_encoder_frontend.load_state_dict(_select_keys(state_dict, "model.speech_encoder_frontend."))
model.speech_encoder.load_state_dict(_select_keys(state_dict, "model.speech_encoder."))
assert model.text_decoder_frontend is not None
model.text_decoder_frontend.load_state_dict(_select_keys(state_dict, "model.text_decoder_frontend."))
assert model.text_decoder is not None
model.text_decoder.load_state_dict(_select_keys(state_dict, "model.text_decoder."))


def main() -> None:
translator = Translator(
model_name_or_card="seamlessM4T_medium",
vocoder_name_or_card=None,
device=torch.device("cuda"),
)
non_tuned_wer = _eval(translator)

load_checkpoint(translator.model, CHCK_PATH)
tuned_wer = _eval(translator)

log.info(f"WER non-tuned: {non_tuned_wer:.3f}")
log.info(f"WER tuned: {tuned_wer:.3f}")


if __name__ == "__main__":
main()
Loading