Skip to content

Commit c1c354e

Browse files
CANN: Refactor ND to NZ workspace to be per-device (#15763)
* CANN:Refactor ND to NZ workspace to be per-device in Ascend backend - Replaced the previous single global ND→NZ workspace with a per-device cache using unordered_map keyed by device ID. - Functions `release_nz_workspace`, `relloc_nz_workspace`, and `get_nz_workspace` now manage workspace independently for each device, preventing memory conflicts in multi-device / pipeline parallel scenarios. - This change fixes potential precision issues caused by workspace overwrites when multiple devices perform ND→NZ conversions concurrently. Co-authored-by: hipudding <[email protected]> * refactor Signed-off-by: noemotiovon <[email protected]> * rename Signed-off-by: noemotiovon <[email protected]> * fix review comments Signed-off-by: noemotiovon <[email protected]> --------- Signed-off-by: noemotiovon <[email protected]> Co-authored-by: hipudding <[email protected]>
1 parent a68d914 commit c1c354e

File tree

1 file changed

+61
-24
lines changed

1 file changed

+61
-24
lines changed

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,30 +1116,65 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
11161116
return GGML_STATUS_SUCCESS;
11171117
}
11181118

1119-
// ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed
1120-
namespace {
1121-
void* g_nz_workspace = nullptr;
1122-
size_t g_nz_workspace_allocated = 0;
1123-
1124-
void release_nz_workspace() {
1125-
if (g_nz_workspace) {
1126-
aclrtFree(g_nz_workspace);
1127-
g_nz_workspace = nullptr;
1128-
g_nz_workspace_allocated = 0;
1119+
/**
1120+
* @brief Workspace for caching NZ buffers per device.
1121+
*
1122+
* This struct manages a device buffer used in NZ computations. It supports
1123+
* allocation, reallocation, and clearing of cached memory. The struct is
1124+
* designed to be used with a global array, one per device.
1125+
*/
1126+
struct ggml_cann_nz_workspace {
1127+
void* ptr; // Pointer to allocated device buffer
1128+
size_t allocated; // Size of currently allocated buffer in bytes
1129+
1130+
/**
1131+
* @brief Constructor. Initializes the workspace with no allocated memory.
1132+
*/
1133+
ggml_cann_nz_workspace() : ptr(nullptr), allocated(0) {}
1134+
1135+
/**
1136+
* @brief Free cached memory and reset the workspace.
1137+
*
1138+
* If a buffer has been allocated, this function releases it using
1139+
* aclrtFree and resets internal state.
1140+
*/
1141+
void clear() {
1142+
if (ptr) {
1143+
ACL_CHECK(aclrtFree(ptr));
1144+
ptr = nullptr;
1145+
allocated = 0;
11291146
}
11301147
}
11311148

1132-
void relloc_nz_workspace(size_t new_size) {
1133-
if (new_size > g_nz_workspace_allocated) {
1134-
if (g_nz_workspace) {
1135-
aclrtFree(g_nz_workspace);
1136-
g_nz_workspace = nullptr;
1149+
/**
1150+
* @brief Allocate or reallocate the workspace buffer.
1151+
*
1152+
* If the requested size is larger than the currently allocated size,
1153+
* the old buffer will be freed and a new buffer of the requested size
1154+
* will be allocated on the device.
1155+
*
1156+
* @param new_size Size in bytes to allocate for the workspace.
1157+
*/
1158+
void realloc(size_t new_size) {
1159+
if (new_size > allocated) {
1160+
clear();
1161+
ACL_CHECK(aclrtMalloc(&ptr, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1162+
allocated = new_size;
11371163
}
1138-
ACL_CHECK(aclrtMalloc(&g_nz_workspace, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1139-
g_nz_workspace_allocated = new_size;
1140-
}
11411164
}
1142-
}
1165+
1166+
/**
1167+
* @brief Get the device buffer pointer.
1168+
*
1169+
* @return Pointer to the allocated buffer, or nullptr if not allocated.
1170+
*/
1171+
void* get() const { return ptr; }
1172+
};
1173+
1174+
/**
1175+
* @brief Global array of NZ workspaces, one per device.
1176+
*/
1177+
static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
11431178

11441179
/**
11451180
* @brief Convert tensor weights to NZ format using Ascend CANN API.
@@ -1149,13 +1184,13 @@ namespace {
11491184
* improve performance on certain hardware.
11501185
*
11511186
* @param tensor Pointer to the input ggml_tensor containing the weights.
1152-
* @param data Pointer to the raw data buffer for the tensor weights.
11531187
* @param offset Byte offset within the tensor data buffer where weights start.
1188+
* @param device device id.
11541189
*
11551190
* @note The workspace buffer used in this function is managed globally and reused
11561191
* across calls. This reduces overhead from repeated memory allocation and deallocation.
11571192
*/
1158-
static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
1193+
static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device) {
11591194
aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
11601195
tensor->nb, 2, ACL_FORMAT_ND, offset);
11611196
uint64_t workspaceSize = 0;
@@ -1165,7 +1200,9 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
11651200
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed,
11661201
&workspaceSize, &executor));
11671202
// Avoid frequent malloc/free of the workspace.
1168-
relloc_nz_workspace(workspaceSize);
1203+
g_nz_workspaces[device].realloc(workspaceSize);
1204+
1205+
void* g_nz_workspace = g_nz_workspaces[device].get();
11691206

11701207
ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
11711208
ACL_CHECK(aclDestroyTensor(weightTransposed));
@@ -1203,7 +1240,7 @@ static void ggml_backend_cann_buffer_set_tensor(
12031240
if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
12041241
GGML_ASSERT(tensor->ne[2] == 1);
12051242
GGML_ASSERT(tensor->ne[3] == 1);
1206-
weight_format_to_nz(tensor, offset);
1243+
weight_format_to_nz(tensor, offset, ctx->device);
12071244
}
12081245
} else {
12091246
void *transform_buffer = malloc(size);
@@ -2262,7 +2299,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
22622299
ggml_backend_cann_context* cann_ctx =
22632300
(ggml_backend_cann_context*)backend->context;
22642301
ggml_cann_set_device(cann_ctx->device);
2265-
release_nz_workspace();
2302+
g_nz_workspaces[cann_ctx->device].clear();
22662303

22672304
#ifdef USE_ACL_GRAPH
22682305
bool use_cann_graph = true;

0 commit comments

Comments
 (0)