diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 4b7d8179b0..0699b1876d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -120,6 +120,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // Set conditions for MXFP8 and NVFP4 gemm execution. const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode); const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode); + int is_nvte_non_tn_fp8_gemm_supported = 0; // needed only for per tensor scaling + if (is_tensor_scaling(A.scaling_mode) || is_tensor_scaling(B.scaling_mode)) { + is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + } // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { @@ -129,7 +133,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; ret.lda = is_A_transposed ? k : m; - if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { + if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { ret.A = A.columnwise_data.dptr; @@ -140,7 +144,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), @@ -220,7 +224,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; - if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { + if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { ret.B = B.columnwise_data.dptr; @@ -231,7 +235,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype), diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index 2715d8e4e4..8de3d9ba5b 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -10,6 +10,7 @@ #include #include +#include #include "../common.h" #include "../util/string.h" @@ -29,13 +30,31 @@ void *get_symbol(const char *symbol, int cuda_version = 12010); * without GPUs. Indirect function calls into a lazily-initialized * library ensures we are accessing the correct version. * + * Symbol pointers are cached to avoid repeated lookups. + * * \param[in] symbol Function name * \param[in] args Function arguments */ template inline CUresult call(const char *symbol, ArgTs... args) { using FuncT = CUresult(ArgTs...); - FuncT *func = reinterpret_cast(get_symbol(symbol)); + + // Cache for symbol pointers + static std::unordered_map symbol_cache; + + // Check if symbol is already cached + auto it = symbol_cache.find(symbol); + FuncT *func; + + if (it != symbol_cache.end()) { + func = reinterpret_cast(it->second); + } else { + // Symbol not in cache, look it up and cache the result + void *ptr = get_symbol(symbol); + symbol_cache[symbol] = ptr; + func = reinterpret_cast(ptr); + } + return (*func)(args...); } diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 2a97e2ac71..7fe37b5f54 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -76,8 +76,6 @@ def get_tensor_device(tensor: torch.Tensor) -> int: The order of attributes checked is important to also minimize overhead. """ - if hasattr(tensor, "device"): - return tensor.device.index if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: return tensor._rowwise_data.device.index if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: @@ -86,6 +84,8 @@ def get_tensor_device(tensor: torch.Tensor) -> int: return tensor._data.device.index if hasattr(tensor, "_transpose") and tensor._transpose is not None: return tensor._transpose.device.index + if hasattr(tensor, "device"): + return tensor.device.index return torch.cuda.current_device() diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c5c8905294..b04ea13b2b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -35,9 +35,9 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *NVFP4TensorPythonClass = nullptr; PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; PyTypeObject *NVFP4QuantizerClass = nullptr; +std::once_flag extension_init_flag; void init_float8_extension() { - if (Float8TensorPythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); Float8QuantizerClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); @@ -54,7 +54,6 @@ void init_float8_extension() { } void init_mxfp8_extension() { - if (MXFP8TensorPythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor"); MXFP8QuantizerClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer")); @@ -69,7 +68,6 @@ void init_mxfp8_extension() { } void init_float8blockwise_extension() { - if (Float8BlockwiseQTensorStoragePythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); auto fp8_base_module = py::module_::import( @@ -90,7 +88,6 @@ void init_float8blockwise_extension() { } void init_nvfp4_extensions() { - if (NVFP4TensorPythonClass) return; auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor"); NVFP4QuantizerClass = reinterpret_cast( PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer")); @@ -105,10 +102,12 @@ void init_nvfp4_extensions() { } void init_extension() { - init_float8_extension(); - init_mxfp8_extension(); - init_float8blockwise_extension(); - init_nvfp4_extensions(); + std::call_once(extension_init_flag, []() { + init_float8_extension(); + init_mxfp8_extension(); + init_float8blockwise_extension(); + init_nvfp4_extensions(); + }); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index a73efc008a..0e4a0ca355 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -121,9 +121,9 @@ std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Initialize data tensor - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data && !data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -134,7 +134,7 @@ std::pair Float8Quantizer::create_tensor( py::object data_py = with_data ? py::cast(*data) : py::none(); // Initialize transpose tensor - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose && !transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -143,26 +143,60 @@ std::pair Float8Quantizer::create_tensor( transpose.reset(); } py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); - // Initialize scale-inverse tensor if (!scale_inv) { scale_inv = at::reciprocal(scale); } - + py::object scale_inv_py = py::cast(*scale_inv); + at::Device device = + with_data ? data->device() + : (with_transpose ? transpose->device() + : at::Device(torch::kCUDA, c10::cuda::current_device())); // Construct Python FP8 tensor py::object out_py; if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyObject* args = PyTuple_New(0); + PyDict_SetItemString(kwargs, "data", data_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* result = + PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), args, kwargs); + if (result == nullptr) { + PyErr_Print(); + } + Py_DECREF(kwargs); + Py_DECREF(args); + NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); - out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyObject* args = PyTuple_New(0); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "data", data_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).inc_ref().ptr()); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8TensorPythonClass), args, kwargs); + if (result == nullptr) { + PyErr_Print(); + } + Py_DECREF(kwargs); + Py_DECREF(args); + + NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ FP8 tensor @@ -185,10 +219,10 @@ std::pair Float8Quantizer::create_tensor( std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); // Extract buffers from Python tensor @@ -328,7 +362,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize data tensor at::Tensor data_tensor; - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -337,13 +372,12 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize transpose tensor at::Tensor transpose_tensor; - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); transpose_tensor = at::empty(transpose_shape, opts); } - // Initialize scale-inverse tensor at::Tensor scale_inv_tensor; { @@ -351,23 +385,57 @@ std::pair Float8CurrentScalingQuantizer::create_tenso const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); scale_inv_tensor = at::empty(scale_inv_shape, opts); } - + at::Device device = + with_data ? data_tensor.device() + : (with_transpose ? transpose_tensor.device() + : at::Device(torch::kCUDA, c10::cuda::current_device())); // Construct Python FP8 tensor py::object out_py; + py::object scale_inv_py = py::cast(scale_inv_tensor); py::object data_py = with_data ? py::cast(data_tensor) : py::none(); py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "data", data_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* args = PyTuple_New(0); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), args, kwargs); + if (result == nullptr) { + PyErr_Print(); + } + Py_DECREF(args); + Py_DECREF(kwargs); + NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); - out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "data", data_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).inc_ref().ptr()); + PyObject* args = PyTuple_New(0); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8TensorPythonClass), args, kwargs); + if (result == nullptr) { + PyErr_Print(); + } + Py_DECREF(args); + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ FP8 tensor @@ -406,10 +474,10 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8CurrentScalingQuantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); // Extract buffers from Python tensor @@ -629,22 +697,53 @@ std::pair Float8BlockQuantizer::create_tensor( py::object ret; if (internal) { - py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass)); - ret = Float8BlockwiseQTensorClass( - "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, - "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, - "is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", + py::cast(scale_inv_colwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.inc_ref().ptr()); + PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).inc_ref().ptr()); + + PyObject* args = PyTuple_New(0); + PyObject* result = PyObject_Call( + reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass), args, kwargs); + if (result == nullptr) { + PyErr_Print(); + } + Py_DECREF(args); + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensorStorage instance"); + ret = py::reinterpret_steal(result); } else { - py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorPythonClass)); - ret = Float8BlockwiseQTensorClass( - "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, - "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, - "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2), - "data_format"_a = data_format); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(torch_shape).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", + py::cast(scale_inv_colwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).inc_ref().ptr()); + PyObject* args = PyTuple_New(0); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorPythonClass), args, kwargs); + if (result == nullptr) { + PyErr_Print(); + } + Py_DECREF(args); + Py_DECREF(kwargs); + NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensor instance"); + ret = py::reinterpret_steal(result); } return {std::move(tensor), std::move(ret)}; @@ -950,20 +1049,49 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve // Construct Python MXFP8 tensor py::object out_py; if (internal) { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); - out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, - "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyObject* args = PyTuple_New(0); + PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* result = + PyObject_Call(reinterpret_cast(MXFP8TensorStoragePythonClass), args, kwargs); + if (result == nullptr) { + PyErr_Print(); + } + Py_DECREF(args); + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create MXFP8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); - out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = rowwise_data_py, - "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* args = PyTuple_New(0); + PyObject* result = + PyObject_Call(reinterpret_cast(MXFP8TensorPythonClass), args, kwargs); + if (result == nullptr) { + PyErr_Print(); + } + Py_DECREF(args); + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create MXFP8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ MXFP8 tensor @@ -1234,22 +1362,54 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // Construct Python NVFP4 tensor py::object out_py; if (internal) { - py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); - out_py = NVFP4TensorClass( - "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, - "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); + PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); + PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* args = PyTuple_New(0); + + PyObject* result = + PyObject_Call(reinterpret_cast(NVFP4TensorStoragePythonClass), args, kwargs); + if (result == nullptr) { + PyErr_Print(); + } + Py_DECREF(args); + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create NVFP4TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorPythonClass)); - out_py = NVFP4TensorClass( - "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, - "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); + PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); + PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* args = PyTuple_New(0); + PyObject* result = + PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), args, kwargs); + if (result == nullptr) { + PyErr_Print(); + } + Py_DECREF(args); + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create NVFP4Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ tensor diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ad5cd04341..368b61b382 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -929,12 +929,11 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: if torch.is_autocast_enabled(): self.activation_dtype = torch_get_autocast_gpu_dtype() return - + dtype = inp.dtype # All checks after this have already been performed once, thus skip - if self.activation_dtype == inp.dtype: + if self.activation_dtype == dtype: return - dtype = inp.dtype if not self.allow_different_data_and_param_types: for name, param in self.named_parameters(): if param is not None: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b8349f84a0..d11e492f92 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -31,7 +31,6 @@ clear_tensor_data, divide, init_method_constant, - requires_grad, needs_quantized_gemm, assert_dim_for_fp8_exec, assert_dim_for_all_gather, @@ -93,7 +92,6 @@ def forward( non_tensor_args: Tuple, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - ( is_first_microbatch, fp8, @@ -130,6 +128,10 @@ def forward( debug, ) = non_tensor_args + inp_requires_grad = inp.requires_grad + weight_requires_grad = weight.requires_grad + bias_requires_grad = bias.requires_grad if bias is not None else False + # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" if ub_name is not None: @@ -141,7 +143,7 @@ def forward( # Configure tensor-parallel communication tp_world_size = get_distributed_world_size(tp_group) - backward_needs_input = is_grad_enabled and weight.requires_grad + backward_needs_input = is_grad_enabled and weight_requires_grad with_input_all_gather_nccl = ( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) @@ -254,7 +256,7 @@ def forward( # Configure quantizer # No need to set the quantizer states if weight is already quantized if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): - columnwise_usage = is_grad_enabled and inp.requires_grad + columnwise_usage = is_grad_enabled and inp_requires_grad if not columnwise_usage: columnwise_usage = ( is_fp8_activation_recompute_enabled() @@ -379,7 +381,7 @@ def forward( ctx.weight_quantizer = weight_quantizer ctx.backward_input_needs_gather = ( - weight.requires_grad and parallel_mode == "column" and sequence_parallel + weight_requires_grad and parallel_mode == "column" and sequence_parallel ) # Discard unneeded data in input tensor @@ -448,7 +450,7 @@ def forward( ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - if fuse_wgrad_accumulation and weight.requires_grad: + if fuse_wgrad_accumulation and weight_requires_grad: # This check is needed to ensure that main_grad is not created # during the forward pass when using MCore FSDP as it creates # the main_grad buffer lazily before backprop @@ -474,12 +476,12 @@ def forward( ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_name = ub_name ctx.tp_size = tp_size - ctx.requires_dgrad = inp.requires_grad - ctx.requires_wgrad = weight.requires_grad + ctx.requires_dgrad = inp_requires_grad + ctx.requires_wgrad = weight_requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and requires_grad(inp, weight, bias): + if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index ac827e794a..e6c2f92dff 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -355,9 +355,58 @@ def __new__( requires_grad=requires_grad, device=torch.cuda.current_device() if device is None else device, ) - + # instance._requires_grad = requires_grad + # instance._dtype = dtype return instance + @property + def dtype(self) -> torch.dtype: + """ + Return the high precision data type of the tensor + Attribute access of custom tensors goes through an + expensive Pyobject lookup. Since dtype for a tensor is never + change after creation, we cache it in a member variable and return + """ + # Lazy initialization for tensors created via alternate paths + if not hasattr(self, "_dtype"): + self._dtype = torch._C.TensorBase.dtype.__get__(self, type(self)) + return self._dtype + + @dtype.setter + def dtype(self, value: torch.dtype) -> None: + """Set dtype property""" + # Update the cached value + self._dtype = value + warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") + + @property + def requires_grad(self) -> bool: + """ + Return whether or not the tensor requires gradient. + Attribute access of custom tensors goes through an + expensive Pyobject lookup. Since requires_grad is set during + initialization and may be updated, we cache it in a member variable. + """ + # Fallback to parent if not cached yet + if not hasattr(self, "_requires_grad"): + self._requires_grad = torch._C.TensorBase.requires_grad.__get__(self, type(self)) + return self._requires_grad + + @requires_grad.setter + def requires_grad(self, value: bool) -> None: + """Set requires_grad property so that autograd engine is aware of the change""" + # Update the cached value and call parent class method to ensure autograd engine is aware + self.requires_grad_(value) + + def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: + """Cache requires_grad property and call parent class method""" + # pylint: disable=missing-function-docstring + # Update the cached value + self._requires_grad = requires_grad + # Call parent class method to ensure autograd engine is aware + super().requires_grad_(requires_grad) + return self + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Convert quantized data to standard PyTorch tensor""" raise NotImplementedError( diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 03c16ebbed..005aa8d8b6 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -602,6 +602,24 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): # Cast to FP8 when setting Float8BlockwiseQTensor.data data = property(_get_data, _set_data) + @property + def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.shape + if self._columnwise_data is not None: + return self._columnwise_data.shape + raise RuntimeError("Float8BlockwiseQTensor has no data!") + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + if self._rowwise_data is not None: + return self._rowwise_data.is_cuda + if self._columnwise_data is not None: + return self._columnwise_data.is_cuda + raise RuntimeError("Float8BlockwiseQTensor has no data!") + class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 43cbdcf9e6..6bc3e42a0a 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -910,6 +910,25 @@ def fsdp_post_all_gather( ) return out, all_gather_outputs + @property + def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._data is not None: + return self._data.shape + if self._transpose is not None: + transpose_shape = self._transpose.shape + return tuple(transpose_shape[1:]) + (transpose_shape[0],) + raise RuntimeError("Both data and transpose are None") + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + if self._data is not None: + return self._data.is_cuda + if self._transpose is not None: + return self._transpose.is_cuda + raise RuntimeError("Both data and transpose are None") + @classmethod def _make_in_reduce_ex( cls, @@ -982,6 +1001,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) + self.dtype = tensor.dtype # Float8Tensor attributes self._data = tensor._data diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 88081f51bf..98fb59f387 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -785,6 +785,9 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) + # Cache the attributes + self.dtype = tensor.dtype + self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data self._quantizer = tensor._quantizer.copy() @@ -803,6 +806,24 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Cast to FP8 when setting MXFP8Tensor.data data = property(_get_data, _set_data) + @property + def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.shape + if self._columnwise_data is not None: + return self._columnwise_data.shape + raise RuntimeError("MXFP8Tensor has no data!") + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + if self._rowwise_data is not None: + return self._rowwise_data.is_cuda + if self._columnwise_data is not None: + return self._columnwise_data.is_cuda + raise RuntimeError("MXFP8Tensor has no data!") + class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 8b707af3b2..cc8c348aa2 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -689,6 +689,9 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(NVFP4Tensor, type(self)).data.__set__(self, dummy_tensor) + # Cache the attributes + self.dtype = tensor.dtype + self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data self._quantizer = tensor._quantizer