@@ -150,6 +150,7 @@ def __init__(
150150 num_heads : int ,
151151 mlp_hidden_ratio : float ,
152152 edge_dim : int ,
153+ attn_channels : Optional [int ] = None ,
153154 qk_norm : bool = False ,
154155 mlp_implementation : MLPImplementation = "mlp" ,
155156 cpu_offload : bool = False ,
@@ -179,6 +180,11 @@ def __init__(
179180 ratio of mlp hidden dimension to embedding dimension
180181 edge_dim : int
181182 Edge feature dimension
183+ attn_channels : int, optional
184+ Internal attention width used for q/k/v and edge projections. If
185+ None, defaults to the hidden dimension. This allows reducing the
186+ number of channels used for the attention computation without
187+ changing the width of the surrounding MLPs.
182188 qk_norm : bool, optional
183189 Whether to use query and key normalization, default False
184190 mlp_implementation: MLPImplementation
@@ -213,6 +219,7 @@ def __init__(
213219 in_channels = hidden_dim ,
214220 hidden_dim = compute_mlp_hidden_dim (hidden_dim , mlp_hidden_ratio ),
215221 out_channels = hidden_dim ,
222+ attn_channels = attn_channels ,
216223 num_heads = num_heads ,
217224 edge_dim = edge_dim ,
218225 qk_norm = qk_norm ,
@@ -507,6 +514,7 @@ def __init__(
507514 num_heads : int ,
508515 mlp_hidden_ratio : float ,
509516 edge_dim : int ,
517+ attn_channels : Optional [int ] = None ,
510518 qk_norm : bool = False ,
511519 mlp_implementation : MLPImplementation = "mlp" ,
512520 cpu_offload : bool = False ,
@@ -534,6 +542,11 @@ def __init__(
534542 ratio of mlp hidden dimension to embedding dimension
535543 edge_dim : int
536544 Edge feature dimension
545+ attn_channels : int, optional
546+ Internal attention width used for q/k/v and edge projections. If
547+ None, defaults to the hidden dimension. This allows reducing the
548+ number of channels used for the attention computation without
549+ changing the width of the surrounding MLPs.
537550 qk_norm : bool, optional
538551 Whether to use query and key normalization, default False
539552 mlp_implementation: MLPImplementation
@@ -561,6 +574,7 @@ def __init__(
561574 mlp_hidden_ratio = mlp_hidden_ratio ,
562575 edge_dim = edge_dim ,
563576 mlp_implementation = mlp_implementation ,
577+ attn_channels = attn_channels ,
564578 layer_kernels = layer_kernels ,
565579 shard_strategy = shard_strategy ,
566580 graph_attention_backend = graph_attention_backend ,
@@ -629,6 +643,7 @@ def __init__(
629643 num_heads : int ,
630644 mlp_hidden_ratio : float ,
631645 edge_dim : int ,
646+ attn_channels : Optional [int ] = None ,
632647 qk_norm : bool = False ,
633648 mlp_implementation : MLPImplementation = "mlp" ,
634649 initialise_data_extractor_zero : bool = False ,
@@ -659,6 +674,11 @@ def __init__(
659674 Ratio of mlp hidden dimension to embedding dimension
660675 edge_dim : int
661676 Edge feature dimension
677+ attn_channels : int, optional
678+ Internal attention width used for q/k/v and edge projections. If
679+ None, defaults to the hidden dimension. This allows reducing the
680+ number of channels used for the attention computation without
681+ changing the width of the surrounding MLPs.
662682 qk_norm : bool, optional
663683 Whether to use query and key normalization, default False
664684 mlp_implementation: MLPImplementation
@@ -689,6 +709,7 @@ def __init__(
689709 mlp_hidden_ratio = mlp_hidden_ratio ,
690710 edge_dim = edge_dim ,
691711 mlp_implementation = mlp_implementation ,
712+ attn_channels = attn_channels ,
692713 layer_kernels = layer_kernels ,
693714 shard_strategy = shard_strategy ,
694715 graph_attention_backend = graph_attention_backend ,
@@ -1108,6 +1129,7 @@ def __init__(
11081129 num_chunks : int ,
11091130 num_heads : int ,
11101131 mlp_hidden_ratio : float ,
1132+ attn_channels : Optional [int ] = None ,
11111133 window_size : Optional [int ] = None ,
11121134 dropout_p : float = 0.0 ,
11131135 qk_norm : bool = False ,
@@ -1133,6 +1155,11 @@ def __init__(
11331155 Output channels of the destination node, by default None
11341156 mlp_hidden_ratio: float
11351157 Ratio of mlp hidden dimension to embedding dimension
1158+ attn_channels : int, optional
1159+ Internal attention width used for q/k/v projections. If None,
1160+ defaults to the hidden dimension. This allows reducing the number
1161+ of channels used for the attention computation without changing
1162+ the width of the surrounding MLPs.
11361163 qk_norm: bool, optional
11371164 Normalize query and key, by default False
11381165 dropout_p: float, optional
@@ -1167,6 +1194,7 @@ def __init__(
11671194 self .proc = TransformerMapperBlock (
11681195 num_channels = hidden_dim ,
11691196 hidden_dim = compute_mlp_hidden_dim (hidden_dim , mlp_hidden_ratio ),
1197+ attn_channels = attn_channels ,
11701198 num_heads = num_heads ,
11711199 window_size = window_size ,
11721200 layer_kernels = self .layer_factory ,
@@ -1256,6 +1284,7 @@ def __init__(
12561284 num_chunks : int ,
12571285 num_heads : int ,
12581286 mlp_hidden_ratio : float ,
1287+ attn_channels : Optional [int ] = None ,
12591288 qk_norm : bool = False ,
12601289 dropout_p : float = 0.0 ,
12611290 mlp_implementation : MLPImplementation = "mlp" ,
@@ -1282,6 +1311,11 @@ def __init__(
12821311 Output channels of the destination node, by default None
12831312 mlp_hidden_ratio: float
12841313 Ratio of mlp hidden dimension to embedding dimension
1314+ attn_channels : int, optional
1315+ Internal attention width used for q/k/v projections. If None,
1316+ defaults to the hidden dimension. This allows reducing the number
1317+ of channels used for the attention computation without changing
1318+ the width of the surrounding MLPs.
12851319 qk_norm: bool, optional
12861320 Normalize query and key, by default False
12871321 dropout_p: float, optional
@@ -1313,6 +1347,7 @@ def __init__(
13131347 cpu_offload = cpu_offload ,
13141348 num_heads = num_heads ,
13151349 mlp_hidden_ratio = mlp_hidden_ratio ,
1350+ attn_channels = attn_channels ,
13161351 window_size = window_size ,
13171352 dropout_p = dropout_p ,
13181353 qk_norm = qk_norm ,
@@ -1384,6 +1419,7 @@ def __init__(
13841419 num_chunks : int ,
13851420 num_heads : int ,
13861421 mlp_hidden_ratio : float ,
1422+ attn_channels : Optional [int ] = None ,
13871423 qk_norm : bool = False ,
13881424 dropout_p : float = 0.0 ,
13891425 mlp_implementation : MLPImplementation = "mlp" ,
@@ -1410,6 +1446,11 @@ def __init__(
14101446 Output channels of the destination node, by default None
14111447 mlp_hidden_ratio: float
14121448 Ratio of mlp hidden dimension to embedding dimension
1449+ attn_channels : int, optional
1450+ Internal attention width used for q/k/v projections. If None,
1451+ defaults to the hidden dimension. This allows reducing the number
1452+ of channels used for the attention computation without changing
1453+ the width of the surrounding MLPs.
14131454 qk_norm: bool, optional
14141455 Normalize query and key, by default False
14151456 dropout_p: float, optional
@@ -1441,6 +1482,7 @@ def __init__(
14411482 cpu_offload = cpu_offload ,
14421483 num_heads = num_heads ,
14431484 mlp_hidden_ratio = mlp_hidden_ratio ,
1485+ attn_channels = attn_channels ,
14441486 window_size = window_size ,
14451487 dropout_p = dropout_p ,
14461488 qk_norm = qk_norm ,
0 commit comments