Skip to content

Commit 0277d24

Browse files
committed
add in name remapping for more model support
1 parent 445c0f2 commit 0277d24

File tree

7 files changed

+58
-26
lines changed

7 files changed

+58
-26
lines changed

gadget/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from . import models
1919
from .models.bert import BertModel
2020
from .models.llama import LlamaModel
21+
from .models.names import NAMES
2122

2223
from .ggml import GGMLQuantizationType as T
2324
from .tensor import get_tensor_info

gadget/model.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import defaultdict
77

88
from .ggml import GGMLQuantizationType
9-
from .utils import AttrDict
9+
from .utils import AttrDict, IdentDict
1010
from .loader import GgufFile
1111
from .compute import GgmlCompute
1212

@@ -38,11 +38,11 @@ def resolve_field(key, *dicts):
3838
else:
3939
return key
4040

41-
def eval_parameter(expr, gguf):
41+
def eval_parameter(expr, fields, tensors):
4242
if type(expr) is str:
43-
return gguf.get_field(expr)
43+
return fields[expr]
4444
elif callable(expr):
45-
return expr(gguf)
45+
return expr(fields, tensors)
4646
return expr
4747

4848
##
@@ -92,42 +92,50 @@ def from_values(cls, values=None, backend=None, framework=None, **params):
9292
return self
9393

