From 87187ea2dd8d6cdc02f2314f6c38b9c5a8238076 Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 7 Oct 2023 17:47:00 -0700 Subject: [PATCH] add batch translation for text inputs --- .../models/inference/translator.py | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/src/seamless_communication/models/inference/translator.py b/src/seamless_communication/models/inference/translator.py index ad76de10..e9e56428 100644 --- a/src/seamless_communication/models/inference/translator.py +++ b/src/seamless_communication/models/inference/translator.py @@ -4,10 +4,11 @@ # LICENSE file in the root directory of this source tree. from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union, Any import torch import torch.nn as nn +from torch.nn.functional import pad as pad_tensor from fairseq2.assets.card import AssetCard from fairseq2.data import Collater from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter @@ -146,6 +147,16 @@ def get_prediction( ngram_filtering=ngram_filtering, ) + def batch_tensors(self, tensors: List[Tensor], pad_value: Any) -> Tensor: + padding_size = max(tensor.shape[0] for tensor in tensors) + dims = len(tensors[0].shape) + padded_tensors = [] + for tensor in tensors: + padding = [0] * 2 * dims + padding[-1] = padding_size - tensor.shape[0] + padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value)) + return torch.stack([tensor for tensor in padded_tensors], dim=0) + def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]: if task == Task.S2ST: return Modality.SPEECH, Modality.SPEECH @@ -160,7 +171,7 @@ def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]: @torch.inference_mode() def predict( self, - input: Union[str, Tensor], + input: Union[str, List[str], Tensor], task_str: str, tgt_lang: str, src_lang: Optional[str] = None, @@ -220,7 +231,16 @@ def predict( self.token_encoder = self.text_tokenizer.create_encoder( task="translation", lang=src_lang, mode="source", device=self.device ) - src = self.collate(self.token_encoder(text)) + + if isinstance(text, str): + src = self.collate(self.token_encoder(text)) + else: + collated = [self.collate(self.token_encoder(t)) for t in text] + + src = { + 'seqs': self.batch_tensors([item['seqs'].squeeze() for item in collated], self.text_tokenizer.vocab_info.pad_idx), + 'seq_lens': torch.cat([item['seq_lens'] for item in collated]) + } result = self.get_prediction( self.model, @@ -240,7 +260,10 @@ def predict( text_out = result[0] unit_out = result[1] if output_modality == Modality.TEXT: - return text_out.sentences[0], None, None + if len(text_out.sentences) == 0: + return text_out.sentences[0], None, None + else: + return text_out.sentences, None, None else: units = unit_out.units[:, 1:][0].cpu().numpy().tolist() wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)