Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions src/dnet/api/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,131 @@
"quantization": "4bit",
"alias": "hermes-4-405b",
},
{
"id": "mlx-community/Olmo-3-1025-7B-4bit",
"arch": "olmo3",
"quantization": "4bit",
"alias": "olmo3-7b",
},
{
"id": "mlx-community/Olmo-3-7B-Think-4bit",
"arch": "olmo3",
"quantization": "4bit",
"alias": "olmo3-7b",
},
{
"id": "mlx-community/Olmo-3-7B-Think-SFT-4bit",
"arch": "olmo3",
"quantization": "4bit",
"alias": "olmo3-7b",
},
{
"id": "mlx-community/Olmo-3-7B-Instruct-4bit",
"arch": "olmo3",
"quantization": "4bit",
"alias": "olmo3-7b",
},
{
"id": "mlx-community/Olmo-3-7B-Instruct-SFT-4bit",
"arch": "olmo3",
"quantization": "4bit",
"alias": "olmo3-7b",
},
{
"id": "mlx-community/Olmo-3-1025-7B-8bit",
"arch": "olmo3",
"quantization": "8bit",
"alias": "olmo3-7b",
},
{
"id": "mlx-community/Olmo-3-7B-Think-8bit",
"arch": "olmo3",
"quantization": "8bit",
"alias": "olmo3-7b",
},
{
"id": "mlx-community/Olmo-3-7B-Think-SFT-8bit",
"arch": "olmo3",
"quantization": "8bit",
"alias": "olmo3-7b",
},
{
"id": "mlx-community/Olmo-3-7B-Instruct-8bit",
"arch": "olmo3",
"quantization": "8bit",
"alias": "olmo3-7b",
},
{
"id": "mlx-community/Olmo-3-7B-Instruct-SFT-8bit",
"arch": "olmo3",
"quantization": "8bit",
"alias": "olmo3-7b",
},
{
"id": "mlx-community/Olmo-3-1125-32B-4bit",
"arch": "olmo3",
"quantization": "4bit",
"alias": "olmo3-32b",
},
{
"id": "mlx-community/Olmo-3-1125-32B-8bit",
"arch": "olmo3",
"quantization": "8bit",
"alias": "olmo3-32b",
},
{
"id": "mlx-community/GLM-4-9B-0414-4bit",
"arch": "glm4",
"quantization": "4bit",
"alias": "glm4-9b",
},
{
"id": "mlx-community/GLM-Z1-9B-0414-4bit",
"arch": "glm4",
"quantization": "4bit",
"alias": "glm4-9b",
},
{
"id": "mlx-community/GLM-4-9B-0414-8bit",
"arch": "glm4",
"quantization": "8bit",
"alias": "glm4-9b",
},
{
"id": "mlx-community/GLM-Z1-9B-0414-8bit",
"arch": "glm4",
"quantization": "8bit",
"alias": "glm4-9b",
},
{
"id": "mlx-community/GLM-4-32B-0414-4bit",
"arch": "glm4",
"quantization": "4bit",
"alias": "glm4-32b",
},
{
"id": "mlx-community/GLM-Z1-32B-0414-4bit",
"arch": "glm4",
"quantization": "4bit",
"alias": "glm4-32b",
},
{
"id": "mlx-community/Qwen3-30B-A3B-4bit",
"arch": "qwen3_moe",
"quantization": "4bit",
"alias": "qwen3-moe-30b",
},
{
"id": "mlx-community/Qwen3-Coder-30B-A3B-Instruct-4bit",
"arch": "qwen3_moe",
"quantization": "4bit",
"alias": "qwen3-moe-30b",
},
{
"id": "mlx-community/Qwen3-Coder-30B-A3B-Instruct-8bit",
"arch": "qwen3_moe",
"quantization": "8bit",
"alias": "qwen3-moe-30b",
},
]
}
6 changes: 6 additions & 0 deletions src/dnet/core/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from .llama import LlamaRingModel
from .gpt_oss import GptOssRingModel
from .qwen3 import Qwen3RingModel
from .qwen3_moe import Qwen3MoERingModel
from .glm4 import Glm4RingModel
from .olmo3 import Olmo3RingModel


