diff --git a/README.md b/README.md index 5222a78..fac8a78 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ Beyond normal ViT (e.g., dinov2 or siglip), equimo proposes other SotA architect | PartialFormer | [Efficient Vision Transformers with Partial Attention](https://eccv.ecva.net/virtual/2024/poster/1877) | 2024 | ✅ | | SHViT | [SHViT: Single-Head Vision Transformer with Memory Efficient Macro Design](https://arxiv.org/abs/2401.16456) | 2024 | ✅ | | VSSD | [VSSD: Vision Mamba with Non-Causal State Space Duality](https://arxiv.org/abs/2407.18559) | 2024 | ✅ | +| ReduceFormer | [ReduceFormer: Attention with Tensor Reduction by Summation](https://arxiv.org/abs/2406.07488) | 2024 | ✅ | \*: Only contains the [Linear Angular Attention](https://github.com/clementpoiret/Equimo/blob/f8fcc79e45ca65e9deb1d970c4286c0b8562f9c2/equimo/layers/attention.py#L1407) module. It is straight forward to build a ViT around it, but may require an additional `__call__` kwarg to control the `sparse_reg` bool. diff --git a/devenv.lock b/devenv.lock index 86704d5..e5d8127 100644 --- a/devenv.lock +++ b/devenv.lock @@ -3,10 +3,10 @@ "devenv": { "locked": { "dir": "src/modules", - "lastModified": 1743783972, + "lastModified": 1744725539, "owner": "cachix", "repo": "devenv", - "rev": "2f53e2f867e0c2ba18b880e66169366e5f8ca554", + "rev": "e35cb7bb6e6424b83560b5ae0896f75263942191", "type": "github" }, "original": { @@ -74,10 +74,10 @@ }, "nixpkgs": { "locked": { - "lastModified": 1744096231, + "lastModified": 1744536153, "owner": "NixOS", "repo": "nixpkgs", - "rev": "b2b0718004cc9a5bca610326de0a82e6ea75920b", + "rev": "18dd725c29603f582cf1900e0d25f9f1063dbf11", "type": "github" }, "original": { diff --git a/src/equimo/layers/attention.py b/src/equimo/layers/attention.py index 894cb51..9dc2855 100644 --- a/src/equimo/layers/attention.py +++ b/src/equimo/layers/attention.py @@ -1,4 +1,5 @@ -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Sequence, Tuple + import equinox as eqx import jax import jax.numpy as jnp @@ -6,6 +7,7 @@ from einops import rearrange, reduce from jaxtyping import Array, Float, PRNGKeyArray +from equimo.layers.convolution import SingleConvBlock, MBConv from equimo.layers.dropout import DropPathAdd from equimo.layers.ffn import Mlp from equimo.layers.mamba import Mamba2Mixer @@ -1565,6 +1567,177 @@ def __call__( return x +class RFAttention(eqx.Module): + """Attention with Tensor Reduction by Summation[1]. + + A ReLU Linear Attention mechanism replacing matmuls with global + summation and element-wise multiplications. + + Attributes: + dim: Total dimension of the input/output + num_heads: Number of attention heads + head_dim: Dimension of each attention head (dim // num_heads) + + References: + [1]. Yang, J., An, L., & Park, S. I. (2024). ReduceFormer: Attention + with Tensor Reduction by Summation (No. arXiv:2406.07488). arXiv. + https://doi.org/10.48550/arXiv.2406.07488 + """ + + total_dim: int = eqx.field(static=True) + kernel_func: Callable = eqx.field(static=True) + eps: float = eqx.field(static=True) + + qkv: eqx.nn.Conv2d + aggreg: list[eqx.nn.Conv2d] + proj: SingleConvBlock + + def __init__( + self, + in_channels: int, + out_channels: int, + *, + key: PRNGKeyArray, + num_heads: int | None = None, + head_dim: int = 8, + heads_ratio: float = 1.0, + scales: Sequence[int] = (5,), + use_bias: bool = False, + kernel_func: Callable = jax.nn.relu, + # TODO: Benchmark against LN, RMSN, NsLN + norm_layer: eqx.Module = eqx.nn.GroupNorm, + norm_kwargs: dict = {}, + eps: float = 1e-15, + **kwargs, + ): + key_qkv, key_aggreg, key_proj = jr.split(key, 3) + + self.kernel_func = kernel_func + self.eps = eps + num_heads = num_heads or int(in_channels // head_dim * heads_ratio) + total_dim = num_heads * head_dim + self.total_dim = total_dim * (1 + len(scales)) + + self.qkv = eqx.nn.Conv2d( + in_channels=in_channels, + out_channels=3 * total_dim, + kernel_size=1, + padding="SAME", + use_bias=use_bias, + key=key_qkv, + ) + self.aggreg = [ + eqx.nn.Conv2d( + in_channels=3 * total_dim, + out_channels=3 * total_dim, + kernel_size=scale, + padding="SAME", + groups=3 * total_dim, + key=key_aggreg, + use_bias=use_bias, + ) + for scale in scales + ] + # TODO: test different normalizations + self.proj = SingleConvBlock( + in_channels=self.total_dim, + out_channels=out_channels, + kernel_size=1, + use_bias=use_bias, + norm_layer=norm_layer, + norm_kwargs=norm_kwargs, + key=key_proj, + ) + + def __call__( + self, + x: Float[Array, "seqlen height width"], + key: PRNGKeyArray, + inference: Optional[bool] = None, + ) -> Float[Array, "seqlen height width"]: + qkv_base = self.qkv(x) + + aggregated_qkvs = [op(qkv_base) for op in self.aggreg] + all_qkvs = [qkv_base] + aggregated_qkvs + + rearranged_qkvs = [ + rearrange(qkv, "(n d) h w -> n d h w", n=3) for qkv in all_qkvs + ] + multiscale_qkv = jnp.concatenate(rearranged_qkvs, axis=1) + + q, k, v = multiscale_qkv + + q = self.kernel_func(q) + k = self.kernel_func(k) + + sum_k = jnp.sum(k, axis=(-1, -2), keepdims=True) + sum_v = jnp.sum(v, axis=(-1, -2), keepdims=True) + sum_kv = jnp.sum(k * sum_v, axis=(-1, -2), keepdims=True) + sum_q = jnp.sum(q, axis=0, keepdims=True) + + out = (q * sum_kv) / (sum_q * sum_k + self.eps) + out = self.proj(out) + + return out + + +class RFAttentionBlock(eqx.Module): + context_module: RFAttention + local_module: MBConv + + def __init__( + self, + in_channels: int, + *, + key, + head_dim: int = 32, + heads_ratio: float = 1.0, + scales: Sequence[int] = (5,), + rfattn_norm_layer: eqx.Module = eqx.nn.GroupNorm, + norm_kwargs: dict = {}, + expand_ratio: float = 4.0, + mbconv_norm_layers: tuple = (None, None, eqx.nn.GroupNorm), + mbconv_act_layers: tuple = (jax.nn.hard_swish, jax.nn.hard_swish, None), + fuse_mbconv: bool = False, + **kwargs, + ): + key_context, key_local = jr.split(key, 2) + + self.context_module = RFAttention( + in_channels=in_channels, + out_channels=in_channels, + head_dim=head_dim, + heads_ratio=heads_ratio, + scales=scales, + norm_layer=rfattn_norm_layer, + norm_kwargs=norm_kwargs, + key=key_context, + ) + self.local_module = MBConv( + in_channels=in_channels, + out_channels=in_channels, + expand_ratio=expand_ratio, + norm_layers=mbconv_norm_layers, + act_layers=mbconv_act_layers, + use_bias=(True, True, False), + fuse=fuse_mbconv, + key=key_local, + ) + + def __call__( + self, + x: Float[Array, "dim height width"], + key: PRNGKeyArray, + inference: Optional[bool] = None, + ): + key_context, key_local = jr.split(key, 2) + + x += self.context_module(x, inference=inference, key=key_context) + x += self.local_module(x, inference=inference, key=key_local) + + return x + + def get_attention(module: str | eqx.Module) -> eqx.Module: """Get an `eqx.Module` from its common name. diff --git a/src/equimo/layers/convolution.py b/src/equimo/layers/convolution.py index bde5411..ffac2b7 100644 --- a/src/equimo/layers/convolution.py +++ b/src/equimo/layers/convolution.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence +from typing import Callable, Optional, Sequence, Tuple import equinox as eqx import jax @@ -153,9 +153,13 @@ def __init__( out_channels: int, *, key: PRNGKeyArray, - use_norm: bool = True, + kernel_size: int = 3, + stride: int = 1, + padding: str | int = "SAME", + norm_layer: eqx.Module | None = eqx.nn.GroupNorm, norm_max_group: int = 32, act_layer: Callable | None = None, + norm_kwargs: dict = {}, **kwargs, ): """Initialize the SingleConvBlock. @@ -164,25 +168,33 @@ def __init__( in_channels: Number of input channels out_channels: Number of output channels key: PRNG key for initialization - use_norm: Whether to use group normalization (default: True) norm_max_group: Maximum number of groups for GroupNorm (default: 32) act_layer: Optional activation function (default: None) + norm_kwargs: Args passed to the norm layer. This allows disabling + weights of LayerNorm, which do not work well with conv layers **kwargs: Additional arguments passed to Conv layer """ - num_groups = nearest_power_of_2_divisor(out_channels, norm_max_group) - self.conv = eqx.nn.Conv( - num_spatial_dims=2, + self.conv = eqx.nn.Conv2d( in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, key=key, **kwargs, ) - self.norm = ( - eqx.nn.GroupNorm(num_groups, out_channels) - if use_norm - else eqx.nn.Identity() - ) + + # TODO: test + if norm_layer is not None: + if norm_layer == eqx.nn.GroupNorm: + num_groups = nearest_power_of_2_divisor(out_channels, norm_max_group) + self.norm = eqx.nn.GroupNorm(num_groups, out_channels, **norm_kwargs) + else: + self.norm = norm_layer(out_channels, **norm_kwargs) + else: + self.norm = eqx.nn.Identity() + self.act = eqx.nn.Lambda(act_layer) if act_layer else eqx.nn.Identity() def __call__( @@ -258,7 +270,6 @@ def __init__( stride=2, padding=1, use_bias=False, - use_norm=True, act_layer=jax.nn.relu, key=key_conv1, ) @@ -272,7 +283,6 @@ def __init__( stride=1, padding=1, use_bias=False, - use_norm=True, act_layer=jax.nn.relu, key=key_conv2, ), @@ -283,7 +293,6 @@ def __init__( stride=1, padding=1, use_bias=False, - use_norm=True, act_layer=None, key=key_conv3, ), @@ -299,7 +308,6 @@ def __init__( stride=2, padding=1, use_bias=False, - use_norm=True, act_layer=jax.nn.relu, key=key_conv4, ), @@ -310,7 +318,6 @@ def __init__( stride=1, padding=0, use_bias=False, - use_norm=True, act_layer=None, key=key_conv5, ), @@ -645,3 +652,221 @@ def __call__( y = jnp.split(self.conv1(x), [self.hidden_channels]) y.extend(blk(y[-1]) for blk in self.blocks) return self.conv2(jnp.concatenate(y, axis=0)) + + +class MBConv(eqx.Module): + """MobileNet Conv Block with optional fusing from [1]. + + References: + [1]: Nottebaum, M., Dunnhofer, M., & Micheloni, C. (2024). LowFormer: + Hardware Efficient Design for Convolutional Transformer Backbones (No. + arXiv:2409.03460). arXiv. https://doi.org/10.48550/arXiv.2409.03460 + """ + + fused: bool = eqx.field(static=True) + + inverted_conv: SingleConvBlock | None + depth_conv: SingleConvBlock | None + spatial_conv: SingleConvBlock | None + point_conv: SingleConvBlock + + def __init__( + self, + in_channels: int, + out_channels: int, + *, + key: PRNGKeyArray, + mid_channels: int | None = None, + kernel_size: int = 3, + stride: int = 1, + use_bias: Tuple[bool, ...] | bool = False, + expand_ratio: float = 6.0, + norm_layers: Tuple[eqx.Module | None, ...] + | eqx.Module + | None = eqx.nn.GroupNorm, + act_layers: Tuple[Callable | None, ...] | Callable | None = jax.nn.relu6, + fuse: bool = False, + fuse_threshold: int = 256, + fuse_group: bool = False, + fused_conv_groups: int = 1, + **kwargs, + ): + key_inverted, key_depth, key_point = jr.split(key, 3) + + if not isinstance(norm_layers, Tuple): + norm_layers = (norm_layers,) * 3 + if not isinstance(act_layers, Tuple): + act_layers = (act_layers,) * 3 + if isinstance(use_bias, bool): + use_bias: Tuple = (use_bias,) * 3 + if len(use_bias) != 3: + raise ValueError( + f"`use_bias` should be a Tuple of length 3, got: {len(use_bias)}" + ) + if len(norm_layers) != 3: + raise ValueError( + f"`norm_layers` should be a Tuple of length 3, got: {len(norm_layers)}" + ) + if len(act_layers) != 3: + raise ValueError( + f"`act_layers` should be a Tuple of length 3, got: {len(act_layers)}" + ) + + mid_channels = ( + mid_channels + if mid_channels is not None + else round(in_channels * expand_ratio) + ) + self.fused = fuse and in_channels <= fuse_threshold + + self.inverted_conv = ( + SingleConvBlock( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + norm_layer=norm_layers[0], + act_layer=act_layers[0], + use_bias=use_bias[0], + padding="SAME", + key=key_inverted, + ) + if not self.fused + else None + ) + self.depth_conv = ( + SingleConvBlock( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + groups=mid_channels, + norm_layer=norm_layers[1], + act_layer=act_layers[1], + use_bias=use_bias[1], + padding="SAME", + key=key_depth, + ) + if not self.fused + else None + ) + self.spatial_conv = ( + SingleConvBlock( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + groups=2 + if fuse_group and fused_conv_groups == 1 + else fused_conv_groups, + norm_layer=norm_layers[0], + act_layer=act_layers[0], + use_bias=use_bias[0], + padding="SAME", + key=key_depth, + ) + if self.fused + else None + ) + self.point_conv = SingleConvBlock( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + norm_layer=norm_layers[2], + act_layer=act_layers[2], + use_bias=use_bias[2], + padding="SAME", + key=key_point, + ) + + def __call__( + self, + x: Float[Array, "channels height width"], + key: PRNGKeyArray, + inference: Optional[bool] = None, + ): + if self.fused: + x = self.spatial_conv(x) + else: + x = self.inverted_conv(x) + x = self.depth_conv(x) + x = self.point_conv(x) + + return x + + +class DSConv(eqx.Module): + depth_conv: SingleConvBlock + point_conv: SingleConvBlock + + def __init__( + self, + in_channels: int, + out_channels: int, + *, + key: PRNGKeyArray, + kernel_size: int = 3, + stride: int = 1, + use_bias: Tuple[bool, ...] | bool = False, + norm_layers: Tuple[eqx.Module | None, ...] + | eqx.Module + | None = eqx.nn.GroupNorm, + act_layers: Tuple[Callable | None, ...] | Callable | None = jax.nn.relu6, + **kwargs, + ): + key_depth, key_point = jr.split(key, 2) + + if not isinstance(norm_layers, Tuple): + norm_layers = (norm_layers,) * 2 + if not isinstance(act_layers, Tuple): + act_layers = (act_layers,) * 2 + if isinstance(use_bias, bool): + use_bias: Tuple = (use_bias,) * 2 + if len(use_bias) != 2: + raise ValueError( + f"`use_bias` should be a Tuple of length 2, got: {len(use_bias)}" + ) + if len(norm_layers) != 2: + raise ValueError( + f"`norm_layers` should be a Tuple of length 2, got: {len(norm_layers)}" + ) + if len(act_layers) != 2: + raise ValueError( + f"`act_layers` should be a Tuple of length 2, got: {len(act_layers)}" + ) + + self.depth_conv = SingleConvBlock( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=stride, + groups=in_channels, + norm_layer=norm_layers[0], + act_layer=act_layers[0], + use_bias=use_bias[0], + padding="SAME", + key=key_depth, + ) + self.point_conv = SingleConvBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + norm_layer=norm_layers[1], + act_layer=act_layers[1], + use_bias=use_bias[1], + padding="SAME", + key=key_point, + ) + + def __call__( + self, + x: Float[Array, "channels height width"], + key: PRNGKeyArray, + inference: Optional[bool] = None, + ): + x = self.depth_conv(x) + x = self.point_conv(x) + + return x diff --git a/src/equimo/layers/patch.py b/src/equimo/layers/patch.py index b33b6b7..eb5b25c 100644 --- a/src/equimo/layers/patch.py +++ b/src/equimo/layers/patch.py @@ -278,7 +278,7 @@ def __init__( kernel_size=1, stride=1, padding=0, - use_norm=False, + norm_layer=None, act_layer=jax.nn.relu, key=key_conv1, ) @@ -289,7 +289,7 @@ def __init__( stride=2, padding=1, groups=hidden_dim, - use_norm=False, + norm_layer=None, act_layer=jax.nn.relu, key=key_conv2, ) @@ -299,7 +299,7 @@ def __init__( kernel_size=1, stride=1, padding=0, - use_norm=False, + norm_layer=None, act_layer=jax.nn.relu, key=key_conv3, ) diff --git a/src/equimo/models/__init__.py b/src/equimo/models/__init__.py index 07b7c12..e101fd4 100644 --- a/src/equimo/models/__init__.py +++ b/src/equimo/models/__init__.py @@ -1,6 +1,12 @@ from .fastervit import FasterViT from .mlla import Mlla from .partialformer import PartialFormer +from .reduceformer import ( + ReduceFormer, + reduceformer_backbone_b1, + reduceformer_backbone_b2, + reduceformer_backbone_b3, +) from .shvit import SHViT from .vit import VisionTransformer from .vssd import Vssd diff --git a/src/equimo/models/reduceformer.py b/src/equimo/models/reduceformer.py new file mode 100644 index 0000000..3239dd4 --- /dev/null +++ b/src/equimo/models/reduceformer.py @@ -0,0 +1,330 @@ +from ssl import DefaultVerifyPaths +from typing import Callable, Literal, Optional, Tuple + +import equinox as eqx +import jax +import jax.random as jr +from einops import reduce +from jaxtyping import Array, Float, PRNGKeyArray + +from equimo.layers.activation import get_act +from equimo.layers.attention import RFAttentionBlock +from equimo.layers.convolution import DSConv, MBConv, SingleConvBlock +from equimo.layers.norm import get_norm + + +class BlockChunk(eqx.Module): + residuals: list[bool] = eqx.field(static=True) + blocks: list[DSConv | MBConv | RFAttentionBlock] + + def __init__( + self, + in_channels: int, + out_channels: int, + depth: int, + *, + key: PRNGKeyArray, + block_type: Literal["conv", "attention"] = "conv", + stride: int = 1, + expand_ratio: float = 1.0, + scales: Tuple[int, ...] = (5,), + head_dim: int = 32, + heads_ratio: float = 1.0, + norm_layer: eqx.Module = eqx.nn.GroupNorm, + act_layer: Callable = jax.nn.hard_swish, + fewer_norm: bool = False, + fuse_mbconv: bool = False, + **kwargs, + ): + key, *block_subkeys = jr.split(key, depth + 1) + + keys_to_spread = [ + k for k, v in kwargs.items() if isinstance(v, list) and len(v) == depth + ] + + blocks = [] + residuals = [] + + # TODO: simplify logic + match block_type: + case "conv": + block = DSConv if expand_ratio == 1.0 else MBConv + if fewer_norm: + use_bias: Tuple[bool, ...] | bool = ( + (True, False) if block == DSConv else (True, True, False) + ) + norm_layer = ( + (None, norm_layer) + if block == DSConv + else (None, None, norm_layer) + ) + else: + use_bias = False + + for i in range(depth): + config = kwargs | {k: kwargs[k][i] for k in keys_to_spread} + + if block == MBConv: + config["expand_ratio"] = expand_ratio + config["fuse"] = fuse_mbconv + + blocks.append( + block( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + stride=stride if i == 0 else 1, + use_bias=use_bias, + norm_layers=norm_layer, + act_layers=(act_layer, None) + if block == DSConv + else (act_layer, act_layer, None), + **config, + key=block_subkeys[i], + ) + ) + residuals.append( + (in_channels == out_channels and stride == 1) or i > 0 + ) + + case "attention": + blocks.append( + MBConv( + in_channels, + out_channels, + stride=2, # TODO: make downsampling optional + expand_ratio=expand_ratio, + norm_layers=(None, None, norm_layer), + act_layers=(act_layer, act_layer, None), + use_bias=(True, True, False), + fuse=fuse_mbconv, + key=key, + ) + ) + for i in range(depth): + blocks.append( + RFAttentionBlock( + in_channels=out_channels, + head_dim=head_dim, + heads_ratio=heads_ratio, + scales=scales, + rfattn_norm_layer=norm_layer, + expand_ratio=expand_ratio, + mbconv_norm_layers=(None, None, norm_layer), + mbconv_act_layers=(act_layer, act_layer, None), + fuse_mbconv=fuse_mbconv, + key=block_subkeys[i], + ) + ) + residuals.append(False) + + self.blocks = blocks + self.residuals = residuals + + def __call__( + self, + x: Float[Array, "..."], + *, + key: PRNGKeyArray, + inference: Optional[bool] = None, + **kwargs, + ) -> Float[Array, "..."]: + keys = jr.split(key, len(self.blocks)) + + # TODO: Dropout and Stochastic Path Add + for blk, residual, key_block in zip(self.blocks, self.residuals, keys): + res = blk(x, inference=inference, key=key_block, **kwargs) + x = x + res if residual else res + + return x + + +class ReduceFormer(eqx.Module): + input_stem: eqx.nn.Sequential + blocks: list[BlockChunk] + head: eqx.nn.Linear | eqx.nn.Identity + + def __init__( + self, + in_channels: int, + widths: list[int], + depths: list[int], + block_types: list[Literal["conv", "attention"]], + *, + key: PRNGKeyArray, + heads_dim: int = 32, + expand_ratio: float = 4.0, + norm_layer: eqx.Module | str = eqx.nn.GroupNorm, + act_layer: Callable | str = jax.nn.hard_swish, + fuse_mbconv: bool = False, + num_classes: int | None = 1000, + **kwargs, + ): + if not len(widths) == len(depths) == len(block_types): + raise ValueError( + "`widths`, `depths`, `strides`, and `expand_ratios` and `block_types` must have the same lengths." + ) + + key_stem, key_head, *key_blocks = jr.split(key, 3 + len(depths)) + + act_layer = get_act(act_layer) + norm_layer = get_norm(norm_layer) + + width_stem = widths.pop(0) + depth_stem = depths.pop(0) + block_type_stem = block_types.pop(0) + key_block_stem = key_blocks.pop(0) + + self.input_stem = eqx.nn.Sequential( + [ + SingleConvBlock( + in_channels=in_channels, + out_channels=width_stem, + kernel_size=3, + stride=2, + padding="SAME", + use_bias=False, + norm_layer=norm_layer, + act_layer=act_layer, + key=key_stem, + ), + BlockChunk( + in_channels=width_stem, + out_channels=width_stem, + depth=depth_stem, + block_type=block_type_stem, + stride=1, + expand_ratio=1.0, + norm_layer=norm_layer, + act_layer=act_layer, + key=key_block_stem, + ), + ] + ) + + self.blocks = [ + BlockChunk( + in_channels=widths[i - 1] if i > 0 else width_stem, + out_channels=widths[i], + depth=depth, + block_type=block_type, + stride=2, + expand_ratio=expand_ratio, + norm_layer=norm_layer, + act_layer=act_layer, + fuse_mbconv=fuse_mbconv, + key=key_block, + ) + for i, (depth, block_type, key_block) in enumerate( + zip(depths, block_types, key_blocks) + ) + ] + + self.head = ( + eqx.nn.Linear( + in_features=widths[-1], out_features=num_classes, key=key_head + ) + if num_classes and num_classes > 0 + else eqx.nn.Identity() + ) + + def features( + self, + x: Float[Array, "channels height width"], + key: PRNGKeyArray, + inference: Optional[bool] = None, + **kwargs, + ) -> Float[Array, "seqlen dim"]: + """Extract features from input image. + + Args: + x: Input image tensor + inference: Whether to enable dropout during inference + key: PRNG key for random operations + + Returns: + Processed feature tensor + """ + key_stem, *key_blocks = jr.split(key, len(self.blocks) + 1) + + x = self.input_stem(x, key=key_stem) + + for i, blk in enumerate(self.blocks): + x = blk(x, inference=inference, key=key_blocks[i]) + + return x + + def __call__( + self, + x: Float[Array, "channels height width"], + key: PRNGKeyArray = jr.PRNGKey(42), + inference: Optional[bool] = None, + **kwargs, + ) -> Float[Array, "num_classes"]: + """Process input image through the full network. + + Args: + x: Input image tensor + inference: Whether to enable dropout during inference + key: PRNG key for random operations + + Returns: + Classification logits + """ + x = self.features(x, inference=inference, key=key, **kwargs) + + x = reduce(x, "c h w -> c", "mean") + + x = self.head(x) + + return x + + +def reduceformer_backbone_b1(**kwargs) -> ReduceFormer: + backbone = ReduceFormer( + widths=[16, 32, 64, 128, 256], + depths=[1, 2, 3, 3, 4], + block_types=[ + "conv", + "conv", + "conv", + "attention", + "attention", + ], + heads_dim=16, + **kwargs, + ) + return backbone + + +def reduceformer_backbone_b2(**kwargs) -> ReduceFormer: + backbone = ReduceFormer( + widths=[24, 48, 96, 192, 384], + depths=[1, 3, 4, 4, 6], + block_types=[ + "conv", + "conv", + "conv", + "attention", + "attention", + ], + heads_dim=32, + **kwargs, + ) + return backbone + + +def reduceformer_backbone_b3(**kwargs) -> ReduceFormer: + backbone = ReduceFormer( + widths=[32, 64, 128, 256, 512], + depths=[1, 4, 6, 6, 9], + block_types=[ + "conv", + "conv", + "conv", + "attention", + "attention", + ], + heads_dim=32, + **kwargs, + ) + return backbone diff --git a/src/equimo/utils.py b/src/equimo/utils.py index 88e882b..3d79d67 100644 --- a/src/equimo/utils.py +++ b/src/equimo/utils.py @@ -1,12 +1,12 @@ +import typing as t from functools import partial +import equinox as eqx import jax import jax.numpy as jnp -from jaxtyping import Array, Float - -import typing as t +import jax.random as jr import numpy as np - +from jaxtyping import Array, Float _ArrayLike = t.Union[np.ndarray, jnp.ndarray] @@ -203,3 +203,50 @@ def pool_sd( raise ValueError(f"Unknown pool type {pool_type}") return x + + +def count_params(model: eqx.Module): + num_params = sum( + x.size for x in jax.tree_util.tree_leaves(eqx.filter(model, eqx.is_array)) + ) + return num_params / 1_000_000 + + +def cost_analysis(model: eqx.Module, input_example: Float[Array, "..."]): + """Estimates the memory usage, flops, and #params of a model's forward pass. + + This function JIT-compiles the model's forward pass for a given input, + retrieves the cost analysis, and extracts the estimated bytes accessed, + converting it to Mebibytes (MiB), the flops converting it to GigaFLOPs + (GFLOPs), and the number of parameters in millions. + + Args: + model: The Equinox model. + x: An example input tensor for the model. + + Returns: + A dict containing the relevant information. + """ + key = jr.PRNGKey(42) + + @jax.jit + def fpass(x): + return model(x, inference=True, key=key) + + analysis: dict | list[dict] = fpass.lower(input_example).compile().cost_analysis() + cost_dict: dict = analysis[0] if isinstance(analysis, list) else analysis + + # Memory + memory_mib = cost_dict.get("bytes accessed", 0.0) / (1024 * 1024) + + # Flops + gflops = cost_dict.get("flops", 0.0) / 1_000_000_000 + + # Params + mparams = count_params(model) + + return { + "memory_mib": memory_mib, + "gflops": gflops, + "mparams": mparams, + } diff --git a/tests/test_models.py b/tests/test_models.py index d81310a..b2618be 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -135,3 +135,27 @@ def test_load_pretrained_model(): assert features.shape[-1] == 384 # DINOv2-S has embedding dimension of 384 assert jnp.all(jnp.isfinite(features)) # Check for NaN/Inf values + + +def test_reduceformer(): + """Test creation and inference of a ReduceFormer model.""" + key = jr.PRNGKey(42) + + x = jr.normal(key, (3, 64, 64)) + model = em.reduceformer_backbone_b1(in_channels=3, num_classes=10, key=key) + y_hat = model(x, key=key) + + assert len(y_hat) == 10 + + +def test_fused_reduceformer(): + """Test creation and inference of a ReduceFormer model with fused mbconv.""" + key = jr.PRNGKey(42) + + x = jr.normal(key, (3, 64, 64)) + model = em.reduceformer_backbone_b1( + in_channels=3, num_classes=10, fuse_mbconv=True, key=key + ) + y_hat = model(x, key=key) + + assert len(y_hat) == 10