Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions bonsai/models/vit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class ModelConfig:
patch_size: tuple[int, int]
num_channels: int
hidden_dim: int
attn_dropout_prob: float
dropout_prob: float
num_heads: int
mlp_dim: int
Expand All @@ -25,6 +26,7 @@ def vit_p16_224(cls):
patch_size=(16, 16),
num_channels=3,
hidden_dim=768,
attn_dropout_prob=0.0,
dropout_prob=0.0,
num_heads=12,
mlp_dim=3072,
Expand Down Expand Up @@ -66,9 +68,9 @@ def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs):
)
self.cls_token = nnx.Variable(jax.random.normal(rngs.params(), (1, 1, cfg.hidden_dim)))
self.pos_embeddings = nnx.Variable(jax.random.normal(rngs.params(), (1, num_patches + 1, cfg.hidden_dim)))
self.dropout = nnx.Dropout(cfg.dropout_prob, rngs=rngs)
self.dropout = nnx.Dropout(cfg.dropout_prob)

def __call__(self, pixel_values: jnp.ndarray) -> jnp.ndarray:
def __call__(self, pixel_values: jnp.ndarray, *, rngs: nnx.Rngs | None) -> jnp.ndarray:
embeddings = self.projection(pixel_values)
b, h, w, c = embeddings.shape
embeddings = embeddings.reshape(b, h * w, c)
Expand All @@ -89,49 +91,54 @@ def __call__(self, pixel_values: jnp.ndarray) -> jnp.ndarray:
embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)

embeddings = embeddings + current_pos_embeddings
embeddings = self.dropout(embeddings)
embeddings = self.dropout(embeddings, rngs=rngs)
return embeddings


class TransformerEncoder(nnx.Module):
def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs):
self.attention = nnx.MultiHeadAttention(
num_heads=cfg.num_heads, in_features=cfg.hidden_dim, decode=False, rngs=rngs
num_heads=cfg.num_heads,
in_features=cfg.hidden_dim,
dropout_rate=cfg.attn_dropout_prob,
decode=False,
rngs=rngs,
)
self.linear1 = nnx.Linear(cfg.hidden_dim, cfg.mlp_dim, rngs=rngs)
self.linear2 = nnx.Linear(cfg.mlp_dim, cfg.hidden_dim, rngs=rngs)
self.dropout = nnx.Dropout(cfg.dropout_prob, rngs=rngs)
self.dropout = nnx.Dropout(cfg.dropout_prob)
self.layernorm_before = nnx.LayerNorm(cfg.hidden_dim, epsilon=cfg.eps, rngs=rngs)
self.layernorm_after = nnx.LayerNorm(cfg.hidden_dim, epsilon=cfg.eps, rngs=rngs)

def __call__(self, hidden_states, head_mask=None):
def __call__(self, hidden_states, head_mask=None, *, rngs: nnx.Rngs | None):
hidden_states_norm = self.layernorm_before(hidden_states)
attention_output = self.attention(hidden_states_norm, head_mask)
attention_output = self.attention(hidden_states_norm, head_mask, rngs=rngs)
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = jax.nn.gelu(self.linear1(layer_output))
layer_output = self.linear2(layer_output)
layer_output = self.dropout(layer_output)
layer_output = self.dropout(layer_output, rngs=rngs)
layer_output += hidden_states
return layer_output


class ViTClassificationModel(nnx.Module):
def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs):
self.pos_embeddings = Embeddings(cfg, rngs=rngs)
self.layers = nnx.Sequential(*[TransformerEncoder(cfg, rngs=rngs) for _ in range(cfg.num_layers)])
self.layers = nnx.List([TransformerEncoder(cfg, rngs=rngs) for _ in range(cfg.num_layers)])
self.ln = nnx.LayerNorm(cfg.hidden_dim, epsilon=cfg.eps, rngs=rngs)
self.classifier = nnx.Linear(cfg.hidden_dim, cfg.num_labels, rngs=rngs)

def __call__(self, x):
x = self.pos_embeddings(x)
x = self.layers(x)
def __call__(self, x, *, rngs: nnx.Rngs | None):
x = self.pos_embeddings(x, rngs=rngs)
for layer in self.layers:
x = layer(x, rngs=rngs)
x = self.ln(x)
x = self.classifier(x[:, 0, :])
return x


