Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
93ee022
add all the optimizations
vthumbe1503 Jan 5, 2026
06338bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2026
50de9cd
requires_grad optimization
vthumbe1503 Jan 6, 2026
5fee841
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 6, 2026
4c79ac7
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 6, 2026
62b88e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2026
99494d7
test if commenting out requires_grad works
vthumbe1503 Jan 7, 2026
b157f85
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 7, 2026
2a7b627
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 7, 2026
b61a6a8
fix minor bug
vthumbe1503 Jan 7, 2026
938651e
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 7, 2026
88dfdbd
fix ci
vthumbe1503 Jan 11, 2026
1526eea
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 11, 2026
5809dcc
missed a bug
vthumbe1503 Jan 11, 2026
b3bd748
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 11, 2026
30fecf2
Update transformer_engine/pytorch/csrc/quantizer.cpp
vthumbe1503 Jan 11, 2026
1b0d497
fix some bugs pointed to by copilot
vthumbe1503 Jan 11, 2026
138b7bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2026
eec1e86
linting error
vthumbe1503 Jan 11, 2026
8169d9c
fix the error
vthumbe1503 Jan 12, 2026
6fefaf2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2026
a5feaf9
fix the bug
vthumbe1503 Jan 13, 2026
285dbff
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 13, 2026
afb2f23
get rid of the change
vthumbe1503 Jan 13, 2026
3919cb8
fix the transpose shape bug
vthumbe1503 Jan 13, 2026
fd36424
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 13, 2026
4668133
minor linter fix
vthumbe1503 Jan 13, 2026
5a00652
fix lint
vthumbe1503 Jan 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
// Set conditions for MXFP8 and NVFP4 gemm execution.
const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode);
const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode);
int is_nvte_non_tn_fp8_gemm_supported = 0; // needed only for per tensor scaling
if (is_tensor_scaling(A.scaling_mode) || is_tensor_scaling(B.scaling_mode)) {
is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
}

// Configure A matrix
if (is_tensor_scaling(A.scaling_mode)) {
Expand All @@ -129,7 +133,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr;
Expand All @@ -140,7 +144,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype),
Expand Down Expand Up @@ -220,7 +224,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr;
Expand All @@ -231,7 +235,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype),
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
}

int nvte_is_non_tn_fp8_gemm_supported() {
int num_devices = transformer_engine::cuda::num_devices();
static int num_devices = transformer_engine::cuda::num_devices();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making num_devices static has a subtle initialization order issue. The static initialization happens once when the function is first called, but the subsequent static vectors cache and flags depend on num_devices for their size.

If transformer_engine::cuda::num_devices() returns different values across multiple calls (which shouldn't happen in practice but isn't guaranteed by the API), the first call to this function will initialize num_devices, and subsequent calls will use that cached value. However, if the CUDA context changes or devices are added/removed (in rare scenarios), this could cause a mismatch.

Consider:

static int num_devices = transformer_engine::cuda::num_devices();

This is initialized once, but cache and flags vectors might need a different size if the device count somehow changes. While unlikely, this could cause out-of-bounds access.

A safer approach might be:

static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
  cache.resize(num_devices, -1);
  flags.resize(num_devices);
});

Or simply document that the device count must not change during the application's lifetime.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Potential out-of-bounds access with static num_devices

Making num_devices static creates a serious bug when GPUs are hot-plugged after the first call to this function. The cache and flags vectors are sized based on the initial device count, but device_id from current_device() could exceed num_devices if GPUs are added later.

This will cause out-of-bounds access on lines 968 and 975:

std::call_once(flags[device_id], ...);  // OOB if device_id >= num_devices
return cache[device_id];                 // OOB if device_id >= num_devices

Impact: Undefined behavior, potential crashes, memory corruption.

Fix: Query device count each time, or add bounds checking:

int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::mutex resize_mutex;

