-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
98 lines (79 loc) · 3.89 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import jax
import jax.numpy as jnp
from flax import nnx
from jax.sharding import Mesh
from omegaconf.dictconfig import DictConfig
class TransformerDecoder(nnx.Module):
def __init__(self, cfg: DictConfig, rngs: nnx.Rngs):
self.embed = nnx.Embed(num_embeddings=cfg.V, features=cfg.D, embedding_init=fsdp_init('embedding', cfg.fsdp_enabled), rngs=rngs)
self.pos_embed = nnx.Embed(num_embeddings=cfg.L, features=cfg.D, embedding_init=fsdp_init('embedding', cfg.fsdp_enabled), rngs=rngs)
self.blocks = [TransformerBlock(cfg, rngs) for _ in range(cfg.N)]
self.out_ln = nnx.LayerNorm(cfg.D, use_bias=False, dtype=cfg.dtype, rngs=rngs)
def __call__(self, x): # [B, S]
# Token + positional embedding
h = self.embed(x) + self.pos_embed(jnp.arange(x.shape[1])[None, ...]) # [B, S, D]
# Transformer blocks
for block in self.blocks:
h = block(h)
# Project back to vocabulary
h = self.out_ln(h)
return self.embed.attend(h.astype(jnp.float32)) # [B, S, V]
class TransformerBlock(nnx.Module):
def __init__(self, cfg: DictConfig, rngs: nnx.Rngs):
self.ln1 = nnx.LayerNorm(cfg.D, use_bias=False, dtype=cfg.dtype, rngs=rngs)
self.attn = nnx.MultiHeadAttention(
num_heads=cfg.H, in_features=cfg.D, qkv_features=cfg.D, out_features=cfg.D,
kernel_init=fsdp_init('attn_in_proj', cfg.fsdp_enabled), out_kernel_init=fsdp_init('attn_out_proj', cfg.fsdp_enabled),
use_bias=False, dtype=cfg.dtype, rngs=rngs, decode=False,
)
self.ln2 = nnx.LayerNorm(cfg.D, use_bias=False, dtype=cfg.dtype, rngs=rngs)
self.mlp = Mlp(cfg, rngs)
def __call__(self, x): # [B, S, D]
# Pre-layernorm attention block
h = self.ln1(x)
mask = nnx.make_causal_mask(jnp.ones((x.shape[0], x.shape[1])), dtype=x.dtype)
x = x + self.attn(h, mask=mask)
# Pre-layernorm MLP block
return x + self.mlp(self.ln2(x))
class Mlp(nnx.Module):
"""Multilayer perceptron."""
def __init__(self, cfg: DictConfig, rngs: nnx.Rngs):
kernel_init = fsdp_init('mlp_kernel', cfg.fsdp_enabled)
self.fc1 = nnx.Linear(in_features=cfg.D, out_features=cfg.F, use_bias=False, kernel_init=kernel_init, dtype=cfg.dtype, rngs=rngs)
self.fc2 = nnx.Linear(in_features=cfg.F, out_features=cfg.D, use_bias=False, kernel_init=kernel_init, dtype=cfg.dtype, rngs=rngs)
def __call__(self, x): # [B, S, D]
h = jax.nn.gelu(self.fc1(x)) # [B, S, F]
return self.fc2(h) # [B, S, D]
def fsdp_init(layer_type: str, fsdp_enabled: bool):
"""Initialize weights with optional FSDP partitioning."""
partition_fn = nnx.with_partitioning if fsdp_enabled else lambda x, _: x
kernel_init = jax.nn.initializers.xavier_uniform()
embed_init = jax.nn.initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0)
match layer_type:
case "embedding": # [V, D]
return partition_fn(embed_init, (None, "data"))
case "attn_in_proj": # [D, H, D/H]
return partition_fn(kernel_init, ("data", None, None))
case "attn_out_proj": # [H, D/H, D]
return partition_fn(kernel_init, (None, None, "data"))
case "mlp_kernel": # [D, F]
return partition_fn(kernel_init, ("data", None))
case _:
raise ValueError(f"unrecognized layer type: {layer_type}")
def create_sharded_model(c: DictConfig, mesh: Mesh, seed: int):
"""
initialize sharded model without putting it on a single device
https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html
TODO: add rng key
"""
@nnx.jit
def initialize_sharded_model():
model = TransformerDecoder(c, rngs=nnx.Rngs(seed)) # unsharded at this moment
state = nnx.state(model) # the model's state, a pure pytree
pspecs = nnx.get_partition_spec(state) # get annotations from state
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state) # the model is sharded now
return model
with mesh:
model = initialize_sharded_model()
return model