Skip to content

Conversation

@vthumbe1503
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503 vthumbe1503 marked this pull request as ready for review January 7, 2026 17:22
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 7, 2026

Greptile Overview

Greptile Summary

This PR implements CPU-side performance optimizations for FP8 operations by reducing Python interpreter overhead through strategic caching and direct C API usage.

Key Changes

1. Function Call Caching

  • Caches nvte_is_non_tn_fp8_gemm_supported() results in cublaslt_gemm.cu to avoid repeated expensive capability checks during GEMM setup
  • Makes num_devices static in transformer_engine.cpp to cache device count (note: won't detect GPU hot-plugging, but acceptable for typical scenarios)
  • Caches requires_grad attribute lookups in linear.py to prevent redundant PyObject attribute access

2. Direct C API Usage

  • Replaces pybind11's keyword argument construction with direct PyDict_New() and PyDict_SetItemString() operations in quantizer.cpp across all quantizer types (Float8, Float8CurrentScaling, Float8Block, MXFP8, NVFP4)
  • Reduces object creation overhead by bypassing pybind11's higher-level abstractions

3. Property Caching

  • Adds cached dtype and requires_grad properties in QuantizedTensor base class to avoid expensive PyObject lookups
  • Implements proper setters to maintain consistency with PyTorch's autograd engine
  • Extends caching to shape and is_cuda properties in Float8Tensor and MXFP8Tensor

4. Thread-Safe Initialization

  • Uses std::call_once for extension initialization in pybind.cpp to ensure thread-safe, one-time initialization

5. Attribute Check Reordering

  • Reorders device attribute checks in gemm.py to prioritize internal tensor data attributes before falling back to .device, optimizing for quantized tensors

Issues Identified

The main concern is in quantizer.cpp where the direct C API calls lack proper error checking. Missing validation of return values from PyDict_New(), PyTuple_New(), PyDict_SetItemString(), and improper error handling order for PyObject_Call() could lead to crashes if memory allocation fails or Python exceptions occur.

Confidence Score: 3/5

  • This PR has effective optimizations but critical error handling gaps in C API usage that need addressing before merge
  • Score reflects solid optimization strategy and mostly correct implementation, but penalized for missing error checking in quantizer.cpp C API calls. The lack of validation for PyDict_New, PyTuple_New, and PyDict_SetItemString return values, plus improper error handling order, poses crash risks in low-memory conditions. Most other changes (caching, thread-safe init, property optimization) are well-implemented and safe.
  • transformer_engine/pytorch/csrc/quantizer.cpp requires error checking additions for all Python C API calls (approximately 9 locations throughout the file)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/transformer_engine.cpp 4/5 Makes num_devices static to cache device count. Minor concern: won't detect GPU hot-plugging, but acceptable for typical training scenarios.
transformer_engine/pytorch/csrc/quantizer.cpp 3/5 Replaces pybind11 keyword construction with direct C API (PyDict_SetItemString, PyObject_Call). Reduces overhead but lacks error checking for PyDict operations which could lead to crashes if memory allocation fails.
transformer_engine/pytorch/module/linear.py 5/5 Caches requires_grad attribute lookups to avoid repeated PyObject attribute access. Logic is correct - properly uses OR operation for fp8 state management.
transformer_engine/pytorch/quantized_tensor.py 4/5 Adds property caching for dtype and requires_grad to avoid expensive PyObject lookups. Includes proper setters to maintain autograd consistency. Minor risk if PyTorch modifies requires_grad through non-standard paths.

Sequence Diagram

sequenceDiagram
    participant User
    participant Linear as Linear.forward()
    participant Quantizer as Float8Quantizer
    participant CAPI as Python C API
    participant Tensor as Float8Tensor
    participant GEMM as cublaslt_gemm
    
    Note over Linear: CPU Optimization: Cache requires_grad
    User->>Linear: forward(input, weight, bias)
    Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>(cached property access)
    Linear->>Linear: weight_requires_grad = weight.requires_grad<br/>(cached property access)
    
    Note over Quantizer: CPU Optimization: Direct C API
    Linear->>Quantizer: create_tensor()
    Quantizer->>Quantizer: Check nvte_is_non_tn_fp8_gemm_supported()<br/>(cached in local var)
    Quantizer->>CAPI: PyDict_New()<br/>⚠️ No error check
    Quantizer->>CAPI: PyDict_SetItemString(kwargs, ...)<br/>⚠️ No error check
    Quantizer->>CAPI: PyObject_Call(Float8TensorClass)<br/>⚠️ Error check after Py_DECREF
    CAPI-->>Quantizer: result
    Quantizer->>Tensor: Float8Tensor instance
    
    Note over GEMM: CPU Optimization: Cache capability check
    Linear->>GEMM: general_gemm(A, B)
    GEMM->>GEMM: is_nvte_non_tn_fp8_gemm_supported<br/>(cached once per function call)
    GEMM->>GEMM: Configure A matrix using cached value
    GEMM->>GEMM: Configure B matrix using cached value
    GEMM-->>Linear: output
    
    Note over Tensor: CPU Optimization: Cached properties
    Linear->>Tensor: output.dtype<br/>(returns _dtype cache)
    Linear->>Tensor: output.shape<br/>(returns cached shape)
    Tensor-->>Linear: cached values
    Linear-->>User: result
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/pytorch/csrc/util.cpp, line 18-20 (link)

    logic: Critical logical error: || should be &&. This condition will always betruesince a value cannot simultaneously be both scaling modes, causing the function to always return nullopt for valid inputs.

  2. transformer_engine/pytorch/quantized_tensor.py, line 373-393 (link)

    style: commented-out code for requires_grad caching optimization - consider removing dead code entirely. Is this code planned to be implemented later or should it be removed?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  3. transformer_engine/pytorch/module/linear.py, line 484 (link)

    logic: Logical error: this condition should use OR (||) not AND (&&). The original logic was checking if ANY tensor requires gradients for FP8 handling, but this now only activates when ALL three require gradients, including bias which may be None.

    Should the FP8 condition check if any tensor requires gradients (OR logic) rather than all tensors (AND logic)?

10 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <[email protected]>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <[email protected]>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements CPU-side performance optimizations for FP8 operations by caching frequently accessed attributes and reducing redundant function calls. The optimizations target expensive PyObject attribute lookups on custom tensor types and repeated C++ function calls.

Key Changes:

  • Caches requires_grad, dtype, shape, and is_cuda attribute accesses to avoid expensive PyObject lookups on custom tensors
  • Reorders attribute checks in get_tensor_device() to prioritize internal quantized tensor attributes
  • Makes num_devices static in nvte_is_non_tn_fp8_gemm_supported() to cache device count
  • Stores GEMM support check results in local variables to avoid redundant function calls

Critical Issues Found:

  • Variable redeclaration error in cublaslt_gemm.cu (line 224) will prevent compilation
  • Logic bug in linear.py (line 484) changes FP8 state management from OR logic to AND logic, breaking functionality when bias is None or doesn't require grad

Confidence Score: 0/5

  • This PR cannot be merged due to compilation error and critical logic bug
  • Two critical issues prevent merging: (1) C++ compilation will fail due to variable redeclaration at line 224 of cublaslt_gemm.cu, and (2) logic bug at line 484 of linear.py breaks FP8 state management by requiring all three tensors to have requires_grad=True instead of any one of them
  • Pay close attention to transformer_engine/common/gemm/cublaslt_gemm.cu (compilation error) and transformer_engine/pytorch/module/linear.py (logic bug)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/gemm/cublaslt_gemm.cu 1/5 Caches function call result to reduce overhead, but contains variable redeclaration error that will cause compilation failure
transformer_engine/common/transformer_engine.cpp 5/5 Makes num_devices static to avoid redundant calls to cuda::num_devices() - valid optimization
transformer_engine/pytorch/module/linear.py 0/5 Caches requires_grad checks for performance, but contains critical logic bug at line 484 that changes FP8 state management behavior

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Linear as Linear Module
    participant Quantizer as Quantizer/QuantizedTensor
    participant GEMM as GEMM Operations
    participant CPP as C++ Extensions

    Note over Linear,CPP: Performance Optimization Flow
    
    User->>Linear: forward(input, weight, bias)
    
    Note over Linear: Cache requires_grad checks
    Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>weight_requires_grad = weight.requires_grad<br/>bias_requires_grad = bias.requires_grad
    
    Linear->>Quantizer: Check if quantized tensor
    alt QuantizedTensor
        Note over Quantizer: Use cached dtype property
        Quantizer->>Quantizer: return self._dtype
        Note over Quantizer: Use cached shape/is_cuda
        Quantizer->>Quantizer: return self._data.shape
    else Regular Tensor
        Quantizer->>Linear: Standard attribute access
    end
    
    Linear->>CPP: get_tensor_device(tensor)
    Note over CPP: Reordered attribute checks
    CPP->>CPP: Check _rowwise_data first<br/>Check _columnwise_data<br/>Check device last
    CPP-->>Linear: device_index
    
    Linear->>GEMM: Configure GEMM parameters
    Note over GEMM: Cache nvte_is_non_tn_fp8_gemm_supported
    GEMM->>CPP: nvte_is_non_tn_fp8_gemm_supported()
    Note over CPP: Static num_devices cached
    CPP-->>GEMM: support_flag
    GEMM->>GEMM: Store in local variable
    
    GEMM->>GEMM: Execute optimized GEMM
    GEMM-->>Linear: output
    
    Note over Linear: FP8 State Management
    alt FP8 enabled and requires_grad check
        Linear->>Linear: Update FP8 tensors<br/>based on cached flags
    end
    
    Linear-->>User: output
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 7, 2026

Additional Comments (2)

transformer_engine/common/gemm/cublaslt_gemm.cu
variable redeclared in same scope - already declared at line 132

    // int is_nvte_non_tn_fp8_gemm_supported already declared at line 132

transformer_engine/pytorch/module/linear.py
logic change from original requires_grad(inp, weight, bias) which returns True if ANY tensor requires grad. New code requires ALL THREE to be True, breaking FP8 state management when bias is None or doesn't require grad

            if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):

