Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Please check out above [section](#seamlessexpressive-models) on how to acquire `
### W2v-BERT 2.0 speech encoder
| Model Name | #params | checkpoint |
| ----------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| W2v-BERT 2.0 | 600M | [🤗 Model card](https://huggingface.co/facebook/conformer-shaw) - [checkpoint](https://huggingface.co/facebook/conformer-shaw/resolve/main/conformer_shaw.pt)
| W2v-BERT 2.0 | 600M | [🤗 Model card](https://huggingface.co/facebook/w2v-bert-2.0) - [checkpoint](https://huggingface.co/facebook/w2v-bert-2.0/resolve/main/conformer_shaw.pt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we fix this in another PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure


Here's how you should do a foward pass through the speech encoder:

Expand Down
2 changes: 1 addition & 1 deletion demo/expressive/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
load_gcmvn_stats,
load_unity_unit_tokenizer,
)
from seamless_communication.cli.expressivity.predict.pretssel_generator import PretsselGenerator
from seamless_communication.inference.pretssel_generator import PretsselGenerator

from typing import Tuple
from utils import LANGUAGE_CODE_TO_NAME
Expand Down
67 changes: 13 additions & 54 deletions src/seamless_communication/cli/expressivity/evaluate/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
from torch import Tensor
from tqdm import tqdm

from seamless_communication.cli.expressivity.predict.pretssel_generator import (
PretsselGenerator,
)
from seamless_communication.cli.m4t.evaluate.evaluate import (
adjust_output_for_corrupted_inputs,
count_lines,
Expand All @@ -36,11 +33,8 @@
add_inference_arguments,
set_generation_opts,
)
from seamless_communication.inference import BatchedSpeechOutput, Translator
from seamless_communication.models.unity import (
load_gcmvn_stats,
load_unity_unit_tokenizer,
)
from seamless_communication.inference import BatchedSpeechOutput, ExpressiveTranslator
from seamless_communication.models.unity import load_unity_unit_tokenizer
from seamless_communication.store import add_gated_assets

logging.basicConfig(
Expand All @@ -55,8 +49,6 @@ def build_data_pipeline(
args: Namespace,
device: Device,
dtype: DataType,
gcmvn_mean: Tensor,
gcmvn_std: Tensor,
) -> DataPipeline:
with open(args.data_file, "r") as f:
header = f.readline().strip("\n").split("\t")
Expand Down Expand Up @@ -89,15 +81,8 @@ def build_data_pipeline(
dtype=dtype,
)

def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
fbank = data["fbank"]
std, mean = torch.std_mean(fbank, dim=0)
data["fbank"] = fbank.subtract(mean).divide(std)
data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
return data

pipeline_builder.map(
[decode_audio, convert_to_fbank, normalize_fbank],
[decode_audio, convert_to_fbank],
selector=f"{args.audio_field}.data",
num_parallel_calls=n_parallel,
)
Expand Down Expand Up @@ -176,17 +161,10 @@ def main() -> None:

unit_tokenizer = load_unity_unit_tokenizer(args.model_name)

_gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)

pipeline = build_data_pipeline(args, device, dtype, gcmvn_mean, gcmvn_std)
pipeline = build_data_pipeline(args, device, dtype)

translator = Translator(
args.model_name,
vocoder_name_or_card=None,
device=device,
dtype=dtype,
expressive_translator = ExpressiveTranslator(
args.model_name, args.vocoder_name, device, dtype
)

text_generation_opts, unit_generation_opts = set_generation_opts(args)
Expand All @@ -197,13 +175,6 @@ def main() -> None:
f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
)

pretssel_generator = PretsselGenerator(
args.vocoder_name,
vocab_info=unit_tokenizer.vocab_info,
device=device,
dtype=dtype,
)

total_steps = count_lines(args.data_file) - 1
progress_bar = tqdm(total=total_steps)

Expand Down Expand Up @@ -240,28 +211,16 @@ def main() -> None:
src["seqs"] = src["seqs"][valid_sequences]
src["seq_lens"] = src["seq_lens"][valid_sequences]

# Skip performing inference when the input is entirely corrupted.
# Skip inference when the input is entirely corrupted.
if src["seqs"].numel() > 0:
prosody_encoder_input = example[args.audio_field]["data"]["gcmvn_fbank"]
text_output, unit_output = translator.predict(
text_output, speech_output = expressive_translator.predict(
src,
"s2st",
args.tgt_lang,
src_lang=args.src_lang,
text_generation_opts=text_generation_opts,
unit_generation_opts=unit_generation_opts,
unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
duration_factor=args.duration_factor,
prosody_encoder_input=prosody_encoder_input,
text_generation_opts,
unit_generation_opts,
args.unit_generation_ngram_filtering,
args.duration_factor,
)

assert unit_output is not None
speech_output = pretssel_generator.predict(
unit_output.units,
tgt_lang=args.tgt_lang,
prosody_encoder_input=prosody_encoder_input,
)

else:
text_output = []
speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
Expand All @@ -273,7 +232,7 @@ def main() -> None:
speech_output,
)

hyps += [str(s) for s in text_output]
hyps += [s for s in text_output]
if args.ref_field is not None and args.ref_field in example:
refs += [str(s) for s in example[args.ref_field]]

Expand Down
116 changes: 18 additions & 98 deletions src/seamless_communication/cli/expressivity/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,18 @@

import argparse
import logging
import torch
import torchaudio
from pathlib import Path

from fairseq2.data import SequenceData
from fairseq2.data.audio import WaveformToFbankConverter
import torch
import torchaudio

from seamless_communication.cli.expressivity.predict.pretssel_generator import (
PretsselGenerator,
)
from seamless_communication.cli.m4t.predict import (
add_inference_arguments,
set_generation_opts,
)
from seamless_communication.inference import Translator
from seamless_communication.models.unity import (
load_gcmvn_stats,
load_unity_unit_tokenizer,
)
from seamless_communication.inference import ExpressiveTranslator
from seamless_communication.store import add_gated_assets


AUDIO_SAMPLE_RATE = 16000


logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
Expand All @@ -39,16 +26,11 @@
logger = logging.getLogger(__name__)


def remove_prosody_tokens_from_text(text: str) -> str:
# filter out prosody tokens, there is only emphasis '*', and pause '='
text = text.replace("*", "").replace("=", "")
text = " ".join(text.split())
return text


def main() -> None:
parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference.")
parser.add_argument("input", type=str, help="Audio WAV file path.")
parser = argparse.ArgumentParser(
description="Running SeamlessExpressive inference."
)
parser.add_argument("input", type=Path, help="Audio WAV file path.")

parser = add_inference_arguments(parser)
parser.add_argument(
Expand All @@ -69,10 +51,10 @@ def main() -> None:
raise Exception(
"--tgt_lang, --output_path must be provided for SeamlessExpressive inference."
)

if args.gated_model_dir:
add_gated_assets(args.gated_model_dir)

if torch.cuda.is_available():
device = torch.device("cuda:0")
dtype = torch.float16
Expand All @@ -82,59 +64,8 @@ def main() -> None:

logger.info(f"Running inference on {device=} with {dtype=}.")

unit_tokenizer = load_unity_unit_tokenizer(args.model_name)

translator = Translator(
args.model_name,
vocoder_name_or_card=None,
device=device,
dtype=dtype,
)

pretssel_generator = PretsselGenerator(
args.vocoder_name,
vocab_info=unit_tokenizer.vocab_info,
device=device,
dtype=dtype,
)

fbank_extractor = WaveformToFbankConverter(
num_mel_bins=80,
waveform_scale=2**15,
channel_last=True,
standardize=False,
device=device,
dtype=dtype,
)

_gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)

wav, sample_rate = torchaudio.load(args.input)
wav = torchaudio.functional.resample(wav, orig_freq=sample_rate, new_freq=16_000)
wav = wav.transpose(0, 1)

data = fbank_extractor(
{
"waveform": wav,
"sample_rate": 16000,
}
)
fbank = data["fbank"]
gcmvn_fbank = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
std, mean = torch.std_mean(fbank, dim=0)
fbank = fbank.subtract(mean).divide(std)

src = SequenceData(
seqs=fbank.unsqueeze(0),
seq_lens=torch.LongTensor([fbank.shape[0]]),
is_ragged=False,
)
src_gcmvn = SequenceData(
seqs=gcmvn_fbank.unsqueeze(0),
seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]),
is_ragged=False,
expressive_translator = ExpressiveTranslator(
args.model_name, args.vocoder_name, device, dtype
)

text_generation_opts, unit_generation_opts = set_generation_opts(args)
Expand All @@ -145,22 +76,13 @@ def main() -> None:
f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
)

