Skip to content

Commit 687b63e

Browse files
committed
Merge remote-tracking branch 'origin/main' into feat/swiglu_mlp
# Conflicts: # models/src/anemoi/models/layers/mapper.py # models/src/anemoi/models/layers/processor.py # models/src/anemoi/models/schemas/common_components.py # models/tests/layers/block/test_block_graphtransformer.py
2 parents 46014b0 + 1c28a92 commit 687b63e

16 files changed

Lines changed: 519 additions & 164 deletions

File tree

.github/workflows/inactivity-bot.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,19 @@ on:
55
- cron: "0 23 * * *" # every day at 23pm on default(main) branch
66
workflow_dispatch: # Allows manual trigger
77

8+
permissions:
9+
actions: write
10+
contents: write # only for delete-branch option
11+
issues: write
12+
pull-requests: write
13+
814
jobs:
915
stale:
1016
runs-on: ubuntu-latest
1117
steps:
1218
- uses: actions/stale@v9
1319
with:
14-
repo-token: $
20+
repo-token: ${{ secrets.GITHUB_TOKEN }}
1521

1622
# Issue settings
1723
days-before-issue-stale: 90 # (~3 months)

graphs/src/anemoi/graphs/nodes/attributes/area_weights.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def compute_latitude_weight(self, latitudes: np.ndarray) -> np.ndarray:
354354

355355

356356
class IsolatitudeAreaWeights(BaseLatWeightedAttribute):
357-
"""Latitude-weighted area weights for rectilinear grids.
357+
r"""Latitude-weighted area weights for rectilinear grids.
358358
359359
Attributes
360360
----------