Signed-off-by: Varun Thumbe <[email protected]>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same scope - it was already declared at line 224 for the B matrix configuration. This causes a compilation error or shadowing issue.

The variable should only be declared once at the beginning of the function scope and reused for both A and B matrix configurations.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Move this declaration before line 125 (Configure A matrix)
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();

Comment on lines 165 to 166
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory leak: PyTuple_New(0) creates a new Python tuple object with a reference count of 1, but this reference is never decremented. The tuple is passed directly to PyObject_Call which doesn't steal the reference, so it must be explicitly released.

This same issue occurs at lines 186, 398, 417, 697, 718, 1039, 1058, 1346, and 1367.

Suggested change
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
args, kwargs);
Py_DECREF(args);
Py_DECREF(kwargs);

PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invalid parameter: The Float8TensorStorage.__new__ method does not accept a device parameter. Looking at the Python class definition in float8_tensor_storage.py, the __new__ signature only expects: data, fp8_scale_inv, fp8_dtype, data_transpose, and quantizer.

This will cause a runtime error when trying to create Float8TensorStorage instances. This issue also occurs at lines 184, 396, 415, and in similar patterns for MXFP8 and other quantizers.

Suggested change
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());
// Remove this line - device parameter is not accepted
// PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

Comment on lines 374 to 378
def requires_grad(self) -> bool:
# Attribute access of custom tensors goes through an
# expensive Pyobject lookup. Since requires_grad is set during
# initialization and may be updated, we cache it in a member variable.
return self._requires_grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: The requires_grad property getter accesses self._requires_grad at line 378, but this member variable is never initialized in the __new__ method (lines 353-364). This will cause an AttributeError when the property is first accessed.

