diff --git a/src/horde_model_reference/integrations/horde_api_models.py b/src/horde_model_reference/integrations/horde_api_models.py index db8a55d7..f0aeec8f 100644 --- a/src/horde_model_reference/integrations/horde_api_models.py +++ b/src/horde_model_reference/integrations/horde_api_models.py @@ -236,6 +236,41 @@ class _StatsLookup(BaseModel): total: dict[str, int] = Field(default_factory=dict) +def _strip_quantization_suffix(model_name: str) -> str: + """Strip quantization suffix from a model name, preserving size. + + This is different from get_base_model_name which strips BOTH size and quantization. + This function only strips quantization, keeping the size suffix. + + Args: + model_name: Model name potentially with quantization suffix. + + Returns: + Model name without quantization suffix, but with size preserved. + + Example: + "Lumimaid-v0.2-8B-Q8_0" -> "Lumimaid-v0.2-8B" + "Lumimaid-v0.2-8B" -> "Lumimaid-v0.2-8B" + "koboldcpp/Lumimaid-v0.2-8B-Q4_K_M" -> "koboldcpp/Lumimaid-v0.2-8B" + """ + import re + + # Quantization patterns to strip (same as text_model_parser.QUANT_PATTERNS but as suffix) + quant_suffix_patterns = [ + r"[-_](Q[2-8]_K(?:_[SMLH])?)$", # -Q4_K_M, -Q5_K_S + r"[-_](Q[2-8]_[01])$", # -Q4_0, -Q5_0, -Q8_0 + r"[-_](Q[2-8])$", # -Q4, -Q8 + r"[-_](GGUF|GGML|GPTQ|AWQ|EXL2)$", + r"[-_](fp16|fp32|int8|int4)$", + ] + + result = model_name + for pattern in quant_suffix_patterns: + result = re.sub(pattern, "", result, flags=re.IGNORECASE) + + return result + + def _build_base_name_index(model_names: list[str]) -> dict[str, list[str]]: """Build an index mapping base model names to all matching model names. @@ -283,14 +318,85 @@ def _build_base_name_index(model_names: list[str]) -> dict[str, list[str]]: return base_name_index +def _build_model_with_size_index(model_names: list[str]) -> dict[str, list[str]]: + """Build an index mapping model names (with size, without quant) to all matching names. + + This enables aggregating stats across quantization variants only (e.g., Q4_K_M, Q8_0) + while keeping different sizes separate. + + Unlike _build_base_name_index which groups ALL variants (including different sizes), + this index only groups quantization variants of the SAME sized model. + + The key normalizes: + - Backend prefix (stripped for matching, but preserved in values) + - Org prefix (stripped for matching) + - Quantization suffix (stripped for matching) + + But preserves: + - Size suffix (8B, 12B, etc.) + + Args: + model_names: List of model names from API stats (may include backend prefixes + and quantization suffixes). + + Returns: + Dictionary mapping normalized model names (backend/model-size) to lists of + original model names (lowercase) that match that model. + + Example: + Input: ["koboldcpp/Lumimaid-v0.2-8B", "koboldcpp/Lumimaid-v0.2-8B-Q8_0", + "koboldcpp/Lumimaid-v0.2-12B", "aphrodite/NeverSleep/Lumimaid-v0.2-8B"] + Output: { + "koboldcpp/lumimaid-v0.2-8b": [ + "koboldcpp/lumimaid-v0.2-8b", + "koboldcpp/lumimaid-v0.2-8b-q8_0" + ], + "koboldcpp/lumimaid-v0.2-12b": ["koboldcpp/lumimaid-v0.2-12b"], + "aphrodite/lumimaid-v0.2-8b": ["aphrodite/neversleep/lumimaid-v0.2-8b"] + } + """ + model_with_size_index: dict[str, list[str]] = {} + + for model_name in model_names: + model_name_lower = model_name.lower() + + # Extract backend prefix if present + backend_prefix = "" + stripped = model_name_lower + if stripped.startswith("aphrodite/"): + backend_prefix = "aphrodite/" + stripped = stripped[len("aphrodite/") :] + elif stripped.startswith("koboldcpp/"): + backend_prefix = "koboldcpp/" + stripped = stripped[len("koboldcpp/") :] + + # Strip org prefix (e.g., "neversleep/lumimaid-v0.2-8b" -> "lumimaid-v0.2-8b") + if "/" in stripped: + stripped = stripped.split("/")[-1] + + # Strip quantization suffix + stripped_no_quant = _strip_quantization_suffix(stripped) + + # Build key: backend_prefix + model_name (no org, no quant, but with size) + key = f"{backend_prefix}{stripped_no_quant}" + + if key not in model_with_size_index: + model_with_size_index[key] = [] + if model_name_lower not in model_with_size_index[key]: + model_with_size_index[key].append(model_name_lower) + + return model_with_size_index + + class IndexedHordeModelStats(RootModel[_StatsLookup]): """Indexed model stats for O(1) lookups by model name. This wraps the stats response and provides case-insensitive dictionary access. Time complexity: O(1) for lookups instead of O(n) for dict iteration. - Also builds a base-name index for aggregating stats across quantization variants - and different backend prefixes. + Two indexes are built: + - _base_name_index: Groups ALL variants (including different sizes) for group-level aggregation + - _model_with_size_index: Groups only quantization variants for per-model stats Usage: indexed = IndexedHordeModelStats(stats_response) @@ -300,6 +406,7 @@ class IndexedHordeModelStats(RootModel[_StatsLookup]): root: _StatsLookup _base_name_index: dict[str, list[str]] = {} + _model_with_size_index: dict[str, list[str]] = {} def __init__(self, stats_response: HordeModelStatsResponse) -> None: """Build indexed lookups from stats response. @@ -315,11 +422,13 @@ def __init__(self, stats_response: HordeModelStatsResponse) -> None: ) super().__init__(root=lookups) - # Build base name index from all unique model names across all time periods + # Build indexes from all unique model names across all time periods all_model_names = ( set(stats_response.day.keys()) | set(stats_response.month.keys()) | set(stats_response.total.keys()) ) - self._base_name_index = _build_base_name_index(list(all_model_names)) + model_names_list = list(all_model_names) + self._base_name_index = _build_base_name_index(model_names_list) + self._model_with_size_index = _build_model_with_size_index(model_names_list) def get_day(self, model_name: str) -> int | None: """Get day count for a model (case-insensitive). O(1).""" @@ -371,7 +480,9 @@ def get_aggregated_stats(self, canonical_name: str) -> tuple[int, int, int]: # Then, add all model names that share the same base model name # This catches quantization variants and org-prefixed variants - base_name = get_base_model_name(canonical_name).lower() + # Strip org prefix from canonical name if present (e.g., "NeverSleep/Lumimaid-v0.2" -> "Lumimaid-v0.2") + canonical_without_org = canonical_name.split("/")[-1] if "/" in canonical_name else canonical_name + base_name = get_base_model_name(canonical_without_org).lower() if base_name in self._base_name_index: for api_model_name in self._base_name_index[base_name]: names_to_aggregate.add(api_model_name) @@ -391,37 +502,59 @@ def get_aggregated_stats(self, canonical_name: str) -> tuple[int, int, int]: def get_stats_with_variations( self, canonical_name: str ) -> tuple[tuple[int, int, int], dict[str, tuple[int, int, int]]]: - """Get aggregated stats and individual backend variations. + """Get stats for a specific model broken down by backend. + + Unlike get_aggregated_stats which aggregates across all models with the same + base name (e.g., all Lumimaid-v0.2 sizes), this method returns stats only for + the exact model specified (including its quantization variants), broken down + by backend prefix. - This method returns both the aggregated stats (same as get_aggregated_stats) - and a dictionary of individual backend stats keyed by backend name. - Now includes quantization variants and org-prefixed variants via base name matching. + This enables showing per-model stats in the UI when displaying grouped models, + where each model variant (8B, 12B, etc.) shows its own stats by backend. Args: canonical_name: The canonical model name from the model reference. Returns: Tuple of (aggregated_stats, variations_dict) where: - - aggregated_stats: (day_total, month_total, total_total) aggregated + - aggregated_stats: (day_total, month_total, total_total) for this exact model - variations_dict: Dict of backend_name -> (day, month, total) Keys are 'canonical', 'aphrodite', 'koboldcpp' depending on what's found """ - from horde_model_reference.analytics.text_model_parser import get_base_model_name from horde_model_reference.meta_consts import get_model_name_variants - # Collect all model names to aggregate (use set to avoid double-counting) + # Collect all model names that are variants of this specific model + # Use _model_with_size_index to include quantization variants, but NOT size variants names_to_aggregate: set[str] = set() - # First, add exact variants from get_model_name_variants + # Get exact backend-prefixed variants variants = get_model_name_variants(canonical_name) for variant in variants: - names_to_aggregate.add(variant.lower()) - - # Then, add all model names that share the same base model name - base_name = get_base_model_name(canonical_name).lower() - if base_name in self._base_name_index: - for api_model_name in self._base_name_index[base_name]: - names_to_aggregate.add(api_model_name) + variant_lower = variant.lower() + names_to_aggregate.add(variant_lower) + + # Build the normalized key to look up in _model_with_size_index + # The key format is: [backend_prefix/]model_name (no org, no quant) + backend_prefix = "" + stripped = variant_lower + if stripped.startswith("aphrodite/"): + backend_prefix = "aphrodite/" + stripped = stripped[len("aphrodite/") :] + elif stripped.startswith("koboldcpp/"): + backend_prefix = "koboldcpp/" + stripped = stripped[len("koboldcpp/") :] + + # Strip org prefix if present + if "/" in stripped: + stripped = stripped.split("/")[-1] + + # Strip quantization suffix and build key + stripped_no_quant = _strip_quantization_suffix(stripped) + key = f"{backend_prefix}{stripped_no_quant}" + + if key in self._model_with_size_index: + for api_model_name in self._model_with_size_index[key]: + names_to_aggregate.add(api_model_name) # Track stats by backend for variations dict backend_stats: dict[str, tuple[int, int, int]] = { diff --git a/tests/integrations/test_stats_aggregation.py b/tests/integrations/test_stats_aggregation.py index 4eb55985..235fb948 100644 --- a/tests/integrations/test_stats_aggregation.py +++ b/tests/integrations/test_stats_aggregation.py @@ -184,3 +184,64 @@ def test_aggregate_stats_model_not_found(self) -> None: assert day == 0 assert month == 0 assert total == 0 + + def test_aggregate_stats_canonical_name_with_org_prefix(self) -> None: + """Test that canonical names with org prefix correctly match API stats. + + This is a critical test case because model reference entries often have + org prefixes (e.g., "NeverSleep/Lumimaid-v0.2") but API stats may have + different prefixing patterns. + """ + stats = HordeModelStatsResponse( + day={ + "koboldcpp/Lumimaid-v0.2-8B": 4080, + "koboldcpp/Lumimaid-v0.2-8B-Q8_0": 1500, + "aphrodite/NeverSleep/Lumimaid-v0.2-8B": 2000, + }, + month={ + "koboldcpp/Lumimaid-v0.2-8B": 40000, + "koboldcpp/Lumimaid-v0.2-8B-Q8_0": 15000, + "aphrodite/NeverSleep/Lumimaid-v0.2-8B": 20000, + }, + total={}, + ) + + indexed = IndexedHordeModelStats(stats) + + # Query with canonical name that HAS org prefix (like in model reference) + canonical_with_org = "NeverSleep/Lumimaid-v0.2" + day, month, _total = indexed.get_aggregated_stats(canonical_with_org) + + # Should aggregate all variants even though canonical has org prefix + assert day == 4080 + 1500 + 2000 + assert month == 40000 + 15000 + 20000 + + def test_get_stats_with_variations_canonical_with_org_prefix(self) -> None: + """Test variations breakdown with canonical name that has org prefix. + + When querying for a specific model like 'NeverSleep/Lumimaid-v0.2-8B', + get_stats_with_variations should find: + - The aphrodite variant: aphrodite/NeverSleep/Lumimaid-v0.2-8B + - The koboldcpp variant: koboldcpp/Lumimaid-v0.2-8B + - Quantization variants: koboldcpp/Lumimaid-v0.2-8B-Q8_0 + """ + stats = HordeModelStatsResponse( + day={ + "koboldcpp/Lumimaid-v0.2-8B": 300, + "koboldcpp/Lumimaid-v0.2-8B-Q8_0": 400, + "aphrodite/NeverSleep/Lumimaid-v0.2-8B": 200, + }, + month={}, + total={}, + ) + + indexed = IndexedHordeModelStats(stats) + + # Query with canonical name that HAS org prefix (like in model reference) + (day_total, _m, _t), variations = indexed.get_stats_with_variations("NeverSleep/Lumimaid-v0.2-8B") + + assert day_total == 300 + 400 + 200 + assert "aphrodite" in variations + assert "koboldcpp" in variations + assert variations["aphrodite"][0] == 200 + assert variations["koboldcpp"][0] == 300 + 400