text_output, unit_output = translator.predict(
src,
"s2st",
text_output, speech_output = expressive_translator.predict(
args.input,
args.tgt_lang,
text_generation_opts=text_generation_opts,
unit_generation_opts=unit_generation_opts,
unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
duration_factor=args.duration_factor,
prosody_encoder_input=src_gcmvn,
)

assert unit_output is not None
speech_output = pretssel_generator.predict(
unit_output.units,
tgt_lang=args.tgt_lang,
prosody_encoder_input=src_gcmvn,
text_generation_opts,
unit_generation_opts,
args.unit_generation_ngram_filtering,
args.duration_factor,
)

logger.info(f"Saving expressive translated audio in {args.tgt_lang}")
Expand All @@ -170,9 +92,7 @@ def main() -> None:
sample_rate=speech_output.sample_rate,
)

text_out = remove_prosody_tokens_from_text(str(text_output[0]))

logger.info(f"Translated text in {args.tgt_lang}: {text_out}")
logger.info(f"Translated text in {args.tgt_lang}: {text_output[0]}")


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions src/seamless_communication/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.

from seamless_communication.inference.expressive_translator import (
ExpressiveTranslator as ExpressiveTranslator,
)
from seamless_communication.inference.generator import (
SequenceGeneratorOptions as SequenceGeneratorOptions,
)
Expand Down
Loading