The _requires_grad attribute must be initialized in __new__:

Suggested change
def requires_grad(self) -> bool:
# Attribute access of custom tensors goes through an
# expensive Pyobject lookup. Since requires_grad is set during
# initialization and may be updated, we cache it in a member variable.
return self._requires_grad
@property
def requires_grad(self) -> bool:
# Attribute access of custom tensors goes through an
# expensive Pyobject lookup. Since requires_grad is set during
# initialization and may be updated, we cache it in a member variable.
# Fallback to parent if not cached yet
if not hasattr(self, '_requires_grad'):
self._requires_grad = torch.Tensor.requires_grad.fget(self)
return self._requires_grad

Comment on lines +363 to 364
instance._dtype = dtype
return instance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing initialization: The _requires_grad member variable needs to be initialized here to support the cached property added later in the PR. Without this, accessing the requires_grad property will fail.

Suggested change
instance._dtype = dtype
return instance
instance._dtype = dtype
instance._requires_grad = requires_grad
return instance

PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
bool is_extension_initialized = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Race condition: The is_extension_initialized flag is not thread-safe. If init_extension() is called concurrently from multiple threads, both threads could pass the check at line 105 before either sets the flag at line 110, causing the initialization code to run multiple times.

The original implementation using null-checks on the individual class pointers was safer because Python's module import system provides thread safety. Consider using std::once_flag and std::call_once for proper thread-safe initialization:

Suggested change
bool is_extension_initialized = false;
bool is_extension_initialized = false;
std::once_flag extension_init_flag;

And in init_extension():

void init_extension() {
  std::call_once(extension_init_flag, []() {
    init_float8_extension();
    init_mxfp8_extension();
    init_float8blockwise_extension();
    init_nvfp4_extensions();
  });
}

Comment on lines +964 to 966
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential edge case: The num_devices is captured once as a static variable, so if the number of devices changes at runtime (e.g., with CUDA MPS or device enumeration changes), the cache and flags vectors could be undersized, leading to out-of-bounds access at line 968.

While dynamic device changes are rare, consider adding a bounds check:

Suggested change
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count");

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same function scope (already declared at line 132). In C++, this is a compilation error. The variable should be declared once at the function scope level before both the A matrix and B matrix configuration blocks.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Move declaration to line 104 (after GemmParam ret;) to avoid redeclaration

