Skip to content

Commit

Permalink
Refactor Attention tensor allocation
Browse files Browse the repository at this point in the history
Signed-off-by: ERMAN GURSES <[email protected]>
  • Loading branch information
erman-gurses committed Aug 14, 2024
1 parent a557826 commit 5a119bb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
4 changes: 2 additions & 2 deletions tests/e2e/attention/generate_e2e_attention_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ class TestShapeAndScale:
def get_test_shapes(shapes_id: ShapesId):
if shapes_id == ShapesId.SMALL:
return [
TestShapeAndScale(batch=2, m=1024, k1=64, k2=512, n=32, scale=1.0),
TestShapeAndScale(batch=2, m=1024, k1=64, k2=256, n=32, scale=1.0),
]
if shapes_id == ShapesId.MEDIUM:
return [
TestShapeAndScale(batch=2, m=2048, k1=128, k2=512, n=64, scale=1.0),
]
if shapes_id == ShapesId.LARGE:
return [
TestShapeAndScale(batch=2, m=4096, k1=64, k2=1024, n=128, scale=1.0),
TestShapeAndScale(batch=4, m=4096, k1=64, k2=1024, n=128, scale=1.0),
]

raise ValueError(shapes_id)
Expand Down
24 changes: 13 additions & 11 deletions tools/testing/e2e/iree-e2e-attention-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include "iree/base/api.h"
#include "iree/base/internal/cpu.h"
Expand Down Expand Up @@ -49,9 +50,8 @@ int index_3d(int i, int j, int k, int dim2, int dim3) {
static void reference_attention_f32_f32_f32_f32(
iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N,
iree_hal_dim_t B, const float* query_data, const float* key_data,
const float* value_data, float* result_data, iree_hal_dim_t b) {
float* Attention = allocate_tensor(B, M, K2);

const float* value_data, float* result_data, iree_hal_dim_t b,
float* Attention) {
// Compute Q * K^T
for (int m = 0; m < M; ++m) {
for (int k2 = 0; k2 < K2; ++k2) {
Expand All @@ -62,7 +62,7 @@ static void reference_attention_f32_f32_f32_f32(

sum += query_data[q_idx] * key_data[k_idx];
}
int att_idx = index_3d(b, m, k2, M, K2);
int att_idx = index_3d(0, m, k2, M, K2);
Attention[att_idx] = sum / sqrt(K1); // Scale by sqrt(K1)
}
}
Expand All @@ -72,12 +72,12 @@ static void reference_attention_f32_f32_f32_f32(
// Calculate softmax denominator
float sum = 0.0;
for (int k2 = 0; k2 < K2; ++k2) {
int att_idx = index_3d(b, m, k2, M, K2);
int att_idx = index_3d(0, m, k2, M, K2);
sum += exp(Attention[att_idx]);
}
// Apply softmax
for (int k2 = 0; k2 < K2; ++k2) {
int att_idx = index_3d(b, m, k2, M, K2);
int att_idx = index_3d(0, m, k2, M, K2);
Attention[att_idx] = exp(Attention[att_idx]) / sum;
}
}
Expand All @@ -87,29 +87,29 @@ static void reference_attention_f32_f32_f32_f32(
for (int n = 0; n < N; ++n) {
float sum = 0.0;
for (int k2 = 0; k2 < K2; ++k2) {
int att_idx = index_3d(b, m, k2, M, K2);
int att_idx = index_3d(0, m, k2, M, K2);
int v_idx = index_3d(b, k2, n, K2, N);
sum += Attention[att_idx] * value_data[v_idx];
}
int o_idx = index_3d(b, m, n, M, N);
result_data[o_idx] = sum;
}
}
free_tensor(Attention);
}

static iree_status_t reference_attention_element(
iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N,
iree_hal_dim_t B, iree_hal_element_type_t query_elem_type,
iree_hal_element_type_t key_elem_type,
iree_hal_element_type_t value_elem_type, void* query_data, void* key_data,
void* value_data, void* actual_data, void* result_data, iree_hal_dim_t b) {
void* value_data, void* actual_data, void* result_data, iree_hal_dim_t b,
float* Attention) {
if (query_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
key_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
value_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
reference_attention_f32_f32_f32_f32(
M, K1, K2, N, B, (const float*)query_data, (const float*)key_data,
(const float*)value_data, (float*)result_data, b);
(const float*)value_data, (float*)result_data, b, Attention);

} else {
return iree_make_status(
Expand Down Expand Up @@ -137,6 +137,7 @@ static iree_status_t reference_attention(
IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, N);

iree_host_size_t count = 0;
float* Attention = allocate_tensor(1, M, K2);
for (iree_hal_dim_t b = 0; b < B; ++b) {
if (++count < compute_every) continue;
count = 0;
Expand All @@ -145,8 +146,9 @@ static iree_status_t reference_attention(
reference_attention_element(
M, K1, K2, N, B, query_elem_type, key_elem_type, value_elem_type,
query_contents.data, key_contents.data, value_contents.data,
actual_contents.data, result_contents.data, b));
actual_contents.data, result_contents.data, b, Attention));
}
free_tensor(Attention);

IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
Expand Down

0 comments on commit 5a119bb

Please sign in to comment.