Skip to content

Commit 08e488d

Browse files
committed
fix(diarizer/sortformer): consume BNNS-fixed v3 models + config-mismatch guard (#726)
The root-level Sortformer CoreML models hit a BNNS graph-compile crash on newer BNNS ("tensor chunk_pre_encoder_embs_out as both an input and output"). The fixed rebuild lives at v3/fp16/ in the HF repo; point ModelNames there so downloads pick up the working models. - ModelNames.Sortformer.modelsSubdirectory = "v3/fp16" (BNNS-fixed set); v3/palettized is the 6-bit, ~2.5x-smaller set for RAM-constrained devices. - Add efficientV2_1 variant (chunk_len=25, ~2s latency, ~4x RTFx of fast) + config preset. - SortformerDiarizer now validates the diarizer config against the model's embedded metadata on init and logs a clear error on mismatch (a mismatch silently produced incorrect/slow diarization — #726). spkcacheUpdatePeriod excluded (host-clamped). - CLI: `sortformer --config fast|efficient|low|high`; `sortformer-benchmark --collar`, `--onset`, `--offset` (the hardcoded collar=0 / onset=0.5 skewed reported DER).
1 parent ffefeec commit 08e488d

6 files changed

Lines changed: 161 additions & 5 deletions

File tree

Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizer.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ public final class SortformerDiarizer: Diarizer {
108108
mainModelPath: mainModelPath
109109
)
110110

111+
validateConfigMatch(loadedModels)
112+
111113
// Use withLock helper to avoid direct NSLock usage in async context
112114
withLock {
113115
self._models = loadedModels
@@ -117,6 +119,34 @@ public final class SortformerDiarizer: Diarizer {
117119
logger.info("Sortformer initialized in \(String(format: "%.2f", loadedModels.compilationDuration))s")
118120
}
119121

122+
/// Warn loudly if the diarizer's `config` does not match the streaming parameters baked
123+
/// into the loaded model. A mismatch (e.g. a `.default` config against a `highContextV2_1`
124+
/// model) runs but produces incorrect and much slower results — issue #726.
125+
private func validateConfigMatch(_ models: SortformerModels) {
126+
guard let embedded = models.embeddedConfig else { return }
127+
let current = SortformerModels.EmbeddedConfig(
128+
chunkLen: config.chunkLen,
129+
chunkLeftContext: config.chunkLeftContext,
130+
chunkRightContext: config.chunkRightContext,
131+
fifoLen: config.fifoLen,
132+
spkcacheLen: config.spkcacheLen
133+
)
134+
guard current != embedded else { return }
135+
logger.error(
136+
"""
137+
Sortformer config mismatch — diarizer config does not match the loaded model. \
138+
This produces incorrect and much slower diarization (issue #726). \
139+
diarizer(chunkLen=\(current.chunkLen), leftCtx=\(current.chunkLeftContext), \
140+
rightCtx=\(current.chunkRightContext), fifoLen=\(current.fifoLen), \
141+
spkcacheLen=\(current.spkcacheLen)) \
142+
vs model(chunkLen=\(embedded.chunkLen), leftCtx=\(embedded.chunkLeftContext), \
143+
rightCtx=\(embedded.chunkRightContext), fifoLen=\(embedded.fifoLen), \
144+
spkcacheLen=\(embedded.spkcacheLen)). \
145+
Construct SortformerDiarizer with the SortformerConfig matching the model variant.
146+
"""
147+
)
148+
}
149+
120150
/// Execute a closure while holding the lock
121151
private func withLock<T>(_ body: () throws -> T) rethrows -> T {
122152
lock.lock()
@@ -126,6 +156,8 @@ public final class SortformerDiarizer: Diarizer {
126156

127157
/// Initialize with pre-loaded models.
128158
public func initialize(models: SortformerModels) {
159+
validateConfigMatch(models)
160+
129161
lock.lock()
130162
defer { lock.unlock() }
131163

Sources/FluidAudio/Diarizer/Sortformer/SortformerModelInference.swift

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,52 @@ extension SortformerModels {
156156
}
157157
}
158158

159+
// MARK: - Embedded Configuration
160+
161+
extension SortformerModels {
162+
163+
/// The model-shape-defining streaming parameters the converter embeds in the CoreML model
164+
/// metadata. These determine the input tensor shapes and must match the host config.
165+
///
166+
/// Note: `spkcache_update_period` is intentionally excluded — `SortformerConfig.init` clamps
167+
/// it host-side (`max(min(period, fifoLen+chunkLen), chunkLen)`), so the host value legitimately
168+
/// differs from the raw value baked into the model and is not a compatibility signal.
169+
public struct EmbeddedConfig: Equatable, Sendable {
170+
public let chunkLen: Int
171+
public let chunkLeftContext: Int
172+
public let chunkRightContext: Int
173+
public let fifoLen: Int
174+
public let spkcacheLen: Int
175+
}
176+
177+
/// The variant-defining streaming parameters the converter writes into the CoreML model
178+
/// metadata. Returns `nil` for older exports that don't carry them. Used to detect a
179+
/// `SortformerConfig` that doesn't match the model — a mismatch yields incorrect and much
180+
/// slower diarization (issue #726).
181+
public var embeddedConfig: EmbeddedConfig? {
182+
guard let meta = mainModel.modelDescription.metadata[.creatorDefinedKey] as? [String: String] else {
183+
return nil
184+
}
185+
func value(_ key: String) -> Int? { meta[key].flatMap(Int.init) }
186+
guard
187+
let chunkLen = value("chunk_len"),
188+
let chunkLeftContext = value("chunk_left_context"),
189+
let chunkRightContext = value("chunk_right_context"),
190+
let fifoLen = value("fifo_len"),
191+
let spkcacheLen = value("spkcache_len")
192+
else {
193+
return nil
194+
}
195+
return EmbeddedConfig(
196+
chunkLen: chunkLen,
197+
chunkLeftContext: chunkLeftContext,
198+
chunkRightContext: chunkRightContext,
199+
fifoLen: fifoLen,
200+
spkcacheLen: spkcacheLen
201+
)
202+
}
203+
}
204+
159205
// MARK: - Main Model Inference
160206

161207
extension SortformerModels {

Sources/FluidAudio/Diarizer/Sortformer/SortformerTypes.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,21 @@ public struct SortformerConfig: Sendable {
194194
spkcacheUpdatePeriod: 300
195195
)
196196

197+
/// Higher-throughput streaming config with Sortformer v2.1 weights (~2s output latency).
198+
/// Same context as `fastV2_1` but a larger 25-frame chunk: per-inference cost is dominated
199+
/// by the static speaker-cache + FIFO context, so a bigger chunk advances ~4x more audio per
200+
/// call at near-identical latency (~4x real-time factor vs `fastV2_1`). Use when ~2s latency
201+
/// is acceptable and throughput matters.
202+
public static let efficientV2_1 = SortformerConfig(
203+
modelVariant: .efficientV2_1,
204+
chunkLen: 25,
205+
chunkLeftContext: 1,
206+
chunkRightContext: 7,
207+
fifoLen: 40,
208+
spkcacheLen: 188,
209+
spkcacheUpdatePeriod: 31
210+
)
211+
197212
/// - Warning: If you don't use one of the default configurations, you must use a local model converted with that configuration.
198213
public init(
199214
modelVariant: ModelVariant? = .fastV2_1,

Sources/FluidAudio/ModelNames.swift

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,9 @@ public enum ModelNames {
642642
case balancedV2_1
643643
case highContextV2
644644
case highContextV2_1
645+
/// Higher-throughput streaming: larger chunk (~2s output latency) for ~4x the
646+
/// real-time factor of `fastV2_1` at near-identical per-inference cost.
647+
case efficientV2_1
645648

646649
public var name: String {
647650
switch self {
@@ -657,6 +660,8 @@ public enum ModelNames {
657660
return "SortformerNvidiaHigh_v2"
658661
case .highContextV2_1:
659662
return "SortformerNvidiaHigh_v2.1"
663+
case .efficientV2_1:
664+
return "SortformerEfficient_v2.1"
660665
}
661666
}
662667

@@ -674,18 +679,26 @@ public enum ModelNames {
674679
return .highContextV2
675680
case .highContextV2_1:
676681
return .highContextV2_1
682+
case .efficientV2_1:
683+
return .efficientV2_1
677684
}
678685
}
679686

680687
public var fileName: String {
681-
return "\(name).mlmodelc"
688+
return "\(Sortformer.modelsSubdirectory)/\(name).mlmodelc"
682689
}
683690

684691
public func isCompatible(with config: SortformerConfig) -> Bool {
685692
defaultConfiguration.isCompatible(with: config)
686693
}
687694
}
688695

696+
/// Repo subdirectory holding the active model set. `v3/fp16` is the BNNS-fixed rebuild
697+
/// (the older root-level models hit a "tensor as both input and output" graph-compile
698+
/// crash on newer BNNS — issue #726). Use `v3/palettized` for the 6-bit, ~2.5x-smaller
699+
/// set (fixes RAM-driven crashes on older devices at ~+0.9pp DER).
700+
public static let modelsSubdirectory = "v3/fp16"
701+
689702
/// Lowest latency for streaming
690703
public static let defaultVariant: Variant = .fastV2_1
691704

Sources/FluidAudioCLI/Commands/SortformerBenchmark.swift

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ enum SortformerBenchmark {
6868
var singleFile: String?
6969
var maxFiles: Int?
7070
var threshold: Float = 0.5
71+
var collarSeconds: Double = 0
72+
var onsetThreshold: Float?
73+
var offsetThreshold: Float?
7174
var modelPath: String?
7275
var outputFile: String?
7376
var verbose = false
@@ -116,6 +119,21 @@ enum SortformerBenchmark {
116119
threshold = Float(arguments[i + 1]) ?? 0.5
117120
i += 1
118121
}
122+
case "--collar":
123+
if i + 1 < arguments.count {
124+
collarSeconds = Double(arguments[i + 1]) ?? 0
125+
i += 1
126+
}
127+
case "--onset":
128+
if i + 1 < arguments.count {
129+
onsetThreshold = Float(arguments[i + 1])
130+
i += 1
131+
}
132+
case "--offset":
133+
if i + 1 < arguments.count {
134+
offsetThreshold = Float(arguments[i + 1])
135+
i += 1
136+
}
119137
case "--model":
120138
if i + 1 < arguments.count {
121139
modelPath = arguments[i + 1]
@@ -294,7 +312,19 @@ enum SortformerBenchmark {
294312
if let v = weakBoostRate { config.weakBoostRate = v }
295313
if let v = minPosScoresRate { config.minPosScoresRate = v }
296314
if let v = spkcacheSilFramesPerSpk { config.spkcacheSilFramesPerSpk = v }
297-
let diarizer = SortformerDiarizer(config: config)
315+
// Allow overriding the timeline binarization thresholds (sortformerDefault = 0.5/0.5).
316+
let diarizer: SortformerDiarizer
317+
if onsetThreshold != nil || offsetThreshold != nil {
318+
let timeline = DiarizerTimelineConfig(
319+
numSpeakers: config.numSpeakers,
320+
frameDurationSeconds: Float(config.frameDurationSeconds),
321+
onsetThreshold: onsetThreshold ?? 0.5,
322+
offsetThreshold: offsetThreshold ?? onsetThreshold ?? 0.5
323+
)
324+
diarizer = SortformerDiarizer(config: config, timelineConfig: timeline)
325+
} else {
326+
diarizer = SortformerDiarizer(config: config)
327+
}
298328

299329
do {
300330
if useHuggingFace {
@@ -340,6 +370,7 @@ enum SortformerBenchmark {
340370
diarizer: diarizer,
341371
modelLoadTime: modelLoadTime,
342372
threshold: threshold,
373+
collarSeconds: collarSeconds,
343374
verbose: verbose
344375
)
345376

@@ -381,6 +412,7 @@ enum SortformerBenchmark {
381412
diarizer: SortformerDiarizer,
382413
modelLoadTime: Double,
383414
threshold: Float,
415+
collarSeconds: Double,
384416
verbose: Bool
385417
) async -> BenchmarkResult? {
386418

@@ -484,7 +516,7 @@ enum SortformerBenchmark {
484516
ref: referenceSegments,
485517
hyp: hypothesisSegments,
486518
frameStep: derFrameStepSeconds,
487-
collar: 0
519+
collar: collarSeconds
488520
)
489521
let totalRefSpeech = max(derResult.totalRefSpeech, .leastNonzeroMagnitude)
490522
let derPercent = Float(derResult.der * 100)

Sources/FluidAudioCLI/Commands/SortformerCommand.swift

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ enum SortformerCommand {
4141
var weakBoostRate: Float?
4242
var minPosScoresRate: Float?
4343
var spkcacheSilFramesPerSpk: Int?
44+
var configName = "default"
4445

4546
// Parse remaining arguments
4647
var i = 1
@@ -88,6 +89,11 @@ enum SortformerCommand {
8889
modelPath = arguments[i + 1]
8990
i += 1
9091
}
92+
case "--config":
93+
if i + 1 < arguments.count {
94+
configName = arguments[i + 1].lowercased()
95+
i += 1
96+
}
9197
case "--threshold":
9298
if i + 1 < arguments.count, let v = Float(arguments[i + 1]) {
9399
predScoreThreshold = v
@@ -132,8 +138,20 @@ enum SortformerCommand {
132138
print("Sortformer Streaming Diarization")
133139
print(" Audio: \(audioFile)")
134140

135-
// Initialize Sortformer with default config (NVIDIA low latency: 1.04s)
136-
var config = SortformerConfig.default
141+
// Select config (default = NVIDIA low latency ~1.04s). `--config efficient` = chunk_len=25 (~2s, higher throughput).
142+
var config: SortformerConfig
143+
switch configName {
144+
case "efficient":
145+
config = .efficientV2_1
146+
case "fast", "fastv2_1":
147+
config = .fastV2_1
148+
case "low", "balanced":
149+
config = .balancedV2_1
150+
case "high", "highcontext":
151+
config = .highContextV2_1
152+
default:
153+
config = .default
154+
}
137155
var postConfig = DiarizerTimelineConfig.sortformerDefault
138156
config.debugMode = debugMode
139157
if let v = predScoreThreshold { config.predScoreThreshold = v }

0 commit comments

Comments
 (0)