@@ -1116,30 +1116,65 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
1116
1116
return GGML_STATUS_SUCCESS;
1117
1117
}
1118
1118
1119
- // ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed
1120
- namespace {
1121
- void * g_nz_workspace = nullptr ;
1122
- size_t g_nz_workspace_allocated = 0 ;
1123
-
1124
- void release_nz_workspace () {
1125
- if (g_nz_workspace) {
1126
- aclrtFree (g_nz_workspace);
1127
- g_nz_workspace = nullptr ;
1128
- g_nz_workspace_allocated = 0 ;
1119
+ /* *
1120
+ * @brief Workspace for caching NZ buffers per device.
1121
+ *
1122
+ * This struct manages a device buffer used in NZ computations. It supports
1123
+ * allocation, reallocation, and clearing of cached memory. The struct is
1124
+ * designed to be used with a global array, one per device.
1125
+ */
1126
+ struct ggml_cann_nz_workspace {
1127
+ void * ptr; // Pointer to allocated device buffer
1128
+ size_t allocated; // Size of currently allocated buffer in bytes
1129
+
1130
+ /* *
1131
+ * @brief Constructor. Initializes the workspace with no allocated memory.
1132
+ */
1133
+ ggml_cann_nz_workspace () : ptr(nullptr ), allocated(0 ) {}
1134
+
1135
+ /* *
1136
+ * @brief Free cached memory and reset the workspace.
1137
+ *
1138
+ * If a buffer has been allocated, this function releases it using
1139
+ * aclrtFree and resets internal state.
1140
+ */
1141
+ void clear () {
1142
+ if (ptr) {
1143
+ ACL_CHECK (aclrtFree (ptr));
1144
+ ptr = nullptr ;
1145
+ allocated = 0 ;
1129
1146
}
1130
1147
}
1131
1148
1132
- void relloc_nz_workspace (size_t new_size) {
1133
- if (new_size > g_nz_workspace_allocated) {
1134
- if (g_nz_workspace) {
1135
- aclrtFree (g_nz_workspace);
1136
- g_nz_workspace = nullptr ;
1149
+ /* *
1150
+ * @brief Allocate or reallocate the workspace buffer.
1151
+ *
1152
+ * If the requested size is larger than the currently allocated size,
1153
+ * the old buffer will be freed and a new buffer of the requested size
1154
+ * will be allocated on the device.
1155
+ *
1156
+ * @param new_size Size in bytes to allocate for the workspace.
1157
+ */
1158
+ void realloc (size_t new_size) {
1159
+ if (new_size > allocated) {
1160
+ clear ();
1161
+ ACL_CHECK (aclrtMalloc (&ptr, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1162
+ allocated = new_size;
1137
1163
}
1138
- ACL_CHECK (aclrtMalloc (&g_nz_workspace, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1139
- g_nz_workspace_allocated = new_size;
1140
- }
1141
1164
}
1142
- }
1165
+
1166
+ /* *
1167
+ * @brief Get the device buffer pointer.
1168
+ *
1169
+ * @return Pointer to the allocated buffer, or nullptr if not allocated.
1170
+ */
1171
+ void * get () const { return ptr; }
1172
+ };
1173
+
1174
+ /* *
1175
+ * @brief Global array of NZ workspaces, one per device.
1176
+ */
1177
+ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
1143
1178
1144
1179
/* *
1145
1180
* @brief Convert tensor weights to NZ format using Ascend CANN API.
@@ -1149,13 +1184,13 @@ namespace {
1149
1184
* improve performance on certain hardware.
1150
1185
*
1151
1186
* @param tensor Pointer to the input ggml_tensor containing the weights.
1152
- * @param data Pointer to the raw data buffer for the tensor weights.
1153
1187
* @param offset Byte offset within the tensor data buffer where weights start.
1188
+ * @param device device id.
1154
1189
*
1155
1190
* @note The workspace buffer used in this function is managed globally and reused
1156
1191
* across calls. This reduces overhead from repeated memory allocation and deallocation.
1157
1192
*/
1158
- static void weight_format_to_nz (ggml_tensor *tensor, size_t offset) {
1193
+ static void weight_format_to_nz (ggml_tensor *tensor, size_t offset, int device ) {
1159
1194
aclTensor* weightTransposed = ggml_cann_create_tensor (tensor, tensor->ne ,
1160
1195
tensor->nb , 2 , ACL_FORMAT_ND, offset);
1161
1196
uint64_t workspaceSize = 0 ;
@@ -1165,7 +1200,9 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
1165
1200
ACL_CHECK (aclnnTransMatmulWeightGetWorkspaceSize (weightTransposed,
1166
1201
&workspaceSize, &executor));
1167
1202
// Avoid frequent malloc/free of the workspace.
1168
- relloc_nz_workspace (workspaceSize);
1203
+ g_nz_workspaces[device].realloc (workspaceSize);
1204
+
1205
+ void * g_nz_workspace = g_nz_workspaces[device].get ();
1169
1206
1170
1207
ACL_CHECK (aclnnTransMatmulWeight (g_nz_workspace, workspaceSize, executor, nullptr ));
1171
1208
ACL_CHECK (aclDestroyTensor (weightTransposed));
@@ -1203,7 +1240,7 @@ static void ggml_backend_cann_buffer_set_tensor(
1203
1240
if (weight_to_nz && is_matmul_weight ((const ggml_tensor*)tensor)) {
1204
1241
GGML_ASSERT (tensor->ne [2 ] == 1 );
1205
1242
GGML_ASSERT (tensor->ne [3 ] == 1 );
1206
- weight_format_to_nz (tensor, offset);
1243
+ weight_format_to_nz (tensor, offset, ctx-> device );
1207
1244
}
1208
1245
} else {
1209
1246
void *transform_buffer = malloc (size);
@@ -2262,7 +2299,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
2262
2299
ggml_backend_cann_context* cann_ctx =
2263
2300
(ggml_backend_cann_context*)backend->context ;
2264
2301
ggml_cann_set_device (cann_ctx->device );
2265
- release_nz_workspace ();
2302
+ g_nz_workspaces[cann_ctx-> device ]. clear ();
2266
2303
2267
2304
#ifdef USE_ACL_GRAPH
2268
2305
bool use_cann_graph = true ;
0 commit comments