models/src/anemoi/models/layers/attention.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
num_heads: int,
5858
embed_dim: int,
5959
layer_kernels: DotDict,
60+
attn_channels: Optional[int] = None,
6061
qkv_bias: bool = False,
6162
qk_norm: bool = False,
6263
is_causal: bool = False,
@@ -81,7 +82,10 @@ def __init__(
8182
num_heads : int
8283
number of heads
8384
embed_dim : int
84-
embedding dimension
85+
Input and output embedding dimension
86+
attn_channels : int, optional
87+
Internal attention width used for q/k/v projections. If None,
88+
defaults to embed_dim.
8589
qkv_bias : bool, optional
8690
bias for querys, keys and values, by default False
8791
qk_norm : bool, optional
@@ -102,16 +106,17 @@ def __init__(
102106
"""
103107
super().__init__()
104108

105-
assert (
106-
embed_dim % num_heads == 0
107-
), f"Embedding dimension ({embed_dim}) must be divisible by number of heads ({num_heads})"
109+
self.attn_channels = embed_dim if attn_channels is None else attn_channels
110+
if self.attn_channels <= 0:
111+
raise ValueError(f"attn_channels must be > 0, got {self.attn_channels}")
112+
if self.attn_channels % num_heads != 0:
113+
raise ValueError(f"attn_channels ({self.attn_channels}) must be divisible by number of heads ({num_heads})")
108114

109115
self.attention_implementation = attention_implementation
110116
self.use_alibi_slopes = use_alibi_slopes
111117

112118
self.num_heads = num_heads
113-
self.embed_dim = embed_dim
114-
self.head_dim = embed_dim // num_heads # q k v
119+
self.head_dim = self.attn_channels // num_heads # q k v
115120
self.window_size = window_size
116121
self.dropout_p = dropout_p
117122
self.is_causal = is_causal
@@ -128,11 +133,11 @@ def __init__(
128133
self.alibi_slopes = None
129134

130135
linear = layer_kernels.Linear
131-
self.lin_q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
132-
self.lin_k = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
133-
self.lin_v = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
136+
self.lin_q = nn.Linear(embed_dim, self.attn_channels, bias=qkv_bias)
137+
self.lin_k = nn.Linear(embed_dim, self.attn_channels, bias=qkv_bias)
138+
self.lin_v = nn.Linear(embed_dim, self.attn_channels, bias=qkv_bias)
134139

135-
self.projection = linear(embed_dim, embed_dim, bias=True)
140+
self.projection = linear(self.attn_channels, embed_dim, bias=True)
136141

137142
if self.qk_norm:
138143
self.q_norm = layer_kernels["QueryNorm"](self.head_dim)

models/src/anemoi/models/layers/block.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(
112112
num_heads: int,
113113
window_size: Optional[int],
114114
layer_kernels: DotDict,
115+
attn_channels: Optional[int] = None,
115116
dropout_p: float = 0.0,
116117
qk_norm: bool = False,
117118
attention_implementation: str = "flash_attention",
@@ -128,6 +129,7 @@ def __init__(
128129
self.attention = MultiHeadSelfAttention(
129130
num_heads=num_heads,
130131
embed_dim=num_channels,
132+
attn_channels=attn_channels,
131133
window_size=window_size,
132134
qkv_bias=False,
133135
is_causal=False,
@@ -186,6 +188,7 @@ def __init__(
186188
num_heads: int,
187189
window_size: Optional[int],
188190
layer_kernels: DotDict,
191+
attn_channels: Optional[int] = None,
189192
dropout_p: float = 0.0,
190193
qk_norm: bool = False,
191194
attention_implementation: str = "flash_attention",
@@ -197,6 +200,7 @@ def __init__(
197200
super().__init__(
198201
num_channels=num_channels,
199202
hidden_dim=hidden_dim,
203+
attn_channels=attn_channels,
200204
num_heads=num_heads,
201205
window_size=window_size,
202206
layer_kernels=layer_kernels,
@@ -212,6 +216,7 @@ def __init__(
212216
self.attention = MultiHeadCrossAttention(
213217
num_heads=num_heads,
214218
embed_dim=num_channels,
219+
attn_channels=attn_channels,
215220
window_size=window_size,
216221
qkv_bias=False,
217222
qk_norm=qk_norm,
@@ -462,6 +467,7 @@ def __init__(
462467
mlp_implementation: MLPImplementation = "mlp",
463468
update_src_nodes: bool = False,
464469
layer_kernels: DotDict,
470+
attn_channels: Optional[int] = None,
465471
graph_attention_backend: str = "triton",
466472
edge_pre_mlp: bool = False,
467473
**kwargs,
@@ -474,6 +480,9 @@ def __init__(
474480
Number of input channels.
475481
out_channels : int
476482
Number of output channels.
483+
attn_channels : int, optional
484+
Internal attention width used for q/k/v and edge projections. If
485+
None, defaults to out_channels.
477486
num_heads : int,
478487
Number of heads
479488
edge_dim : int,
@@ -496,7 +505,15 @@ def __init__(
496505

497506
self.update_src_nodes = update_src_nodes
498507

499-
self.out_channels_conv = out_channels // num_heads
508+
self.attn_channels = out_channels if attn_channels is None else attn_channels
509+
if self.attn_channels <= 0:
510+
raise ValueError(f"attn_channels must be > 0, got {self.attn_channels}")
511+
if self.attn_channels % num_heads != 0:
512+
raise ValueError(
513+
f"attn_channels ({self.attn_channels}) must be divisible by num_heads ({num_heads}) in {self.__class__.__name__}."
514+
)
515+
516+
self.out_channels_conv = self.attn_channels // num_heads
500517
self.num_heads = num_heads
501518
self.qk_norm = qk_norm
502519

@@ -508,7 +525,7 @@ def __init__(
508525
self.lin_self = Linear(in_channels, num_heads * self.out_channels_conv, bias=bias)
509526
self.lin_edge = Linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False)
510527

511-
self.projection = Linear(out_channels, out_channels)
528+
self.projection = Linear(self.attn_channels, out_channels)
512529

513530
if self.qk_norm:
514531
self.q_norm = layer_kernels.QueryNorm(self.out_channels_conv)

models/src/anemoi/models/layers/mapper.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __init__(
150150
num_heads: int,
151151
mlp_hidden_ratio: float,
152152
edge_dim: int,
153+
attn_channels: Optional[int] = None,
153154
qk_norm: bool = False,
154155
mlp_implementation: MLPImplementation = "mlp",
155156
cpu_offload: bool = False,
@@ -179,6 +180,11 @@ def __init__(
179180
ratio of mlp hidden dimension to embedding dimension
180181
edge_dim : int
181182
Edge feature dimension
183+
attn_channels : int, optional
184+
Internal attention width used for q/k/v and edge projections. If
185+
None, defaults to the hidden dimension. This allows reducing the
186+
number of channels used for the attention computation without
187+
changing the width of the surrounding MLPs.
182188
qk_norm : bool, optional
183189
Whether to use query and key normalization, default False
184190
mlp_implementation: MLPImplementation
@@ -213,6 +219,7 @@ def __init__(
213219
in_channels=hidden_dim,
214220
hidden_dim=compute_mlp_hidden_dim(hidden_dim, mlp_hidden_ratio),
215221
out_channels=hidden_dim,
222+
attn_channels=attn_channels,
216223
num_heads=num_heads,
217224
edge_dim=edge_dim,
218225
qk_norm=qk_norm,
@@ -507,6 +514,7 @@ def __init__(
507514
num_heads: int,
508515
mlp_hidden_ratio: float,
509516
edge_dim: int,
517+
attn_channels: Optional[int] = None,
510518
qk_norm: bool = False,
511519
mlp_implementation: MLPImplementation = "mlp",
512520
cpu_offload: bool = False,
@@ -534,6 +542,11 @@ def __init__(
534542
ratio of mlp hidden dimension to embedding dimension
535543
edge_dim : int
536544
Edge feature dimension
545+
attn_channels : int, optional
546+
Internal attention width used for q/k/v and edge projections. If
547+
None, defaults to the hidden dimension. This allows reducing the
548+
number of channels used for the attention computation without
549+
changing the width of the surrounding MLPs.
537550
qk_norm : bool, optional
538551
Whether to use query and key normalization, default False
539552
mlp_implementation: MLPImplementation
@@ -561,6 +574,7 @@ def __init__(
561574
mlp_hidden_ratio=mlp_hidden_ratio,
562575
edge_dim=edge_dim,
563576
mlp_implementation=mlp_implementation,
577+
attn_channels=attn_channels,
564578
layer_kernels=layer_kernels,
565579
shard_strategy=shard_strategy,
566580
graph_attention_backend=graph_attention_backend,
@@ -629,6 +643,7 @@ def __init__(
629643
num_heads: int,
630644
mlp_hidden_ratio: float,
631645
edge_dim: int,
646+
attn_channels: Optional[int] = None,
632647
qk_norm: bool = False,
633648
mlp_implementation: MLPImplementation = "mlp",
634649
initialise_data_extractor_zero: bool = False,
@@ -659,6 +674,11 @@ def __init__(
659674
Ratio of mlp hidden dimension to embedding dimension
660675
edge_dim : int
661676
Edge feature dimension
677+
attn_channels : int, optional
678+
Internal attention width used for q/k/v and edge projections. If
679+
None, defaults to the hidden dimension. This allows reducing the
680+
number of channels used for the attention computation without
681+
changing the width of the surrounding MLPs.
662682
qk_norm : bool, optional
663683
Whether to use query and key normalization, default False
664684
mlp_implementation: MLPImplementation
@@ -689,6 +709,7 @@ def __init__(
689709
mlp_hidden_ratio=mlp_hidden_ratio,
690710
edge_dim=edge_dim,
691711
mlp_implementation=mlp_implementation,
712+
attn_channels=attn_channels,
692713
layer_kernels=layer_kernels,
693714
shard_strategy=shard_strategy,
694715
graph_attention_backend=graph_attention_backend,
@@ -1108,6 +1129,7 @@ def __init__(
11081129
num_chunks: int,
11091130
num_heads: int,
11101131
mlp_hidden_ratio: float,
1132+
attn_channels: Optional[int] = None,
11111133
window_size: Optional[int] = None,
11121134
dropout_p: float = 0.0,
11131135
qk_norm: bool = False,
@@ -1133,6 +1155,11 @@ def __init__(
11331155
Output channels of the destination node, by default None
11341156
mlp_hidden_ratio: float
11351157
Ratio of mlp hidden dimension to embedding dimension
1158+
attn_channels : int, optional
1159+
Internal attention width used for q/k/v projections. If None,
1160+
defaults to the hidden dimension. This allows reducing the number
1161+
of channels used for the attention computation without changing
1162+
the width of the surrounding MLPs.
11361163
qk_norm: bool, optional
11371164
Normalize query and key, by default False
11381165
dropout_p: float, optional
@@ -1167,6 +1194,7 @@ def __init__(
11671194
self.proc = TransformerMapperBlock(
11681195
num_channels=hidden_dim,
11691196
hidden_dim=compute_mlp_hidden_dim(hidden_dim, mlp_hidden_ratio),
1197+
attn_channels=attn_channels,
11701198
num_heads=num_heads,
11711199
window_size=window_size,
11721200
layer_kernels=self.layer_factory,
@@ -1256,6 +1284,7 @@ def __init__(
12561284
num_chunks: int,
12571285
num_heads: int,
12581286
mlp_hidden_ratio: float,
1287+
attn_channels: Optional[int] = None,
12591288
qk_norm: bool = False,
12601289
dropout_p: float = 0.0,
12611290
mlp_implementation: MLPImplementation = "mlp",
@@ -1282,6 +1311,11 @@ def __init__(
12821311
Output channels of the destination node, by default None
12831312
mlp_hidden_ratio: float
12841313
Ratio of mlp hidden dimension to embedding dimension
1314+
attn_channels : int, optional
1315+
Internal attention width used for q/k/v projections. If None,
1316+
defaults to the hidden dimension. This allows reducing the number
1317+
of channels used for the attention computation without changing
1318+
the width of the surrounding MLPs.
12851319
qk_norm: bool, optional
12861320
Normalize query and key, by default False
12871321
dropout_p: float, optional
@@ -1313,6 +1347,7 @@ def __init__(
13131347
cpu_offload=cpu_offload,
13141348
num_heads=num_heads,
13151349
mlp_hidden_ratio=mlp_hidden_ratio,
1350+
attn_channels=attn_channels,
13161351
window_size=window_size,
13171352
dropout_p=dropout_p,
13181353
qk_norm=qk_norm,
@@ -1384,6 +1419,7 @@ def __init__(
13841419
num_chunks: int,
13851420
num_heads: int,
13861421
mlp_hidden_ratio: float,
1422+
attn_channels: Optional[int] = None,
13871423
qk_norm: bool = False,
13881424
dropout_p: float = 0.0,
13891425
mlp_implementation: MLPImplementation = "mlp",
@@ -1410,6 +1446,11 @@ def __init__(
14101446
Output channels of the destination node, by default None
14111447
mlp_hidden_ratio: float
14121448
Ratio of mlp hidden dimension to embedding dimension
1449+
attn_channels : int, optional
1450+
Internal attention width used for q/k/v projections. If None,
1451+
defaults to the hidden dimension. This allows reducing the number
1452+
of channels used for the attention computation without changing
1453+
the width of the surrounding MLPs.
14131454
qk_norm: bool, optional
14141455
Normalize query and key, by default False
14151456
dropout_p: float, optional
@@ -1441,6 +1482,7 @@ def __init__(
14411482
cpu_offload=cpu_offload,
14421483
num_heads=num_heads,
14431484
mlp_hidden_ratio=mlp_hidden_ratio,
1485+
attn_channels=attn_channels,
14441486
window_size=window_size,
14451487
dropout_p=dropout_p,
14461488
qk_norm=qk_norm,

models/src/anemoi/models/layers/processor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def __init__(
213213
num_chunks: int,
214214
num_heads: int,
215215
mlp_hidden_ratio: float,
216+
attn_channels: Optional[int] = None,
216217
qk_norm=False,
217218
dropout_p: float = 0.0,
218219
attention_implementation: str = "flash_attention",
@@ -238,6 +239,11 @@ def __init__(
238239
Number of heads in transformer
239240
mlp_hidden_ratio: float
240241
Ratio of mlp hidden dimension to embedding dimension
242+
attn_channels : int, optional
243+
Internal attention width used for q/k/v projections. If None,
244+
defaults to num_channels. This allows reducing the number of
245+
channels used for the attention computation without changing the
246+
width of the surrounding MLPs.
241247
qk_norm: bool, optional
242248
Normalize query and key, by default False
243249
dropout_p: float, optional
@@ -275,6 +281,7 @@ def __init__(
275281
TransformerProcessorBlock,
276282
num_channels=num_channels,
277283
hidden_dim=compute_mlp_hidden_dim(num_channels, mlp_hidden_ratio),
284+
attn_channels=attn_channels,
278285
num_heads=num_heads,
279286
qk_norm=qk_norm,
280287
window_size=window_size,
@@ -423,6 +430,7 @@ def __init__(
423430
num_heads: int,
424431
mlp_hidden_ratio: float,
425432
edge_dim: int,
433+
attn_channels: Optional[int] = None,
426434
qk_norm: bool = False,
427435
mlp_implementation: MLPImplementation = "mlp",
428436
cpu_offload: bool = False,
@@ -447,6 +455,11 @@ def __init__(
447455
Ratio of mlp hidden dimension to embedding dimension
448456
edge_dim : int
449457
Edge feature dimension
458+
attn_channels : int, optional
459+
Internal attention width used for q/k/v and edge projections. If
460+
None, defaults to num_channels. This allows reducing the number
461+
of channels used for the attention computation without changing
462+
the width of the surrounding MLPs.
450463
qk_norm: bool, optional
451464
Normalize query and key, by default False
452465
mlp_implementation: MLPImplementation
@@ -476,6 +489,7 @@ def __init__(
476489
in_channels=num_channels,
477490
hidden_dim=compute_mlp_hidden_dim(num_channels, mlp_hidden_ratio),
478491
out_channels=num_channels,
492+
attn_channels=attn_channels,
479493
num_heads=num_heads,
480494
layer_kernels=self.layer_factory,
481495
qk_norm=qk_norm,

0 commit comments

Comments
 (0)