@jax.jit
def forward(graphdef: nnx.GraphDef[nnx.Module], state: nnx.State, x: jax.Array) -> jax.Array:
def forward(graphdef: nnx.GraphDef[nnx.Module], state: nnx.State, x: jax.Array, rngs: nnx.Rngs) -> jax.Array:
model = nnx.merge(graphdef, state)
return model(x)
return model(x, rngs=rngs)
58 changes: 17 additions & 41 deletions bonsai/models/vit/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import re
from enum import Enum

import jax
import jax.numpy as jnp
import safetensors.flax as safetensors
from etils import epath
Expand Down Expand Up @@ -50,60 +49,42 @@ class Transform(Enum):
r"^vit.embeddings.patch_embeddings.projection.weight$": (r"pos_embeddings.projection.kernel", Transform.CONV2D),
r"^vit.embeddings.position_embeddings$": (r"pos_embeddings.pos_embeddings", Transform.EMBED),
r"^vit.encoder.layer.([0-9]+).attention.attention.key.bias$": (
r"layers.layers.\1.attention.key.bias",
r"layers.\1.attention.key.bias",
Transform.ATTN_KQV_BIAS,
),
r"^vit.encoder.layer.([0-9]+).attention.attention.key.weight$": (
r"layers.layers.\1.attention.key.kernel",
r"layers.\1.attention.key.kernel",
Transform.ATTN_KQV_KERNEL,
),
r"^vit.encoder.layer.([0-9]+).attention.attention.query.bias$": (
r"layers.layers.\1.attention.query.bias",
r"layers.\1.attention.query.bias",
Transform.ATTN_KQV_BIAS,
),
r"^vit.encoder.layer.([0-9]+).attention.attention.query.weight$": (
r"layers.layers.\1.attention.query.kernel",
r"layers.\1.attention.query.kernel",
Transform.ATTN_KQV_KERNEL,
),
r"^vit.encoder.layer.([0-9]+).attention.attention.value.bias$": (
r"layers.layers.\1.attention.value.bias",
r"layers.\1.attention.value.bias",
Transform.ATTN_KQV_BIAS,
),
r"^vit.encoder.layer.([0-9]+).attention.attention.value.weight$": (
r"layers.layers.\1.attention.value.kernel",
r"layers.\1.attention.value.kernel",
Transform.ATTN_KQV_KERNEL,
),
r"^vit.encoder.layer.([0-9]+).attention.output.dense.bias$": (
r"layers.layers.\1.attention.out.bias",
Transform.BIAS,
),
r"^vit.encoder.layer.([0-9]+).attention.output.dense.bias$": (r"layers.\1.attention.out.bias", Transform.BIAS),
r"^vit.encoder.layer.([0-9]+).attention.output.dense.weight$": (
r"layers.layers.\1.attention.out.kernel",
r"layers.\1.attention.out.kernel",
Transform.ATTN_OUT,
),
r"^vit.encoder.layer.([0-9]+).intermediate.dense.bias$": (r"layers.layers.\1.linear1.bias", Transform.BIAS),
r"^vit.encoder.layer.([0-9]+).intermediate.dense.weight$": (
r"layers.layers.\1.linear1.kernel",
Transform.LINEAR,
),
r"^vit.encoder.layer.([0-9]+).layernorm_after.bias$": (
r"layers.layers.\1.layernorm_after.bias",
Transform.BIAS,
),
r"^vit.encoder.layer.([0-9]+).layernorm_after.weight$": (
r"layers.layers.\1.layernorm_after.scale",
Transform.SCALE,
),
r"^vit.encoder.layer.([0-9]+).layernorm_before.bias$": (
r"layers.layers.\1.layernorm_before.bias",
Transform.BIAS,
),
r"^vit.encoder.layer.([0-9]+).layernorm_before.weight$": (
r"layers.layers.\1.layernorm_before.scale",
Transform.SCALE,
),
r"^vit.encoder.layer.([0-9]+).output.dense.bias$": (r"layers.layers.\1.linear2.bias", Transform.BIAS),
r"^vit.encoder.layer.([0-9]+).output.dense.weight$": (r"layers.layers.\1.linear2.kernel", Transform.LINEAR),
r"^vit.encoder.layer.([0-9]+).intermediate.dense.bias$": (r"layers.\1.linear1.bias", Transform.BIAS),
r"^vit.encoder.layer.([0-9]+).intermediate.dense.weight$": (r"layers.\1.linear1.kernel", Transform.LINEAR),
r"^vit.encoder.layer.([0-9]+).layernorm_after.bias$": (r"layers.\1.layernorm_after.bias", Transform.BIAS),
r"^vit.encoder.layer.([0-9]+).layernorm_after.weight$": (r"layers.\1.layernorm_after.scale", Transform.SCALE),
r"^vit.encoder.layer.([0-9]+).layernorm_before.bias$": (r"layers.\1.layernorm_before.bias", Transform.BIAS),
r"^vit.encoder.layer.([0-9]+).layernorm_before.weight$": (r"layers.\1.layernorm_before.scale", Transform.SCALE),
r"^vit.encoder.layer.([0-9]+).output.dense.bias$": (r"layers.\1.linear2.bias", Transform.BIAS),
r"^vit.encoder.layer.([0-9]+).output.dense.weight$": (r"layers.\1.linear2.kernel", Transform.LINEAR),
r"^vit.layernorm.bias$": (r"ln.bias", Transform.BIAS),
r"^vit.layernorm.weight$": (r"ln.scale", Transform.SCALE),
}
Expand Down Expand Up @@ -149,12 +130,7 @@ def _stoi(s):
return s


