Skip to content

Commit

Permalink
[TKW] RPE fixes (#506)
Browse files Browse the repository at this point in the history
* Fix rpe mapping condition
* make `max_context_length` a test param
* Enable perf testing

---------

Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 authored Feb 14, 2025
1 parent df50964 commit 36d74e9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
2 changes: 1 addition & 1 deletion iree/turbine/kernel/wave/templates/t5_rpe_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_t5_rpe_attention_kernel(

d0, d1 = [tkw.IndexMapping.dynamic_val(i) for i in range(2)]
clip = sympy.Piecewise(
(d0 - d1, (d0 - d1 <= max_context_length) & (d0 - d1 > 0)), (0, True)
(d0 - d1, (d0 - d1 < max_context_length) & (d0 - d1 >= 0)), (0, True)
)
offset_mapping = tkw.IndexMapping(
num_iterators=2,
Expand Down
14 changes: 9 additions & 5 deletions tests/kernel/wave/attention/t5_rpe_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,18 @@
from iree.turbine.kernel.wave.templates.t5_rpe_attention import (
get_t5_rpe_attention_kernel,
)
from ..common.shapes import make_shape_param
from ..common.utils import (
require_e2e,
require_cdna3,
enable_scheduling_barriers,
)
from typing import Tuple

shapes = [(128, 128, 128, 128, 128, 128)]

# T5 RPE parameter
max_context_length = 10
shapes = [
make_shape_param((128, 128, 128, 128, 128, 128), is_perf=False),
make_shape_param((128, 128, 128, 128, 128, 128), is_perf=True),
]


def t5_rpe_masked_cond(
Expand All @@ -54,6 +55,7 @@ def validate_accuracy(
value: torch.Tensor,
rpe: torch.Tensor,
output: torch.Tensor,
max_context_length: int,
) -> torch.Tensor:
# Precompute values.
dk_sqrt = math.sqrt(1.0 / query.shape[-1])
Expand Down Expand Up @@ -86,13 +88,15 @@ def create_inputs(
@require_e2e
@require_cdna3
@pytest.mark.parametrize("shape", shapes)
@pytest.mark.parametrize("max_context_length", [10, 128]) # T5 RPE parameter
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize(
"mfma_variant",
[(MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)],
)
def test_t5_rpe_attention(
shape: Tuple[int],
max_context_length: int,
dtype: torch.dtype,
mfma_variant: MMAType,
request,
Expand Down Expand Up @@ -160,4 +164,4 @@ def test_t5_rpe_attention(
output,
)

validate_accuracy(query, key, value, rpe, output)
validate_accuracy(query, key, value, rpe, output, max_context_length)
17 changes: 10 additions & 7 deletions tests/kernel/wave/common/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,16 @@ def construct_test_name(
return test_name + ".json"


def make_shape_param(shape: Sequence[int], is_perf: bool):
name = "x".join(map(str, shape))
if is_perf:
return pytest.param(shape, id=name + "-perf", marks=pytest.mark.perf_only)
else:
return pytest.param(shape, id=name)


def get_test_shapes(test_name: str):
assert test_name in _e2e_test_shapes, f"Unknown test name: {test_name}"
shapes = [
pytest.param(s, id="x".join(map(str, s))) for s in _e2e_test_shapes[test_name]
]
shapes += [
pytest.param(s, id="x".join(map(str, s)) + "-perf", marks=pytest.mark.perf_only)
for s in _perf_test_shapes[test_name]
]
shapes = [make_shape_param(s, False) for s in _e2e_test_shapes[test_name]]
shapes += [make_shape_param(s, True) for s in _perf_test_shapes[test_name]]
return shapes

0 comments on commit 36d74e9

Please sign in to comment.