diff --git a/.env.example b/.env.example index 2762a942..e1fc9593 100644 --- a/.env.example +++ b/.env.example @@ -66,4 +66,14 @@ METERED_TURN_CREDENTIAL= # https://developers.cloudflare.com/realtime/turn/ CLOUDFLARE_TURN_TOKEN= -CLOUDFLARE_TURN_API_TOKEN= \ No newline at end of file +CLOUDFLARE_TURN_API_TOKEN= + +# d1 +CLOUDFLARE_ACCOUNT_ID= +CLOUDFLARE_API_KEY= +PODCAST_D1_DB_ID= +# r2 +CLOUDFLARE_ACCESS_KEY= +CLOUDFLARE_SECRET_KEY= +CLOUDFLARE_REGION=apac +S3_BUCKET_URL= \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 4444ddbc..837756f4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -78,3 +78,6 @@ [submodule "deps/HiggsAudio"] path = deps/HiggsAudio url = https://github.com/weedge/higgs-audio.git +[submodule "deps/StepAudio2"] + path = deps/StepAudio2 + url = https://github.com/weedge/Step-Audio2.git diff --git a/demo/cloudflare/rest_api.py b/demo/cloudflare/rest_api.py index c3edb93d..c752a3b5 100644 --- a/demo/cloudflare/rest_api.py +++ b/demo/cloudflare/rest_api.py @@ -42,7 +42,7 @@ def d1_table_query(db_id: str, sql: str, sql_params: List[str] = []) -> dict: data = res.read().decode("utf-8") # print(data) json_data = json.loads(data) - # logging.info(f"body:{body}, db_id:{db_id}, query res:{json_data}") + logging.debug(f"body:{body}, db_id:{db_id}, query res:{json_data}") return json_data @@ -85,7 +85,7 @@ def d1_db(db_id: str) -> dict: """ if __name__ == "__main__": logging.basicConfig( - level=logging.INFO, + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(pathname)s:%(lineno)d - %(funcName)s - %(message)s", handlers=[logging.StreamHandler()], ) diff --git a/demo/content_parser_tts.py b/demo/content_parser_tts.py index 57acf190..345f320a 100644 --- a/demo/content_parser_tts.py +++ b/demo/content_parser_tts.py @@ -67,7 +67,15 @@ async def gen_podcast_tts_audios( podcast_index, role_index = (0, 0) pre_role = "" pre_cn, cur_cn = (0, 0) + title = "" + description = "" for extraction in data_models: + if title == "" and extraction.description: + title = extraction.title + print(f"title:{title}\n") + if description == "" and extraction.roles: + description = extraction.description + print(f"description:{description}\n") if not extraction.roles: continue p_save_dir = os.path.join(save_dir, str(podcast_index)) @@ -111,6 +119,7 @@ async def gen_podcast_tts_audios( role_index += 1 await edge_tts_conversion(role.content, output_file, voice) + # print(extraction) return extraction diff --git a/demo/gen_podcast.py b/demo/gen_podcast.py index 4158a202..90013dd9 100644 --- a/demo/gen_podcast.py +++ b/demo/gen_podcast.py @@ -64,6 +64,7 @@ def run( save_dir=save_dir, ) for data in data_list: + print(data) source = data[0] extraction: podcast.Podcast = data[1] audio_output_file = data[2] diff --git a/demo/insert_podcast.py b/demo/insert_podcast.py index ad9c87a6..3965398f 100644 --- a/demo/insert_podcast.py +++ b/demo/insert_podcast.py @@ -126,6 +126,7 @@ def get_podcast( raise # 如果达到最大重试次数,抛出异常 gen_img_prompt = f"podcast cover image which content is about {en_title}" + print(f"{gen_img_prompt}") img_file = save_gen_image(gen_img_prompt, uuid.uuid4().hex) cover_img_url = r2_upload("podcast", img_file) @@ -201,11 +202,20 @@ def insert_podcast_to_d1( formatted_time, podcast.audio_size, ] + # ==================== + # debug_sql = sql.replace("?", "{}").format( + # *[f"'{p}'" if isinstance(p, str) else str(p) for p in sql_params] + # ) + # print(f"Debug SQL: {debug_sql}") + # ==================== + res = d1_table_query(db_id, sql, sql_params) if res["success"] is True: logging.info( f"insert podcast success, url: https://podcast-997.pages.dev/podcast/{podcast.pid}" ) + else: + logging.error(f"insert podcast failed, res: {res}") return res["success"] @@ -224,6 +234,7 @@ def update_podcast_cover_to_d1( pid, ] res = d1_table_query(db_id, sql, sql_params) + print(res) return res["success"] diff --git a/deploy/modal/src/fastapi_webrtc_step2_voice_bot_serve.py b/deploy/modal/src/fastapi_webrtc_step2_voice_bot_serve.py new file mode 100644 index 00000000..97b0fdb5 --- /dev/null +++ b/deploy/modal/src/fastapi_webrtc_step2_voice_bot_serve.py @@ -0,0 +1,157 @@ +import os + +import modal + + +achatbot_version = os.getenv("ACHATBOT_VERSION", "0.0.25") +app = modal.App("step-audio2-voice-bot") +# fastapi_webrtc_bots | fastapi_webrtc_single_bot server +SERVER_TAG = os.getenv("SERVER_TAG", "fastapi_webrtc_bots") +IMAGE_GPU = os.getenv("IMAGE_GPU", "L4") +img = ( + # https://catalog.ngc.nvidia.com/orgs/nvidia/containers/cuda/tags + modal.Image.from_registry( + "nvcr.io/nvidia/cuda:12.9.1-cudnn-devel-ubuntu22.04", + add_python="3.10", + ) + .apt_install("git", "git-lfs", "ffmpeg") + .pip_install( + [ + "achatbot[" + "fastapi_bot_server," + "livekit,livekit-api,daily,agora," + "silero_vad_analyzer," + "sense_voice_asr,deepgram_asr_processor," + "tts_edge," + "queue" + f"]~={achatbot_version}", + ], + extra_index_url=os.getenv("EXTRA_INDEX_URL", "https://pypi.org/simple/"), + ) + .pip_install( + "transformers==4.49.0", + "torchaudio", + "librosa", + "onnxruntime", + "s3tokenizer", + "diffusers", + "hyperpyyaml", + "huggingface_hub", + ) + .env( + { + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", + "ACHATBOT_PKG": "1", + "SERVER_TAG": SERVER_TAG, + "CONFIG_FILE": os.getenv( + "CONFIG_FILE", + "/root/.achatbot/config/bots/daily_step_audio2_aqaa_bot.json", + ), + } + ) +) + +# img = img.pip_install( +# f"achatbot==0.0.25.dev57", +# extra_index_url=os.getenv("EXTRA_INDEX_URL", "https://test.pypi.org/simple/"), +# ) + + +HF_MODEL_DIR = "/root/.achatbot/models" +hf_model_vol = modal.Volume.from_name("models", create_if_missing=True) +ASSETS_DIR = "/root/.achatbot/assets" +assets_vol = modal.Volume.from_name("assets", create_if_missing=True) +CONFIG_DIR = "/root/.achatbot/config" +config_vol = modal.Volume.from_name("config", create_if_missing=True) +RECORDS_DIR = "/root/.achatbot/records" +records_vol = modal.Volume.from_name("records", create_if_missing=True) + +TORCH_CACHE_DIR = "/root/.cache/torch" +torch_cache_vol = modal.Volume.from_name("torch_cache", create_if_missing=True) + + +# 128 MiB of memory and 0.125 CPU cores by default container runtime +@app.cls( + image=img, + gpu=os.getenv("IMAGE_GPU", None), + secrets=[modal.Secret.from_name("achatbot")], + volumes={ + HF_MODEL_DIR: hf_model_vol, + ASSETS_DIR: assets_vol, + CONFIG_DIR: config_vol, + RECORDS_DIR: records_vol, + TORCH_CACHE_DIR: torch_cache_vol, + }, + cpu=2.0, + timeout=1200, # default 300s + scaledown_window=1200, + max_containers=1, + # allow_concurrent_inputs=int(os.getenv("IMAGE_CONCURRENT_CN", "1")), +) +@modal.concurrent(max_inputs=int(os.getenv("IMAGE_CONCURRENT_CN", "1"))) # inputs per container +class Srv: + @modal.enter() + def enter(self): + # run container runtime to enter when container is starting + import subprocess + import torch + + subprocess.run("nvidia-smi --version", shell=True) + gpu_prop = None + if torch.cuda.is_available(): + gpu_prop = torch.cuda.get_device_properties("cuda:0") + print(gpu_prop) + torch.multiprocessing.set_start_method("spawn", force=True) + else: + print("CUDA is not available.") + + @modal.asgi_app() + def app(self): + SERVER_TAG = os.getenv("SERVER_TAG", "fastapi_webrtc_bots") + if SERVER_TAG == "fastapi_webrtc_single_bot": + from achatbot.cmd.http.server.fastapi_room_bot_serve import app as fastapi_app + + print("run fastapi_room_bot_serve(single bot)") + else: + from achatbot.cmd.http.server.fastapi_daily_bot_serve import app as fastapi_app + + print("run fastapi_daily_bot_serve(multi bots)") + + return fastapi_app + + +""" +# 0. download models and assets +modal run src/download_models.py --repo-ids "stepfun-ai/Step-Audio-2-mini" +modal run src/download_assets.py --asset-urls "https://raw.githubusercontent.com/stepfun-ai/Step-Audio2/refs/heads/main/assets/default_male.wav" +modal run src/download_assets.py --asset-urls "https://raw.githubusercontent.com/stepfun-ai/Step-Audio2/refs/heads/main/assets/default_female.wav" + +# 1. run webrtc room http bots server + +IMAGE_GPU=L4 SERVER_TAG=fastapi_webrtc_bots \ + ACHATBOT_VERSION=0.0.25 \ + modal serve src/fastapi_webrtc_step2_voice_bot_serve.py + +# 2. run webrtc room http signal bot server + +modal volume create config + +modal volume put config ./config/bots/daily_step_audio2_aqaa_bot.json /bots/ -f +modal volume put config ./config/bots/daily_step_audio2_aqaa_tools_bot.json /bots/ -f + +# run container with gpu +IMAGE_GPU=L4 SERVER_TAG=fastapi_webrtc_single_bot \ + ACHATBOT_VERSION=0.0.25 \ + CONFIG_FILE=/root/.achatbot/config/bots/daily_step_audio2_aqaa_bot.json \ + modal serve src/fastapi_webrtc_step2_voice_bot_serve.py +IMAGE_GPU=L4 SERVER_TAG=fastapi_webrtc_single_bot \ + ACHATBOT_VERSION=0.0.25 \ + CONFIG_FILE=/root/.achatbot/config/bots/daily_step_audio2_aqaa_tools_bot.json \ + modal serve src/fastapi_webrtc_step2_voice_bot_serve.py + +# cold start fastapi webrtc http server +curl -v -XGET "https://weedge--step-audio2-voice-bot-srv-app-dev.modal.run/health" + +# run bot +curl -XPOST "https://weedge--step-audio2-voice-bot-srv-app-dev.modal.run/bot_join/chat-room/DailyStepAudio2AQAABot" +""" diff --git a/deploy/modal/src/llm/transformers/step_audio2.py b/deploy/modal/src/llm/transformers/step_audio2.py new file mode 100644 index 00000000..9ddb5034 --- /dev/null +++ b/deploy/modal/src/llm/transformers/step_audio2.py @@ -0,0 +1,1685 @@ +import io +import math +import requests +import os +import sys +import json +import time +import asyncio +import subprocess +from pathlib import Path +from threading import Thread + + +import modal + + +app = modal.App("step-audio") +IMAGE_GPU = os.getenv("IMAGE_GPU", None) +img = ( + # https://catalog.ngc.nvidia.com/orgs/nvidia/containers/cuda/tags + modal.Image.from_registry( + "nvcr.io/nvidia/cuda:12.9.1-cudnn-devel-ubuntu22.04", + add_python="3.10", + ) + .apt_install("git", "git-lfs") + .pip_install( + "transformers==4.49.0", + "torchaudio", + "librosa", + "onnxruntime", + "s3tokenizer", + "diffusers", + "hyperpyyaml", + ) + .run_commands( + "git clone https://github.com/weedge/Step-Audio2.git -b torch_compile" + " && cd /Step-Audio2" + " && git checkout d340cd7b8318cb04ff231e5cf1adbe112e5097b1" + ) + .env( + { + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", + "ACHATBOT_PKG": "1", + "LLM_MODEL": os.getenv("LLM_MODEL", "stepfun-ai/Step-Audio-2-mini"), + } + ) +) + +img = img.pip_install( + f"achatbot==0.0.25", + extra_index_url=os.getenv("EXTRA_INDEX_URL", "https://pypi.org/simple/"), +) + + +HF_MODEL_DIR = "/root/.achatbot/models" +hf_model_vol = modal.Volume.from_name("models", create_if_missing=True) +ASSETS_DIR = "/root/.achatbot/assets" +assets_vol = modal.Volume.from_name("assets", create_if_missing=True) +CONFIG_DIR = "/root/.achatbot/config" +config_vol = modal.Volume.from_name("config", create_if_missing=True) +RECORDS_DIR = "/root/.achatbot/records" +records_vol = modal.Volume.from_name("records", create_if_missing=True) + +TORCH_CACHE_DIR = "/root/.cache/torch" +torch_cache_vol = modal.Volume.from_name("torch_cache", create_if_missing=True) + + +with img.imports(): + from queue import Queue + + import wave + import torch + from transformers import GenerationConfig + from transformers.generation.streamers import BaseStreamer + + sys.path.insert(1, "/Step-Audio2") + + from stepaudio2 import StepAudio2, StepAudio2Base + from token2wav import Token2wav + from utils import compute_token_num, load_audio, log_mel_spectrogram, padding_mels + + MODEL_ID = os.getenv("LLM_MODEL", "stepfun-ai/Step-Audio-2-mini") + MODEL_PATH = os.path.join(HF_MODEL_DIR, MODEL_ID) + os.makedirs(f"{ASSETS_DIR}/StepAudio2", exist_ok=True) + + CHUNK_SIZE = 25 + + # torch.set_float32_matmul_precision("high") + + +def print_model_params(model: torch.nn.Module, extra_info="", f=None): + # print the number of parameters in the model + model_million_params = sum(p.numel() for p in model.parameters()) / 1e6 + print(model, file=f) + print(f"{extra_info} {model_million_params} M parameters", file=f) + + +@app.function( + gpu=IMAGE_GPU, + cpu=2.0, + retries=1, + image=img, + secrets=[modal.Secret.from_name("achatbot")], + volumes={ + HF_MODEL_DIR: hf_model_vol, + ASSETS_DIR: assets_vol, + }, + timeout=1200, # default 300s + scaledown_window=1200, + max_containers=1, +) +async def run(func, **kwargs): + subprocess.run("nvidia-smi --version", shell=True) + subprocess.run("nvcc --version", shell=True) + gpu_prop = None + if torch.cuda.is_available(): + gpu_prop = torch.cuda.get_device_properties("cuda") + print(gpu_prop) + + if asyncio.iscoroutinefunction(func): + await func(gpu_prop, **kwargs) + else: + func(gpu_prop, **kwargs) + + +def dump_model(gpu_prop, **kwargs): + if "Base" in MODEL_ID: + model = StepAudio2Base(MODEL_PATH) # mini-base + else: + model = StepAudio2(MODEL_PATH) # mini + # print_model_params(model.llm, MODEL_ID) # AudioEncoder + Adpater + LLM decoder + + token2wav = Token2wav(f"{MODEL_PATH}/token2wav") + + print_model_params( + token2wav.flow, f"{MODEL_ID}/token2wav.flow" + ) # Flow-Matching audio_tokens->mels + print_model_params( + token2wav.flow.encoder, f"{MODEL_ID}/token2wav.flow.encoder" + ) # Flow-Matching encoder + print_model_params( + token2wav.flow.decoder, f"{MODEL_ID}/token2wav.flow.decoder" + ) # Flow-Matching decoder + + print_model_params(token2wav.hift, f"{MODEL_ID}/token2wav.hift") # Vocoder mels->waveform + + print_model_params( + token2wav.audio_tokenizer, f"{MODEL_ID}/token2wav.audio_tokenizer" + ) # for ref audio quantization (FSQ) + print( + token2wav.spk_model, f"{MODEL_ID}/token2wav.spk_model" + ) # for ref audio speaker embedding(fbank feat) + + # print(f"{model.llm_tokenizer=}") # text tokenizer with instruct specail token + print(f"{model.llm.config=}") + + print(model.llm_tokenizer.decode(49434)) + print(model.llm_tokenizer.decode(239)) + print(model.llm_tokenizer.decode([49434, 239])) + + +def tokenize(gpu_prop, **kwargs): + if "Base" in MODEL_ID: + model = StepAudio2Base(MODEL_PATH) # mini-base + messages = [ + "请记录下你所听到的语音内容。", + { + "type": "audio", + "audio": "/Step-Audio2/assets/give_me_a_brief_introduction_to_the_great_wall.wav", + }, + ] + else: + model = StepAudio2(MODEL_PATH) # mini + messages = [ + {"role": "system", "content": "请记录下你所听到的语音内容。"}, + { + "role": "human", + "content": [ + { + "type": "audio", + "audio": "/Step-Audio2/assets/give_me_a_brief_introduction_to_the_great_wall.wav", + } + ], + }, + {"role": "assistant", "content": None}, + ] + + res, mels = model.apply_chat_template(messages) + print(res) + print(mels) + + # Tokenize prompts + prompt_ids = [] + for msg in res: + if isinstance(msg, str): + prompt_ids.append( + model.llm_tokenizer(text=msg, return_tensors="pt", padding=True)["input_ids"] + ) + elif isinstance(msg, list): + prompt_ids.append(torch.tensor([msg], dtype=torch.int32)) + else: + raise ValueError(f"Unsupported content type: {type(msg)}") + prompt_ids = torch.cat(prompt_ids, dim=-1).cuda() + attention_mask = torch.ones_like(prompt_ids) + print(prompt_ids) + print(attention_mask) + + # mels = None if len(mels) == 0 else torch.stack(mels).cuda() + # mel_lengths = None if mels is None else torch.tensor([mel.shape[1] - 2 for mel in mels], dtype=torch.int32, device='cuda') + mels, mel_lengths = padding_mels(mels) + print(mels, mel_lengths) + + +# ASR +def asr_test(model, token2wav=None): + messages = [ + "请记录下你所听到的语音内容。", + { + "type": "audio", + "audio": "/Step-Audio2/assets/give_me_a_brief_introduction_to_the_great_wall.wav", + }, + ] + eos_token_id = model.llm_tokenizer.convert_tokens_to_ids("<|endoftext|>") + tokens, text, _ = model( + messages, + max_new_tokens=256, + temperature=0.1, + do_sample=True, + eos_token_id=[model.eos_token_id, eos_token_id], + ) + print(text) + + +# S2TT(support: en,zh,ja) +def s2tt_test(model, token2wav=None): + messages = [ + "请仔细聆听这段语音,然后将其内容翻译成中文", + # "Please listen carefully to this audio and then translate its content into Chinese.", + { + "type": "audio", + "audio": "/Step-Audio2/assets/give_me_a_brief_introduction_to_the_great_wall.wav", + }, + ] + eos_token_id = model.llm_tokenizer.convert_tokens_to_ids("<|endoftext|>") + tokens, text, _ = model( + messages, + max_new_tokens=256, + temperature=0.1, + do_sample=True, + eos_token_id=[model.eos_token_id, eos_token_id], + ) + print(text) + + +# audio caption +def audio_caption_test(model, token2wav=None): + messages = [ + "Please briefly explain the important events involved in this audio clip.", + { + "type": "audio", + "audio": "/Step-Audio2/assets/music_playing_followed_by_a_woman_speaking.wav", + }, + ] + eos_token_id = model.llm_tokenizer.convert_tokens_to_ids("<|endoftext|>") + tokens, text, _ = model( + messages, + max_new_tokens=256, + temperature=0.1, + do_sample=True, + eos_token_id=[model.eos_token_id, eos_token_id], + ) + print(text) + + +# TTS(support: en,zh,ja) +def tts_test(model, token2wav): + messages = [ + "以自然的语速读出下面的文字。\n", + # "Read this paragraph at a natural pace.\n", + "你好呀,我是你的AI助手,很高兴认识你!", + ] + tokens, text, audio = model(messages, max_tokens=2048, temperature=0.7, do_sample=True) + print(text) + # print(tokens) + audio = [x for x in audio if x < 6561] # remove audio padding + audio = token2wav(audio, prompt_wav="/Step-Audio2/assets/default_male.wav") + with open(f"{ASSETS_DIR}/StepAudio2/output-tts.wav", "wb") as f: + f.write(audio) + + +# T2ST(support: en,zh,ja) +def t2st_test(model, token2wav): + messages = [ + "将下面的文本翻译成英文,并用语音播报。\n", + # "Translate the following text into English and broadcast it with voice.\n", + "你好呀,我是你的AI助手,很高兴认识你!", + ] + tokens, text, audio = model(messages, max_tokens=2048, temperature=0.7, do_sample=True) + print(text) + # print(tokens) + audio = [x for x in audio if x < 6561] # remove audio padding + audio = token2wav(audio, prompt_wav="/Step-Audio2/assets/default_male.wav") + with open(f"{ASSETS_DIR}/StepAudio2/output-t2st.wav", "wb") as f: + f.write(audio) + + +# S2ST(support: en,zh) +def s2st_test(model, token2wav): + messages = [ + "请仔细聆听这段语音,然后将其内容翻译成中文并用语音播报。", + # "Please listen carefully to this audio and then translate its content into Chinese speech.", + { + "type": "audio", + "audio": "/Step-Audio2/assets/give_me_a_brief_introduction_to_the_great_wall.wav", + }, + "", + ] + tokens, text, audio = model(messages, max_tokens=2048, temperature=0.7, do_sample=True) + print(text) + # print(tokens) + audio = [x for x in audio if x < 6561] # remove audio padding + audio = token2wav(audio, prompt_wav="/Step-Audio2/assets/default_female.wav") + with open(f"{ASSETS_DIR}/StepAudio2/output-s2st.wav", "wb") as f: + f.write(audio) + + +# multi turn aqta +def multi_turn_aqta_test(model, token2wav=None): + history = [] + for round_idx, inp_audio in enumerate( + [ + "/Step-Audio2/assets/multi-turn-round1-听说荡口古镇从下个月开始取消门票了,你知道这事吗。.wav", + "/Step-Audio2/assets/multi-turn-round2-新闻说九月十九号就免费开放了。好像整个古镇都升级改造了,现在变成开放式街区了。.wav", + ] + ): + print("round: ", round_idx) + history.append({"type": "audio", "audio": inp_audio}) + tokens, text, _ = model(history, max_new_tokens=256, temperature=0.5, do_sample=True) + print(text) + history.append(text) + + +# multi turn aqaa +def multi_turn_aqaa_test(model, token2wav): + history = [] + for round_idx, inp_audio in enumerate( + [ + "/Step-Audio2/assets/multi-turn-round1-听说荡口古镇从下个月开始取消门票了,你知道这事吗。.wav", + "/Step-Audio2/assets/multi-turn-round2-新闻说九月十九号就免费开放了。好像整个古镇都升级改造了,现在变成开放式街区了。.wav", + ] + ): + print("round: ", round_idx) + history.append( + {"type": "audio", "audio": inp_audio}, + ) + history.append("") + tokens, text, audio = model(history, max_new_tokens=2048, temperature=0.7, do_sample=True) + print(text) + audio = [x for x in audio if x < 6561] # remove audio padding + audio = token2wav(audio, prompt_wav="/Step-Audio2/assets/default_female.wav") + with open(f"{ASSETS_DIR}/StepAudio2/output-round-{round_idx}.wav", "wb") as f: + f.write(audio) + history.append({"type": "token", "token": tokens}) + + +def test_base(gpu_prop, **kwargs): + model = StepAudio2Base(MODEL_PATH) + token2wav = Token2wav(f"{MODEL_PATH}/token2wav") + + test_func = kwargs.get("test_func", "asr_test") + globals()[test_func](model, token2wav) + + +# ------------------------------------------------------------------------------------------------- +# special Instruct + + +# ASR +def instruct_asr_test(model, token2wav): + messages = [ + {"role": "system", "content": "请记录下你所听到的语音内容。"}, + # {"role": "system", "content": "Please record the audio content you hear."}, + { + "role": "human", + "content": [ + { + "type": "audio", + "audio": "/Step-Audio2/assets/give_me_a_brief_introduction_to_the_great_wall.wav", + } + ], + }, + {"role": "assistant", "content": None}, + ] + tokens, text, _ = model(messages, max_new_tokens=256) + print(text) + + +# audio caption +def instruct_audio_caption_test(model, token2wav): + messages = [ + { + "role": "system", + "content": "Please briefly explain the important events involved in this audio clip.", + }, + { + "role": "human", + "content": [ + { + "type": "audio", + "audio": "/Step-Audio2/assets/music_playing_followed_by_a_woman_speaking.wav", + } + ], + }, + {"role": "assistant", "content": None}, + ] + tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.1, do_sample=True) + print(text) + + +# S2TT(support: en,zh,ja) +def instruct_s2tt_test(model, token2wav): + messages = [ + {"role": "system", "content": "请仔细聆听这段语音,然后将其内容翻译成中文。"}, + # {"role": "system", "content":"Please listen carefully to this audio and then translate its content into Chinese."}, + { + "role": "human", + "content": [ + { + "type": "audio", + "audio": "/Step-Audio2/assets/give_me_a_brief_introduction_to_the_great_wall.wav", + } + ], + }, + {"role": "assistant", "content": None}, + ] + tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.1, do_sample=True) + print(text) + + +# S2ST(support: en,zh) +def instruct_s2st_test(model, token2wav): + messages = [ + {"role": "system", "content": "请仔细聆听这段语音,然后将其内容翻译成中文并用语音播报。"}, + # {"role": "system", "content":"Please listen carefully to this audio and then translate its content into Chinese speech."}, + { + "role": "human", + "content": [ + { + "type": "audio", + "audio": "/Step-Audio2/assets/give_me_a_brief_introduction_to_the_great_wall.wav", + } + ], + }, + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ] + tokens, text, audio = model(messages, max_tokens=2048, temperature=0.7, do_sample=True) + print(text) + # print(tokens) + audio = [x for x in audio if x < 6561] # remove audio padding + audio = token2wav(audio, prompt_wav="/Step-Audio2/assets/default_female.wav") + with open(f"{ASSETS_DIR}/StepAudio2/output-s2st.wav", "wb") as f: + f.write(audio) + + +# multi turn tqta +def instruct_multi_turn_tqta_test(model, token2wav): + history = [{"role": "system", "content": "You are a helpful assistant."}] + for round_idx, input_text in enumerate( + [ + "听说荡口古镇从下个月开始取消门票了,你知道这事吗。", + "新闻说九月十九号就免费开放了。好像整个古镇都升级改造了,现在变成开放式街区了。", + ] + ): + print("round: ", round_idx) + history.append({"role": "human", "content": [{"type": "text", "text": input_text}]}) + history.append({"role": "assistant", "content": None}) + tokens, text, _ = model(history, max_new_tokens=256, temperature=0.5, do_sample=True) + print(text) + history.pop(-1) + history.append({"role": "assistant", "content": text}) + + +# multi turn tqaa +def instruct_multi_turn_tqaa_test(model, token2wav): + history = [{"role": "system", "content": "You are a helpful assistant."}] + for round_idx, input_text in enumerate( + [ + "听说荡口古镇从下个月开始取消门票了,你知道这事吗。", + "新闻说九月十九号就免费开放了。好像整个古镇都升级改造了,现在变成开放式街区了。", + ] + ): + print("round: ", round_idx) + history.append({"role": "human", "content": [{"type": "text", "text": input_text}]}) + history.append( + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ) + tokens, text, audio = model(history, max_new_tokens=256, temperature=0.5, do_sample=True) + print(tokens, model.llm_tokenizer.decode(tokens)) + print(text) + audio = [x for x in audio if x < 6561] # remove audio padding + audio = token2wav(audio, prompt_wav="/Step-Audio2/assets/default_female.wav") + with open(f"{ASSETS_DIR}/StepAudio2/output-round-tqaa-{round_idx}.wav", "wb") as f: + f.write(audio) + history.pop(-1) + history.append( + { + "role": "assistant", + "content": [ + {"type": "text", "text": ""}, + {"type": "token", "token": tokens}, + ], + } + ) + + +# multi turn aqta +def instruct_multi_turn_aqta_test(model, token2wav): + history = [{"role": "system", "content": "You are a helpful assistant."}] + for round_idx, inp_audio in enumerate( + [ + "/Step-Audio2/assets/multi-turn-round1-听说荡口古镇从下个月开始取消门票了,你知道这事吗。.wav", + "/Step-Audio2/assets/multi-turn-round2-新闻说九月十九号就免费开放了。好像整个古镇都升级改造了,现在变成开放式街区了。.wav", + ] + ): + print("round: ", round_idx) + history.append({"role": "human", "content": [{"type": "audio", "audio": inp_audio}]}) + history.append({"role": "assistant", "content": None}) + tokens, text, _ = model(history, max_new_tokens=256, temperature=0.5, do_sample=True) + print(text) + history.pop(-1) + history.append({"role": "assistant", "content": text}) + + +# multi turn aqaa +def instruct_multi_turn_aqaa_test(model, token2wav): + history = [{"role": "system", "content": "You are a helpful assistant."}] + for round_idx, inp_audio in enumerate( + [ + "/Step-Audio2/assets/multi-turn-round1-听说荡口古镇从下个月开始取消门票了,你知道这事吗。.wav", + "/Step-Audio2/assets/multi-turn-round2-新闻说九月十九号就免费开放了。好像整个古镇都升级改造了,现在变成开放式街区了。.wav", + ] + ): + print("round: ", round_idx) + history.append({"role": "human", "content": [{"type": "audio", "audio": inp_audio}]}) + history.append( + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ) + tokens, text, audio = model(history, max_new_tokens=2048, temperature=0.7, do_sample=True) + print(tokens, model.llm_tokenizer.decode(tokens)) + print(text) + audio = [x for x in audio if x < 6561] # remove audio padding + audio = token2wav(audio, prompt_wav="/Step-Audio2/assets/default_female.wav") + with open(f"{ASSETS_DIR}/StepAudio2/output-round-aqaa-{round_idx}.wav", "wb") as f: + f.write(audio) + history.pop(-1) + history.append( + { + "role": "assistant", + "content": [ + {"type": "text", "text": ""}, + {"type": "token", "token": tokens}, + ], + } + ) + + +# Tool call & Web search +def instruct_tool_call_test(model, token2wav): + history = [ + { + "role": "system", + "content": "你的名字叫做小跃,是由阶跃星辰公司训练出来的语音大模型。\n你具备调用工具解决问题的能力,你需要根据用户的需求和上下文情景,自主选择是否调用系统提供的工具来协助用户。\n你情感细腻,观察能力强,擅长分析用户的内容,并作出善解人意的回复,说话的过程中时刻注意用户的感受,富有同理心,提供多样的情绪价值。\n今天是2025年8月28日,星期四\n请用默认女声与用户交流", + }, + { + "role": "tool_json_schemas", + "content": '[{"type": "function", "function": {"name": "search", "description": "搜索工具", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "搜索关键词"}}, "required": ["query"], "additionalProperties": false}}}]', + }, + { + "role": "human", + "content": [ + { + "type": "audio", + "audio": "/Step-Audio2/assets/帮我查一下今天上证指数的开盘价是多少.wav", + } + ], + }, + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ] + tokens, text, audio = model( + history, + max_new_tokens=4096, + repetition_penalty=1.05, + top_p=0.9, + temperature=0.7, + do_sample=True, + ) + print(text) + audio = [x for x in audio if x < 6561] # remove audio padding + audio = token2wav(audio, prompt_wav="/Step-Audio2/assets/default_female.wav") + with open(f"{ASSETS_DIR}/StepAudio2/output-tool-call-1.wav", "wb") as f: + f.write(audio) + history.pop(-1) + with open("/Step-Audio2/assets/search_result.txt") as f: + search_result = f.read().strip() + print(f"search result: {search_result}") + history += [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": ""}, + {"type": "token", "token": tokens}, + ], + }, + { + "role": "input", + "content": [ + {"type": "text", "text": search_result}, + { + "type": "text", + "text": "\n\n\n请用口语化形式总结检索结果,简短地回答用户的问题。", + }, + ], + }, + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ] + tokens, text, audio = model( + history, + max_new_tokens=4096, + repetition_penalty=1.05, + top_p=0.9, + temperature=0.7, + do_sample=True, + ) + print(text) + audio = [x for x in audio if x < 6561] # remove audio padding + audio = token2wav(audio, prompt_wav="/Step-Audio2/assets/default_female.wav") + with open(f"{ASSETS_DIR}/StepAudio2/output-tool-call-2.wav", "wb") as f: + f.write(audio) + + +# Paralingustic information understanding +def instruct_paralinguistic_test(model, token2wav): + messages = [ + {"role": "system", "content": "请用语音与我交流。"}, + { + "role": "human", + "content": [ + { + "type": "audio", + "audio": "/Step-Audio2/assets/paralinguistic_information_understanding.wav", + } + ], + }, + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ] + tokens, text, audio = model(messages, max_tokens=2048, temperature=0.7, do_sample=True) + print(text) + # print(tokens) + audio = [x for x in audio if x < 6561] # remove audio padding + audio = token2wav(audio, prompt_wav="/Step-Audio2/assets/default_female.wav") + with open(f"{ASSETS_DIR}/StepAudio2/output-paralinguistic.wav", "wb") as f: + f.write(audio) + + +# Audio understanding +def instruct_mmau_test(model, token2wav): + messages = [ + { + "role": "system", + "content": "You are an expert in audio analysis, please analyze the audio content and answer the questions accurately.", + }, + { + "role": "human", + "content": [ + {"type": "audio", "audio": "/Step-Audio2/assets/mmau_test.wav"}, + { + "type": "text", + "text": f"Which of the following best describes the male vocal in the audio? Please choose the answer from the following options: [Soft and melodic, Aggressive and talking, High-pitched and singing, Whispering] Output the final answer in .", + }, + ], + }, + {"role": "assistant", "content": None}, + ] + tokens, text, _ = model(messages, max_new_tokens=256, num_beams=2) + print(text) + + +# Audio understanding +def instruct_mmau_audio_answer_test(model, token2wav): + messages = [ + { + "role": "system", + "content": "You are an expert in audio analysis, please analyze the audio content and answer the questions accurately. \nPlease communicate with me via voice.\n", + }, + { + "role": "human", + "content": [ + {"type": "audio", "audio": "/Step-Audio2/assets/mmau_test.wav"}, + { + "type": "text", + "text": f"Which of the following best describes the male vocal in the audio? Please choose the answer from the following options: [Soft and melodic, Aggressive and talking, High-pitched and singing, Whispering].", + }, + ], + }, + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ] + tokens, text, audio = model(messages, max_tokens=2048, temperature=0.7, do_sample=True) + print(text) + # print(tokens) + audio = [x for x in audio if x < 6561] # remove audio padding + audio = token2wav(audio, prompt_wav="/Step-Audio2/assets/default_female.wav") + with open(f"{ASSETS_DIR}/StepAudio2/output-mmau.wav", "wb") as f: + f.write(audio) + + +def test_instruct(gpu_prop, **kwargs): + model = StepAudio2(MODEL_PATH) + token2wav = Token2wav(f"{MODEL_PATH}/token2wav") + + test_func = kwargs.get("test_func", "instruct_asr_test") + globals()[test_func](model, token2wav) + + +# ------------------------------------------------------------------------------------------------------------- +# generate stream + + +# ASR (not live asr) +def stream_asr_test(model, token2wav=None): + messages = [ + {"role": "system", "content": "请记录下你所听到的语音内容。"}, + { + "role": "human", + "content": [ + { + "type": "audio", + "audio": "/Step-Audio2/assets/give_me_a_brief_introduction_to_the_great_wall.wav", + } + ], + }, + {"role": "assistant", "content": None}, + ] + eos_token_id = model.llm_tokenizer.convert_tokens_to_ids("<|endoftext|>") + token_iter = model( + messages, + max_new_tokens=256, + temperature=0.1, + do_sample=True, + eos_token_id=[model.eos_token_id, eos_token_id], + ) + output_text_token_ids = [] + output_audio_token_ids = [] + output_text = "" + is_tag = False + for token_id in token_iter: + token = model.llm_tokenizer.decode(token_id) + print(token_id, token) + + if token_id == 27: + is_tag = True + continue + if token_id == 29: + is_tag = False + continue + if is_tag: + continue + if token_id in [model.eos_token_id, eos_token_id]: + break + + if token_id < 151688: + output_text_token_ids.append(token_id) + if token_id > 151695: + output_audio_token_ids.append(token_id - 151696) + output_text += token + print(output_text) + + +# TTS(support: en,zh,ja) +def stream_tts_test(model, token2wav): + messages = [ + {"role": "system", "content": "以自然的语速读出下面的文字。\n"}, + {"role": "human", "content": "你好呀,我是你的AI助手,很高兴认识你!"}, + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ] + token_iter = model(messages, max_tokens=2048, temperature=0.7, do_sample=True) + output_text_token_ids = [] + output_audio_token_ids = [] + output_token = "" + + # stream audio + buffer = [] + prompt_wav = "/Step-Audio2/assets/default_male.wav" + token2wav.set_stream_cache(prompt_wav) + output_stream = Path(f"{ASSETS_DIR}/StepAudio2/output-chunks-stream-tts.pcm") + output_stream.unlink(missing_ok=True) + for token_id in token_iter: + token = model.llm_tokenizer.decode(token_id) + print(token_id, token) + output_token += token + + if token_id < 151688: # text + output_text_token_ids.append(token_id) + if token_id > 151695: # audio + audio_token_id = token_id - 151696 + if audio_token_id < 6561: # remove audio padding + output_audio_token_ids.append(audio_token_id) + buffer.append(audio_token_id) + if len(buffer) >= CHUNK_SIZE + token2wav.flow.pre_lookahead_len: + start = time.time() + output = token2wav.stream( + buffer[: CHUNK_SIZE + token2wav.flow.pre_lookahead_len], + prompt_wav=prompt_wav, + last_chunk=False, + ) + print(len(buffer), len(output), output[:50], time.time() - start) + with open(output_stream, "ab") as f: + f.write(output) + buffer = buffer[CHUNK_SIZE:] + + if len(buffer) > 0: + start = time.time() + output = token2wav.stream(buffer, prompt_wav=prompt_wav, last_chunk=True) + print("last_chunk", len(buffer), len(output), output[:50], time.time() - start) + with open(output_stream, "ab") as f: + f.write(output) + + with open(output_stream, "rb") as f: + pcm = f.read() + wav_path = output_stream.with_suffix(".wav") + with wave.open(str(wav_path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(pcm) + + print(output_token) + audio = token2wav(output_audio_token_ids, prompt_wav="/Step-Audio2/assets/default_male.wav") + with open(f"{ASSETS_DIR}/StepAudio2/output-stream-tts.wav", "wb") as f: + f.write(audio) + + +def stream_aqaa_test(model, token2wav): + history = [{"role": "system", "content": "You are a helpful assistant."}] + for round_idx, inp_audio in enumerate( + [ + "/Step-Audio2/assets/multi-turn-round1-听说荡口古镇从下个月开始取消门票了,你知道这事吗。.wav", + "/Step-Audio2/assets/multi-turn-round2-新闻说九月十九号就免费开放了。好像整个古镇都升级改造了,现在变成开放式街区了。.wav", + ] + ): + print("round: ", round_idx) + history.append({"role": "human", "content": [{"type": "audio", "audio": inp_audio}]}) + history.append( + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ) + + token_iter = model(history, max_tokens=2048, temperature=0.7, do_sample=True) + output_text_token_ids = [] + output_audio_token_ids = [] + output_token = "" + output_token_ids = [] + + # stream audio + buffer = [] + prompt_wav = "/Step-Audio2/assets/default_male.wav" + token2wav.set_stream_cache(prompt_wav) + output_stream = Path(f"{ASSETS_DIR}/StepAudio2/output-aqaa-{round_idx}-chunks-stream.pcm") + output_stream.unlink(missing_ok=True) + for token_id in token_iter: + output_token_ids.append(token_id) + token = model.llm_tokenizer.decode(token_id) + print(token_id, token) + output_token += token + + if token_id < 151688: # text + output_text_token_ids.append(token_id) + if token_id > 151695: # audio + audio_token_id = token_id - 151696 + if audio_token_id < 6561: # remove audio padding + output_audio_token_ids.append(audio_token_id) + buffer.append(audio_token_id) + if len(buffer) >= CHUNK_SIZE + token2wav.flow.pre_lookahead_len: + start = time.time() + output = token2wav.stream( + buffer[: CHUNK_SIZE + token2wav.flow.pre_lookahead_len], + prompt_wav=prompt_wav, + last_chunk=False, + ) + print(len(buffer), len(output), output[:50], time.time() - start) + with open(output_stream, "ab") as f: + f.write(output) + buffer = buffer[CHUNK_SIZE:] + + if len(buffer) > 0: + start = time.time() + output = token2wav.stream(buffer, prompt_wav=prompt_wav, last_chunk=True) + print("last_chunk", len(buffer), len(output), output[:50], time.time() - start) + with open(output_stream, "ab") as f: + f.write(output) + + with open(output_stream, "rb") as f: + pcm = f.read() + wav_path = output_stream.with_suffix(".wav") + with wave.open(str(wav_path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(pcm) + + print(output_token) + audio = token2wav(output_audio_token_ids, prompt_wav="/Step-Audio2/assets/default_male.wav") + with open(f"{ASSETS_DIR}/StepAudio2/output-stream-aqaa-{round_idx}.wav", "wb") as f: + f.write(audio) + + history.pop(-1) + history.append( + { + "role": "assistant", + "content": [ + {"type": "text", "text": ""}, + {"type": "token", "token": output_token_ids}, + ], + } + ) + + +def extract_function_info(tool_calls_token: str) -> tuple: + """ + 从 tool_calls_token 字符串中提取 function_name 和 function_args + + 参数格式示例: + 'function\nweb_search\n{"query": "2025年8月28日 上证指数 开盘价"}' + + 返回: (function_name, function_args_dict) + """ + # 按换行符分割字符串 + parts = tool_calls_token.split("\n") + + # 验证格式是否正确 + if len(parts) < 3 or parts[0] != "function": + raise ValueError("无效的 tool_calls_token 格式") + + # 提取函数名 + function_name = parts[1] + + try: + # 合并剩余部分作为 JSON 字符串(处理可能的多行 JSON) + json_str = "\n".join(parts[2:]) + function_args = json.loads(json_str) + except json.JSONDecodeError: + raise ValueError("无法解析 function_args JSON") + + return function_name, function_args + + +def stream_aqaa_tools_test(model, token2wav): + history = [ + { + "role": "system", + "content": "你的名字叫做小跃,是由阶跃星辰公司训练出来的语音大模型。\n你具备调用工具解决问题的能力,你需要根据用户的需求和上下文情景,自主选择是否调用系统提供的工具来协助用户。\n你情感细腻,观察能力强,擅长分析用户的内容,并作出善解人意的回复,说话的过程中时刻注意用户的感受,富有同理心,提供多样的情绪价值。\n今天是2025年8月28日,星期五", + }, + { + "role": "tool_json_schemas", + "content": '[{"type": "function", "function": {"name": "web_search", "description": "搜索工具", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "搜索关键词"}}, "required": ["query"], "additionalProperties": false}}}]', + }, + ] + tool_calls_token_ids = [] + search_result = "" + for round_idx, inp_audio in enumerate( + [ + "/Step-Audio2/assets/帮我查一下今天上证指数的开盘价是多少.wav", + "/Step-Audio2/assets/multi-turn-round1-听说荡口古镇从下个月开始取消门票了,你知道这事吗。.wav", + "/Step-Audio2/assets/multi-turn-round2-新闻说九月十九号就免费开放了。好像整个古镇都升级改造了,现在变成开放式街区了。.wav", + ] + ): + print("round: ", round_idx) + tool_cn = 0 + while True: + if len(tool_calls_token_ids) > 0: + history.append( + { + "role": "input", + "content": [ + {"type": "text", "text": search_result}, + { + "type": "text", + "text": "\n\n\n请用口语化形式总结检索结果,简短地回答用户的问题。", + }, + ], + } + ) + tool_cn += 1 + else: + history.append( + {"role": "human", "content": [{"type": "audio", "audio": inp_audio}]} + ) + history.append( + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ) + + token_iter = model(history, max_tokens=2048, temperature=0.7, do_sample=True) + output_text_token_ids = [] + output_audio_token_ids = [] + output_token = "" + output_token_ids = [] + + # tools + is_tool = False + tool_calls_token_ids = [] + + # stream audio + buffer = [] + prompt_wav = "/Step-Audio2/assets/default_male.wav" + token2wav.set_stream_cache(prompt_wav) + output_stream = Path( + f"{ASSETS_DIR}/StepAudio2/output-aqaa-tools-{tool_cn}-{round_idx}-chunks-stream.pcm" + ) + output_stream.unlink(missing_ok=True) + for token_id in token_iter: + output_token_ids.append(token_id) + token = model.llm_tokenizer.decode(token_id) + print(token_id, token) + output_token += token + + if token_id < 151688: # text + if token_id == 151657: # + is_tool = True + continue + if token_id == 151658: # + is_tool = False + continue + if is_tool: + tool_calls_token_ids.append(token_id) + continue + output_text_token_ids.append(token_id) + + if token_id > 151695: # audio + audio_token_id = token_id - 151696 + if audio_token_id < 6561: # remove audio padding + output_audio_token_ids.append(audio_token_id) + buffer.append(audio_token_id) + if len(buffer) >= CHUNK_SIZE + token2wav.flow.pre_lookahead_len: + start = time.time() + output = token2wav.stream( + buffer[: CHUNK_SIZE + token2wav.flow.pre_lookahead_len], + prompt_wav=prompt_wav, + last_chunk=False, + ) + print(len(buffer), len(output), output[:50], time.time() - start) + with open(output_stream, "ab") as f: + f.write(output) + buffer = buffer[CHUNK_SIZE:] + + if len(buffer) > 0: + start = time.time() + output = token2wav.stream(buffer, prompt_wav=prompt_wav, last_chunk=True) + print("last_chunk", len(buffer), len(output), output[:50], time.time() - start) + with open(output_stream, "ab") as f: + f.write(output) + + with open(output_stream, "rb") as f: + pcm = f.read() + wav_path = output_stream.with_suffix(".wav") + with wave.open(str(wav_path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(pcm) + + print(f"{output_token=}") + output_text_tokens = model.llm_tokenizer.decode(output_text_token_ids) + print(f"{output_text_tokens=}") + + audio = token2wav( + output_audio_token_ids, prompt_wav="/Step-Audio2/assets/default_male.wav" + ) + with open( + f"{ASSETS_DIR}/StepAudio2/output-stream-aqaa-tools-{tool_cn}-{round_idx}.wav", "wb" + ) as f: + f.write(audio) + + history.pop(-1) + history.append( + { + "role": "assistant", + "content": [ + {"type": "text", "text": ""}, + {"type": "token", "token": output_token_ids}, + ], + } + ) + + if len(tool_calls_token_ids) == 0: + break # break tool call while + + tool_calls_token = model.llm_tokenizer.decode(tool_calls_token_ids) + print(f"{tool_calls_token=}") + function_name, function_args = extract_function_info(tool_calls_token) + print(f"{function_name=}") + print(f"{function_args=}") + + # mock search + with open("/Step-Audio2/assets/search_result.txt") as f: + search_result = f.read().strip() + print(f"search result: {search_result}") + + +def generate_stream(gpu_prop, **kwargs): + class TokenStreamer(BaseStreamer): + def __init__(self, skip_prompt: bool = False, timeout=None): + self.skip_prompt = skip_prompt + + # variables used in the streaming process + self.token_queue = Queue() + self.stop_signal = None + self.next_tokens_are_prompt = True + self.timeout = timeout + + def put(self, value): + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError("TextStreamer only supports batch size 1") + elif len(value.shape) > 1: + value = value[0] + + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + return + + for token in value.tolist(): + self.token_queue.put(token) + + def end(self): + self.token_queue.put(self.stop_signal) + + def __iter__(self): + return self + + def __next__(self): + value = self.token_queue.get(timeout=self.timeout) + if value == self.stop_signal: + raise StopIteration() + else: + return value + + class StepAudio2Stream(StepAudio2): + def __call__(self, messages: list, **kwargs): + messages, mels = self.apply_chat_template(messages) + print(messages) + + # Tokenize prompts + prompt_ids = [] + for msg in messages: + if isinstance(msg, str): + prompt_ids.append( + self.llm_tokenizer(text=msg, return_tensors="pt", padding=True)["input_ids"] + ) + elif isinstance(msg, list): + prompt_ids.append(torch.tensor([msg], dtype=torch.int32)) + else: + raise ValueError(f"Unsupported content type: {type(msg)}") + prompt_ids = torch.cat(prompt_ids, dim=-1).cuda() + attention_mask = torch.ones_like(prompt_ids) + + # mels = None if len(mels) == 0 else torch.stack(mels).cuda() + # mel_lengths = None if mels is None else torch.tensor([mel.shape[1] - 2 for mel in mels], dtype=torch.int32, device='cuda') + if len(mels) == 0: + mels = None + mel_lengths = None + else: + mels, mel_lengths = padding_mels(mels) + mels = mels.cuda() + mel_lengths = mel_lengths.cuda() + + generation_config = dict( + max_new_tokens=2048, + pad_token_id=self.llm_tokenizer.pad_token_id, + eos_token_id=self.eos_token_id, + ) + generation_config.update(kwargs) + generation_config = GenerationConfig(**generation_config) + + streamer = TokenStreamer(skip_prompt=True) + + generation_kwargs = dict( + input_ids=prompt_ids, + wavs=mels, + wav_lens=mel_lengths, + attention_mask=attention_mask, + generation_config=generation_config, + streamer=streamer, + ) + + thread = Thread(target=self.llm.generate, kwargs=generation_kwargs) + thread.start() + + stop_ids = ( + [generation_config.eos_token_id] + if isinstance(generation_config.eos_token_id, int) + else generation_config.eos_token_id + ) + for token_id in streamer: + # print(token_id, end=",", flush=True) + if token_id in stop_ids: + break + yield token_id + + model = StepAudio2Stream(MODEL_PATH) + + token2wav = Token2wav(f"{MODEL_PATH}/token2wav") + no_experimental = torch._dynamo.list_backends() + print(f"{no_experimental=}") + experimental = torch._dynamo.list_backends(None) + print(f"{experimental=}") + token2wav.flow.scatter_cuda_graph(True) + + test_func = kwargs.get("test_func", "stream_asr_test") + globals()[test_func](model, token2wav) + + +async def achatbot_step_audio2_say(): + from apipeline.frames import AudioRawFrame, StartFrame, EndFrame, CancelFrame + from achatbot.types.frames import PathAudioRawFrame + + from achatbot.cmd.bots.voice.step_audio2.helper import ( + get_step_audio2_llm, + get_step_audio2_processor, + ) + from achatbot.types.ai_conf import AIConfig, LLMConfig + + processor = get_step_audio2_processor( + LLMConfig( + processor="StepAudio2TextAudioChatProcessor", + args={ + "init_system_prompt": "", + "prompt_wav": "/root/.achatbot/assets/default_male.wav", + "warmup_cn": 2, + "chat_history_size": None, + "text_stream_out": False, + "no_stream_sleep_time": 0.5, + "lm_model_name_or_path": MODEL_PATH, + "lm_gen_max_new_tokens": 64, + "lm_gen_temperature": 0.1, + "lm_gen_top_k": 20, + "lm_gen_top_p": 0.95, + "lm_gen_repetition_penalty": 1.1, + }, + ) + ) + await processor.start(StartFrame()) + + frame_iter = processor.generator_say( + "你好, 我是Step-Audio2,很高兴认识你。", is_push_frame=False + ) + audio = b"" + async for frame in frame_iter: + if isinstance(frame, AudioRawFrame): + audio += frame.audio + print(f"say gen_frame-->", frame) + print(f"say {processor._session.chat_history=}") + wav_path = Path(f"{ASSETS_DIR}/StepAudio2/output-processor-chunks-stream-say.wav") + with wave.open(str(wav_path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(audio) + await processor.stop(EndFrame()) + + +async def achatbot_step_audio2_t2st(processor_name: str): + from apipeline.frames import AudioRawFrame, StartFrame, EndFrame, CancelFrame, TextFrame + from achatbot.types.frames import PathAudioRawFrame + + from achatbot.cmd.bots.voice.step_audio2.helper import ( + get_step_audio2_llm, + get_step_audio2_processor, + ) + from achatbot.types.ai_conf import AIConfig, LLMConfig + + processor = get_step_audio2_processor( + LLMConfig( + processor="StepT2STProcessor", + args={ + "init_system_prompt": "", + "prompt_wav": "/root/.achatbot/assets/default_male.wav", + "warmup_cn": 2, + "chat_history_size": None, + "text_stream_out": False, + "no_stream_sleep_time": 0.5, + "lm_model_name_or_path": MODEL_PATH, + "lm_gen_max_new_tokens": 64, + "lm_gen_temperature": 0.1, + "lm_gen_top_k": 20, + "lm_gen_top_p": 0.95, + "lm_gen_repetition_penalty": 1.1, + }, + ) + ) + await processor.start(StartFrame()) + + frame_iter = processor.run_text(TextFrame(text="你好, 我是Step-Audio2,很高兴认识你。")) + audio = b"" + async for frame in frame_iter: + if isinstance(frame, AudioRawFrame): + audio += frame.audio + print(f"say gen_frame-->", frame) + print(f"say {processor._session.chat_history=}") + wav_path = Path(f"{ASSETS_DIR}/StepAudio2/output-processor-chunks-stream-t2st.wav") + with wave.open(str(wav_path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(audio) + await processor.stop(EndFrame()) + + +async def achatbot_step_audio2_audio2text(processor_name): + from apipeline.frames import AudioRawFrame, StartFrame, EndFrame, CancelFrame + from achatbot.types.frames import PathAudioRawFrame + + from achatbot.cmd.bots.voice.step_audio2.helper import ( + get_step_audio2_llm, + get_step_audio2_processor, + ) + from achatbot.types.ai_conf import AIConfig, LLMConfig + + processor = get_step_audio2_processor( + LLMConfig( + processor=processor_name, + args={ + "init_system_prompt": "", + # "prompt_wav": "/root/.achatbot/assets/default_male.wav", + "warmup_cn": 2, + "chat_history_size": None, + "text_stream_out": False, + "no_stream_sleep_time": 0.5, + "lm_model_name_or_path": MODEL_PATH, + "lm_gen_max_new_tokens": 1024, + "lm_gen_temperature": 0.1, + "lm_gen_top_k": 20, + "lm_gen_top_p": 0.9, + "lm_gen_repetition_penalty": 1.1, + "is_speaking": False, + }, + ) + ) + await processor.start(StartFrame()) + for round_idx, audio_path in enumerate( + [ + "/Step-Audio2/assets/give_me_a_brief_introduction_to_the_great_wall.wav", + "/Step-Audio2/assets/multi-turn-round1-听说荡口古镇从下个月开始取消门票了,你知道这事吗。.wav", + "/Step-Audio2/assets/multi-turn-round2-新闻说九月十九号就免费开放了。好像整个古镇都升级改造了,现在变成开放式街区了。.wav", + ] + ): + print("round: ", round_idx) + frame_iter = processor.run_voice( + PathAudioRawFrame( + path=audio_path, + audio=b"", + ) + ) + audio = b"" + async for frame in frame_iter: + if isinstance(frame, AudioRawFrame): + audio += frame.audio + print(f"{round_idx=} gen_frame-->", frame) + print(f"{round_idx=} {processor._session.chat_history=}") + if len(audio) > 0: + wav_path = Path( + f"{ASSETS_DIR}/StepAudio2/output-{processor_name}-chunks-stream-{round_idx}.wav" + ) + with wave.open(str(wav_path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(audio) + await processor.stop(EndFrame()) + + +async def achatbot_step_audio2_s2st(processor_name): + from apipeline.frames import AudioRawFrame, StartFrame, EndFrame, CancelFrame + from achatbot.types.frames import PathAudioRawFrame + + from achatbot.cmd.bots.voice.step_audio2.helper import ( + get_step_audio2_llm, + get_step_audio2_processor, + ) + from achatbot.types.ai_conf import AIConfig, LLMConfig + + processor_name = "StepS2STProcessor" + processor = get_step_audio2_processor( + LLMConfig( + processor=processor_name, + args={ + "init_system_prompt": "请仔细聆听这段语音,然后将其内容翻译成中文并用语音播报。", + # "init_system_prompt": "请仔细聆听这段语音,然后将其内容翻译成英文并用语音播报。", + "prompt_wav": "/root/.achatbot/assets/default_male.wav", + "warmup_cn": 2, + "chat_history_size": None, + "text_stream_out": False, + "no_stream_sleep_time": 0.5, + "lm_model_name_or_path": MODEL_PATH, + "lm_gen_max_new_tokens": 1024, + "lm_gen_temperature": 0.7, + "lm_gen_top_k": 20, + "lm_gen_top_p": 0.9, + "lm_gen_repetition_penalty": 1.1, + }, + ) + ) + await processor.start(StartFrame()) + for round_idx, audio_path in enumerate( + [ + "/Step-Audio2/assets/give_me_a_brief_introduction_to_the_great_wall.wav", + # "/Step-Audio2/assets/multi-turn-round1-听说荡口古镇从下个月开始取消门票了,你知道这事吗。.wav", + # "/Step-Audio2/assets/multi-turn-round2-新闻说九月十九号就免费开放了。好像整个古镇都升级改造了,现在变成开放式街区了。.wav", + ] + ): + print("round: ", round_idx) + frame_iter = processor.run_voice( + PathAudioRawFrame( + path=audio_path, + audio=b"", + ) + ) + audio = b"" + async for frame in frame_iter: + if isinstance(frame, AudioRawFrame): + audio += frame.audio + print(f"{round_idx=} gen_frame-->", frame) + print(f"{round_idx=} {processor._session.chat_history=}") + if len(audio) > 0: + wav_path = Path( + f"{ASSETS_DIR}/StepAudio2/output-{processor_name}-chunks-stream-{round_idx}.wav" + ) + with wave.open(str(wav_path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(audio) + await processor.stop(EndFrame()) + + +async def achatbot_step_audio2_aqaa(processor_name): + from apipeline.frames import AudioRawFrame, StartFrame, EndFrame, CancelFrame + from achatbot.types.frames import PathAudioRawFrame + + from achatbot.cmd.bots.voice.step_audio2.helper import ( + get_step_audio2_llm, + get_step_audio2_processor, + ) + from achatbot.types.ai_conf import AIConfig, LLMConfig + + processor_name = "StepAudio2TextAudioChatProcessor" + processor = get_step_audio2_processor( + LLMConfig( + processor=processor_name, + args={ + "init_system_prompt": "", + "prompt_wav": "/root/.achatbot/assets/default_male.wav", + "warmup_cn": 2, + "chat_history_size": None, + "text_stream_out": False, + "no_stream_sleep_time": 0.5, + "lm_model_name_or_path": MODEL_PATH, + "lm_gen_max_new_tokens": 1024, + "lm_gen_temperature": 0.7, + "lm_gen_top_k": 20, + "lm_gen_top_p": 0.9, + "lm_gen_repetition_penalty": 1.1, + }, + ) + ) + await processor.start(StartFrame()) + for round_idx, audio_path in enumerate( + [ + "/Step-Audio2/assets/multi-turn-round1-听说荡口古镇从下个月开始取消门票了,你知道这事吗。.wav", + "/Step-Audio2/assets/multi-turn-round2-新闻说九月十九号就免费开放了。好像整个古镇都升级改造了,现在变成开放式街区了。.wav", + ] + ): + print("round: ", round_idx) + frame_iter = processor.run_voice( + PathAudioRawFrame( + path=audio_path, + audio=b"", + ) + ) + audio = b"" + async for frame in frame_iter: + if isinstance(frame, AudioRawFrame): + audio += frame.audio + print(f"{round_idx=} gen_frame-->", frame) + print(f"{round_idx=} {processor._session.chat_history=}") + if len(audio) > 0: + wav_path = Path( + f"{ASSETS_DIR}/StepAudio2/output-{processor_name}-chunks-stream-{round_idx}.wav" + ) + with wave.open(str(wav_path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(audio) + await processor.stop(EndFrame()) + + +async def achatbot_step_audio2_aqaa_tools(processor_name): + from apipeline.frames import AudioRawFrame, StartFrame, EndFrame, CancelFrame + from achatbot.types.frames import PathAudioRawFrame, FunctionCallFrame + + from achatbot.cmd.bots.voice.step_audio2.helper import ( + get_step_audio2_processor, + ) + from achatbot.types.ai_conf import AIConfig, LLMConfig + + processor_name = "StepAudio2TextAudioChatProcessor" + processor = get_step_audio2_processor( + LLMConfig( + processor=processor_name, + args={ + "init_system_prompt": "你的名字叫做小跃,是由阶跃星辰公司训练出来的语音大模型。\n你具备调用工具解决问题的能力,你需要根据用户的需求和上下文情景,自主选择是否调用系统提供的工具来协助用户。\n你情感细腻,观察能力强,擅长分析用户的内容,并作出善解人意的回复,说话的过程中时刻注意用户的感受,富有同理心,提供多样的情绪价值。\n今天是2025年9月12日,星期五", + "prompt_wav": "/root/.achatbot/assets/default_male.wav", + "warmup_cn": 2, + "chat_history_size": None, + "text_stream_out": False, + "no_stream_sleep_time": 0.5, + # "tools": ["web_search","get_weather"], + "tools": ["web_search"], + "lm_model_name_or_path": MODEL_PATH, + "lm_gen_max_new_tokens": 1024, + "lm_gen_temperature": 0.7, + "lm_gen_top_k": 20, + "lm_gen_top_p": 0.9, + "lm_gen_repetition_penalty": 1.1, + "verbose": True, + }, + ), + ) + print(f"{processor.chat_history=}") + await processor.start(StartFrame()) + for round_idx, audio_path in enumerate( + [ + "/Step-Audio2/assets/帮我查一下今天上证指数的开盘价是多少.wav", + # "/Step-Audio2/assets/multi-turn-round1-听说荡口古镇从下个月开始取消门票了,你知道这事吗。.wav", + # "/Step-Audio2/assets/multi-turn-round2-新闻说九月十九号就免费开放了。好像整个古镇都升级改造了,现在变成开放式街区了。.wav", + ] + ): + print("round: ", round_idx) + frame_iter = processor.run_voice( + PathAudioRawFrame( + path=audio_path, + audio=b"", + ) + ) + audio = b"" + tool_cn = 0 + async for frame in frame_iter: + if isinstance(frame, AudioRawFrame): + audio += frame.audio + if isinstance(frame, FunctionCallFrame): + tool_cn += 1 + print(f"{round_idx=} gen_frame-->", frame) + print(f"{round_idx=} {processor.chat_history=}") + if len(audio) > 0: + wav_path = Path( + f"{ASSETS_DIR}/StepAudio2/output-{processor_name}-tools-chunks-stream-{round_idx}.wav" + ) + with wave.open(str(wav_path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(audio) + await processor.stop(EndFrame()) + + +async def achatbot_step_audio2_processor(gpu_prop, **kwargs): + from achatbot.common.logger import Logger + + Logger.init(os.getenv("LOG_LEVEL", "info").upper(), is_file=False, is_console=True) + + test_func = kwargs.get("test_func", "achatbot_step_audio2_aqaa") + processor_name = kwargs.get("processor_name") or "StepASRProcessor" + await globals()[test_func](processor_name) + + +""" +modal run src/download_models.py --repo-ids "stepfun-ai/Step-Audio-2-mini" +modal run src/download_models.py --repo-ids "stepfun-ai/Step-Audio-2-mini-Base" + +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task dump_model +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task tokenize +LLM_MODEL=stepfun-ai/Step-Audio-2-mini-Base IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task tokenize + +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_base --test-func asr_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_base --test-func audio_caption_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_base --test-func tts_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_base --test-func s2st_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_base --test-func t2st_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_base --test-func multi_turn_aqta_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_base --test-func multi_turn_aqaa_test + +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_instruct --test-func instruct_asr_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_instruct --test-func instruct_audio_caption_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_instruct --test-func instruct_s2tt_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_instruct --test-func instruct_s2st_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_instruct --test-func instruct_multi_turn_tqta_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_instruct --test-func instruct_multi_turn_tqaa_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_instruct --test-func instruct_multi_turn_aqta_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_instruct --test-func instruct_multi_turn_aqaa_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_instruct --test-func instruct_tool_call_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_instruct --test-func instruct_paralinguistic_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_instruct --test-func instruct_mmau_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task test_instruct --test-func instruct_mmau_audio_answer_test + +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task generate_stream --test-func stream_asr_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task generate_stream --test-func stream_tts_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task generate_stream --test-func stream_aqaa_test +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task generate_stream --test-func stream_aqaa_tools_test + +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task achatbot_step_audio2_processor --test-func=achatbot_step_audio2_audio2text --processor-name=StepASRProcessor +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task achatbot_step_audio2_processor --test-func=achatbot_step_audio2_audio2text --processor-name=StepAudioCaptionProcessor +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task achatbot_step_audio2_processor --test-func=achatbot_step_audio2_audio2text --processor-name=StepS2TTProcessor +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task achatbot_step_audio2_processor --test-func=achatbot_step_audio2_say +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task achatbot_step_audio2_processor --test-func=achatbot_step_audio2_t2st +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task achatbot_step_audio2_processor --test-func=achatbot_step_audio2_s2st +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task achatbot_step_audio2_processor --test-func=achatbot_step_audio2_aqaa +IMAGE_GPU=L4 modal run src/llm/transformers/step_audio2.py --task achatbot_step_audio2_processor --test-func=achatbot_step_audio2_aqaa_tools +""" + + +@app.local_entrypoint() +def main(task: str = "dump_model", test_func="", processor_name=""): + tasks = { + "dump_model": dump_model, + "tokenize": tokenize, + "test_base": test_base, + "test_instruct": test_instruct, + "generate_stream": generate_stream, + "achatbot_step_audio2_processor": achatbot_step_audio2_processor, + } + if task not in tasks: + raise ValueError(f"task {task} not found") + print(f"running task {task}") + run.remote( + tasks[task], + test_func=test_func, + processor_name=processor_name, + ) diff --git a/deps/StepAudio2 b/deps/StepAudio2 new file mode 160000 index 00000000..709e5a43 --- /dev/null +++ b/deps/StepAudio2 @@ -0,0 +1 @@ +Subproject commit 709e5a438e51a880613e43241ee728632b379fba diff --git a/pyproject.toml b/pyproject.toml index 4132c1f0..29bfbb76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,13 +10,30 @@ # pip-compile --all-extras pyproject.toml [build-system] -requires = ["setuptools>=61.0", "setuptools-scm>=8.0"] +requires = ["setuptools>=61.2", "setuptools-scm>=8.0", "wheel"] build-backend = "setuptools.build_meta" +[tool.setuptools] +# By default, include-package-data is true in pyproject.toml, so you do +# NOT have to specify this line. Either remove this or ensure it's not +# set to false. +include-package-data = true + +# if use library, need add achatbot dir in pypi_build/app dir, change import codes +[tool.setuptools.packages.find] +# !NOTE: packages find .py file, other file don't to exclude +# All the following settings are optional: +where = ["pypi_build/app"] +#include = ["deps*", "src*"] +#exclude = ["ui*","docs*"] + +[tool.setuptools_scm] +local_scheme = "no-local-version" + [project] name = "achatbot" #dynamic = ["version"] -version = "0.0.24.post5" +version = "0.0.25" authors = [{ name = "weedge", email = "weege007@gmail.com" }] maintainers = [{ name = "weedge", email = "weege007@gmail.com" }] description = "An open source chat bot for voice (and multimodal) assistants" @@ -797,22 +814,12 @@ remote_grpc_tts_server = ["achatbot[grpc,speech_tts]"] test = ["sentence_transformers~=3.0.0", "pytest~=8.3.2", "pytest-mock~=3.14.0"] -# if use library, need add achatbot dir in pypi_build/app dir, change import codes -[tool.setuptools.packages.find] -# !NOTE: packages find .py file, other file don't to exclude -# All the following settings are optional: -where = ["pypi_build/app"] -#include = ["deps", "src", "tests"] -exclude = [] - - [tool.pytest.ini_options] pythonpath = ["tests"] #include = ["tests"] -[tool.setuptools_scm] -local_scheme = "no-local-version" +#------------------- linter ------------------------- # https://docs.astral.sh/ruff/configuration/ [tool.ruff] diff --git a/scripts/pypi_achatbot.sh b/scripts/pypi_achatbot.sh index d8487179..2bca5971 100644 --- a/scripts/pypi_achatbot.sh +++ b/scripts/pypi_achatbot.sh @@ -6,14 +6,16 @@ set -e pypi= if [ $# -ge 1 ] && [ "$1" == "prod" ]; then pypi=pypi + echo "upload to $pypi" fi if [ $# -ge 1 ] && [ "$1" == "test" ]; then pypi=testpypi + echo "upload to $pypi" fi if [ $# -ge 1 ] && [ "$1" == "dev" ]; then pypi=devpypi + echo "install to local package dev" fi -echo "upload to $pypi" rm -rf pypi_build/app/achatbot mkdir -p pypi_build/app/achatbot @@ -23,9 +25,19 @@ rm -f deps/CosyVoice/third_party/Matcha-TTS/data rm -f deps/GLM4Voice/third_party/Matcha-TTS/data rm -f deps/KimiAudio/kimia_infer/models/tokenizer/glm4/third_party/Matcha-TTS/data rm -f deps/VITAAudio/third_party/GLM-4-Voice/third_party/Matcha-TTS/data -cp -r deps/* pypi_build/app/achatbot/ +echo "copy deps ..." +#cp -r deps/* pypi_build/app/achatbot/ +find deps -type f -name "*.py" -exec sh -c ' + dest="pypi_build/app/achatbot/$(echo {} | sed "s|^deps/||")" + mkdir -p "$(dirname "$dest")" + cp {} "$dest" +' \; + + + +echo "replace ..." find pypi_build/app/achatbot/ | grep -E "(/__pycache__$|\.pyc$|\.pyo$)" | xargs rm -rf find pypi_build/app/achatbot/ -type f -print0 | xargs -0 perl -i -pe \ @@ -33,13 +45,17 @@ find pypi_build/app/achatbot/ -type f -print0 | xargs -0 perl -i -pe \ if [ -n "$pypi" ]; then + echo "build ..." pip install -q build rm -rf dist && python3 -m build if [ "$pypi" == "devpypi" ]; then + echo "install ..." pip install -U dist/*.whl else + echo "upload ..." twine upload --verbose --skip-existing --repository $pypi dist/* fi + echo "clear ..." rm -rf pypi_build/app/achatbot fi diff --git a/src/cmd/bots/__init__.py b/src/cmd/bots/__init__.py index e3ecaa8f..0fbabcbb 100644 --- a/src/cmd/bots/__init__.py +++ b/src/cmd/bots/__init__.py @@ -166,6 +166,10 @@ def import_bots(bot_name: str = "DummyBot"): if "DailyASRTranslateTTSBot" in bot_name: from .translation import daily_asr_translate_tts_bot + return True + if "DailyStepAudio2AQAABot" in bot_name: + from .voice.step_audio2 import daily_aqaa_bot + return True if "LivekitBot" in bot_name: from . import livekit_bot diff --git a/src/cmd/bots/base.py b/src/cmd/bots/base.py index 82d4fc82..0cdbcd4e 100644 --- a/src/cmd/bots/base.py +++ b/src/cmd/bots/base.py @@ -78,6 +78,7 @@ def __init__(self, **args) -> None: self._bot_config = self.args.bot_config self._handle_sigint = self.args.handle_sigint + def init_bot_config(self): try: logging.debug(f"args.bot_config: {self.args.bot_config}") diff --git a/src/cmd/bots/voice/step_audio2/daily_aqaa_bot.py b/src/cmd/bots/voice/step_audio2/daily_aqaa_bot.py new file mode 100644 index 00000000..7680026b --- /dev/null +++ b/src/cmd/bots/voice/step_audio2/daily_aqaa_bot.py @@ -0,0 +1,113 @@ +import logging + +from dotenv import load_dotenv +from apipeline.pipeline.pipeline import Pipeline +from apipeline.pipeline.task import PipelineParams, PipelineTask +from apipeline.pipeline.runner import PipelineRunner +from apipeline.processors.logger import FrameLogger +from apipeline.frames import AudioRawFrame, TextFrame + +from src.processors.speech.audio_save_processor import AudioSaveProcessor +from src.processors.aggregators.user_audio_response import UserAudioResponseAggregator +from src.cmd.bots.base_daily import DailyRoomBot +from src.common.types import DailyParams +from src.transports.daily import DailyTransport +from src.cmd.bots import register_ai_room_bots +from src.types.frames import PathAudioRawFrame, LLMGenedTokensFrame, BotSpeakingFrame +from .helper import get_step_audio2_processor, get_step_audio2_llm + + +load_dotenv(override=True) + + +@register_ai_room_bots.register +class DailyStepAudio2AQAABot(DailyRoomBot): + """ + - use daily audio stream(bytes) --> Step2 voice processor --> text/audio_bytes + """ + + def __init__(self, **args) -> None: + super().__init__(**args) + self.init_bot_config() + + self.vad_analyzer = None + self.audio_llm = None + + def load(self): + self.vad_analyzer = self.get_vad_analyzer() + self.audio_llm = get_step_audio2_llm(self.bot_config()) + + async def arun(self): + assert self.vad_analyzer is not None + assert self.audio_llm is not None + + self.params = DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_enabled=True, + vad_analyzer=self.vad_analyzer, + vad_audio_passthrough=True, + ) + + # src/processors/voice/step_audio2_processor.py + self._voice_processor = get_step_audio2_processor( + bot_config=self._bot_config.voice_llm, + session=self.session, + audio_llm=self.audio_llm, + ) + if hasattr(self._voice_processor, "stream_info"): + stream_info = self._voice_processor.stream_info + self.params.audio_out_sample_rate = stream_info["sample_rate"] + self.params.audio_out_channels = stream_info["channels"] + logging.info(f"params: {self.params}") + + transport = DailyTransport( + self.args.room_url, + self.args.token, + self.args.bot_name, + self.params, + ) + + self.task = PipelineTask( + Pipeline( + [ + transport.input_processor(), + UserAudioResponseAggregator(), + FrameLogger(include_frame_types=[AudioRawFrame]), + # AudioSaveProcessor(prefix_name="user_audio_aggr"), + # FrameLogger(include_frame_types=[PathAudioRawFrame]), + self._voice_processor, + FrameLogger( + include_frame_types=[TextFrame, AudioRawFrame, LLMGenedTokensFrame] + ), + AudioSaveProcessor(prefix_name="bot_speak"), + # FrameLogger(include_frame_types=[BotSpeakingFrame]), + FrameLogger(include_frame_types=[AudioRawFrame]), + transport.output_processor(), # BotSpeakingFrame + ] + ), + params=PipelineParams( + allow_interruptions=False, + enable_metrics=False, + send_initial_empty_metrics=False, + ), + ) + + transport.add_event_handlers( + "on_first_participant_joined", + [self.on_first_participant_joined, self.on_first_participant_say_hi], + ) + transport.add_event_handler("on_participant_left", self.on_participant_left) + transport.add_event_handler("on_call_state_updated", self.on_call_state_updated) + + await PipelineRunner().run(self.task) + + async def on_first_participant_say_hi(self, transport: DailyTransport, participant): + await self._voice_processor.say( + "你好。我是一名助手,欢迎语音聊天!", + temperature=0.1, + max_new_tokens=1024, + top_k=20, + top_p=0.95, + repetition_penalty=1.1, + ) diff --git a/src/cmd/bots/voice/step_audio2/helper.py b/src/cmd/bots/voice/step_audio2/helper.py new file mode 100644 index 00000000..f25bf561 --- /dev/null +++ b/src/cmd/bots/voice/step_audio2/helper.py @@ -0,0 +1,69 @@ +import os +import importlib + +from src.common.interface import ILlm +from src.common.session import Session +from src.types.ai_conf import AIConfig, LLMConfig, BaseConfig +from src.common.types import MODELS_DIR +from src.processors.voice.step_audio2_processor import Token2wav, StepAudio2BaseProcessor + + +def get_step_audio2_llm(llm_config: BaseConfig): + from src.core.llm.transformers.manual_voice_step2 import TransformersManualVoiceStep2 + + lm_model_name_or_path = os.path.join(MODELS_DIR, "stepfun-ai/Step-Audio-2-mini") + args = llm_config.args if llm_config.args else {} + if args.get("lm_model_name_or_path", None) is None: + args["lm_model_name_or_path"] = lm_model_name_or_path + return TransformersManualVoiceStep2(**args) + + +def get_step_audio2_processor( + llm_config: BaseConfig, + session: Session | None = None, + token2wav: Token2wav | None = None, + audio_llm: ILlm | None = None, + processor_class_name: str | None = None, +) -> StepAudio2BaseProcessor: + if processor_class_name is None and hasattr(llm_config, "processor"): + processor_class_name = llm_config.processor + try: + if bool(os.getenv("ACHATBOT_PKG", "")): + module = importlib.import_module("achatbot.processors.voice.step_audio2_processor") + else: + module = importlib.import_module("src.processors.voice.step_audio2_processor") + processor_class = getattr(module, processor_class_name) + return processor_class( + session=session, + token2wav=token2wav, + audio_llm=audio_llm or get_step_audio2_llm(llm_config), + **llm_config.args, + ) + except (ImportError, AttributeError) as e: + raise ValueError(f"cannot import {processor_class_name}: {str(e)}") + + +""" +python -m src.cmd.bots.voice.step_audio2.helper +ACHATBOT_PKG=1 python -m src.cmd.bots.voice.step_audio2.helper +""" +if __name__ == "__main__": + get_step_audio2_processor( + LLMConfig( + processor="StepAudio2TextAudioChatProcessor", + args={ + "init_system_prompt": "", + "prompt_wav": "/root/.achatbot/assets/default_male.wav", + "warmup_cn": 2, + "chat_history_size": None, + "text_stream_out": False, + "no_stream_sleep_time": 0.5, + "lm_model_name_or_path": "stepfun-ai/Step-Audio-2-mini", + "lm_gen_max_new_tokens": 64, + "lm_gen_temperature": 0.1, + "lm_gen_top_k": 20, + "lm_gen_top_p": 0.95, + "repetition_penalty": 1.1, + }, + ) + ) diff --git a/src/common/chat_history.py b/src/common/chat_history.py index 03dce4ae..a495d266 100644 --- a/src/common/chat_history.py +++ b/src/common/chat_history.py @@ -1,22 +1,25 @@ class ChatHistory: """ - buffer the local chat hostory with limit size using to avoid OOM issues. + buffer the local chat hostory with limit size using to avoid llm context too long. - if size is None, no limit - if size < 0, no history !TODO: use kv store history like mem0. @weedge """ - def __init__(self, size: int | None = None): + def __init__( + self, size: int | None = None, init_chat_message: dict = None, init_chat_tools: dict = None + ): self.size = size - self.init_chat_message = None + self.init_chat_message = init_chat_message + self.init_chat_tools = init_chat_tools # maxlen is necessary pair, # since a each new step we add an prompt and assitant answer self.buffer = [] def clear(self): self.buffer.clear() - + def append(self, item): if self.size and self.size < 0: return @@ -29,11 +32,41 @@ def append(self, item): self.buffer.pop(0) self.buffer.pop(0) + def pop(self, index: int = -1): + if self.size and self.size < 0: + return + if len(self.buffer) > 0: + self.buffer.pop(index) + def init(self, init_chat_message: dict): self.init_chat_message = init_chat_message + def init_tools(self, tools: dict): + self.init_chat_tools = tools + def to_list(self) -> list: if self.init_chat_message: - return [self.init_chat_message] + self.buffer + if self.init_chat_tools: + return [self.init_chat_message, self.init_chat_tools] + self.buffer + else: + return [self.init_chat_message] + self.buffer else: return self.buffer + + def __getstate__(self): + return { + "size": self.size, + "init_chat_message": self.init_chat_message, + "init_chat_tools": self.init_chat_tools, + "buffer": self.buffer, + } + + def __setstate__(self, state): + self.size = state["size"] + self.init_chat_message = state["init_chat_message"] + self.init_chat_tools = state["init_chat_tools"] + self.buffer = state["buffer"] + + def __repr__(self) -> str: + chat_history = self.__getstate__() + return f"{chat_history=}" diff --git a/src/common/register.py b/src/common/register.py index d3bd17d4..2231726f 100644 --- a/src/common/register.py +++ b/src/common/register.py @@ -46,3 +46,25 @@ def items(self): def dict(self): return self._dict + + +""" +python -m src.common.register +""" + +if __name__ == "__main__": + functions = Register("llm_function_calling") + + @functions.register("test") + def test(): + print("test") + + @functions.register + def test1(): + print("test") + + @functions.register + def test2(): + print("test") + + print(functions.items()) diff --git a/src/common/session.py b/src/common/session.py index dfff8e0c..81ae7e69 100644 --- a/src/common/session.py +++ b/src/common/session.py @@ -1,13 +1,21 @@ from .types import SessionCtx +from .chat_history import ChatHistory class Session: def __init__(self, **args) -> None: + chat_history_size = args.pop("chat_history_size", None) self.ctx = SessionCtx(**args) self.config = {} self.chat_round = 0 - # just for local history,@todo: use kv store history like mem0 - self.chat_history = [] + self.chat_history = ChatHistory(size=chat_history_size) + + def init_chat_message(self, init_chat_message: dict): + self.chat_history.init(init_chat_message) + + def reset(self): + self.chat_round = 0 + self.chat_history.clear() def __getstate__(self): return { @@ -24,14 +32,13 @@ def __setstate__(self, state): self.ctx = state["ctx"] def __repr__(self) -> str: - d = { + session = { "config": self.config, "chat_round": self.chat_round, "chat_history": self.chat_history, "ctx": self.ctx, } - s = f"session: {d}" - return s + return f"{session=}" def set_client_id(self, client_id): self.ctx.client_id = client_id @@ -78,3 +85,70 @@ def close(self): self.ctx.llm.close() if hasattr(self.ctx.tts, "close"): self.ctx.tts.close() + + +"""" +python -m src.common.session +""" + + +def test_chat_history(): + chat_history_size = 3 + session = Session( + chat_history_size=chat_history_size, + client_id="test_client", + asr=None, + llm=None, + tts=None, + vad=None, + waker=None, + buffering_strategy=None, + on_session_start=None, + on_session_end=None, + ) + print("init", session) + + for i in range(10): + session.chat_history.append({"role": "user", "content": f"Hello {i}"}) + session.chat_history.append( + {"role": "assistant", "content": f"Hi, how can I help you {i}?"} + ) + session.increment_chat_round() + print(f"{i=} chat", session) + if i < chat_history_size: + assert len(session.chat_history.to_list()) == (i + 1) * 2 + else: + assert len(session.chat_history.to_list()) == chat_history_size * 2 + + session.reset() + print("reset", session) + assert len(session.chat_history.to_list()) == 0 + + print(f"\ntest_chat_history pass\n\n") + + +def test_pickle(): + import pickle + + chat_history_size = 3 + session = Session( + chat_history_size=chat_history_size, + **SessionCtx(client_id="test_client").__dict__, + ) + print("init", session) + + pickle_data = pickle.dumps(session) + print("session dump", pickle_data) + + load_session = pickle.loads(pickle_data) + print("session load", load_session) + assert str(session) == str(load_session) + print(f"\ntest_pickle pass\n\n") + + +""" +python -m src.common.session +""" +if __name__ == "__main__": + test_pickle() + test_chat_history() diff --git a/src/common/utils/audio_utils.py b/src/common/utils/audio_utils.py index 897b7c74..57d957f1 100644 --- a/src/common/utils/audio_utils.py +++ b/src/common/utils/audio_utils.py @@ -45,7 +45,7 @@ def bytes2TorchTensorWith16(frames: bytes | bytearray): if waveform_tensor.ndim == 1: # float_data= float_data.reshape(1, -1) waveform_tensor = waveform_tensor.reshape(1, -1) - return waveform_tensor + return waveform_tensor # (1, size(time)) def npArray2bytes(np_arr: np.ndarray) -> bytearray: diff --git a/src/core/llm/__init__.py b/src/core/llm/__init__.py index bf7d53ba..7be2a447 100644 --- a/src/core/llm/__init__.py +++ b/src/core/llm/__init__.py @@ -87,6 +87,8 @@ def getEngine(tag, **kwargs) -> interface.ILlmGenerator | interface.ILlm | Engin from .transformers import manual_vision_ernie4v elif "llm_transformers_manual_vision_skyworkr1v" in tag: from .transformers import manual_vision_skyworkr1v + elif "llm_transformers_manual_voice_step2" in tag: + from .transformers import manual_voice_step2 elif "llm_transformers_manual" == tag: from .transformers import manual elif "llm_transformers_pipeline" == tag: @@ -524,6 +526,7 @@ def get_vita_audio_transformers_args() -> dict: "llm_transformers_manual_qwen2_5omni_vision_voice": get_qwen2_5omni_transformers_args, "llm_transformers_manual_qwen2_5omni_text_voice": get_qwen2_5omni_transformers_args, "llm_transformers_manual_qwen2_5omni_audio_voice": get_qwen2_5omni_transformers_args, + "llm_transformers_manual_voice_step2": get_llm_transformers_args, "llm_transformers_generator": get_llm_transformers_args, "llm_llamacpp_generator": get_llm_llamacpp_generator_args, "llm_vllm_generator": get_llm_vllm_generator_args, diff --git a/src/core/llm/transformers/manual_voice_step2.py b/src/core/llm/transformers/manual_voice_step2.py new file mode 100644 index 00000000..0cab390b --- /dev/null +++ b/src/core/llm/transformers/manual_voice_step2.py @@ -0,0 +1,86 @@ +import time +import logging + +import torch + +from src.common.session import Session +from src.types.llm.transformers import TransformersLMArgs +from .base import TransformersBaseLLM +from .models.step_audio2 import StepAudio2Stream + + +class TransformersManualVoiceStep2(TransformersBaseLLM): + """ + https://huggingface.co/stepfun-ai/Step-Audio-2-mini + """ + + TAG = "llm_transformers_manual_voice_step2" + RATE = 24000 + + def __init__(self, **args): + self.args = TransformersLMArgs() + self.args.update(**args) + logging.info(f"args: {self.args}") + self._audio_llm = StepAudio2Stream(model_path=self.args.lm_model_name_or_path) + self.eos_token_id = [ + self._audio_llm.eos_token_id, + self._audio_llm.llm_tokenizer.convert_tokens_to_ids("<|endoftext|>"), + ] + self.warmup() + + @property + def llm(self): + return self._audio_llm + + @property + def llm_tokenizer(self): + return self._audio_llm.llm_tokenizer + + @torch.inference_mode() + def warmup(self): + if self.args.warmup_steps < 1: + return + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "human", "content": self.args.warmup_prompt}, + {"role": "assistant", "content": None}, + ] + for step in range(self.args.warmup_steps): + token_iter = self._audio_llm( + messages, + max_new_tokens=128, + temperature=0.1, + do_sample=True, + eos_token_id=self.eos_token_id, + ) + first = True + start = time.time() + for _ in token_iter: + if first: + first = False + ttft = time.time() - start + total_time = time.time() - start + logging.info(f"warmup {step=} {ttft=:.3f}s {total_time=:.3f}s") + + @torch.inference_mode() + def generate(self, session: Session, **kwargs): + kwargs["max_new_tokens"] = kwargs.get("max_new_tokens", self.args.lm_gen_max_new_tokens) + kwargs["top_k"] = kwargs.get("top_k", self.args.lm_gen_top_k) + kwargs["top_p"] = kwargs.get("top_p", self.args.lm_gen_top_p) + kwargs["do_sample"] = ( + True if kwargs.get("temperature", self.args.lm_gen_temperature) > 0.0 else False + ) + kwargs["temperature"] = kwargs.get("temperature", self.args.lm_gen_temperature) + kwargs["repetition_penalty"] = kwargs.get( + "repetition_penalty", self.args.lm_gen_repetition_penalty + ) + kwargs["min_new_tokens"] = kwargs.get("min_new_tokens", self.args.lm_gen_min_new_tokens) + stop_ids = kwargs.pop("stop_ids", self.args.lm_gen_stop_ids) + for token_id in self._audio_llm( + messages=session.ctx.state["messages"], + **kwargs, + ): + if token_id in stop_ids: + break + yield token_id diff --git a/src/core/llm/transformers/models/step_audio2.py b/src/core/llm/transformers/models/step_audio2.py new file mode 100644 index 00000000..296aad35 --- /dev/null +++ b/src/core/llm/transformers/models/step_audio2.py @@ -0,0 +1,183 @@ +import os +import sys +import logging +from threading import Thread +from typing import BinaryIO + +try: + import torch + import numpy + + from transformers import GenerationConfig + + cur_dir = os.path.dirname(__file__) + if bool(os.getenv("ACHATBOT_PKG", "")): + sys.path.insert(1, os.path.join(cur_dir, "../../../../StepAudio2")) + else: + sys.path.insert(1, os.path.join(cur_dir, "../../../../../deps/StepAudio2")) + + from deps.StepAudio2.stepaudio2 import StepAudio2, StepAudio2Base + from deps.StepAudio2.utils import ( + compute_token_num, + load_audio, + log_mel_spectrogram, + padding_mels, + ) + + from src.core.llm.transformers.streamer import TokenStreamer + from src.common.utils.helper import get_device + +except ModuleNotFoundError as e: + raise Exception(f"Missing module: {e}") + + +class StepAudio2StreamBase(StepAudio2Base): + def apply_chat_template(self, messages: list): + """ + add np.ndarray/torch.Tensor audio msg support + - audio sample rate: 16000 + """ + results = [] + mels = [] + for msg in messages: + content = msg + if isinstance(content, str): + text_with_audio = content + results.append(text_with_audio) + elif isinstance(content, dict): + if content["type"] == "text": + results.append(f"{content['text']}") + elif content["type"] == "audio": + audio = content["audio"] + if isinstance(audio, (BinaryIO, str, os.PathLike)): + audio = load_audio(audio) + elif isinstance(audio, numpy.ndarray): + audio = torch.from_numpy(audio, dtype=torch.float32) + assert isinstance(audio, torch.Tensor), f"Unsupported audio type: {type(audio)}" + if len(audio.shape) > 1: # [1, size] + audio = audio.squeeze(0) # [size] + for i in range(0, audio.shape[0], 16000 * 25): + mel = log_mel_spectrogram( + audio[i : i + 16000 * 25], n_mels=128, padding=479 + ) + mels.append(mel) + audio_tokens = "" * compute_token_num(mel.shape[1]) + results.append(f"{audio_tokens}") + elif content["type"] == "token": + results.append(content["token"]) + else: + raise ValueError(f"Unsupported content type: {type(content)}") + return results, mels + + def __call__(self, messages: list, **kwargs): + messages, mels = self.apply_chat_template(messages) + logging.debug(f"messages: {messages}") + + # Tokenize prompts + prompt_ids = [] + for msg in messages: + if isinstance(msg, str): + prompt_ids.append( + self.llm_tokenizer(text=msg, return_tensors="pt", padding=True)["input_ids"] + ) + elif isinstance(msg, list): + prompt_ids.append(torch.tensor([msg], dtype=torch.int32)) + else: + raise ValueError(f"Unsupported content type: {type(msg)}") + prompt_ids = torch.cat(prompt_ids, dim=-1).cuda() + attention_mask = torch.ones_like(prompt_ids) + + # mels = None if len(mels) == 0 else torch.stack(mels).cuda() + # mel_lengths = None if mels is None else torch.tensor([mel.shape[1] - 2 for mel in mels], dtype=torch.int32, device='cuda') + if len(mels) == 0: + mels = None + mel_lengths = None + else: + mels, mel_lengths = padding_mels(mels) + mels = mels.cuda() + mel_lengths = mel_lengths.cuda() + + generation_config = dict( + # max_new_tokens=256, + pad_token_id=self.llm_tokenizer.pad_token_id, + eos_token_id=self.eos_token_id, + ) + generation_config.update(kwargs) + generation_config = GenerationConfig(**generation_config) + logging.debug(f"generation_config: {generation_config}") + + streamer = TokenStreamer(skip_prompt=True) + + generation_kwargs = dict( + input_ids=prompt_ids, + wavs=mels, + wav_lens=mel_lengths, + attention_mask=attention_mask, + generation_config=generation_config, + streamer=streamer, + ) + + thread = Thread(target=self.llm.generate, kwargs=generation_kwargs) + thread.start() + + stop_ids = ( + [generation_config.eos_token_id] + if isinstance(generation_config.eos_token_id, int) + else generation_config.eos_token_id + ) + for token_id in streamer: + if token_id in stop_ids: + break + yield token_id + + +class StepAudio2Stream(StepAudio2StreamBase): + def apply_chat_template(self, messages: list): + """ + add np.ndarray/torch.Tensor audio msg support + - audio sample rate: 16000 + """ + results = [] + mels = [] + for msg in messages: + role = msg["role"] + content = msg["content"] + if role == "user": + role = "human" + if isinstance(content, str): + text_with_audio = f"<|BOT|>{role}\n{content}" + text_with_audio += "<|EOT|>" if msg.get("eot", True) else "" + results.append(text_with_audio) + elif isinstance(content, list): + results.append(f"<|BOT|>{role}\n") + for item in content: + if item["type"] == "text": + results.append(f"{item['text']}") + elif item["type"] == "audio": + audio = item["audio"] + if isinstance(audio, (BinaryIO, str, os.PathLike)): + audio = load_audio(audio) + elif isinstance(audio, numpy.ndarray): + audio = torch.from_numpy(audio, dtype=torch.float32) + assert isinstance(audio, torch.Tensor), ( + f"Unsupported audio type: {type(audio)}" + ) + if len(audio.shape) > 1: # [1, size] + audio = audio.squeeze(0) # [size] + for i in range(0, audio.shape[0], 16000 * 25): + mel = log_mel_spectrogram( + audio[i : i + 16000 * 25], n_mels=128, padding=479 + ) + mels.append(mel) + audio_tokens = "" * compute_token_num(mel.shape[1]) + results.append(f"{audio_tokens}") + elif item["type"] == "token": + results.append(item["token"]) + if msg.get("eot", True): + results.append("<|EOT|>") + elif content is None: + results.append(f"<|BOT|>{role}\n") + else: + raise ValueError(f"Unsupported content type: {type(content)}") + # print(results) + return results, mels diff --git a/src/modules/functions/function.py b/src/modules/functions/function.py index b1a74390..0adf24fc 100644 --- a/src/modules/functions/function.py +++ b/src/modules/functions/function.py @@ -8,7 +8,7 @@ class FunctionManager: FunctionManager just as a namespace to use static methods. """ - functions = Register("llm_function_calling") + functions = Register("llm_function_callings") def __init__(self): raise RuntimeError("FunctionManager is not intended to be instantiated") @@ -20,7 +20,32 @@ def get_tool_calls() -> list[dict]: tool_calls.append(func_cls.get_tool_call()) return tool_calls + @staticmethod + def get_tool_calls_by_names(names: list[str]) -> list[dict]: + tool_calls = [] + func_list = FunctionManager.functions.items() + for name, func_cls in func_list: + if name in names: + tool_calls.append(func_cls.get_tool_call()) + return tool_calls + @staticmethod def execute(name: str, session, **args): func_cls = FunctionManager.functions[name] return func_cls.execute(session, **args) + + +""" +python -m src.modules.functions.function +""" +if __name__ == "__main__": + import src.modules.functions.search.api + import src.modules.functions.weather.api + from src.modules.functions.function import FunctionManager + + items = FunctionManager.functions.items() + print(items) + tool_calls = FunctionManager.get_tool_calls_by_names( + ["web_search", "web_search2", "get_weather"] + ) + print(tool_calls) diff --git a/src/processors/voice/helper.py b/src/processors/voice/helper.py new file mode 100644 index 00000000..a1e031e8 --- /dev/null +++ b/src/processors/voice/helper.py @@ -0,0 +1,33 @@ +import json + + +def extract_function_info(tool_calls_token: str) -> tuple: + """ + 从 tool_calls_token 字符串中提取 function_name 和 function_args + + 参数格式示例: + 'function\nweb_search\n{"query": "2025年8月28日 上证指数 开盘价"}' + + 返回: (function_name, function_args_dict) + """ + # 按换行符分割字符串 + parts = tool_calls_token.split("\n") + + # 验证格式是否正确 + if len(parts) < 3 or parts[0] != "function": + raise ValueError("无效的 tool_calls_token 格式") + + # 提取函数名 + function_name = parts[1] + + try: + # 合并剩余部分作为 JSON 字符串(处理可能的多行 JSON) + json_str = "\n".join(parts[2:]) + function_args = json.loads(json_str) + except json.JSONDecodeError: + raise ValueError("无法解析 function_args JSON") + + return function_name, function_args + + +# TODO: MCP diff --git a/src/processors/voice/step_audio2_processor.py b/src/processors/voice/step_audio2_processor.py new file mode 100644 index 00000000..c3ff76cc --- /dev/null +++ b/src/processors/voice/step_audio2_processor.py @@ -0,0 +1,789 @@ +import os +import sys +import time +import json +import queue +import asyncio +import logging +import threading +from typing import AsyncGenerator + +import uuid +import torch +from apipeline.frames import * + +try: + cur_dir = os.path.dirname(__file__) + if bool(os.getenv("ACHATBOT_PKG", "")): + sys.path.insert(1, os.path.join(cur_dir, "../../StepAudio2")) + else: + sys.path.insert(1, os.path.join(cur_dir, "../../../deps/StepAudio2")) + + from deps.StepAudio2.token2wav import Token2wav +except ModuleNotFoundError as e: + raise Exception(f"Missing module: {e}") + +from src.common.interface import ILlm +from src.common.types import ASSETS_DIR +from src.processors.voice.base import VoiceProcessorBase +from src.common.session import Session +from src.common.types import SessionCtx +from src.common.utils.audio_utils import bytes2TorchTensorWith16 +from src.types.frames import ( + TextQuestionsAudioRawFrame, + PathAudioRawFrame, + LLMGenedTokensFrame, + FunctionCallFrame, +) +import src.modules.functions.search.api +import src.modules.functions.weather.api +from src.modules.functions.function import FunctionManager +from .helper import extract_function_info + + +class StepAudio2BaseProcessor(VoiceProcessorBase): + """ """ + + CHUNK_SIZE = 25 + SYS_PROMPT = "You are a helpful assistant." + + def __init__( + self, + *, + init_system_prompt: str = "", + text_stream_out: bool = False, + prompt_wav: str = "", + warmup_cn: int = 1, + chat_history_size: int | None = None, + no_stream_sleep_time: float = 0.5, + chunk_size: int = 0, + session: Session | None = None, + audio_llm: ILlm | None = None, + token2wav: Token2wav | None = None, + is_speaking: bool = True, + tools: list = [], + verbose: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + assert audio_llm is not None, "audio_llm is None" + + from src.core.llm.transformers.manual_voice_step2 import TransformersManualVoiceStep2 + + assert isinstance(audio_llm, TransformersManualVoiceStep2), ( + "audio_llm is not TransformersManualVoiceStep2" + ) + + self._audio_llm = audio_llm + self._is_speaking = is_speaking + if is_speaking is True: + token2wav_path = os.path.join(audio_llm.args.lm_model_name_or_path, "token2wav") + self._token2wav = token2wav or Token2wav(token2wav_path) + if torch.cuda.is_available(): + self._token2wav.flow.scatter_cuda_graph(True) + + self._prompt_wav = prompt_wav or os.path.join(ASSETS_DIR, "default_female.wav") + self._token2wav.set_stream_cache(self._prompt_wav) + if warmup_cn > 0: + for i in range(warmup_cn): + start = time.time() + self._token2wav.warmup(self._prompt_wav) + logging.info(f"Token2wav warmup {i=} done in {time.time() - start:.3f}s") + + self._system_prompt = init_system_prompt or self.SYS_PROMPT + self._text_stream_out = text_stream_out + self._chunk_size = chunk_size or self.CHUNK_SIZE + + self._session = session or Session( + chat_history_size=chat_history_size, **SessionCtx(str(uuid.uuid4())).__dict__ + ) + self._session.chat_history.init({"role": "system", "content": self._system_prompt}) + tool_calls = FunctionManager.get_tool_calls_by_names(tools) + if len(tool_calls) > 0: + tool_json_schemas = json.dumps(tool_calls) + self._session.chat_history.init_tools( + {"role": "tool_json_schemas", "content": tool_json_schemas} + ) + + # for async generate to yield + self._queue = queue.Queue() + self._input_queue = queue.Queue() + self._generate_thread = None + + self._sleep_time = no_stream_sleep_time + self._verbose = verbose + + @property + def chat_history(self): + return self._session.chat_history + + def reset(self): + if self._is_speaking is True: + self._token2wav.cache = {} + self._session.reset() + + def _generate(self): + while True: + try: + item = self._input_queue.get() + if item is None: + self._queue.put(None) # Signal the end of the stream + break # Signal to stop the thread + session, kwargs = item + token_iter = self._audio_llm.generate(session, **kwargs) + self.put_out_audio_text(token_iter, is_out_text=True) + self._queue.put(None) # Signal the end of the stream + except Exception as e: + logging.error(f"Exception generate: {e}", exc_info=True) + self._queue.put(None) # Signal the end of the stream + break + + async def start(self, frame: StartFrame): + await super().start(frame) + self._generate_thread = threading.Thread(target=self._generate) + self._generate_thread.start() + logging.info("start done") + + async def stop(self, frame: EndFrame): + await super().stop(frame) + self._input_queue.put(None) # Signal the thread to stop + self._generate_thread.join() # Wait for the thread to finish + logging.info("stop done") + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + self._input_queue.put(None) # Signal the thread to stop + self._generate_thread.join() # Wait for the thread to finish + logging.info("cancel done") + + async def gen(self, is_push_frame: bool = False) -> AsyncGenerator[Frame, None]: + while True: + try: + item = self._queue.get_nowait() + if item is None: + logging.info(f"generate done") + break # End of the stream + logging.info(f"generate data: {item}") + if is_push_frame is True: + await self.push_frame(item) + yield None + else: + yield item + except queue.Empty: + # yield asysncio.sleep to allow other tasks to run, e.g.: sink task (write audio) + await asyncio.sleep(self._sleep_time) + # logging.info(f"queue empty sleep {self._sleep_time}") + continue + + def send_input(self, session: Session, **kwargs): + self._input_queue.put((session, kwargs)) + + @property + def stream_info(self) -> dict: + """Return dict out stream info""" + return { + "sample_rate": self._audio_llm.RATE, + "channels": 1, + } + + async def say( + self, + text: str, + system_prompt: str = "以自然的语速读出下面的文字。\n", + **kwargs, + ): + async for item in self.generator_say( + text, + system_prompt=system_prompt, + is_push_frame=True, + **kwargs, + ): + pass + + async def generator_say( + self, + text: str, + system_prompt: str = "以自然的语速读出下面的文字。\n", + is_push_frame: str = True, + **kwargs, + ) -> AsyncGenerator[Frame, None]: + """ + support: en,zh,ja + """ + logging.info(f"say: {text}") + + messages = [ + {"role": "system", "content": system_prompt or self._system_prompt}, + {"role": "human", "content": text}, + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ] + self._session.ctx.state["messages"] = messages + self.send_input(self._session, **kwargs) + async for item in self.gen(is_push_frame=is_push_frame): + yield item + + def put_out_audio_text(self, token_iter, is_out_text: bool = True): + output_token_ids = [] + output_audio_token_ids = [] + out_text_token_ids = [] + is_tool = False + tool_calls_token_ids = [] + is_tag = False + buffer = [] + unicode_token_id = [] + for token_id in token_iter: + if self._verbose is True: + print(f"{token_id=} {self._audio_llm.llm_tokenizer.decode(token_id)=}") + output_token_ids.append(token_id) + if token_id < 151688: # text + if token_id == 151657: # + is_tool = True + continue + if token_id == 151658: # + is_tool = False + continue + if is_tool: + tool_calls_token_ids.append(token_id) + continue + + if token_id == 27: # < + is_tag = True + continue + if token_id == 29: # > + is_tag = False + continue + if is_tag: # <***> + continue + + out_text_token_ids.append(token_id) + if is_out_text is True and self._text_stream_out is True: + out_text = self._audio_llm.llm_tokenizer.decode(unicode_token_id + [token_id]) + if "�" in out_text: + unicode_token_id.append(token_id) + else: + unicode_token_id = [] + frame = TextFrame(text=out_text) + self._queue.put(frame) + if token_id > 151695 and self._is_speaking is True: # audio + audio_token_id = token_id - 151696 + if audio_token_id < 6561: # remove audio padding + output_audio_token_ids.append(audio_token_id) + buffer.append(audio_token_id) + if len(buffer) >= self._chunk_size + self._token2wav.flow.pre_lookahead_len: + out_bytes = self._token2wav.stream( + buffer[: self._chunk_size + self._token2wav.flow.pre_lookahead_len], + prompt_wav=self._prompt_wav, + last_chunk=False, + ) + frame = AudioRawFrame( + audio=out_bytes, + sample_rate=self._audio_llm.RATE, + num_channels=1, + ) + self._queue.put(frame) + buffer = buffer[self._chunk_size :] + if len(buffer) > 0 and self._is_speaking is True: + logging.info(f"last chunk size: {len(buffer)}") + out_bytes = self._token2wav.stream(buffer, prompt_wav=self._prompt_wav, last_chunk=True) + frame = AudioRawFrame( + audio=out_bytes, + sample_rate=self._audio_llm.RATE, + num_channels=1, + ) + self._queue.put(frame) + + out_text = "" + if len(out_text_token_ids) > 0 and is_out_text is True: + out_text = self._audio_llm.llm_tokenizer.decode(out_text_token_ids) + if self._text_stream_out is False: + frame = TextFrame(text=out_text) + self._queue.put(frame) + + self._queue.put(LLMGenedTokensFrame(token_ids=output_token_ids)) + + if len(tool_calls_token_ids) > 0: + tool_calls_token = self._audio_llm.llm_tokenizer.decode(tool_calls_token_ids) + # print(f"{tool_calls_token=}") + function_name, function_args = extract_function_info(tool_calls_token) + self._queue.put(FunctionCallFrame(function_name=function_name, arguments=function_args)) + + return output_token_ids, out_text + + +# -------------------------------------------------------------------------------- + + +# A1->T2 +class StepAudio2TextProcessor(StepAudio2BaseProcessor): + """ + audio -> audio_LLM -> text + - A1->T2 (ASR(trancribe), Audio Understanding, S2TT(support: en,zh,ja)) + - system prompt example: + - ASR: 请记录下你所听到的语音内容。 + - Audio Understanding: Please briefly explain the important events involved in this audio clip. + - S2TT: 请仔细聆听这段语音,然后将其内容翻译成中文。 + """ + + def __init__(self, **kwargs): + kwargs["is_speaking"] = False + super().__init__(**kwargs) + + async def run_voice(self, frame: AudioRawFrame) -> AsyncGenerator[Frame, None]: + if isinstance(frame, PathAudioRawFrame): + audio = frame.path + else: + audio = bytes2TorchTensorWith16(frame.audio) + + if len(audio) == 0: + yield ErrorFrame("No audio tokens extracted") + return + + messages = [ + {"role": "system", "content": self._system_prompt}, + { + "role": "human", + "content": [{"type": "audio", "audio": audio}], + }, + {"role": "assistant", "content": None}, + ] + self._session.ctx.state["messages"] = messages + self.send_input(self._session) + async for item in self.gen(): + yield item + + +class StepASRProcessor(StepAudio2TextProcessor): + SYS_PROMPT = "请记录下你所听到的语音内容。" + + +class StepAudioCaptionProcessor(StepAudio2TextProcessor): + SYS_PROMPT = "Please briefly explain the important events involved in this audio clip." + + +class StepS2TTProcessor(StepAudio2TextProcessor): + SYS_PROMPT = "请仔细聆听这段语音,然后将其内容翻译成中文。" + + +# -------------------------------------------------------------------------------- + + +# T1->A2 +class StepText2AudioProcessor(StepAudio2BaseProcessor): + """ + text -> audio_LLM -> audio + - T1-A2 (TTS) + """ + + +class StepTTSProcessor(StepText2AudioProcessor): + async def run_text(self, frame: TextFrame) -> AsyncGenerator[Frame, None]: + user_input = frame.text.strip() + async for item in self.generator_say( + user_input, system_prompt="以自然的语速读出下面的文字。\n" + ): + yield item + + +# -------------------------------------------------------------------------------- + + +# T1-T2A2 +class StepText2TextAudioProcessor(StepAudio2BaseProcessor): + """ + text -> audio_LLM -> audio + - T1-T2A2 (T2ST) + """ + + +class StepT2STProcessor(StepText2TextAudioProcessor): + async def run_text(self, frame: TextFrame) -> AsyncGenerator[Frame, None]: + user_input = frame.text.strip() + async for item in self.generator_say( + user_input, + system_prompt="请将下面的文本翻译成英文,并用语音播报。\n", + is_push_frame=False, + ): + yield item + + +# -------------------------------------------------------------------------------- + + +# A1-T2A2 +class StepAudio2TextAudioProcessor(StepAudio2BaseProcessor): + """ + audio -> audio_LLM -> text and audio + - A1-T2A2 (S2ST, Paralingustic information understanding) + """ + + async def run_voice(self, frame: AudioRawFrame) -> AsyncGenerator[Frame, None]: + if isinstance(frame, PathAudioRawFrame): + audio = frame.path + else: + audio = bytes2TorchTensorWith16(frame.audio) + + if len(audio) == 0: + yield ErrorFrame("No audio tokens extracted") + return + + messages = [ + {"role": "system", "content": self._system_prompt}, + { + "role": "human", + "content": [{"type": "audio", "audio": audio}], + }, + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ] + self._session.ctx.state["messages"] = messages + self.send_input(self._session) + async for item in self.gen(): + yield item + + +class StepS2STProcessor(StepAudio2TextAudioProcessor): + SYS_PROMPT = "请仔细聆听这段语音,然后将其内容翻译成英文并用语音播报。" + + +class StepParalingusticInformationUnderstandingProcessor(StepAudio2TextAudioProcessor): + SYS_PROMPT = "请用语音与我交流。" + + +# -------------------------------------------------------------------------------- + + +# Chat: multi turn TQTA +class StepText2TextChatProcessor(StepAudio2BaseProcessor): + """ + text -> audio_LLM -> text + - T1->T2 (Text Query and Text Answer) + - system prompt example: + - TQTA: "You are a helpful assistant." + """ + + def __init__(self, **kwargs): + kwargs["is_speaking"] = False + super().__init__(**kwargs) + + async def run_text(self, frame: TextFrame) -> AsyncGenerator[Frame, None]: + user_input = frame.text.strip() + + self._session.chat_history.append( + {"role": "human", "content": [{"type": "text", "text": user_input}]} + ) + self._session.chat_history.append({"role": "assistant", "content": None}) + self._session.ctx.state["messages"] = self._session.chat_history.to_list() + self.send_input(self._session) + out_text = "" + async for item in self.gen(): + if isinstance(item, TextFrame): + out_text += item.text + yield item + self._session.chat_history.pop(-1) + self._session.chat_history.append({"role": "assistant", "content": out_text}) + + +# Chat: multi turn AQTA +class StepAudio2TextChatProcessor(StepAudio2BaseProcessor): + """ + audio -> audio_LLM -> text + - A1->T2 (Audio Query and Text Answer) + - system prompt example: + - AQTA: "You are a helpful assistant." + """ + + def __init__(self, **kwargs): + kwargs["is_speaking"] = False + super().__init__(**kwargs) + + async def run_voice(self, frame: AudioRawFrame) -> AsyncGenerator[Frame, None]: + if isinstance(frame, PathAudioRawFrame): + audio = frame.path + else: + audio = bytes2TorchTensorWith16(frame.audio) + + if audio is None or len(audio) == 0: + yield ErrorFrame("No audio tokens extracted") + return + + self._session.chat_history.append( + {"role": "human", "content": [{"type": "audio", "audio": audio}]} + ) + self._session.chat_history.append({"role": "assistant", "content": None}) + self._session.ctx.state["messages"] = self._session.chat_history.to_list() + self.send_input(self._session) + out_text = "" + async for item in self.gen(): + if isinstance(item, TextFrame): + out_text += item.text + yield item + self._session.chat_history.pop(-1) + self._session.chat_history.append({"role": "assistant", "content": out_text}) + + +# Chat: multi turn TQAA +class StepText2TextAudioChatProcessor(StepAudio2BaseProcessor): + """ + text -> audio_LLM -> audio and text + - T1-T2A2 (Text Query and Audio Answer) + - system prompt example: + - TQAA: "You are a helpful assistant." + """ + + async def run_text(self, frame: TextFrame) -> AsyncGenerator[Frame, None]: + user_input = frame.text.strip() + + self._session.chat_history.append( + {"role": "human", "content": [{"type": "text", "text": user_input}]} + ) + self._session.chat_history.append( + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ) + self._session.ctx.state["messages"] = self._session.chat_history.to_list() + self.send_input(self._session) + output_token_ids = [] + async for item in self.gen(): + if isinstance(item, LLMGenedTokensFrame): + output_token_ids = item.token_ids + self._session.chat_history.pop(-1) + self._session.chat_history.append( + { + "role": "assistant", + "content": [ + {"type": "text", "text": ""}, + {"type": "token", "token": output_token_ids}, + ], + } + ) + yield item + + +# Chat: multi turn AQAA +class StepAudio2TextAudioChatProcessor(StepAudio2BaseProcessor): + """ + Audio -> audio_LLM -> audio and text + - A1-T2A2 (Audio Query and Audio+text Answer) + - system prompt example: + - AQAA: "You are a helpful assistant." + """ + + async def run_voice(self, frame: AudioRawFrame) -> AsyncGenerator[Frame, None]: + if isinstance(frame, PathAudioRawFrame): + audio = frame.path + else: + audio = bytes2TorchTensorWith16(frame.audio) + + if audio is None or len(audio) == 0: + yield ErrorFrame("No audio tokens extracted") + return + + self._session.chat_history.append( + {"role": "human", "content": [{"type": "audio", "audio": audio}]} + ) + self._session.chat_history.append( + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ) + self._session.ctx.state["messages"] = self._session.chat_history.to_list() + self.send_input(self._session) + output_token_ids = [] + async for item in self.gen(): + if isinstance(item, LLMGenedTokensFrame): + output_token_ids = item.token_ids + self._session.chat_history.pop(-1) + self._session.chat_history.append( + { + "role": "assistant", + "content": [ + {"type": "text", "text": ""}, + {"type": "token", "token": output_token_ids}, + ], + } + ) + if isinstance(item, FunctionCallFrame): # send input for function call + func_res = FunctionManager.execute( + item.function_name, self._session, **item.arguments + ) + self._session.chat_history.append( + { + "role": "input", + "content": [ + {"type": "text", "text": func_res}, + { + "type": "text", + "text": "\n\n\n请用口语化形式总结检索结果,简短地回答用户的问题。", + }, + ], + } + ) + self._session.chat_history.append( + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ) + self._session.ctx.state["messages"] = self._session.chat_history.to_list() + self.send_input(self._session) + yield item + + +# ---------------------------------------------------------------------------------- + + +# A1T1->T2 +class StepAudioText2TextProcessor(StepAudio2BaseProcessor): + """ + audio+text -> audio_LLM -> text + - A1T1->T2 (Audio Understanding with Text Query) + """ + + def __init__(self, **kwargs): + kwargs["is_speaking"] = False + super().__init__(**kwargs) + + async def run_voice(self, frame: TextQuestionsAudioRawFrame) -> AsyncGenerator[Frame, None]: + audio = bytes2TorchTensorWith16(frame.audio) + + if len(audio) == 0: + yield ErrorFrame("No audio tokens extracted") + return + + messages = [ + {"role": "system", "content": self._system_prompt}, + { + "role": "human", + "content": [ + {"type": "audio", "audio": audio}, + {"type": "text", "text": frame.text}, + ], + }, + {"role": "assistant", "content": None}, + ] + self._session.ctx.state["messages"] = messages + self.send_input(self._session) + async for item in self.gen(): + yield item + + +# ---------------------------------------------------------------------------------- + + +# A1T1->T2A2 +class StepAudioText2TextAudioProcessor(StepAudio2BaseProcessor): + """ + audio+text -> audio_LLM -> audio+text + - A1T1->T2A2 (Audio Understanding with Text Query) + """ + + async def run_voice(self, frame: TextQuestionsAudioRawFrame) -> AsyncGenerator[Frame, None]: + audio = bytes2TorchTensorWith16(frame.audio) + + if len(audio) == 0: + yield ErrorFrame("No audio tokens extracted") + return + + messages = [ + {"role": "system", "content": self._system_prompt}, + { + "role": "human", + "content": [ + {"type": "audio", "audio": audio}, + {"type": "text", "text": frame.text}, + ], + }, + { + "role": "assistant", + "content": "", + "eot": False, + }, # Insert for speech response + ] + self._session.ctx.state["messages"] = messages + self.send_input(self._session) + async for item in self.gen(): + yield item + + +# audio understanding with text query +class StepMMAUProcessor(StepAudioText2TextAudioProcessor): + SYS_PROMPT = "You are an expert in audio analysis, please analyze the audio content and answer the questions accurately." + + +""" +python -m src.processors.voice.step_audio2_processor +""" +if __name__ == "__main__": + from pathlib import Path + + import wave + from apipeline.frames import AudioRawFrame, StartFrame, EndFrame, CancelFrame + + from src.common.logger import Logger + from src.types.frames import PathAudioRawFrame + from src.cmd.bots.voice.step_audio2.helper import ( + get_step_audio2_processor, + ) + from src.types.ai_conf import AIConfig, LLMConfig + + Logger.init(os.getenv("LOG_LEVEL", "info").upper(), is_file=False, is_console=True) + + async def run_aqaa(): + processor = get_step_audio2_processor( + LLMConfig( + processor="StepAudio2TextAudioChatProcessor", + args={ + "init_system_prompt": "", + "prompt_wav": "./assets/default_male.wav", + "warmup_cn": 2, + "chat_history_size": None, + "text_stream_out": False, + "no_stream_sleep_time": 0.5, + "lm_model_name_or_path": "./models/stepfun-ai/Step-Audio-2-mini", + }, + ) + ) + await processor.start(StartFrame()) + for round_idx, audio_path in enumerate( + [ + "./deps/StepAudio2/assets/multi-turn-round1-听说荡口古镇从下个月开始取消门票了,你知道这事吗。.wav", + "./deps/StepAudio2/assets/multi-turn-round2-新闻说九月十九号就免费开放了。好像整个古镇都升级改造了,现在变成开放式街区了。.wav", + ] + ): + print("round: ", round_idx) + frame_iter = processor.run_voice( + PathAudioRawFrame( + path=audio_path, + audio=b"", + ) + ) + audio = b"" + async for frame in frame_iter: + if isinstance(frame, AudioRawFrame): + audio += frame.audio + print(f"{round_idx=} gen_frame-->", frame) + print(f"{round_idx=} {processor._session.chat_history=}") + wav_path = Path( + f"{ASSETS_DIR}/StepAudio2/output-processor-chunks-stream-{round_idx}.wav" + ) + with wave.open(str(wav_path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(audio) + await processor.stop(EndFrame()) + + asyncio.run(run_aqaa()) diff --git a/src/types/ai_conf.py b/src/types/ai_conf.py index 7c393cf7..133fffbe 100644 --- a/src/types/ai_conf.py +++ b/src/types/ai_conf.py @@ -29,87 +29,78 @@ DEFAULT_LLM_LANG = "zh" +class BaseConfig(BaseModel): + tag: Optional[str] = None + args: Optional[dict] = None + + +class MCPServerConfig(BaseModel): + transport: Optional[str] = "stdio" + parameters: Optional[Dict[str, Any]] = None + + class StreamConfig(BaseModel): tag: Optional[str] = "daily_room_stream" - args: Optional[dict] = None -class VADConfig(BaseModel): - tag: Optional[str] = None - args: Optional[dict] = None +class VADConfig(BaseConfig): + pass -class TurnConfig(BaseModel): - tag: Optional[str] = None - args: Optional[dict] = None +class TurnConfig(BaseConfig): + pass -class VisionDetectorConfig(BaseModel): - tag: Optional[str] = None - args: Optional[dict] = None +class VisionDetectorConfig(BaseConfig): + pass -class VisionOCRConfig(BaseModel): +class VisionOCRConfig(BaseConfig): trigger_texts: Optional[List[str]] = None - tag: Optional[str] = None - args: Optional[dict] = None -class ImageGenConfig(BaseModel): - tag: Optional[str] = None - args: Optional[dict] = None +class ImageGenConfig(BaseConfig): + pass -class ASRConfig(BaseModel): - tag: Optional[str] = None - args: Optional[dict] = None +class ASRConfig(BaseConfig): + pass -class PuncConfig(BaseModel): - tag: Optional[str] = None - args: Optional[dict] = None +class PuncConfig(BaseConfig): + pass -class AvatarConfig(BaseModel): - tag: Optional[str] = None - args: Optional[dict] = None +class AvatarConfig(BaseConfig): + pass -class LLMConfig(BaseModel): +class LLMConfig(BaseConfig): + init_prompt: Optional[str] = None + processor: Optional[str] = None base_url: Optional[str] = None model: Optional[str] = None language: Optional[str] = None messages: Optional[List[dict]] = None tools: Optional[List[dict]] = None # is_use_tools_description: Optional[bool] = False - tag: Optional[str] = None - args: Optional[dict] = None -class TranslateLLMConfig(BaseModel): +class TranslateLLMConfig(BaseConfig): init_prompt: Optional[str] = None model: Optional[str] = None src: Optional[str] = None target: Optional[str] = None streaming: Optional[bool] = False prompt_tpl: Optional[str] = None - tag: Optional[str] = None - args: Optional[dict] = None -class TTSConfig(BaseModel): +class TTSConfig(BaseConfig): voice: Optional[str] = None language: Optional[str] = None aggregate_sentences: Optional[bool] = True push_text_frames: Optional[bool] = True remove_punctuation: Optional[bool] = False - tag: Optional[str] = None - args: Optional[dict] = None - - -class MCPServerConfig(BaseModel): - transport: Optional[str] = "stdio" - parameters: Optional[Dict[str, Any]] = None class AIConfig(BaseModel): diff --git a/src/types/frames/data_frames.py b/src/types/frames/data_frames.py index cd7581c5..d0402f01 100644 --- a/src/types/frames/data_frames.py +++ b/src/types/frames/data_frames.py @@ -364,3 +364,30 @@ def __str__(self): return ( f"{super_str} animation_json: {self.animation_json} avatar_status: {self.avatar_status}" ) + + +@dataclass +class TextQuestionsAudioRawFrame(AudioRawFrame, TextFrame): + """text questions with audio frame""" + + +@dataclass +class LLMGenedTokensFrame(Frame): + """llm gened tokens frame""" + + token_ids: list[int] = field(default_factory=list) + + def __str__(self): + return f"{self.name}(token_ids: {self.token_ids})" + + +@dataclass +class FunctionCallFrame(Frame): + """llm gened tokens frame""" + + function_name: str = "" + tool_call_id: str = "" + arguments: dict = field(default_factory=dict) + + def __str__(self): + return f"{self.name}(function_name: {self.function_name}, tool_call_id: {self.tool_call_id}, arguments: {self.arguments})" diff --git a/src/types/llm/transformers.py b/src/types/llm/transformers.py index be0eb09b..595ebfe0 100644 --- a/src/types/llm/transformers.py +++ b/src/types/llm/transformers.py @@ -131,6 +131,21 @@ def to_dict(self) -> dict: return super().to_dict() +""" +python -m src.types.llm.transformers +""" if __name__ == "__main__": args = TransformersLMArgs() print(args.to_dict().items()) + + unset_args = args.update( + **{ + "lm_model_name_or_path": "HuggingFaceTB/SmolLM-360M", + "lm_model_name_or_path1": "HuggingFaceTB/SmolLM-360M-Instruct", + "lm_gen_repetition_penalty": 1.011, + } + ) + assert args.lm_model_name_or_path == "HuggingFaceTB/SmolLM-360M" + assert args.lm_gen_repetition_penalty == 1.011 + print("TransformersLMArgs", args) + print(f"{unset_args=}")