From 76162fc6714edbf074afcb2370f3d9b1e4c2ab9d Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 24 Nov 2025 23:58:39 -0800 Subject: [PATCH 1/7] add support for qwen3_moe model --- src/dnet/core/models/__init__.py | 4 + src/dnet/core/models/qwen3_moe.py | 187 ++++++++++++++++++++++++++++++ 2 files changed, 191 insertions(+) create mode 100644 src/dnet/core/models/qwen3_moe.py diff --git a/src/dnet/core/models/__init__.py b/src/dnet/core/models/__init__.py index 82b48384..fc609c8f 100644 --- a/src/dnet/core/models/__init__.py +++ b/src/dnet/core/models/__init__.py @@ -8,6 +8,8 @@ from .llama import LlamaRingModel from .gpt_oss import GptOssRingModel from .qwen3 import Qwen3RingModel +from .qwen3_moe import Qwen3MoERingModel +from .glm4 import Glm4RingModel def get_ring_model( @@ -41,5 +43,7 @@ def get_ring_model( "LlamaRingModel", "GptOssRingModel", "Qwen3RingModel", + "Qwen3MoERingModel" + "Glm4RingModel", "get_ring_model", ] diff --git a/src/dnet/core/models/qwen3_moe.py b/src/dnet/core/models/qwen3_moe.py new file mode 100644 index 00000000..47236545 --- /dev/null +++ b/src/dnet/core/models/qwen3_moe.py @@ -0,0 +1,187 @@ +from typing import Any, Dict, List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.qwen3_moe import ModelArgs, Attention, MLP, SwitchGLU + +from .base import BaseRingModel + + +class Qwen3MoEDenseBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_experts = args.num_experts + self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False) + self.switch_mlp = SwitchGLU( + args.hidden_size, args.intermediate_size, self.num_experts + ) + + def __call__(self, x: mx.array): + indices = mx.arange(self.num_experts, dtype=mx.int32) + indices = mx.broadcast_to( + indices[None, None, :], (x.shape[0], x.shape[1], self.num_experts) + ) + gates = self.gate(x) + scores = mx.softmax(gates, axis=-1, precise=True) + y = self.switch_mlp(x, indices) + y = (y * scores[..., None]).sum(axis=-2) + return y + + +# force dense execution +class Qwen3MoEDecoderLayer(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.args = args + self.self_attn = Attention(args, layer_idx) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + + if (layer_idx not in args.mlp_only_layers) and ( + args.num_experts > 0 and (layer_idx + 1) % args.decoder_sparse_step == 0 + ): + self.mlp = Qwen3MoEDenseBlock(args) + else: + self.mlp = MLP(args.hidden_size, args.intermediate_size) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ): + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + return h + r + + +class Qwen3MoERingModel(BaseRingModel): + """Qwen3 MoE model for distributed execution""" + + model_type = "qwen3_moe" + + def __init__( + self, + model_config: Any, + assigned_layers: Optional[List[int]] = None, + is_api_layer: bool = False, + ): + super().__init__() + + if is_api_layer and assigned_layers: + raise RuntimeError("API Node cannot execute layers") + + self.model_config = model_config + self.is_api_layer = is_api_layer + self.config = config = ModelArgs.from_dict(model_config) + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if not config.tie_word_embeddings: + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.layers: List[nn.Module] = [] + self.abs_to_local: Dict[int, int] = {} + + for i, layer in enumerate(sorted(assigned_layers or [])): + self.layers.append(Qwen3MoEDecoderLayer(config, layer_idx=layer)) + self.abs_to_local[layer] = i + + self._converted_to_quantized = False + self._cached_mask_state: Optional[int] = None + self._cached_mask = None + + def embed(self, x: mx.array) -> mx.array: + if hasattr(self, "embed_tokens"): + return self.embed_tokens(x) + return x + + def normalize(self, x: mx.array) -> mx.array: + if hasattr(self, "norm"): + return self.norm(x) + return x + + def lm_project(self, x: mx.array) -> mx.array: + if hasattr(self, "lm_head") or hasattr(self, "embed_tokens"): + use_tied = bool(getattr(self.config, "tie_word_embeddings", False)) + if use_tied or not hasattr(self, "lm_head"): + return self.embed_tokens.as_linear(x) + return self.lm_head(x) + return x + + def forward(self, x: mx.array, cache: Optional[List[Any]] = None) -> mx.array: + mask = create_attention_mask(x, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for i, layer in enumerate(self.layers): + x = layer(x, mask, cache[i] if i < len(cache) else None) + + return x + + def apply_single_layer( + self, layer_idx: int, x: mx.array, cache: Optional[List[Any]] = None + ) -> mx.array: + if layer_idx not in self.abs_to_local: + raise RuntimeError(f"Layer {layer_idx} not hosted on this model instance") + # TODO: Mask reuse should respect concurrent requests + try: + T = int(x.shape[1]) if len(x.shape) > 1 else 1 + except Exception: + T = 1 + # dimension diagnostics removed + mask = None + if T > 1: + if self._cached_mask_state is None or self._cached_mask_state != T: + mask = create_attention_mask(x, cache) + self._cached_mask = mask + self._cached_mask_state = T + else: + mask = self._cached_mask + if mask is None: + mask = create_attention_mask(x, cache) + self._cached_mask = mask + self._cached_mask_state = T + local_idx = self.abs_to_local[layer_idx] + + c = None + if cache is not None and local_idx < len(cache): + c = cache[local_idx] + + return self.layers[local_idx](x, mask, c) + + # qwen stores expert weights in separate buffers + def sanitize(self, weights): + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: + return weights + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n in ["up_proj", "down_proj", "gate_proj"]: + if f"{prefix}.mlp.experts.0.{n}.weight" in weights: + to_join = [ + weights.pop(f"{prefix}.mlx.experts.{e}.{n}.weight") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlx.{n}.weight"] = mx.stack(to_join) + return weights + + @property + def decoding_layers(self): + return self.layers + + @property + def head_dim(self) -> Tuple[int, int]: + return (self.config.head_dim, self.config.head_dim) + + @property + def n_kv_heads(self) -> int: + return self.config.num_key_value_heads + + @property + def num_layers(self) -> int: + return len(self.layers) From 16bd5403303ad28e886bb6d67f88e93db653cedf Mon Sep 17 00:00:00 2001 From: Octavian Date: Mon, 24 Nov 2025 23:59:15 -0800 Subject: [PATCH 2/7] add support for glm4 model --- src/dnet/core/models/glm4.py | 119 +++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 src/dnet/core/models/glm4.py diff --git a/src/dnet/core/models/glm4.py b/src/dnet/core/models/glm4.py new file mode 100644 index 00000000..e8b143e1 --- /dev/null +++ b/src/dnet/core/models/glm4.py @@ -0,0 +1,119 @@ +from typing import Any, Dict, List, Optional, Tuple + +import mlx.nn as nn +import mlx.core as mx +from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.glm4 import ModelArgs, Glm4DecoderLayer + +from .base import BaseRingModel + + +class Glm4RingModel(BaseRingModel): + model_type = "glm4" + + def __init__( + self, + model_config: Any, + assigned_layers: Optional[List[int]] = None, + is_api_layer: bool = False, + ): + super().__init__() + + if is_api_layer and assigned_layers: + raise RuntimeError("API Node can't handle layer execution") + + self.model_config = model_config + self.is_api_layer = is_api_layer + self.config = config = ModelArgs.from_dict(model_config) + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if getattr(config, "tie_word_embeddings", None) is None: + config.tie_word_embeddings = False + if not config.tie_word_embeddings: + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.layers: List[nn.Module] = [] + self.abs_to_local: Dict[int, int] = {} + + for i, layer in enumerate(sorted(assigned_layers or [])): + self.layers.append(Glm4DecoderLayer(config)) + self.abs_to_local[layer] = i + + self._converted_to_quantized = False + self._cached_mask_state: Optional[int] = None + self._cached_mask = None + + def embed(self, x: mx.array) -> mx.array: + if hasattr(self, "embed_tokens"): + return self.embed_tokens(x) + return x + + def normalize(self, x: mx.array) -> mx.array: + if hasattr(self, "norm"): + return self.norm(x) + return x + + def lm_project(self, x: mx.array) -> mx.array: + if hasattr(self, "lm_head") or hasattr(self, "embed_tokens"): + use_tied = bool(getattr(self.config, "tie_word_embeddings", False)) + if use_tied or not hasattr(self, "lm_head"): + return self.embed_tokens.as_linear(x) + return self.lm_head(x) + return x + + def forward(self, x: mx.array, cache: Optional[List[Any]] = None) -> mx.array: + mask = create_attention_mask(x, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for i, layer in enumerate(self.layers): + x = layer(x, mask, cache[i] if i < len(cache) else None) + + return x + + def apply_single_layer( + self, layer_idx: int, x: mx.array, cache: Optional[List[Any]] = None + ) -> mx.array: + if layer_idx not in self.abs_to_local: + raise RuntimeError(f"Layer {layer_idx} not hosted on this model instance") + try: + T = int(x.shape[1]) if len(x.shape) > 1 else 1 + except Exception: + T = 1 + mask = None + if T > 1: + if self._cached_mask_state is None or self._cached_mask_state != T: + mask = create_attention_mask(x, cache) + self._cached_mask = mask + self._cached_mask_state = T + else: + mask = self._cached_mask + if mask is None: + mask = create_attention_mask(x, cache) + self._cached_mask = mask + self._cached_mask_state = T + local_idx = self.abs_to_local[layer_idx] + + c = None + if cache is not None and local_idx < len(cache): + c = cache[local_idx] + + return self.layers[local_idx](x, mask, c) + + @property + def decoding_layers(self): + return self.layers + + @property + def head_dim(self) -> Tuple[int, int]: + return (self.config.head_dim, self.config.head_dim) + + @property + def n_kv_heads(self) -> int: + return self.config.num_key_value_heads + + @property + def num_layers(self) -> int: + return len(self.layers) From 0adfb07c9940cb9d8d8547f3e888eba94f01229f Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 25 Nov 2025 02:31:16 -0800 Subject: [PATCH 3/7] add olmo3 model --- src/dnet/core/models/__init__.py | 2 + src/dnet/core/models/olmo3.py | 120 +++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 src/dnet/core/models/olmo3.py diff --git a/src/dnet/core/models/__init__.py b/src/dnet/core/models/__init__.py index fc609c8f..5353c994 100644 --- a/src/dnet/core/models/__init__.py +++ b/src/dnet/core/models/__init__.py @@ -10,6 +10,7 @@ from .qwen3 import Qwen3RingModel from .qwen3_moe import Qwen3MoERingModel from .glm4 import Glm4RingModel +from .olmo3 import Olmo3RingModel def get_ring_model( @@ -45,5 +46,6 @@ def get_ring_model( "Qwen3RingModel", "Qwen3MoERingModel" "Glm4RingModel", + "Olmo3RingModel", "get_ring_model", ] diff --git a/src/dnet/core/models/olmo3.py b/src/dnet/core/models/olmo3.py new file mode 100644 index 00000000..c3ce96c5 --- /dev/null +++ b/src/dnet/core/models/olmo3.py @@ -0,0 +1,120 @@ +from typing import Any, Dict, List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.olmo3 import ModelArgs, Olmo3DecoderLayer + +from .base import BaseRingModel + +# NOTE: mlx_lm handles sliding window attention, check 'initialize_rope' for logic +class Olmo3RingModel(BaseRingModel): + + model_type = "olmo3" + + def __init__( + self, + model_config: Any, + assigned_layers: Optional[List[int]] = None, + is_api_layer: bool = False, + ): + super().__init__() + + if is_api_layer and assigned_layers: + raise RuntimeError("API node cannot handle layers") + + self.model_config = model_config + self.is_api_layer = is_api_layer + self.config = config = ModelArgs.from_dict(model_config) + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if not config.tie_word_embeddings: + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.layers: List[nn.Module] = [] + self.abs_to_local: Dict[int, int] = {} + + for i, layer in enumerate(sorted(assigned_layers or [])): + self.layers.append(Olmo3DecoderLayer(config, layer_idx=layer)) + self.abs_to_local[layer] = i + + self._converted_to_quantized = False + self._cached_mask_state: Optional[int] = None + self._cached_mask = None + + def embed(self, x: mx.array) -> mx.array: + if hasattr(self, "embed_tokens"): + return self.embed_tokens(x) + return x + + def normalize(self, x: mx.array) -> mx.array: + if hasattr(self, "norm"): + return self.norm(x) + return x + + def lm_project(self, x: mx.array) -> mx.array: + if hasattr(self, "lm_head") or hasattr(self, "embed_tokens"): + use_tied = bool(getattr(self.config, "tie_word_embeddings", False)) + if use_tied or not hasattr(self, "lm_head"): + return self.embed_tokens.as_linear(x) + return self.lm_head(x) + return x + + def forward(self, x: mx.array, cache: Optional[List[Any]] = None) -> mx.array: + mask = create_attention_mask(x, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for i, layer in enumerate(self.layers): + x = layer(x, mask, cache[i] if i < len(cache) else None) + + return x + + def apply_single_layer( + self, layer_idx: int, x: mx.array, cache: Optional[List[Any]] = None + ) -> mx.array: + if layer_idx not in self.abs_to_local: + raise RuntimeError(f"Layer {layer_idx} not hosted on this model instance") + try: + T = int(x.shape[1]) if len(x.shape) > 1 else 1 + except Exception: + T = 1 + mask = None + if T > 1: + if self._cached_mask_state is None or self._cached_mask_state != T: + mask = create_attention_mask(x, cache) + self._cached_mask = mask + self._cached_mask_state = T + else: + mask = self._cached_mask + if mask is None: + mask = create_attention_mask(x, cache) + self._cached_mask = mask + self._cached_mask_state = T + local_idx = self.abs_to_local[layer_idx] + + c = None + if cache is not None and local_idx < len(cache): + c = cache[local_idx] + + return self.layers[local_idx](x, mask, c) + + @property + def decoding_layers(self): + return self.layers + + @property + def head_dim(self) -> Tuple[int, int]: + head_dim = (self.config.head_dim or + self.config.hidden_size // self.config.num_attention_heads) + return(head_dim, head_dim) + + @property + def n_kv_heads(self) -> int: + return self.config.num_key_value_heads + + @property + def num_layers(self) -> int: + return len(self.layers) From 37e7ebe0b947b4f70ffe9b50e666f53551624c5a Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 25 Nov 2025 06:47:49 -0800 Subject: [PATCH 4/7] ruff format --- src/dnet/core/models/__init__.py | 2 +- src/dnet/core/models/olmo3.py | 10 ++++++---- src/dnet/core/models/qwen3_moe.py | 10 +++++----- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/dnet/core/models/__init__.py b/src/dnet/core/models/__init__.py index 5353c994..513764a2 100644 --- a/src/dnet/core/models/__init__.py +++ b/src/dnet/core/models/__init__.py @@ -44,7 +44,7 @@ def get_ring_model( "LlamaRingModel", "GptOssRingModel", "Qwen3RingModel", - "Qwen3MoERingModel" + "Qwen3MoERingModel", "Glm4RingModel", "Olmo3RingModel", "get_ring_model", diff --git a/src/dnet/core/models/olmo3.py b/src/dnet/core/models/olmo3.py index c3ce96c5..e707058d 100644 --- a/src/dnet/core/models/olmo3.py +++ b/src/dnet/core/models/olmo3.py @@ -7,9 +7,9 @@ from .base import BaseRingModel + # NOTE: mlx_lm handles sliding window attention, check 'initialize_rope' for logic class Olmo3RingModel(BaseRingModel): - model_type = "olmo3" def __init__( @@ -107,9 +107,11 @@ def decoding_layers(self): @property def head_dim(self) -> Tuple[int, int]: - head_dim = (self.config.head_dim or - self.config.hidden_size // self.config.num_attention_heads) - return(head_dim, head_dim) + head_dim = ( + self.config.head_dim + or self.config.hidden_size // self.config.num_attention_heads + ) + return (head_dim, head_dim) @property def n_kv_heads(self) -> int: diff --git a/src/dnet/core/models/qwen3_moe.py b/src/dnet/core/models/qwen3_moe.py index 47236545..8afc2020 100644 --- a/src/dnet/core/models/qwen3_moe.py +++ b/src/dnet/core/models/qwen3_moe.py @@ -158,14 +158,14 @@ def apply_single_layer( # qwen stores expert weights in separate buffers def sanitize(self, weights): if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: - return weights - for l in range(self.args.num_hidden_layers): - prefix = f"model.layers.{l}" + return weights + for layer in range(self.args.num_hidden_layers): + prefix = f"model.layers.{layer}" for n in ["up_proj", "down_proj", "gate_proj"]: if f"{prefix}.mlp.experts.0.{n}.weight" in weights: to_join = [ - weights.pop(f"{prefix}.mlx.experts.{e}.{n}.weight") - for e in range(self.args.num_experts) + weights.pop(f"{prefix}.mlx.experts.{e}.{n}.weight") + for e in range(self.args.num_experts) ] weights[f"{prefix}.mlp.switch_mlx.{n}.weight"] = mx.stack(to_join) return weights From 069aa5012b7ea6f7a5328a31d313b139e4ad566a Mon Sep 17 00:00:00 2001 From: Octavian Date: Tue, 25 Nov 2025 10:05:41 -0800 Subject: [PATCH 5/7] add working models to catalog --- src/dnet/api/catalog.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/dnet/api/catalog.py b/src/dnet/api/catalog.py index bd36384d..cd97fb7f 100644 --- a/src/dnet/api/catalog.py +++ b/src/dnet/api/catalog.py @@ -168,5 +168,41 @@ "quantization": "4bit", "alias": "hermes-4-405b", }, + { + "id": "mlx-community/olmo-3-7b-think-4bit", + "arch": "olmo3", + "quantization": "4bit", + "alias": "olmo3-7b", + }, + { + "id": "mlx-community/Olmo-3-1125-32B-4bit", + "arch": "olmo3", + "quantization": "4bit", + "alias": "olmo3-32b", + }, + { + "id": "mlx-community/GLM-4-9B-0414-4bit", + "arch": "glm4", + "quantization": "4bit", + "alias": "glm4-9b", + }, + { + "id": "mlx-community/GLM-4-9B-0414-8bit", + "arch": "glm4", + "quantization": "8bit", + "alias": "glm4-9b", + }, + { + "id": "mlx-community/GLM-4-32B-0414-4bit", + "arch": "glm4", + "quantization": "4bit", + "alias": "glm4-32b", + }, + { + "id": "mlx-community/Qwen3-30B-A3B-4bit", + "arch": "qwen3_moe", + "quantization": "4bit", + "alias": "qwen3-moe-30b", + }, ] } From 451190c5207d8182db13ce2747d910ccd54595f9 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 27 Nov 2025 00:13:08 -0800 Subject: [PATCH 6/7] fix qwen3_moe switch_glu object --- src/dnet/api/catalog.py | 12 +++++++++++ src/dnet/core/models/qwen3_moe.py | 34 +++++++++++++++---------------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/src/dnet/api/catalog.py b/src/dnet/api/catalog.py index cd97fb7f..7128d7da 100644 --- a/src/dnet/api/catalog.py +++ b/src/dnet/api/catalog.py @@ -204,5 +204,17 @@ "quantization": "4bit", "alias": "qwen3-moe-30b", }, + { + "id": "mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit", + "arch": "qwen3_moe", + "quantization": "4bit", + "alias": "qwen3-moe-30b", + }, + { + "id": "mlx-community/Qwen3-Coder-30B-A3B-Instruct-8bit", + "arch": "qwen3_moe", + "quantization": "8bit", + "alias": "qwen3-moe-30b", + }, ] } diff --git a/src/dnet/core/models/qwen3_moe.py b/src/dnet/core/models/qwen3_moe.py index 8afc2020..346098ff 100644 --- a/src/dnet/core/models/qwen3_moe.py +++ b/src/dnet/core/models/qwen3_moe.py @@ -9,12 +9,12 @@ class Qwen3MoEDenseBlock(nn.Module): - def __init__(self, args: ModelArgs): + def __init__(self, config: ModelArgs): super().__init__() - self.num_experts = args.num_experts - self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False) + self.num_experts = config.num_experts + self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) self.switch_mlp = SwitchGLU( - args.hidden_size, args.intermediate_size, self.num_experts + config.hidden_size, config.moe_intermediate_size, self.num_experts ) def __call__(self, x: mx.array): @@ -31,21 +31,21 @@ def __call__(self, x: mx.array): # force dense execution class Qwen3MoEDecoderLayer(nn.Module): - def __init__(self, args: ModelArgs, layer_idx: int): + def __init__(self, config: ModelArgs, layer_idx: int): super().__init__() - self.args = args - self.self_attn = Attention(args, layer_idx) - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.config = config + self.self_attn = Attention(config, layer_idx) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps + config.hidden_size, eps=config.rms_norm_eps ) - if (layer_idx not in args.mlp_only_layers) and ( - args.num_experts > 0 and (layer_idx + 1) % args.decoder_sparse_step == 0 + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 ): - self.mlp = Qwen3MoEDenseBlock(args) + self.mlp = Qwen3MoEDenseBlock(config) else: - self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.mlp = MLP(config.hidden_size, config.intermediate_size) def __call__( self, @@ -159,15 +159,15 @@ def apply_single_layer( def sanitize(self, weights): if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: return weights - for layer in range(self.args.num_hidden_layers): + for layer in range(self.config.num_hidden_layers): prefix = f"model.layers.{layer}" for n in ["up_proj", "down_proj", "gate_proj"]: if f"{prefix}.mlp.experts.0.{n}.weight" in weights: to_join = [ - weights.pop(f"{prefix}.mlx.experts.{e}.{n}.weight") - for e in range(self.args.num_experts) + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.weight") + for e in range(self.config.num_experts) ] - weights[f"{prefix}.mlp.switch_mlx.{n}.weight"] = mx.stack(to_join) + weights[f"{prefix}.mlp.switch_mlp.{n}.weight"] = mx.stack(to_join) return weights @property From d04d04eca38ea0707ab1e58f0eacdaade1303286 Mon Sep 17 00:00:00 2001 From: Octavian Date: Thu, 27 Nov 2025 00:52:34 -0800 Subject: [PATCH 7/7] update catalogue --- src/dnet/api/catalog.py | 80 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/src/dnet/api/catalog.py b/src/dnet/api/catalog.py index 7128d7da..95f948e8 100644 --- a/src/dnet/api/catalog.py +++ b/src/dnet/api/catalog.py @@ -169,35 +169,113 @@ "alias": "hermes-4-405b", }, { - "id": "mlx-community/olmo-3-7b-think-4bit", + "id": "mlx-community/Olmo-3-1025-7B-4bit", "arch": "olmo3", "quantization": "4bit", "alias": "olmo3-7b", }, + { + "id": "mlx-community/Olmo-3-7B-Think-4bit", + "arch": "olmo3", + "quantization": "4bit", + "alias": "olmo3-7b", + }, + { + "id": "mlx-community/Olmo-3-7B-Think-SFT-4bit", + "arch": "olmo3", + "quantization": "4bit", + "alias": "olmo3-7b", + }, + { + "id": "mlx-community/Olmo-3-7B-Instruct-4bit", + "arch": "olmo3", + "quantization": "4bit", + "alias": "olmo3-7b", + }, + { + "id": "mlx-community/Olmo-3-7B-Instruct-SFT-4bit", + "arch": "olmo3", + "quantization": "4bit", + "alias": "olmo3-7b", + }, + { + "id": "mlx-community/Olmo-3-1025-7B-8bit", + "arch": "olmo3", + "quantization": "8bit", + "alias": "olmo3-7b", + }, + { + "id": "mlx-community/Olmo-3-7B-Think-8bit", + "arch": "olmo3", + "quantization": "8bit", + "alias": "olmo3-7b", + }, + { + "id": "mlx-community/Olmo-3-7B-Think-SFT-8bit", + "arch": "olmo3", + "quantization": "8bit", + "alias": "olmo3-7b", + }, + { + "id": "mlx-community/Olmo-3-7B-Instruct-8bit", + "arch": "olmo3", + "quantization": "8bit", + "alias": "olmo3-7b", + }, + { + "id": "mlx-community/Olmo-3-7B-Instruct-SFT-8bit", + "arch": "olmo3", + "quantization": "8bit", + "alias": "olmo3-7b", + }, { "id": "mlx-community/Olmo-3-1125-32B-4bit", "arch": "olmo3", "quantization": "4bit", "alias": "olmo3-32b", }, + { + "id": "mlx-community/Olmo-3-1125-32B-8bit", + "arch": "olmo3", + "quantization": "8bit", + "alias": "olmo3-32b", + }, { "id": "mlx-community/GLM-4-9B-0414-4bit", "arch": "glm4", "quantization": "4bit", "alias": "glm4-9b", }, + { + "id": "mlx-community/GLM-Z1-9B-0414-4bit", + "arch": "glm4", + "quantization": "4bit", + "alias": "glm4-9b", + }, { "id": "mlx-community/GLM-4-9B-0414-8bit", "arch": "glm4", "quantization": "8bit", "alias": "glm4-9b", }, + { + "id": "mlx-community/GLM-Z1-9B-0414-8bit", + "arch": "glm4", + "quantization": "8bit", + "alias": "glm4-9b", + }, { "id": "mlx-community/GLM-4-32B-0414-4bit", "arch": "glm4", "quantization": "4bit", "alias": "glm4-32b", }, + { + "id": "mlx-community/GLM-Z1-32B-0414-4bit", + "arch": "glm4", + "quantization": "4bit", + "alias": "glm4-32b", + }, { "id": "mlx-community/Qwen3-30B-A3B-4bit", "arch": "qwen3_moe",