Skip to content

Commit 67b6743

Browse files
Update FE to 1.5.2 and miscellaneous fixes (#975)
* update FE to 1.5.2 Signed-off-by: Charlene Yang <[email protected]> * enable unfused attn for cross attn Signed-off-by: Charlene Yang <[email protected]> * unify logging info Signed-off-by: Charlene Yang <[email protected]> * omit cudnn 9.1.1 and 9.2.1 due to bugs Signed-off-by: Charlene Yang <[email protected]> * set cu_seqlens_padded to cu_seqlens by default Signed-off-by: Charlene Yang <[email protected]> * replace variable name with ctx.variable Signed-off-by: Charlene Yang <[email protected]> * Revert "enable unfused attn for cross attn" This reverts commit bc49f14. Signed-off-by: Charlene Yang <[email protected]> * restrict cudnn version for fp8 tests Signed-off-by: Charlene Yang <[email protected]> * remove mha_fill for FP8 Signed-off-by: Charlene Yang <[email protected]> * Revert "remove mha_fill for FP8" This reverts commit 83ffc44114dc6eb3d426d742b6c5a4d34805ec04. Signed-off-by: Charlene Yang <[email protected]> * lower cudnn version to >=9.2.1 Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7326af9 commit 67b6743

File tree

8 files changed

+66
-32
lines changed

8 files changed

+66
-32
lines changed

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 113 files

tests/pytorch/fused_attn/test_fused_attn.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,7 @@ def _rmse(a, b):
12701270
return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
12711271

12721272

1273-
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
1273+
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
12741274
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
12751275
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
12761276
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@@ -1445,7 +1445,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
14451445
return out, param_names, tuple(x.grad for x in params)
14461446

14471447

1448-
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
1448+
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
14491449
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
14501450
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
14511451
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@@ -1654,7 +1654,14 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
16541654
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
16551655

16561656

1657-
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
1657+
@pytest.mark.skipif(
1658+
(
1659+
get_cudnn_version() < (8, 9, 3)
1660+
if cudnn_frontend_version == 0
1661+
else get_cudnn_version() < (9, 2, 1)
1662+
),
1663+
reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
1664+
)
16581665
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
16591666
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
16601667
@pytest.mark.parametrize("dtype", param_types_fp8)

tests/pytorch/test_sanity.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
init_method_normal,
2121
scaled_init_method_normal,
2222
is_bf16_compatible,
23+
get_cudnn_version,
2324
)
2425
from transformer_engine.pytorch import (
2526
LayerNormLinear,
@@ -1004,6 +1005,7 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
10041005

10051006
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
10061007
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
1008+
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
10071009
@pytest.mark.parametrize("model", ["large"])
10081010
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
10091011
def test_sanity_attention_extra_state(model, dtype):

transformer_engine/common/fused_attn/fused_attn.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
8585
(((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) &&
8686
(max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim == 64) &&
8787
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) ||
88-
((cudnn_runtime_version >= 90100) && (max_seqlen_q % 128 == 0) &&
88+
((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) &&
8989
(max_seqlen_kv % 128 == 0) && (head_dim == 128) &&
9090
((qkv_format == NVTE_QKV_Format::NVTE_BSHD) ||
9191
(qkv_format == NVTE_QKV_Format::NVTE_SBHD)) &&

transformer_engine/pytorch/attention.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4179,9 +4179,10 @@ def forward(
41794179
and cu_seqlens_q is not None
41804180
and cu_seqlens_kv is not None
41814181
), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
4182-
if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None:
4183-
cu_seqlens_q_padded = cu_seqlens_q
4184-
cu_seqlens_kv_padded = cu_seqlens_kv
4182+
4183+
if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None:
4184+
cu_seqlens_q_padded = cu_seqlens_q
4185+
cu_seqlens_kv_padded = cu_seqlens_kv
41854186

41864187
qkv_dtype = TE_DType[query_layer.dtype]
41874188

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
"""GroupedLinear API"""
66
import os
7+
import logging
78
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
89

910
import torch
@@ -44,7 +45,16 @@
4445
from ..graph import is_graph_capturing
4546
from ..float8_tensor import Float8Tensor
4647

48+
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
4749
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
50+
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
51+
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
52+
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
53+
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
54+
logging.basicConfig(
55+
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
56+
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
57+
)
4858

4959
__all__ = ["GroupedLinear"]
5060

@@ -95,6 +105,7 @@ def forward(
95105
is_grad_enabled: bool,
96106
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
97107
) -> torch.Tensor:
108+
logger = logging.getLogger("GroupedLinear")
98109
num_gemms = len(m_splits)
99110
weights = weights_and_biases[:num_gemms]
100111
weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms]
@@ -149,8 +160,7 @@ def forward(
149160
inputmats = inputmats_no_fp8
150161

151162
if fp8:
152-
if _NVTE_DEBUG:
153-
print("[GroupedLinear]: using FP8 forward")
163+
logger.debug("Running forward in FP8")
154164

155165
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
156166
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
@@ -188,8 +198,7 @@ def forward(
188198
# unpad the output
189199
out = torch.cat([o[: m_splits[i]] for i, o in enumerate(out_list)], dim=0)
190200
else:
191-
if _NVTE_DEBUG:
192-
print("[GroupedLinear]: using non-FP8 forward")
201+
logger.debug("Running forward in %s", activation_dtype)
193202

194203
# Cast for native AMP
195204
weights = [cast_if_needed(w, activation_dtype) for w in weights]
@@ -294,6 +303,7 @@ def forward(
294303

295304
@staticmethod
296305
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
306+
logger = logging.getLogger("GroupedLinear")
297307

298308
with torch.cuda.nvtx.range("_GroupedLinear_backward"):
299309
(
@@ -361,8 +371,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
361371

362372
if ctx.requires_dgrad:
363373
if ctx.fp8:
364-
if _NVTE_DEBUG:
365-
print("[GroupedLinear]: using FP8 backward")
374+
logger.debug("Running backward in FP8")
366375
dgrad_list = [
367376
torch.empty(
368377
(grad_output_c[i].size(0), weights_fp8[i].size(1)),
@@ -392,8 +401,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
392401
[d[: ctx.m_splits[i]] for i, d in enumerate(dgrad_list)], dim=0
393402
)
394403
else:
395-
if _NVTE_DEBUG:
396-
print("[GroupedLinear]: using non-FP8 backward")
404+
logger.debug("Running backward in %s", ctx.activation_dtype)
397405

398406
dgrad = torch.empty(
399407
(sum(ctx.m_splits), weights[0].size(1)),

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""LayerNormLinear API"""
66
import os
77
import warnings
8+
import logging
89
from typing import Any, Callable, Dict, Optional, Tuple, Union
910

1011
import torch
@@ -47,7 +48,16 @@
4748
from ._common import _apply_normalization, _noop_cat
4849
from ..float8_tensor import Float8Tensor
4950

51+
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
5052
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
53+
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
54+
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
55+
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
56+
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
57+
logging.basicConfig(
58+
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
59+
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
60+
)
5161

5262
__all__ = ["LayerNormLinear"]
5363

@@ -94,6 +104,7 @@ def forward(
94104
ub_name: str,
95105
fsdp_group: Union[dist_group_type, None],
96106
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
107+
logger = logging.getLogger("LayerNormLinear")
97108
# Make sure input dimensions are compatible
98109
in_features = ln_weight.numel()
99110
assert inp.shape[-1] == in_features, "GEMM not possible"
@@ -190,8 +201,7 @@ def forward(
190201
ln_out = ln_out_total
191202

192203
if fp8:
193-
if _NVTE_DEBUG:
194-
print("[LayerNormLinear]: using FP8 forward")
204+
logger.debug("Running forward in FP8")
195205

196206
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
197207
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
@@ -247,8 +257,7 @@ def forward(
247257
dtype=activation_dtype,
248258
)
249259
else:
250-
if _NVTE_DEBUG:
251-
print("[LayerNormLinear]: using non-FP8 forward")
260+
logger.debug("Running forward in %s", activation_dtype)
252261

253262
# Cast for native AMP
254263
weight = cast_if_needed(weight, activation_dtype)
@@ -370,6 +379,7 @@ def forward(
370379
def backward(
371380
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
372381
) -> Tuple[Union[torch.Tensor, None], ...]:
382+
logger = logging.getLogger("LayerNormLinear")
373383
if isinstance(grad_outputs[0], Float8Tensor):
374384
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[
375385
0
@@ -490,8 +500,7 @@ def backward(
490500
ub_obj = None
491501

492502
if ctx.fp8:
493-
if _NVTE_DEBUG:
494-
print("[LayerNormLinear]: using FP8 backward")
503+
logger.debug("Running backward in FP8")
495504

496505
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
497506
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
@@ -535,8 +544,7 @@ def backward(
535544
)
536545
clear_tensor_data(grad_output_c)
537546
else:
538-
if _NVTE_DEBUG:
539-
print("[LayerNormLinear]: using non-FP8 backward")
547+
logger.debug("Running backward in %s", ctx.activation_dtype)
540548

541549
# DGRAD: Evaluated unconditionally to feed into Linear backward
542550
_, _, _ = tex.gemm(

transformer_engine/pytorch/module/linear.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
"""Linear API"""
66
import os
7+
import logging
78
from typing import Any, Callable, Dict, Optional, Tuple, Union
89

910
import torch
@@ -50,7 +51,16 @@
5051
from ..graph import is_graph_capturing
5152
from ..float8_tensor import Float8Tensor
5253

54+
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
5355
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
56+
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
57+
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
58+
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
59+
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
60+
logging.basicConfig(
61+
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
62+
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
63+
)
5464

5565
__all__ = ["Linear"]
5666

@@ -87,6 +97,7 @@ def forward(
8797
is_first_module_in_mha: bool,
8898
fsdp_group: Union[dist_group_type, None],
8999
) -> torch.Tensor:
100+
logger = logging.getLogger("Linear")
90101
is_input_fp8 = isinstance(inp, Float8Tensor)
91102
if is_input_fp8:
92103
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0]
@@ -147,8 +158,7 @@ def forward(
147158
else:
148159
inputmat_total = inputmat
149160
if fp8:
150-
if _NVTE_DEBUG:
151-
print("[Linear]: using FP8 forward")
161+
logger.debug("Running forward in FP8")
152162

153163
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
154164
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
@@ -238,8 +248,7 @@ def forward(
238248
dtype=activation_dtype,
239249
)
240250
else:
241-
if _NVTE_DEBUG:
242-
print("[Linear]: using non-FP8 forward")
251+
logger.debug("Running forward in %s", activation_dtype)
243252

244253
# Cast for native AMP
245254
weight = cast_if_needed(weight, activation_dtype)
@@ -366,6 +375,7 @@ def forward(
366375

367376
@staticmethod
368377
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
378+
logger = logging.getLogger("Linear")
369379
if isinstance(grad_output, Float8Tensor):
370380
ctx.fp8_meta["scaling_bwd"].scale_inv[
371381
tex.FP8BwdTensors.GRAD_OUTPUT1
@@ -442,8 +452,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
442452

443453
if ctx.requires_dgrad:
444454
if ctx.fp8:
445-
if _NVTE_DEBUG:
446-
print("[Linear]: using FP8 backward")
455+
logger.debug("Running backward in FP8")
447456

448457
if ctx.is_input_fp8:
449458
out_index, meta_tensor, output_te_dtype, output_dtype = (
@@ -487,8 +496,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
487496
dtype=ctx.activation_dtype,
488497
)
489498
else:
490-
if _NVTE_DEBUG:
491-
print("[Linear]: using non-FP8 backward")
499+
logger.debug("Running backward in %s", ctx.activation_dtype)
492500

493501
dgrad, _, _ = gemm(
494502
weight,

0 commit comments

Comments
 (0)