Skip to content

Commit d213e75

Browse files
committed
fix: model stats aggregation for org and quantization variants
- Added _strip_quantization_suffix and _build_model_with_size_index to enable accurate aggregation of stats across quantization variants while preserving model size. - Updated IndexedHordeModelStats to build and use these indexes, ensuring that canonical names with org prefixes and quantization variants are correctly matched. - Added tests to verify correct aggregation and breakdown of stats for models with org prefixes and quantization variants. [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci lint: fix
1 parent e8c3a60 commit d213e75

File tree

2 files changed

+214
-20
lines changed

2 files changed

+214
-20
lines changed

src/horde_model_reference/integrations/horde_api_models.py

Lines changed: 153 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,41 @@ class _StatsLookup(BaseModel):
236236
total: dict[str, int] = Field(default_factory=dict)
237237

238238

239+
def _strip_quantization_suffix(model_name: str) -> str:
240+
"""Strip quantization suffix from a model name, preserving size.
241+
242+
This is different from get_base_model_name which strips BOTH size and quantization.
243+
This function only strips quantization, keeping the size suffix.
244+
245+
Args:
246+
model_name: Model name potentially with quantization suffix.
247+
248+
Returns:
249+
Model name without quantization suffix, but with size preserved.
250+
251+
Example:
252+
"Lumimaid-v0.2-8B-Q8_0" -> "Lumimaid-v0.2-8B"
253+
"Lumimaid-v0.2-8B" -> "Lumimaid-v0.2-8B"
254+
"koboldcpp/Lumimaid-v0.2-8B-Q4_K_M" -> "koboldcpp/Lumimaid-v0.2-8B"
255+
"""
256+
import re
257+
258+
# Quantization patterns to strip (same as text_model_parser.QUANT_PATTERNS but as suffix)
259+
quant_suffix_patterns = [
260+
r"[-_](Q[2-8]_K(?:_[SMLH])?)$", # -Q4_K_M, -Q5_K_S
261+
r"[-_](Q[2-8]_[01])$", # -Q4_0, -Q5_0, -Q8_0
262+
r"[-_](Q[2-8])$", # -Q4, -Q8
263+
r"[-_](GGUF|GGML|GPTQ|AWQ|EXL2)$",
264+
r"[-_](fp16|fp32|int8|int4)$",
265+
]
266+
267+
result = model_name
268+
for pattern in quant_suffix_patterns:
269+
result = re.sub(pattern, "", result, flags=re.IGNORECASE)
270+
271+
return result
272+
273+
239274
def _build_base_name_index(model_names: list[str]) -> dict[str, list[str]]:
240275
"""Build an index mapping base model names to all matching model names.
241276
@@ -283,14 +318,85 @@ def _build_base_name_index(model_names: list[str]) -> dict[str, list[str]]:
283318
return base_name_index
284319

285320

321+
def _build_model_with_size_index(model_names: list[str]) -> dict[str, list[str]]:
322+
"""Build an index mapping model names (with size, without quant) to all matching names.
323+
324+
This enables aggregating stats across quantization variants only (e.g., Q4_K_M, Q8_0)
325+
while keeping different sizes separate.
326+
327+
Unlike _build_base_name_index which groups ALL variants (including different sizes),
328+
this index only groups quantization variants of the SAME sized model.
329+
330+
The key normalizes:
331+
- Backend prefix (stripped for matching, but preserved in values)
332+
- Org prefix (stripped for matching)
333+
- Quantization suffix (stripped for matching)
334+
335+
But preserves:
336+
- Size suffix (8B, 12B, etc.)
337+
338+
Args:
339+
model_names: List of model names from API stats (may include backend prefixes
340+
and quantization suffixes).
341+
342+
Returns:
343+
Dictionary mapping normalized model names (backend/model-size) to lists of
344+
original model names (lowercase) that match that model.
345+
346+
Example:
347+
Input: ["koboldcpp/Lumimaid-v0.2-8B", "koboldcpp/Lumimaid-v0.2-8B-Q8_0",
348+
"koboldcpp/Lumimaid-v0.2-12B", "aphrodite/NeverSleep/Lumimaid-v0.2-8B"]
349+
Output: {
350+
"koboldcpp/lumimaid-v0.2-8b": [
351+
"koboldcpp/lumimaid-v0.2-8b",
352+
"koboldcpp/lumimaid-v0.2-8b-q8_0"
353+
],
354+
"koboldcpp/lumimaid-v0.2-12b": ["koboldcpp/lumimaid-v0.2-12b"],
355+
"aphrodite/lumimaid-v0.2-8b": ["aphrodite/neversleep/lumimaid-v0.2-8b"]
356+
}
357+
"""
358+
model_with_size_index: dict[str, list[str]] = {}
359+
360+
for model_name in model_names:
361+
model_name_lower = model_name.lower()
362+
363+
# Extract backend prefix if present
364+
backend_prefix = ""
365+
stripped = model_name_lower
366+
if stripped.startswith("aphrodite/"):
367+
backend_prefix = "aphrodite/"
368+
stripped = stripped[len("aphrodite/") :]
369+
elif stripped.startswith("koboldcpp/"):
370+
backend_prefix = "koboldcpp/"
371+
stripped = stripped[len("koboldcpp/") :]
372+
373+
# Strip org prefix (e.g., "neversleep/lumimaid-v0.2-8b" -> "lumimaid-v0.2-8b")
374+
if "/" in stripped:
375+
stripped = stripped.split("/")[-1]
376+
377+
# Strip quantization suffix
378+
stripped_no_quant = _strip_quantization_suffix(stripped)
379+
380+
# Build key: backend_prefix + model_name (no org, no quant, but with size)
381+
key = f"{backend_prefix}{stripped_no_quant}"
382+
383+
if key not in model_with_size_index:
384+
model_with_size_index[key] = []
385+
if model_name_lower not in model_with_size_index[key]:
386+
model_with_size_index[key].append(model_name_lower)
387+
388+
return model_with_size_index
389+
390+
286391
class IndexedHordeModelStats(RootModel[_StatsLookup]):
287392
"""Indexed model stats for O(1) lookups by model name.
288393
289394
This wraps the stats response and provides case-insensitive dictionary access.
290395
Time complexity: O(1) for lookups instead of O(n) for dict iteration.
291396
292-
Also builds a base-name index for aggregating stats across quantization variants
293-
and different backend prefixes.
397+
Two indexes are built:
398+
- _base_name_index: Groups ALL variants (including different sizes) for group-level aggregation
399+
- _model_with_size_index: Groups only quantization variants for per-model stats
294400
295401
Usage:
296402
indexed = IndexedHordeModelStats(stats_response)
@@ -300,6 +406,7 @@ class IndexedHordeModelStats(RootModel[_StatsLookup]):
300406

301407
root: _StatsLookup
302408
_base_name_index: dict[str, list[str]] = {}
409+
_model_with_size_index: dict[str, list[str]] = {}
303410

304411
def __init__(self, stats_response: HordeModelStatsResponse) -> None:
305412
"""Build indexed lookups from stats response.
@@ -315,11 +422,13 @@ def __init__(self, stats_response: HordeModelStatsResponse) -> None:
315422
)
316423
super().__init__(root=lookups)
317424

318-
# Build base name index from all unique model names across all time periods
425+
# Build indexes from all unique model names across all time periods
319426
all_model_names = (
320427
set(stats_response.day.keys()) | set(stats_response.month.keys()) | set(stats_response.total.keys())
321428
)
322-
self._base_name_index = _build_base_name_index(list(all_model_names))
429+
model_names_list = list(all_model_names)
430+
self._base_name_index = _build_base_name_index(model_names_list)
431+
self._model_with_size_index = _build_model_with_size_index(model_names_list)
323432

324433
def get_day(self, model_name: str) -> int | None:
325434
"""Get day count for a model (case-insensitive). O(1)."""
@@ -371,7 +480,9 @@ def get_aggregated_stats(self, canonical_name: str) -> tuple[int, int, int]:
371480

372481
# Then, add all model names that share the same base model name
373482
# This catches quantization variants and org-prefixed variants
374-
base_name = get_base_model_name(canonical_name).lower()
483+
# Strip org prefix from canonical name if present (e.g., "NeverSleep/Lumimaid-v0.2" -> "Lumimaid-v0.2")
484+
canonical_without_org = canonical_name.split("/")[-1] if "/" in canonical_name else canonical_name
485+
base_name = get_base_model_name(canonical_without_org).lower()
375486
if base_name in self._base_name_index:
376487
for api_model_name in self._base_name_index[base_name]:
377488
names_to_aggregate.add(api_model_name)
@@ -391,37 +502,59 @@ def get_aggregated_stats(self, canonical_name: str) -> tuple[int, int, int]:
391502
def get_stats_with_variations(
392503
self, canonical_name: str
393504
) -> tuple[tuple[int, int, int], dict[str, tuple[int, int, int]]]:
394-
"""Get aggregated stats and individual backend variations.
505+
"""Get stats for a specific model broken down by backend.
506+
507+
Unlike get_aggregated_stats which aggregates across all models with the same
508+
base name (e.g., all Lumimaid-v0.2 sizes), this method returns stats only for
509+
the exact model specified (including its quantization variants), broken down
510+
by backend prefix.
395511
396-
This method returns both the aggregated stats (same as get_aggregated_stats)
397-
and a dictionary of individual backend stats keyed by backend name.
398-
Now includes quantization variants and org-prefixed variants via base name matching.
512+
This enables showing per-model stats in the UI when displaying grouped models,
513+
where each model variant (8B, 12B, etc.) shows its own stats by backend.
399514
400515
Args:
401516
canonical_name: The canonical model name from the model reference.
402517
403518
Returns:
404519
Tuple of (aggregated_stats, variations_dict) where:
405-
- aggregated_stats: (day_total, month_total, total_total) aggregated
520+
- aggregated_stats: (day_total, month_total, total_total) for this exact model
406521
- variations_dict: Dict of backend_name -> (day, month, total)
407522
Keys are 'canonical', 'aphrodite', 'koboldcpp' depending on what's found
408523
"""
409-
from horde_model_reference.analytics.text_model_parser import get_base_model_name
410524
from horde_model_reference.meta_consts import get_model_name_variants
411525

412-
# Collect all model names to aggregate (use set to avoid double-counting)
526+
# Collect all model names that are variants of this specific model
527+
# Use _model_with_size_index to include quantization variants, but NOT size variants
413528
names_to_aggregate: set[str] = set()
414529

415-
# First, add exact variants from get_model_name_variants
530+
# Get exact backend-prefixed variants
416531
variants = get_model_name_variants(canonical_name)
417532
for variant in variants:
418-
names_to_aggregate.add(variant.lower())
419-
420-
# Then, add all model names that share the same base model name
421-
base_name = get_base_model_name(canonical_name).lower()
422-
if base_name in self._base_name_index:
423-
for api_model_name in self._base_name_index[base_name]:
424-
names_to_aggregate.add(api_model_name)
533+
variant_lower = variant.lower()
534+
names_to_aggregate.add(variant_lower)
535+
536+
# Build the normalized key to look up in _model_with_size_index
537+
# The key format is: [backend_prefix/]model_name (no org, no quant)
538+
backend_prefix = ""
539+
stripped = variant_lower
540+
if stripped.startswith("aphrodite/"):
541+
backend_prefix = "aphrodite/"
542+
stripped = stripped[len("aphrodite/") :]
543+
elif stripped.startswith("koboldcpp/"):
544+
backend_prefix = "koboldcpp/"
545+
stripped = stripped[len("koboldcpp/") :]
546+
547+
# Strip org prefix if present
548+
if "/" in stripped:
549+
stripped = stripped.split("/")[-1]
550+
551+
# Strip quantization suffix and build key
552+
stripped_no_quant = _strip_quantization_suffix(stripped)
553+
key = f"{backend_prefix}{stripped_no_quant}"
554+
555+
if key in self._model_with_size_index:
556+
for api_model_name in self._model_with_size_index[key]:
557+
names_to_aggregate.add(api_model_name)
425558

426559
# Track stats by backend for variations dict
427560
backend_stats: dict[str, tuple[int, int, int]] = {

tests/integrations/test_stats_aggregation.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,64 @@ def test_aggregate_stats_model_not_found(self) -> None:
184184
assert day == 0
185185
assert month == 0
186186
assert total == 0
187+
188+
def test_aggregate_stats_canonical_name_with_org_prefix(self) -> None:
189+
"""Test that canonical names with org prefix correctly match API stats.
190+
191+
This is a critical test case because model reference entries often have
192+
org prefixes (e.g., "NeverSleep/Lumimaid-v0.2") but API stats may have
193+
different prefixing patterns.
194+
"""
195+
stats = HordeModelStatsResponse(
196+
day={
197+
"koboldcpp/Lumimaid-v0.2-8B": 4080,
198+
"koboldcpp/Lumimaid-v0.2-8B-Q8_0": 1500,
199+
"aphrodite/NeverSleep/Lumimaid-v0.2-8B": 2000,
200+
},
201+
month={
202+
"koboldcpp/Lumimaid-v0.2-8B": 40000,
203+
"koboldcpp/Lumimaid-v0.2-8B-Q8_0": 15000,
204+
"aphrodite/NeverSleep/Lumimaid-v0.2-8B": 20000,
205+
},
206+
total={},
207+
)
208+
209+
indexed = IndexedHordeModelStats(stats)
210+
211+
# Query with canonical name that HAS org prefix (like in model reference)
212+
canonical_with_org = "NeverSleep/Lumimaid-v0.2"
213+
day, month, _total = indexed.get_aggregated_stats(canonical_with_org)
214+
215+
# Should aggregate all variants even though canonical has org prefix
216+
assert day == 4080 + 1500 + 2000
217+
assert month == 40000 + 15000 + 20000
218+
219+
def test_get_stats_with_variations_canonical_with_org_prefix(self) -> None:
220+
"""Test variations breakdown with canonical name that has org prefix.
221+
222+
When querying for a specific model like 'NeverSleep/Lumimaid-v0.2-8B',
223+
get_stats_with_variations should find:
224+
- The aphrodite variant: aphrodite/NeverSleep/Lumimaid-v0.2-8B
225+
- The koboldcpp variant: koboldcpp/Lumimaid-v0.2-8B
226+
- Quantization variants: koboldcpp/Lumimaid-v0.2-8B-Q8_0
227+
"""
228+
stats = HordeModelStatsResponse(
229+
day={
230+
"koboldcpp/Lumimaid-v0.2-8B": 300,
231+
"koboldcpp/Lumimaid-v0.2-8B-Q8_0": 400,
232+
"aphrodite/NeverSleep/Lumimaid-v0.2-8B": 200,
233+
},
234+
month={},
235+
total={},
236+
)
237+
238+
indexed = IndexedHordeModelStats(stats)
239+
240+
# Query with canonical name that HAS org prefix (like in model reference)
241+
(day_total, _m, _t), variations = indexed.get_stats_with_variations("NeverSleep/Lumimaid-v0.2-8B")
242+
243+
assert day_total == 300 + 400 + 200
244+
assert "aphrodite" in variations
245+
assert "koboldcpp" in variations
246+
assert variations["aphrodite"][0] == 200
247+
assert variations["koboldcpp"][0] == 300 + 400

0 commit comments

Comments
 (0)