diff --git a/src/dnet/core/models/base.py b/src/dnet/core/models/base.py index 5597cbd0..b1e4ad83 100644 --- a/src/dnet/core/models/base.py +++ b/src/dnet/core/models/base.py @@ -226,7 +226,7 @@ def _abskey_to_local_path(self, key: str) -> Optional[str]: def apply_quantization_from_config( self, model_config: Any, model_metadata: Any - ) -> bool: + ) -> Tuple[bool, bool]: """Quantize using a simple MLX-style predicate with optional per-path overrides. - If config["quantization"][path] exists, use that for this path. @@ -408,15 +408,17 @@ def _predicate(path: str, module: nn.Module): ) except Exception: self._converted_to_quantized = False - return False + if g_bits != 0 and g_group != 0: + return (True, False) + return (False, False) self._converted_to_quantized = True - return True + return (True, True) except Exception: try: self._converted_to_quantized = False except Exception: pass - return False + return (False, False) @staticmethod def _shrink_linear_like(mod) -> None: diff --git a/src/dnet/shard/runtime.py b/src/dnet/shard/runtime.py index ff3ea285..652030c9 100644 --- a/src/dnet/shard/runtime.py +++ b/src/dnet/shard/runtime.py @@ -201,12 +201,12 @@ def load_model_core(self, req: ShardLoadModelRequest) -> None: is_api_layer=False, ) try: - applied = bool( - self.model.apply_quantization_from_config( - self.model_metadata.model_config, - model_metadata=self.model_metadata, - ) + is_quant, applied = self.model.apply_quantization_from_config( + self.model_metadata.model_config, + model_metadata=self.model_metadata, ) + if is_quant and not applied: + raise RuntimeError("apply_quantization_from_config failed.") logger.info( "[QUANT] runtime=%s applied=%s model=%s", self.shard_id, @@ -214,7 +214,10 @@ def load_model_core(self, req: ShardLoadModelRequest) -> None: self.model_metadata.model_type, ) except RuntimeError as e: - logger.warning("[QUANT] apply failed: %s", e) + logger.error( + f"[QUANT] Failed to quantize what appears to be a quantized model: {e}" + ) + raise self.model.eval() self.cache = make_cache( diff --git a/tests/fakes/models.py b/tests/fakes/models.py index 650f844d..2563be17 100644 --- a/tests/fakes/models.py +++ b/tests/fakes/models.py @@ -62,7 +62,7 @@ def __init__( self.loaded = {} def apply_quantization_from_config(self, cfg, model_metadata=None): - return self._quant_applies + return (self._quant_applies, True) def eval(self): self.eval_called = True