Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Beyond normal ViT (e.g., dinov2 or siglip), equimo proposes other SotA architect
| PartialFormer | [Efficient Vision Transformers with Partial Attention](https://eccv.ecva.net/virtual/2024/poster/1877) | 2024 | ✅ |
| SHViT | [SHViT: Single-Head Vision Transformer with Memory Efficient Macro Design](https://arxiv.org/abs/2401.16456) | 2024 | ✅ |
| VSSD | [VSSD: Vision Mamba with Non-Causal State Space Duality](https://arxiv.org/abs/2407.18559) | 2024 | ✅ |
| ReduceFormer | [ReduceFormer: Attention with Tensor Reduction by Summation](https://arxiv.org/abs/2406.07488) | 2024 | ✅ |

\*: Only contains the [Linear Angular Attention](https://github.com/clementpoiret/Equimo/blob/f8fcc79e45ca65e9deb1d970c4286c0b8562f9c2/equimo/layers/attention.py#L1407) module. It is straight forward to build a ViT around it, but may require an additional `__call__` kwarg to control the `sparse_reg` bool.

Expand Down
8 changes: 4 additions & 4 deletions devenv.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
"devenv": {
"locked": {
"dir": "src/modules",
"lastModified": 1743783972,
"lastModified": 1744725539,
"owner": "cachix",
"repo": "devenv",
"rev": "2f53e2f867e0c2ba18b880e66169366e5f8ca554",
"rev": "e35cb7bb6e6424b83560b5ae0896f75263942191",
"type": "github"
},
"original": {
Expand Down Expand Up @@ -74,10 +74,10 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1744096231,
"lastModified": 1744536153,
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "b2b0718004cc9a5bca610326de0a82e6ea75920b",
"rev": "18dd725c29603f582cf1900e0d25f9f1063dbf11",
"type": "github"
},
"original": {
Expand Down
175 changes: 174 additions & 1 deletion src/equimo/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional, Sequence, Tuple

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
from einops import rearrange, reduce
from jaxtyping import Array, Float, PRNGKeyArray

from equimo.layers.convolution import SingleConvBlock, MBConv
from equimo.layers.dropout import DropPathAdd
from equimo.layers.ffn import Mlp
from equimo.layers.mamba import Mamba2Mixer
Expand Down Expand Up @@ -1565,6 +1567,177 @@ def __call__(
return x


class RFAttention(eqx.Module):
"""Attention with Tensor Reduction by Summation[1].

A ReLU Linear Attention mechanism replacing matmuls with global
summation and element-wise multiplications.

Attributes:
dim: Total dimension of the input/output
num_heads: Number of attention heads
head_dim: Dimension of each attention head (dim // num_heads)

References:
[1]. Yang, J., An, L., & Park, S. I. (2024). ReduceFormer: Attention
with Tensor Reduction by Summation (No. arXiv:2406.07488). arXiv.
https://doi.org/10.48550/arXiv.2406.07488
"""

total_dim: int = eqx.field(static=True)
kernel_func: Callable = eqx.field(static=True)
eps: float = eqx.field(static=True)

qkv: eqx.nn.Conv2d
aggreg: list[eqx.nn.Conv2d]
proj: SingleConvBlock

def __init__(
self,
in_channels: int,
out_channels: int,
*,
key: PRNGKeyArray,
num_heads: int | None = None,
head_dim: int = 8,
heads_ratio: float = 1.0,
scales: Sequence[int] = (5,),
use_bias: bool = False,
kernel_func: Callable = jax.nn.relu,
# TODO: Benchmark against LN, RMSN, NsLN
norm_layer: eqx.Module = eqx.nn.GroupNorm,
norm_kwargs: dict = {},
eps: float = 1e-15,
**kwargs,
):
key_qkv, key_aggreg, key_proj = jr.split(key, 3)

self.kernel_func = kernel_func
self.eps = eps
num_heads = num_heads or int(in_channels // head_dim * heads_ratio)
total_dim = num_heads * head_dim
self.total_dim = total_dim * (1 + len(scales))

self.qkv = eqx.nn.Conv2d(
in_channels=in_channels,
out_channels=3 * total_dim,
kernel_size=1,
padding="SAME",
use_bias=use_bias,
key=key_qkv,
)
self.aggreg = [
eqx.nn.Conv2d(
in_channels=3 * total_dim,
out_channels=3 * total_dim,
kernel_size=scale,
padding="SAME",
groups=3 * total_dim,
key=key_aggreg,
use_bias=use_bias,
)
for scale in scales
]
# TODO: test different normalizations
self.proj = SingleConvBlock(
in_channels=self.total_dim,
out_channels=out_channels,
kernel_size=1,
use_bias=use_bias,
norm_layer=norm_layer,
norm_kwargs=norm_kwargs,
key=key_proj,
)

def __call__(
self,
x: Float[Array, "seqlen height width"],
key: PRNGKeyArray,
inference: Optional[bool] = None,
) -> Float[Array, "seqlen height width"]:
qkv_base = self.qkv(x)

aggregated_qkvs = [op(qkv_base) for op in self.aggreg]
all_qkvs = [qkv_base] + aggregated_qkvs

rearranged_qkvs = [
rearrange(qkv, "(n d) h w -> n d h w", n=3) for qkv in all_qkvs
]
multiscale_qkv = jnp.concatenate(rearranged_qkvs, axis=1)

q, k, v = multiscale_qkv

q = self.kernel_func(q)
k = self.kernel_func(k)

sum_k = jnp.sum(k, axis=(-1, -2), keepdims=True)
sum_v = jnp.sum(v, axis=(-1, -2), keepdims=True)
sum_kv = jnp.sum(k * sum_v, axis=(-1, -2), keepdims=True)
sum_q = jnp.sum(q, axis=0, keepdims=True)

out = (q * sum_kv) / (sum_q * sum_k + self.eps)
out = self.proj(out)

return out


class RFAttentionBlock(eqx.Module):
context_module: RFAttention
local_module: MBConv

def __init__(
self,
in_channels: int,
*,
key,
head_dim: int = 32,
heads_ratio: float = 1.0,
scales: Sequence[int] = (5,),
rfattn_norm_layer: eqx.Module = eqx.nn.GroupNorm,
norm_kwargs: dict = {},
expand_ratio: float = 4.0,
mbconv_norm_layers: tuple = (None, None, eqx.nn.GroupNorm),
mbconv_act_layers: tuple = (jax.nn.hard_swish, jax.nn.hard_swish, None),
fuse_mbconv: bool = False,
**kwargs,
):
key_context, key_local = jr.split(key, 2)

self.context_module = RFAttention(
in_channels=in_channels,
out_channels=in_channels,
head_dim=head_dim,
heads_ratio=heads_ratio,
scales=scales,
norm_layer=rfattn_norm_layer,
norm_kwargs=norm_kwargs,
key=key_context,
)
self.local_module = MBConv(
in_channels=in_channels,
out_channels=in_channels,
expand_ratio=expand_ratio,
norm_layers=mbconv_norm_layers,
act_layers=mbconv_act_layers,
use_bias=(True, True, False),
fuse=fuse_mbconv,
key=key_local,
)

def __call__(
self,
x: Float[Array, "dim height width"],
key: PRNGKeyArray,
inference: Optional[bool] = None,
):
key_context, key_local = jr.split(key, 2)

x += self.context_module(x, inference=inference, key=key_context)
x += self.local_module(x, inference=inference, key=key_local)

return x


def get_attention(module: str | eqx.Module) -> eqx.Module:
"""Get an `eqx.Module` from its common name.

Expand Down
Loading