|
4 | 4 |
|
5 | 5 | """GroupedLinear API""" |
6 | 6 | import os |
| 7 | +import logging |
7 | 8 | from typing import Union, Optional, Callable, Tuple, List, Dict, Any |
8 | 9 |
|
9 | 10 | import torch |
|
44 | 45 | from ..graph import is_graph_capturing |
45 | 46 | from ..float8_tensor import Float8Tensor |
46 | 47 |
|
| 48 | +# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 |
47 | 49 | _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) |
| 50 | +# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 |
| 51 | +_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) |
| 52 | +log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL |
| 53 | +log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} |
| 54 | +logging.basicConfig( |
| 55 | + format="[%(levelname)-8s | %(name)-19s]: %(message)s", |
| 56 | + level=log_levels[log_level if log_level in [0, 1, 2] else 2], |
| 57 | +) |
48 | 58 |
|
49 | 59 | __all__ = ["GroupedLinear"] |
50 | 60 |
|
@@ -95,6 +105,7 @@ def forward( |
95 | 105 | is_grad_enabled: bool, |
96 | 106 | *weights_and_biases: Union[Float8Tensor, torch.Tensor, None], |
97 | 107 | ) -> torch.Tensor: |
| 108 | + logger = logging.getLogger("GroupedLinear") |
98 | 109 | num_gemms = len(m_splits) |
99 | 110 | weights = weights_and_biases[:num_gemms] |
100 | 111 | weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms] |
@@ -149,8 +160,7 @@ def forward( |
149 | 160 | inputmats = inputmats_no_fp8 |
150 | 161 |
|
151 | 162 | if fp8: |
152 | | - if _NVTE_DEBUG: |
153 | | - print("[GroupedLinear]: using FP8 forward") |
| 163 | + logger.debug("Running forward in FP8") |
154 | 164 |
|
155 | 165 | bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype |
156 | 166 | biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases |
@@ -188,8 +198,7 @@ def forward( |
188 | 198 | # unpad the output |
189 | 199 | out = torch.cat([o[: m_splits[i]] for i, o in enumerate(out_list)], dim=0) |
190 | 200 | else: |
191 | | - if _NVTE_DEBUG: |
192 | | - print("[GroupedLinear]: using non-FP8 forward") |
| 201 | + logger.debug("Running forward in %s", activation_dtype) |
193 | 202 |
|
194 | 203 | # Cast for native AMP |
195 | 204 | weights = [cast_if_needed(w, activation_dtype) for w in weights] |
@@ -294,6 +303,7 @@ def forward( |
294 | 303 |
|
295 | 304 | @staticmethod |
296 | 305 | def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: |
| 306 | + logger = logging.getLogger("GroupedLinear") |
297 | 307 |
|
298 | 308 | with torch.cuda.nvtx.range("_GroupedLinear_backward"): |
299 | 309 | ( |
@@ -361,8 +371,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], |
361 | 371 |
|
362 | 372 | if ctx.requires_dgrad: |
363 | 373 | if ctx.fp8: |
364 | | - if _NVTE_DEBUG: |
365 | | - print("[GroupedLinear]: using FP8 backward") |
| 374 | + logger.debug("Running backward in FP8") |
366 | 375 | dgrad_list = [ |
367 | 376 | torch.empty( |
368 | 377 | (grad_output_c[i].size(0), weights_fp8[i].size(1)), |
@@ -392,8 +401,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], |
392 | 401 | [d[: ctx.m_splits[i]] for i, d in enumerate(dgrad_list)], dim=0 |
393 | 402 | ) |
394 | 403 | else: |
395 | | - if _NVTE_DEBUG: |
396 | | - print("[GroupedLinear]: using non-FP8 backward") |
| 404 | + logger.debug("Running backward in %s", ctx.activation_dtype) |
397 | 405 |
|
398 | 406 | dgrad = torch.empty( |
399 | 407 | (sum(ctx.m_splits), weights[0].size(1)), |
|
0 commit comments