Skip to content

Commit 4fad6eb

Browse files
chapman20jAatman09
authored andcommitted
vit explicit rngs in call (jax-ml#82)
1 parent 7a08245 commit 4fad6eb

File tree

4 files changed

+46
-63
lines changed

4 files changed

+46
-63
lines changed

bonsai/models/vit/modeling.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class ModelConfig:
1111
patch_size: tuple[int, int]
1212
num_channels: int
1313
hidden_dim: int
14+
attn_dropout_prob: float
1415
dropout_prob: float
1516
num_heads: int
1617
mlp_dim: int
@@ -25,6 +26,7 @@ def vit_p16_224(cls):
2526
patch_size=(16, 16),
2627
num_channels=3,
2728
hidden_dim=768,
29+
attn_dropout_prob=0.0,
2830
dropout_prob=0.0,
2931
num_heads=12,
3032
mlp_dim=3072,
@@ -66,9 +68,9 @@ def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs):
6668
)
6769
self.cls_token = nnx.Variable(jax.random.normal(rngs.params(), (1, 1, cfg.hidden_dim)))
6870
self.pos_embeddings = nnx.Variable(jax.random.normal(rngs.params(), (1, num_patches + 1, cfg.hidden_dim)))
69-
self.dropout = nnx.Dropout(cfg.dropout_prob, rngs=rngs)
71+
self.dropout = nnx.Dropout(cfg.dropout_prob)
7072

71-
def __call__(self, pixel_values: jnp.ndarray) -> jnp.ndarray:
73+
def __call__(self, pixel_values: jnp.ndarray, *, rngs: nnx.Rngs | None) -> jnp.ndarray:
7274
embeddings = self.projection(pixel_values)
7375
b, h, w, c = embeddings.shape
7476
embeddings = embeddings.reshape(b, h * w, c)
@@ -89,49 +91,54 @@ def __call__(self, pixel_values: jnp.ndarray) -> jnp.ndarray:
8991
embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)
9092

9193
embeddings = embeddings + current_pos_embeddings
92-
embeddings = self.dropout(embeddings)
94+
embeddings = self.dropout(embeddings, rngs=rngs)
9395
return embeddings
9496

9597

9698
class TransformerEncoder(nnx.Module):
9799
def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs):
98100
self.attention = nnx.MultiHeadAttention(
99-
num_heads=cfg.num_heads, in_features=cfg.hidden_dim, decode=False, rngs=rngs
101+
num_heads=cfg.num_heads,
102+
in_features=cfg.hidden_dim,
103+
dropout_rate=cfg.attn_dropout_prob,
104+
decode=False,
105+
rngs=rngs,
100106
)
101107
self.linear1 = nnx.Linear(cfg.hidden_dim, cfg.mlp_dim, rngs=rngs)
102108
self.linear2 = nnx.Linear(cfg.mlp_dim, cfg.hidden_dim, rngs=rngs)
103-
self.dropout = nnx.Dropout(cfg.dropout_prob, rngs=rngs)
109+
self.dropout = nnx.Dropout(cfg.dropout_prob)
104110
self.layernorm_before = nnx.LayerNorm(cfg.hidden_dim, epsilon=cfg.eps, rngs=rngs)
105111
self.layernorm_after = nnx.LayerNorm(cfg.hidden_dim, epsilon=cfg.eps, rngs=rngs)
106112

107-
def __call__(self, hidden_states, head_mask=None):
113+
def __call__(self, hidden_states, head_mask=None, *, rngs: nnx.Rngs | None):
108114
hidden_states_norm = self.layernorm_before(hidden_states)
109-
attention_output = self.attention(hidden_states_norm, head_mask)
115+
attention_output = self.attention(hidden_states_norm, head_mask, rngs=rngs)
110116
hidden_states = attention_output + hidden_states
111117
layer_output = self.layernorm_after(hidden_states)
112118
layer_output = jax.nn.gelu(self.linear1(layer_output))
113119
layer_output = self.linear2(layer_output)
114-
layer_output = self.dropout(layer_output)
120+
layer_output = self.dropout(layer_output, rngs=rngs)
115121
layer_output += hidden_states
116122
return layer_output
117123

118124

119125
class ViTClassificationModel(nnx.Module):
120126
def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs):
121127
self.pos_embeddings = Embeddings(cfg, rngs=rngs)
122-
self.layers = nnx.Sequential(*[TransformerEncoder(cfg, rngs=rngs) for _ in range(cfg.num_layers)])
128+
self.layers = nnx.List([TransformerEncoder(cfg, rngs=rngs) for _ in range(cfg.num_layers)])
123129
self.ln = nnx.LayerNorm(cfg.hidden_dim, epsilon=cfg.eps, rngs=rngs)
124130
self.classifier = nnx.Linear(cfg.hidden_dim, cfg.num_labels, rngs=rngs)
125131