def create_vit_from_pretrained(
file_dir: str,
config: model_lib.ModelConfig,
*,
mesh: jax.sharding.Mesh | None = None,
):
def create_vit_from_pretrained(file_dir: str, config: model_lib.ModelConfig):
"""
Load safetensor weights from a file, then convert & merge into a flax.nnx ViT model.

Expand Down
6 changes: 3 additions & 3 deletions bonsai/models/vit/tests/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@ def run_model():
dummy_input = jnp.ones((batch_size, image_size, image_size, channels), dtype=jnp.float32)

# Warmup (triggers compilation)
_ = model_lib.forward(graphdef, flat_state, dummy_input).block_until_ready()
_ = model_lib.forward(graphdef, flat_state, dummy_input, None).block_until_ready()

# Profile a few steps
jax.profiler.start_trace("/tmp/profile-vit")
for _ in range(5):
logits = model_lib.forward(graphdef, flat_state, dummy_input)
logits = model_lib.forward(graphdef, flat_state, dummy_input, None)
jax.block_until_ready(logits)
jax.profiler.stop_trace()

# Timed execution
t0 = time.perf_counter()
for _ in range(10):
logits = model_lib.forward(graphdef, flat_state, dummy_input).block_until_ready()
logits = model_lib.forward(graphdef, flat_state, dummy_input, None).block_until_ready()
print(f"Step time: {(time.perf_counter() - t0) / 10:.4f} s")

# Show top-1 predicted class
Expand Down
10 changes: 5 additions & 5 deletions bonsai/models/vit/tests/test_outputs_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@ def test_embeddings(self):

with torch.no_grad():
ty = torch_emb(tx)
jy = nnx_emb(jx)
jy = nnx_emb(jx, rngs=None)

torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=1e-5)

def test_first_layer(self):
torch_layer = self.baseline_model.vit.encoder.layer[0]
nnx_layer = self.bonsai_model.layers.layers[0]
nnx_layer = self.bonsai_model.layers[0]

hidden_shape = (self.batch_size, 197, 768)
jx = jax.random.normal(jax.random.key(0), hidden_shape, dtype=jnp.float32)
tx = torch.tensor(jx)

with torch.no_grad():
ty = torch_layer(tx)
jy = nnx_layer(jx)
jy = nnx_layer(jx, rngs=None)

torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=1e-2)

Expand All @@ -56,7 +56,7 @@ def test_full(self):

with torch.no_grad():
ty = self.baseline_model(tx).logits
jy = self.bonsai_model(jx)
jy = self.bonsai_model(jx, rngs=None)

torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=5e-2)

Expand All @@ -68,7 +68,7 @@ def test_full_interpolation(self):

with torch.no_grad():
ty = self.baseline_model(tx, interpolate_pos_encoding=True).logits
jy = self.bonsai_model(jx)
jy = self.bonsai_model(jx, rngs=None)

torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=1e-1)

Expand Down