diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 14f7dcf4f41ad..e3db937dd2c58 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8135,7 +8135,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( } // V /= S - const float S_inv = 1.0f/S; + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; ggml_vec_scale_f32(DV, VKQ32, S_inv); // dst indices diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 62d815cc26808..151c01f4dbdd9 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -131,6 +131,50 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } } +// generate an F16 mask where certain blocks are randomly masked with -INF value +static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { + GGML_ASSERT(tensor->type == GGML_TYPE_F16); + + GGML_TENSOR_LOCALS( int32_t, ne, tensor, ne); + + std::vector data_f32(ne0*ne1*ne2*ne3); + std::vector data_f16(ne0*ne1*ne2*ne3); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dis(min, max); + + for (size_t i = 0; i < data_f32.size(); i++) { + data_f32[i] = dis(gen); + } + + // block size + const int blck0 = 128; + const int blck1 = 64; + + // number of INF blocks + const int n_inf_blocks = 0.1*(ne0*ne1*ne2*ne3)/(blck0*blck1); + + for (int b = 0; b < n_inf_blocks; b++) { + const int p3 = (rd() % ne3); + const int p2 = (rd() % ne2); + const int p1 = (rd() % ne1); + const int p0 = (rd() % ne0); + + for (int i1 = 0; i1 < blck1 && p1 + i1 < ne1; i1++) { + const int idx = p3*ne2*ne1*ne0 + p2*ne1*ne0 + (p1 + i1)*ne0 + p0; + + for (int i0 = 0; i0 < blck0 && p0 + i0 < ne0; i0++) { + data_f32[idx + i0] = -INFINITY; + } + } + } + + ggml_fp32_to_fp16_row(data_f32.data(), data_f16.data(), ne0*ne1*ne2*ne3); + + ggml_backend_tensor_set(tensor, data_f16.data(), 0, data_f16.size()*sizeof(ggml_fp16_t)); +} + static std::vector tensor_to_float(const ggml_tensor * t) { std::vector tv; tv.reserve(ggml_nelements(t)); @@ -5104,6 +5148,8 @@ struct test_flash_attn_ext : public test_case { if (strcmp(t->name, "s") == 0) { // make the sink values more noticable in order to trigger a test failure when the implementation is wrong init_tensor_uniform(t, -10.0f, 10.0f); + } else if (strcmp(t->name, "m") == 0) { + init_tensor_kq_mask(t); } else { init_tensor_uniform(t); }