Comment on lines 146 to 150
py::object scale_inv_py = py::cast(scale_inv);
// Initialize scale-inverse tensor
if (!scale_inv) {
scale_inv = at::reciprocal(scale);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: scale_inv_py is created on line 146 by dereferencing scale_inv (which is std::optional<at::Tensor>), but scale_inv is only initialized on lines 148-150 if it's empty. When scale_inv is not provided to this function, line 146 will dereference an empty optional, causing undefined behavior or a crash.

The fix is to move line 146 after the initialization block:

Suggested change
py::object scale_inv_py = py::cast(scale_inv);
// Initialize scale-inverse tensor
if (!scale_inv) {
scale_inv = at::reciprocal(scale);
}
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
// Initialize scale-inverse tensor
if (!scale_inv) {
scale_inv = at::reciprocal(scale);
}
py::object scale_inv_py = py::cast(*scale_inv);

PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory leak: PyTuple_New(0) creates a new tuple with reference count 1, but PyObject_Call does not steal the reference. The tuple must be stored and explicitly DECREF'd after the call. This pattern is repeated 12 times throughout this file (lines 166, 186, 398, 417, 697, 718, 1039, 1058, 1346, 1367).

Suggested change
PyTuple_New(0), kwargs);
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
args, kwargs);
Py_DECREF(args);

Comment on lines +363 to 364
instance._dtype = dtype
return instance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: _dtype is cached in __new__ but _requires_grad is not initialized. The requires_grad property (line 378) tries to return self._requires_grad, which will raise AttributeError when accessed before the setter is called. This breaks the optimization for caching requires_grad.

Suggested change
instance._dtype = dtype
return instance
instance._dtype = dtype
instance._requires_grad = requires_grad
return instance

vthumbe1503 and others added 3 commits January 11, 2026 19:12
Signed-off-by: Varun Thumbe <[email protected]>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <[email protected]>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: vthumbe1503 <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 13 comments

Edit Code Review Agent Settings | Greptile

ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Variable is_nvte_non_tn_fp8_gemm_supported is already declared at line 132. Declaring it again here causes a compilation error. Remove this declaration and reuse the variable from line 132.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Reuse is_nvte_non_tn_fp8_gemm_supported from line 132

}
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();

py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Use-after-free bug: scale_inv_py is created from scale_inv before scale_inv is initialized. When scale_inv parameter is empty (lines 148-150), this dereferences an empty std::optional, causing undefined behavior. Move this line after the scale_inv initialization.

Suggested change
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
// Move scale_inv_py creation after initialization (line 150)

PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) creates a new Python object with refcount 1, but it's never decremented. This leaks memory on every tensor creation. Add Py_DECREF for the tuple:

Suggested change
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);
Py_DECREF(args);

PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same issue as line 166).

Suggested change
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorPythonClass), args, kwargs);
Py_DECREF(args);


ctx.owns_input = saved_inputmat is not inp
if ctx.fp8 and requires_grad(inp, weight, bias):
if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Logic error: The original requires_grad(inp, weight, bias) returns True if ANY tensor requires grad. This change requires ALL THREE to be True with and, breaking FP8 state management when bias is None or doesn't require grad.

Suggested change
if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):
if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):

PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr());
PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Variable is_nvte_non_tn_fp8_gemm_supported is declared twice in the same function scope (first at line 132). This will cause a compilation error. Remove this redeclaration and reuse the variable from line 132.

}
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();

py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Duplicate line. transpose_py is assigned twice (lines 145 and 146). Remove this duplicate line.

PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) creates a new Python object that is never decremented. Store the result and call Py_DECREF(tuple) after PyObject_Call. This pattern repeats in all 12 calls to PyObject_Call throughout this file (lines 167, 186, 331, 350, 397, 416, 706, 725, 1042, 1061, 1347, 1366).

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +156 to +171
// Use direct C API call bypassing pybind11 overhead
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);

Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
out_py = py::reinterpret_steal<py::object>(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potential memory leak if py::cast() throws exception before Py_DECREF

If any of the py::cast(this->dtype), py::cast(shape_int64), or py::cast(device) calls throw an exception, the kwargs and args PyObjects will leak because Py_DECREF on lines 168-169 will never be reached.

Suggested change
// Use direct C API call bypassing pybind11 overhead
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);
Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
out_py = py::reinterpret_steal<py::object>(result);
// Use RAII for Python object lifetime management
py::dict kwargs_py;
kwargs_py["data"] = data_py;
kwargs_py["fp8_scale_inv"] = scale_inv_py;
kwargs_py["fp8_dtype"] = py::cast(this->dtype);
kwargs_py["data_transpose"] = transpose_py;
kwargs_py["quantizer"] = this->quantizer;
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
args, kwargs_py.ptr());
Py_DECREF(args);

This same pattern repeats throughout the file at lines 176-193, 382-400, 411-428, 657-678, 686-707, 1013-1034, 1042-1063, 1304-1325, and 1333-1354. All instances should use pybind11's py::dict for exception safety.

Comment on lines +151 to +152
at::Device device =
with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device defaults to cuda:0 when both data and transpose are None, may not be current device

