33from dataclasses import dataclass
44from typing import Optional , Any
55import math
6+ import logging
67
78from comfy .ldm .modules .attention import optimized_attention_for_device
89import comfy .model_management
@@ -28,6 +29,9 @@ class Llama2Config:
2829 mlp_activation = "silu"
2930 qkv_bias = False
3031 rope_dims = None
32+ q_norm = None
33+ k_norm = None
34+ rope_scale = None
3135
3236@dataclass
3337class Qwen25_3BConfig :
@@ -46,6 +50,9 @@ class Qwen25_3BConfig:
4650 mlp_activation = "silu"
4751 qkv_bias = True
4852 rope_dims = None
53+ q_norm = None
54+ k_norm = None
55+ rope_scale = None
4956
5057@dataclass
5158class Qwen25_7BVLI_Config :
@@ -64,6 +71,9 @@ class Qwen25_7BVLI_Config:
6471 mlp_activation = "silu"
6572 qkv_bias = True
6673 rope_dims = [16 , 24 , 24 ]
74+ q_norm = None
75+ k_norm = None
76+ rope_scale = None
6777
6878@dataclass
6979class Gemma2_2B_Config :
@@ -82,6 +92,32 @@ class Gemma2_2B_Config:
8292 mlp_activation = "gelu_pytorch_tanh"
8393 qkv_bias = False
8494 rope_dims = None
95+ q_norm = None
96+ k_norm = None
97+ sliding_attention = None
98+ rope_scale = None
99+
100+ @dataclass
101+ class Gemma3_4B_Config :
102+ vocab_size : int = 262208
103+ hidden_size : int = 2560
104+ intermediate_size : int = 10240
105+ num_hidden_layers : int = 34
106+ num_attention_heads : int = 8
107+ num_key_value_heads : int = 4
108+ max_position_embeddings : int = 131072
109+ rms_norm_eps : float = 1e-6
110+ rope_theta = [10000.0 , 1000000.0 ]
111+ transformer_type : str = "gemma3"
112+ head_dim = 256
113+ rms_norm_add = True
114+ mlp_activation = "gelu_pytorch_tanh"
115+ qkv_bias = False
116+ rope_dims = None
117+ q_norm = "gemma3"
118+ k_norm = "gemma3"
119+ sliding_attention = [False , False , False , False , False , 1024 ]
120+ rope_scale = [1.0 , 8.0 ]
85121
86122class RMSNorm (nn .Module ):
87123 def __init__ (self , dim : int , eps : float = 1e-5 , add = False , device = None , dtype = None ):
@@ -106,25 +142,40 @@ def rotate_half(x):
106142 return torch .cat ((- x2 , x1 ), dim = - 1 )
107143
108144
109- def precompute_freqs_cis (head_dim , position_ids , theta , rope_dims = None , device = None ):
110- theta_numerator = torch .arange (0 , head_dim , 2 , device = device ).float ()
111- inv_freq = 1.0 / (theta ** (theta_numerator / head_dim ))
145+ def precompute_freqs_cis (head_dim , position_ids , theta , rope_scale = None , rope_dims = None , device = None ):
146+ if not isinstance (theta , list ):
147+ theta = [theta ]
148+
149+ out = []
150+ for index , t in enumerate (theta ):
151+ theta_numerator = torch .arange (0 , head_dim , 2 , device = device ).float ()
152+ inv_freq = 1.0 / (t ** (theta_numerator / head_dim ))
153+
154+ if rope_scale is not None :
155+ if isinstance (rope_scale , list ):
156+ inv_freq /= rope_scale [index ]
157+ else :
158+ inv_freq /= rope_scale
159+
160+ inv_freq_expanded = inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 )
161+ position_ids_expanded = position_ids [:, None , :].float ()
162+ freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
163+ emb = torch .cat ((freqs , freqs ), dim = - 1 )
164+ cos = emb .cos ()
165+ sin = emb .sin ()
166+ if rope_dims is not None and position_ids .shape [0 ] > 1 :
167+ mrope_section = rope_dims * 2
168+ cos = torch .cat ([m [i % 3 ] for i , m in enumerate (cos .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (0 )
169+ sin = torch .cat ([m [i % 3 ] for i , m in enumerate (sin .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (0 )
170+ else :
171+ cos = cos .unsqueeze (1 )
172+ sin = sin .unsqueeze (1 )
173+ out .append ((cos , sin ))
112174
113- inv_freq_expanded = inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 )
114- position_ids_expanded = position_ids [:, None , :].float ()
115- freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
116- emb = torch .cat ((freqs , freqs ), dim = - 1 )
117- cos = emb .cos ()
118- sin = emb .sin ()
119- if rope_dims is not None and position_ids .shape [0 ] > 1 :
120- mrope_section = rope_dims * 2
121- cos = torch .cat ([m [i % 3 ] for i , m in enumerate (cos .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (0 )
122- sin = torch .cat ([m [i % 3 ] for i , m in enumerate (sin .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (0 )
123- else :
124- cos = cos .unsqueeze (1 )
125- sin = sin .unsqueeze (1 )
175+ if len (out ) == 1 :
176+ return out [0 ]
126177
127- return ( cos , sin )
178+ return out
128179
129180
130181def apply_rope (xq , xk , freqs_cis ):
@@ -152,6 +203,14 @@ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = Non
152203 self .v_proj = ops .Linear (config .hidden_size , self .num_kv_heads * self .head_dim , bias = config .qkv_bias , device = device , dtype = dtype )
153204 self .o_proj = ops .Linear (self .inner_size , config .hidden_size , bias = False , device = device , dtype = dtype )
154205
206+ self .q_norm = None
207+ self .k_norm = None
208+
209+ if config .q_norm == "gemma3" :
210+ self .q_norm = RMSNorm (self .head_dim , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
211+ if config .k_norm == "gemma3" :
212+ self .k_norm = RMSNorm (self .head_dim , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
213+
155214 def forward (
156215 self ,
157216 hidden_states : torch .Tensor ,
@@ -168,6 +227,11 @@ def forward(
168227 xk = xk .view (batch_size , seq_length , self .num_kv_heads , self .head_dim ).transpose (1 , 2 )
169228 xv = xv .view (batch_size , seq_length , self .num_kv_heads , self .head_dim ).transpose (1 , 2 )
170229
230+ if self .q_norm is not None :
231+ xq = self .q_norm (xq )
232+ if self .k_norm is not None :
233+ xk = self .k_norm (xk )
234+
171235 xq , xk = apply_rope (xq , xk , freqs_cis = freqs_cis )
172236
173237 xk = xk .repeat_interleave (self .num_heads // self .num_kv_heads , dim = 1 )
@@ -192,7 +256,7 @@ def forward(self, x):
192256 return self .down_proj (self .activation (self .gate_proj (x )) * self .up_proj (x ))
193257
194258class TransformerBlock (nn .Module ):
195- def __init__ (self , config : Llama2Config , device = None , dtype = None , ops : Any = None ):
259+ def __init__ (self , config : Llama2Config , index , device = None , dtype = None , ops : Any = None ):
196260 super ().__init__ ()
197261 self .self_attn = Attention (config , device = device , dtype = dtype , ops = ops )
198262 self .mlp = MLP (config , device = device , dtype = dtype , ops = ops )
@@ -226,7 +290,7 @@ def forward(
226290 return x
227291
228292class TransformerBlockGemma2 (nn .Module ):
229- def __init__ (self , config : Llama2Config , device = None , dtype = None , ops : Any = None ):
293+ def __init__ (self , config : Llama2Config , index , device = None , dtype = None , ops : Any = None ):
230294 super ().__init__ ()
231295 self .self_attn = Attention (config , device = device , dtype = dtype , ops = ops )
232296 self .mlp = MLP (config , device = device , dtype = dtype , ops = ops )
@@ -235,13 +299,28 @@ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = Non
235299 self .pre_feedforward_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
236300 self .post_feedforward_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
237301
302+ if config .sliding_attention is not None : # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
303+ self .sliding_attention = config .sliding_attention [index % len (config .sliding_attention )]
304+ else :
305+ self .sliding_attention = False
306+
307+ self .transformer_type = config .transformer_type
308+
238309 def forward (
239310 self ,
240311 x : torch .Tensor ,
241312 attention_mask : Optional [torch .Tensor ] = None ,
242313 freqs_cis : Optional [torch .Tensor ] = None ,
243314 optimized_attention = None ,
244315 ):
316+ if self .transformer_type == 'gemma3' :
317+ if self .sliding_attention :
318+ if x .shape [1 ] > self .sliding_attention :
319+ logging .warning ("Warning: sliding attention not implemented, results may be incorrect" )
320+ freqs_cis = freqs_cis [1 ]
321+ else :
322+ freqs_cis = freqs_cis [0 ]
323+
245324 # Self Attention
246325 residual = x
247326 x = self .input_layernorm (x )
@@ -276,16 +355,16 @@ def __init__(self, config, device=None, dtype=None, ops=None):
276355 device = device ,
277356 dtype = dtype
278357 )
279- if self .config .transformer_type == "gemma2" :
358+ if self .config .transformer_type == "gemma2" or self . config . transformer_type == "gemma3" :
280359 transformer = TransformerBlockGemma2
281360 self .normalize_in = True
282361 else :
283362 transformer = TransformerBlock
284363 self .normalize_in = False
285364
286365 self .layers = nn .ModuleList ([
287- transformer (config , device = device , dtype = dtype , ops = ops )
288- for _ in range (config .num_hidden_layers )
366+ transformer (config , index = i , device = device , dtype = dtype , ops = ops )
367+ for i in range (config .num_hidden_layers )
289368 ])
290369 self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
291370 # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
@@ -305,6 +384,7 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
305384 freqs_cis = precompute_freqs_cis (self .config .head_dim ,
306385 position_ids ,
307386 self .config .rope_theta ,
387+ self .config .rope_scale ,
308388 self .config .rope_dims ,
309389 device = x .device )
310390
@@ -433,3 +513,12 @@ def __init__(self, config_dict, dtype, device, operations):
433513
434514 self .model = Llama2_ (config , device = device , dtype = dtype , ops = operations )
435515 self .dtype = dtype
516+
517+ class Gemma3_4B (BaseLlama , torch .nn .Module ):
518+ def __init__ (self , config_dict , dtype , device , operations ):
519+ super ().__init__ ()
520+ config = Gemma3_4B_Config (** config_dict )
521+ self .num_layers = config .num_hidden_layers
522+
523+ self .model = Llama2_ (config , device = device , dtype = dtype , ops = operations )
524+ self .dtype = dtype
0 commit comments