def get_ring_model(
Expand Down Expand Up @@ -41,5 +44,8 @@ def get_ring_model(
"LlamaRingModel",
"GptOssRingModel",
"Qwen3RingModel",
"Qwen3MoERingModel",
"Glm4RingModel",
"Olmo3RingModel",
"get_ring_model",
]
119 changes: 119 additions & 0 deletions src/dnet/core/models/glm4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Any, Dict, List, Optional, Tuple

import mlx.nn as nn
import mlx.core as mx
from mlx_lm.models.base import create_attention_mask
from mlx_lm.models.glm4 import ModelArgs, Glm4DecoderLayer

from .base import BaseRingModel


class Glm4RingModel(BaseRingModel):
model_type = "glm4"

def __init__(
self,
model_config: Any,
assigned_layers: Optional[List[int]] = None,
is_api_layer: bool = False,
):
super().__init__()

if is_api_layer and assigned_layers:
raise RuntimeError("API Node can't handle layer execution")

self.model_config = model_config
self.is_api_layer = is_api_layer
self.config = config = ModelArgs.from_dict(model_config)

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if getattr(config, "tie_word_embeddings", None) is None:
config.tie_word_embeddings = False
if not config.tie_word_embeddings:
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

self.layers: List[nn.Module] = []
self.abs_to_local: Dict[int, int] = {}

for i, layer in enumerate(sorted(assigned_layers or [])):
self.layers.append(Glm4DecoderLayer(config))
self.abs_to_local[layer] = i

self._converted_to_quantized = False
self._cached_mask_state: Optional[int] = None
self._cached_mask = None

def embed(self, x: mx.array) -> mx.array:
if hasattr(self, "embed_tokens"):
return self.embed_tokens(x)
return x

def normalize(self, x: mx.array) -> mx.array:
if hasattr(self, "norm"):
return self.norm(x)
return x

def lm_project(self, x: mx.array) -> mx.array:
if hasattr(self, "lm_head") or hasattr(self, "embed_tokens"):
use_tied = bool(getattr(self.config, "tie_word_embeddings", False))
if use_tied or not hasattr(self, "lm_head"):
return self.embed_tokens.as_linear(x)
return self.lm_head(x)
return x

def forward(self, x: mx.array, cache: Optional[List[Any]] = None) -> mx.array:
mask = create_attention_mask(x, cache)

if cache is None:
cache = [None] * len(self.layers)

for i, layer in enumerate(self.layers):
x = layer(x, mask, cache[i] if i < len(cache) else None)

return x

def apply_single_layer(
self, layer_idx: int, x: mx.array, cache: Optional[List[Any]] = None
) -> mx.array:
if layer_idx not in self.abs_to_local:
raise RuntimeError(f"Layer {layer_idx} not hosted on this model instance")
try:
T = int(x.shape[1]) if len(x.shape) > 1 else 1
except Exception:
T = 1
mask = None
if T > 1:
if self._cached_mask_state is None or self._cached_mask_state != T:
mask = create_attention_mask(x, cache)
self._cached_mask = mask
self._cached_mask_state = T
else:
mask = self._cached_mask
if mask is None:
mask = create_attention_mask(x, cache)
self._cached_mask = mask
self._cached_mask_state = T
local_idx = self.abs_to_local[layer_idx]

c = None
if cache is not None and local_idx < len(cache):
c = cache[local_idx]

return self.layers[local_idx](x, mask, c)

@property
def decoding_layers(self):
return self.layers

@property
def head_dim(self) -> Tuple[int, int]:
return (self.config.head_dim, self.config.head_dim)

@property
def n_kv_heads(self) -> int:
return self.config.num_key_value_heads

@property
def num_layers(self) -> int:
return len(self.layers)
122 changes: 122 additions & 0 deletions src/dnet/core/models/olmo3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import Any, Dict, List, Optional, Tuple

import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.base import create_attention_mask
from mlx_lm.models.olmo3 import ModelArgs, Olmo3DecoderLayer

from .base import BaseRingModel


# NOTE: mlx_lm handles sliding window attention, check 'initialize_rope' for logic
class Olmo3RingModel(BaseRingModel):
model_type = "olmo3"

def __init__(
self,
model_config: Any,
assigned_layers: Optional[List[int]] = None,
is_api_layer: bool = False,
):
super().__init__()

if is_api_layer and assigned_layers:
raise RuntimeError("API node cannot handle layers")

self.model_config = model_config
self.is_api_layer = is_api_layer
self.config = config = ModelArgs.from_dict(model_config)

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if not config.tie_word_embeddings:
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

self.layers: List[nn.Module] = []
self.abs_to_local: Dict[int, int] = {}

for i, layer in enumerate(sorted(assigned_layers or [])):
self.layers.append(Olmo3DecoderLayer(config, layer_idx=layer))
self.abs_to_local[layer] = i

self._converted_to_quantized = False
self._cached_mask_state: Optional[int] = None
self._cached_mask = None

def embed(self, x: mx.array) -> mx.array:
if hasattr(self, "embed_tokens"):
return self.embed_tokens(x)
return x

def normalize(self, x: mx.array) -> mx.array:
if hasattr(self, "norm"):
return self.norm(x)
return x

def lm_project(self, x: mx.array) -> mx.array:
if hasattr(self, "lm_head") or hasattr(self, "embed_tokens"):
use_tied = bool(getattr(self.config, "tie_word_embeddings", False))
if use_tied or not hasattr(self, "lm_head"):
return self.embed_tokens.as_linear(x)
return self.lm_head(x)
return x

def forward(self, x: mx.array, cache: Optional[List[Any]] = None) -> mx.array:
mask = create_attention_mask(x, cache)

if cache is None:
cache = [None] * len(self.layers)

for i, layer in enumerate(self.layers):
x = layer(x, mask, cache[i] if i < len(cache) else None)

return x

def apply_single_layer(
self, layer_idx: int, x: mx.array, cache: Optional[List[Any]] = None
) -> mx.array:
if layer_idx not in self.abs_to_local:
raise RuntimeError(f"Layer {layer_idx} not hosted on this model instance")
try:
T = int(x.shape[1]) if len(x.shape) > 1 else 1
except Exception:
T = 1
mask = None
if T > 1:
if self._cached_mask_state is None or self._cached_mask_state != T:
mask = create_attention_mask(x, cache)
self._cached_mask = mask
self._cached_mask_state = T
else:
mask = self._cached_mask
if mask is None:
mask = create_attention_mask(x, cache)
self._cached_mask = mask
self._cached_mask_state = T
local_idx = self.abs_to_local[layer_idx]

c = None
if cache is not None and local_idx < len(cache):
c = cache[local_idx]

return self.layers[local_idx](x, mask, c)

@property
def decoding_layers(self):
return self.layers

@property
def head_dim(self) -> Tuple[int, int]:
head_dim = (
self.config.head_dim
or self.config.hidden_size // self.config.num_attention_heads
)
return (head_dim, head_dim)

@property
def n_kv_heads(self) -> int:
return self.config.num_key_value_heads

@property
def num_layers(self) -> int:
return len(self.layers)
Loading