In multi-GPU setups, if neither with_data nor with_transpose are true, the device defaults to torch::kCUDA (which implicitly converts to cuda:0). This may not match the current CUDA device, potentially causing cross-device operation issues.

Suggested change
at::Device device =
with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA);
at::Device device =
with_data ? data->device() : (with_transpose ? transpose->device() : at::Device(torch::kCUDA, torch::cuda::current_device()));

This same issue exists at line 383 in Float8CurrentScalingQuantizer::create_tensor.

Comment on lines 79 to 89
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
return tensor._columnwise_data.device.index
if hasattr(tensor, "_data") and tensor._data is not None:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
if hasattr(tensor, "device"):
return tensor.device.index
return torch.cuda.current_device()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reordering attribute checks changes behavior when devices don't match

The new code checks quantized tensor attributes (_rowwise_data, _columnwise_data, etc.) before checking the standard .device attribute. If a tensor has both and they refer to different devices, this changes which device is returned.

For example, if tensor.device == cuda:1 but tensor._rowwise_data.device == cuda:0, the old code would return 1 but the new code returns 0. This could break multi-GPU code that relies on the .device attribute being authoritative.

Consider documenting this behavior change or adding a check that all device attributes match if multiple are present.

Comment on lines +367 to +383
@property
def dtype(self) -> torch.dtype:
"""
Return the high precision data type of the tensor
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since dtype for a tensor is never
change after creation, we cache it in a member variable and return
"""
return self._dtype

@property
def requires_grad(self) -> bool:
"""
Return whether or not the tensor requires gradient.
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since requires_grad is set during
initialization and may be updated, we cache it in a member variable.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_dtype and _requires_grad only initialized in __new__, will cause AttributeError in other creation paths

The new dtype and requires_grad properties rely on cached _dtype and _requires_grad attributes, but these are only initialized in the __new__ method (lines 363-364). If a QuantizedTensor is created through other paths such as:

  • Unpickling (__reduce_ex__, __setstate__)
  • torch.Tensor subclass operations that bypass __new__
  • FSDP state dict loading

Accessing .dtype or .requires_grad will raise AttributeError: '_dtype' not found.

Suggested change
@property
def dtype(self) -> torch.dtype:
"""
Return the high precision data type of the tensor
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since dtype for a tensor is never
change after creation, we cache it in a member variable and return
"""
return self._dtype
@property
def requires_grad(self) -> bool:
"""
Return whether or not the tensor requires gradient.
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since requires_grad is set during
initialization and may be updated, we cache it in a member variable.
@property
def dtype(self) -> torch.dtype:
"""
Return the high precision data type of the tensor
"""
# Lazy initialization for tensors created via alternate paths
if not hasattr(self, '_dtype'):
self._dtype = super().__getattribute__('dtype')
return self._dtype

Apply similar fix to requires_grad property.

Comment on lines +913 to +916
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return self._data.shape if self._data is not None else self._transpose.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape property will raise AttributeError if both _data and _transpose are None

The property assumes at least one of _data or _transpose is not None, but doesn't validate the _transpose attribute before accessing it. If both are None (edge case during initialization or after certain operations), accessing .shape will raise AttributeError: 'NoneType' object has no attribute 'shape'.

Suggested change
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return self._data.shape if self._data is not None else self._transpose.shape
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
if self._data is not None:
return self._data.shape
elif self._transpose is not None:
return self._transpose.shape
else:
# Fallback to getting shape from alternate data sources or raise descriptive error
raise RuntimeError("Float8Tensor has no data buffers initialized")

Same issue exists in the is_cuda property at lines 919-921.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.requires_wgrad = weight.requires_grad
ctx.requires_dgrad = inp_requires_grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic change from requires_grad(inp, weight, bias) to (inp_requires_grad or weight_requires_grad or bias_requires_grad) is correct for the optimization, but this breaks the FP8 state management semantic.

The original requires_grad() function returns True if ANY tensor requires grad (see utils.py line 23-28). The new code uses and instead of or:

if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):

However, looking more carefully at the actual change, the new code uses or, which is semantically equivalent to the old requires_grad(inp, weight, bias) call. This is actually correct!

But wait - I need to double-check the previous thread comment. The previous thread says "logic change from original requires_grad(inp, weight, bias) which returns True if ANY tensor requires grad. New code requires ALL THREE to be True". Let me verify this is indeed using or not and...

Comment on lines +367 to +375
@property
def dtype(self) -> torch.dtype:
"""
Return the high precision data type of the tensor
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since dtype for a tensor is never
change after creation, we cache it in a member variable and return
"""
return self._dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype property accesses self._dtype without checking if it exists. While QuantizedTensor.__new__() initializes this at line 364, subclasses like Float8Tensor can be instantiated through paths that may not properly initialize these cached attributes.