126-
def __call__(self, x):
127-
x = self.pos_embeddings(x)
128-
x = self.layers(x)
132+
def __call__(self, x, *, rngs: nnx.Rngs | None):
133+
x = self.pos_embeddings(x, rngs=rngs)
134+
for layer in self.layers:
135+
x = layer(x, rngs=rngs)
129136
x = self.ln(x)
130137
x = self.classifier(x[:, 0, :])
131138
return x
132139

133140

134141
@jax.jit
135-
def forward(graphdef: nnx.GraphDef[nnx.Module], state: nnx.State, x: jax.Array) -> jax.Array:
142+
def forward(graphdef: nnx.GraphDef[nnx.Module], state: nnx.State, x: jax.Array, rngs: nnx.Rngs) -> jax.Array:
136143
model = nnx.merge(graphdef, state)
137-
return model(x)
144+
return model(x, rngs=rngs)

bonsai/models/vit/params.py

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import re
1717
from enum import Enum
1818

19-
import jax
2019
import jax.numpy as jnp
2120
import safetensors.flax as safetensors
2221
from etils import epath
@@ -50,60 +49,42 @@ class Transform(Enum):
5049
r"^vit.embeddings.patch_embeddings.projection.weight$": (r"pos_embeddings.projection.kernel", Transform.CONV2D),
5150
r"^vit.embeddings.position_embeddings$": (r"pos_embeddings.pos_embeddings", Transform.EMBED),
5251
r"^vit.encoder.layer.([0-9]+).attention.attention.key.bias$": (
53-
r"layers.layers.\1.attention.key.bias",
52+
r"layers.\1.attention.key.bias",
5453
Transform.ATTN_KQV_BIAS,
5554
),
5655
r"^vit.encoder.layer.([0-9]+).attention.attention.key.weight$": (
57-
r"layers.layers.\1.attention.key.kernel",
56+
r"layers.\1.attention.key.kernel",
5857
Transform.ATTN_KQV_KERNEL,
5958
),
6059
r"^vit.encoder.layer.([0-9]+).attention.attention.query.bias$": (
61-
r"layers.layers.\1.attention.query.bias",
60+
r"layers.\1.attention.query.bias",
6261
Transform.ATTN_KQV_BIAS,
6362
),
6463
r"^vit.encoder.layer.([0-9]+).attention.attention.query.weight$": (
65-
r"layers.layers.\1.attention.query.kernel",
64+
r"layers.\1.attention.query.kernel",
6665
Transform.ATTN_KQV_KERNEL,
6766
),
6867
r"^vit.encoder.layer.([0-9]+).attention.attention.value.bias$": (
69-
r"layers.layers.\1.attention.value.bias",
68+
r"layers.\1.attention.value.bias",
7069
Transform.ATTN_KQV_BIAS,
7170
),
7271
r"^vit.encoder.layer.([0-9]+).attention.attention.value.weight$": (
73-
r"layers.layers.\1.attention.value.kernel",
72+
r"layers.\1.attention.value.kernel",
7473
Transform.ATTN_KQV_KERNEL,
7574
),
76-
r"^vit.encoder.layer.([0-9]+).attention.output.dense.bias$": (
77-
r"layers.layers.\1.attention.out.bias",
78-
Transform.BIAS,
79-
),
75+
r"^vit.encoder.layer.([0-9]+).attention.output.dense.bias$": (r"layers.\1.attention.out.bias", Transform.BIAS),
8076
r"^vit.encoder.layer.([0-9]+).attention.output.dense.weight$": (
81-
r"layers.layers.\1.attention.out.kernel",
77+
r"layers.\1.attention.out.kernel",
8278
Transform.ATTN_OUT,
8379
),
84-
r"^vit.encoder.layer.([0-9]+).intermediate.dense.bias$": (r"layers.layers.\1.linear1.bias", Transform.BIAS),
85-
r"^vit.encoder.layer.([0-9]+).intermediate.dense.weight$": (
86-
r"layers.layers.\1.linear1.kernel",
87-
Transform.LINEAR,
88-
),
89-
r"^vit.encoder.layer.([0-9]+).layernorm_after.bias$": (
90-
r"layers.layers.\1.layernorm_after.bias",
91-
Transform.BIAS,
92-
),
93-
r"^vit.encoder.layer.([0-9]+).layernorm_after.weight$": (
94-
r"layers.layers.\1.layernorm_after.scale",
95-
Transform.SCALE,
96-
),
97-
r"^vit.encoder.layer.([0-9]+).layernorm_before.bias$": (
98-
r"layers.layers.\1.layernorm_before.bias",
99-
Transform.BIAS,
100-
),
101-
r"^vit.encoder.layer.([0-9]+).layernorm_before.weight$": (
102-
r"layers.layers.\1.layernorm_before.scale",
103-
Transform.SCALE,
104-
),
105-
r"^vit.encoder.layer.([0-9]+).output.dense.bias$": (r"layers.layers.\1.linear2.bias", Transform.BIAS),
106-
r"^vit.encoder.layer.([0-9]+).output.dense.weight$": (r"layers.layers.\1.linear2.kernel", Transform.LINEAR),
80+
r"^vit.encoder.layer.([0-9]+).intermediate.dense.bias$": (r"layers.\1.linear1.bias", Transform.BIAS),
81+
r"^vit.encoder.layer.([0-9]+).intermediate.dense.weight$": (r"layers.\1.linear1.kernel", Transform.LINEAR),
82+
r"^vit.encoder.layer.([0-9]+).layernorm_after.bias$": (r"layers.\1.layernorm_after.bias", Transform.BIAS),
83+
r"^vit.encoder.layer.([0-9]+).layernorm_after.weight$": (r"layers.\1.layernorm_after.scale", Transform.SCALE),
84+
r"^vit.encoder.layer.([0-9]+).layernorm_before.bias$": (r"layers.\1.layernorm_before.bias", Transform.BIAS),
85+
r"^vit.encoder.layer.([0-9]+).layernorm_before.weight$": (r"layers.\1.layernorm_before.scale", Transform.SCALE),
86+
r"^vit.encoder.layer.([0-9]+).output.dense.bias$": (r"layers.\1.linear2.bias", Transform.BIAS),
87+
r"^vit.encoder.layer.([0-9]+).output.dense.weight$": (r"layers.\1.linear2.kernel", Transform.LINEAR),
10788
r"^vit.layernorm.bias$": (r"ln.bias", Transform.BIAS),
10889
r"^vit.layernorm.weight$": (r"ln.scale", Transform.SCALE),
10990
}
@@ -149,12 +130,7 @@ def _stoi(s):
149130
return s
150131