9494
@classmethod
95-
def from_gguf(cls, gguf, backend=None, framework=None, **params):
96-
# get metadata from gguf
97-
weights = {
95+
def from_gguf(cls, gguf, names=None, backend=None, framework=None, **params):
96+
# make name mappers
97+
names = IdentDict({} if names is None else names)
98+
rnames = IdentDict({v: k for k, v in names.items()})
99+
100+
# map field and tensor names
101+
fields0 = {names[k]: v for k, v in gguf.fields.items()}
102+
weights0 = {names[k]: v for k, v in gguf.tensors.items()}
103+
104+
# get weights metadata
105+
weights0_meta = {
98106
key: (ttype, shape)
99-
for key, (ttype, shape, array) in gguf.tensors.items()
107+
for key, (ttype, shape, array) in weights0.items()
100108
}
101109

102110
# get type hints for model
103111
hints = get_type_hints(cls)
104112

105113
# get default parameters
106114
params0 = {
107-
k: eval_parameter(v.field, gguf)
115+
k: eval_parameter(v.field, fields0, weights0_meta)
108116
for k, v in hints.items() if type(v) is Parameter
109117
}
110118

111119
# get state fields
112120
states = {
113-
k: eval_parameter(v.field, gguf)
121+
k: eval_parameter(v.field, fields0, weights0_meta)
114122
for k, v in hints.items() if type(v) is State
115123
}
116124

117-
# resolve tensor shapes
118-
tensors = {
119-
k: (t.ttype, [resolve_field(x, params, params0, gguf.fields) for x in t.shape])
125+
# resolve input shapes
126+
inputs_meta = {
127+
k: (t.ttype, [resolve_field(x, params, params0, fields0) for x in t.shape])
120128
for k, t in hints.items() if type(t) is Tensor
121129
}
122130

123131
# create model and graph
124132
self = cls(
125-
gguf.fields | params0 | params, weights | tensors,
133+
fields0 | params0 | params, weights0_meta | inputs_meta,
126134
states, backend=backend, framework=framework
127135
)
128136

129137
# assign tensors on backend
130-
for name, (ttype, shape, tensor) in gguf.tensors.items():
138+
for name, (ttype, shape, tensor) in weights0.items():
131139
self.set_input(name, tensor)
132140

133141
# return model

gadget/models/bert.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,7 @@ def forward(self):
9292
]
9393

9494
# get attention interactions
95-
att = attention_layer(
96-
ctx, cur, n_heads, mask, wq, wk, wv, wo, bq=bq, bk=bk, bv=bv, bo=bo,
97-
eps=layer_norm_eps, name=f'attn{i}'
98-
)
95+
att = attention_layer(ctx, cur, n_heads, mask, wq, wk, wv, wo, bq=bq, bk=bk, bv=bv, bo=bo, name=f'attn{i}')
9996

10097
# add attention output to current then normalize
10198
att = ggml_add_inplace(ctx, cur, att)

gadget/models/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def rope_extended(
5656

5757
def attention_layer(
5858
ctx, x, n_heads, mask, wq, wk, wv, wo, bq=None, bk=None, bv=None, bo=None, n_heads_kv=None,
59-
rope_freqs=None, rope_base=None, eps=0.0, positions=None, alibi=0.0, kv_cache=None, name=None
59+
rope_freqs=None, rope_base=None, positions=None, alibi=0.0, kv_cache=None, name=None
6060
):
6161
# get n_heads_q and n_heads_kv
6262
n_heads_q = n_heads

gadget/models/llama.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
## llama model
2525
##
2626

27-
def get_head_dim_kv(gguf):
28-
n_head_kv = gguf.get_field('llama.attention.head_count_kv')
29-
embed_size_kv = gguf.get_tensor_shape('blk.0.attn_k.weight')[1]
27+
def get_head_dim_kv(fields, tensors):
28+
n_head_kv = fields['llama.attention.head_count_kv']
29+
_, (_, embed_size_kv) = tensors['blk.0.attn_k.weight']
3030
assert embed_size_kv % n_head_kv == 0
3131
return embed_size_kv // n_head_kv
3232

@@ -140,8 +140,7 @@ def forward(self):
140140
att = norm_layer(ctx, cur, wan, rms=True, eps=layer_norm_rms_eps, name=f'attn{i}_norm')
141141
att = attention_layer(
142142
ctx, att, n_heads_q, mask, wq, wk, wv, wo, positions=positions, n_heads_kv=n_heads_kv,
143-
rope_freqs=rope_freqs, rope_base=rope_base, eps=layer_norm_rms_eps, kv_cache=cache,
144-
name=f'attn{i}'
143+
rope_freqs=rope_freqs, rope_base=rope_base, kv_cache=cache, name=f'attn{i}'
145144
)
146145

147146
# add layer input to attention
@@ -156,7 +155,7 @@ def forward(self):
156155

157156
# get output tensors
158157
onw = self.tensors['output_norm.weight']
159-
ow = self.tensors.get('output.weight', etok)
158+
ow = self.tensors.get('output.weight', etok) # fall back to tied embeddings
160159

161160
# generate output
162161
cur = norm_layer(ctx, cur, onw, rms=True, eps=layer_norm_rms_eps, name='output_norm')

gadget/models/names.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# model name mappings
2+
3+
# llama3.1 is the default
4+
NAMES_LLAMA31 = {}
5+
6+
# qwen3 is similar
7+
NAMES_QWEN3_EMBED = {
8+
'qwen3.context_length' : 'llama.context_length',
9+
'qwen3.block_count' : 'llama.block_count',
10+
'qwen3.attention.head_count' : 'llama.attention.head_count',
11+
'qwen3.attention.head_count_kv' : 'llama.attention.head_count_kv',
12+
'qwen3.rope.freq_base' : 'llama.rope.freq_base',
13+
'qwen3.attention.layer_norm_rms_epsilon': 'llama.attention.layer_norm_rms_epsilon',
14+
}
15+
16+
# final name map
17+
NAMES = {
18+
'LlamaForCausalLM': NAMES_LLAMA31,
19+
'Qwen3ForCausalLM': NAMES_QWEN3_EMBED,
20+
}

gadget/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ def subset(self, keys):
5656
# get subset dict
5757
return {k: self[k] for k in keys}
5858

59+
# dictionary that yields value=key when key is not found
60+
class IdentDict(UserDict):
61+
def __getitem__(self, key):
62+
if key not in self:
63+
return key
64+
return super().__getitem__(key)
65+
5966
# = defaultdict(list)
6067
# + handles popping off maximal list
6168
# + handles deletion on empty list

0 commit comments

Comments
 (0)