int device_id = transformer_engine::cuda::current_device();
{
  std::lock_guard<std::mutex> lock(resize_mutex);
  if (device_id >= cache.size()) {
    cache.resize(device_id + 1, -1);
    flags.resize(device_id + 1);
  }
}
std::call_once(flags[device_id], [&]() { ... });

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential race condition in static variable initialization

Making num_devices static improves performance by caching the device count, but this introduces a subtle race condition. While the cache and flags vectors are protected by std::call_once, num_devices itself is initialized on first function call without synchronization.

Scenario:

  1. Thread A enters function first time, starts calling transformer_engine::cuda::num_devices()
  2. Thread B enters function while A is still in num_devices(), sees uninitialized static
  3. Both threads may initialize num_devices, causing undefined behavior

However, this is likely safe in practice because:

  • The function is typically called during module initialization (single-threaded)
  • The value doesn't change after initialization
  • Most compilers guarantee thread-safe static initialization in C++11

But for correctness, consider wrapping in the existing std::call_once:

static int num_devices = []() {
  return transformer_engine::cuda::num_devices();
}();

Or initialize it within the std::call_once block if device_id doesn't need num_devices.

static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
Comment on lines +964 to 966
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential edge case: The num_devices is captured once as a static variable, so if the number of devices changes at runtime (e.g., with CUDA MPS or device enumeration changes), the cache and flags vectors could be undersized, leading to out-of-bounds access at line 968.

While dynamic device changes are rare, consider adding a bounds check:

Suggested change
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count");

Comment on lines +964 to 966
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: Potential out-of-bounds access when device_id >= num_devices

Making num_devices static causes it to be evaluated only once at first call. This creates a serious bug if the device ID is not within the range [0, num_devices):

  1. Problem: CUDA device IDs are not necessarily contiguous. For example, with CUDA_VISIBLE_DEVICES=0,2, you have 2 visible devices but device IDs are 0 and 2.
  2. Bug scenario: If num_devices is initialized to 2, but later code calls current_device() which returns 2, then cache[device_id] at line 975 will access cache[2] which is out of bounds (cache size is 2, valid indices are 0-1).
  3. Impact: This will cause undefined behavior, likely a segmentation fault or memory corruption.

Fix: Either:

  • Use current_device() as the cache key and bounds-check it
  • Map actual device IDs to cache indices
  • Remove the static keyword and re-evaluate num_devices on each call (sacrifices some of the optimization but ensures correctness)

int device_id = transformer_engine::cuda::current_device();
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def get_tensor_device(tensor: torch.Tensor) -> int:
The order of attributes checked is important to also
minimize overhead.
"""
if hasattr(tensor, "device"):
return tensor.device.index
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
Expand All @@ -86,6 +84,8 @@ def get_tensor_device(tensor: torch.Tensor) -> int:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
if hasattr(tensor, "device"):
return tensor.device.index
return torch.cuda.current_device()


Expand Down
15 changes: 7 additions & 8 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
std::once_flag extension_init_flag;

void init_float8_extension() {
if (Float8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor");
Float8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer"));
Expand All @@ -54,7 +54,6 @@ void init_float8_extension() {
}

void init_mxfp8_extension() {
if (MXFP8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor");
MXFP8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer"));
Expand All @@ -69,7 +68,6 @@ void init_mxfp8_extension() {
}

void init_float8blockwise_extension() {
if (Float8BlockwiseQTensorStoragePythonClass) return;
auto fp8_module =
py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor");
auto fp8_base_module = py::module_::import(
Expand All @@ -90,7 +88,6 @@ void init_float8blockwise_extension() {
}

void init_nvfp4_extensions() {
if (NVFP4TensorPythonClass) return;
auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor");
NVFP4QuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer"));
Expand All @@ -105,10 +102,12 @@ void init_nvfp4_extensions() {
}

void init_extension() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
});
}

} // namespace transformer_engine::pytorch
Expand Down
Loading
Loading