Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/dnet/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _abskey_to_local_path(self, key: str) -> Optional[str]:

def apply_quantization_from_config(
self, model_config: Any, model_metadata: Any
) -> bool:
) -> Tuple[bool, bool]:
"""Quantize using a simple MLX-style predicate with optional per-path overrides.

- If config["quantization"][path] exists, use that for this path.
Expand Down Expand Up @@ -408,15 +408,17 @@ def _predicate(path: str, module: nn.Module):
)
except Exception:
self._converted_to_quantized = False
return False
if g_bits != 0 and g_group != 0:
return (True, False)
return (False, False)
self._converted_to_quantized = True
return True
return (True, True)
except Exception:
try:
self._converted_to_quantized = False
except Exception:
pass
return False
return (False, False)

@staticmethod
def _shrink_linear_like(mod) -> None:
Expand Down
15 changes: 9 additions & 6 deletions src/dnet/shard/runtime.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better to raiseException since we don't/can't run a model once it's not correctly quantized?

Original file line number Diff line number Diff line change
Expand Up @@ -201,20 +201,23 @@ def load_model_core(self, req: ShardLoadModelRequest) -> None:
is_api_layer=False,
)
try:
applied = bool(
self.model.apply_quantization_from_config(
self.model_metadata.model_config,
model_metadata=self.model_metadata,
)
is_quant, applied = self.model.apply_quantization_from_config(
self.model_metadata.model_config,
model_metadata=self.model_metadata,
)
if is_quant and not applied:
raise RuntimeError("apply_quantization_from_config failed.")
logger.info(
"[QUANT] runtime=%s applied=%s model=%s",
self.shard_id,
applied,
self.model_metadata.model_type,
)
except RuntimeError as e:
logger.warning("[QUANT] apply failed: %s", e)
logger.error(
f"[QUANT] Failed to quantize what appears to be a quantized model: {e}"
)
raise

self.model.eval()
self.cache = make_cache(
Expand Down
2 changes: 1 addition & 1 deletion tests/fakes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
self.loaded = {}

def apply_quantization_from_config(self, cfg, model_metadata=None):
return self._quant_applies
return (self._quant_applies, True)

def eval(self):
self.eval_called = True
Expand Down