From c0c2f07f1e2358e035285ead6d77be076dfa0854 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ku=C4=8Dera?= Date: Fri, 31 Oct 2025 02:30:34 +0100 Subject: [PATCH 01/13] Implement smart text splitting for TTS processing Added text splitting functionality to handle long inputs efficiently. --- gradio_tts_app.py | 111 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 98 insertions(+), 13 deletions(-) diff --git a/gradio_tts_app.py b/gradio_tts_app.py index cda7912b..b34a1936 100644 --- a/gradio_tts_app.py +++ b/gradio_tts_app.py @@ -3,9 +3,11 @@ import torch import gradio as gr from chatterbox.tts import ChatterboxTTS +import re DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +MAX_CHARS_PER_CHUNK = 400 # Adjust this based on testing def set_seed(seed: int): @@ -21,35 +23,118 @@ def load_model(): return model +def split_text_smart(text, max_chars=MAX_CHARS_PER_CHUNK): + """Split text into chunks at sentence boundaries""" + if len(text) <= max_chars: + return [text] + + # Split by sentences (., !, ?) + sentences = re.split(r'([.!?]+\s*)', text) + + chunks = [] + current_chunk = "" + + for i in range(0, len(sentences), 2): + sentence = sentences[i] + punctuation = sentences[i + 1] if i + 1 < len(sentences) else "" + full_sentence = sentence + punctuation + + # If adding this sentence would exceed limit + if len(current_chunk) + len(full_sentence) > max_chars: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = full_sentence + else: + # Single sentence is too long, split by words + words = full_sentence.split() + temp_chunk = "" + for word in words: + if len(temp_chunk) + len(word) + 1 <= max_chars: + temp_chunk += word + " " + else: + if temp_chunk: + chunks.append(temp_chunk.strip()) + temp_chunk = word + " " + current_chunk = temp_chunk + else: + current_chunk += full_sentence + + if current_chunk.strip(): + chunks.append(current_chunk.strip()) + + return chunks + + def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw, min_p, top_p, repetition_penalty): if model is None: model = ChatterboxTTS.from_pretrained(DEVICE) + if not text or not text.strip(): + return None + + # Set seed if specified if seed_num != 0: set_seed(int(seed_num)) - wav = model.generate( - text, - audio_prompt_path=audio_prompt_path, - exaggeration=exaggeration, - temperature=temperature, - cfg_weight=cfgw, - min_p=min_p, - top_p=top_p, - repetition_penalty=repetition_penalty, - ) - return (model.sr, wav.squeeze(0).numpy()) + # Split text into chunks + chunks = split_text_smart(text.strip()) + + print(f"Processing {len(chunks)} chunk(s)...") + + try: + audio_chunks = [] + + for i, chunk in enumerate(chunks): + print(f"Generating chunk {i+1}/{len(chunks)}: {chunk[:50]}...") + + wav = model.generate( + chunk, + audio_prompt_path=audio_prompt_path, + exaggeration=exaggeration, + temperature=temperature, + cfg_weight=cfgw, + min_p=min_p, + top_p=top_p, + repetition_penalty=repetition_penalty, + ) + + # Convert to numpy and add to chunks + audio_chunks.append(wav.squeeze(0).numpy()) + + # Concatenate all audio chunks + if len(audio_chunks) > 1: + # Add small silence between chunks (0.1 seconds) + silence = np.zeros(int(model.sr * 0.1)) + final_audio = audio_chunks[0] + for chunk in audio_chunks[1:]: + final_audio = np.concatenate([final_audio, silence, chunk]) + print(f"Successfully generated {len(chunks)} chunks!") + return (model.sr, final_audio) + else: + return (model.sr, audio_chunks[0]) + + except Exception as e: + print(f"Error during generation: {str(e)}") + return None with gr.Blocks() as demo: model_state = gr.State(None) # Loaded once per session/user + gr.Markdown(""" + # Chatterbox TTS + + **Note:** Long texts are automatically split into chunks for processing. + Each chunk is limited to ~400 characters for optimal quality. + """) + with gr.Row(): with gr.Column(): text = gr.Textbox( value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.", - label="Text to synthesize (max chars 300)", - max_lines=5 + label="Text to synthesize (automatically chunked if too long)", + lines=8, + max_lines=15 ) ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value=None) exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5) From e4f6c325e2a5284a49d2c2b99c35b8cc50e3b7b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ku=C4=8Dera?= Date: Mon, 3 Nov 2025 19:14:31 -0600 Subject: [PATCH 02/13] add multimodel gradio app --- tts_multi_gradio.cmd | 9 ++ tts_multi_gradio.py | 338 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 347 insertions(+) create mode 100644 tts_multi_gradio.cmd create mode 100644 tts_multi_gradio.py diff --git a/tts_multi_gradio.cmd b/tts_multi_gradio.cmd new file mode 100644 index 00000000..e10ccd2d --- /dev/null +++ b/tts_multi_gradio.cmd @@ -0,0 +1,9 @@ +@echo off +echo Starting Chatterbox TTS Gradio App... +echo. + +call .venv\Scripts\activate.bat +start http://127.0.0.1:7860 +python tts_multi_gradio.py + +if errorlevel 1 pause \ No newline at end of file diff --git a/tts_multi_gradio.py b/tts_multi_gradio.py new file mode 100644 index 00000000..fec3e37f --- /dev/null +++ b/tts_multi_gradio.py @@ -0,0 +1,338 @@ +import random +import numpy as np +import torch +import gradio as gr +from chatterbox.tts import ChatterboxTTS +from transformers import BarkModel +import re +import os +from pathlib import Path + + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +MAX_CHARS_PER_CHUNK = 400 # Adjust this based on testing + +# Custom Bark model configuration +CUSTOM_BARK_MODELS = { + "Default (suno/bark-small)": None, + "Fine-tuned": "C:/ChatterboxTraining/checkpoints/final_bark_model", + # Add more custom models here: + # "My Voice Clone": "C:/path/to/another/model", +} + + +def set_seed(seed: int): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + random.seed(seed) + np.random.seed(seed) + + +def load_model(): + """Load ChatterboxTTS with default Bark model""" + print(f"Loading Chatterbox TTS on {DEVICE}...") + model = ChatterboxTTS.from_pretrained(DEVICE) + print("✓ Chatterbox TTS loaded with default Bark model") + return model + + +def switch_bark_model(model, model_choice): + """Switch the Bark model to a different version""" + if model is None: + model = ChatterboxTTS.from_pretrained(DEVICE) + + custom_path = CUSTOM_BARK_MODELS.get(model_choice) + + if custom_path: + # Check if path exists + if not os.path.exists(custom_path): + return model, f"❌ Error: Model path not found: {custom_path}" + + print(f"Loading custom Bark model from: {custom_path}") + try: + # Replace the Bark model + model.bark_model = BarkModel.from_pretrained(custom_path).to(DEVICE) + print(f"✓ Loaded custom Bark model: {model_choice}") + return model, f"✓ Loaded: {model_choice}" + except Exception as e: + return model, f"❌ Error loading model: {str(e)}" + else: + print("Loading default Bark model...") + try: + # Reload default model + model.bark_model = BarkModel.from_pretrained("suno/bark-small").to(DEVICE) + print("✓ Loaded default Bark model") + return model, "✓ Loaded: Default Bark model" + except Exception as e: + return model, f"❌ Error loading model: {str(e)}" + + +def split_text_smart(text, max_chars=MAX_CHARS_PER_CHUNK): + """Split text into chunks at sentence boundaries""" + if len(text) <= max_chars: + return [text] + + # Split by sentences (., !, ?) + sentences = re.split(r'([.!?]+\s*)', text) + + chunks = [] + current_chunk = "" + + for i in range(0, len(sentences), 2): + sentence = sentences[i] + punctuation = sentences[i + 1] if i + 1 < len(sentences) else "" + full_sentence = sentence + punctuation + + # If adding this sentence would exceed limit + if len(current_chunk) + len(full_sentence) > max_chars: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = full_sentence + else: + # Single sentence is too long, split by words + words = full_sentence.split() + temp_chunk = "" + for word in words: + if len(temp_chunk) + len(word) + 1 <= max_chars: + temp_chunk += word + " " + else: + if temp_chunk: + chunks.append(temp_chunk.strip()) + temp_chunk = word + " " + current_chunk = temp_chunk + else: + current_chunk += full_sentence + + if current_chunk.strip(): + chunks.append(current_chunk.strip()) + + return chunks + + +def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw, min_p, top_p, repetition_penalty): + if model is None: + model = ChatterboxTTS.from_pretrained(DEVICE) + + if not text or not text.strip(): + return None + + # Set seed if specified + if seed_num != 0: + set_seed(int(seed_num)) + + # Split text into chunks + chunks = split_text_smart(text.strip()) + + print(f"Processing {len(chunks)} chunk(s)...") + + try: + audio_chunks = [] + + for i, chunk in enumerate(chunks): + print(f"Generating chunk {i+1}/{len(chunks)}: {chunk[:50]}...") + + wav = model.generate( + chunk, + audio_prompt_path=audio_prompt_path, + exaggeration=exaggeration, + temperature=temperature, + cfg_weight=cfgw, + min_p=min_p, + top_p=top_p, + repetition_penalty=repetition_penalty, + ) + + # Convert to numpy and add to chunks + audio_chunks.append(wav.squeeze(0).numpy()) + + # Concatenate all audio chunks + if len(audio_chunks) > 1: + # Add small silence between chunks (0.1 seconds) + silence = np.zeros(int(model.sr * 0.1)) + final_audio = audio_chunks[0] + for chunk in audio_chunks[1:]: + final_audio = np.concatenate([final_audio, silence, chunk]) + print(f"Successfully generated {len(chunks)} chunks!") + return (model.sr, final_audio) + else: + return (model.sr, audio_chunks[0]) + + except Exception as e: + print(f"Error during generation: {str(e)}") + return None + + +with gr.Blocks(theme=gr.themes.Soft()) as demo: + model_state = gr.State(None) # Loaded once per session/user + + gr.Markdown(""" + # 🎙️ Chatterbox TTS + + Advanced text-to-speech with custom Bark model support. + + **Note:** Long texts are automatically split into chunks for processing. + Each chunk is limited to ~400 characters for optimal quality. + """) + + with gr.Row(): + with gr.Column(scale=1): + # Model Selection + gr.Markdown("### 🔧 Model Configuration") + model_dropdown = gr.Dropdown( + choices=list(CUSTOM_BARK_MODELS.keys()), + value="Default (suno/bark-small)", + label="Bark Model", + info="Select which Bark model to use" + ) + model_status = gr.Textbox( + label="Model Status", + value="Default model loaded", + interactive=False, + lines=1 + ) + load_model_btn = gr.Button("🔄 Load Selected Model", variant="secondary", size="sm") + + gr.Markdown("---") + + # Text Input + gr.Markdown("### 📝 Text Input") + text = gr.Textbox( + value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.", + label="Text to synthesize", + lines=6, + max_lines=12, + placeholder="Enter text to convert to speech..." + ) + + # Reference Audio + gr.Markdown("### 🎵 Voice Reference") + ref_wav = gr.Audio( + sources=["upload", "microphone"], + type="filepath", + label="Reference Audio File", + value=None + ) + + # Main Controls + gr.Markdown("### 🎛️ Voice Controls") + exaggeration = gr.Slider( + 0.25, 2, + step=.05, + label="Exaggeration", + info="Neutral = 0.5, extreme values can be unstable", + value=.5 + ) + cfg_weight = gr.Slider( + 0.0, 1, + step=.05, + label="CFG Weight / Pace", + info="Controls generation guidance", + value=0.5 + ) + + # Advanced Options + with gr.Accordion("⚙️ Advanced Options", open=False): + seed_num = gr.Number( + value=0, + label="Random Seed", + info="0 for random, set number for reproducible results" + ) + temp = gr.Slider( + 0.05, 5, + step=.05, + label="Temperature", + info="Higher = more creative/variable", + value=.8 + ) + min_p = gr.Slider( + 0.00, 1.00, + step=0.01, + label="Min P", + info="Newer sampler. 0.02-0.1 recommended. 0.00 disables", + value=0.05 + ) + top_p = gr.Slider( + 0.00, 1.00, + step=0.01, + label="Top P", + info="Original sampler. 1.0 disables (recommended)", + value=1.00 + ) + repetition_penalty = gr.Slider( + 1.00, 2.00, + step=0.1, + label="Repetition Penalty", + info="Reduces repeated phrases", + value=1.2 + ) + + # Generate Button + run_btn = gr.Button("🎬 Generate Speech", variant="primary", size="lg") + + with gr.Column(scale=1): + gr.Markdown("### 🔊 Output") + audio_output = gr.Audio(label="Generated Audio") + + gr.Markdown(""" + --- + ### 💡 Tips + + **Model Selection:** + - **Default**: Original Bark model from Suno AI + - **Fine-tuned**: Your custom trained model + + **Voice Cloning:** + - Upload 5-10 seconds of clear reference audio + - Single speaker, minimal background noise + + **Text Processing:** + - Texts over 400 chars are auto-chunked + - Use proper punctuation for better prosody + + **Parameter Guide:** + - **Exaggeration**: Controls emotion intensity + - **Temperature**: Higher = more variation + - **CFG Weight**: Affects pacing and adherence + """) + + # Event Handlers + demo.load(fn=load_model, inputs=[], outputs=model_state) + + load_model_btn.click( + fn=switch_bark_model, + inputs=[model_state, model_dropdown], + outputs=[model_state, model_status] + ) + + run_btn.click( + fn=generate, + inputs=[ + model_state, + text, + ref_wav, + exaggeration, + temp, + seed_num, + cfg_weight, + min_p, + top_p, + repetition_penalty, + ], + outputs=audio_output, + ) + +if __name__ == "__main__": + print("\n" + "="*60) + print("🎙️ CHATTERBOX TTS - CUSTOM MODEL EDITION") + print("="*60) + print(f"Device: {DEVICE}") + print(f"Available models: {len(CUSTOM_BARK_MODELS)}") + for model_name in CUSTOM_BARK_MODELS.keys(): + print(f" - {model_name}") + print("="*60 + "\n") + + demo.queue( + max_size=50, + default_concurrency_limit=1, + ).launch(share=True) \ No newline at end of file From 1e3e092585788aec2a13062a7a7a886dd5b03e8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ku=C4=8Dera?= Date: Tue, 4 Nov 2025 02:15:47 +0100 Subject: [PATCH 03/13] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index d6651ef7..c0c36ab0 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,8 @@ # Chatterbox TTS +## For multimodel use tts_multi_gradio.py + [![Alt Text](https://img.shields.io/badge/listen-demo_samples-blue)](https://resemble-ai.github.io/chatterbox_demopage/) [![Alt Text](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/ResembleAI/Chatterbox) [![Alt Text](https://static-public.podonos.com/badges/insight-on-pdns-sm-dark.svg)](https://podonos.com/resembleai/chatterbox) From 78cf39b678114d465ea4fe2e2127a296bfb5ff73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ku=C4=8Dera?= Date: Tue, 4 Nov 2025 02:16:44 +0100 Subject: [PATCH 04/13] Document fork of Chatterbox-TTS-Extended Added a note about forking the original project due to output quality issues. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c0c36ab0..e81dd823 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ # Chatterbox TTS ## For multimodel use tts_multi_gradio.py +## I forked the og because Chatterbox-TTS-Extended managed to lower the output quality [![Alt Text](https://img.shields.io/badge/listen-demo_samples-blue)](https://resemble-ai.github.io/chatterbox_demopage/) [![Alt Text](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/ResembleAI/Chatterbox) From 079a88d606045698b76602dcda42e553763c8025 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ku=C4=8Dera?= Date: Tue, 4 Nov 2025 02:16:57 +0100 Subject: [PATCH 05/13] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index e81dd823..1b2fa2af 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ # Chatterbox TTS ## For multimodel use tts_multi_gradio.py + ## I forked the og because Chatterbox-TTS-Extended managed to lower the output quality [![Alt Text](https://img.shields.io/badge/listen-demo_samples-blue)](https://resemble-ai.github.io/chatterbox_demopage/) From 35b9f858eeecea3cb79fc45425685815ef0a316c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ku=C4=8Dera?= Date: Wed, 12 Nov 2025 03:05:20 -0600 Subject: [PATCH 06/13] custom model support --- multilingual_app.py | 196 ++++++++++++++++++++++++++++++++------------ 1 file changed, 142 insertions(+), 54 deletions(-) diff --git a/multilingual_app.py b/multilingual_app.py index 51e9c693..edca2060 100644 --- a/multilingual_app.py +++ b/multilingual_app.py @@ -3,10 +3,20 @@ import torch from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES import gradio as gr +from safetensors.torch import load_file as load_safetensors +import os DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"🚀 Running on device: {DEVICE}") +# --- Custom T3 Model Configuration --- +CUSTOM_T3_MODELS = { + "Default": None, + "Czech (t3_cs)": "t3_cs", # Path to your safetensors file + # Add more custom models here: + # "Another Language": "path/to/model.safetensors", +} + # --- Global Model Initialization --- MODEL = None @@ -15,6 +25,10 @@ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac", "text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب." }, + "cs": { # Add Czech language + "audio": None, + "text": "Dobrý den, vítáme vás v našem testu syntézy řeči" + }, "da": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/da_m1.flac", "text": "Sidste måned nåede vi en ny milepæl med to milliarder visninger på vores YouTube-kanal." @@ -116,8 +130,12 @@ def default_text_for_ui(lang: str) -> str: def get_supported_languages_display() -> str: """Generate a formatted display of all supported languages.""" + # Combine base supported languages with any custom ones + all_langs = dict(SUPPORTED_LANGUAGES) + all_langs.update({"cs": "Czech"}) # Add custom languages here + language_items = [] - for code, name in sorted(SUPPORTED_LANGUAGES.items()): + for code, name in sorted(all_langs.items()): language_items.append(f"**{name}** (`{code}`)") # Split into 2 lines @@ -126,7 +144,7 @@ def get_supported_languages_display() -> str: line2 = " • ".join(language_items[mid:]) return f""" -### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total) +### 🌍 Supported Languages ({len(all_langs)} total) {line1} {line2} @@ -134,8 +152,7 @@ def get_supported_languages_display() -> str: def get_or_load_model(): - """Loads the ChatterboxMultilingualTTS model if it hasn't been loaded already, - and ensures it's on the correct device.""" + """Loads the ChatterboxMultilingualTTS model if it hasn't been loaded already.""" global MODEL if MODEL is None: print("Model not loaded, initializing...") @@ -149,12 +166,49 @@ def get_or_load_model(): raise return MODEL + +def switch_t3_model(model_choice: str): + """Switch the T3 model to a custom version""" + global MODEL + + if MODEL is None: + MODEL = get_or_load_model() + + custom_path = CUSTOM_T3_MODELS.get(model_choice) + + if custom_path: + # Check if path exists + if not os.path.exists(custom_path): + return f"❌ Error: Model path not found: {custom_path}" + + print(f"Loading custom T3 model from: {custom_path}") + try: + # Load the custom T3 state dict + t3_state = load_safetensors(custom_path, device="cpu") + MODEL.t3.load_state_dict(t3_state) + MODEL.t3.to(DEVICE).eval() + print(f"✓ Loaded custom T3 model: {model_choice}") + return f"✓ Loaded: {model_choice}" + except Exception as e: + return f"❌ Error loading model: {str(e)}" + else: + print("Reloading default T3 model...") + try: + # Reload the entire model to get default T3 + MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE) + print("✓ Loaded default T3 model") + return "✓ Loaded: Default T3 model" + except Exception as e: + return f"❌ Error loading model: {str(e)}" + + # Attempt to load the model at startup. try: get_or_load_model() except Exception as e: print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}") + def set_seed(seed: int): """Sets the random seed for reproducibility across torch, numpy, and random.""" torch.manual_seed(seed) @@ -163,16 +217,6 @@ def set_seed(seed: int): torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) - -def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None: - """ - Decide which audio prompt to use: - - If user provided a path (upload/mic/url), use it. - - Else, fall back to language-specific default (if any). - """ - if provided_path and str(provided_path).strip(): - return provided_path - return LANGUAGE_CONFIG.get(language_id, {}).get("audio") def generate_tts_audio( @@ -184,26 +228,7 @@ def generate_tts_audio( seed_num_input: int = 0, cfgw_input: float = 0.5 ) -> tuple[int, np.ndarray]: - """ - Generate high-quality speech audio from text using Chatterbox Multilingual model with optional reference audio styling. - Supported languages: English, French, German, Spanish, Italian, Portuguese, and Hindi. - - This tool synthesizes natural-sounding speech from input text. When a reference audio file - is provided, it captures the speaker's voice characteristics and speaking style. The generated audio - maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided. - - Args: - text_input (str): The text to synthesize into speech (maximum 300 characters) - language_id (str): The language code for synthesis (eg. en, fr, de, es, it, pt, hi) - audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None. - exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5. - temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8. - seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0. - cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5, 0 for language transfer. - - Returns: - tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray) - """ + """Generate TTS audio with custom T3 model support""" current_model = get_or_load_model() if current_model is None: @@ -212,7 +237,7 @@ def generate_tts_audio( if seed_num_input != 0: set_seed(int(seed_num_input)) - print(f"Generating audio for text: '{text_input[:50]}...'") + print(f"Generating audio for text: '{text_input[:50]}...' in language: {language_id}") # Handle optional audio prompt chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id) @@ -236,27 +261,50 @@ def generate_tts_audio( print("Audio generation complete.") return (current_model.sr, wav.squeeze(0).numpy()) -with gr.Blocks() as demo: + +with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ - # Chatterbox Multilingual Demo - Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages. + # 🎙️ Chatterbox Multilingual Demo with Custom T3 Support + Generate high-quality multilingual speech from text with reference audio styling and custom model support. """ ) # Display supported languages gr.Markdown(get_supported_languages_display()) + with gr.Row(): with gr.Column(): - initial_lang = "fr" + # Model Selection Section + gr.Markdown("### 🔧 Model Configuration") + t3_model_dropdown = gr.Dropdown( + choices=list(CUSTOM_T3_MODELS.keys()), + value="Default", + label="T3 Model", + info="Select which T3 model to use" + ) + model_status = gr.Textbox( + label="Model Status", + value="Default model loaded", + interactive=False, + lines=1 + ) + load_t3_btn = gr.Button("🔄 Load Selected T3 Model", variant="secondary", size="sm") + + gr.Markdown("---") + + # TTS Controls + initial_lang = "cs" # Default to Czech for testing text = gr.Textbox( value=default_text_for_ui(initial_lang), label="Text to synthesize (max chars 300)", max_lines=5 ) + # Get all supported languages including custom ones + all_language_codes = list(SUPPORTED_LANGUAGES.keys()) + ["cs"] language_id = gr.Dropdown( - choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()), + choices=sorted(set(all_language_codes)), value=initial_lang, label="Language", info="Select the language for text-to-speech synthesis" @@ -270,35 +318,62 @@ def generate_tts_audio( ) gr.Markdown( - "💡 **Note**: Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip's language. To mitigate this, set the CFG weight to 0.", + "💡 **Note**: Ensure that the reference clip matches the specified language tag. For custom languages, set CFG weight to 0 if experiencing accent issues.", elem_classes=["audio-note"] ) exaggeration = gr.Slider( - 0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5 + 0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5)", value=.5 ) cfg_weight = gr.Slider( - 0.2, 1, step=.05, label="CFG/Pace", value=0.5 + 0.0, 1, step=.05, label="CFG/Pace (0 for language transfer)", value=0.5 ) with gr.Accordion("More options", open=False): seed_num = gr.Number(value=0, label="Random seed (0 for random)") temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8) - run_btn = gr.Button("Generate", variant="primary") + run_btn = gr.Button("🎬 Generate Speech", variant="primary", size="lg") with gr.Column(): - audio_output = gr.Audio(label="Output Audio") + gr.Markdown("### 📊 Output") + audio_output = gr.Audio(label="Generated Audio") + + gr.Markdown(""" + --- + ### 💡 Tips + + **Custom T3 Models:** + - Load your fine-tuned T3 models for new languages + - Place `.safetensors` files in the working directory + - Switch between models without restarting + + **Voice Cloning:** + - Upload 5-10 seconds of clear reference audio + - Single speaker, minimal background noise + - Match reference language to target language + + **Parameters:** + - **Exaggeration**: Controls emotion intensity + - **Temperature**: Higher = more variation + - **CFG Weight**: Set to 0 for language transfer without accent + """) - def on_language_change(lang, current_ref, current_text): - return default_audio_for_ui(lang), default_text_for_ui(lang) + def on_language_change(lang, current_ref, current_text): + return default_audio_for_ui(lang), default_text_for_ui(lang) - language_id.change( - fn=on_language_change, - inputs=[language_id, ref_wav, text], - outputs=[ref_wav, text], - show_progress=False - ) + language_id.change( + fn=on_language_change, + inputs=[language_id, ref_wav, text], + outputs=[ref_wav, text], + show_progress=False + ) + + load_t3_btn.click( + fn=switch_t3_model, + inputs=[t3_model_dropdown], + outputs=[model_status] + ) run_btn.click( fn=generate_tts_audio, @@ -314,4 +389,17 @@ def on_language_change(lang, current_ref, current_text): outputs=[audio_output], ) -demo.launch(mcp_server=True) +if __name__ == "__main__": + print("\n" + "="*60) + print("🎙️ CHATTERBOX MULTILINGUAL TTS - CUSTOM T3 EDITION") + print("="*60) + print(f"Device: {DEVICE}") + print(f"Available T3 models: {len(CUSTOM_T3_MODELS)}") + for model_name in CUSTOM_T3_MODELS.keys(): + print(f" - {model_name}") + print("="*60 + "\n") + + demo.queue( + max_size=50, + default_concurrency_limit=1, + ).launch(share=True) \ No newline at end of file From d4167a2869780e0831ab02b2b3aa1203f99b30a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ku=C4=8Dera?= Date: Wed, 12 Nov 2025 03:07:40 -0600 Subject: [PATCH 07/13] bugfix --- multilingual_app.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/multilingual_app.py b/multilingual_app.py index edca2060..706035af 100644 --- a/multilingual_app.py +++ b/multilingual_app.py @@ -12,9 +12,9 @@ # --- Custom T3 Model Configuration --- CUSTOM_T3_MODELS = { "Default": None, - "Czech (t3_cs)": "t3_cs", # Path to your safetensors file + "Czech (t3_cs)": "C:/ChatterboxTraining/t3/t3_cs", # Path to your safetensors file (no extension) # Add more custom models here: - # "Another Language": "path/to/model.safetensors", + # "Another Language": "C:/path/to/model", } # --- Global Model Initialization --- @@ -177,19 +177,28 @@ def switch_t3_model(model_choice: str): custom_path = CUSTOM_T3_MODELS.get(model_choice) if custom_path: - # Check if path exists - if not os.path.exists(custom_path): - return f"❌ Error: Model path not found: {custom_path}" + # The path should point directly to the safetensors file + # Try with no extension first, then .safetensors + if not custom_path.endswith('.safetensors'): + safetensors_path = custom_path + '.safetensors' if os.path.exists(custom_path + '.safetensors') else custom_path + else: + safetensors_path = custom_path + + if not os.path.exists(safetensors_path): + return f"❌ Error: Model file not found: {safetensors_path}" - print(f"Loading custom T3 model from: {custom_path}") + print(f"Loading custom T3 model from: {safetensors_path}") try: - # Load the custom T3 state dict - t3_state = load_safetensors(custom_path, device="cpu") + # Load the custom T3 state dict (exactly like your friend's code) + t3_state = load_safetensors(safetensors_path, device="cpu") MODEL.t3.load_state_dict(t3_state) MODEL.t3.to(DEVICE).eval() print(f"✓ Loaded custom T3 model: {model_choice}") - return f"✓ Loaded: {model_choice}" + return f"✓ Loaded: {model_choice} from {safetensors_path}" except Exception as e: + import traceback + error_details = traceback.format_exc() + print(f"Error details:\n{error_details}") return f"❌ Error loading model: {str(e)}" else: print("Reloading default T3 model...") @@ -345,7 +354,7 @@ def generate_tts_audio( **Custom T3 Models:** - Load your fine-tuned T3 models for new languages - - Place `.safetensors` files in the working directory + - Point directly to the `.safetensors` file (e.g., `C:/path/to/t3_cs` or `C:/path/to/t3_cs.safetensors`) - Switch between models without restarting **Voice Cloning:** From 0c88f2a6f10450843c5daa7e6cb3899bbb0f7b2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ku=C4=8Dera?= Date: Wed, 12 Nov 2025 03:26:04 -0600 Subject: [PATCH 08/13] Add files via upload --- multilingual_app.cmd | 9 ++++ multilingual_app.py | 100 +++++++++++++++++++++++++++++++------------ tts_multi_gradio.py | 4 +- 3 files changed, 83 insertions(+), 30 deletions(-) create mode 100644 multilingual_app.cmd diff --git a/multilingual_app.cmd b/multilingual_app.cmd new file mode 100644 index 00000000..a07a961a --- /dev/null +++ b/multilingual_app.cmd @@ -0,0 +1,9 @@ +@echo off +echo Starting Chatterbox TTS Gradio App... +echo. + +call .venv\Scripts\activate.bat +start http://127.0.0.1:7860 +python multilingual_app.py + +if errorlevel 1 pause \ No newline at end of file diff --git a/multilingual_app.py b/multilingual_app.py index 706035af..21cd76a3 100644 --- a/multilingual_app.py +++ b/multilingual_app.py @@ -12,9 +12,9 @@ # --- Custom T3 Model Configuration --- CUSTOM_T3_MODELS = { "Default": None, - "Czech (t3_cs)": "C:/ChatterboxTraining/t3/t3_cs", # Path to your safetensors file (no extension) + "Czech (t3_cs)": "C:/ChatterboxTraining/t3/t3_cs.safetensors", # FIXED: Full path with extension # Add more custom models here: - # "Another Language": "C:/path/to/model", + # "Another Language": "C:/path/to/model.safetensors", } # --- Global Model Initialization --- @@ -115,7 +115,7 @@ }, "zh": { "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/zh_f2.flac", - "text": "上个月,我们达到了一个新的里程碑. 我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。" + "text": "上个月,我们达到了一个新的里程碑. 我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。" }, } @@ -158,11 +158,33 @@ def get_or_load_model(): print("Model not loaded, initializing...") try: MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE) - if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE: - MODEL.to(DEVICE) - print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}") + # FIXED: Force model to device explicitly + if hasattr(MODEL, 'to'): + MODEL = MODEL.to(DEVICE) + print(f"✓ Model moved to {DEVICE}") + + # Also move submodules if they exist + if hasattr(MODEL, 't3') and hasattr(MODEL.t3, 'to'): + MODEL.t3 = MODEL.t3.to(DEVICE) + if hasattr(MODEL, 't2') and hasattr(MODEL.t2, 'to'): + MODEL.t2 = MODEL.t2.to(DEVICE) + if hasattr(MODEL, 't1') and hasattr(MODEL.t1, 'to'): + MODEL.t1 = MODEL.t1.to(DEVICE) + + print(f"✓ Model loaded successfully on {DEVICE}") + + # Print actual device locations for debugging + if hasattr(MODEL, 't3'): + try: + t3_device = next(MODEL.t3.parameters()).device + print(f" T3 device: {t3_device}") + except: + pass + except Exception as e: print(f"Error loading model: {e}") + import traceback + traceback.print_exc() raise return MODEL @@ -177,36 +199,43 @@ def switch_t3_model(model_choice: str): custom_path = CUSTOM_T3_MODELS.get(model_choice) if custom_path: - # The path should point directly to the safetensors file - # Try with no extension first, then .safetensors - if not custom_path.endswith('.safetensors'): - safetensors_path = custom_path + '.safetensors' if os.path.exists(custom_path + '.safetensors') else custom_path - else: - safetensors_path = custom_path - - if not os.path.exists(safetensors_path): - return f"❌ Error: Model file not found: {safetensors_path}" + # FIXED: Better path handling + if not os.path.exists(custom_path): + # Try adding .safetensors if not present + if not custom_path.endswith('.safetensors'): + alt_path = custom_path + '.safetensors' + if os.path.exists(alt_path): + custom_path = alt_path + else: + return f"❌ Error: Model file not found at {custom_path} or {alt_path}" + else: + return f"❌ Error: Model file not found: {custom_path}" - print(f"Loading custom T3 model from: {safetensors_path}") + print(f"Loading custom T3 model from: {custom_path}") try: - # Load the custom T3 state dict (exactly like your friend's code) - t3_state = load_safetensors(safetensors_path, device="cpu") - MODEL.t3.load_state_dict(t3_state) + # FIXED: Load directly to target device, not CPU first + t3_state = load_safetensors(custom_path, device=str(DEVICE)) + MODEL.t3.load_state_dict(t3_state, strict=False) # Added strict=False for safety MODEL.t3.to(DEVICE).eval() - print(f"✓ Loaded custom T3 model: {model_choice}") - return f"✓ Loaded: {model_choice} from {safetensors_path}" + + # Verify device + t3_device = next(MODEL.t3.parameters()).device + print(f"✓ Loaded custom T3 model: {model_choice} on {t3_device}") + return f"✓ Loaded: {model_choice}\n📍 Device: {t3_device}\n📁 From: {custom_path}" except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error details:\n{error_details}") - return f"❌ Error loading model: {str(e)}" + return f"❌ Error loading model: {str(e)}\n\nFull traceback in console." else: print("Reloading default T3 model...") try: # Reload the entire model to get default T3 MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE) + MODEL = MODEL.to(DEVICE) + MODEL.t3.to(DEVICE).eval() print("✓ Loaded default T3 model") - return "✓ Loaded: Default T3 model" + return f"✓ Loaded: Default T3 model\n📍 Device: {DEVICE}" except Exception as e: return f"❌ Error loading model: {str(e)}" @@ -248,6 +277,13 @@ def generate_tts_audio( print(f"Generating audio for text: '{text_input[:50]}...' in language: {language_id}") + # FIXED: Verify model is on correct device before generation + if hasattr(current_model, 't3'): + t3_device = next(current_model.t3.parameters()).device + print(f" T3 currently on: {t3_device}") + if str(t3_device) != DEVICE and DEVICE == "cuda": + print(f" ⚠️ WARNING: T3 is on {t3_device} but should be on {DEVICE}") + # Handle optional audio prompt chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id) @@ -268,7 +304,7 @@ def generate_tts_audio( **generate_kwargs ) print("Audio generation complete.") - return (current_model.sr, wav.squeeze(0).numpy()) + return (current_model.sr, wav.squeeze(0).cpu().numpy()) # FIXED: Added .cpu() before .numpy() with gr.Blocks(theme=gr.themes.Soft()) as demo: @@ -296,7 +332,7 @@ def generate_tts_audio( label="Model Status", value="Default model loaded", interactive=False, - lines=1 + lines=2 ) load_t3_btn = gr.Button("🔄 Load Selected T3 Model", variant="secondary", size="sm") @@ -354,7 +390,8 @@ def generate_tts_audio( **Custom T3 Models:** - Load your fine-tuned T3 models for new languages - - Point directly to the `.safetensors` file (e.g., `C:/path/to/t3_cs` or `C:/path/to/t3_cs.safetensors`) + - Use full path with `.safetensors` extension + - Example: `C:/ChatterboxTraining/t3/t3_cs.safetensors` - Switch between models without restarting **Voice Cloning:** @@ -403,9 +440,16 @@ def on_language_change(lang, current_ref, current_text): print("🎙️ CHATTERBOX MULTILINGUAL TTS - CUSTOM T3 EDITION") print("="*60) print(f"Device: {DEVICE}") + if DEVICE == "cuda": + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"CUDA Version: {torch.version.cuda}") print(f"Available T3 models: {len(CUSTOM_T3_MODELS)}") - for model_name in CUSTOM_T3_MODELS.keys(): - print(f" - {model_name}") + for model_name, path in CUSTOM_T3_MODELS.items(): + if path: + exists = "✓" if os.path.exists(path) or os.path.exists(path + ".safetensors") else "✗" + print(f" {exists} {model_name}: {path}") + else: + print(f" ✓ {model_name} (built-in)") print("="*60 + "\n") demo.queue( diff --git a/tts_multi_gradio.py b/tts_multi_gradio.py index fec3e37f..2cfd92fa 100644 --- a/tts_multi_gradio.py +++ b/tts_multi_gradio.py @@ -15,7 +15,7 @@ # Custom Bark model configuration CUSTOM_BARK_MODELS = { "Default (suno/bark-small)": None, - "Fine-tuned": "C:/ChatterboxTraining/checkpoints/final_bark_model", + "czt3": "C:/ChatterboxTraining/t3", # Add more custom models here: # "My Voice Clone": "C:/path/to/another/model", } @@ -228,7 +228,7 @@ def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num step=.05, label="CFG Weight / Pace", info="Controls generation guidance", - value=0.5 + value=0.3 ) # Advanced Options From b4eded814ee1be3b0920f1bef1cff143aa179ce9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ku=C4=8Dera?= Date: Wed, 12 Nov 2025 10:10:43 -0600 Subject: [PATCH 09/13] Add files via upload --- gui.cmd | 3 + gui.py | 275 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 22 +++++ 3 files changed, 300 insertions(+) create mode 100644 gui.cmd create mode 100644 gui.py create mode 100644 main.py diff --git a/gui.cmd b/gui.cmd new file mode 100644 index 00000000..dce4e948 --- /dev/null +++ b/gui.cmd @@ -0,0 +1,3 @@ +call venv\Scripts\activate +python gui.py +pause \ No newline at end of file diff --git a/gui.py b/gui.py new file mode 100644 index 00000000..92d2df7b --- /dev/null +++ b/gui.py @@ -0,0 +1,275 @@ +import sys +import torch +import torchaudio as ta +from pathlib import Path +from safetensors.torch import load_file +from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, + QHBoxLayout, QPushButton, QLineEdit, QTextEdit, + QComboBox, QLabel, QGroupBox, QFileDialog, + QMessageBox, QProgressBar) +from PySide6.QtCore import QThread, Signal, Qt +from PySide6.QtMultimedia import QAudioOutput, QMediaPlayer +from PySide6.QtCore import QUrl +from chatterbox_git.src.chatterbox import mtl_tts + +class TTSThread(QThread): + finished = Signal(str) + error = Signal(str) + progress = Signal(str) + + def __init__(self, text, language_id, device, model_path, output_path, audio_prompt_path=None): + super().__init__() + self.text = text + self.language_id = language_id + self.device = device + self.model_path = model_path + self.output_path = output_path + self.audio_prompt_path = audio_prompt_path + + def run(self): + try: + self.progress.emit("Loading multilingual model...") + model = mtl_tts.ChatterboxMultilingualTTS.from_pretrained(device=self.device) + + self.progress.emit(f"Loading language weights...") + t3_state = load_file(self.model_path, device=self.device) + model.t3.load_state_dict(t3_state) + model.t3.to(self.device).eval() + + self.progress.emit("Generating speech...") + # CORRECT: Pass audio_prompt_path directly to generate() + wav = model.generate( + self.text, + language_id=self.language_id, + audio_prompt_path=self.audio_prompt_path # Direct path reference + ) + + Path(self.output_path).parent.mkdir(parents=True, exist_ok=True) + ta.save(self.output_path, wav, model.sr) + + self.finished.emit(self.output_path) + except Exception as e: + self.error.emit(str(e)) + +class ChatterboxTTSGUI(QMainWindow): + def __init__(self): + super().__init__() + self.setWindowTitle("Chatterbox Multilingual TTS") + self.setMinimumSize(700, 450) + self.media_player = QMediaPlayer() + self.audio_output = QAudioOutput() + self.media_player.setAudioOutput(self.audio_output) + self.init_ui() + + # === CORRECT HELPER METHODS === + + def browse_model_file(self): + path, _ = QFileDialog.getOpenFileName(self, "Select Model", "", "SafeTensors (*.safetensors)") + if path: + self.model_path_edit.setText(path) + + def browse_prompt_file(self): + path, _ = QFileDialog.getOpenFileName(self, "Select Reference Audio", "", "Audio Files (*.wav *.mp3 *.flac)") + if path: + self.prompt_path_edit.setText(path) + + def browse_output_file(self): + path, _ = QFileDialog.getSaveFileName(self, "Save Audio", "", "WAV (*.wav)") + if path: + if not path.endswith('.wav'): + path += '.wav' + self.output_path_edit.setText(path) + + def generate_speech(self): + # Validate inputs + text = self.text_edit.toPlainText().strip() + if not text: + QMessageBox.warning(self, "Input Error", "Please enter text to synthesize.") + return + + model_path = self.model_path_edit.text().strip() + if not Path(model_path).exists(): + QMessageBox.warning(self, "Error", f"Model file not found:\n{model_path}") + return + + output_path = self.output_path_edit.text().strip() + if not output_path: + QMessageBox.warning(self, "Error", "Please specify an output path") + return + + # Get audio prompt path (if any) + audio_prompt_path = self.prompt_path_edit.text().strip() + if audio_prompt_path and not Path(audio_prompt_path).exists(): + QMessageBox.warning(self, "Error", f"Reference audio not found:\n{audio_prompt_path}") + return + + # Disable UI and start generation + self.set_ui_enabled(False) + self.progress_bar.setVisible(True) + self.progress_bar.setRange(0, 0) + self.status_label.setText("Generating...") + + self.tts_thread = TTSThread( + text, + self.lang_combo.currentText(), + self.device_combo.currentText(), + model_path, + output_path, + audio_prompt_path if audio_prompt_path else None + ) + self.tts_thread.finished.connect(self.on_finished) + self.tts_thread.error.connect(self.on_error) + self.tts_thread.progress.connect(self.status_label.setText) + self.tts_thread.start() + + def on_finished(self, path): + self.set_ui_enabled(True) + self.progress_bar.setVisible(False) + self.status_label.setText(f"Saved: {Path(path).name}") + self.play_btn.setEnabled(True) + + def on_error(self, error): + self.set_ui_enabled(True) + self.progress_bar.setVisible(False) + self.status_label.setText("Error") + QMessageBox.critical(self, "Generation Error", error) + + def set_ui_enabled(self, enabled): + widgets = [self.generate_btn, self.device_combo, self.lang_combo, self.text_edit, + self.model_path_edit, self.browse_model_btn, self.prompt_path_edit, + self.browse_prompt_btn, self.output_path_edit, self.browse_output_btn] + for widget in widgets: + widget.setEnabled(enabled) + + def play_audio(self): + path = self.output_path_edit.text() + if Path(path).exists(): + self.media_player.setSource(QUrl.fromLocalFile(path)) + self.media_player.play() + self.play_btn.setEnabled(False) + self.stop_btn.setEnabled(True) + else: + QMessageBox.warning(self, "Error", "Audio file not found. Generate it first.") + + def stop_audio(self): + self.media_player.stop() + self.play_btn.setEnabled(True) + self.stop_btn.setEnabled(False) + + # === UI SETUP === + + def init_ui(self): + central = QWidget() + self.setCentralWidget(central) + layout = QVBoxLayout(central) + + # Device Settings + device_group = QGroupBox("Device Settings") + device_layout = QHBoxLayout() + device_layout.addWidget(QLabel("Device:")) + self.device_combo = QComboBox() + self.device_combo.addItems(["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + + (["mps"] if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else [])) + device_layout.addWidget(self.device_combo) + device_layout.addStretch() + device_group.setLayout(device_layout) + layout.addWidget(device_group) + + # Model Settings + model_group = QGroupBox("Model Settings") + model_layout = QVBoxLayout() + + # Language selection + lang_layout = QHBoxLayout() + lang_layout.addWidget(QLabel("Language:")) + self.lang_combo = QComboBox() + self.lang_combo.addItems(["cs", "en", "de", "es", "fr", "it", "pl", "pt", "ru", "tr", "ja", "ko", "zh"]) + self.lang_combo.setCurrentText("cs") + lang_layout.addWidget(self.lang_combo) + lang_layout.addStretch() + model_layout.addLayout(lang_layout) + + # Language model path + model_path_layout = QHBoxLayout() + self.model_path_edit = QLineEdit() + self.model_path_edit.setPlaceholderText("Path to language model (e.g., t3_cs.safetensors)...") + model_path_layout.addWidget(self.model_path_edit) + self.browse_model_btn = QPushButton("Browse...") + self.browse_model_btn.clicked.connect(self.browse_model_file) + model_path_layout.addWidget(self.browse_model_btn) + model_layout.addLayout(model_path_layout) + + # Audio prompt path (Speaker conditioning) + prompt_group = QGroupBox("Speaker Reference (Optional)") + prompt_layout = QVBoxLayout() + + prompt_path_layout = QHBoxLayout() + self.prompt_path_edit = QLineEdit() + self.prompt_path_edit.setPlaceholderText("Path to reference audio file (optional)...") + prompt_path_layout.addWidget(self.prompt_path_edit) + self.browse_prompt_btn = QPushButton("Browse...") + self.browse_prompt_btn.clicked.connect(self.browse_prompt_file) + prompt_path_layout.addWidget(self.browse_prompt_btn) + prompt_layout.addLayout(prompt_path_layout) + + prompt_group.setLayout(prompt_layout) + model_layout.addWidget(prompt_group) + model_group.setLayout(model_layout) + layout.addWidget(model_group) + + # Text Input + text_group = QGroupBox("Text to Speech") + text_layout = QVBoxLayout() + self.text_edit = QTextEdit() + self.text_edit.setPlaceholderText("Enter text here...") + self.text_edit.setText("Dobrý den, vítáme vás v našem testu syntézy řeči") + text_layout.addWidget(self.text_edit) + text_group.setLayout(text_layout) + layout.addWidget(text_group) + + # Output + output_group = QGroupBox("Output") + output_layout = QVBoxLayout() + output_path_layout = QHBoxLayout() + self.output_path_edit = QLineEdit() + self.output_path_edit.setText(str(Path.cwd() / "output.wav")) + output_path_layout.addWidget(self.output_path_edit) + self.browse_output_btn = QPushButton("Browse...") + self.browse_output_btn.clicked.connect(self.browse_output_file) + output_path_layout.addWidget(self.browse_output_btn) + output_layout.addLayout(output_path_layout) + output_group.setLayout(output_layout) + layout.addWidget(output_group) + + # Progress + self.progress_bar = QProgressBar() + self.progress_bar.setVisible(False) + layout.addWidget(self.progress_bar) + + self.status_label = QLabel("Ready") + self.status_label.setAlignment(Qt.AlignCenter) + layout.addWidget(self.status_label) + + # Buttons + button_layout = QHBoxLayout() + self.generate_btn = QPushButton("Generate Speech") + self.generate_btn.clicked.connect(self.generate_speech) + button_layout.addWidget(self.generate_btn) + self.play_btn = QPushButton("Play") + self.play_btn.clicked.connect(self.play_audio) + self.play_btn.setEnabled(False) + button_layout.addWidget(self.play_btn) + self.stop_btn = QPushButton("Stop") + self.stop_btn.clicked.connect(self.stop_audio) + self.stop_btn.setEnabled(False) + button_layout.addWidget(self.stop_btn) + layout.addLayout(button_layout) + +def main(): + app = QApplication(sys.argv) + window = ChatterboxTTSGUI() + window.show() + sys.exit(app.exec()) + +if __name__ == "__main__": + main() diff --git a/main.py b/main.py new file mode 100644 index 00000000..f32d6870 --- /dev/null +++ b/main.py @@ -0,0 +1,22 @@ +from chatterbox_git.src.chatterbox import mtl_tts +import torchaudio as ta +from safetensors.torch import load_file as load_safetensors + +device = "cuda" # or mps or cuda + +multilingual_model = mtl_tts.ChatterboxMultilingualTTS.from_pretrained(device=device) + +# ---- +# Then download the file from huggingface and place it in the current directory. +# ---- + + + +t3_state = load_safetensors("t3_cs.safetensors", device="cuda") +multilingual_model.t3.load_state_dict(t3_state) +multilingual_model.t3.to(device).eval() + +czech_text = "Přečtěte si krátký text a odpovězte na několik otázek, které testují porozumění. Můžete se začíst do krátkých úryvků z článků nebo do některého z našich krátkých a vtipných příběhů. Pozor, vybraný text můžete řešit pouze jednou v daný den." +wav_czech = multilingual_model.generate(czech_text, language_id="cs") +ta.save("test-cs.wav", wav_czech, multilingual_model.sr) + From a4edd02a57656e06946a1bf84e38bb3e6b8ac375 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ku=C4=8Dera?= Date: Wed, 19 Nov 2025 17:43:27 -0600 Subject: [PATCH 10/13] Add files via upload --- demo_text.txt | 8 + gui.py | 662 +++++++++++++++++++++++++++++-------------------- reference.cmd | 3 + reference.py | 22 ++ wav_to_mp3.cmd | 2 + 5 files changed, 422 insertions(+), 275 deletions(-) create mode 100644 demo_text.txt create mode 100644 reference.cmd create mode 100644 reference.py create mode 100644 wav_to_mp3.cmd diff --git a/demo_text.txt b/demo_text.txt new file mode 100644 index 00000000..70ecead4 --- /dev/null +++ b/demo_text.txt @@ -0,0 +1,8 @@ +Text 1 (Cca 300 znaků) +Praha, srdce Evropy, se probouzí do chladného podzimního rána. Listí šustí pod nohama spěchajících chodců na Karlově mostě, kde se první turisté shlukují, aby zachytili magický východ slunce nad Hradčany. Historie dýchá z každého kamene starého města a vůně čerstvě upečeného trdelníku se line ulicemi. Je to den plný možností a objevů pro každého, kdo sem zavítá. Česká metropole má vždy co nabídnout, ať už jde o umění, architekturu nebo kulturu. + +Text 2 (Cca 300 znaků) +Vývoj umělé inteligence (AI) je v současnosti jedním z nejdůležitějších technologických témat. Otevírá dveře k revolučním změnám v mnoha odvětvích, od medicíny po průmysl. Zároveň s sebou přináší i etické a společenské otázky, které je třeba zodpovědět. Jak zajistit spravedlivé využití AI a ochranu osobních dat? Tyto debaty jsou klíčové pro budoucí směřování naší společnosti v digitální éře. Je nutné hledat rovnováhu. + +Text 3 (Cca 300 znaků) +V Krkonoších napadl první sníh a horské chaty se připravují na zimní sezónu. Vzduch je svěží a mrazivý, ideální pro dlouhé túry s výhledy na zasněžené vrcholky. Lyžařská střediska finišují s údržbou a netrpělivě čekají na první nedočkavé sportovce. Pohyb v horách v zimě vyžaduje respekt a dobrou výbavu, ale odměnou je nezapomenutelný zážitek a pocit svobody v tiché, bílé krajině. \ No newline at end of file diff --git a/gui.py b/gui.py index 92d2df7b..85d7efce 100644 --- a/gui.py +++ b/gui.py @@ -1,275 +1,387 @@ -import sys -import torch -import torchaudio as ta -from pathlib import Path -from safetensors.torch import load_file -from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, - QHBoxLayout, QPushButton, QLineEdit, QTextEdit, - QComboBox, QLabel, QGroupBox, QFileDialog, - QMessageBox, QProgressBar) -from PySide6.QtCore import QThread, Signal, Qt -from PySide6.QtMultimedia import QAudioOutput, QMediaPlayer -from PySide6.QtCore import QUrl -from chatterbox_git.src.chatterbox import mtl_tts - -class TTSThread(QThread): - finished = Signal(str) - error = Signal(str) - progress = Signal(str) - - def __init__(self, text, language_id, device, model_path, output_path, audio_prompt_path=None): - super().__init__() - self.text = text - self.language_id = language_id - self.device = device - self.model_path = model_path - self.output_path = output_path - self.audio_prompt_path = audio_prompt_path - - def run(self): - try: - self.progress.emit("Loading multilingual model...") - model = mtl_tts.ChatterboxMultilingualTTS.from_pretrained(device=self.device) - - self.progress.emit(f"Loading language weights...") - t3_state = load_file(self.model_path, device=self.device) - model.t3.load_state_dict(t3_state) - model.t3.to(self.device).eval() - - self.progress.emit("Generating speech...") - # CORRECT: Pass audio_prompt_path directly to generate() - wav = model.generate( - self.text, - language_id=self.language_id, - audio_prompt_path=self.audio_prompt_path # Direct path reference - ) - - Path(self.output_path).parent.mkdir(parents=True, exist_ok=True) - ta.save(self.output_path, wav, model.sr) - - self.finished.emit(self.output_path) - except Exception as e: - self.error.emit(str(e)) - -class ChatterboxTTSGUI(QMainWindow): - def __init__(self): - super().__init__() - self.setWindowTitle("Chatterbox Multilingual TTS") - self.setMinimumSize(700, 450) - self.media_player = QMediaPlayer() - self.audio_output = QAudioOutput() - self.media_player.setAudioOutput(self.audio_output) - self.init_ui() - - # === CORRECT HELPER METHODS === - - def browse_model_file(self): - path, _ = QFileDialog.getOpenFileName(self, "Select Model", "", "SafeTensors (*.safetensors)") - if path: - self.model_path_edit.setText(path) - - def browse_prompt_file(self): - path, _ = QFileDialog.getOpenFileName(self, "Select Reference Audio", "", "Audio Files (*.wav *.mp3 *.flac)") - if path: - self.prompt_path_edit.setText(path) - - def browse_output_file(self): - path, _ = QFileDialog.getSaveFileName(self, "Save Audio", "", "WAV (*.wav)") - if path: - if not path.endswith('.wav'): - path += '.wav' - self.output_path_edit.setText(path) - - def generate_speech(self): - # Validate inputs - text = self.text_edit.toPlainText().strip() - if not text: - QMessageBox.warning(self, "Input Error", "Please enter text to synthesize.") - return - - model_path = self.model_path_edit.text().strip() - if not Path(model_path).exists(): - QMessageBox.warning(self, "Error", f"Model file not found:\n{model_path}") - return - - output_path = self.output_path_edit.text().strip() - if not output_path: - QMessageBox.warning(self, "Error", "Please specify an output path") - return - - # Get audio prompt path (if any) - audio_prompt_path = self.prompt_path_edit.text().strip() - if audio_prompt_path and not Path(audio_prompt_path).exists(): - QMessageBox.warning(self, "Error", f"Reference audio not found:\n{audio_prompt_path}") - return - - # Disable UI and start generation - self.set_ui_enabled(False) - self.progress_bar.setVisible(True) - self.progress_bar.setRange(0, 0) - self.status_label.setText("Generating...") - - self.tts_thread = TTSThread( - text, - self.lang_combo.currentText(), - self.device_combo.currentText(), - model_path, - output_path, - audio_prompt_path if audio_prompt_path else None - ) - self.tts_thread.finished.connect(self.on_finished) - self.tts_thread.error.connect(self.on_error) - self.tts_thread.progress.connect(self.status_label.setText) - self.tts_thread.start() - - def on_finished(self, path): - self.set_ui_enabled(True) - self.progress_bar.setVisible(False) - self.status_label.setText(f"Saved: {Path(path).name}") - self.play_btn.setEnabled(True) - - def on_error(self, error): - self.set_ui_enabled(True) - self.progress_bar.setVisible(False) - self.status_label.setText("Error") - QMessageBox.critical(self, "Generation Error", error) - - def set_ui_enabled(self, enabled): - widgets = [self.generate_btn, self.device_combo, self.lang_combo, self.text_edit, - self.model_path_edit, self.browse_model_btn, self.prompt_path_edit, - self.browse_prompt_btn, self.output_path_edit, self.browse_output_btn] - for widget in widgets: - widget.setEnabled(enabled) - - def play_audio(self): - path = self.output_path_edit.text() - if Path(path).exists(): - self.media_player.setSource(QUrl.fromLocalFile(path)) - self.media_player.play() - self.play_btn.setEnabled(False) - self.stop_btn.setEnabled(True) - else: - QMessageBox.warning(self, "Error", "Audio file not found. Generate it first.") - - def stop_audio(self): - self.media_player.stop() - self.play_btn.setEnabled(True) - self.stop_btn.setEnabled(False) - - # === UI SETUP === - - def init_ui(self): - central = QWidget() - self.setCentralWidget(central) - layout = QVBoxLayout(central) - - # Device Settings - device_group = QGroupBox("Device Settings") - device_layout = QHBoxLayout() - device_layout.addWidget(QLabel("Device:")) - self.device_combo = QComboBox() - self.device_combo.addItems(["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + - (["mps"] if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else [])) - device_layout.addWidget(self.device_combo) - device_layout.addStretch() - device_group.setLayout(device_layout) - layout.addWidget(device_group) - - # Model Settings - model_group = QGroupBox("Model Settings") - model_layout = QVBoxLayout() - - # Language selection - lang_layout = QHBoxLayout() - lang_layout.addWidget(QLabel("Language:")) - self.lang_combo = QComboBox() - self.lang_combo.addItems(["cs", "en", "de", "es", "fr", "it", "pl", "pt", "ru", "tr", "ja", "ko", "zh"]) - self.lang_combo.setCurrentText("cs") - lang_layout.addWidget(self.lang_combo) - lang_layout.addStretch() - model_layout.addLayout(lang_layout) - - # Language model path - model_path_layout = QHBoxLayout() - self.model_path_edit = QLineEdit() - self.model_path_edit.setPlaceholderText("Path to language model (e.g., t3_cs.safetensors)...") - model_path_layout.addWidget(self.model_path_edit) - self.browse_model_btn = QPushButton("Browse...") - self.browse_model_btn.clicked.connect(self.browse_model_file) - model_path_layout.addWidget(self.browse_model_btn) - model_layout.addLayout(model_path_layout) - - # Audio prompt path (Speaker conditioning) - prompt_group = QGroupBox("Speaker Reference (Optional)") - prompt_layout = QVBoxLayout() - - prompt_path_layout = QHBoxLayout() - self.prompt_path_edit = QLineEdit() - self.prompt_path_edit.setPlaceholderText("Path to reference audio file (optional)...") - prompt_path_layout.addWidget(self.prompt_path_edit) - self.browse_prompt_btn = QPushButton("Browse...") - self.browse_prompt_btn.clicked.connect(self.browse_prompt_file) - prompt_path_layout.addWidget(self.browse_prompt_btn) - prompt_layout.addLayout(prompt_path_layout) - - prompt_group.setLayout(prompt_layout) - model_layout.addWidget(prompt_group) - model_group.setLayout(model_layout) - layout.addWidget(model_group) - - # Text Input - text_group = QGroupBox("Text to Speech") - text_layout = QVBoxLayout() - self.text_edit = QTextEdit() - self.text_edit.setPlaceholderText("Enter text here...") - self.text_edit.setText("Dobrý den, vítáme vás v našem testu syntézy řeči") - text_layout.addWidget(self.text_edit) - text_group.setLayout(text_layout) - layout.addWidget(text_group) - - # Output - output_group = QGroupBox("Output") - output_layout = QVBoxLayout() - output_path_layout = QHBoxLayout() - self.output_path_edit = QLineEdit() - self.output_path_edit.setText(str(Path.cwd() / "output.wav")) - output_path_layout.addWidget(self.output_path_edit) - self.browse_output_btn = QPushButton("Browse...") - self.browse_output_btn.clicked.connect(self.browse_output_file) - output_path_layout.addWidget(self.browse_output_btn) - output_layout.addLayout(output_path_layout) - output_group.setLayout(output_layout) - layout.addWidget(output_group) - - # Progress - self.progress_bar = QProgressBar() - self.progress_bar.setVisible(False) - layout.addWidget(self.progress_bar) - - self.status_label = QLabel("Ready") - self.status_label.setAlignment(Qt.AlignCenter) - layout.addWidget(self.status_label) - - # Buttons - button_layout = QHBoxLayout() - self.generate_btn = QPushButton("Generate Speech") - self.generate_btn.clicked.connect(self.generate_speech) - button_layout.addWidget(self.generate_btn) - self.play_btn = QPushButton("Play") - self.play_btn.clicked.connect(self.play_audio) - self.play_btn.setEnabled(False) - button_layout.addWidget(self.play_btn) - self.stop_btn = QPushButton("Stop") - self.stop_btn.clicked.connect(self.stop_audio) - self.stop_btn.setEnabled(False) - button_layout.addWidget(self.stop_btn) - layout.addLayout(button_layout) - -def main(): - app = QApplication(sys.argv) - window = ChatterboxTTSGUI() - window.show() - sys.exit(app.exec()) - -if __name__ == "__main__": - main() +import sys +import re +import torch +import torchaudio as ta +from pathlib import Path +from safetensors.torch import load_file +from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, + QHBoxLayout, QPushButton, QLineEdit, QTextEdit, + QComboBox, QLabel, QGroupBox, QFileDialog, + QMessageBox, QCheckBox, QDoubleSpinBox) +from PySide6.QtCore import QThread, Signal, QSettings +from chatterbox_git.src.chatterbox import mtl_tts + + +def split_into_sentences(text): + """Simple sentence splitter for batching""" + # Split on period, exclamation, question mark followed by space or end + sentences = re.split(r'(?<=[.!?])\s+', text.strip()) + return [s.strip() for s in sentences if s.strip()] + + +class TTSThread(QThread): + finished = Signal(str) + error = Signal(str) + progress = Signal(str) + + def __init__(self, text, language_id, device, model_path, output_path, + audio_prompt_path=None, use_batching=False, output_format="wav", + temperature=0.8, cfg_weight=0.3): + super().__init__() + self.text = text + self.language_id = language_id + self.device = device + self.model_path = model_path + self.output_path = output_path + self.audio_prompt_path = audio_prompt_path + self.use_batching = use_batching + self.output_format = output_format.lower() + self.temperature = temperature + self.cfg_weight = cfg_weight + + def run(self): + try: + # Load model (EXACTLY like reference) + self.progress.emit("Loading model...") + multilingual_model = mtl_tts.ChatterboxMultilingualTTS.from_pretrained(device=self.device) + + self.progress.emit("Loading language weights...") + t3_state = load_file(self.model_path, device=self.device) + multilingual_model.t3.load_state_dict(t3_state) + multilingual_model.t3.to(self.device).eval() + + # Ensure output directory exists + Path(self.output_path).parent.mkdir(parents=True, exist_ok=True) + + # Generate audio + if self.use_batching: + self.progress.emit("Splitting into sentences...") + sentences = split_into_sentences(self.text) + + if len(sentences) == 0: + raise ValueError("No sentences found in text") + + self.progress.emit(f"Processing {len(sentences)} sentences...") + wav_chunks = [] + + for i, sentence in enumerate(sentences, 1): + self.progress.emit(f"Generating sentence {i}/{len(sentences)}: {sentence[:50]}...") + + wav_chunk = multilingual_model.generate( + sentence, + language_id=self.language_id, + audio_prompt_path=self.audio_prompt_path, + temperature=self.temperature, + cfg_weight=self.cfg_weight + ) + wav_chunks.append(wav_chunk) + + # Progressive saving - save accumulated chunks after each sentence + self.progress.emit(f"Saving progress ({i}/{len(sentences)})...") + wav_so_far = torch.cat(wav_chunks, dim=-1) + if self.output_format == "mp3": + ta.save(self.output_path, wav_so_far, multilingual_model.sr, format="mp3") + else: + ta.save(self.output_path, wav_so_far, multilingual_model.sr) + + # Final combined audio (already saved in loop) + wav = wav_so_far + else: + self.progress.emit("Generating audio...") + wav = multilingual_model.generate( + self.text, + language_id=self.language_id, + audio_prompt_path=self.audio_prompt_path, + temperature=self.temperature, + cfg_weight=self.cfg_weight + ) + + # Save audio + self.progress.emit("Saving audio...") + if self.output_format == "mp3": + ta.save(self.output_path, wav, multilingual_model.sr, format="mp3") + else: + ta.save(self.output_path, wav, multilingual_model.sr) + + self.finished.emit(self.output_path) + + except Exception as e: + import traceback + self.error.emit(f"{str(e)}\n\n{traceback.format_exc()}") + + +class ChatterboxTTSGUI(QMainWindow): + def __init__(self): + super().__init__() + self.setWindowTitle("Chatterbox TTS - Simple") + self.setMinimumWidth(600) + self.tts_thread = None + self.settings = QSettings("ChatterboxTTS", "SimpleGUI") + self.init_ui() + self.load_settings() + + def init_ui(self): + central_widget = QWidget() + self.setCentralWidget(central_widget) + layout = QVBoxLayout(central_widget) + + # Model Settings + model_group = QGroupBox("Model Settings") + model_layout = QVBoxLayout() + + # Language + lang_layout = QHBoxLayout() + lang_layout.addWidget(QLabel("Language ID:")) + self.language_edit = QLineEdit("cs") + self.language_edit.setPlaceholderText("e.g., cs, en, de") + lang_layout.addWidget(self.language_edit) + model_layout.addLayout(lang_layout) + + # Device + device_layout = QHBoxLayout() + device_layout.addWidget(QLabel("Device:")) + self.device_combo = QComboBox() + self.device_combo.addItems(["cuda", "cpu", "mps"]) + device_layout.addWidget(self.device_combo) + device_layout.addStretch() + model_layout.addLayout(device_layout) + + # Model Path + model_path_layout = QHBoxLayout() + model_path_layout.addWidget(QLabel("Model (.safetensors):")) + self.model_path_edit = QLineEdit("t3_cs.safetensors") + model_path_layout.addWidget(self.model_path_edit) + browse_btn = QPushButton("Browse...") + browse_btn.clicked.connect(self.browse_model) + model_path_layout.addWidget(browse_btn) + model_layout.addLayout(model_path_layout) + + # Reference Voice (Audio Prompt) + prompt_layout = QHBoxLayout() + prompt_layout.addWidget(QLabel("Reference Voice:")) + self.prompt_path_edit = QLineEdit() + self.prompt_path_edit.setPlaceholderText("Optional - leave empty for default voice") + prompt_layout.addWidget(self.prompt_path_edit) + browse_prompt_btn = QPushButton("Browse...") + browse_prompt_btn.clicked.connect(self.browse_prompt) + prompt_layout.addWidget(browse_prompt_btn) + model_layout.addLayout(prompt_layout) + + model_group.setLayout(model_layout) + layout.addWidget(model_group) + + # Text Input + text_group = QGroupBox("Text to Synthesize") + text_layout = QVBoxLayout() + self.text_edit = QTextEdit() + self.text_edit.setPlaceholderText("Enter text here...") + self.text_edit.setMaximumHeight(120) + text_layout.addWidget(self.text_edit) + + # Batching option + self.batch_checkbox = QCheckBox("Use sentence batching (for longer texts with short sentences)") + text_layout.addWidget(self.batch_checkbox) + + text_group.setLayout(text_layout) + layout.addWidget(text_group) + + # Generation Parameters + params_group = QGroupBox("Generation Parameters") + params_layout = QVBoxLayout() + + # Temperature + temp_layout = QHBoxLayout() + temp_layout.addWidget(QLabel("Temperature:")) + self.temperature_spin = QDoubleSpinBox() + self.temperature_spin.setRange(0.05, 5.0) + self.temperature_spin.setSingleStep(0.05) + self.temperature_spin.setValue(0.8) + temp_layout.addWidget(self.temperature_spin) + temp_layout.addWidget(QLabel("(Lower = more consistent, Higher = more varied)")) + temp_layout.addStretch() + params_layout.addLayout(temp_layout) + + # CFG Weight / Pace + cfg_layout = QHBoxLayout() + cfg_layout.addWidget(QLabel("CFG Weight / Pace:")) + self.cfg_weight_spin = QDoubleSpinBox() + self.cfg_weight_spin.setRange(0.0, 1.0) + self.cfg_weight_spin.setSingleStep(0.05) + self.cfg_weight_spin.setValue(0.3) + cfg_layout.addWidget(self.cfg_weight_spin) + cfg_layout.addWidget(QLabel("(Controls generation guidance and pacing)")) + cfg_layout.addStretch() + params_layout.addLayout(cfg_layout) + + params_group.setLayout(params_layout) + layout.addWidget(params_group) + + # Output Settings + output_group = QGroupBox("Output Settings") + output_group_layout = QVBoxLayout() + + # Format selector - MORE PROMINENT + format_layout = QHBoxLayout() + format_layout.addWidget(QLabel("Output Format:")) + self.format_combo = QComboBox() + self.format_combo.addItems(["WAV", "MP3"]) + self.format_combo.setMinimumWidth(100) + self.format_combo.currentTextChanged.connect(self.on_format_changed) + format_layout.addWidget(self.format_combo) + format_layout.addStretch() + output_group_layout.addLayout(format_layout) + + # Output Path + output_layout = QHBoxLayout() + output_layout.addWidget(QLabel("Output File:")) + self.output_path_edit = QLineEdit("output.wav") + output_layout.addWidget(self.output_path_edit) + browse_output_btn = QPushButton("Browse...") + browse_output_btn.clicked.connect(self.browse_output) + output_layout.addWidget(browse_output_btn) + output_group_layout.addLayout(output_layout) + + output_group.setLayout(output_group_layout) + layout.addWidget(output_group) + + # Status + self.status_label = QLabel("Ready") + layout.addWidget(self.status_label) + + # Generate Button + self.generate_btn = QPushButton("Generate Speech") + self.generate_btn.clicked.connect(self.generate_speech) + layout.addWidget(self.generate_btn) + + layout.addStretch() + + def browse_model(self): + file_path, _ = QFileDialog.getOpenFileName( + self, "Select Model File", "", "SafeTensors (*.safetensors);;All Files (*)" + ) + if file_path: + self.model_path_edit.setText(file_path) + + def browse_prompt(self): + file_path, _ = QFileDialog.getOpenFileName( + self, "Select Reference Voice", "", "Audio Files (*.wav *.mp3 *.flac);;All Files (*)" + ) + if file_path: + self.prompt_path_edit.setText(file_path) + + def browse_output(self): + format_filter = "WAV Files (*.wav);;MP3 Files (*.mp3);;All Files (*)" + file_path, _ = QFileDialog.getSaveFileName( + self, "Save Output", "", format_filter + ) + if file_path: + self.output_path_edit.setText(file_path) + # Auto-detect format from extension + if file_path.lower().endswith('.mp3'): + self.format_combo.setCurrentText("MP3") + else: + self.format_combo.setCurrentText("WAV") + + def on_format_changed(self, format_text): + """Update file extension when format changes""" + current_path = Path(self.output_path_edit.text()) + if format_text == "MP3": + new_path = current_path.with_suffix('.mp3') + else: + new_path = current_path.with_suffix('.wav') + self.output_path_edit.setText(str(new_path)) + + def load_settings(self): + """Load all settings from QSettings""" + self.language_edit.setText(self.settings.value("language", "cs")) + self.device_combo.setCurrentText(self.settings.value("device", "cuda")) + self.model_path_edit.setText(self.settings.value("model_path", "t3_cs.safetensors")) + self.prompt_path_edit.setText(self.settings.value("prompt_path", "")) + self.temperature_spin.setValue(float(self.settings.value("temperature", 0.8))) + self.cfg_weight_spin.setValue(float(self.settings.value("cfg_weight", 0.3))) + self.batch_checkbox.setChecked(self.settings.value("use_batching", False, type=bool)) + self.output_path_edit.setText(self.settings.value("output_path", "output.wav")) + self.format_combo.setCurrentText(self.settings.value("format", "WAV")) + self.text_edit.setText(self.settings.value("text", "")) + + def save_settings(self): + """Save all settings to QSettings""" + self.settings.setValue("language", self.language_edit.text()) + self.settings.setValue("device", self.device_combo.currentText()) + self.settings.setValue("model_path", self.model_path_edit.text()) + self.settings.setValue("prompt_path", self.prompt_path_edit.text()) + self.settings.setValue("temperature", self.temperature_spin.value()) + self.settings.setValue("cfg_weight", self.cfg_weight_spin.value()) + self.settings.setValue("use_batching", self.batch_checkbox.isChecked()) + self.settings.setValue("output_path", self.output_path_edit.text()) + self.settings.setValue("format", self.format_combo.currentText()) + self.settings.setValue("text", self.text_edit.toPlainText()) + + def generate_speech(self): + text = self.text_edit.toPlainText().strip() + if not text: + QMessageBox.warning(self, "Error", "Please enter some text") + return + + model_path = self.model_path_edit.text() + if not Path(model_path).exists(): + QMessageBox.warning(self, "Error", f"Model file not found: {model_path}") + return + + audio_prompt_path = self.prompt_path_edit.text().strip() or None + if audio_prompt_path and not Path(audio_prompt_path).exists(): + QMessageBox.warning(self, "Error", f"Reference voice file not found: {audio_prompt_path}") + return + + # Save settings before generation + self.save_settings() + + self.generate_btn.setEnabled(False) + self.status_label.setText("Starting generation...") + + self.tts_thread = TTSThread( + text=text, + language_id=self.language_edit.text(), + device=self.device_combo.currentText(), + model_path=model_path, + output_path=self.output_path_edit.text(), + audio_prompt_path=audio_prompt_path, + use_batching=self.batch_checkbox.isChecked(), + output_format=self.format_combo.currentText(), + temperature=self.temperature_spin.value(), + cfg_weight=self.cfg_weight_spin.value() + ) + + self.tts_thread.finished.connect(self.on_finished) + self.tts_thread.error.connect(self.on_error) + self.tts_thread.progress.connect(self.on_progress) + self.tts_thread.start() + + def on_progress(self, message): + self.status_label.setText(message) + + def on_finished(self, output_path): + self.generate_btn.setEnabled(True) + self.status_label.setText(f"Complete! Saved to: {output_path}") + QMessageBox.information(self, "Success", f"Audio saved to:\n{output_path}") + + def on_error(self, error_message): + self.generate_btn.setEnabled(True) + self.status_label.setText("Generation failed") + QMessageBox.critical(self, "Error", f"Error:\n{error_message}") + + def closeEvent(self, event): + """Save settings when window is closed""" + self.save_settings() + super().closeEvent(event) + + +def main(): + app = QApplication(sys.argv) + window = ChatterboxTTSGUI() + window.show() + sys.exit(app.exec()) + + +if __name__ == "__main__": + main() diff --git a/reference.cmd b/reference.cmd new file mode 100644 index 00000000..ee94da83 --- /dev/null +++ b/reference.cmd @@ -0,0 +1,3 @@ +call venv\Scripts\activate +python reference.py +pause \ No newline at end of file diff --git a/reference.py b/reference.py new file mode 100644 index 00000000..f32d6870 --- /dev/null +++ b/reference.py @@ -0,0 +1,22 @@ +from chatterbox_git.src.chatterbox import mtl_tts +import torchaudio as ta +from safetensors.torch import load_file as load_safetensors + +device = "cuda" # or mps or cuda + +multilingual_model = mtl_tts.ChatterboxMultilingualTTS.from_pretrained(device=device) + +# ---- +# Then download the file from huggingface and place it in the current directory. +# ---- + + + +t3_state = load_safetensors("t3_cs.safetensors", device="cuda") +multilingual_model.t3.load_state_dict(t3_state) +multilingual_model.t3.to(device).eval() + +czech_text = "Přečtěte si krátký text a odpovězte na několik otázek, které testují porozumění. Můžete se začíst do krátkých úryvků z článků nebo do některého z našich krátkých a vtipných příběhů. Pozor, vybraný text můžete řešit pouze jednou v daný den." +wav_czech = multilingual_model.generate(czech_text, language_id="cs") +ta.save("test-cs.wav", wav_czech, multilingual_model.sr) + diff --git a/wav_to_mp3.cmd b/wav_to_mp3.cmd new file mode 100644 index 00000000..77e6e126 --- /dev/null +++ b/wav_to_mp3.cmd @@ -0,0 +1,2 @@ +ffmpeg -i %1 %1.mp3 +pause \ No newline at end of file From b9d2c99613e89dca11d6fa1a67deacaec3aecd71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ku=C4=8Dera?= Date: Wed, 19 Nov 2025 17:51:45 -0600 Subject: [PATCH 11/13] Add files via upload --- gui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gui.py b/gui.py index 85d7efce..b74d6a1c 100644 --- a/gui.py +++ b/gui.py @@ -15,7 +15,7 @@ def split_into_sentences(text): """Simple sentence splitter for batching""" # Split on period, exclamation, question mark followed by space or end - sentences = re.split(r'(?<=[.!?])\s+', text.strip()) + sentences = re.split(r'(? Date: Wed, 19 Nov 2025 18:08:58 -0600 Subject: [PATCH 12/13] Add files via upload --- gui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gui.py b/gui.py index b74d6a1c..c06e90f6 100644 --- a/gui.py +++ b/gui.py @@ -15,7 +15,7 @@ def split_into_sentences(text): """Simple sentence splitter for batching""" # Split on period, exclamation, question mark followed by space or end - sentences = re.split(r'(? Date: Thu, 20 Nov 2025 03:45:22 -0600 Subject: [PATCH 13/13] convert mono to stereo by duplication --- gui.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gui.py b/gui.py index c06e90f6..788bc0eb 100644 --- a/gui.py +++ b/gui.py @@ -79,6 +79,10 @@ def run(self): # Progressive saving - save accumulated chunks after each sentence self.progress.emit(f"Saving progress ({i}/{len(sentences)})...") wav_so_far = torch.cat(wav_chunks, dim=-1) + + if wav_so_far.shape[0] == 1: + wav_so_far = wav_so_far.repeat(2, 1) + if self.output_format == "mp3": ta.save(self.output_path, wav_so_far, multilingual_model.sr, format="mp3") else: