Skip to content

Commit 47b45dd

Browse files
committed
add 2nd part code change\
1 parent 3700087 commit 47b45dd

File tree

5 files changed

+415
-58
lines changed

5 files changed

+415
-58
lines changed

Sources/FluidAudio/TextToSpeech/Kokoro/Assets/Lexicon/KokoroVocabulary.swift

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ public actor KokoroVocabulary {
1313
private let logger = AppLogger(subsystem: "com.fluidaudio.tts", category: "KokoroVocabulary")
1414
private var vocabulary: [String: Int32] = [:]
1515
private var isLoaded = false
16+
private var overrideURL: URL? = nil
1617

1718
/// Get the full vocabulary dictionary, loading it from disk (and downloading if required).
1819
public func getVocabulary() async throws -> [String: Int32] {
@@ -23,21 +24,31 @@ public actor KokoroVocabulary {
2324
}
2425

2526
private func loadVocabulary() async throws {
26-
let cacheDir = try TtsModels.cacheDirectoryURL()
27-
let kokoroDir = cacheDir.appendingPathComponent("Models/kokoro")
28-
let vocabURL = kokoroDir.appendingPathComponent("vocab_index.json")
29-
30-
if !FileManager.default.fileExists(atPath: vocabURL.path) {
31-
logger.info("Vocabulary file not found in cache, downloading...")
32-
try await downloadVocabularyFile(to: cacheDir)
33-
}
34-
3527
let data: Data
36-
do {
37-
data = try Data(contentsOf: vocabURL)
38-
} catch {
39-
logger.error("Failed to read vocabulary at \(vocabURL.path): \(error.localizedDescription)")
40-
throw TTSError.processingFailed("Failed to read Kokoro vocabulary: \(error.localizedDescription)")
28+
if let overrideURL, FileManager.default.fileExists(atPath: overrideURL.path) {
29+
do {
30+
data = try Data(contentsOf: overrideURL)
31+
logger.info("Loaded vocabulary override from: \(overrideURL.path)")
32+
} catch {
33+
logger.error("Failed to read override vocabulary at \(overrideURL.path): \(error.localizedDescription)")
34+
throw TTSError.processingFailed("Failed to read Kokoro vocabulary override: \(error.localizedDescription)")
35+
}
36+
} else {
37+
let cacheDir = try TtsModels.cacheDirectoryURL()
38+
let kokoroDir = cacheDir.appendingPathComponent("Models/kokoro")
39+
let vocabURL = kokoroDir.appendingPathComponent("vocab_index.json")
40+
41+
if !FileManager.default.fileExists(atPath: vocabURL.path) {
42+
logger.info("Vocabulary file not found in cache, downloading...")
43+
try await downloadVocabularyFile(to: cacheDir)
44+
}
45+
46+
do {
47+
data = try Data(contentsOf: vocabURL)
48+
} catch {
49+
logger.error("Failed to read vocabulary at \(vocabURL.path): \(error.localizedDescription)")
50+
throw TTSError.processingFailed("Failed to read Kokoro vocabulary: \(error.localizedDescription)")
51+
}
4152
}
4253

4354
guard let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] else {
@@ -93,4 +104,13 @@ public actor KokoroVocabulary {
93104
throw TTSError.downloadFailed("Failed to obtain Kokoro vocabulary: \(error.localizedDescription)")
94105
}
95106
}
107+
108+
/// Set an optional override file to load the vocabulary from.
109+
/// If set, this file will be used instead of downloading the default vocab_index.json.
110+
public func setOverrideURL(_ url: URL?) {
111+
overrideURL = url
112+
isLoaded = false
113+
if let url { logger.info("Vocabulary override set to: \(url.path)") } else { logger.info("Vocabulary override cleared") }
114+
}
96115
}
116+

Sources/FluidAudio/TextToSpeech/Kokoro/Pipeline/Synthesize/KokoroSynthesizer.swift

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,197 @@ public struct KokoroSynthesizer {
757757
)
758758
}
759759

760+
/// Synthesize directly from a Kokoro-zh phoneme string (one codepoint per token).
761+
/// This bypasses English lexicon + chunking and is useful for Mandarin.
762+
public static func synthesizePhonemeStringDetailed(
763+
phonemes: String,
764+
voice: String = TtsConstants.recommendedVoice,
765+
voiceSpeed: Float = 1.0,
766+
variantPreference: ModelNames.TTS.Variant? = .fifteenSecond
767+
) async throws -> SynthesisResult {
768+
logger.info("Starting synthesis from phoneme string; length=\(phonemes.count)")
769+
770+
try await ensureRequiredFiles()
771+
if !isVoiceEmbeddingPayloadCached(for: voice) {
772+
try? await TtsResourceDownloader.ensureVoiceEmbedding(voice: voice)
773+
}
774+
try await loadModel(variant: variantPreference)
775+
776+
let modelCache = try currentModelCache()
777+
let vocabulary = try await KokoroVocabulary.shared.getVocabulary()
778+
let capacities = try await capacities(for: variantPreference)
779+
let lexiconMetrics = await lexiconCache.metrics()
780+
781+
// Build a single chunk from phoneme codepoints
782+
let tokens: [String] = phonemes.map { String($0) }
783+
let chunk = TextChunk(
784+
words: [],
785+
atoms: tokens,
786+
phonemes: tokens,
787+
totalFrames: 0,
788+
pauseAfterMs: 0,
789+
text: phonemes
790+
)
791+
let entries = try buildChunkEntries(
792+
from: [chunk],
793+
vocabulary: vocabulary,
794+
preference: variantPreference,
795+
capacities: capacities
796+
)
797+
798+
struct ChunkSynthesisResult: Sendable { let index: Int; let samples: [Float]; let predictionTime: TimeInterval }
799+
800+
let embeddingDimension = try await modelCache.referenceEmbeddingDimension()
801+
let embeddingCache = try prepareVoiceEmbeddingCache(
802+
voice: voice,
803+
entries: entries,
804+
embeddingDimension: embeddingDimension
805+
)
806+
807+
let totalChunks = entries.count
808+
let groupedByTargetTokens = Dictionary(grouping: entries, by: { $0.template.targetTokens })
809+
let phasesShape: [NSNumber] = [1, 9]
810+
try await multiArrayPool.preallocate(shape: phasesShape, dataType: .float32, count: max(1, totalChunks), zeroFill: true)
811+
for (targetTokens, group) in groupedByTargetTokens {
812+
let shape: [NSNumber] = [1, NSNumber(value: targetTokens)]
813+
try await multiArrayPool.preallocate(shape: shape, dataType: .int32, count: max(1, group.count * 2), zeroFill: false)
814+
}
815+
let refShape: [NSNumber] = [1, NSNumber(value: embeddingDimension)]
816+
try await multiArrayPool.preallocate(shape: refShape, dataType: .float32, count: max(1, totalChunks), zeroFill: false)
817+
818+
let chunkTemplates = entries.map { $0.template }
819+
var chunkSampleBuffers = Array(repeating: [Float](), count: totalChunks)
820+
var allSamples: [Float] = []
821+
let crossfadeMs = 8
822+
let samplesPerMillisecond = Double(TtsConstants.audioSampleRate) / 1_000.0
823+
let crossfadeN = max(0, Int(Double(crossfadeMs) * samplesPerMillisecond))
824+
var totalPredictionTime: TimeInterval = 0
825+
826+
let chunkOutputs = try await withThrowingTaskGroup(of: ChunkSynthesisResult.self) { group in
827+
for (index, entry) in entries.enumerated() {
828+
let chunk = entry.chunk
829+
let inputIds = entry.inputIds
830+
let template = entry.template
831+
let chunkIndex = index
832+
guard let embeddingData = embeddingCache[inputIds.count] else {
833+
throw TTSError.processingFailed("Missing voice embedding for chunk \(index + 1) with \(inputIds.count) tokens")
834+
}
835+
let referenceVector = embeddingData.vector
836+
group.addTask(priority: .userInitiated) {
837+
let (samples, t) = try await synthesizeChunk(
838+
chunk,
839+
inputIds: inputIds,
840+
variant: template.variant,
841+
targetTokens: template.targetTokens,
842+
referenceVector: referenceVector
843+
)
844+
return ChunkSynthesisResult(index: chunkIndex, samples: samples, predictionTime: t)
845+
}
846+
}
847+
var results: [ChunkSynthesisResult] = []
848+
results.reserveCapacity(totalChunks)
849+
for try await r in group { results.append(r) }
850+
return results
851+
}
852+
853+
let sorted = chunkOutputs.sorted { $0.index < $1.index }
854+
var totalFrameCount = 0
855+
for output in sorted {
856+
let idx = output.index
857+
let samples = output.samples
858+
chunkSampleBuffers[idx] = samples
859+
totalPredictionTime += output.predictionTime
860+
if TtsConstants.kokoroFrameSamples > 0 {
861+
totalFrameCount += samples.count / TtsConstants.kokoroFrameSamples
862+
}
863+
if idx == 0 { allSamples.append(contentsOf: samples); continue }
864+
let prevPause = entries[idx - 1].chunk.pauseAfterMs
865+
if prevPause > 0 {
866+
let silenceCount = Int(Double(prevPause) * samplesPerMillisecond)
867+
if silenceCount > 0 { allSamples.append(contentsOf: repeatElement(0.0, count: silenceCount)) }
868+
allSamples.append(contentsOf: samples)
869+
} else {
870+
let n = min(crossfadeN, allSamples.count, samples.count)
871+
if n > 0 {
872+
let tailStartIndex = allSamples.count - n
873+
var fadeIn = [Float](repeating: 0, count: n)
874+
if n == 1 { fadeIn[0] = 1 } else { var start: Float = 0; var step: Float = 1.0/Float(n-1); vDSP_vramp(&start,&step,&fadeIn,1,vDSP_Length(n)) }
875+
var fadeOut = [Float](repeating: 1, count: n)
876+
// Avoid overlapping in-place access: compute fadeOut = 1 - fadeIn via simple loop
877+
if n == 1 {
878+
fadeOut[0] = 1 - fadeIn[0]
879+
} else {
880+
for j in 0..<n { fadeOut[j] = 1 - fadeIn[j] }
881+
}
882+
allSamples.withUnsafeMutableBufferPointer { allBuf in
883+
let tail = allBuf.baseAddress!.advanced(by: tailStartIndex)
884+
vDSP_vmul(tail, 1, fadeOut, 1, tail, 1, vDSP_Length(n))
885+
}
886+
vDSP_vma(Array(samples[0..<n]), 1, fadeIn, 1, Array(allSamples[(allSamples.count - n)...]), 1, &allSamples[(allSamples.count - n)], 1, vDSP_Length(n))
887+
if samples.count > n { allSamples.append(contentsOf: samples[n...]) }
888+
} else {
889+
allSamples.append(contentsOf: samples)
890+
}
891+
}
892+
}
893+
894+
guard !allSamples.isEmpty else { throw TTSError.processingFailed("Synthesis produced no samples") }
895+
var maxMag: Float = 0
896+
vDSP_maxmgv(allSamples, 1, &maxMag, vDSP_Length(allSamples.count))
897+
if maxMag > 0 {
898+
let d = maxMag
899+
if d > 0 {
900+
let inv = 1.0 / d
901+
// Safe element-wise scaling to avoid overlapping accesses
902+
for k in allSamples.indices { allSamples[k] *= inv }
903+
for idx in chunkSampleBuffers.indices {
904+
var buf = chunkSampleBuffers[idx]
905+
for k in buf.indices { buf[k] *= inv }
906+
chunkSampleBuffers[idx] = buf
907+
}
908+
}
909+
}
910+
// If total audio is shorter than 5.0s, trim 4 frames from the end (Mandarin zh guard)
911+
do {
912+
let fiveSecSamples = Int(5.0 * Double(TtsConstants.audioSampleRate))
913+
let trimSamples = TtsConstants.shortVariantGuardFrameCount * TtsConstants.kokoroFrameSamples
914+
if allSamples.count > trimSamples, allSamples.count < fiveSecSamples {
915+
allSamples.removeLast(trimSamples)
916+
if let lastIdx = chunkSampleBuffers.indices.last {
917+
var last = chunkSampleBuffers[lastIdx]
918+
if last.count > trimSamples { last.removeLast(trimSamples) }
919+
chunkSampleBuffers[lastIdx] = last
920+
}
921+
}
922+
}
923+
let audioData = try AudioWAV.data(from: allSamples, sampleRate: Double(TtsConstants.audioSampleRate))
924+
let chunkInfos = zip(chunkTemplates, chunkSampleBuffers).map { t, s in
925+
ChunkInfo(index: t.index, text: t.text, wordCount: t.wordCount, words: t.words, atoms: t.atoms, pauseAfterMs: t.pauseAfterMs, tokenCount: t.tokenCount, samples: s, variant: t.variant)
926+
}
927+
var footprints: [ModelNames.TTS.Variant: Int] = [:]
928+
for v in Set(entries.map { $0.template.variant }) {
929+
if let url = try? modelBundleURL(for: v) { footprints[v] = directorySize(at: url) }
930+
}
931+
let diagnostics = Diagnostics(
932+
variantFootprints: footprints,
933+
lexiconEntryCount: lexiconMetrics.entryCount,
934+
lexiconEstimatedBytes: lexiconMetrics.estimatedBytes,
935+
audioSampleBytes: allSamples.count * MemoryLayout<Float>.size,
936+
outputWavBytes: audioData.count
937+
)
938+
let base = SynthesisResult(audio: audioData, chunks: chunkInfos, diagnostics: diagnostics)
939+
let factor = max(0.1, voiceSpeed)
940+
if abs(factor - 1.0) < 0.01 { return base }
941+
let adjustedChunks = base.chunks.map { c -> ChunkInfo in
942+
let stretched = adjustSamples(c.samples, factor: factor)
943+
return ChunkInfo(index: c.index, text: c.text, wordCount: c.wordCount, words: c.words, atoms: c.atoms, pauseAfterMs: c.pauseAfterMs, tokenCount: c.tokenCount, samples: stretched, variant: c.variant)
944+
}
945+
let combined = adjustedChunks.flatMap { $0.samples }
946+
let adjustedAudio = try AudioWAV.data(from: combined, sampleRate: Double(TtsConstants.audioSampleRate))
947+
let updatedDiag = base.diagnostics?.updating(audioSampleBytes: combined.count * MemoryLayout<Float>.size, outputWavBytes: adjustedAudio.count)
948+
return SynthesisResult(audio: adjustedAudio, chunks: adjustedChunks, diagnostics: updatedDiag)
949+
}
950+
760951
private static func adjustSamples(_ samples: [Float], factor: Float) -> [Float] {
761952
let clamped = max(0.1, factor)
762953
if abs(clamped - 1.0) < 0.01 { return samples }

Sources/FluidAudio/TextToSpeech/TtsManager.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,31 @@ public final class TtSManager {
144144
logger.notice("Saved synthesized audio to: \(outputURL.lastPathComponent)")
145145
}
146146

147+
/// Synthesize directly from a Kokoro-zh phoneme string (single-codepoint tokens).
148+
public func synthesizePhonemesDetailed(
149+
phonemes: String,
150+
voice: String? = nil,
151+
voiceSpeed: Float = 1.0,
152+
speakerId: Int = 0,
153+
variantPreference: ModelNames.TTS.Variant? = nil
154+
) async throws -> KokoroSynthesizer.SynthesisResult {
155+
guard isInitialized else { throw TTSError.modelNotFound("Kokoro model not initialized") }
156+
try await prepareLexiconAssetsIfNeeded()
157+
let selectedVoice = resolveVoice(voice, speakerId: speakerId)
158+
try await ensureVoiceEmbeddingIfNeeded(for: selectedVoice)
159+
160+
return try await KokoroSynthesizer.withLexiconAssets(lexiconAssets) {
161+
try await KokoroSynthesizer.withModelCache(modelCache) {
162+
try await KokoroSynthesizer.synthesizePhonemeStringDetailed(
163+
phonemes: phonemes,
164+
voice: selectedVoice,
165+
voiceSpeed: voiceSpeed,
166+
variantPreference: variantPreference
167+
)
168+
}
169+
}
170+
}
171+
147172
public func setDefaultVoice(_ voice: String, speakerId: Int = 0) async throws {
148173
let normalized = Self.normalizeVoice(voice)
149174
try await ensureVoiceEmbeddingIfNeeded(for: normalized)

Sources/FluidAudio/TextToSpeech/TtsModels.swift

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,45 @@ public struct TtsModels {
8181
return TtsModels(models: loaded)
8282
}
8383

84+
/// Load a Kokoro CoreML model from a local path (either .mlmodelc or .mlpackage).
85+
/// The loaded model is registered under the 15s variant by default.
86+
public static func loadLocal(at path: String) async throws -> TtsModels {
87+
let expanded = (path as NSString).expandingTildeInPath
88+
let url = URL(fileURLWithPath: expanded)
89+
let fm = FileManager.default
90+
91+
guard fm.fileExists(atPath: url.path) else {
92+
throw TTSError.modelNotFound("Local model not found at: \(url.path)")
93+
}
94+
95+
let modelURL: URL
96+
do {
97+
let isDir = (try? url.resourceValues(forKeys: [.isDirectoryKey]).isDirectory) == true
98+
if url.pathExtension == "mlmodelc" || (isDir && url.path.hasSuffix(".mlmodelc")) {
99+
modelURL = url
100+
} else if url.pathExtension == "mlpackage" || isDir {
101+
modelURL = try await MLModel.compileModel(at: url)
102+
} else {
103+
// Try loading directly; if it fails, attempt compile
104+
do {
105+
_ = try MLModel(contentsOf: url)
106+
modelURL = url
107+
} catch {
108+
modelURL = try await MLModel.compileModel(at: url)
109+
}
110+
}
111+
} catch {
112+
throw TTSError.processingFailed("Failed to prepare local model: \(error.localizedDescription)")
113+
}
114+
115+
do {
116+
let model = try MLModel(contentsOf: modelURL)
117+
return TtsModels(models: [.fifteenSecond: model])
118+
} catch {
119+
throw TTSError.processingFailed("Failed to load local model: \(error.localizedDescription)")
120+
}
121+
}
122+
84123
private static func getCacheDirectory() throws -> URL {
85124
let baseDirectory: URL
86125
#if os(macOS)

0 commit comments

Comments
 (0)