For example, when Float8TensorStorage.__new__() calls super().__new__(cls, *args, **kwargs), if dtype is not in kwargs, QuantizedTensor.__new__() won't receive it and won't set instance._dtype.

Consider adding defensive checks or ensuring all instantiation paths properly initialize _dtype and _requires_grad, or catch AttributeError and fall back to the parent class property.

Comment on lines +157 to +171
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);

Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
out_py = py::reinterpret_steal<py::object>(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential memory leak if any operation between PyDict_New()/PyTuple_New() and Py_DECREF() throws an exception or fails. The py::cast() calls at lines 161-163 could throw C++ exceptions, and PyDict_SetItemString() could fail, leaving kwargs and args unreleased.

Consider using RAII wrappers or try-catch blocks to ensure cleanup:

PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
if (!kwargs || !args) {
    Py_XDECREF(kwargs);
    Py_XDECREF(args);
    throw std::runtime_error("Failed to allocate Python objects");
}

try {
    PyDict_SetItemString(kwargs, "data", data_py.ptr());
    // ... rest of setup
    PyObject* result = PyObject_Call(...);
    Py_DECREF(kwargs);
    Py_DECREF(args);
    NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
    out_py = py::reinterpret_steal<py::object>(result);
} catch (...) {
    Py_DECREF(kwargs);
    Py_DECREF(args);
    throw;
}

Or better yet, use py::dict and py::tuple from pybind11 which handle cleanup automatically.

Comment on lines +913 to +921
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return self._data.shape if self._data is not None else self._transpose.shape

@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
return self._data.is_cuda if self._data is not None else self._transpose.is_cuda
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape and is_cuda properties assume that either self._data or self._transpose is not None. However, looking at the quantizer logic in Float8Quantizer.create_tensor() (quantizer.cpp lines 126-137), there are cases where:

  • with_data = rowwise_usage || is_non_tn_fp8_gemm_supported
  • with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported

If both rowwise_usage and columnwise_usage are False, and is_non_tn_fp8_gemm_supported is False, then both _data and _transpose would be None, causing AttributeError when accessing .shape or .is_cuda.

This scenario should probably be prevented at quantizer initialization, or these properties should handle the None case gracefully.

Comment on lines 79 to 89
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
return tensor._columnwise_data.device.index
if hasattr(tensor, "_data") and tensor._data is not None:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
if hasattr(tensor, "device"):
return tensor.device.index
return torch.cuda.current_device()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reordering of device attribute checks changes the priority of which device is returned. Previously, if a tensor had both device attribute and internal attributes like _data, it would return tensor.device.index. Now it returns the device of the internal attribute.

For QuantizedTensor instances where these might differ (though unlikely), this could change behavior. Consider the case where:

  • A QuantizedTensor has _data on device 0
  • But the tensor itself reports being on device 1 (via its device property)

The old code would return device 1, the new code returns device 0. While this scenario is unlikely given how QuantizedTensor is implemented, it's a semantic change worth noting. If the internal attributes are guaranteed to be on the same device as the tensor, this is fine. Otherwise, document this priority or add an assertion.


int nvte_is_non_tn_fp8_gemm_supported() {
int num_devices = transformer_engine::cuda::num_devices();
static int num_devices = transformer_engine::cuda::num_devices();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making num_devices static has a subtle initialization order issue. The static initialization happens once when the function is first called, but the subsequent static vectors cache and flags depend on num_devices for their size.

If transformer_engine::cuda::num_devices() returns different values across multiple calls (which shouldn't happen in practice but isn't guaranteed by the API), the first call to this function will initialize num_devices, and subsequent calls will use that cached value. However, if the CUDA context changes or devices are added/removed (in rare scenarios), this could cause a mismatch.

Consider:

static int num_devices = transformer_engine::cuda::num_devices();

This is initialized once, but cache and flags vectors might need a different size if the device count somehow changes. While unlikely, this could cause out-of-bounds access.

A safer approach might be:

static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
  cache.resize(num_devices, -1);
  flags.resize(num_devices);
});

Or simply document that the device count must not change during the application's lifetime.

Signed-off-by: Varun Thumbe <[email protected]>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical memory management bug: py::cast(this->dtype).ptr() creates a temporary py::object that is immediately destroyed after .ptr() is called, but PyDict_SetItemString increments the reference count of the PyObject. This pattern is used throughout this file and can lead to use-after-free bugs or memory corruption.

The issue is that:

  1. py::cast(this->dtype) creates a temporary py::object with refcount=1
  2. .ptr() returns the raw PyObject* pointer
  3. PyDict_SetItemString increments refcount to 2
  4. The temporary py::object goes out of scope, decrementing refcount back to 1
  5. The dictionary now holds the only reference to the object, which works BUT
  6. If the temporary object's destructor releases resources or if there are any edge cases, this becomes unsafe

Fix: Store py::cast() results in named variables before calling PyDict_SetItemString:

Suggested change
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
py::object fp8_dtype_obj = py::cast(this->dtype);
PyDict_SetItemString(kwargs, "fp8_dtype", fp8_dtype_obj.ptr());

This same pattern appears on lines: 178, 179, 182, 185, 394, 411, 412, 415, 418, 691-698, 712-721, 1042, 1056-1057, 1062, 1352, 1366-1367, 1374. All instances must be fixed.


Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error check occurs AFTER Py_DECREF(kwargs) and Py_DECREF(args) (lines 168-169), which means if result is nullptr, the program will terminate via NVTE_CHECK without proper cleanup. While this is acceptable for fatal errors, it's inconsistent with the cleanup-before-check pattern and could mask the actual Python exception message.

Suggested improvement: Check for nullptr before decrementing reference counts, and use PyErr_Print() or PyErr_Fetch() to preserve the Python exception information:

Suggested change
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);
if (result == nullptr) {
PyErr_Print(); // Print the actual Python error
Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(false, "Failed to create Float8TensorStorage instance");
}
Py_DECREF(kwargs);
Py_DECREF(args);
out_py = py::reinterpret_steal<py::object>(result);

This same pattern applies to all other PyObject_Call invocations in this file (lines 187, 401, 706, 723, 1047, 1068, 1357, 1378).

Comment on lines 145 to 148
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();

// Initialize scale-inverse tensor
if (!scale_inv) {
scale_inv = at::reciprocal(scale);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code initializes scale_inv if not provided, then creates scale_inv_py, then computes device from the tensors. However, device is only computed if with_data or with_transpose is true, but it's used unconditionally on line 185. If neither with_data nor with_transpose is true (which shouldn't happen per the check on line 196, but code flow is unclear), device would be uninitialized.

The flow would be clearer if the device calculation happens before constructing the Python tensor, as it's needed for the constructor. Current code is safe due to the check on line 196, but the logic flow could be clearer.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 104 to +110
void init_extension() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good use of std::call_once to ensure thread-safe initialization! However, the individual init functions (init_float8_extension, etc.) previously had guards that checked if the classes were already initialized (e.g., if (Float8TensorPythonClass) return;). These guards have been removed.

While std::call_once ensures the lambda is only executed once, if init_extension() is called multiple times from different contexts (e.g., multi-threaded module loading), the std::call_once protection is solid. However, if any of the individual init functions fail partway through (e.g., PyObject_GetAttrString returns nullptr), the subsequent init functions in the same call would still execute, potentially leaving the extension in a partially initialized state.

Recommendation: Add nullptr checks after each PyObject_GetAttrString call to ensure all classes were successfully loaded before proceeding.

Comment on lines +387 to +393
@requires_grad.setter
def requires_grad(self, value: bool) -> None:
"""Set requires_grad property so that autograd engine is aware of the change"""
# Update the cached value
self._requires_grad = value
# Call parent class to ensure autograd engine is aware of the change
torch.Tensor.requires_grad.fset(self, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The requires_grad property setter updates the cached _requires_grad value and then calls torch.Tensor.requires_grad.fset(self, value) to ensure the autograd engine is aware. However, there's a potential issue: if the parent class's setter fails or raises an exception, _requires_grad will already be updated, causing the cached value to be out of sync with the actual tensor state.

Suggested improvement: Call the parent setter first, then update the cache only if successful:

Suggested change
@requires_grad.setter
def requires_grad(self, value: bool) -> None:
"""Set requires_grad property so that autograd engine is aware of the change"""
# Update the cached value
self._requires_grad = value
# Call parent class to ensure autograd engine is aware of the change
torch.Tensor.requires_grad.fset(self, value)
@requires_grad.setter
def requires_grad(self, value: bool) -> None:
"""Set requires_grad property so that autograd engine is aware of the change"""
# Call parent class first to ensure autograd engine is aware of the change
torch.Tensor.requires_grad.fset(self, value)
# Update the cached value only after successful update
self._requires_grad = value

The same pattern should be applied to requires_grad_() method on lines 402-409.

Comment on lines +395 to +400
@dtype.setter
def dtype(self, value: torch.dtype) -> None:
"""Set dtype property"""
# Update the cached value
self._dtype = value
warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype property has a setter that allows changing the dtype after creation, with a warning. However, this is problematic because:

  1. The warning says "Dtype of QuantizedTensor has been changed" but doesn't indicate this could break things
  2. QuantizedTensor subclasses (Float8Tensor, MXFP8Tensor, etc.) have quantized data in specific formats (FP8, MXFP4) that don't change when you update _dtype
  3. Changing dtype without re-quantizing the underlying data creates a mismatch between the cached dtype and the actual data representation

This setter creates a dangerous API surface. If dtype must be mutable, it should either:

  • Raise an error instead of just warning
  • Or trigger re-quantization of the data

Current impact: Users could accidentally corrupt tensor semantics by changing dtype, leading to incorrect dequantization.

Comment on lines 910 to +920
)
return out, all_gather_outputs

