diff --git a/bonsai/models/ConvNext/__init__.py b/bonsai/models/ConvNext/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bonsai/models/ConvNext/modeling.py b/bonsai/models/ConvNext/modeling.py new file mode 100644 index 00000000..04bb98b5 --- /dev/null +++ b/bonsai/models/ConvNext/modeling.py @@ -0,0 +1,181 @@ +from typing import Optional, Sequence + +import jax +import jax.numpy as jnp +from flax import nnx + + +class DropPath(nnx.Module): + """ + Stochastic depth (DropPath) module, compatible with JAX jit. + """ + + def __init__(self, drop_prob: float = 0.0): + self.drop_prob = drop_prob + + def __call__(self, x, rng: Optional[jax.Array] = None, train: bool = True): + train_flag = jnp.asarray(train) + + if rng is None: + rng = jax.random.PRNGKey(0) + + def apply_drop(_): + keep_prob = jnp.asarray(1.0) - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + mask = jax.random.bernoulli(rng, p=keep_prob, shape=shape) + return (x * mask) / keep_prob + + def no_drop(_): + return x + + cond = jnp.logical_or(self.drop_prob == 0.0, jnp.logical_not(train_flag)) + return jax.lax.cond(cond, no_drop, apply_drop, operand=None) + + +class Block(nnx.Module): + def __init__(self, dim: int, drop_path: float = 0.0, layer_scale_init_value=1e-6, *, rngs: nnx.Rngs): + self.dwconv = nnx.Conv( + in_features=dim, + out_features=dim, + kernel_size=(7, 7), + padding=3, + feature_group_count=dim, + rngs=rngs, + ) + self.norm = nnx.LayerNorm(dim, epsilon=1e-6, rngs=rngs) + self.pwconv1 = nnx.Linear(dim, 4 * dim, rngs=rngs) + # Use exact GELU to match PyTorch + self.activation = lambda x: jax.nn.gelu(x, approximate=False) + self.pwconv2 = nnx.Linear(4 * dim, dim, rngs=rngs) + + self.gamma = nnx.Param(layer_scale_init_value * jnp.ones((dim))) if layer_scale_init_value > 0 else None + self.drop_path = DropPath(drop_path) + + def __call__(self, x, rng: Optional[jax.Array] = None, train: bool = True): + if rng is None: + rng = jax.random.PRNGKey(0) + + input_ = x + x = self.dwconv(x) + x = self.norm(x) + x = self.pwconv1(x) + x = self.activation(x) + x = self.pwconv2(x) + + if self.gamma is not None: + x = self.gamma.value * x + + x = input_ + self.drop_path(x, rng=rng, train=train) + return x + + +class EmbeddingsWrapper(nnx.Module): + """ + Thin wrapper so `embeddings` returns NCHW (PyTorch style) like HF ConvNeXt 'embeddings'. + Internally uses a provided `stem` module (which expects NHWC input and returns NHWC). + """ + + def __init__(self, stem: nnx.Module): + self.stem = stem + + def __call__(self, x): + # stem expects NHWC (JAX image layout). If test provides NHWC, OK. + # Convert stem output (NHWC) -> NCHW to match HF torch embeddings. + out = self.stem(x) + return jnp.transpose(out, (0, 3, 1, 2)) + + +class ConvNeXt(nnx.Module): + def __init__( + self, + in_chans: int = 3, + num_classes: int = 1000, + depths: Sequence[int] = (3, 3, 27, 3), + dims: Sequence[int] = (192, 384, 768, 1536), + drop_path_rate: float = 0.0, + layer_scale_init_value: float = 1e-6, + head_init_scale: float = 1.0, + *, + rngs: nnx.Rngs, + ): + self.downsample_layers = nnx.List() + self.depths = depths + + # stem: produces NHWC + stem = nnx.Sequential( + nnx.Conv(in_features=in_chans, out_features=dims[0], kernel_size=(4, 4), strides=(4, 4), rngs=rngs), + nnx.LayerNorm(dims[0], epsilon=1e-6, rngs=rngs), + ) + + self.embeddings = EmbeddingsWrapper(stem) + + self.downsample_layers.append(stem) + + for i in range(3): + downsample_layer = nnx.Sequential( + nnx.LayerNorm(dims[i], epsilon=1e-6, rngs=rngs), + nnx.Conv( + in_features=dims[i], + out_features=dims[i + 1], + kernel_size=(2, 2), + strides=(2, 2), + rngs=rngs, + ), + ) + self.downsample_layers.append(downsample_layer) + + # stages and blocks + self.stages = nnx.List() + dp_rates = list(jnp.linspace(0, drop_path_rate, sum(depths))) + curr = 0 + for i in range(4): + stage_blocks = nnx.List() + for j in range(depths[i]): + stage_blocks.append( + Block( + dim=dims[i], + drop_path=dp_rates[curr + j], + layer_scale_init_value=layer_scale_init_value, + rngs=rngs, + ) + ) + self.stages.append(stage_blocks) + curr += depths[i] + + self.norm = nnx.LayerNorm(dims[-1], epsilon=1e-6, rngs=rngs) + self.head = nnx.Linear(dims[-1], num_classes, rngs=rngs) + + def __call__(self, x, rng: Optional[jax.Array] = None, train: bool = False): + """ + Forward pass. + `x` expected in NHWC (JAX default). This function keeps the model internally NHWC. + """ + if rng is None: + rng = jax.random.PRNGKey(0) + + for i in range(4): + x = self.downsample_layers[i](x) + + for block in self.stages[i]: + rng, block_rng = jax.random.split(rng) + x = block(x, rng=block_rng, train=train) + + # global average pool over H, W (NHWC) + x = jnp.mean(x, axis=(1, 2)) + + x = self.norm(x) + x = self.head(x) + return x + + +@jax.jit +def forward( + graph_def: nnx.GraphDef, + state: nnx.State, + x: jax.Array, + *, + rng: Optional[jax.Array] = None, + train: bool = False, +): + model = nnx.merge(graph_def, state) + return model(x, rng=rng, train=train) diff --git a/bonsai/models/ConvNext/params.py b/bonsai/models/ConvNext/params.py new file mode 100644 index 00000000..2591186a --- /dev/null +++ b/bonsai/models/ConvNext/params.py @@ -0,0 +1,207 @@ +import logging +import re +from typing import Callable + +import h5py +import jax +from etils import epath +from flax import nnx + +from bonsai.models.ConvNext import modeling as model_lib + + +def _get_key_and_transform_mapping(): + """ + Creates the mapping from the TF/Keras .h5 keys to the JAX/NNX keys. + The transform is `None` because TF and JAX have the same weight shapes. + """ + # Prefix for the main 'convnext' model parts + convnext_prefix = r"^convnext/tf_conv_next_for_image_classification/convnext/" + + # A separate prefix for the final 'classifier' head + classifier_prefix = r"^classifier/tf_conv_next_for_image_classification/classifier/" + + mapping = { + # --- Stem (downsample_layers.0) --- + r"" + convnext_prefix + r"embeddings/patch_embeddings/kernel:0$": ("downsample_layers.0.layers.0.kernel", None), + r"" + convnext_prefix + r"embeddings/patch_embeddings/bias:0$": ("downsample_layers.0.layers.0.bias", None), + # This is the 'layernorm' right after patch_embeddings + r"" + convnext_prefix + r"embeddings/layernorm/beta:0$": ( + "downsample_layers.0.layers.1.bias", + None, # Keras 'beta' is 'bias' + ), + r"" + convnext_prefix + r"embeddings/layernorm/gamma:0$": ( + "downsample_layers.0.layers.1.scale", + None, # Keras 'gamma' is 'scale' + ), + # --- Downsampling Layers (Stages 1, 2, 3) --- + r"" + convnext_prefix + r"encoder/stages\.([1-3])/downsampling_layer\.0/beta:0$": ( + r"downsample_layers.\1.layers.0.bias", + None, + ), + r"" + convnext_prefix + r"encoder/stages\.([1-3])/downsampling_layer\.0/gamma:0$": ( + r"downsample_layers.\1.layers.0.scale", + None, + ), + r"" + convnext_prefix + r"encoder/stages\.([1-3])/downsampling_layer\.1/kernel:0$": ( + r"downsample_layers.\1.layers.1.kernel", + None, + ), + r"" + convnext_prefix + r"encoder/stages\.([1-3])/downsampling_layer\.1/bias:0$": ( + r"downsample_layers.\1.layers.1.bias", + None, + ), + # --- Main Blocks (All Stages 0-3, All Layers 0-N) --- + r"" + convnext_prefix + r"encoder/stages\.([0-3])/layers\.([0-9]+)/dwconv/kernel:0$": ( + r"stages.\1.\2.dwconv.kernel", + None, + ), + r"" + convnext_prefix + r"encoder/stages\.([0-3])/layers\.([0-9]+)/dwconv/bias:0$": ( + r"stages.\1.\2.dwconv.bias", + None, + ), + r"" + convnext_prefix + r"encoder/stages\.([0-3])/layers\.([0-9]+)/layernorm/beta:0$": ( + r"stages.\1.\2.norm.bias", + None, + ), + r"" + convnext_prefix + r"encoder/stages\.([0-3])/layers\.([0-9]+)/layernorm/gamma:0$": ( + r"stages.\1.\2.norm.scale", + None, + ), + r"" + convnext_prefix + r"encoder/stages\.([0-3])/layers\.([0-9]+)/pwconv1/kernel:0$": ( + r"stages.\1.\2.pwconv1.kernel", + None, + ), + r"" + convnext_prefix + r"encoder/stages\.([0-3])/layers\.([0-9]+)/pwconv1/bias:0$": ( + r"stages.\1.\2.pwconv1.bias", + None, + ), + r"" + convnext_prefix + r"encoder/stages\.([0-3])/layers\.([0-9]+)/pwconv2/kernel:0$": ( + r"stages.\1.\2.pwconv2.kernel", + None, + ), + r"" + convnext_prefix + r"encoder/stages\.([0-3])/layers\.([0-9]+)/pwconv2/bias:0$": ( + r"stages.\1.\2.pwconv2.bias", + None, + ), + # This is the 'gamma' param in your nnx.Block + r"" + convnext_prefix + r"encoder/stages\.([0-3])/layers\.([0-9]+)/layer_scale_parameter:0$": ( + r"stages.\1.\2.gamma", + None, + ), + # --- Head (Final Norm and Linear Layer) --- + # Final LayerNorm before the classifier + r"" + convnext_prefix + r"layernorm/beta:0$": ("norm.bias", None), + r"" + convnext_prefix + r"layernorm/gamma:0$": ("norm.scale", None), + # Final Linear 'head' layer (note the different prefix) + r"" + classifier_prefix + r"kernel:0$": ("head.kernel", None), + r"" + classifier_prefix + r"bias:0$": ("head.bias", None), + } + return mapping + + +def _h5_key_to_jax_key(mapping, source_key): + """Map a h5 key to exactly one JAX key & transform, else warn/error.""" + subs = [ + (re.sub(pat, repl, source_key), transform) + for pat, (repl, transform) in mapping.items() + if re.match(pat, source_key) + ] + if not subs: + logging.warning(f"No mapping found for key: {source_key!r}") + return None, None + if len(subs) > 1: + keys = [s for s, _ in subs] + raise ValueError(f"Multiple mappings found for {source_key!r}: {keys}") + return subs[0] + + +def _assign_weights(keys, tensor, state_dict, h5_key, transform): + """Recursively descend into state_dict and assign the (possibly permuted/reshaped) tensor.""" + key, *rest = keys + if not rest: + if transform is not None: + permute, reshape = transform + if permute: + tensor = tensor.transpose(permute) + if reshape: + tensor = tensor.reshape(reshape) + + # Ensure shapes match before assigning + if key not in state_dict: + raise KeyError(f"JAX key {key} (from {h5_key}) not found in model state.") + if tensor.shape != state_dict[key].shape: + raise ValueError( + f"Shape mismatch for {h5_key} -> {'.'.join(map(str, keys))}:\n" + f"H5 shape: {tensor.shape} vs Model shape: {state_dict[key].shape}" + ) + + state_dict[key] = tensor + else: + # Recurse into the nested dictionary/list + _assign_weights(rest, tensor, state_dict[key], h5_key, transform) + + +def _stoi(s): + try: + return int(s) + except ValueError: + return s + + +def _create_convnext_from_pretrained( + model_cls: Callable[..., model_lib.ConvNeXt], + file_dir: str, + num_classes: int = 1000, + model_name: str | None = None, + *, + mesh: jax.sharding.Mesh | None = None, +): + """ + Load h5 weights from a file, then convert & merge into a flax.nnx ResNet model. + Returns: + A flax.nnx.Model instance with loaded parameters. + """ + files = list(epath.Path(file_dir).expanduser().glob("*.h5")) + if not files: + raise ValueError(f"No .h5 files found in {file_dir}") + + state_dict = {} + for f in files: + with h5py.File(f, "r") as hf: + # Recursively visit all objects (groups and datasets) + hf.visititems( + lambda name, obj: ( + # If it's a Dataset (a tensor), read it (obj[()]) and add to dict + state_dict.update({name: obj[()]}) if isinstance(obj, h5py.Dataset) else None + ) + ) + + model = model_cls(num_classes=num_classes, rngs=nnx.Rngs(params=0)) + graph_def, abs_state = nnx.split(model) + jax_state = abs_state.to_pure_dict() + + mapping = _get_key_and_transform_mapping() + + # Keep track of unassigned JAX keys to warn the user + assigned_jax_keys = set() + + for h5_key, tensor in state_dict.items(): + jax_key, transform = _h5_key_to_jax_key(mapping, h5_key) + if jax_key is None: + continue + + keys = [_stoi(k) for k in jax_key.split(".")] + try: + _assign_weights(keys, tensor, jax_state, h5_key, transform) + assigned_jax_keys.add(jax_key) + except (KeyError, ValueError) as e: + logging.error(f"Failed to assign weight for {h5_key}:\n{e}") + + if mesh is not None: + sharding = nnx.get_named_sharding(abs_state, mesh).to_pure_dict() + jax_state = jax.device_put(jax_state, sharding) + else: + jax_state = jax.device_put(jax_state, jax.devices()[0]) + + return nnx.merge(graph_def, jax_state) diff --git a/bonsai/models/ConvNext/tests/__init__.py b/bonsai/models/ConvNext/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bonsai/models/ConvNext/tests/run_model.py b/bonsai/models/ConvNext/tests/run_model.py new file mode 100644 index 00000000..940fd72a --- /dev/null +++ b/bonsai/models/ConvNext/tests/run_model.py @@ -0,0 +1,72 @@ +import os +import time + +import jax +import jax.numpy as jnp +from flax import nnx +from huggingface_hub import snapshot_download + +from bonsai.models.ConvNext import modeling as model_lib +from bonsai.models.ConvNext import params + + +def run_model(): + # 1. Download Model Weights + model_name = "facebook/convnext-large-224" + model_ckpt_path = snapshot_download(repo_id=model_name, allow_patterns="*.h5") + print(f"Downloaded Keras weights from: {model_name}") + + # 2. Load Pretrained Model + + model = params._create_convnext_from_pretrained( + model_cls=model_lib.ConvNeXt, + file_dir=model_ckpt_path, + num_classes=1000, # ImageNet-1K + ) + + graphdef, state = nnx.split(model) + + # 3. Prepare dummy input + batch_size, channels, image_size = 8, 3, 224 + dummy_input = jnp.ones((batch_size, image_size, image_size, channels), dtype=jnp.float32) + + key = jax.random.PRNGKey(0) + key, warmup_key, prof_key, time_key = jax.random.split(key, 4) + + # 4. Warmup + profiling + + _ = model_lib.forward(graphdef, state, dummy_input, rng=warmup_key, train=False).block_until_ready() + + # Profile a Few Steps + + prof_keys = jax.random.split(prof_key, 5) + + jax.profiler.start_trace("/tmp/profile-convnext") + for i in range(5): + logits = model_lib.forward(graphdef, state, dummy_input, rng=prof_keys[i], train=False) + jax.block_until_ready(logits) + jax.profiler.stop_trace() + print("Profiling complete. Trace saved to /tmp/profile-convnext") + + # 5. Timed execution + + time_keys = jax.random.split(time_key, 10) + + t0 = time.perf_counter() + for i in range(10): + logits = model_lib.forward(graphdef, state, dummy_input, rng=time_keys[i], train=False).block_until_ready() + + step_time = (time.perf_counter() - t0) / 10 + print(f"Step time: {step_time:.4f} s") + print(f"Throughput: {batch_size / step_time:.2f} images/s") + + # 6. Show Top-1 Predicted Class + + pred = jnp.argmax(logits, axis=-1) + print("Predicted classes (batch):", pred) + + +if __name__ == "__main__": + run_model() + +__all__ = ["run_model"] diff --git a/bonsai/models/ConvNext/tests/test_outputs_ConvNext.py b/bonsai/models/ConvNext/tests/test_outputs_ConvNext.py new file mode 100644 index 00000000..fcb4266f --- /dev/null +++ b/bonsai/models/ConvNext/tests/test_outputs_ConvNext.py @@ -0,0 +1,88 @@ +import jax +import jax.numpy as jnp +import torch +from absl.testing import absltest +from huggingface_hub import snapshot_download +from transformers import ConvNextForImageClassification + +from bonsai.models.ConvNext import modeling as model_lib +from bonsai.models.ConvNext import params + + +class TestModuleForwardPasses(absltest.TestCase): + def setUp(self): + super().setUp() + model_name = "facebook/convnext-large-224" + model_ckpt_path = snapshot_download(model_name) + + self.bonsai_model = params._create_convnext_from_pretrained(model_lib.ConvNeXt, model_ckpt_path) + self.baseline_model = ConvNextForImageClassification.from_pretrained(model_name) + + self.bonsai_model.eval() + self.baseline_model.eval() + + self.batch_size = 1 + self.image_shape = (self.batch_size, 224, 224, 3) + + def test_embeddings(self): + torch_emb = self.baseline_model.convnext.embeddings + nnx_emb = self.bonsai_model.embeddings + + jx = jax.random.normal(jax.random.key(0), self.image_shape, dtype=jnp.float32) + tx = torch.tensor(jx).permute(0, 3, 1, 2) + + with torch.no_grad(): + ty = torch_emb(tx) + jy = nnx_emb(jx) + + torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=1e-5) + + def test_blocks_isolated(self): + """ + Compare every ConvNeXt block between HuggingFace PyTorch and the JAX/NNX model. + Tests each block in ISOLATION with fresh random input to avoid error accumulation. + """ + jax_model = self.bonsai_model + torch_model = self.baseline_model.convnext + torch_model.eval() + + # Dimensions for each stage (NHWC for JAX, NCHW for Torch) + # Stage 0: 56x56, dim=192 + # Stage 1: 28x28, dim=384 + # Stage 2: 14x14, dim=768 + # Stage 3: 7x7, dim=1536 + stage_dims = [(56, 192), (28, 384), (14, 768), (7, 1536)] + + key = jax.random.PRNGKey(42) + + for stage_idx, stage_blocks in enumerate(jax_model.stages): + h, dim = stage_dims[stage_idx] + + for block_idx in range(len(stage_blocks)): + key, sub = jax.random.split(key) + + jx_input = jax.random.normal(sub, (1, h, h, dim), dtype=jnp.float32) + tx_input = torch.tensor(jx_input).permute(0, 3, 1, 2) + + key, sub = jax.random.split(key) + j_out = jax_model.stages[stage_idx][block_idx](jx_input, rng=sub, train=False) + + with torch.no_grad(): + t_out = torch_model.encoder.stages[stage_idx].layers[block_idx](tx_input) + + # Convert Torch output to NHWC for comparison + t_out_nhwc = t_out.permute(0, 2, 3, 1) + + torch.testing.assert_close(torch.tensor(j_out), t_out_nhwc, rtol=5e-4, atol=5e-4) + + def test_full(self): + jx = jax.random.normal(jax.random.key(0), self.image_shape, dtype=jnp.float32) + tx = torch.tensor(jx).permute(0, 3, 1, 2) + with torch.no_grad(): + ty = self.baseline_model(tx).logits + jy = self.bonsai_model(jx) + torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + absltest.main() diff --git a/bonsai/models/__init__.py b/bonsai/models/__init__.py deleted file mode 100644 index 49a8b6d6..00000000 --- a/bonsai/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "2025.08.09"