diff --git a/bonsai/models/dinov3/Readme.md b/bonsai/models/dinov3/Readme.md new file mode 100644 index 00000000..e253a1af --- /dev/null +++ b/bonsai/models/dinov3/Readme.md @@ -0,0 +1,20 @@ +# DINOv3 in Jax +This directory contains a pure jax implementation of the [Dinov3 collection of VIT models](https://huggingface.co/collections/facebook/dinov3) using the [Flax NNX](https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html) API. + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1hWUQ6dKO0rL5__OxaPzLBfl0G7aJ34js#scrollTo=p6xCHoM04byX]) + +## Model configuration support status +| Model Name | Size |Config Support Status | +| :--- | :--- | :--- | +| **Web (LVD) Models** | | | +| [ViT-S](https://huggingface.co/facebook/dinov3-vits16-pretrain-lvd1689m) | 21M | **✅ Supported** | +| [ViT-S+](https://huggingface.co/facebook/dinov3-vits16plus-pretrain-lvd1689m) | 29M | **✅ Supported** | +| [ViT-B](https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m) | 86M | **✅ Supported** | +| [ViT-L](https://huggingface.co/facebook/dinov3-vitl16-pretrain-lvd1689m) | 0.3B |**✅ Supported** | +| [ViT-H+](https://huggingface.co/facebook/dinov3-vith16plus-pretrain-lvd1689m) | 0.84B |**✅ Supported** | +| [ViT-7B](https://huggingface.co/facebook/dinov3-vit7b16-pretrain-lvd1689m) | 7B |**Needs sharding** | +| **Satellite (SAT) Models** | | | +| [ViT-L](https://huggingface.co/facebook/dinov3-vitl16-pretrain-sat493m) | 0.3B | **✅ Supported** | +| [ViT-7B](https://huggingface.co/facebook/dinov3-vit7b16-pretrain-sat493m) | 7B |**Needs sharding** | + +* Note: Hf login and approval required. \ No newline at end of file diff --git a/bonsai/models/dinov3/modeling.py b/bonsai/models/dinov3/modeling.py new file mode 100644 index 00000000..5039573a --- /dev/null +++ b/bonsai/models/dinov3/modeling.py @@ -0,0 +1,333 @@ +import dataclasses +from typing import Tuple + +import jax.numpy as jnp +from flax import nnx +from jax import Array + + +@dataclasses.dataclass +class Dinov3ViTModelOutput: + last_hidden_state: Array + pooler_output: Array + + +@dataclasses.dataclass(frozen=True) +class DINOv3ViTFlaxConfig: + model_type = "dinov3_ViT" + patch_size: Tuple[int, int] = (16, 16) + hidden_size: int = 384 + intermediate_size: int = 1536 + num_hidden_layers: int = 12 + num_attention_heads: int = 6 + hidden_act: str = "gelu" + layer_norm_eps: float = 1e-5 + rope_theta: float = 100.0 + image_size: int = 224 + num_channels: int = 3 + query_bias: bool = True + key_bias: bool = False + value_bias: bool = True + proj_bias: bool = True + mlp_bias: bool = True + layerscale_value: float = 1.0 + use_gated_mlp: bool = False + num_register_tokens: int = 4 + + @classmethod + def dinov3_vits16(cls): + return cls() + + @classmethod + def dinov3_vits16plus(cls): + return cls( + hidden_size=384, + intermediate_size=1536, + num_hidden_layers=12, + num_attention_heads=6, + hidden_act="silu", + use_gated_mlp=True, + ) + + @classmethod + def dinov3_vitb16(cls): + return cls( + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + hidden_act="gelu", + use_gated_mlp=False, + ) + + @classmethod + def dinov3_vitl16(cls): + return cls( + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=24, + num_attention_heads=16, + hidden_act="gelu", + use_gated_mlp=False, + ) + + @classmethod + def dinov3_vith16plus(cls): + return cls( + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + hidden_act="silu", + use_gated_mlp=True, + ) + + @classmethod + def dinov3_vit7b16(cls): + return cls( + hidden_size=4096, + intermediate_size=8192, + num_hidden_layers=40, + num_attention_heads=32, + hidden_act="silu", + use_gated_mlp=True, + ) + + +class DINOv3ViTEmbeddings(nnx.Module): + def __init__(self, config: DINOv3ViTFlaxConfig, rngs: nnx.Rngs): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.cls_token = nnx.Param(jnp.ones((1, 1, self.hidden_size), dtype=jnp.float32)) + self.mask_token = nnx.Param(jnp.zeros((1, 1, self.hidden_size), dtype=jnp.float32)) + self.register_tokens = nnx.Param( + jnp.zeros((1, config.num_register_tokens, config.hidden_size), dtype=jnp.float32) + ) + self.patch_embeddings = nnx.Conv( + in_features=config.num_channels, + out_features=config.hidden_size, + kernel_size=config.patch_size, + strides=config.patch_size, + rngs=rngs, + ) + + def __call__(self, pixel_values: Array) -> Array: + b, _, _, _ = pixel_values.shape + + # (batch_size, num_channels, height, width) -> (batch_size, num_patches, hidden_size) + pixel_values = pixel_values.transpose(0, 2, 3, 1) + patch_embeddings = self.patch_embeddings(pixel_values) + patch_embeddings = patch_embeddings.reshape(b, -1, self.hidden_size) + + cls_token = jnp.broadcast_to(self.cls_token[...], (b, 1, self.hidden_size)) + register_tokens = jnp.broadcast_to( + self.register_tokens[...], (b, self.config.num_register_tokens, self.hidden_size) + ) + return jnp.concat([cls_token, register_tokens, patch_embeddings], axis=1) + + +class Dinov3ViTRopePositionEmbedding(nnx.Module): + def __init__(self, config: DINOv3ViTFlaxConfig): + super().__init__() + self.config = config + self.base = config.rope_theta + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_patches_h = config.image_size // config.patch_size[0] + self.num_patches_w = config.image_size // config.patch_size[0] + + def __call__(self, pixel_values: Array) -> Tuple[Array, Array]: + _, _, height, width = pixel_values.shape + num_patches_h = height // self.config.patch_size[0] + num_patches_w = width // self.config.patch_size[0] + + coords_h = jnp.arange(0.5, num_patches_h, dtype=jnp.float32) / num_patches_h # [H] + coords_w = jnp.arange(0.5, num_patches_w, dtype=jnp.float32) / num_patches_w # [W] + coords = jnp.stack(jnp.meshgrid(coords_h, coords_w, indexing="ij"), axis=-1) # [H, W, 2] + coords = coords.reshape(-1, 2) + coords = 2 * coords - 1.0 + + inv_freq = 1.0 / self.base ** jnp.arange(0.0, 1.0, 4.0 / self.head_dim, dtype=jnp.float32) # [head_dim // 4] + angles = 2 * jnp.pi * coords[:, :, None] * inv_freq[None, None, :] # (HW, 2, D//4) + angles = angles.reshape(coords.shape[0], -1) # (HW, D//2) + angles = jnp.tile(angles, (1, 2)) # (HW, D) + + cos = jnp.cos(angles) + sin = jnp.sin(angles) + + return (cos, sin) + + +class Dinov3LayerScale(nnx.Module): + def __init__(self, config: DINOv3ViTFlaxConfig): + super().__init__() + self.lambda1 = nnx.Param(jnp.full((config.hidden_size,), config.layerscale_value, dtype=jnp.float32)) + + def __call__(self, x: Array) -> Array: + return x * self.lambda1 + + +def rotate_half(x: Array) -> Array: + d = x.shape[-1] + assert d % 2 == 0 + x1 = x[..., : d // 2] + x2 = x[..., d // 2 :] + return jnp.concatenate((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(q: Array, k: Array, cos: Array, sin: Array) -> Tuple[Array, Array]: + q = q.astype(jnp.bfloat16) + k = k.astype(jnp.bfloat16) + cos = cos.astype(jnp.bfloat16) + sin = sin.astype(jnp.bfloat16) + num_tokens = q.shape[-2] + num_patches = cos.shape[-2] + num_prefix = num_tokens - num_patches + q_prefix, q_patches = jnp.split(q, [num_prefix], axis=-2) + k_prefix, k_patches = jnp.split(k, [num_prefix], axis=-2) + cos_b = cos[None, None, ...] + sin_b = sin[None, None, ...] + # Rotation + q_patches = (q_patches * cos_b) + (rotate_half(q_patches) * sin_b) + k_patches = (k_patches * cos_b) + (rotate_half(k_patches) * sin_b) + q = jnp.concatenate([q_prefix, q_patches], axis=-2) + k = jnp.concatenate([k_prefix, k_patches], axis=-2) + q = q.astype(jnp.float32) + k = k.astype(jnp.float32) + return (q, k) + + +class Dinov3ViTAttention(nnx.Module): + def __init__(self, config: DINOv3ViTFlaxConfig, rngs: nnx.Rngs): + super().__init__() + self.config = config + + self.q_proj = nnx.Linear( + in_features=config.hidden_size, out_features=config.hidden_size, use_bias=config.query_bias, rngs=rngs + ) + self.k_proj = nnx.Linear( + in_features=config.hidden_size, out_features=config.hidden_size, use_bias=config.key_bias, rngs=rngs + ) + self.v_proj = nnx.Linear( + in_features=config.hidden_size, out_features=config.hidden_size, use_bias=config.value_bias, rngs=rngs + ) + self.o_proj = nnx.Linear( + in_features=config.hidden_size, out_features=config.hidden_size, use_bias=config.proj_bias, rngs=rngs + ) + + def __call__(self, hidden_states: Array, position_embeddings: Tuple[Array, Array]) -> Array: + batch_size, patches, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + n_heads = self.config.num_attention_heads + head_dim = self.config.hidden_size // n_heads + + query_states = query_states.reshape(batch_size, patches, n_heads, head_dim).transpose(0, 2, 1, 3) + key_states = key_states.reshape(batch_size, patches, n_heads, head_dim).transpose(0, 2, 1, 3) + value_states = value_states.reshape(batch_size, patches, n_heads, head_dim).transpose(0, 2, 1, 3) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + scale = self.config.hidden_size // self.config.num_attention_heads + scale = 1.0 / jnp.sqrt(scale) + + # (B, H, P, D) @ (B, H, D, P) -> (B, H, P, P) + attn_weights = jnp.matmul(query_states, key_states.transpose(0, 1, 3, 2)) * scale + attn_weights = nnx.softmax(attn_weights, axis=-1) + + # (B, H, P, P) @ (B, H, P, D) -> (B, H, P, D) + hidden_states = jnp.matmul(attn_weights, value_states) + + hidden_states = hidden_states.transpose(0, 2, 1, 3).reshape(batch_size, patches, -1) + hidden_states = self.o_proj(hidden_states) + return hidden_states + + +class Dinov3MLP(nnx.Module): + def __init__(self, config: DINOv3ViTFlaxConfig, rngs: nnx.Rngs): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = nnx.Linear(self.hidden_size, self.intermediate_size, rngs=rngs) + self.down_proj = nnx.Linear(self.intermediate_size, self.hidden_size, rngs=rngs) + if config.hidden_act == "gelu": + self.act_fn = nnx.gelu + elif config.hidden_act == "silu": + self.act_fn = nnx.silu + + def __call__(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +class Dinov3GatedMLP(nnx.Module): + def __init__(self, config: DINOv3ViTFlaxConfig, rngs: nnx.Rngs): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nnx.Linear(self.hidden_size, self.intermediate_size, use_bias=config.mlp_bias, rngs=rngs) + self.up_proj = nnx.Linear(self.hidden_size, self.intermediate_size, use_bias=config.mlp_bias, rngs=rngs) + self.down_proj = nnx.Linear(self.intermediate_size, self.hidden_size, use_bias=config.mlp_bias, rngs=rngs) + if config.hidden_act == "gelu": + self.act_fn = nnx.gelu + elif config.hidden_act == "silu": + self.act_fn = nnx.silu + + def __call__(self, x): + x = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return x + + +class Dinov3ViTLayer(nnx.Module): + def __init__(self, config: DINOv3ViTFlaxConfig, rngs: nnx.Rngs): + self.norm1 = nnx.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps, rngs=rngs) + self.attention = Dinov3ViTAttention(config, rngs=rngs) + self.layer_scale1 = Dinov3LayerScale(config) + self.norm2 = nnx.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps, rngs=rngs) + if config.use_gated_mlp: + self.mlp = Dinov3GatedMLP(config, rngs=rngs) + else: + self.mlp = Dinov3MLP(config, rngs=rngs) + + self.layer_scale2 = Dinov3LayerScale(config) + + def __call__(self, hidden_states: Array, position_embeddings: Tuple[Array, Array]) -> Array: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.attention(hidden_states, position_embeddings) + hidden_states = self.layer_scale1(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_scale2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class Dinov3ViTModel(nnx.Module): + def __init__(self, config: DINOv3ViTFlaxConfig, rngs: nnx.Rngs): + super().__init__() + self.config = config + self.embeddings = DINOv3ViTEmbeddings(config, rngs=rngs) + self.rope_embeddings = Dinov3ViTRopePositionEmbedding(config) + self.layer = nnx.List([Dinov3ViTLayer(config, rngs=rngs) for _ in range(config.num_hidden_layers)]) + self.norm = nnx.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps, rngs=rngs) + + def __call__(self, pixel_values: Array) -> Dinov3ViTModelOutput: + hidden_states = self.embeddings(pixel_values) + position_embeddings = self.rope_embeddings(pixel_values) + + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, position_embeddings) + + sequence_output = self.norm(hidden_states) + pooled_output = sequence_output[:, 0, :] + + return Dinov3ViTModelOutput(**{"last_hidden_state": sequence_output, "pooler_output": pooled_output}) diff --git a/bonsai/models/dinov3/params.py b/bonsai/models/dinov3/params.py new file mode 100644 index 00000000..3c9c327c --- /dev/null +++ b/bonsai/models/dinov3/params.py @@ -0,0 +1,144 @@ +import gc +import re +from enum import Enum + +import jax +import safetensors +from etils import epath +from flax import nnx + +from bonsai.models.dinov3.modeling import DINOv3ViTFlaxConfig, Dinov3ViTModel + + +def _get_key_and_transform_mapping(): + class Transform(Enum): + BIAS = (None, None, False) + LINEAR = ((1, 0), None, False) + CONV2D = ((2, 3, 1, 0), None, False) + DEFAULT = (None, None, False) + + # Mapping st_keys -> (nnx_keys, (permute_rule, reshape_rule, reshape_first)) + return { + # Embeddings + r"embeddings\.cls_token$": ("embeddings.cls_token", Transform.DEFAULT), + r"embeddings\.mask_token$": ("embeddings.mask_token", Transform.DEFAULT), + r"embeddings\.register_tokens$": ("embeddings.register_tokens", Transform.DEFAULT), + r"embeddings\.patch_embeddings\.weight$": ("embeddings.patch_embeddings.kernel", Transform.CONV2D), + r"embeddings\.patch_embeddings\.bias$": ("embeddings.patch_embeddings.bias", Transform.BIAS), + # Attention weights and biases + r"layer\.([0-9]+)\.attention\.q_proj\.weight$": (r"layer.\1.attention.q_proj.kernel", Transform.LINEAR), + r"layer\.([0-9]+)\.attention\.k_proj\.weight$": (r"layer.\1.attention.k_proj.kernel", Transform.LINEAR), + r"layer\.([0-9]+)\.attention\.v_proj\.weight$": (r"layer.\1.attention.v_proj.kernel", Transform.LINEAR), + r"layer\.([0-9]+)\.attention\.o_proj\.weight$": (r"layer.\1.attention.o_proj.kernel", Transform.LINEAR), + r"layer\.([0-9]+)\.attention\.q_proj\.bias$": (r"layer.\1.attention.q_proj.bias", Transform.BIAS), + r"layer\.([0-9]+)\.attention\.k_proj\.bias$": (r"layer.\1.attention.k_proj.bias", Transform.BIAS), + r"layer\.([0-9]+)\.attention\.v_proj\.bias$": (r"layer.\1.attention.v_proj.bias", Transform.BIAS), + r"layer\.([0-9]+)\.attention\.o_proj\.bias$": (r"layer.\1.attention.o_proj.bias", Transform.BIAS), + # MLP (gated or not) + r"layer\.([0-9]+)\.mlp\.gate_proj\.weight$": (r"layer.\1.mlp.gate_proj.kernel", Transform.LINEAR), + r"layer\.([0-9]+)\.mlp\.up_proj\.weight$": (r"layer.\1.mlp.up_proj.kernel", Transform.LINEAR), + r"layer\.([0-9]+)\.mlp\.down_proj\.weight$": (r"layer.\1.mlp.down_proj.kernel", Transform.LINEAR), + r"layer\.([0-9]+)\.mlp\.gate_proj\.bias$": (r"layer.\1.mlp.gate_proj.bias", Transform.BIAS), + r"layer\.([0-9]+)\.mlp\.up_proj\.bias$": (r"layer.\1.mlp.up_proj.bias", Transform.BIAS), + r"layer\.([0-9]+)\.mlp\.down_proj\.bias$": (r"layer.\1.mlp.down_proj.bias", Transform.BIAS), + # layer_scale1 / layer_scale2 keys + r"layer\.([0-9]+)\.layer_scale1\.lambda1$": (r"layer.\1.layer_scale1.lambda1", Transform.DEFAULT), + r"layer\.([0-9]+)\.layer_scale2\.lambda1$": (r"layer.\1.layer_scale2.lambda1", Transform.DEFAULT), + # norm1 / norm2 mapping + r"layer\.([0-9]+)\.norm1\.weight$": (r"layer.\1.norm1.scale", Transform.DEFAULT), + r"layer\.([0-9]+)\.norm1\.bias$": (r"layer.\1.norm1.bias", Transform.DEFAULT), + r"layer\.([0-9]+)\.norm2\.weight$": (r"layer.\1.norm2.scale", Transform.DEFAULT), + r"layer\.([0-9]+)\.norm2\.bias$": (r"layer.\1.norm2.bias", Transform.DEFAULT), + # final model norm + r"norm\.weight": ("norm.scale", Transform.DEFAULT), + r"norm\.bias": ("norm.bias", Transform.DEFAULT), + } + + +def _torch_key_to_jax_key(mapping, source_key): + subs = [ + (re.sub(pat, repl, source_key), reshape) + for pat, (repl, reshape) in mapping.items() + if re.match(pat, source_key) + ] + if len(subs) != 1: + raise ValueError(f"Only one key should be found: {subs[0]}") + return subs[0] + + +def _assign_weights(keys, tensor, state_dict, st_key, transform, sharding_dict): + """Recursively descend into state_dict and assign the (possibly permuted/reshaped) tensor.""" + key, *rest = keys + if not rest: + if transform is not None: + permute, reshape, reshape_first = transform + if reshape_first and reshape is not None: + tensor = tensor.reshape(reshape) + if permute: + tensor = tensor.transpose(permute) + if not reshape_first and reshape is not None: + tensor = tensor.reshape(reshape) + if tensor.shape != state_dict[key].shape: + raise ValueError(f"Shape mismatch for {st_key}: {tensor.shape} vs {state_dict[key].shape}") + # Only apply sharding if sharding_dict is provided + if sharding_dict is not None: + state_dict[key] = jax.device_put(tensor, sharding_dict[key]) + else: + state_dict[key] = jax.device_put(tensor) + else: + next_sharding = sharding_dict[key] if sharding_dict is not None else None + _assign_weights(rest, tensor, state_dict[key], st_key, transform, next_sharding) + + +def _stoi(s): + try: + return int(s) + except ValueError: + return s + + +def create_model_from_safe_tensors( + file_dir: str, + cfg: DINOv3ViTFlaxConfig, + mesh: jax.sharding.Mesh | None = None, +) -> Dinov3ViTModel: + """Load tensors from the safetensors file and create a Dinov3 model (memory-optimized).""" + files = list(epath.Path(file_dir).expanduser().glob("*.safetensors")) + if not files: + raise ValueError(f"No safetensors found in {file_dir}") + + dinov3 = nnx.eval_shape(lambda: Dinov3ViTModel(cfg, rngs=nnx.Rngs(0))) + graph_def, abs_state = nnx.split(dinov3) + state_dict = abs_state.to_pure_dict() + # Only use sharding if mesh is provided + sharding = nnx.get_named_sharding(abs_state, mesh).to_pure_dict() if mesh is not None else None + + key_mapping = _get_key_and_transform_mapping() + + conversion_errors = [] + for f in files: + with safetensors.safe_open(f, framework="numpy") as sf: + for torch_key in sf.keys(): + tensor = sf.get_tensor(torch_key) + + jax_key, transform = _torch_key_to_jax_key(key_mapping, torch_key) + if jax_key is None: + continue + keys = [_stoi(k) for k in jax_key.split(".")] + try: + _assign_weights(keys, tensor, state_dict, torch_key, transform.value, sharding) + except Exception as e: + full_jax_key = ".".join([str(k) for k in keys]) + conversion_errors.append( + f"Failed to assign '{torch_key}' to '{full_jax_key}': {type(e).__name__}: {e}" + ) + + gc.collect() + + if conversion_errors: + full_error_log = "\n".join(conversion_errors) + raise RuntimeError(f"Encountered {len(conversion_errors)} weight conversion errors. Log: \n{full_error_log}") + + m = nnx.merge(graph_def, state_dict) + m.eval() + return m diff --git a/bonsai/models/dinov3/tests/test_outputs_dinov3.py b/bonsai/models/dinov3/tests/test_outputs_dinov3.py new file mode 100644 index 00000000..3f4ee7e7 --- /dev/null +++ b/bonsai/models/dinov3/tests/test_outputs_dinov3.py @@ -0,0 +1,125 @@ +import os + +import jax +import jax.numpy as jnp +import numpy as np +import torch +from absl.testing import absltest +from huggingface_hub import constants +from safetensors.torch import save_file +from transformers import DINOv3ViTConfig, DINOv3ViTModel + +from bonsai.models.dinov3 import modeling as model_lib +from bonsai.models.dinov3 import params + + +class TestForwardPass(absltest.TestCase): + def setUp(self): + super().setUp() + self.save_dir = constants.default_cache_path + os.makedirs(self.save_dir, exist_ok=True) + + self.hfconfig = DINOv3ViTConfig( + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + hidden_act="gelu", + use_gated_mlp=False, + num_register_tokens=4, + ) + self.baseline_model = DINOv3ViTModel(config=self.hfconfig) + self.model_ckpt_path = os.path.join(self.save_dir, "model.safetensors") + save_file(self.baseline_model.state_dict(), self.model_ckpt_path) + + self.config = model_lib.DINOv3ViTFlaxConfig.dinov3_vitb16() + self.bonsai_model = params.create_model_from_safe_tensors(self.save_dir, self.config) + + self.bonsai_model.eval() + self.baseline_model.eval() + + self.batch_size = 1 + self.image_shape = (self.batch_size, 3, 224, 224) + + def test_input_embeddings(self): + torch_emb = self.baseline_model.embeddings + nnx_emb = self.bonsai_model.embeddings + + key = jax.random.PRNGKey(0) + jx = jax.random.normal(key, self.image_shape, dtype=jnp.float32) + + np_x = np.asarray(jax.device_get(jx)) + tx = torch.tensor(np_x, dtype=torch.float32) + + with torch.inference_mode(): + ty = torch_emb(tx) + jy = nnx_emb(jx) + + np_y = np.asarray(jax.device_get(jy)) + ty_bonsai = torch.tensor(np_y, dtype=torch.float32) + + torch.testing.assert_close(ty_bonsai, ty, rtol=1e-5, atol=1e-5) + + def test_first_layer(self): + torch_emb = self.baseline_model.embeddings + nnx_emb = self.bonsai_model.embeddings + torch_pe = self.baseline_model.rope_embeddings + nnx_pe = self.bonsai_model.rope_embeddings + torch_layer = self.baseline_model.layer[0] + nnx_layer = self.bonsai_model.layer[0] + + key = jax.random.PRNGKey(0) + jx = jax.random.normal(key, self.image_shape, dtype=jnp.float32) + np_x = np.asarray(jax.device_get(jx)) + tx = torch.tensor(np_x, dtype=torch.float32) + + jhs = nnx_emb(jx) + jpe = nnx_pe(jx) + + ths = torch_emb(tx) + tpe = torch_pe(tx) + + with torch.inference_mode(): + ty = torch_layer(ths, position_embeddings=tpe) + jy = nnx_layer(jhs, jpe) + + np_y = np.asarray(jax.device_get(jy)) + ty_bonsai = torch.tensor(np_y, dtype=torch.float32) + + torch.testing.assert_close(ty_bonsai, ty, rtol=1e-5, atol=3e-3) + + def test_last_hidden_state(self): + key = jax.random.PRNGKey(0) + jx = jax.random.normal(key, self.image_shape, dtype=jnp.float32) + + np_x = np.asarray(jax.device_get(jx)) + tx = torch.tensor(np_x, dtype=torch.float32) + + with torch.inference_mode(): + ty = self.baseline_model(tx).last_hidden_state + jy = self.bonsai_model(jx).last_hidden_state + + np_y = np.asarray(jax.device_get(jy)) + ty_bonsai = torch.tensor(np_y, dtype=torch.float32) + + torch.testing.assert_close(ty_bonsai, ty, rtol=1e-5, atol=3e-2) + + def test_pooled_output_embeddings(self): + key = jax.random.PRNGKey(0) + jx = jax.random.normal(key, self.image_shape, dtype=jnp.float32) + + np_x = np.asarray(jax.device_get(jx)) + tx = torch.tensor(np_x, dtype=torch.float32) + + with torch.inference_mode(): + ty = self.baseline_model(tx).pooler_output + jy = self.bonsai_model(jx).pooler_output + + np_y = np.asarray(jax.device_get(jy)) + ty_bonsai = torch.tensor(np_y, dtype=torch.float32) + + torch.testing.assert_close(ty_bonsai, ty, rtol=1e-5, atol=2e-2) + + +if __name__ == "__main__": + absltest.main()