diff --git a/bonsai/models/vit/modeling.py b/bonsai/models/vit/modeling.py index b78aeeec..7e203d46 100644 --- a/bonsai/models/vit/modeling.py +++ b/bonsai/models/vit/modeling.py @@ -11,6 +11,7 @@ class ModelConfig: patch_size: tuple[int, int] num_channels: int hidden_dim: int + attn_dropout_prob: float dropout_prob: float num_heads: int mlp_dim: int @@ -25,6 +26,7 @@ def vit_p16_224(cls): patch_size=(16, 16), num_channels=3, hidden_dim=768, + attn_dropout_prob=0.0, dropout_prob=0.0, num_heads=12, mlp_dim=3072, @@ -66,9 +68,9 @@ def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs): ) self.cls_token = nnx.Variable(jax.random.normal(rngs.params(), (1, 1, cfg.hidden_dim))) self.pos_embeddings = nnx.Variable(jax.random.normal(rngs.params(), (1, num_patches + 1, cfg.hidden_dim))) - self.dropout = nnx.Dropout(cfg.dropout_prob, rngs=rngs) + self.dropout = nnx.Dropout(cfg.dropout_prob) - def __call__(self, pixel_values: jnp.ndarray) -> jnp.ndarray: + def __call__(self, pixel_values: jnp.ndarray, *, rngs: nnx.Rngs | None) -> jnp.ndarray: embeddings = self.projection(pixel_values) b, h, w, c = embeddings.shape embeddings = embeddings.reshape(b, h * w, c) @@ -89,29 +91,33 @@ def __call__(self, pixel_values: jnp.ndarray) -> jnp.ndarray: embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) embeddings = embeddings + current_pos_embeddings - embeddings = self.dropout(embeddings) + embeddings = self.dropout(embeddings, rngs=rngs) return embeddings class TransformerEncoder(nnx.Module): def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs): self.attention = nnx.MultiHeadAttention( - num_heads=cfg.num_heads, in_features=cfg.hidden_dim, decode=False, rngs=rngs + num_heads=cfg.num_heads, + in_features=cfg.hidden_dim, + dropout_rate=cfg.attn_dropout_prob, + decode=False, + rngs=rngs, ) self.linear1 = nnx.Linear(cfg.hidden_dim, cfg.mlp_dim, rngs=rngs) self.linear2 = nnx.Linear(cfg.mlp_dim, cfg.hidden_dim, rngs=rngs) - self.dropout = nnx.Dropout(cfg.dropout_prob, rngs=rngs) + self.dropout = nnx.Dropout(cfg.dropout_prob) self.layernorm_before = nnx.LayerNorm(cfg.hidden_dim, epsilon=cfg.eps, rngs=rngs) self.layernorm_after = nnx.LayerNorm(cfg.hidden_dim, epsilon=cfg.eps, rngs=rngs) - def __call__(self, hidden_states, head_mask=None): + def __call__(self, hidden_states, head_mask=None, *, rngs: nnx.Rngs | None): hidden_states_norm = self.layernorm_before(hidden_states) - attention_output = self.attention(hidden_states_norm, head_mask) + attention_output = self.attention(hidden_states_norm, head_mask, rngs=rngs) hidden_states = attention_output + hidden_states layer_output = self.layernorm_after(hidden_states) layer_output = jax.nn.gelu(self.linear1(layer_output)) layer_output = self.linear2(layer_output) - layer_output = self.dropout(layer_output) + layer_output = self.dropout(layer_output, rngs=rngs) layer_output += hidden_states return layer_output @@ -119,19 +125,20 @@ def __call__(self, hidden_states, head_mask=None): class ViTClassificationModel(nnx.Module): def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs): self.pos_embeddings = Embeddings(cfg, rngs=rngs) - self.layers = nnx.Sequential(*[TransformerEncoder(cfg, rngs=rngs) for _ in range(cfg.num_layers)]) + self.layers = nnx.List([TransformerEncoder(cfg, rngs=rngs) for _ in range(cfg.num_layers)]) self.ln = nnx.LayerNorm(cfg.hidden_dim, epsilon=cfg.eps, rngs=rngs) self.classifier = nnx.Linear(cfg.hidden_dim, cfg.num_labels, rngs=rngs) - def __call__(self, x): - x = self.pos_embeddings(x) - x = self.layers(x) + def __call__(self, x, *, rngs: nnx.Rngs | None): + x = self.pos_embeddings(x, rngs=rngs) + for layer in self.layers: + x = layer(x, rngs=rngs) x = self.ln(x) x = self.classifier(x[:, 0, :]) return x @jax.jit -def forward(graphdef: nnx.GraphDef[nnx.Module], state: nnx.State, x: jax.Array) -> jax.Array: +def forward(graphdef: nnx.GraphDef[nnx.Module], state: nnx.State, x: jax.Array, rngs: nnx.Rngs) -> jax.Array: model = nnx.merge(graphdef, state) - return model(x) + return model(x, rngs=rngs) diff --git a/bonsai/models/vit/params.py b/bonsai/models/vit/params.py index 214d9133..1742d98a 100644 --- a/bonsai/models/vit/params.py +++ b/bonsai/models/vit/params.py @@ -16,7 +16,6 @@ import re from enum import Enum -import jax import jax.numpy as jnp import safetensors.flax as safetensors from etils import epath @@ -50,60 +49,42 @@ class Transform(Enum): r"^vit.embeddings.patch_embeddings.projection.weight$": (r"pos_embeddings.projection.kernel", Transform.CONV2D), r"^vit.embeddings.position_embeddings$": (r"pos_embeddings.pos_embeddings", Transform.EMBED), r"^vit.encoder.layer.([0-9]+).attention.attention.key.bias$": ( - r"layers.layers.\1.attention.key.bias", + r"layers.\1.attention.key.bias", Transform.ATTN_KQV_BIAS, ), r"^vit.encoder.layer.([0-9]+).attention.attention.key.weight$": ( - r"layers.layers.\1.attention.key.kernel", + r"layers.\1.attention.key.kernel", Transform.ATTN_KQV_KERNEL, ), r"^vit.encoder.layer.([0-9]+).attention.attention.query.bias$": ( - r"layers.layers.\1.attention.query.bias", + r"layers.\1.attention.query.bias", Transform.ATTN_KQV_BIAS, ), r"^vit.encoder.layer.([0-9]+).attention.attention.query.weight$": ( - r"layers.layers.\1.attention.query.kernel", + r"layers.\1.attention.query.kernel", Transform.ATTN_KQV_KERNEL, ), r"^vit.encoder.layer.([0-9]+).attention.attention.value.bias$": ( - r"layers.layers.\1.attention.value.bias", + r"layers.\1.attention.value.bias", Transform.ATTN_KQV_BIAS, ), r"^vit.encoder.layer.([0-9]+).attention.attention.value.weight$": ( - r"layers.layers.\1.attention.value.kernel", + r"layers.\1.attention.value.kernel", Transform.ATTN_KQV_KERNEL, ), - r"^vit.encoder.layer.([0-9]+).attention.output.dense.bias$": ( - r"layers.layers.\1.attention.out.bias", - Transform.BIAS, - ), + r"^vit.encoder.layer.([0-9]+).attention.output.dense.bias$": (r"layers.\1.attention.out.bias", Transform.BIAS), r"^vit.encoder.layer.([0-9]+).attention.output.dense.weight$": ( - r"layers.layers.\1.attention.out.kernel", + r"layers.\1.attention.out.kernel", Transform.ATTN_OUT, ), - r"^vit.encoder.layer.([0-9]+).intermediate.dense.bias$": (r"layers.layers.\1.linear1.bias", Transform.BIAS), - r"^vit.encoder.layer.([0-9]+).intermediate.dense.weight$": ( - r"layers.layers.\1.linear1.kernel", - Transform.LINEAR, - ), - r"^vit.encoder.layer.([0-9]+).layernorm_after.bias$": ( - r"layers.layers.\1.layernorm_after.bias", - Transform.BIAS, - ), - r"^vit.encoder.layer.([0-9]+).layernorm_after.weight$": ( - r"layers.layers.\1.layernorm_after.scale", - Transform.SCALE, - ), - r"^vit.encoder.layer.([0-9]+).layernorm_before.bias$": ( - r"layers.layers.\1.layernorm_before.bias", - Transform.BIAS, - ), - r"^vit.encoder.layer.([0-9]+).layernorm_before.weight$": ( - r"layers.layers.\1.layernorm_before.scale", - Transform.SCALE, - ), - r"^vit.encoder.layer.([0-9]+).output.dense.bias$": (r"layers.layers.\1.linear2.bias", Transform.BIAS), - r"^vit.encoder.layer.([0-9]+).output.dense.weight$": (r"layers.layers.\1.linear2.kernel", Transform.LINEAR), + r"^vit.encoder.layer.([0-9]+).intermediate.dense.bias$": (r"layers.\1.linear1.bias", Transform.BIAS), + r"^vit.encoder.layer.([0-9]+).intermediate.dense.weight$": (r"layers.\1.linear1.kernel", Transform.LINEAR), + r"^vit.encoder.layer.([0-9]+).layernorm_after.bias$": (r"layers.\1.layernorm_after.bias", Transform.BIAS), + r"^vit.encoder.layer.([0-9]+).layernorm_after.weight$": (r"layers.\1.layernorm_after.scale", Transform.SCALE), + r"^vit.encoder.layer.([0-9]+).layernorm_before.bias$": (r"layers.\1.layernorm_before.bias", Transform.BIAS), + r"^vit.encoder.layer.([0-9]+).layernorm_before.weight$": (r"layers.\1.layernorm_before.scale", Transform.SCALE), + r"^vit.encoder.layer.([0-9]+).output.dense.bias$": (r"layers.\1.linear2.bias", Transform.BIAS), + r"^vit.encoder.layer.([0-9]+).output.dense.weight$": (r"layers.\1.linear2.kernel", Transform.LINEAR), r"^vit.layernorm.bias$": (r"ln.bias", Transform.BIAS), r"^vit.layernorm.weight$": (r"ln.scale", Transform.SCALE), } @@ -149,12 +130,7 @@ def _stoi(s): return s -def create_vit_from_pretrained( - file_dir: str, - config: model_lib.ModelConfig, - *, - mesh: jax.sharding.Mesh | None = None, -): +def create_vit_from_pretrained(file_dir: str, config: model_lib.ModelConfig): """ Load safetensor weights from a file, then convert & merge into a flax.nnx ViT model. diff --git a/bonsai/models/vit/tests/run_model.py b/bonsai/models/vit/tests/run_model.py index 5fd7abb2..06feec0a 100644 --- a/bonsai/models/vit/tests/run_model.py +++ b/bonsai/models/vit/tests/run_model.py @@ -39,19 +39,19 @@ def run_model(): dummy_input = jnp.ones((batch_size, image_size, image_size, channels), dtype=jnp.float32) # Warmup (triggers compilation) - _ = model_lib.forward(graphdef, flat_state, dummy_input).block_until_ready() + _ = model_lib.forward(graphdef, flat_state, dummy_input, None).block_until_ready() # Profile a few steps jax.profiler.start_trace("/tmp/profile-vit") for _ in range(5): - logits = model_lib.forward(graphdef, flat_state, dummy_input) + logits = model_lib.forward(graphdef, flat_state, dummy_input, None) jax.block_until_ready(logits) jax.profiler.stop_trace() # Timed execution t0 = time.perf_counter() for _ in range(10): - logits = model_lib.forward(graphdef, flat_state, dummy_input).block_until_ready() + logits = model_lib.forward(graphdef, flat_state, dummy_input, None).block_until_ready() print(f"Step time: {(time.perf_counter() - t0) / 10:.4f} s") # Show top-1 predicted class diff --git a/bonsai/models/vit/tests/test_outputs_vit.py b/bonsai/models/vit/tests/test_outputs_vit.py index 0baa8311..0e897dcf 100644 --- a/bonsai/models/vit/tests/test_outputs_vit.py +++ b/bonsai/models/vit/tests/test_outputs_vit.py @@ -32,13 +32,13 @@ def test_embeddings(self): with torch.no_grad(): ty = torch_emb(tx) - jy = nnx_emb(jx) + jy = nnx_emb(jx, rngs=None) torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=1e-5) def test_first_layer(self): torch_layer = self.baseline_model.vit.encoder.layer[0] - nnx_layer = self.bonsai_model.layers.layers[0] + nnx_layer = self.bonsai_model.layers[0] hidden_shape = (self.batch_size, 197, 768) jx = jax.random.normal(jax.random.key(0), hidden_shape, dtype=jnp.float32) @@ -46,7 +46,7 @@ def test_first_layer(self): with torch.no_grad(): ty = torch_layer(tx) - jy = nnx_layer(jx) + jy = nnx_layer(jx, rngs=None) torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=1e-2) @@ -56,7 +56,7 @@ def test_full(self): with torch.no_grad(): ty = self.baseline_model(tx).logits - jy = self.bonsai_model(jx) + jy = self.bonsai_model(jx, rngs=None) torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=5e-2) @@ -68,7 +68,7 @@ def test_full_interpolation(self): with torch.no_grad(): ty = self.baseline_model(tx, interpolate_pos_encoding=True).logits - jy = self.bonsai_model(jx) + jy = self.bonsai_model(jx, rngs=None) torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=1e-1)