diff --git a/src/seamless_communication/cli/m4t/predict/predict.py b/src/seamless_communication/cli/m4t/predict/predict.py index 9fd6bc00..e77bd746 100644 --- a/src/seamless_communication/cli/m4t/predict/predict.py +++ b/src/seamless_communication/cli/m4t/predict/predict.py @@ -27,11 +27,13 @@ def add_inference_arguments(parser: argparse.ArgumentParser) -> argparse.Argumen parser.add_argument( "--task", type=str, - choices=["ASR", "S2ST", "S2TT"], + choices=["ASR", "S2ST", "S2TT", "T2ST", "T2TT"], help=( "* `ASR` -- automatic speech recognition (transcription);" "* `S2ST` -- speech to speech translation;" "* `S2TT` -- speech to text translation;" + "* `T2ST` -- text to speech translation;" + "* `T2TT` -- text to text translation;" ) ) parser.add_argument( @@ -201,6 +203,11 @@ def main() -> None: "Please provide required arguments for evaluation - task, tgt_lang" ) + if args.task.upper() in {"T2TT", "T2ST"} and args.src_lang is None: + raise ValueError( + "src_lang must be provided for text input tasks (T2TT, T2ST)" + ) + if args.task.upper() in {"S2ST", "T2ST"} and args.output_path is None: raise ValueError("output_path must be provided to save the generated audio")