151132

152-
def create_vit_from_pretrained(
153-
file_dir: str,
154-
config: model_lib.ModelConfig,
155-
*,
156-
mesh: jax.sharding.Mesh | None = None,
157-
):
133+
def create_vit_from_pretrained(file_dir: str, config: model_lib.ModelConfig):
158134
"""
159135
Load safetensor weights from a file, then convert & merge into a flax.nnx ViT model.
160136

bonsai/models/vit/tests/run_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,19 @@ def run_model():
3939
dummy_input = jnp.ones((batch_size, image_size, image_size, channels), dtype=jnp.float32)
4040

4141
# Warmup (triggers compilation)
42-
_ = model_lib.forward(graphdef, flat_state, dummy_input).block_until_ready()
42+
_ = model_lib.forward(graphdef, flat_state, dummy_input, None).block_until_ready()
4343

4444
# Profile a few steps
4545
jax.profiler.start_trace("/tmp/profile-vit")
4646
for _ in range(5):
47-
logits = model_lib.forward(graphdef, flat_state, dummy_input)
47+
logits = model_lib.forward(graphdef, flat_state, dummy_input, None)
4848
jax.block_until_ready(logits)
4949
jax.profiler.stop_trace()
5050

5151
# Timed execution
5252
t0 = time.perf_counter()
5353
for _ in range(10):
54-
logits = model_lib.forward(graphdef, flat_state, dummy_input).block_until_ready()
54+
logits = model_lib.forward(graphdef, flat_state, dummy_input, None).block_until_ready()
5555
print(f"Step time: {(time.perf_counter() - t0) / 10:.4f} s")
5656

5757
# Show top-1 predicted class

bonsai/models/vit/tests/test_outputs_vit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,21 @@ def test_embeddings(self):
3232

3333
with torch.no_grad():
3434
ty = torch_emb(tx)
35-
jy = nnx_emb(jx)
35+
jy = nnx_emb(jx, rngs=None)
3636

3737
torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=1e-5)
3838

3939
def test_first_layer(self):
4040
torch_layer = self.baseline_model.vit.encoder.layer[0]
41-
nnx_layer = self.bonsai_model.layers.layers[0]
41+
nnx_layer = self.bonsai_model.layers[0]
4242

4343
hidden_shape = (self.batch_size, 197, 768)
4444
jx = jax.random.normal(jax.random.key(0), hidden_shape, dtype=jnp.float32)
4545
tx = torch.tensor(jx)
4646

4747
with torch.no_grad():
4848
ty = torch_layer(tx)
49-
jy = nnx_layer(jx)
49+
jy = nnx_layer(jx, rngs=None)
5050

5151
torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=1e-2)
5252

@@ -56,7 +56,7 @@ def test_full(self):
5656

5757
with torch.no_grad():
5858
ty = self.baseline_model(tx).logits
59-
jy = self.bonsai_model(jx)
59+
jy = self.bonsai_model(jx, rngs=None)
6060

6161
torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=5e-2)
6262

@@ -68,7 +68,7 @@ def test_full_interpolation(self):
6868

6969
with torch.no_grad():
7070
ty = self.baseline_model(tx, interpolate_pos_encoding=True).logits
71-
jy = self.bonsai_model(jx)
71+
jy = self.bonsai_model(jx, rngs=None)
7272

7373
torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=1e-1)
7474

0 commit comments

Comments
 (0)