diff --git a/.gitmodules b/.gitmodules index 37f7a6e4..b34ee409 100644 --- a/.gitmodules +++ b/.gitmodules @@ -43,3 +43,6 @@ [submodule "deps/Zonos"] path = deps/Zonos url = https://github.com/weedge/Zonos.git +[submodule "deps/StepAudio"] + path = deps/StepAudio + url = https://github.com/weedge/Step-Audio.git diff --git a/deps/StepAudio b/deps/StepAudio new file mode 160000 index 00000000..7ce0a889 --- /dev/null +++ b/deps/StepAudio @@ -0,0 +1 @@ +Subproject commit 7ce0a88996e91fe5b4956ab1f9447869ce60944a diff --git a/pyproject.toml b/pyproject.toml index f3fe9c52..c6fa6164 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,32 +182,14 @@ speech_vad_analyzer = [ rms_recorder = [] vad_recorder = ["achatbot[speech_vad]"] -# asr module tag -> pkgs -whisper_asr = ["openai-whisper==20231117"] -whisper_timestamped_asr = ["whisper-timestamped~=1.14.2"] -whisper_faster_asr = ["faster-whisper~=1.0.2"] -whisper_transformers_asr = ["transformers[torch]>=4.40.2"] -whisper_mlx_asr = [ - "mlx_whisper~=0.2.0; sys_platform == 'darwin' and platform_machine == 'arm64'", -] -whisper_groq_asr = ["groq~=0.9.0"] -sense_voice_asr = [ - "torch~=2.2.2", - "funasr~=1.1.8", - "onnx", - "onnxconverter-common", -] -speech_asr = [ - "achatbot[whisper_asr,whisper_timestamped_asr,whisper_faster_asr,whisper_transformers_asr,whisper_mlx_asr,whisper_groq_asr,sense_voice_asr]", -] - +# --------------------------------- llm -------------------------- # llm module tag -> pkgs # init use cpu Pre-built Wheel to install, # if want to use other lib(cuda), see: https://github.com/abetlen/llama-cpp-python#installation-configuration llama_cpp = ["llama-cpp-python~=0.2.82"] llm_personalai_proxy = ["geocoder~=1.38.1"] -# vision +# vision llm llm_transformers_manual_vision = [ #"transformers@git+https://github.com/huggingface/transformers", # https://github.com/huggingface/transformers/releases/tag/v4.45.0 @@ -245,9 +227,65 @@ llm_transformers_manual_vision_deepseekvl2 = [ "timm>=0.9.16", ] +# voice llm +llm_transformers_manual_voice = [ + #"transformers@git+https://github.com/huggingface/transformers", + # https://github.com/huggingface/transformers/releases/tag/v4.45.transformers~=4.45.2 + "transformers~=4.45.2", + "torch~=2.2.2", + "torchaudio~=2.2.2", +] +llm_transformers_manual_voice_glm = [ + "achatbot[llm_transformers_manual_voice,tts_cosy_voice,gdown,matplotlib,conf]", +] +llm_transformers_manual_voice_freeze_omni = [ + "achatbot[llm_transformers_manual_voice,librosa,soundfile,yaml]", +] +# speech llm +llm_transformers_manual_speech_llasa = [ + "achatbot[llm_transformers_manual_voice]", +] +llm_transformers_manual_speech_step = [ + "achatbot[llm_transformers_manual_voice]", +] +# vision voice llm +llm_transformers_manual_vision_voice_minicpmo = [ + "achatbot[accelerate,librosa,soundfile]", + "torch~=2.2.2", + "torchaudio~=2.2.2", + "torchvision~=0.17.2", + "transformers==4.44.2", + #"librosa==0.9.0", + #"soundfile==0.12.1", + "vector-quantize-pytorch~=1.18.5", + "vocos~=0.1.0", + "decord", + "moviepy", +] + # core llms core_llm = ["achatbot[llama_cpp,llm_personalai_proxy]"] +# ----------------- asr ------------------ +# asr module tag -> pkgs +whisper_asr = ["openai-whisper==20231117"] +whisper_timestamped_asr = ["whisper-timestamped~=1.14.2"] +whisper_faster_asr = ["faster-whisper~=1.0.2"] +whisper_transformers_asr = ["transformers[torch]>=4.40.2"] +whisper_mlx_asr = [ + "mlx_whisper~=0.2.0; sys_platform == 'darwin' and platform_machine == 'arm64'", +] +whisper_groq_asr = ["groq~=0.9.0"] +sense_voice_asr = [ + "torch~=2.2.2", + "funasr~=1.1.8", + "onnx", + "onnxconverter-common", +] +speech_asr = [ + "achatbot[whisper_asr,whisper_timestamped_asr,whisper_faster_asr,whisper_transformers_asr,whisper_mlx_asr,whisper_groq_asr,sense_voice_asr]", +] + # -----------------codec------------------ # https://huggingface.co/kyutai/mimi/blob/main/config.json transformers_version codec_transformers_mimi = ["transformers[torch]~=4.45.1"] @@ -357,43 +395,31 @@ tts_zonos_hybrid = [ "mamba-ssm>=2.2.4", "causal-conv1d>=1.5.0.post8", ] +tts_step = [ + "torch==2.3.1", + "torchaudio==2.3.1", + "torchvision==0.18.1", + "transformers==4.48.3", + "accelerate==1.3.0", + "openai-whisper==20231117", + "sox==1.5.0", + "modelscope", + "six==1.16.0", + "hyperpyyaml", + "conformer==0.3.2", + "diffusers", + "onnxruntime-gpu==1.20.1", # cuda 12.5 + "sentencepiece", + "funasr>=1.1.3", + "protobuf==5.29.3", + "achatbot[conf,librosa]", +] # multi tts modules engine speech_tts = [ "achatbot[tts_coqui,tts_edge,tts_g,tts_pyttsx3,tts_cosy_voice,tts_chat,tts_f5,tts_openvoicev2,tts_kokoro]", ] -# voice -llm_transformers_manual_voice = [ - #"transformers@git+https://github.com/huggingface/transformers", - # https://github.com/huggingface/transformers/releases/tag/v4.45.transformers~=4.45.2 - "transformers~=4.45.2", - "torch~=2.2.2", - "torchaudio~=2.2.2", -] -llm_transformers_manual_voice_glm = [ - "achatbot[llm_transformers_manual_voice,tts_cosy_voice,gdown,matplotlib,conf]", -] -llm_transformers_manual_voice_freeze_omni = [ - "achatbot[llm_transformers_manual_voice,librosa,soundfile,yaml]", -] -llm_transformers_manual_speech_llasa = [ - "achatbot[llm_transformers_manual_voice]", -] -llm_transformers_manual_vision_voice_minicpmo = [ - "achatbot[accelerate,librosa,soundfile]", - "torch~=2.2.2", - "torchaudio~=2.2.2", - "torchvision~=0.17.2", - "transformers==4.44.2", - #"librosa==0.9.0", - #"soundfile==0.12.1", - "vector-quantize-pytorch~=1.18.5", - "vocos~=0.1.0", - "decord", - "moviepy", -] - # player module tag -> pkgs stream_player = [] diff --git a/src/cmd/bots/image/storytelling/assets/speakers/TingtingRAP_prompt.wav b/src/cmd/bots/image/storytelling/assets/speakers/TingtingRAP_prompt.wav new file mode 100644 index 00000000..b85907c7 Binary files /dev/null and b/src/cmd/bots/image/storytelling/assets/speakers/TingtingRAP_prompt.wav differ diff --git a/src/cmd/bots/image/storytelling/assets/speakers/Tingting_prompt.wav b/src/cmd/bots/image/storytelling/assets/speakers/Tingting_prompt.wav new file mode 100644 index 00000000..27cbca3b Binary files /dev/null and b/src/cmd/bots/image/storytelling/assets/speakers/Tingting_prompt.wav differ diff --git "a/src/cmd/bots/image/storytelling/assets/speakers/Tingting\345\223\274\345\224\261_prompt.wav" "b/src/cmd/bots/image/storytelling/assets/speakers/Tingting\345\223\274\345\224\261_prompt.wav" new file mode 100644 index 00000000..09eb15a5 Binary files /dev/null and "b/src/cmd/bots/image/storytelling/assets/speakers/Tingting\345\223\274\345\224\261_prompt.wav" differ diff --git a/src/cmd/bots/image/storytelling/assets/speakers/speakers_info.json b/src/cmd/bots/image/storytelling/assets/speakers/speakers_info.json new file mode 100644 index 00000000..0d396475 --- /dev/null +++ b/src/cmd/bots/image/storytelling/assets/speakers/speakers_info.json @@ -0,0 +1,5 @@ +{ + "TingtingRAP": "(RAP)远远甩开的笑他是陆行龟 他曾跌倒也曾吃过灰 他说有福的人才会多吃亏 他的爸爸让他小心交友可他偏偏钻进个垃圾堆 他说他明白How to play", + "Tingting哼唱": "(哼唱)你从一座叫 我 的小镇经过 刚好屋顶的雪化成雨飘落", + "Tingting": "那等我们到海洋馆之后,给妈妈买个礼物,好不好呀?" +} diff --git a/src/cmd/grpc/speaker/client.py b/src/cmd/grpc/speaker/client.py index dba42b2a..d19dab8c 100644 --- a/src/cmd/grpc/speaker/client.py +++ b/src/cmd/grpc/speaker/client.py @@ -62,7 +62,16 @@ def load_model(tts_stub: TTSStub): def synthesize_us(tts_stub: TTSStub): - request_data = SynthesizeRequest(tts_text="hello,你好,我是机器人") + tag = os.getenv("TTS_TAG", "tts_edge") + if tag not in TTSEnvInit.map_synthesize_config_func: + logging.warning(f"{tag} not in map_synthesize_config_func, use default config") + kwargs = TTSEnvInit.get_tts_synth_args() + else: + kwargs = TTSEnvInit.map_synthesize_config_func[tag]() + request_data = SynthesizeRequest( + tts_text="hello,你好,我是机器人", json_kwargs=json.dumps(kwargs) + ) + logging.debug(request_data) response_iterator = tts_stub.SynthesizeUS(request_data) for response in response_iterator: yield response.tts_audio @@ -111,7 +120,7 @@ def set_voice(tts_stub: TTSStub, voice: str): IS_RELOAD=1 python -m src.cmd.grpc.speaker.client TTS_TAG=tts_llasa IS_SAVE=1 IS_RELOAD=1 python -m src.cmd.grpc.speaker.client -TTS_TAG=tts_llasa IS_SAVE=1 IS_RELOAD=1 python -m src.cmd.grpc.speaker.client +TTS_TAG=tts_step IS_SAVE=1 IS_RELOAD=1 python -m src.cmd.grpc.speaker.client # instruct2speech TTS_TAG=tts_minicpmo \ @@ -134,6 +143,21 @@ def set_voice(tts_stub: TTSStub, voice: str): SPEAKER_EMBEDDING_MODEL_DIR=./models/Zyphra/Zonos-v0.1-speaker-embedding ZONOS_REF_AUDIO_PATH=./test/audio_files/asr_example_zh.wav \ IS_SAVE=1 IS_RELOAD=1 python -m src.cmd.grpc.speaker.client + +# tts lm gen +TTS_TAG=tts_step IS_SAVE=1 IS_RELOAD=1 \ + TTS_WARMUP_STEPS=2 TTS_LM_MODEL_PATH=./models/stepfun-ai/Step-Audio-TTS-3B \ + TTS_TOKENIZER_MODEL_PATH=./models/stepfun-ai/Step-Audio-Tokenizer \ + python -m src.cmd.grpc.speaker.client +# tts voice clone +TTS_TAG=tts_step IS_SAVE=1 IS_RELOAD=1 \ + TTS_WARMUP_STEPS=2 TTS_LM_MODEL_PATH=/content/models/stepfun-ai/Step-Audio-TTS-3B \ + TTS_TOKENIZER_MODEL_PATH=/content/models/stepfun-ai/Step-Audio-Tokenizer \ + TTS_STREAM_FACTOR=2 \ + TTS_MODE=voice_clone \ + SRC_AUDIO_PATH=./test/audio_files/asr_example_zh.wav \ + python -m src.cmd.grpc.speaker.client + """ if __name__ == "__main__": player = None diff --git a/src/common/interface.py b/src/common/interface.py index e61b4133..e562f02d 100644 --- a/src/common/interface.py +++ b/src/common/interface.py @@ -216,7 +216,7 @@ def get_stream_info(self) -> dict: raise NotImplementedError("must be implemented in the child class") @abstractmethod - def set_voice(self, voice: str): + def set_voice(self, voice: str, **kwargs): """ Note: - just simple voice set, don't support set voice with user id diff --git a/src/common/utils/task.py b/src/common/utils/task.py index cf57c401..21129773 100644 --- a/src/common/utils/task.py +++ b/src/common/utils/task.py @@ -1,5 +1,7 @@ #!/usr/bin/env python from concurrent.futures import ThreadPoolExecutor +import logging +import traceback from typing import Callable, Any import asyncio import queue @@ -17,9 +19,15 @@ async def async_task(sync_func: Callable, *args, **kwargs) -> Any: def fetch_async_items(queue: queue.Queue, asyncFunc, *args) -> None: async def get_items() -> None: - async for item in asyncFunc(*args): - queue.put(item) - queue.put(None) + try: + async for item in asyncFunc(*args): + queue.put(item) + queue.put(None) + except Exception as e: + error_message = traceback.format_exc() + logging.error(f"error:{e} trace: {error_message}") + + queue.put(None) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) diff --git a/src/core/llm/transformers/manual_speech_llasa.py b/src/core/llm/transformers/manual_speech_llasa.py index 00481935..7f17d2ed 100644 --- a/src/core/llm/transformers/manual_speech_llasa.py +++ b/src/core/llm/transformers/manual_speech_llasa.py @@ -1,12 +1,9 @@ import logging -from threading import Thread -from queue import Queue - +from threading import Lock, Thread try: import torch from transformers import AutoTokenizer, AutoModelForCausalLM - from transformers.generation.streamers import BaseStreamer except ModuleNotFoundError as e: logging.error(f"Exception: {e}") logging.error( @@ -14,11 +11,11 @@ ) raise Exception(f"Missing module: {e}") - from src.common.utils.helper import get_device from src.common.session import Session from src.types.llm.transformers import TransformersSpeechLMArgs from .base import TransformersBaseLLM +from .streamer import TokenStreamer def ids_to_speech_tokens(speech_ids): @@ -41,43 +38,6 @@ def extract_speech_ids(speech_tokens_str): return speech_ids -class TokenStreamer(BaseStreamer): - def __init__(self, skip_prompt: bool = False, timeout=None): - self.skip_prompt = skip_prompt - - # variables used in the streaming process - self.token_queue = Queue() - self.stop_signal = None - self.next_tokens_are_prompt = True - self.timeout = timeout - - def put(self, value): - if len(value.shape) > 1 and value.shape[0] > 1: - raise ValueError("TextStreamer only supports batch size 1") - elif len(value.shape) > 1: - value = value[0] - - if self.skip_prompt and self.next_tokens_are_prompt: - self.next_tokens_are_prompt = False - return - - for token in value.tolist(): - self.token_queue.put(token) - - def end(self): - self.token_queue.put(self.stop_signal) - - def __iter__(self): - return self - - def __next__(self): - value = self.token_queue.get(timeout=self.timeout) - if value == self.stop_signal: - raise StopIteration() - else: - return value - - class TransformersManualSpeechLlasa(TransformersBaseLLM): """ TTS: text + ref audio -> llama2 -> vq code tokens @@ -93,7 +53,10 @@ def __init__(self, **args): self._model = AutoModelForCausalLM.from_pretrained(self.args.lm_model_name_or_path) self._model.eval().to(self.args.lm_device) self._tokenizer = AutoTokenizer.from_pretrained(self.args.lm_model_name_or_path) - self._streamer = TokenStreamer(skip_prompt=True) + + # session ctx dict with lock, maybe need a session class + self.session_lm_generat_lock = Lock() + self.session_lm_generated_ids = {} # session_id: ids(ptr) self.warmup() @@ -118,10 +81,11 @@ def warmup(self): input_ids = input_ids.to("cuda") speech_end_id = self._tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") + streamer = TokenStreamer(skip_prompt=True) warmup_gen_kwargs = dict( input_ids=input_ids, eos_token_id=speech_end_id, - streamer=self._streamer, + streamer=streamer, min_new_tokens=self.args.lm_gen_min_new_tokens, max_new_tokens=self.args.lm_gen_max_new_tokens, top_k=self.args.lm_gen_top_k, @@ -134,7 +98,7 @@ def warmup(self): self._warmup( target=self._model.generate, kwargs=warmup_gen_kwargs, - streamer=self._streamer, + streamer=streamer, ) # @torch.no_grad() @@ -172,10 +136,11 @@ def generate(self, session: Session, **kwargs): ) input_ids = input_ids.to(self.args.lm_device) speech_end_id = self._tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") + streamer = TokenStreamer(skip_prompt=True) generation_kwargs = dict( input_ids=input_ids, eos_token_id=speech_end_id, - streamer=self._streamer, + streamer=streamer, max_length=2048, # We trained our model with a max length of 2048 min_new_tokens=kwargs["min_new_tokens"] if "min_new_tokens" in kwargs @@ -197,32 +162,40 @@ def generate(self, session: Session, **kwargs): thread = Thread(target=self._model.generate, kwargs=generation_kwargs) thread.start() - i, j = 0, 0 - generated_ids = [] - for token_id in self._streamer: + session_id = session.ctx.client_id + with self.session_lm_generat_lock: + self.session_lm_generated_ids[session_id] = [] + + for token_id in streamer: # print(token_id, end=",", flush=True) - generated_ids.append(token_id) - i += 1 + self.session_lm_generated_ids[session_id].append(token_id) - if i % self.args.lm_tokenizer_decode_batch_size == 0: + if ( + len(self.session_lm_generated_ids[session_id]) + % self.args.lm_tokenizer_decode_batch_size + == 0 + ): # print(generated_ids) speech_tokens = self._tokenizer.batch_decode( - torch.tensor(generated_ids).to(self.args.lm_device), + torch.tensor(self.session_lm_generated_ids[session_id]).to(self.args.lm_device), skip_special_tokens=True, ) # Convert token <|s_23456|> to int 23456 speech_tokens = extract_speech_ids(speech_tokens) speech_vq_tokens = torch.tensor(speech_tokens).to(self.args.lm_device) yield speech_vq_tokens - generated_ids = [] - j += 1 + with self.session_lm_generat_lock: + self.session_lm_generated_ids[session_id] = [] - if len(generated_ids) > 0: # last batch + if len(self.session_lm_generated_ids[session_id]) > 0: # last batch speech_tokens = self._tokenizer.batch_decode( - torch.tensor(generated_ids).to(self.args.lm_device), + torch.tensor(self.session_lm_generated_ids[session_id]).to(self.args.lm_device), skip_special_tokens=True, ) # Convert token <|s_23456|> to int 23456 speech_tokens = extract_speech_ids(speech_tokens) speech_vq_tokens = torch.tensor(speech_tokens).to(self.args.lm_device) yield speech_vq_tokens + + with self.session_lm_generat_lock: + self.session_lm_generated_ids.pop(session_id) diff --git a/src/core/llm/transformers/manual_speech_step.py b/src/core/llm/transformers/manual_speech_step.py new file mode 100644 index 00000000..4dc0f55c --- /dev/null +++ b/src/core/llm/transformers/manual_speech_step.py @@ -0,0 +1,177 @@ +import logging +import os +import sys +from threading import Thread + +import torch + +try: + from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers.generation.logits_process import LogitsProcessor + from transformers.generation.utils import LogitsProcessorList +except ModuleNotFoundError as e: + logging.error(f"Exception: {e}") + logging.error( + "In order to use Step-Audio-TTS, you need to `pip install achatbot[llm_transformers_manual_speech_step]`. " + ) + raise Exception(f"Missing module: {e}") + +from src.common.utils.helper import get_device +from src.common.session import Session +from src.types.llm.transformers import TransformersLMArgs +from .base import TransformersBaseLLM +from .streamer import TokenStreamer + + +class RepetitionAwareLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + window_size = 10 + threshold = 0.1 + + window = input_ids[:, -window_size:] + if window.shape[1] < window_size: + return scores + + last_tokens = window[:, -1].unsqueeze(-1) + repeat_counts = (window == last_tokens).sum(dim=1) + repeat_ratios = repeat_counts.float() / window_size + + mask = repeat_ratios > threshold + scores[mask, last_tokens[mask].squeeze(-1)] = float("-inf") + return scores + + +class TransformersManualSpeechStep(TransformersBaseLLM): + """ + system prompt + (one short: text->speech(audio vq code) prompt) + tts prompt -> tokenizer encode -> token ids -> StepForCausalLM -> audio vq tokens + with TransformersLMArgs + """ + + TAG = "llm_transformers_manual_speech_step" + DEFAULT_SYS_PROMPT = "Convert the text to speech" + + def __init__(self, **args): + self.args = TransformersLMArgs(**args) + self.args.lm_device = self.args.lm_device or get_device() + logging.info("args: %s", self.args) + + if self.args.lm_device_map: + self._model = AutoModelForCausalLM.from_pretrained( + self.args.lm_model_name_or_path, + torch_dtype=self.args.lm_torch_dtype, + attn_implementation=self.args.lm_attn_impl, + #!NOTE: https://github.com/huggingface/transformers/issues/20896 + # device_map for multi cpu/gpu with accelerate + device_map=self.args.lm_device_map, + trust_remote_code=True, + ).eval() + else: + self._model = ( + AutoModelForCausalLM.from_pretrained( + self.args.lm_model_name_or_path, + torch_dtype=self.args.lm_torch_dtype, + attn_implementation=self.args.lm_attn_impl, + trust_remote_code=True, + ) + .eval() + .to(self.args.lm_device) + ) + + self._tokenizer = AutoTokenizer.from_pretrained( + self.args.lm_model_name_or_path, trust_remote_code=True + ) + self.end_token_id = 3 + end_token_ids = self._tokenizer.encode("<|EOT|>") + if len(end_token_ids) >= 1: + self.end_token_id = end_token_ids[-1] + + self.sys_prompt = self.DEFAULT_SYS_PROMPT + + self.warmup() + + def set_system_prompt(self, **kwargs): + # session sys settings + self.sys_prompt = kwargs.get("sys_prompt", self.sys_prompt) + + @torch.inference_mode() + def warmup(self): + if self.args.warmup_steps < 1: + return + logging.info(f"Warming up {self.__class__.__name__} device: {self._model.device}") + dummy_input_text = self.args.warnup_prompt.strip() + prompt = f"<|BOT|> system\n{self.sys_prompt}" + prompt += f"<|EOT|><|BOT|> human\n{dummy_input_text}" + prompt += "<|EOT|><|BOT|> assistant\n" + token_ids = self._tokenizer.encode(prompt) + + logging.debug(f"prompt:{prompt}") + logging.debug(f"token_ids:{token_ids}") + logging.debug(f"args:{self.args}") + logging.debug(f"end_token_id:{self.end_token_id}") + + # inference token streamer + streamer = TokenStreamer(skip_prompt=True) + + warmup_gen_kwargs = dict( + input_ids=torch.tensor([token_ids]).to(torch.long).to(self._model.device), + eos_token_id=self.end_token_id, + streamer=streamer, + min_new_tokens=self.args.lm_gen_min_new_tokens, + max_new_tokens=self.args.lm_gen_max_new_tokens, + do_sample=True if self.args.lm_gen_temperature > 0.0 else False, + top_k=self.args.lm_gen_top_k, + top_p=self.args.lm_gen_top_p, + temperature=self.args.lm_gen_temperature, + logits_processor=LogitsProcessorList([RepetitionAwareLogitsProcessor()]), + # repetition_penalty=self.args.lm_gen_repetition_penalty, + ) + + self._warmup( + target=self._model.generate, + kwargs=warmup_gen_kwargs, + streamer=streamer, + ) + + # @torch.no_grad() + @torch.inference_mode() + def generate(self, session: Session, **kwargs): + """ + system prompt + (one short: text->speech(audio code) prompt) + tts prompt -> tokenizer encode -> token ids -> step lm -> audio vq tokens + """ + one_shot_ref_text = session.ctx.state.get("ref_text", "") + one_shot_ref_audio = self._tokenizer.decode(session.ctx.state.get("ref_audio_code", [])) + tts_text = session.ctx.state.get("prompt", "") + prompt = f"<|BOT|> system\n{self.sys_prompt}" + prompt += f"<|EOT|><|BOT|> human\n{one_shot_ref_text}" if one_shot_ref_text else "" + prompt += f"<|EOT|><|BOT|> assistant\n{one_shot_ref_audio}" if one_shot_ref_audio else "" + prompt += f"<|EOT|><|BOT|> human\n{tts_text}" + prompt += "<|EOT|><|BOT|> assistant\n" + token_ids = self._tokenizer.encode(prompt) + logging.debug(f"prompt:{prompt}") + logging.debug(f"token_ids:{token_ids}") + logging.debug(f"args:{self.args}") + logging.debug(f"kwargs:{kwargs}") + logging.debug(f"end_token_id:{self.end_token_id}") + + # inference token streamer + streamer = TokenStreamer(skip_prompt=True) + + # inference token streamer + generation_kwargs = dict( + input_ids=torch.tensor([token_ids]).to(torch.long).to(self._model.device), + eos_token_id=self.end_token_id, + streamer=streamer, + min_new_tokens=kwargs.get("min_new_tokens", self.args.lm_gen_min_new_tokens), + max_new_tokens=kwargs.get("max_new_tokens", self.args.lm_gen_max_new_tokens), + do_sample=True if self.args.lm_gen_temperature > 0.0 else False, + top_k=kwargs.get("top_k", self.args.lm_gen_top_k), + top_p=kwargs.get("top_p", self.args.lm_gen_top_p), + temperature=kwargs.get("temperature", self.args.lm_gen_temperature), + logits_processor=LogitsProcessorList([RepetitionAwareLogitsProcessor()]), + # repetition_penalty=self.args.lm_gen_repetition_penalty, + ) + thread = Thread(target=self._model.generate, kwargs=generation_kwargs) + thread.start() + + for token_id in streamer: + yield token_id diff --git a/src/core/llm/transformers/manual_vision_voice_minicpmo.py b/src/core/llm/transformers/manual_vision_voice_minicpmo.py index 5d054978..91affc23 100644 --- a/src/core/llm/transformers/manual_vision_voice_minicpmo.py +++ b/src/core/llm/transformers/manual_vision_voice_minicpmo.py @@ -175,16 +175,16 @@ def reset_session(self): def set_system_prompt(self, **kwargs): # session sys settings # language - self.language = kwargs.pop("language", self.language) + self.language = kwargs.get("language", self.language) # interation mode # "default": default system prompt and not refer to any task # "omni": input video and audio simultaneously # "audio_assistant": Default voice-only mode, the model will use the ref_audio's voice to reply user's question as a helpful assistant. # "audio_roleplay": Roleplay voice-only mode, the model will use the ref_audio's voice to reply, and also role-play the character based on the audio prompt. # "voice_cloning": TTS mode, the model will clone the voice of ref_audio. - self.interaction_mode = kwargs.pop("interaction_mode", self.interaction_mode) + self.interaction_mode = kwargs.get("interaction_mode", self.interaction_mode) # reference audio - ref_audio_path = kwargs.pop("ref_audio_path", None) + ref_audio_path = kwargs.get("ref_audio_path", None) if ref_audio_path is not None: self.ref_audio, _ = librosa.load(ref_audio_path, sr=16000, mono=True) diff --git a/src/core/llm/transformers/manual_voice_glm.py b/src/core/llm/transformers/manual_voice_glm.py index a4cef81e..ea7a0cb1 100644 --- a/src/core/llm/transformers/manual_voice_glm.py +++ b/src/core/llm/transformers/manual_voice_glm.py @@ -6,7 +6,6 @@ try: from transformers import BitsAndBytesConfig, AutoModel, AutoTokenizer - from transformers.generation.streamers import BaseStreamer except ModuleNotFoundError as e: logging.error(f"Exception: {e}") logging.error( @@ -19,43 +18,7 @@ from src.common.session import Session from src.types.llm.transformers import TransformersLMArgs from .base import TransformersBaseLLM - - -class TokenStreamer(BaseStreamer): - def __init__(self, skip_prompt: bool = False, timeout=None): - self.skip_prompt = skip_prompt - - # variables used in the streaming process - self.token_queue = Queue() - self.stop_signal = None - self.next_tokens_are_prompt = True - self.timeout = timeout - - def put(self, value): - if len(value.shape) > 1 and value.shape[0] > 1: - raise ValueError("TextStreamer only supports batch size 1") - elif len(value.shape) > 1: - value = value[0] - - if self.skip_prompt and self.next_tokens_are_prompt: - self.next_tokens_are_prompt = False - return - - for token in value.tolist(): - self.token_queue.put(token) - - def end(self): - self.token_queue.put(self.stop_signal) - - def __iter__(self): - return self - - def __next__(self): - value = self.token_queue.get(timeout=self.timeout) - if value == self.stop_signal: - raise StopIteration() - else: - return value +from .streamer import TokenStreamer class TransformersManualVoicGLM(TransformersBaseLLM): diff --git a/src/core/llm/transformers/streamer.py b/src/core/llm/transformers/streamer.py new file mode 100644 index 00000000..aa375da7 --- /dev/null +++ b/src/core/llm/transformers/streamer.py @@ -0,0 +1,39 @@ +from queue import Queue + +from transformers.generation.streamers import BaseStreamer + +class TokenStreamer(BaseStreamer): + def __init__(self, skip_prompt: bool = False, timeout=None): + self.skip_prompt = skip_prompt + + # variables used in the streaming process + self.token_queue = Queue() + self.stop_signal = None + self.next_tokens_are_prompt = True + self.timeout = timeout + + def put(self, value): + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError("TextStreamer only supports batch size 1") + elif len(value.shape) > 1: + value = value[0] + + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + return + + for token in value.tolist(): + self.token_queue.put(token) + + def end(self): + self.token_queue.put(self.stop_signal) + + def __iter__(self): + return self + + def __next__(self): + value = self.token_queue.get(timeout=self.timeout) + if value == self.stop_signal: + raise StopIteration() + else: + return value \ No newline at end of file diff --git a/src/modules/speech/player/stream_player.py b/src/modules/speech/player/stream_player.py index b8fba5fc..75180ae6 100644 --- a/src/modules/speech/player/stream_player.py +++ b/src/modules/speech/player/stream_player.py @@ -255,6 +255,12 @@ class PlayStreamInit: "rate": 44100, "sample_width": 2, }, + "tts_step": { + "format": PYAUDIO_PAFLOAT32, + "channels": 1, + "rate": 22050, + "sample_width": 2, + }, "tts_daily_speaker": { "format": PYAUDIO_PAINT16, "channels": 1, diff --git a/src/modules/speech/tts/__init__.py b/src/modules/speech/tts/__init__.py index 725ed97c..cfaf298d 100644 --- a/src/modules/speech/tts/__init__.py +++ b/src/modules/speech/tts/__init__.py @@ -44,6 +44,8 @@ def getEngine(tag, **kwargs) -> interface.ITts | EngineClass: from . import minicpmo_tts elif "tts_zonos" == tag: from . import zonos_tts + elif "tts_step" == tag: + from . import step_tts # elif "tts_openai" in tag: # from . import openai_tts @@ -244,6 +246,32 @@ def get_tts_llasa_args() -> dict: ).__dict__ return kwargs + @staticmethod + def get_tts_step_args() -> dict: + from src.types.speech.tts.step import StepTTSArgs + from src.types.llm.transformers import TransformersSpeechLMArgs + + kwargs = StepTTSArgs( + lm_args=TransformersSpeechLMArgs( + lm_model_name_or_path=os.getenv( + "TTS_LM_MODEL_PATH", os.path.join(MODELS_DIR, "stepfun-ai/Step-Audio-TTS-3B") + ), + lm_device=os.getenv("TTS_LM_DEVICE", None), + warmup_steps=int(os.getenv("TTS_WARMUP_STEPS", "1")), + lm_gen_top_k=int(os.getenv("TTS_LM_GEN_TOP_K", "10")), + lm_gen_top_p=float(os.getenv("TTS_LM_GEN_TOP_P", "1.0")), + lm_gen_temperature=float(os.getenv("TTS_LM_GEN_TEMPERATURE", "0.8")), + lm_gen_repetition_penalty=float(os.getenv("TTS_LM_GEN_REPETITION_PENALTY", "1.1")), + ).__dict__, + stream_factor=int(os.getenv("TTS_STREAM_FACTOR", "2")), + tts_mode=os.getenv("TTS_MODE", "lm_gen"), + speech_tokenizer_model_path=os.getenv( + "TTS_TOKENIZER_MODEL_PATH", + os.path.join(MODELS_DIR, "stepfun-ai/Step-Audio-Tokenizer"), + ), + ).__dict__ + return kwargs + @staticmethod def get_tts_minicpmo_args() -> dict: kwargs = LLMEnvInit.get_llm_transformers_args() @@ -286,4 +314,21 @@ def get_tts_zonos_args() -> dict: "tts_g": get_tts_g_args, "tts_minicpmo": get_tts_minicpmo_args, "tts_zonos": get_tts_zonos_args, + "tts_step": get_tts_step_args, + } + + + @staticmethod + def get_tts_step_synth_args() -> dict: + kwargs = {} + kwargs["src_audio_path"] = os.getenv("SRC_AUDIO_PATH", None) + return kwargs + + @staticmethod + def get_tts_synth_args() -> dict: + kwargs = {} + return kwargs + + map_synthesize_config_func = { + "tts_step": get_tts_step_synth_args, } diff --git a/src/modules/speech/tts/step_tts.py b/src/modules/speech/tts/step_tts.py new file mode 100644 index 00000000..f2b54399 --- /dev/null +++ b/src/modules/speech/tts/step_tts.py @@ -0,0 +1,347 @@ +import json +import logging +import math +from pathlib import Path +import re +import sys +from threading import Lock +from typing import AsyncGenerator +import os +import io + +import numpy as np +import torch +import torchaudio + + +try: + cur_dir = os.path.dirname(__file__) + if bool(os.getenv("ACHATBOT_PKG", "")): + sys.path.insert(1, os.path.join(cur_dir, "../../../StepAudio")) + else: + sys.path.insert(1, os.path.join(cur_dir, "../../../../deps/StepAudio")) + from src.core.llm.transformers.manual_speech_step import TransformersManualSpeechStep + from deps.StepAudio.tokenizer import StepAudioTokenizer + from deps.StepAudio.cosyvoice.cli.cosyvoice import CosyVoice +except ModuleNotFoundError as e: + logging.error("In order to use step-tts, you need to `pip install achatbot[tts_step]`.") + raise Exception(f"Missing module: {e}") + +from src.common.random import set_all_random_seed +from src.common.types import PYAUDIO_PAFLOAT32, ASSETS_DIR +from src.common.interface import ITts +from src.common.session import Session +from src.types.speech.tts.step import StepTTSArgs +from src.types.llm.transformers import TransformersSpeechLMArgs +from .base import BaseTTS + + +class StepTTS(BaseTTS, ITts): + """ + support tts mode: + - lm_gen: text+ref audio waveform lm gen audio wav code to gen waveform with static batch stream: + text+ref audio waveform -> tokenizer -> text+audio token ids -> step1 lm -> audio token ids (wav_code) -> flow(CFM) -> mel - vocoder(hifi) -> waveform + - voice_clone: voice clone w/o lm gen, decode wav code: + src+ref audio waveform -> speech tokenizer-> audio token ids (wav_code) -> flow(CFM) -> mel - vocoder(hifi) -> clone ref audio waveform + """ + + TAG = "tts_step" + + def __init__(self, **kwargs) -> None: + self.args = StepTTSArgs(**kwargs) + assert ( + self.args.stream_factor >= 2 + ), "stream_factor must >=2 increase for better speech quality, but rtf slow (speech quality vs rtf)" + + self.encoder = StepAudioTokenizer(self.args.speech_tokenizer_model_path) + + self.lm_args = TransformersSpeechLMArgs(**self.args.lm_args) + self.lm_model = TransformersManualSpeechStep(**self.lm_args.__dict__) + # session ctx dict with lock, maybe need a session class + self.session_lm_generat_lock = Lock() + self.session_lm_generated_ids = {} # session_id: ids(ptr) + + self.common_cosy_model = CosyVoice( + os.path.join(self.lm_args.lm_model_name_or_path, "CosyVoice-300M-25Hz") + ) + self.music_cosy_model = CosyVoice( + os.path.join(self.lm_args.lm_model_name_or_path, "CosyVoice-300M-25Hz-Music") + ) + + self.sys_prompt_dict = { + "sys_prompt_for_rap": "请参考对话历史里的音色,用RAP方式将文本内容大声说唱出来。", + "sys_prompt_for_vocal": "请参考对话历史里的音色,用哼唱的方式将文本内容大声唱出来。", + "sys_prompt_wo_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。', + "sys_prompt_with_spk": '作为一名卓越的声优演员,你的任务是根据文本中()或()括号内标注的情感、语种或方言、音乐哼唱、语音调整等标签,以丰富细腻的情感和自然顺畅的语调来朗读文本。\n# 情感标签涵盖了多种情绪状态,包括但不限于:\n- "高兴1"\n- "高兴2"\n- "生气1"\n- "生气2"\n- "悲伤1"\n- "撒娇1"\n\n# 语种或方言标签包含多种语言或方言,包括但不限于:\n- "中文"\n- "英文"\n- "韩语"\n- "日语"\n- "四川话"\n- "粤语"\n- "广东话"\n\n# 音乐哼唱标签包含多种类型歌曲哼唱,包括但不限于:\n- "RAP"\n- "哼唱"\n\n# 语音调整标签,包括但不限于:\n- "慢速1"\n- "慢速2"\n- "快速1"\n- "快速2"\n\n请在朗读时,使用[{}]的声音,根据这些情感标签的指示,调整你的情感、语气、语调和哼唱节奏,以确保文本的情感和意义得到准确而生动的传达,如果没有()或()括号,则根据文本语义内容自由演绎。', + } + + self.speakers_info = {} + self.register_speakers() + + # lm model gen warmup, codec model decode(flow + hifi) don't to warmup + + def register_speakers(self): + self.speakers_info = {} + + speackers_info_path = os.path.join(ASSETS_DIR, "speakers/speakers_info.json") + with open(speackers_info_path, "r") as f: + speakers_info = json.load(f) + + for speaker_id, prompt_text in speakers_info.items(): + prompt_wav_path = os.path.join(ASSETS_DIR, f"speakers/{speaker_id}_prompt.wav") + ( + ref_audio_code, + ref_audio_token, + ref_audio_token_len, + ref_speech_feat, + ref_speech_feat_len, + ref_speech_embedding, + ) = self.preprocess_prompt_wav(prompt_wav_path) + + self.speakers_info[speaker_id] = { + "ref_text": prompt_text, + "ref_audio_code": ref_audio_code, + "ref_speech_feat": ref_speech_feat.to(torch.bfloat16), + "ref_speech_feat_len": ref_speech_feat_len, + "ref_speech_embedding": ref_speech_embedding.to(torch.bfloat16), + "ref_audio_token": ref_audio_token, + "ref_audio_token_len": ref_audio_token_len, + } + logging.info(f"Registered speaker: {speaker_id}") + + def wav2code(self, prompt_wav_path: str): + prompt_wav, prompt_wav_sr = torchaudio.load(prompt_wav_path) + if prompt_wav.shape[0] > 1: + prompt_wav = prompt_wav.mean(dim=0, keepdim=True) # multi-channel to mono + prompt_code, _, _ = self.encoder.wav2token(prompt_wav, prompt_wav_sr) + return prompt_code + + def preprocess_prompt_wav(self, prompt_wav_path: str): + prompt_wav, prompt_wav_sr = torchaudio.load(prompt_wav_path) + if prompt_wav.shape[0] > 1: + prompt_wav = prompt_wav.mean(dim=0, keepdim=True) # multi-channel to mono + prompt_wav_16k = torchaudio.transforms.Resample(orig_freq=prompt_wav_sr, new_freq=16000)( + prompt_wav + ) + prompt_wav_22k = torchaudio.transforms.Resample(orig_freq=prompt_wav_sr, new_freq=22050)( + prompt_wav + ) + + speech_feat, speech_feat_len = self.common_cosy_model.frontend._extract_speech_feat( + prompt_wav_22k + ) + speech_embedding = self.common_cosy_model.frontend._extract_spk_embedding(prompt_wav_16k) + + prompt_code, _, _ = self.encoder.wav2token(prompt_wav, prompt_wav_sr) + prompt_token = torch.tensor([prompt_code], dtype=torch.long) - 65536 + prompt_token_len = torch.tensor([prompt_token.shape[1]], dtype=torch.long) + + return ( + prompt_code, + prompt_token, + prompt_token_len, + speech_feat, + speech_feat_len, + speech_embedding, + ) + + def set_voice(self, ref_audio_path: str, **kwargs): + """ + - save to speacker info dict + TODO: save dict to dist kv store + """ + assert os.path.exists(ref_audio_path), "ref_audio_path is not exists" + assert kwargs.get( + "ref_speaker", None + ), "ref_speaker is not exists" # maybe use random speaker + assert kwargs.get( + "ref_text", None + ), "ref_text is not exists" # maybe use asr to get ref_text + ref_speaker = kwargs.get("ref_speaker") + ref_text = kwargs.get("ref_text") + + ( + ref_audio_code, + ref_audio_token, + ref_audio_token_len, + ref_speech_feat, + ref_speech_feat_len, + ref_speech_embedding, + ) = self.preprocess_prompt_wav(ref_audio_path) + + self.speakers_info[ref_speaker] = { + "ref_text": ref_text, + "ref_audio_code": ref_audio_code, + "ref_speech_feat": ref_speech_feat.to(torch.bfloat16), + "ref_speech_feat_len": ref_speech_feat_len, + "ref_speech_embedding": ref_speech_embedding.to(torch.bfloat16), + "ref_audio_token": ref_audio_token, + "ref_audio_token_len": ref_audio_token_len, + } + + def get_voices(self) -> list: + return list(self.speakers_info.keys()) + + def get_stream_info(self) -> dict: + return { + # "format": PYAUDIO_PAINT16, + "format": PYAUDIO_PAFLOAT32, + "channels": 1, + "rate": 22050, + "sample_width": 2, + # "np_dtype": np.int16, + "np_dtype": np.float32, + } + + def detect_instruction_name(self, text): + instruction_name = "" + match_group = re.match(r"^([(\(][^\(\)()]*[)\)]).*$", text, re.DOTALL) + if match_group is not None: + instruction = match_group.group(1) + instruction_name = instruction.strip("()()") + return instruction_name + + def set_system_prompt(self, text, ref_speaker: str = "Tingting"): + sys_prompt = self.sys_prompt_dict["sys_prompt_wo_spk"] + instruction_name = self.detect_instruction_name(text) + if instruction_name: + if "哼唱" in text: + sys_prompt = self.sys_prompt_dict["sys_prompt_for_vocal"] + else: + sys_prompt = self.sys_prompt_dict["sys_prompt_for_rap"] + elif ref_speaker: + sys_prompt = self.sys_prompt_dict["sys_prompt_with_spk"].format(ref_speaker) + self.lm_model.set_system_prompt(sys_prompt=sys_prompt) + + def voice_clone(self, session: Session, ref_speaker: str, cosy_model, **kwargs): + """ + - voice_clone: voice clone w/o lm gen, decode wav code: + src+ref audio waveform -> speech tokenizer-> audio token ids (wav_code) -> flow(CFM) -> mel - vocoder(hifi) -> clone ref audio waveform + """ + src_audio_path = kwargs.get("src_audio_path", None) + if not src_audio_path or not os.path.exists(src_audio_path): + logging.error(f"{src_audio_path} is not exists") + return None + + src_audio_code = self.wav2code(src_audio_path) + tensor_audio_token_ids = torch.tensor([src_audio_code]).to(torch.long).to("cuda") - 65536 + tts_speech = cosy_model.token_to_wav_offline( + tensor_audio_token_ids, + self.speakers_info[ref_speaker]["ref_speech_feat"].to(torch.bfloat16), + self.speakers_info[ref_speaker]["ref_speech_feat_len"], + self.speakers_info[ref_speaker]["ref_audio_token"], + self.speakers_info[ref_speaker]["ref_audio_token_len"], + self.speakers_info[ref_speaker]["ref_speech_embedding"].to(torch.bfloat16), + ) + return tts_speech.float().numpy().tobytes() + + async def lm_gen( + self, + session: Session, + text: str, + ref_speaker: str, + batch_size: int, + cosy_model, + **kwargs, + ) -> AsyncGenerator[bytes, None]: + """ + - lm_gen: text+ref audio waveform lm gen audio wav code to gen waveform with static batch stream: + text+ref audio waveform -> tokenizer -> text+audio token ids -> step1 lm -> audio token ids (wav_code) -> flow(CFM) -> mel - vocoder(hifi) -> waveform + """ + session_id = session.ctx.client_id + + self.set_system_prompt(text, ref_speaker=ref_speaker) + session.ctx.state["ref_text"] = self.speakers_info[ref_speaker]["ref_text"] + session.ctx.state["ref_audio_code"] = self.speakers_info[ref_speaker]["ref_audio_code"] + session.ctx.state["prompt"] = text + audio_vq_tokens = self.lm_model.generate(session, **kwargs) + for token_id in audio_vq_tokens: + # print(token_id, end=",", flush=True) + if token_id == self.lm_model.end_token_id: # skip <|EOT|>, break + break + self.session_lm_generated_ids[session_id].append(token_id) + if len(self.session_lm_generated_ids[session_id]) % batch_size == 0: + batch = ( + torch.tensor(self.session_lm_generated_ids[session_id]) + .unsqueeze(0) + .to(cosy_model.model.device) + - 65536 + ) # [T] -> [1,T] + logging.debug(f"batch: {batch}") + # Process each batch + sub_tts_speech = cosy_model.token_to_wav_offline( + batch, + self.speakers_info[ref_speaker]["ref_speech_feat"].to(torch.bfloat16), + self.speakers_info[ref_speaker]["ref_speech_feat_len"], + self.speakers_info[ref_speaker]["ref_audio_token"], + self.speakers_info[ref_speaker]["ref_audio_token_len"], + self.speakers_info[ref_speaker]["ref_speech_embedding"].to(torch.bfloat16), + ) + yield sub_tts_speech.float().numpy().tobytes() + with self.session_lm_generat_lock: + self.session_lm_generated_ids[session_id] = [] + + if len(self.session_lm_generated_ids[session_id]) > 0: + batch = ( + torch.tensor(self.session_lm_generated_ids[session_id]) + .unsqueeze(0) + .to(cosy_model.model.device) + - 65536 + ) # [T] -> [1,T] + logging.debug(f"batch: {batch}") + # Process each batch + sub_tts_speech = cosy_model.token_to_wav_offline( + batch, + self.speakers_info[ref_speaker]["ref_speech_feat"].to(torch.bfloat16), + self.speakers_info[ref_speaker]["ref_speech_feat_len"], + self.speakers_info[ref_speaker]["ref_audio_token"], + self.speakers_info[ref_speaker]["ref_audio_token_len"], + self.speakers_info[ref_speaker]["ref_speech_embedding"].to(torch.bfloat16), + ) + yield sub_tts_speech.float().numpy().tobytes() + + async def _inference( + self, session: Session, text: str, **kwargs + ) -> AsyncGenerator[bytes, None]: + if "cuda" in str(self.lm_model._model.device): + torch.cuda.empty_cache() + seed = kwargs.get("seed", self.lm_args.lm_gen_seed) + set_all_random_seed(seed) + + ref_speaker = kwargs.pop("ref_speaker", "Tingting") + instruction_name = self.detect_instruction_name(text) + cosy_model = self.common_cosy_model + if instruction_name in ["RAP", "哼唱"]: + cosy_model = self.music_cosy_model + ref_speaker = f"{ref_speaker}{instruction_name}" + if ref_speaker not in self.speakers_info: + ref_speaker = f"Tingting{instruction_name}" + if ref_speaker and ref_speaker not in self.speakers_info: + ref_speaker = "Tingting" + logging.debug(f"use ref_speaker: {ref_speaker}") + + assert ( + kwargs.get("stream_factor", self.args.stream_factor) >= 2 + ), "stream_factor must >=2 increase for better speech quality, but rtf slow (speech quality vs rtf)" + batch_size = math.ceil( + kwargs.get("stream_factor", self.args.stream_factor) + * cosy_model.model.flow.input_frame_rate + ) + + session_id = session.ctx.client_id + with self.session_lm_generat_lock: + self.session_lm_generated_ids[session_id] = [] + + tts_mode = kwargs.get("tts_mode", self.args.tts_mode) + if tts_mode == "voice_clone": + tts_speech = self.voice_clone(session, ref_speaker, cosy_model, **kwargs) + yield tts_speech + else: # lm_gen + async for item in self.lm_gen( + session, text, ref_speaker, batch_size, cosy_model, **kwargs + ): + yield item + + with self.session_lm_generat_lock: + self.session_lm_generated_ids.pop(session_id) + torch.cuda.empty_cache() diff --git a/src/types/speech/tts/step.py b/src/types/speech/tts/step.py new file mode 100644 index 00000000..867ace2d --- /dev/null +++ b/src/types/speech/tts/step.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass, field +import os + +from src.common.types import MODELS_DIR +from src.types.llm.transformers import TransformersSpeechLMArgs + + +@dataclass +class StepTTSArgs: + """ + TransformersManualSpeechLlasa LM + Linguistic and Semantic Tokenizer(ref audio encoder) -> Step-Audio TTS + """ + + lm_args: dict = field(default_factory=TransformersSpeechLMArgs().__dict__) + # >=2 increase for better speech quality, but rtf slow (speech quality vs rtf) + stream_factor: int = 2 + + tts_mode: str = "lm_gen" # lm_gen(lm_gen->flow->hifi), voice_clone(no lm_gen, flow->hifi) + + # Linguistic and Semantic speech Tokenizer(ref audio encoder) args + speech_tokenizer_model_path: str = os.path.join(MODELS_DIR, "stepfun-ai/Step-Audio-Tokenizer") diff --git a/test/modules/speech/tts/test_step.py b/test/modules/speech/tts/test_step.py new file mode 100644 index 00000000..e130905c --- /dev/null +++ b/test/modules/speech/tts/test_step.py @@ -0,0 +1,184 @@ +import os + +import numpy as np +import soundfile +import unittest + +from src.common.interface import ITts +from src.modules.speech.tts import TTSEnvInit +from src.common.logger import Logger +from src.common.session import Session +from src.common.types import RECORDS_DIR, SessionCtx, TEST_DIR + +r""" +# ---- TTS_MODE: voice_clone ---- + +python -m unittest test.modules.speech.tts.test_step.TestStepTTS.test_get_voices +REF_AUDIO_PATH=./test/audio_files/asr_example_zh.wav \ + REF_TEXT="欢迎大家来体验达摩院推出的语音识别模型" \ + python -m unittest test.modules.speech.tts.test_step.TestStepTTS.test_set_voice + +python -m unittest test.modules.speech.tts.test_step.TestStepTTS.test_synthesize +python -m unittest test.modules.speech.tts.test_step.TestStepTTS.test_synthesize_speak + +# ref audio +TTS_STREAM_FACTOR=4 \ + REF_AUDIO_PATH=./test/audio_files/asr_example_zh.wav \ + REF_TEXT="欢迎大家来体验达摩院推出的语音识别模型" \ + TTS_TEXT="万物之始,大道至简,衍化至繁。君不见黄河之水天上来,奔流到海不复回。君不见高堂明镜悲白发,朝如青丝暮成雪。人生得意须尽欢,莫使金樽空对月。天生我材必有用,千金散尽还复来。" \ + python -m unittest test.modules.speech.tts.test_step.TestStepTTS.test_synthesize + +TTS_STREAM_FACTOR=4 \ +REF_AUDIO_PATH=./test/audio_files/asr_example_zh.wav \ + REF_TEXT="欢迎大家来体验达摩院推出的语音识别模型" \ + TTS_TEXT="万物之始,大道至简,衍化至繁。君不见黄河之水天上来,奔流到海不复回。君不见高堂明镜悲白发,朝如青丝暮成雪。人生得意须尽欢,莫使金樽空对月。天生我材必有用,千金散尽还复来。" \ + python -m unittest test.modules.speech.tts.test_step.TestStepTTS.test_synthesize_speak + +# ---- TTS_MODE: voice_clone ---- +# src audio + default ref audio +SRC_AUDIO_PATH=./test/audio_files/asr_example_zh.wav \ + python -m unittest test.modules.speech.tts.test_step.TestStepTTS.test_synthesize + +""" + + +class TestStepTTS(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.tts_tag = os.getenv("TTS_TAG", "tts_step") + cls.src_audio_path = os.getenv( + "SRC_AUDIO_PATH", + # os.path.join(TEST_DIR, "audio_files/asr_example_zh.wav"), + "", + ) + cls.ref_audio_path = os.getenv( + "REF_AUDIO_PATH", + # os.path.join(TEST_DIR, "audio_files/asr_example_zh.wav"), + "", + ) + cls.ref_text = os.getenv( + "REF_TEXT", + # "欢迎大家来体验达摩院推出的语音识别模型", + "", + ) + cls.tts_text = os.getenv( + "TTS_TEXT", + "你好,hello.", + ) + + Logger.init(os.getenv("LOG_LEVEL", "debug").upper(), is_file=False) + + @classmethod + def tearDownClass(cls): + pass + + def setUp(self): + self.tts: ITts = TTSEnvInit.initTTSEngine(self.tts_tag) + self.session = Session(**SessionCtx("test_tts_client_id").__dict__) + self.pyaudio_instance = None + self.audio_stream = None + + def tearDown(self): + self.audio_stream and self.audio_stream.stop_stream() + self.audio_stream and self.audio_stream.close() + self.pyaudio_instance and self.pyaudio_instance.terminate() + + def test_get_voices(self): + voices = self.tts.get_voices() + self.assertGreaterEqual(len(voices), 0) + print(voices) + + def test_set_voice(self): + voices = self.tts.get_voices() + self.assertGreaterEqual(len(voices), 0) + print(voices) + + self.tts.set_voice( + self.ref_audio_path, + ref_speaker="test_speaker", + ref_text=self.ref_text, + ) + add_voices = self.tts.get_voices() + self.assertEqual(len(add_voices), len(voices) + 1) + print(add_voices) + + def test_synthesize(self): + ref_speaker = "" + if os.path.exists(self.ref_audio_path): + ref_speaker = "test_speaker" + self.tts.set_voice( + self.ref_audio_path, + ref_speaker=ref_speaker, + ref_text=self.ref_text, + ) + self.session.ctx.state["ref_speaker"] = ref_speaker + else: + voices = self.tts.get_voices() + self.assertGreaterEqual(len(voices), 0) + print(f"use default voices: {voices}") + + if os.path.exists(self.src_audio_path): + self.session.ctx.state["src_audio_path"] = self.src_audio_path + + self.session.ctx.state["tts_text"] = self.tts_text + print(self.session.ctx) + iter = self.tts.synthesize_sync(self.session) + res = bytearray() + for i, chunk in enumerate(iter): + print(i, len(chunk)) + res.extend(chunk) + + stream_info = self.tts.get_stream_info() + print(f"stream_info:{stream_info}") + + file_name = f"test_{self.tts.TAG}.wav" + os.makedirs(RECORDS_DIR, exist_ok=True) + file_path = os.path.join(RECORDS_DIR, file_name) + data = np.frombuffer(res, dtype=stream_info["np_dtype"]) + soundfile.write(file_path, data, stream_info["rate"]) + + print(file_path) + + def test_synthesize_speak(self): + import pyaudio + + stream_info = self.tts.get_stream_info() + self.pyaudio_instance = pyaudio.PyAudio() + self.audio_stream = self.pyaudio_instance.open( + format=stream_info["format"], + channels=stream_info["channels"], + rate=stream_info["rate"], + output_device_index=None, + output=True, + ) + + ref_speaker = "" + if os.path.exists(self.ref_audio_path): + ref_speaker = "test_speaker" + self.tts.set_voice( + self.ref_audio_path, + ref_speaker=ref_speaker, + ref_text=self.ref_text, + ) + self.session.ctx.state["ref_speaker"] = ref_speaker + else: + voices = self.tts.get_voices() + self.assertGreaterEqual(len(voices), 0) + print(f"use default voices: {voices}") + + if os.path.exists(self.src_audio_path): + self.session.ctx.state["src_audio_path"] = self.src_audio_path + + self.session.ctx.state["tts_text"] = self.tts_text + print(self.session.ctx) + iter = self.tts.synthesize_sync(self.session) + sub_chunk_size = 1024 + for i, chunk in enumerate(iter): + print(f"get {i} chunk {len(chunk)}") + self.assertGreaterEqual(len(chunk), 0) + if len(chunk) / sub_chunk_size < 100: + self.audio_stream.write(chunk) + continue + for i in range(0, len(chunk), sub_chunk_size): + sub_chunk = chunk[i : i + sub_chunk_size] + self.audio_stream.write(sub_chunk)