@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return self._data.shape if self._data is not None else self._transpose.shape

@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new shape and is_cuda property implementations have a logic issue: they return properties from self._data if it exists, otherwise from self._transpose. However, according to the codebase logic in cublaslt_gemm.cu and quantizer.cpp, on Hopper (non-Blackwell) architectures, FP8 tensors may only have _transpose data for certain GEMM layouts, while on Blackwell they may only have _data.

The properties assume that if _data is None, then _transpose must exist, but there's no fallback handling if both are somehow None (which shouldn't happen but could in edge cases).

Suggested improvement: Add a safety check or document the invariant that at least one must always exist:

Suggested change
)
return out, all_gather_outputs
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return self._data.shape if self._data is not None else self._transpose.shape
@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
if self._data is not None:
return self._data.shape
if self._transpose is not None:
return self._transpose.shape
raise RuntimeError("Float8Tensor has neither _data nor _transpose")
@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
if self._data is not None:
return self._data.is_cuda
if self._transpose is not None:
return self._transpose.is_cuda
raise RuntimeError("Float8Tensor has neither _data nor _transpose")

Comment on lines +810 to +817
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return (
self._rowwise_data.shape
if self._rowwise_data is not None
else self._columnwise_data.shape
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as Float8Tensor: the shape and is_cuda properties assume that if _rowwise_data is None, then _columnwise_data must exist. However, there's no safety check for the edge case where both could be None.

Suggested improvement: Add safety checks:

Suggested change
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return (
self._rowwise_data.shape
if self._rowwise_data is not None
else self._columnwise_data.shape
)
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
if self._rowwise_data is not None:
return self._rowwise_data.shape
if self._columnwise_data is not None:
return self._columnwise_data.shape
raise RuntimeError("MXFP8Tensor has neither _rowwise_data nor _columnwise_data")
@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
if self._rowwise_data is not None:
return self._rowwise_data.is_cuda
if self._columnwise_data is not None:
return self._columnwise_data.is_cuda
raise RuntimeError("MXFP8Tensor has neither _rowwise_data nor _columnwise_data")

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +157 to +170
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);

Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing error checking for Python C API calls could lead to crashes. The code doesn't check return values from PyDict_New(), PyTuple_New(), PyDict_SetItemString(), or PyObject_Call() before proceeding. If any of these fail (e.g., due to memory allocation failure), the subsequent operations will access invalid pointers.

Recommendation:
Add error checks after each C API call:

  • Check if PyDict_New() and PyTuple_New() return NULL
  • Check if PyDict_SetItemString() returns -1
  • The existing NVTE_CHECK(result != nullptr, ...) is good, but should come before Py_DECREF calls to avoid decrementing invalid references

This pattern is repeated throughout this file at lines 176-193, 389-404, 410-426, 688-706, 710-727, 1035-1050, 1054-1070, 1343-1361, and 1365-1383. All instances need the same error checking.

Comment on lines +377 to +393
@property
def requires_grad(self) -> bool:
"""
Return whether or not the tensor requires gradient.
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since requires_grad is set during
initialization and may be updated, we cache it in a member variable.
"""
return self._requires_grad

@requires_grad.setter
def requires_grad(self, value: bool) -> None:
"""Set requires_grad property so that autograd engine is aware of the change"""
# Update the cached value
self._requires_grad = value
# Call parent class to ensure autograd engine is aware of the change
torch.Tensor.requires_grad.fset(self, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential cache inconsistency if PyTorch modifies requires_grad through internal mechanisms. While the setter and requires_grad_() method properly sync the cache with the parent tensor, there's a risk that PyTorch's autograd engine could modify requires_grad through paths that bypass these methods (e.g., during graph construction or when tensors are involved in certain operations).

Current behavior: The cached value could become stale if PyTorch modifies requires_grad internally.

Mitigation: The implementation correctly updates the parent tensor in the setter and requires_grad_(), which should handle most cases. However, consider documenting this limitation or adding validation in critical paths to ensure cache consistency.

Note: This is a minor concern as most requires_grad modifications go through the standard setter, but worth being aware of for debugging purposes.

Comment on lines +189 to +193
Py_DECREF(kwargs);
Py_DECREF(args);

NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance");
out_py = py::reinterpret_steal<py::object>(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error checking order issue: NVTE_CHECK for result != nullptr happens after Py_DECREF(kwargs) and Py_DECREF(args). While decrefing these is safe even when result is NULL, checking for errors before cleanup is better practice and makes the error path clearer.

If PyObject_Call fails and returns NULL, Python's exception state is set. The NVTE_CHECK should happen immediately after the call, and cleanup should be part of an error handling path. Consider moving the check before Py_DECREF or using proper error handling with early returns.

This pattern affects all PyObject_Call instances in this file.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant