Skip to content

Commit ff4ec6a

Browse files
committed
add docs
1 parent fb144b8 commit ff4ec6a

File tree

13 files changed

+272
-18
lines changed

13 files changed

+272
-18
lines changed

.github/workflows/docs.yml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
name: Build and Deploy Documentation
2+
3+
on:
4+
# Runs on pushes to main branch
5+
push:
6+
branches: [ main ]
7+
# Allows manual trigger from Actions tab
8+
workflow_dispatch:
9+
10+
# Sets permissions for GITHUB_TOKEN to allow deployment to GitHub Pages
11+
permissions:
12+
contents: read
13+
pages: write
14+
id-token: write
15+
16+
# Allow only one concurrent deployment
17+
concurrency:
18+
group: "pages"
19+
cancel-in-progress: false
20+
21+
jobs:
22+
build:
23+
runs-on: ubuntu-latest
24+
steps:
25+
- name: Checkout
26+
uses: actions/checkout@v4
27+
28+
- name: Set up Python
29+
uses: actions/setup-python@v5
30+
with:
31+
python-version: '3.11'
32+
33+
- name: Install dependencies
34+
run: |
35+
python -m pip install --upgrade pip
36+
pip install -e ".[docs]"
37+
38+
- name: Build documentation
39+
run: mkdocs build
40+
41+
- name: Setup Pages
42+
uses: actions/configure-pages@v5
43+
44+
- name: Upload artifact
45+
uses: actions/upload-pages-artifact@v3
46+
with:
47+
path: ./site
48+
retention-days: 1
49+
50+
deploy:
51+
environment:
52+
name: github-pages
53+
url: ${{ steps.deployment.outputs.page_url }}
54+
runs-on: ubuntu-latest
55+
needs: build
56+
steps:
57+
- name: Deploy to GitHub Pages
58+
id: deployment
59+
uses: actions/deploy-pages@v4

docs/index.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# sparse-transformer-layers documentation
2+
3+
Welcome to the documentation for sparse-transformer-layers.
4+
5+
For basic information, please see the [repository Readme](https://github.com/mawright/sparse-transformer-layers).
6+
7+
This documentation features more detailed usage instructions for all of the Transformer layers in the library.

docs/msdeform_attn.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Sparse multi-scale deformable attention
2+
3+
## Overview
4+
5+
This implements a version of Multi-scale Deformable Attention (MSDeformAttention) adapted for sparse tensors.
6+
7+
---
8+
9+
::: blocks.ms_deform_attn.SparseMSDeformableAttentionBlock
10+
options:
11+
members:
12+
- forward
13+
- reset_parameters
14+
show_root_heading: true
15+
show_root_toc_entry: true
16+
show_root_full_path: false
17+
18+
---
19+
20+
::: layers.sparse_ms_deform_attn.layer.SparseMSDeformableAttention
21+
options:
22+
members:
23+
- forward
24+
- reset_parameters
25+
show_root_heading: true
26+
show_root_toc_entry: true
27+
show_root_full_path: false
28+
29+
---
30+
31+
## Utilities
32+
33+
::: layers.sparse_ms_deform_attn.utils
34+
options:
35+
members:
36+
- sparse_split_heads
37+
- multilevel_sparse_bilinear_grid_sample
38+
show_root_heading: false
39+
show_root_toc_entry: false
40+
show_root_full_path: false
41+
heading_level: 3

docs/neigh_attn.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Multi-level sparse neighborhood attention
2+
3+
## Overview
4+
5+
The multi-level sparse neighborhood attention operation allows query points to attend to the small neighborhoods of nonzero points around their spatial position, one neighborhood for each feature level.
6+
This is a potentially useful alternative or complement to multi-scale deformable attention, which can potentially try to sample from zero points on sparse tensors. The neighborhood attention operation, on the other hand, will always attend to all nonzero points within the given neighborhood sizes.
7+
8+
The neighborhood attention implementation makes use of a custom autograd operator that checkpoints the key and value projections of the neighborhood points and manually calculates the backward pass.
9+
This checkpointing is essential for memory management, particularly for operations with many potential query points such as within a DETR encoder, or a DETR decoder with many object queries.
10+
11+
---
12+
13+
::: blocks.neighborhood_attn.SparseNeighborhoodAttentionBlock
14+
options:
15+
members:
16+
- forward
17+
- reset_parameters
18+
show_root_heading: true
19+
show_root_toc_entry: true
20+
show_root_full_path: false
21+
22+
---
23+
24+
::: blocks.neighborhood_attn
25+
options:
26+
members:
27+
- get_multilevel_neighborhoods
28+
show_root_heading: false
29+
show_root_toc_entry: false

docs/self_attn.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Multi-level sparse self-attention
2+
3+
## Overview
4+
5+
The self-attention implementation is intended for use with `torch.sparse_coo_tensor` multi-level feature maps. It uses [`RoPEEncodingND`](https://mawright.github.io/nd-rotary-encodings/layer/#position_encoding_layer.rope_encoding_layer.RoPEEncodingND) from [nd-rotary-encodings](https://github.com/mawright/nd-rotary-encodings) to encode the positions and feature levels of all input points.
6+
7+
---
8+
9+
::: blocks.self_attn.MultilevelSelfAttentionBlockWithRoPE
10+
options:
11+
members:
12+
- forward
13+
- reset_parameters
14+
show_root_heading: true
15+
show_root_toc_entry: true
16+
show_root_full_path: false

mkdocs.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
site_name: sparse-transformer-layers
2+
theme:
3+
name: readthedocs
4+
repo_url: https://github.com/mawright/sparse-transformer-layers
5+
6+
plugins:
7+
- search
8+
- mkdocstrings:
9+
handlers:
10+
python:
11+
options:
12+
show_source: false
13+
paths: [sparse_transformer_layers]
14+
15+
markdown_extensions:
16+
- toc:
17+
permalink: true
18+
19+
nav:
20+
- Home: index.md
21+
- Self-Attention: self_attn.md
22+
- Neighborhood Attention: neigh_attn.md
23+
- MSDeform Attention: msdeform_attn.md
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from .blocks import (
2+
MultilevelSelfAttentionBlockWithRoPE,
3+
SparseMSDeformableAttentionBlock,
4+
SparseNeighborhoodAttentionBlock,
5+
)
6+
from .layers import BatchSparseIndexSubsetAttention, SparseMSDeformableAttention
7+
8+
__all__ = [
9+
"BatchSparseIndexSubsetAttention",
10+
"SparseMSDeformableAttention",
11+
"SparseMSDeformableAttentionBlock",
12+
"SparseNeighborhoodAttentionBlock",
13+
"MultilevelSelfAttentionBlockWithRoPE",
14+
]
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
from .ms_deform_attn import SparseMSDeformableAttentionBlock
2+
from .neighborhood_attn import SparseNeighborhoodAttentionBlock
13
from .self_attn import MultilevelSelfAttentionBlockWithRoPE
24

3-
__all__ = ["MultilevelSelfAttentionBlockWithRoPE"]
5+
__all__ = [
6+
"MultilevelSelfAttentionBlockWithRoPE",
7+
"SparseMSDeformableAttentionBlock",
8+
"SparseNeighborhoodAttentionBlock",
9+
]

sparse_transformer_layers/blocks/ms_deform_attn.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,32 @@
77

88

99
class SparseMSDeformableAttentionBlock(nn.Module):
10+
"""A standard transformer block using Sparse Multi-Scale Deformable Attention.
11+
12+
This module encapsulates the `SparseMSDeformableAttention` layer within a
13+
typical transformer block structure. It includes a query input projection,
14+
the attention mechanism itself, an output projection with dropout, a residual
15+
connection, and layer normalization. The layer normalization can be applied
16+
either before (pre-norm) or after (post-norm) the main block operations.
17+
18+
This block is designed to be a plug-and-play component in a larger transformer
19+
architecture that operates on sparse, multi-scale feature maps, such as the
20+
encoder or decoder of a Deformable DETR-like model.
21+
22+
The current version of this module only supports spatially-2D data.
23+
24+
Args:
25+
embed_dim (int): The embedding dimension for the queries and features.
26+
n_heads (int): The number of attention heads.
27+
n_levels (int): The number of feature levels to sample from.
28+
n_points (int): The number of sampling points per head per level.
29+
dropout (float): Dropout probability for the output projection. Defaults to 0.0.
30+
bias (bool): Whether to include bias terms in the input and output
31+
projection layers. Defaults to False.
32+
norm_first (bool): If True, applies layer normalization before the attention
33+
and projection (pre-norm). If False, applies it after the residual
34+
connection (post-norm). Defaults to True.
35+
"""
1036
def __init__(
1137
self,
1238
embed_dim: int,
@@ -48,6 +74,31 @@ def forward(
4874
background_embedding: Optional[Tensor] = None,
4975
query_level_indices: Optional[Tensor] = None,
5076
) -> Tensor:
77+
"""Forward pass for the SparseMSDeformableAttentionBlock.
78+
79+
Args:
80+
query (Tensor): Batch-flattened query tensor of shape [n_query, embed_dim].
81+
query_spatial_positions (Tensor): Spatial positions of queries,
82+
shape [n_queries, 2]. The positions must be floating-point
83+
values scaled to the feature level in which each query resides.
84+
query_batch_offsets (Tensor): Tensor of shape [batch_size+1] indicating
85+
the start and end indices for each batch item in the flattened `query`.
86+
stacked_feature_maps (Tensor): A sparse tensor containing feature maps
87+
from all levels, with shape [batch, height, width, levels, embed_dim].
88+
The last dimension is dense, others are sparse.
89+
level_spatial_shapes (Tensor): Spatial dimensions (height, width) of each
90+
feature level, shape [n_levels, 2].
91+
background_embedding (Optional[Tensor]): An embedding to use for sampling
92+
points that fall in unspecified regions of the sparse feature maps.
93+
Shape [batch, n_levels, embed_dim].
94+
query_level_indices (Optional[Tensor]): The level index for each query,
95+
shape [n_queries]. If None, queries are assumed to be at the largest
96+
feature level.
97+
98+
Returns:
99+
Tensor: The output tensor after the attention block, with the same shape
100+
as the input `query`, [n_query, embed_dim].
101+
"""
51102
residual = query
52103
if self.norm_first:
53104
query = self.norm(query)
@@ -74,6 +125,7 @@ def forward(
74125
return x
75126

76127
def reset_parameters(self):
128+
"""Resets the parameters of all submodules."""
77129
self.norm.reset_parameters()
78130
self.q_in_proj.reset_parameters()
79131
self.msdeform_attn.reset_parameters()

sparse_transformer_layers/blocks/neighborhood_attn.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def forward(
303303
return x
304304

305305
def reset_parameters(self):
306+
"""Initializes/resets the weights of all submodules."""
306307
self.norm.reset_parameters()
307308
self.q_in_proj.reset_parameters()
308309
self.subset_attn.reset_parameters()
@@ -333,21 +334,20 @@ def get_multilevel_neighborhoods(
333334
Default: [3, 5, 7, 9].
334335
335336
Returns:
336-
Tuple[Tensor, Tensor, Tensor]: A tuple containing:
337-
- multilevel_neighborhood_indices: Tensor of shape
338-
[n_queries, sum(neighborhood_sizes^position_dim), position_dim]
339-
containing the spatial indices of all neighborhood points for each
340-
query across all levels.
341-
- out_of_bounds_mask: Boolean tensor of shape
342-
[n_queries, sum(neighborhood_sizes^position_dim)] that is True at locations
343-
in multilevel_neighborhood_indices that are out of bounds; i.e.
344-
negative or >= the spatial shape for that level
345-
If some of the computed neighborhood indices for a query are out of
346-
bounds of the level's spatial shape, those indices will instead be
347-
filled with mask values of -1.
348-
- level_indices: Tensor of shape [sum(neighborhood_sizes^position_dim)]
349-
mapping each neighborhood position to its corresponding resolution
350-
level.
337+
multilevel_neighborhood_indices (Tensor): Tensor of shape
338+
[n_queries, sum(neighborhood_sizes^position_dim), position_dim]
339+
containing the spatial indices of all neighborhood points for each
340+
query across all levels.
341+
out_of_bounds_mask (Tensor): Boolean tensor of shape
342+
[n_queries, sum(neighborhood_sizes^position_dim)] that is True at locations
343+
in multilevel_neighborhood_indices that are out of bounds; i.e.
344+
negative or >= the spatial shape for that level
345+
If some of the computed neighborhood indices for a query are out of
346+
bounds of the level's spatial shape, those indices will instead be
347+
filled with mask values of -1.
348+
level_indices (Tensor): Tensor of shape [sum(neighborhood_sizes^position_dim)]
349+
mapping each neighborhood position to its corresponding resolution
350+
level.
351351
352352
Raises:
353353
ValueError: If input tensors don't have the expected shape or dimensions, or

0 commit comments

Comments
 (0)