-
Notifications
You must be signed in to change notification settings - Fork 603
CPU Optimizations for FP8 #2559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <[email protected]>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci L1 pytorch |
Greptile OverviewGreptile SummaryThis PR implements CPU-side performance optimizations for FP8 operations by reducing Python interpreter overhead through strategic caching and direct C API usage. Key Changes1. Function Call Caching
2. Direct C API Usage
3. Property Caching
4. Thread-Safe Initialization
5. Attribute Check Reordering
Issues IdentifiedThe main concern is in Confidence Score: 3/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant Linear as Linear.forward()
participant Quantizer as Float8Quantizer
participant CAPI as Python C API
participant Tensor as Float8Tensor
participant GEMM as cublaslt_gemm
Note over Linear: CPU Optimization: Cache requires_grad
User->>Linear: forward(input, weight, bias)
Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>(cached property access)
Linear->>Linear: weight_requires_grad = weight.requires_grad<br/>(cached property access)
Note over Quantizer: CPU Optimization: Direct C API
Linear->>Quantizer: create_tensor()
Quantizer->>Quantizer: Check nvte_is_non_tn_fp8_gemm_supported()<br/>(cached in local var)
Quantizer->>CAPI: PyDict_New()<br/>⚠️ No error check
Quantizer->>CAPI: PyDict_SetItemString(kwargs, ...)<br/>⚠️ No error check
Quantizer->>CAPI: PyObject_Call(Float8TensorClass)<br/>⚠️ Error check after Py_DECREF
CAPI-->>Quantizer: result
Quantizer->>Tensor: Float8Tensor instance
Note over GEMM: CPU Optimization: Cache capability check
Linear->>GEMM: general_gemm(A, B)
GEMM->>GEMM: is_nvte_non_tn_fp8_gemm_supported<br/>(cached once per function call)
GEMM->>GEMM: Configure A matrix using cached value
GEMM->>GEMM: Configure B matrix using cached value
GEMM-->>Linear: output
Note over Tensor: CPU Optimization: Cached properties
Linear->>Tensor: output.dtype<br/>(returns _dtype cache)
Linear->>Tensor: output.shape<br/>(returns cached shape)
Tensor-->>Linear: cached values
Linear-->>User: result
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (3)
-
transformer_engine/pytorch/csrc/util.cpp, line 18-20 (link)logic: Critical logical error:
||should be&&. This condition will always betruesince a value cannot simultaneously be both scaling modes, causing the function to always return nullopt for valid inputs. -
transformer_engine/pytorch/quantized_tensor.py, line 373-393 (link)style: commented-out code for
requires_gradcaching optimization - consider removing dead code entirely. Is this code planned to be implemented later or should it be removed?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
transformer_engine/pytorch/module/linear.py, line 484 (link)logic: Logical error: this condition should use OR (||) not AND (&&). The original logic was checking if ANY tensor requires gradients for FP8 handling, but this now only activates when ALL three require gradients, including bias which may be None.
Should the FP8 condition check if any tensor requires gradients (OR logic) rather than all tensors (AND logic)?
10 files reviewed, 3 comments
Signed-off-by: Varun Thumbe <[email protected]>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements CPU-side performance optimizations for FP8 operations by caching frequently accessed attributes and reducing redundant function calls. The optimizations target expensive PyObject attribute lookups on custom tensor types and repeated C++ function calls.
Key Changes:
- Caches
requires_grad,dtype,shape, andis_cudaattribute accesses to avoid expensive PyObject lookups on custom tensors - Reorders attribute checks in
get_tensor_device()to prioritize internal quantized tensor attributes - Makes
num_devicesstatic innvte_is_non_tn_fp8_gemm_supported()to cache device count - Stores GEMM support check results in local variables to avoid redundant function calls
Critical Issues Found:
- Variable redeclaration error in
cublaslt_gemm.cu(line 224) will prevent compilation - Logic bug in
linear.py(line 484) changes FP8 state management from OR logic to AND logic, breaking functionality when bias is None or doesn't require grad
Confidence Score: 0/5
- This PR cannot be merged due to compilation error and critical logic bug
- Two critical issues prevent merging: (1) C++ compilation will fail due to variable redeclaration at line 224 of cublaslt_gemm.cu, and (2) logic bug at line 484 of linear.py breaks FP8 state management by requiring all three tensors to have requires_grad=True instead of any one of them
- Pay close attention to
transformer_engine/common/gemm/cublaslt_gemm.cu(compilation error) andtransformer_engine/pytorch/module/linear.py(logic bug)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/gemm/cublaslt_gemm.cu | 1/5 | Caches function call result to reduce overhead, but contains variable redeclaration error that will cause compilation failure |
| transformer_engine/common/transformer_engine.cpp | 5/5 | Makes num_devices static to avoid redundant calls to cuda::num_devices() - valid optimization |
| transformer_engine/pytorch/module/linear.py | 0/5 | Caches requires_grad checks for performance, but contains critical logic bug at line 484 that changes FP8 state management behavior |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Linear as Linear Module
participant Quantizer as Quantizer/QuantizedTensor
participant GEMM as GEMM Operations
participant CPP as C++ Extensions
Note over Linear,CPP: Performance Optimization Flow
User->>Linear: forward(input, weight, bias)
Note over Linear: Cache requires_grad checks
Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>weight_requires_grad = weight.requires_grad<br/>bias_requires_grad = bias.requires_grad
Linear->>Quantizer: Check if quantized tensor
alt QuantizedTensor
Note over Quantizer: Use cached dtype property
Quantizer->>Quantizer: return self._dtype
Note over Quantizer: Use cached shape/is_cuda
Quantizer->>Quantizer: return self._data.shape
else Regular Tensor
Quantizer->>Linear: Standard attribute access
end
Linear->>CPP: get_tensor_device(tensor)
Note over CPP: Reordered attribute checks
CPP->>CPP: Check _rowwise_data first<br/>Check _columnwise_data<br/>Check device last
CPP-->>Linear: device_index
Linear->>GEMM: Configure GEMM parameters
Note over GEMM: Cache nvte_is_non_tn_fp8_gemm_supported
GEMM->>CPP: nvte_is_non_tn_fp8_gemm_supported()
Note over CPP: Static num_devices cached
CPP-->>GEMM: support_flag
GEMM->>GEMM: Store in local variable
GEMM->>GEMM: Execute optimized GEMM
GEMM-->>Linear: output
Note over Linear: FP8 State Management
alt FP8 enabled and requires_grad check
Linear->>Linear: Update FP8 tensors<br/>based on cached flags
end
Linear-->>User: output
Additional Comments (2)
|
Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 8 comments
| ret.A_scale_inv = A.scale_inv.dptr; | ||
| ret.lda = is_A_transposed ? k : m; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same scope - it was already declared at line 224 for the B matrix configuration. This causes a compilation error or shadowing issue.
The variable should only be declared once at the beginning of the function scope and reused for both A and B matrix configurations.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Move this declaration before line 125 (Configure A matrix) | |
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Memory leak: PyTuple_New(0) creates a new Python tuple object with a reference count of 1, but this reference is never decremented. The tuple is passed directly to PyObject_Call which doesn't steal the reference, so it must be explicitly released.
This same issue occurs at lines 186, 398, 417, 697, 718, 1039, 1058, 1346, and 1367.
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| PyTuple_New(0), kwargs); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs); | |
| Py_DECREF(args); | |
| Py_DECREF(kwargs); |
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid parameter: The Float8TensorStorage.__new__ method does not accept a device parameter. Looking at the Python class definition in float8_tensor_storage.py, the __new__ signature only expects: data, fp8_scale_inv, fp8_dtype, data_transpose, and quantizer.
This will cause a runtime error when trying to create Float8TensorStorage instances. This issue also occurs at lines 184, 396, 415, and in similar patterns for MXFP8 and other quantizers.
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | |
| // Remove this line - device parameter is not accepted | |
| // PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
| def requires_grad(self) -> bool: | ||
| # Attribute access of custom tensors goes through an | ||
| # expensive Pyobject lookup. Since requires_grad is set during | ||
| # initialization and may be updated, we cache it in a member variable. | ||
| return self._requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: The requires_grad property getter accesses self._requires_grad at line 378, but this member variable is never initialized in the __new__ method (lines 353-364). This will cause an AttributeError when the property is first accessed.
The _requires_grad attribute must be initialized in __new__:
| def requires_grad(self) -> bool: | |
| # Attribute access of custom tensors goes through an | |
| # expensive Pyobject lookup. Since requires_grad is set during | |
| # initialization and may be updated, we cache it in a member variable. | |
| return self._requires_grad | |
| @property | |
| def requires_grad(self) -> bool: | |
| # Attribute access of custom tensors goes through an | |
| # expensive Pyobject lookup. Since requires_grad is set during | |
| # initialization and may be updated, we cache it in a member variable. | |
| # Fallback to parent if not cached yet | |
| if not hasattr(self, '_requires_grad'): | |
| self._requires_grad = torch.Tensor.requires_grad.fget(self) | |
| return self._requires_grad |
| instance._dtype = dtype | ||
| return instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing initialization: The _requires_grad member variable needs to be initialized here to support the cached property added later in the PR. Without this, accessing the requires_grad property will fail.
| instance._dtype = dtype | |
| return instance | |
| instance._dtype = dtype | |
| instance._requires_grad = requires_grad | |
| return instance |
| PyTypeObject *NVFP4TensorPythonClass = nullptr; | ||
| PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; | ||
| PyTypeObject *NVFP4QuantizerClass = nullptr; | ||
| bool is_extension_initialized = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition: The is_extension_initialized flag is not thread-safe. If init_extension() is called concurrently from multiple threads, both threads could pass the check at line 105 before either sets the flag at line 110, causing the initialization code to run multiple times.
The original implementation using null-checks on the individual class pointers was safer because Python's module import system provides thread safety. Consider using std::once_flag and std::call_once for proper thread-safe initialization:
| bool is_extension_initialized = false; | |
| bool is_extension_initialized = false; | |
| std::once_flag extension_init_flag; |
And in init_extension():
void init_extension() {
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
});
}| static int num_devices = transformer_engine::cuda::num_devices(); | ||
| static std::vector<int> cache(num_devices, -1); | ||
| static std::vector<std::once_flag> flags(num_devices); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential edge case: The num_devices is captured once as a static variable, so if the number of devices changes at runtime (e.g., with CUDA MPS or device enumeration changes), the cache and flags vectors could be undersized, leading to out-of-bounds access at line 968.
While dynamic device changes are rare, consider adding a bounds check:
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| int device_id = transformer_engine::cuda::current_device(); | |
| NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 4 comments
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same function scope (already declared at line 132). In C++, this is a compilation error. The variable should be declared once at the function scope level before both the A matrix and B matrix configuration blocks.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Move declaration to line 104 (after GemmParam ret;) to avoid redeclaration |
| py::object scale_inv_py = py::cast(scale_inv); | ||
| // Initialize scale-inverse tensor | ||
| if (!scale_inv) { | ||
| scale_inv = at::reciprocal(scale); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: scale_inv_py is created on line 146 by dereferencing scale_inv (which is std::optional<at::Tensor>), but scale_inv is only initialized on lines 148-150 if it's empty. When scale_inv is not provided to this function, line 146 will dereference an empty optional, causing undefined behavior or a crash.
The fix is to move line 146 after the initialization block:
| py::object scale_inv_py = py::cast(scale_inv); | |
| // Initialize scale-inverse tensor | |
| if (!scale_inv) { | |
| scale_inv = at::reciprocal(scale); | |
| } | |
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | |
| // Initialize scale-inverse tensor | |
| if (!scale_inv) { | |
| scale_inv = at::reciprocal(scale); | |
| } | |
| py::object scale_inv_py = py::cast(*scale_inv); |
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Memory leak: PyTuple_New(0) creates a new tuple with reference count 1, but PyObject_Call does not steal the reference. The tuple must be stored and explicitly DECREF'd after the call. This pattern is repeated 12 times throughout this file (lines 166, 186, 398, 417, 697, 718, 1039, 1058, 1346, 1367).
| PyTuple_New(0), kwargs); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs); | |
| Py_DECREF(args); |
| instance._dtype = dtype | ||
| return instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: _dtype is cached in __new__ but _requires_grad is not initialized. The requires_grad property (line 378) tries to return self._requires_grad, which will raise AttributeError when accessed before the setter is called. This breaks the optimization for caching requires_grad.
| instance._dtype = dtype | |
| return instance | |
| instance._dtype = dtype | |
| instance._requires_grad = requires_grad | |
| return instance |
Signed-off-by: Varun Thumbe <[email protected]>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <[email protected]>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 13 comments
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Variable is_nvte_non_tn_fp8_gemm_supported is already declared at line 132. Declaring it again here causes a compilation error. Remove this declaration and reuse the variable from line 132.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Reuse is_nvte_non_tn_fp8_gemm_supported from line 132 |
| } | ||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | ||
|
|
||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Use-after-free bug: scale_inv_py is created from scale_inv before scale_inv is initialized. When scale_inv parameter is empty (lines 148-150), this dereferences an empty std::optional, causing undefined behavior. Move this line after the scale_inv initialization.
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | |
| // Move scale_inv_py creation after initialization (line 150) |
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) creates a new Python object with refcount 1, but it's never decremented. This leaks memory on every tensor creation. Add Py_DECREF for the tuple:
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | |
| Py_DECREF(args); |
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same issue as line 166).
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorPythonClass), args, kwargs); | |
| Py_DECREF(args); |
|
|
||
| ctx.owns_input = saved_inputmat is not inp | ||
| if ctx.fp8 and requires_grad(inp, weight, bias): | ||
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Logic error: The original requires_grad(inp, weight, bias) returns True if ANY tensor requires grad. This change requires ALL THREE to be True with and, breaking FP8 state management when bias is None or doesn't require grad.
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): | |
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): |
| PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Variable is_nvte_non_tn_fp8_gemm_supported is declared twice in the same function scope (first at line 132). This will cause a compilation error. Remove this redeclaration and reuse the variable from line 132.
| } | ||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | ||
|
|
||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Duplicate line. transpose_py is assigned twice (lines 145 and 146). Remove this duplicate line.
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) creates a new Python object that is never decremented. Store the result and call Py_DECREF(tuple) after PyObject_Call. This pattern repeats in all 12 calls to PyObject_Call throughout this file (lines 167, 186, 331, 350, 397, 416, 706, 725, 1042, 1061, 1347, 1366).
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 5 comments
| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
potential memory leak if py::cast() throws exception before Py_DECREF
If any of the py::cast(this->dtype), py::cast(shape_int64), or py::cast(device) calls throw an exception, the kwargs and args PyObjects will leak because Py_DECREF on lines 168-169 will never be reached.
| // Use direct C API call bypassing pybind11 overhead | |
| PyObject* kwargs = PyDict_New(); | |
| PyObject* args = PyTuple_New(0); | |
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | |
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | |
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | |
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | |
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | |
| PyObject* result = | |
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | |
| Py_DECREF(kwargs); | |
| Py_DECREF(args); | |
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | |
| out_py = py::reinterpret_steal<py::object>(result); | |
| // Use RAII for Python object lifetime management | |
| py::dict kwargs_py; | |
| kwargs_py["data"] = data_py; | |
| kwargs_py["fp8_scale_inv"] = scale_inv_py; | |
| kwargs_py["fp8_dtype"] = py::cast(this->dtype); | |
| kwargs_py["data_transpose"] = transpose_py; | |
| kwargs_py["quantizer"] = this->quantizer; | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs_py.ptr()); | |
| Py_DECREF(args); |
This same pattern repeats throughout the file at lines 176-193, 382-400, 411-428, 657-678, 686-707, 1013-1034, 1042-1063, 1304-1325, and 1333-1354. All instances should use pybind11's py::dict for exception safety.
| at::Device device = | ||
| with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device defaults to cuda:0 when both data and transpose are None, may not be current device
In multi-GPU setups, if neither with_data nor with_transpose are true, the device defaults to torch::kCUDA (which implicitly converts to cuda:0). This may not match the current CUDA device, potentially causing cross-device operation issues.
| at::Device device = | |
| with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA); | |
| at::Device device = | |
| with_data ? data->device() : (with_transpose ? transpose->device() : at::Device(torch::kCUDA, torch::cuda::current_device())); |
This same issue exists at line 383 in Float8CurrentScalingQuantizer::create_tensor.
| if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: | ||
| return tensor._rowwise_data.device.index | ||
| if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: | ||
| return tensor._columnwise_data.device.index | ||
| if hasattr(tensor, "_data") and tensor._data is not None: | ||
| return tensor._data.device.index | ||
| if hasattr(tensor, "_transpose") and tensor._transpose is not None: | ||
| return tensor._transpose.device.index | ||
| if hasattr(tensor, "device"): | ||
| return tensor.device.index | ||
| return torch.cuda.current_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reordering attribute checks changes behavior when devices don't match
The new code checks quantized tensor attributes (_rowwise_data, _columnwise_data, etc.) before checking the standard .device attribute. If a tensor has both and they refer to different devices, this changes which device is returned.
For example, if tensor.device == cuda:1 but tensor._rowwise_data.device == cuda:0, the old code would return 1 but the new code returns 0. This could break multi-GPU code that relies on the .device attribute being authoritative.
Consider documenting this behavior change or adding a check that all device attributes match if multiple are present.
| @property | ||
| def dtype(self) -> torch.dtype: | ||
| """ | ||
| Return the high precision data type of the tensor | ||
| Attribute access of custom tensors goes through an | ||
| expensive Pyobject lookup. Since dtype for a tensor is never | ||
| change after creation, we cache it in a member variable and return | ||
| """ | ||
| return self._dtype | ||
|
|
||
| @property | ||
| def requires_grad(self) -> bool: | ||
| """ | ||
| Return whether or not the tensor requires gradient. | ||
| Attribute access of custom tensors goes through an | ||
| expensive Pyobject lookup. Since requires_grad is set during | ||
| initialization and may be updated, we cache it in a member variable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_dtype and _requires_grad only initialized in __new__, will cause AttributeError in other creation paths
The new dtype and requires_grad properties rely on cached _dtype and _requires_grad attributes, but these are only initialized in the __new__ method (lines 363-364). If a QuantizedTensor is created through other paths such as:
- Unpickling (
__reduce_ex__,__setstate__) torch.Tensorsubclass operations that bypass__new__- FSDP state dict loading
Accessing .dtype or .requires_grad will raise AttributeError: '_dtype' not found.
| @property | |
| def dtype(self) -> torch.dtype: | |
| """ | |
| Return the high precision data type of the tensor | |
| Attribute access of custom tensors goes through an | |
| expensive Pyobject lookup. Since dtype for a tensor is never | |
| change after creation, we cache it in a member variable and return | |
| """ | |
| return self._dtype | |
| @property | |
| def requires_grad(self) -> bool: | |
| """ | |
| Return whether or not the tensor requires gradient. | |
| Attribute access of custom tensors goes through an | |
| expensive Pyobject lookup. Since requires_grad is set during | |
| initialization and may be updated, we cache it in a member variable. | |
| @property | |
| def dtype(self) -> torch.dtype: | |
| """ | |
| Return the high precision data type of the tensor | |
| """ | |
| # Lazy initialization for tensors created via alternate paths | |
| if not hasattr(self, '_dtype'): | |
| self._dtype = super().__getattribute__('dtype') | |
| return self._dtype |
Apply similar fix to requires_grad property.
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return self._data.shape if self._data is not None else self._transpose.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape property will raise AttributeError if both _data and _transpose are None
The property assumes at least one of _data or _transpose is not None, but doesn't validate the _transpose attribute before accessing it. If both are None (edge case during initialization or after certain operations), accessing .shape will raise AttributeError: 'NoneType' object has no attribute 'shape'.
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| return self._data.shape if self._data is not None else self._transpose.shape | |
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| if self._data is not None: | |
| return self._data.shape | |
| elif self._transpose is not None: | |
| return self._transpose.shape | |
| else: | |
| # Fallback to getting shape from alternate data sources or raise descriptive error | |
| raise RuntimeError("Float8Tensor has no data buffers initialized") |
Same issue exists in the is_cuda property at lines 919-921.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 6 comments
| ctx.tp_size = tp_size | ||
| ctx.requires_dgrad = inp.requires_grad | ||
| ctx.requires_wgrad = weight.requires_grad | ||
| ctx.requires_dgrad = inp_requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic change from requires_grad(inp, weight, bias) to (inp_requires_grad or weight_requires_grad or bias_requires_grad) is correct for the optimization, but this breaks the FP8 state management semantic.
The original requires_grad() function returns True if ANY tensor requires grad (see utils.py line 23-28). The new code uses and instead of or:
if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):However, looking more carefully at the actual change, the new code uses or, which is semantically equivalent to the old requires_grad(inp, weight, bias) call. This is actually correct!
But wait - I need to double-check the previous thread comment. The previous thread says "logic change from original requires_grad(inp, weight, bias) which returns True if ANY tensor requires grad. New code requires ALL THREE to be True". Let me verify this is indeed using or not and...
| @property | ||
| def dtype(self) -> torch.dtype: | ||
| """ | ||
| Return the high precision data type of the tensor | ||
| Attribute access of custom tensors goes through an | ||
| expensive Pyobject lookup. Since dtype for a tensor is never | ||
| change after creation, we cache it in a member variable and return | ||
| """ | ||
| return self._dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dtype property accesses self._dtype without checking if it exists. While QuantizedTensor.__new__() initializes this at line 364, subclasses like Float8Tensor can be instantiated through paths that may not properly initialize these cached attributes.
For example, when Float8TensorStorage.__new__() calls super().__new__(cls, *args, **kwargs), if dtype is not in kwargs, QuantizedTensor.__new__() won't receive it and won't set instance._dtype.
Consider adding defensive checks or ensuring all instantiation paths properly initialize _dtype and _requires_grad, or catch AttributeError and fall back to the parent class property.
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential memory leak if any operation between PyDict_New()/PyTuple_New() and Py_DECREF() throws an exception or fails. The py::cast() calls at lines 161-163 could throw C++ exceptions, and PyDict_SetItemString() could fail, leaving kwargs and args unreleased.
Consider using RAII wrappers or try-catch blocks to ensure cleanup:
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
if (!kwargs || !args) {
Py_XDECREF(kwargs);
Py_XDECREF(args);
throw std::runtime_error("Failed to allocate Python objects");
}
try {
PyDict_SetItemString(kwargs, "data", data_py.ptr());
// ... rest of setup
PyObject* result = PyObject_Call(...);
Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
out_py = py::reinterpret_steal<py::object>(result);
} catch (...) {
Py_DECREF(kwargs);
Py_DECREF(args);
throw;
}Or better yet, use py::dict and py::tuple from pybind11 which handle cleanup automatically.
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return self._data.shape if self._data is not None else self._transpose.shape | ||
|
|
||
| @property | ||
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return self._data.is_cuda if self._data is not None else self._transpose.is_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shape and is_cuda properties assume that either self._data or self._transpose is not None. However, looking at the quantizer logic in Float8Quantizer.create_tensor() (quantizer.cpp lines 126-137), there are cases where:
with_data = rowwise_usage || is_non_tn_fp8_gemm_supportedwith_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported
If both rowwise_usage and columnwise_usage are False, and is_non_tn_fp8_gemm_supported is False, then both _data and _transpose would be None, causing AttributeError when accessing .shape or .is_cuda.
This scenario should probably be prevented at quantizer initialization, or these properties should handle the None case gracefully.
| if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: | ||
| return tensor._rowwise_data.device.index | ||
| if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: | ||
| return tensor._columnwise_data.device.index | ||
| if hasattr(tensor, "_data") and tensor._data is not None: | ||
| return tensor._data.device.index | ||
| if hasattr(tensor, "_transpose") and tensor._transpose is not None: | ||
| return tensor._transpose.device.index | ||
| if hasattr(tensor, "device"): | ||
| return tensor.device.index | ||
| return torch.cuda.current_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reordering of device attribute checks changes the priority of which device is returned. Previously, if a tensor had both device attribute and internal attributes like _data, it would return tensor.device.index. Now it returns the device of the internal attribute.
For QuantizedTensor instances where these might differ (though unlikely), this could change behavior. Consider the case where:
- A QuantizedTensor has
_dataon device 0 - But the tensor itself reports being on device 1 (via its
deviceproperty)
The old code would return device 1, the new code returns device 0. While this scenario is unlikely given how QuantizedTensor is implemented, it's a semantic change worth noting. If the internal attributes are guaranteed to be on the same device as the tensor, this is fine. Otherwise, document this priority or add an assertion.
|
|
||
| int nvte_is_non_tn_fp8_gemm_supported() { | ||
| int num_devices = transformer_engine::cuda::num_devices(); | ||
| static int num_devices = transformer_engine::cuda::num_devices(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making num_devices static has a subtle initialization order issue. The static initialization happens once when the function is first called, but the subsequent static vectors cache and flags depend on num_devices for their size.
If transformer_engine::cuda::num_devices() returns different values across multiple calls (which shouldn't happen in practice but isn't guaranteed by the API), the first call to this function will initialize num_devices, and subsequent calls will use that cached value. However, if the CUDA context changes or devices are added/removed (in rare scenarios), this could cause a mismatch.
Consider:
static int num_devices = transformer_engine::cuda::num_devices();This is initialized once, but cache and flags vectors might need a different size if the device count somehow changes. While unlikely, this could cause out-of-bounds access.
A safer approach might be:
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
cache.resize(num_devices, -1);
flags.resize(num_devices);
});Or simply document that the device count must not change during the application's lifetime.
Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 8 comments
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical memory management bug: py::cast(this->dtype).ptr() creates a temporary py::object that is immediately destroyed after .ptr() is called, but PyDict_SetItemString increments the reference count of the PyObject. This pattern is used throughout this file and can lead to use-after-free bugs or memory corruption.
The issue is that:
py::cast(this->dtype)creates a temporarypy::objectwith refcount=1.ptr()returns the rawPyObject*pointerPyDict_SetItemStringincrements refcount to 2- The temporary
py::objectgoes out of scope, decrementing refcount back to 1 - The dictionary now holds the only reference to the object, which works BUT
- If the temporary object's destructor releases resources or if there are any edge cases, this becomes unsafe
Fix: Store py::cast() results in named variables before calling PyDict_SetItemString:
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | |
| py::object fp8_dtype_obj = py::cast(this->dtype); | |
| PyDict_SetItemString(kwargs, "fp8_dtype", fp8_dtype_obj.ptr()); |
This same pattern appears on lines: 178, 179, 182, 185, 394, 411, 412, 415, 418, 691-698, 712-721, 1042, 1056-1057, 1062, 1352, 1366-1367, 1374. All instances must be fixed.
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error check occurs AFTER Py_DECREF(kwargs) and Py_DECREF(args) (lines 168-169), which means if result is nullptr, the program will terminate via NVTE_CHECK without proper cleanup. While this is acceptable for fatal errors, it's inconsistent with the cleanup-before-check pattern and could mask the actual Python exception message.
Suggested improvement: Check for nullptr before decrementing reference counts, and use PyErr_Print() or PyErr_Fetch() to preserve the Python exception information:
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | |
| PyObject* result = | |
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | |
| if (result == nullptr) { | |
| PyErr_Print(); // Print the actual Python error | |
| Py_DECREF(kwargs); | |
| Py_DECREF(args); | |
| NVTE_CHECK(false, "Failed to create Float8TensorStorage instance"); | |
| } | |
| Py_DECREF(kwargs); | |
| Py_DECREF(args); | |
| out_py = py::reinterpret_steal<py::object>(result); |
This same pattern applies to all other PyObject_Call invocations in this file (lines 187, 401, 706, 723, 1047, 1068, 1357, 1378).
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | ||
|
|
||
| // Initialize scale-inverse tensor | ||
| if (!scale_inv) { | ||
| scale_inv = at::reciprocal(scale); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code initializes scale_inv if not provided, then creates scale_inv_py, then computes device from the tensors. However, device is only computed if with_data or with_transpose is true, but it's used unconditionally on line 185. If neither with_data nor with_transpose is true (which shouldn't happen per the check on line 196, but code flow is unclear), device would be uninitialized.
The flow would be clearer if the device calculation happens before constructing the Python tensor, as it's needed for the constructor. Current code is safe due to the check on line 196, but the logic flow could be clearer.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| void init_extension() { | ||
| init_float8_extension(); | ||
| init_mxfp8_extension(); | ||
| init_float8blockwise_extension(); | ||
| init_nvfp4_extensions(); | ||
| std::call_once(extension_init_flag, []() { | ||
| init_float8_extension(); | ||
| init_mxfp8_extension(); | ||
| init_float8blockwise_extension(); | ||
| init_nvfp4_extensions(); | ||
| }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good use of std::call_once to ensure thread-safe initialization! However, the individual init functions (init_float8_extension, etc.) previously had guards that checked if the classes were already initialized (e.g., if (Float8TensorPythonClass) return;). These guards have been removed.
While std::call_once ensures the lambda is only executed once, if init_extension() is called multiple times from different contexts (e.g., multi-threaded module loading), the std::call_once protection is solid. However, if any of the individual init functions fail partway through (e.g., PyObject_GetAttrString returns nullptr), the subsequent init functions in the same call would still execute, potentially leaving the extension in a partially initialized state.
Recommendation: Add nullptr checks after each PyObject_GetAttrString call to ensure all classes were successfully loaded before proceeding.
| @requires_grad.setter | ||
| def requires_grad(self, value: bool) -> None: | ||
| """Set requires_grad property so that autograd engine is aware of the change""" | ||
| # Update the cached value | ||
| self._requires_grad = value | ||
| # Call parent class to ensure autograd engine is aware of the change | ||
| torch.Tensor.requires_grad.fset(self, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The requires_grad property setter updates the cached _requires_grad value and then calls torch.Tensor.requires_grad.fset(self, value) to ensure the autograd engine is aware. However, there's a potential issue: if the parent class's setter fails or raises an exception, _requires_grad will already be updated, causing the cached value to be out of sync with the actual tensor state.
Suggested improvement: Call the parent setter first, then update the cache only if successful:
| @requires_grad.setter | |
| def requires_grad(self, value: bool) -> None: | |
| """Set requires_grad property so that autograd engine is aware of the change""" | |
| # Update the cached value | |
| self._requires_grad = value | |
| # Call parent class to ensure autograd engine is aware of the change | |
| torch.Tensor.requires_grad.fset(self, value) | |
| @requires_grad.setter | |
| def requires_grad(self, value: bool) -> None: | |
| """Set requires_grad property so that autograd engine is aware of the change""" | |
| # Call parent class first to ensure autograd engine is aware of the change | |
| torch.Tensor.requires_grad.fset(self, value) | |
| # Update the cached value only after successful update | |
| self._requires_grad = value |
The same pattern should be applied to requires_grad_() method on lines 402-409.
| @dtype.setter | ||
| def dtype(self, value: torch.dtype) -> None: | ||
| """Set dtype property""" | ||
| # Update the cached value | ||
| self._dtype = value | ||
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dtype property has a setter that allows changing the dtype after creation, with a warning. However, this is problematic because:
- The warning says "Dtype of QuantizedTensor has been changed" but doesn't indicate this could break things
- QuantizedTensor subclasses (Float8Tensor, MXFP8Tensor, etc.) have quantized data in specific formats (FP8, MXFP4) that don't change when you update
_dtype - Changing dtype without re-quantizing the underlying data creates a mismatch between the cached dtype and the actual data representation
This setter creates a dangerous API surface. If dtype must be mutable, it should either:
- Raise an error instead of just warning
- Or trigger re-quantization of the data
Current impact: Users could accidentally corrupt tensor semantics by changing dtype, leading to incorrect dequantization.
| ) | ||
| return out, all_gather_outputs | ||
|
|
||
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return self._data.shape if self._data is not None else self._transpose.shape | ||
|
|
||
| @property | ||
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new shape and is_cuda property implementations have a logic issue: they return properties from self._data if it exists, otherwise from self._transpose. However, according to the codebase logic in cublaslt_gemm.cu and quantizer.cpp, on Hopper (non-Blackwell) architectures, FP8 tensors may only have _transpose data for certain GEMM layouts, while on Blackwell they may only have _data.
The properties assume that if _data is None, then _transpose must exist, but there's no fallback handling if both are somehow None (which shouldn't happen but could in edge cases).
Suggested improvement: Add a safety check or document the invariant that at least one must always exist:
| ) | |
| return out, all_gather_outputs | |
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| return self._data.shape if self._data is not None else self._transpose.shape | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| if self._data is not None: | |
| return self._data.shape | |
| if self._transpose is not None: | |
| return self._transpose.shape | |
| raise RuntimeError("Float8Tensor has neither _data nor _transpose") | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| if self._data is not None: | |
| return self._data.is_cuda | |
| if self._transpose is not None: | |
| return self._transpose.is_cuda | |
| raise RuntimeError("Float8Tensor has neither _data nor _transpose") |
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return ( | ||
| self._rowwise_data.shape | ||
| if self._rowwise_data is not None | ||
| else self._columnwise_data.shape | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as Float8Tensor: the shape and is_cuda properties assume that if _rowwise_data is None, then _columnwise_data must exist. However, there's no safety check for the edge case where both could be None.
Suggested improvement: Add safety checks:
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| return ( | |
| self._rowwise_data.shape | |
| if self._rowwise_data is not None | |
| else self._columnwise_data.shape | |
| ) | |
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| if self._rowwise_data is not None: | |
| return self._rowwise_data.shape | |
| if self._columnwise_data is not None: | |
| return self._columnwise_data.shape | |
| raise RuntimeError("MXFP8Tensor has neither _rowwise_data nor _columnwise_data") | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| if self._rowwise_data is not None: | |
| return self._rowwise_data.is_cuda | |
| if self._columnwise_data is not None: | |
| return self._columnwise_data.is_cuda | |
| raise RuntimeError("MXFP8Tensor has neither _rowwise_data nor _columnwise_data") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error checking for Python C API calls could lead to crashes. The code doesn't check return values from PyDict_New(), PyTuple_New(), PyDict_SetItemString(), or PyObject_Call() before proceeding. If any of these fail (e.g., due to memory allocation failure), the subsequent operations will access invalid pointers.
Recommendation:
Add error checks after each C API call:
- Check if
PyDict_New()andPyTuple_New()return NULL - Check if
PyDict_SetItemString()returns -1 - The existing
NVTE_CHECK(result != nullptr, ...)is good, but should come beforePy_DECREFcalls to avoid decrementing invalid references
This pattern is repeated throughout this file at lines 176-193, 389-404, 410-426, 688-706, 710-727, 1035-1050, 1054-1070, 1343-1361, and 1365-1383. All instances need the same error checking.
| @property | ||
| def requires_grad(self) -> bool: | ||
| """ | ||
| Return whether or not the tensor requires gradient. | ||
| Attribute access of custom tensors goes through an | ||
| expensive Pyobject lookup. Since requires_grad is set during | ||
| initialization and may be updated, we cache it in a member variable. | ||
| """ | ||
| return self._requires_grad | ||
|
|
||
| @requires_grad.setter | ||
| def requires_grad(self, value: bool) -> None: | ||
| """Set requires_grad property so that autograd engine is aware of the change""" | ||
| # Update the cached value | ||
| self._requires_grad = value | ||
| # Call parent class to ensure autograd engine is aware of the change | ||
| torch.Tensor.requires_grad.fset(self, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential cache inconsistency if PyTorch modifies requires_grad through internal mechanisms. While the setter and requires_grad_() method properly sync the cache with the parent tensor, there's a risk that PyTorch's autograd engine could modify requires_grad through paths that bypass these methods (e.g., during graph construction or when tensors are involved in certain operations).
Current behavior: The cached value could become stale if PyTorch modifies requires_grad internally.
Mitigation: The implementation correctly updates the parent tensor in the setter and requires_grad_(), which should handle most cases. However, consider documenting this limitation or adding validation in critical paths to ensure cache consistency.
Note: This is a minor concern as most requires_grad modifications go through the standard setter, but worth being aware of for debugging purposes.
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error checking order issue: NVTE_CHECK for result != nullptr happens after Py_DECREF(kwargs) and Py_DECREF(args). While decrefing these is safe even when result is NULL, checking for errors before cleanup is better practice and makes the error path clearer.
If PyObject_Call fails and returns NULL, Python's exception state is set. The NVTE_CHECK should happen immediately after the call, and cleanup should be part of an error handling path. Consider moving the check before Py_DECREF or using proper error handling with early returns.
This pattern affects all PyObject_Call instances in this file.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: