diff --git a/README.md b/README.md index 67bfc5fd..fcc18a6e 100644 --- a/README.md +++ b/README.md @@ -161,7 +161,7 @@ Please check out above [section](#seamlessexpressive-models) on how to acquire ` ### W2v-BERT 2.0 speech encoder | Model Name | #params | checkpoint | | ----------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| W2v-BERT 2.0 | 600M | [🤗 Model card](https://huggingface.co/facebook/conformer-shaw) - [checkpoint](https://huggingface.co/facebook/conformer-shaw/resolve/main/conformer_shaw.pt) +| W2v-BERT 2.0 | 600M | [🤗 Model card](https://huggingface.co/facebook/w2v-bert-2.0) - [checkpoint](https://huggingface.co/facebook/w2v-bert-2.0/resolve/main/conformer_shaw.pt) Here's how you should do a foward pass through the speech encoder: diff --git a/demo/expressive/app.py b/demo/expressive/app.py index 4b7d869b..ec525c4e 100644 --- a/demo/expressive/app.py +++ b/demo/expressive/app.py @@ -29,7 +29,7 @@ load_gcmvn_stats, load_unity_unit_tokenizer, ) -from seamless_communication.cli.expressivity.predict.pretssel_generator import PretsselGenerator +from seamless_communication.inference.pretssel_generator import PretsselGenerator from typing import Tuple from utils import LANGUAGE_CODE_TO_NAME diff --git a/src/seamless_communication/cli/expressivity/evaluate/evaluate.py b/src/seamless_communication/cli/expressivity/evaluate/evaluate.py index 84f41c3e..088de9b8 100644 --- a/src/seamless_communication/cli/expressivity/evaluate/evaluate.py +++ b/src/seamless_communication/cli/expressivity/evaluate/evaluate.py @@ -25,9 +25,6 @@ from torch import Tensor from tqdm import tqdm -from seamless_communication.cli.expressivity.predict.pretssel_generator import ( - PretsselGenerator, -) from seamless_communication.cli.m4t.evaluate.evaluate import ( adjust_output_for_corrupted_inputs, count_lines, @@ -36,11 +33,8 @@ add_inference_arguments, set_generation_opts, ) -from seamless_communication.inference import BatchedSpeechOutput, Translator -from seamless_communication.models.unity import ( - load_gcmvn_stats, - load_unity_unit_tokenizer, -) +from seamless_communication.inference import BatchedSpeechOutput, ExpressiveTranslator +from seamless_communication.models.unity import load_unity_unit_tokenizer from seamless_communication.store import add_gated_assets logging.basicConfig( @@ -55,8 +49,6 @@ def build_data_pipeline( args: Namespace, device: Device, dtype: DataType, - gcmvn_mean: Tensor, - gcmvn_std: Tensor, ) -> DataPipeline: with open(args.data_file, "r") as f: header = f.readline().strip("\n").split("\t") @@ -89,15 +81,8 @@ def build_data_pipeline( dtype=dtype, ) - def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput: - fbank = data["fbank"] - std, mean = torch.std_mean(fbank, dim=0) - data["fbank"] = fbank.subtract(mean).divide(std) - data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std) - return data - pipeline_builder.map( - [decode_audio, convert_to_fbank, normalize_fbank], + [decode_audio, convert_to_fbank], selector=f"{args.audio_field}.data", num_parallel_calls=n_parallel, ) @@ -176,17 +161,10 @@ def main() -> None: unit_tokenizer = load_unity_unit_tokenizer(args.model_name) - _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name) - gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype) - gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype) - - pipeline = build_data_pipeline(args, device, dtype, gcmvn_mean, gcmvn_std) + pipeline = build_data_pipeline(args, device, dtype) - translator = Translator( - args.model_name, - vocoder_name_or_card=None, - device=device, - dtype=dtype, + expressive_translator = ExpressiveTranslator( + args.model_name, args.vocoder_name, device, dtype ) text_generation_opts, unit_generation_opts = set_generation_opts(args) @@ -197,13 +175,6 @@ def main() -> None: f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}" ) - pretssel_generator = PretsselGenerator( - args.vocoder_name, - vocab_info=unit_tokenizer.vocab_info, - device=device, - dtype=dtype, - ) - total_steps = count_lines(args.data_file) - 1 progress_bar = tqdm(total=total_steps) @@ -240,28 +211,16 @@ def main() -> None: src["seqs"] = src["seqs"][valid_sequences] src["seq_lens"] = src["seq_lens"][valid_sequences] - # Skip performing inference when the input is entirely corrupted. + # Skip inference when the input is entirely corrupted. if src["seqs"].numel() > 0: - prosody_encoder_input = example[args.audio_field]["data"]["gcmvn_fbank"] - text_output, unit_output = translator.predict( + text_output, speech_output = expressive_translator.predict( src, - "s2st", args.tgt_lang, - src_lang=args.src_lang, - text_generation_opts=text_generation_opts, - unit_generation_opts=unit_generation_opts, - unit_generation_ngram_filtering=args.unit_generation_ngram_filtering, - duration_factor=args.duration_factor, - prosody_encoder_input=prosody_encoder_input, + text_generation_opts, + unit_generation_opts, + args.unit_generation_ngram_filtering, + args.duration_factor, ) - - assert unit_output is not None - speech_output = pretssel_generator.predict( - unit_output.units, - tgt_lang=args.tgt_lang, - prosody_encoder_input=prosody_encoder_input, - ) - else: text_output = [] speech_output = BatchedSpeechOutput(units=[], audio_wavs=[]) @@ -273,7 +232,7 @@ def main() -> None: speech_output, ) - hyps += [str(s) for s in text_output] + hyps += [s for s in text_output] if args.ref_field is not None and args.ref_field in example: refs += [str(s) for s in example[args.ref_field]] diff --git a/src/seamless_communication/cli/expressivity/predict/predict.py b/src/seamless_communication/cli/expressivity/predict/predict.py index 7ad0e718..dac99887 100644 --- a/src/seamless_communication/cli/expressivity/predict/predict.py +++ b/src/seamless_communication/cli/expressivity/predict/predict.py @@ -6,31 +6,18 @@ import argparse import logging -import torch -import torchaudio from pathlib import Path -from fairseq2.data import SequenceData -from fairseq2.data.audio import WaveformToFbankConverter +import torch +import torchaudio -from seamless_communication.cli.expressivity.predict.pretssel_generator import ( - PretsselGenerator, -) from seamless_communication.cli.m4t.predict import ( add_inference_arguments, set_generation_opts, ) -from seamless_communication.inference import Translator -from seamless_communication.models.unity import ( - load_gcmvn_stats, - load_unity_unit_tokenizer, -) +from seamless_communication.inference import ExpressiveTranslator from seamless_communication.store import add_gated_assets - -AUDIO_SAMPLE_RATE = 16000 - - logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s -- %(name)s: %(message)s", @@ -39,16 +26,11 @@ logger = logging.getLogger(__name__) -def remove_prosody_tokens_from_text(text: str) -> str: - # filter out prosody tokens, there is only emphasis '*', and pause '=' - text = text.replace("*", "").replace("=", "") - text = " ".join(text.split()) - return text - - def main() -> None: - parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference.") - parser.add_argument("input", type=str, help="Audio WAV file path.") + parser = argparse.ArgumentParser( + description="Running SeamlessExpressive inference." + ) + parser.add_argument("input", type=Path, help="Audio WAV file path.") parser = add_inference_arguments(parser) parser.add_argument( @@ -69,10 +51,10 @@ def main() -> None: raise Exception( "--tgt_lang, --output_path must be provided for SeamlessExpressive inference." ) - + if args.gated_model_dir: add_gated_assets(args.gated_model_dir) - + if torch.cuda.is_available(): device = torch.device("cuda:0") dtype = torch.float16 @@ -82,59 +64,8 @@ def main() -> None: logger.info(f"Running inference on {device=} with {dtype=}.") - unit_tokenizer = load_unity_unit_tokenizer(args.model_name) - - translator = Translator( - args.model_name, - vocoder_name_or_card=None, - device=device, - dtype=dtype, - ) - - pretssel_generator = PretsselGenerator( - args.vocoder_name, - vocab_info=unit_tokenizer.vocab_info, - device=device, - dtype=dtype, - ) - - fbank_extractor = WaveformToFbankConverter( - num_mel_bins=80, - waveform_scale=2**15, - channel_last=True, - standardize=False, - device=device, - dtype=dtype, - ) - - _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name) - gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype) - gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype) - - wav, sample_rate = torchaudio.load(args.input) - wav = torchaudio.functional.resample(wav, orig_freq=sample_rate, new_freq=16_000) - wav = wav.transpose(0, 1) - - data = fbank_extractor( - { - "waveform": wav, - "sample_rate": 16000, - } - ) - fbank = data["fbank"] - gcmvn_fbank = fbank.subtract(gcmvn_mean).divide(gcmvn_std) - std, mean = torch.std_mean(fbank, dim=0) - fbank = fbank.subtract(mean).divide(std) - - src = SequenceData( - seqs=fbank.unsqueeze(0), - seq_lens=torch.LongTensor([fbank.shape[0]]), - is_ragged=False, - ) - src_gcmvn = SequenceData( - seqs=gcmvn_fbank.unsqueeze(0), - seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]), - is_ragged=False, + expressive_translator = ExpressiveTranslator( + args.model_name, args.vocoder_name, device, dtype ) text_generation_opts, unit_generation_opts = set_generation_opts(args) @@ -145,22 +76,13 @@ def main() -> None: f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}" ) - text_output, unit_output = translator.predict( - src, - "s2st", + text_output, speech_output = expressive_translator.predict( + args.input, args.tgt_lang, - text_generation_opts=text_generation_opts, - unit_generation_opts=unit_generation_opts, - unit_generation_ngram_filtering=args.unit_generation_ngram_filtering, - duration_factor=args.duration_factor, - prosody_encoder_input=src_gcmvn, - ) - - assert unit_output is not None - speech_output = pretssel_generator.predict( - unit_output.units, - tgt_lang=args.tgt_lang, - prosody_encoder_input=src_gcmvn, + text_generation_opts, + unit_generation_opts, + args.unit_generation_ngram_filtering, + args.duration_factor, ) logger.info(f"Saving expressive translated audio in {args.tgt_lang}") @@ -170,9 +92,7 @@ def main() -> None: sample_rate=speech_output.sample_rate, ) - text_out = remove_prosody_tokens_from_text(str(text_output[0])) - - logger.info(f"Translated text in {args.tgt_lang}: {text_out}") + logger.info(f"Translated text in {args.tgt_lang}: {text_output[0]}") if __name__ == "__main__": diff --git a/src/seamless_communication/inference/__init__.py b/src/seamless_communication/inference/__init__.py index f5c24ca1..fafe835c 100644 --- a/src/seamless_communication/inference/__init__.py +++ b/src/seamless_communication/inference/__init__.py @@ -4,6 +4,9 @@ # This source code is licensed under the license found in the # MIT_LICENSE file in the root directory of this source tree. +from seamless_communication.inference.expressive_translator import ( + ExpressiveTranslator as ExpressiveTranslator, +) from seamless_communication.inference.generator import ( SequenceGeneratorOptions as SequenceGeneratorOptions, ) diff --git a/src/seamless_communication/inference/expressive_translator.py b/src/seamless_communication/inference/expressive_translator.py new file mode 100644 index 00000000..c84de7cc --- /dev/null +++ b/src/seamless_communication/inference/expressive_translator.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# This source code is licensed under the license found in the +# MIT_LICENSE file in the root directory of this source tree. + +from copy import deepcopy +from pathlib import Path +from typing import List, Optional, Tuple, Union, cast + +import torch +import torchaudio +from fairseq2.assets.card import AssetCard +from fairseq2.data import SequenceData, StringLike +from fairseq2.data.audio import WaveformToFbankConverter +from fairseq2.nn.padding import apply_padding_mask, get_seqs_and_padding_mask +from fairseq2.typing import DataType, Device +from torch.nn import Module + +from seamless_communication.inference.generator import SequenceGeneratorOptions +from seamless_communication.inference.pretssel_generator import PretsselGenerator +from seamless_communication.inference.translator import BatchedSpeechOutput, Translator +from seamless_communication.models.unity import ( + load_gcmvn_stats, + load_unity_unit_tokenizer, +) + +AUDIO_SAMPLE_RATE = 16000 + + +class ExpressiveTranslator(Module): + def __init__( + self, + model_name_or_card: Union[str, AssetCard], + vocoder_name_or_card: Union[str, AssetCard, None], + device: Device, + dtype: DataType, + ): + super().__init__() + + unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card) + + self.translator = Translator( + model_name_or_card, + vocoder_name_or_card=None, + device=device, + dtype=dtype, + ) + + self.pretssel_generator = PretsselGenerator( + vocoder_name_or_card, + vocab_info=unit_tokenizer.vocab_info, + device=device, + dtype=dtype, + ) + + self.fbank_extractor = WaveformToFbankConverter( + num_mel_bins=80, + waveform_scale=2**15, + channel_last=True, + standardize=False, + device=device, + dtype=dtype, + ) + + _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(vocoder_name_or_card) + self.gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype) + self.gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype) + + @staticmethod + def remove_prosody_tokens_from_text(text_output: List[str]) -> List[str]: + modified_text_output = [] + for text in text_output: + # filter out prosody tokens, there is only emphasis '*', and pause '=' + text = str(text).replace("*", "").replace("=", "") + text = " ".join(text.split()) + modified_text_output.append(text) + return modified_text_output + + @torch.inference_mode() + def predict( + self, + input: Union[Path, SequenceData], + tgt_lang: str, + text_generation_opts: Optional[SequenceGeneratorOptions] = None, + unit_generation_opts: Optional[SequenceGeneratorOptions] = None, + unit_generation_ngram_filtering: bool = False, + duration_factor: float = 1.0, + ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]: + """ + The main method used to perform inference on all tasks. + + :param input: + Either path to audio or audio Tensor. + :param tgt_lang: + Target language to decode into. + :param text_generation_opts: + Text generation hyperparameters for incremental decoding. + :param unit_generation_opts: + Unit generation hyperparameters for incremental decoding. + :param unit_generation_ngram_filtering: + If True, removes consecutive repeated ngrams + from the decoded unit output. + + :returns: + - Batched list of Translated text. + - Translated BatchedSpeechOutput. + """ + if isinstance(input, dict): + src = cast(SequenceData, input) + src_gcmvn = deepcopy(src) + + fbank, padding_mask = get_seqs_and_padding_mask(src) + gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std) + # due to padding, batched std_mean calculation is wrong + mean = torch.zeros_like(fbank[:, 0]) + std = torch.zeros_like(fbank[:, 0]) + for i, (i_fbank, i_seq_len) in enumerate(zip(fbank, src["seq_lens"])): + std[i], mean[i] = torch.std_mean(i_fbank[:i_seq_len], dim=0) + + ucmvn_fbank = fbank.subtract(mean.unsqueeze(1)).divide(std.unsqueeze(1)) + src["seqs"] = apply_padding_mask(ucmvn_fbank, padding_mask) + src_gcmvn["seqs"] = apply_padding_mask(gcmvn_fbank, padding_mask) + + elif isinstance(input, Path): + # TODO: Replace with fairseq2.data once re-sampling is implemented. + wav, sample_rate = torchaudio.load(input) + wav = torchaudio.functional.resample( + wav, + orig_freq=sample_rate, + new_freq=AUDIO_SAMPLE_RATE, + ) + wav = wav.transpose(0, 1) + + data = self.fbank_extractor( + { + "waveform": wav, + "sample_rate": AUDIO_SAMPLE_RATE, + } + ) + + fbank = data["fbank"] + gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std) + std, mean = torch.std_mean(fbank, dim=0) + fbank = fbank.subtract(mean).divide(std) + + src = SequenceData( + seqs=fbank.unsqueeze(0), + seq_lens=torch.LongTensor([fbank.shape[0]]), + is_ragged=False, + ) + src_gcmvn = SequenceData( + seqs=gcmvn_fbank.unsqueeze(0), + seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]), + is_ragged=False, + ) + + text_output, unit_output = self.translator.predict( + src, + "s2st", + tgt_lang, + text_generation_opts=text_generation_opts, + unit_generation_opts=unit_generation_opts, + unit_generation_ngram_filtering=unit_generation_ngram_filtering, + duration_factor=duration_factor, + prosody_encoder_input=src_gcmvn, + ) + text_output = self.remove_prosody_tokens_from_text(text_output) + + assert unit_output is not None + speech_output = self.pretssel_generator.predict( + unit_output.units, + tgt_lang=tgt_lang, + prosody_encoder_input=src_gcmvn, + ) + return text_output, speech_output diff --git a/src/seamless_communication/cli/expressivity/predict/pretssel_generator.py b/src/seamless_communication/inference/pretssel_generator.py similarity index 95% rename from src/seamless_communication/cli/expressivity/predict/pretssel_generator.py rename to src/seamless_communication/inference/pretssel_generator.py index 0754e339..4a2f8074 100644 --- a/src/seamless_communication/cli/expressivity/predict/pretssel_generator.py +++ b/src/seamless_communication/inference/pretssel_generator.py @@ -6,19 +6,13 @@ from typing import List import torch -from torch.nn import Module - -from fairseq2.typing import DataType, Device - from fairseq2.assets import asset_store -from fairseq2.data import ( - Collater, - SequenceData, - VocabularyInfo, -) +from fairseq2.data import Collater, SequenceData, VocabularyInfo from fairseq2.nn.padding import get_seqs_and_padding_mask +from fairseq2.typing import DataType, Device +from torch.nn import Module -from seamless_communication.inference import BatchedSpeechOutput +from seamless_communication.inference.translator import BatchedSpeechOutput from seamless_communication.models.generator.loader import load_pretssel_vocoder_model @@ -60,7 +54,6 @@ def predict( tgt_lang: str, prosody_encoder_input: SequenceData, ) -> BatchedSpeechOutput: - units_batch, durations = [], [] for u in units: unit = torch.tensor(u).to(self.unit_eos_token) diff --git a/src/seamless_communication/inference/translator.py b/src/seamless_communication/inference/translator.py index 57bea931..3513f33c 100644 --- a/src/seamless_communication/inference/translator.py +++ b/src/seamless_communication/inference/translator.py @@ -10,7 +10,7 @@ from typing import List, Optional, Tuple, Union, cast import torch -import torch.nn as nn +from torch.nn import Module from fairseq2.assets import asset_store from fairseq2.assets.card import AssetCard from fairseq2.data import Collater, SequenceData, StringLike @@ -75,7 +75,7 @@ class BatchedSpeechOutput: """Sample rate of the audio waveforms.""" -class Translator(nn.Module): +class Translator(Module): def __init__( self, model_name_or_card: Union[str, AssetCard], diff --git a/src/seamless_communication/toxicity/mintox.py b/src/seamless_communication/toxicity/mintox.py index aa772be4..ea24f62d 100644 --- a/src/seamless_communication/toxicity/mintox.py +++ b/src/seamless_communication/toxicity/mintox.py @@ -7,26 +7,19 @@ import logging from typing import List, Optional, Tuple -from torch import Tensor import torch -from torch.nn import functional as F - - -from seamless_communication.inference import SequenceGeneratorOptions -from seamless_communication.toxicity.etox_bad_word_checker import ( - ETOXBadWordChecker, -) -from fairseq2.generation import BannedSequenceProcessor +from fairseq2.data import SequenceData from fairseq2.data.text.text_tokenizer import TextTokenizer from fairseq2.data.typing import StringLike -from fairseq2.typing import Device -from fairseq2.data import SequenceData +from fairseq2.generation import BannedSequenceProcessor from fairseq2.nn.padding import get_seqs_and_padding_mask -from seamless_communication.models.unity import ( - UnitTokenizer, - UnitYModel, -) +from fairseq2.typing import Device +from torch import Tensor +from torch.nn import functional as F +from seamless_communication.inference.generator import SequenceGeneratorOptions +from seamless_communication.models.unity import UnitTokenizer, UnitYModel +from seamless_communication.toxicity.etox_bad_word_checker import ETOXBadWordChecker logger = logging.getLogger(__name__) @@ -84,9 +77,7 @@ def _replace_with_new_unit_output_in_batch( ) else: # pad on the new units - new_units = F.pad( - new_units, pad=nb_pads, mode="constant", value=pad_idx - ) + new_units = F.pad(new_units, pad=nb_pads, mode="constant", value=pad_idx) original_units[indices_with_toxicity_tensor] = new_units