From 70381271189132cb937e6eb0a47096f14142eadd Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Tue, 4 Feb 2025 17:52:48 -0800 Subject: [PATCH] [Wave] Add self_index, predicate, and selectOp to implement causal attention (#452) - Extracted core pieces of self_index, predicate, and selectOp, and LIT for predicate and select written by @nicolasvasilache and @ftynse which is required for causal mask and remove causal mask unrelated pieces. - Implemented a numerically correct causal attention kernel based on original from @nicolasvasilache - Added GPR_NUM partitioning support for SelfIndex to allow causal to work on more MMA intrinsics(i.e 32x32x8 which has GPR_NUMs) - Refactored tkw.slt/sgt/sge/sle to be operator.lt/gt/ge/le to preserve number of tkw ops and for user ergonomics - Refactored vanilla kernel to support both in single kernel, controlled by is_causal flag - Add support on handle_op to take in multiple Ops that map to same function. - Added a bunch of LIT tests --------- Signed-off-by: Alex Zinenko Signed-off-by: Nicolas Vasilache Signed-off-by: Stanley Winata Co-authored-by: Alex Zinenko Co-authored-by: Nicolas Vasilache --- iree/turbine/aot/support/ir_utils.py | 4 + iree/turbine/kernel/_support/dtype.py | 1 + iree/turbine/kernel/ops/wave_ops.py | 147 +++++++++++++-- iree/turbine/kernel/wave/codegen.py | 171 ++++++++++++++++-- .../kernel/wave/index_sequence_analysis.py | 22 ++- .../wave/templates/vanilla_attention.py | 19 +- lit_tests/kernel/wave/attention.py | 59 ++++++ lit_tests/kernel/wave/codegen.py | 153 ++++++++++++++++ .../wave/attention/vanilla_attention_test.py | 83 +++++++++ 9 files changed, 628 insertions(+), 31 deletions(-) diff --git a/iree/turbine/aot/support/ir_utils.py b/iree/turbine/aot/support/ir_utils.py index 21c10b277..a2b075209 100644 --- a/iree/turbine/aot/support/ir_utils.py +++ b/iree/turbine/aot/support/ir_utils.py @@ -489,6 +489,10 @@ def _is_float_type(type): return isinstance(type, (BF16Type, F16Type, F32Type, F64Type, Float8E4M3FNUZType)) +def _is_index_type(type): + return isinstance(type, (IndexType)) + + def _is_integer_like_type(type): return isinstance(type, (IntegerType, IndexType)) diff --git a/iree/turbine/kernel/_support/dtype.py b/iree/turbine/kernel/_support/dtype.py index f850c4da1..62c8590ce 100644 --- a/iree/turbine/kernel/_support/dtype.py +++ b/iree/turbine/kernel/_support/dtype.py @@ -72,6 +72,7 @@ def bitwidth(self): bf16 = DataType("bf16") bool = DataType("bool", "i1") +i1 = bool i4 = DataType("i4") i8 = DataType("i8") i16 = DataType("i16") diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 404ef13bd..8f85384e4 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -20,7 +20,7 @@ from ..lang.wave_types import Memory, Register, IndexMapping from ..lang.global_symbols import * from .._support.indexing import IndexExpr, IndexSymbol, IndexSequence -from .._support.dtype import DataType +from .._support.dtype import DataType, i1 from .._support.regions import RegionGraph from .base import OpDispatcher import numpy as np @@ -45,6 +45,14 @@ def allocate( ... +def self_index( + idx: IndexExpr, + dtype: DataType, + elements_per_thread: Optional[IndexExpr | int] = None, +) -> "Register": + ... + + def extract( register: "Register", offsets: tuple[IndexExpr], @@ -166,6 +174,22 @@ def shuffle(src: "Register", offset: int, width: int) -> "Register": ... +def gt(lhs: "Register", rhs: "Register") -> "Register": + ... + + +def ge(lhs: "Register", rhs: "Register") -> "Register": + ... + + +def lt(lhs: "Register", rhs: "Register") -> "Register": + ... + + +def le(lhs: "Register", rhs: "Register") -> "Register": + ... + + def cast(src: "Register", dtype: DataType) -> "Register": ... @@ -178,6 +202,10 @@ def reshape(inputs: Sequence["Register"]) -> "Register": ... +def select(cond: "Register", if_true: "Register", if_false: "Register") -> "Register": + ... + + def define_op(op_name: str) -> Callable[[T], T]: def decorator(cls: T) -> T: cls.tkw_op_name = op_name @@ -680,14 +708,8 @@ def transform_index( return index -@define_py_op(operator.add) -@define_py_op(operator.sub) -@define_py_op(operator.mul) -@define_py_op(operator.truediv) -@define_interface_op("maximum") -@define_interface_op("minimum") @dataclass -class BinaryPyOp(CustomOp, ABC): +class BinaryOpBase(CustomOp, ABC): """ Represents an elementwise binary python operator. @@ -715,21 +737,51 @@ def indexing_dims(self) -> list[IndexSymbol]: def py_operator(self) -> str: return self.tkw_op_name - def infer_type(self): + def infer_shape(self) -> Any: lhs_type = get_custom(self.lhs).type rhs_type = get_custom(self.rhs).type has_same_type = has_same_custom_type(lhs_type, rhs_type) if has_same_type: - self.type = lhs_type - return + return lhs_type.symbolic_shape + lhs_dim_set = set(lhs_type.symbolic_shape) rhs_dim_set = set(rhs_type.symbolic_shape) if lhs_dim_set.isdisjoint(rhs_dim_set): raise ValueError( "BinaryPyOp requires lhs and rhs shape to be at least broadcastable." + f" got {lhs_type.symbolic_shape} vs {rhs_type.symbolic_shape}" ) + + # TODO: this logic looks suspicious. Specifically, there's no check that + # rhs_dim_set subsumes lhs_dim_set, they may partially overlap. broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type - self.type = broadcasted_type + return broadcasted_type.symbolic_shape + + +@define_py_op(operator.add) +@define_py_op(operator.sub) +@define_py_op(operator.mul) +@define_py_op(operator.truediv) +@define_interface_op("maximum") +@define_interface_op("minimum") +@dataclass +class BinaryPyOp(BinaryOpBase, ABC): + def infer_type(self): + self.type = Register[(*self.infer_shape(), get_custom(self.lhs).type.dtype)] + + +@define_py_op(operator.gt) +@define_py_op(operator.ge) +@define_py_op(operator.lt) +@define_py_op(operator.le) +@define_interface_op("gt") +@define_interface_op("ge") +@define_interface_op("lt") +@define_interface_op("le") +@dataclass +class ComparisonPyOp(BinaryOpBase, ABC): + def infer_type(self): + self.type = Register[(*self.infer_shape(), i1)] @define_interface_op("log2") @@ -759,6 +811,42 @@ def infer_type(self): self.type = src_type +@define_op("select") +@dataclass +class SelectOp(CustomOp): + cond: fx.Node + if_true: fx.Node + if_false: fx.Node + + @property + def indexing_dims(self) -> list[IndexSymbol]: + combined_dims = [] + combined_dims += get_custom(self.cond).indexing_dims + combined_dims += get_custom(self.if_true).indexing_dims + combined_dims += get_custom(self.if_false).indexing_dims + return list(dict.fromkeys(combined_dims)) + + def infer_type(self): + cond_type = get_custom(self.cond).type + if_true_type = get_custom(self.if_true).type + if_false_type = get_custom(self.if_false).type + + if cond_type.dtype != i1: + raise ValueError("SelectOp expects condition type to be i1.") + + if if_true_type.dtype != if_false_type.dtype: + raise ValueError("SelectOp expects lhs and rhs dtype to match.") + + # TODO: support broadcasting behavior. + if ( + cond_type.symbolic_shape != if_true_type.symbolic_shape + or cond_type.symbolic_shape != if_false_type.symbolic_shape + ): + raise ValueError("SelectOp doesn't support broadcasting. (yet?)") + + self.type = if_true_type + + @final @dataclass class Unknown(CustomOp): @@ -940,6 +1028,22 @@ def type(self) -> "Memory": return Memory[(*self.shape, self.address_space, self.dtype)] +@define_op("self_index") +@dataclass +class SelfIndex(CustomOp): + dim: IndexExpr + dtype: DataType + elements_per_thread: Optional[IndexExpr | int] = None + + @property + def indexing_dims(self) -> list[IndexSymbol]: + return [self.dim] + + @property + def type(self) -> "Register": + return Register[(self.dim, self.dtype)] + + @define_op("shared_memory_barrier") @dataclass class SharedMemoryBarrier(CustomOp): @@ -1657,6 +1761,25 @@ class Broadcast(CustomOp, ABC): arg: fx.Node target_shape: Sequence[IndexSymbol] = None + def __post_init__(self): + # Required for setting up hash. + super().__post_init__() + # Verify for valid src type. + if isinstance(self.arg, fx.Node): + src = self.arg + elif isinstance(self.arg, fx.Proxy): + src = self.arg.node + else: + raise ValueError(f"Unexpected broadcast src type of {type(self.arg)}") + + # Verifies target broadcast shape is valid. + src_type = get_custom(src).type + src_shape = set(getattr(src_type, "symbolic_shape", [])) + dst_shape = set(self.target_shape) + assert src_shape.issubset( + dst_shape + ), "Fail to initialize broadcast because of invalid target_shape." + @property def indexing_dims(self) -> list[IndexSymbol]: return self.target_shape diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 7da24d7e2..6192f7cc1 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -46,7 +46,11 @@ vector_d, llvm_d, ) -from iree.turbine.aot.support.ir_utils import _is_float_type, _is_integer_like_type +from iree.turbine.aot.support.ir_utils import ( + _is_float_type, + _is_index_type, + _is_integer_like_type, +) # TK infrastructure imports. from iree.turbine.kernel.lang.global_symbols import * @@ -61,9 +65,13 @@ exp2, extract, extract_slice, + ge, get_custom, get_result, + gt, + le, log2, + lt, maximum, minimum, mma, @@ -75,6 +83,8 @@ reshape, scheduling_barrier, scheduling_group_barrier, + self_index, + select, set_symbol, shared_memory_barrier, shuffle, @@ -576,11 +586,17 @@ def get_constant_attr(value: Any, element_type: IrType) -> Attribute: raise CodegenError(f"Cannot create a constant attribute for type `{element_type}`") -def handle_op(op: Callable[..., Any]): +def handle_op(op: Callable[..., Any] | list[Callable[..., Any]]): def decorator( f: Callable[[WaveEmitter, fx.Node], None] ) -> Callable[[WaveEmitter, fx.Node], None]: - WaveEmitter.OP_HANDLERS[op.__name__] = f + if isinstance(op, Callable): + WaveEmitter.OP_HANDLERS[op.__name__] = f + elif isinstance(op, list): + for op_iter in op: + WaveEmitter.OP_HANDLERS[op_iter.__name__] = f + else: + raise ValueError("handle_op only handle Callable or list of Callable") return f return decorator @@ -975,6 +991,46 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): vector_d.scatter(kb_dest, start_indices, offsets_vec, mask, insert_vector) +############################################################################### +# Expressions, Dims and Indexing related ops +############################################################################### + + +@handle_op(self_index) +def handle_self_index(emitter: WaveEmitter, node: fx.Node): + try: + iterator, dtype, elements_per_thread = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + + index = get_custom(node).index + var = index[iterator] + offset = subs_idxc(var.start) + size = elements_per_thread or subs_idxc(var.size) + stride = subs_idxc(var.stride) + + start = _build_start_indices(emitter, {iterator: var})[0] + + element_type = IrType.parse(dtype.ir_type_asm()) + index_type = IrType.parse("index") + vector_shape = cast_py_literal(emitter, [size]) + + vector_index_type = VectorType.get(vector_shape, index_type) + vector_type = VectorType.get(vector_shape, element_type) + + step = vector_d.step(vector_index_type) + stride_cst = arith_d.ConstantOp( + index_type, get_constant_attr(cast_py_literal(emitter, stride), index_type) + ) + stride_vec = vector_d.splat(vector_index_type, stride_cst) + scaled = arith_d.muli(step, stride_vec) + offset = vector_d.splat(vector_index_type, start) + shifted = arith_d.addi(scaled, offset) + casted_i = arith_d.index_cast(vector_type, shifted) + + emitter.bind_node_proxy(node, IRProxyValue(casted_i)) + + @handle_op(apply_expr) def handle_apply_expr(emitter: WaveEmitter, node: fx.Node): try: @@ -1142,7 +1198,10 @@ def handle_generic_binary(emitter: WaveEmitter, node: fx.Node): rhs = cast_py_value(emitter, rhs) if lhs.ir_value.type != rhs.ir_value.type: - raise ValidationError("Expected lhs and rhs to have same type.") + raise ValidationError( + "Expected lhs and rhs to have same type." + f" Got: {lhs.ir_value.type} vs {rhs.ir_value.type}" + ) lhs = lhs.ir_value rhs = rhs.ir_value @@ -1195,27 +1254,79 @@ def handle_div(lhs: Value, rhs: Value) -> OpResult: if _is_float_type(element_type): result = arith_d.divf(lhs, rhs) elif _is_integer_like_type(element_type) and ( - element_type.is_signed() or element_type.is_signless() + element_type.is_signed or element_type.is_signless ): result = arith_d.divsi(lhs, rhs) - elif _is_integer_like_type(element_type) and element_type.is_unsigned(): - result = arith_d.divui(lhs, rhs) else: raise ValidationError(f"Found unhandled operand type for div: {element_type}") return result +@handle_binary_op([operator.gt, gt]) +def handle_gt(lhs: Value, rhs: Value) -> OpResult: + element_type = get_type_or_element_type(lhs.type) + if _is_float_type(element_type): + result = arith_d.cmpi(arith_d.CmpFPredicate.OGT, lhs, rhs) + elif _is_integer_like_type(element_type) and ( + element_type.is_signed or element_type.is_signless + ): + result = arith_d.cmpi(arith_d.CmpIPredicate.sgt, lhs, rhs) + else: + raise ValidationError(f"Found unhandled operand type for gt: {element_type}") + return result + + +@handle_binary_op([ge, operator.ge]) +def handle_ge(lhs: Value, rhs: Value) -> OpResult: + element_type = get_type_or_element_type(lhs.type) + if _is_float_type(element_type): + result = arith_d.cmpi(arith_d.CmpFPredicate.OGE, lhs, rhs) + elif _is_integer_like_type(element_type) and ( + element_type.is_signed or element_type.is_signless + ): + result = arith_d.cmpi(arith_d.CmpIPredicate.sge, lhs, rhs) + else: + raise ValidationError(f"Found unhandled operand type for ge: {element_type}") + return result + + +@handle_binary_op([operator.lt, lt]) +def handle_lt(lhs: Value, rhs: Value) -> OpResult: + element_type = get_type_or_element_type(lhs.type) + if _is_float_type(element_type): + result = arith_d.cmpi(arith_d.CmpFPredicate.OLT, lhs, rhs) + elif _is_integer_like_type(element_type) and ( + element_type.is_signed or element_type.is_signless + ): + result = arith_d.cmpi(arith_d.CmpIPredicate.slt, lhs, rhs) + else: + raise ValidationError(f"Found unhandled operand type for lt: {element_type}") + return result + + +@handle_binary_op([operator.le, le]) +def handle_le(lhs: Value, rhs: Value) -> OpResult: + element_type = get_type_or_element_type(lhs.type) + if _is_float_type(element_type): + result = arith_d.cmpi(arith_d.CmpFPredicate.OLE, lhs, rhs) + elif _is_integer_like_type(element_type) and ( + element_type.is_signed or element_type.is_signless + ): + result = arith_d.cmpi(arith_d.CmpIPredicate.sle, lhs, rhs) + else: + raise ValidationError(f"Found unhandled operand type for le: {element_type}") + return result + + @handle_binary_op(maximum) def handle_maximum(lhs: Value, rhs: Value) -> OpResult: element_type = get_type_or_element_type(lhs.type) if _is_float_type(element_type): result = arith_d.maximumf(lhs, rhs) elif _is_integer_like_type(element_type) and ( - element_type.is_signed() or element_type.is_signless() + element_type.is_signed or element_type.is_signless ): result = arith_d.maxsi(lhs, rhs) - elif _is_integer_like_type(element_type) and element_type.is_unsigned(): - result = arith_d.maxui(lhs, rhs) else: raise ValidationError( f"Found unhandled operand type for maximum: {element_type}" @@ -1232,8 +1343,6 @@ def handle_minimum(lhs: Value, rhs: Value) -> OpResult: element_type.is_signed() or element_type.is_signless() ): result = arith_d.minsi(lhs, rhs) - elif _is_integer_like_type(element_type) and element_type.is_unsigned(): - result = arith_d.minui(lhs, rhs) else: raise ValidationError( f"Found unhandled operand type for minimum: {element_type}" @@ -1543,8 +1652,22 @@ def handle_broadcast(emitter: WaveEmitter, node: fx.Node): # Only support broadcasting vector<1xdtype> for now. if not VectorType.isinstance(vector_type): raise NotImplementedError("Scalar src is not implemented yet for shuffleOp.") - assert vector_type.rank == 1 - assert vector_type.shape[0] == 1 + assert ( + vector_type.rank == 0 or vector_type.rank == 1 + ), f"expected vector_type.rank == 1 but got {vector_type}" + + if vector_type.rank == 0: + result_type = VectorType.get( + [bcast_dim_lane_dim_size], vector_type.element_type + ) + element = vector_d.extract(vector_src, static_position=[], dynamic_position=[]) + splat = vector_d.splat(result_type, element) + emitter.bind_node_proxy(node, IRProxyValue(splat)) + return + + assert ( + vector_type.shape[0] == 1 + ), f"expected vector_type.shape[0] == 1 but got {vector_type}" # Extract and Splat # If by chance broadcast size matches current size, we can return src. @@ -1562,6 +1685,18 @@ def handle_broadcast(emitter: WaveEmitter, node: fx.Node): ############################################################################### +@handle_op(select) +def handle_select(emitter: WaveEmitter, node: fx.Node): + try: + cond, if_true, if_false = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + + unwrap = lambda x: cast_py_value(emitter, x).ir_value + selected = arith_d.select(unwrap(cond), unwrap(if_true), unwrap(if_false)) + emitter.bind_node_proxy(node, IRProxyValue(selected)) + + @handle_op(get_result) def handle_get_result(emitter: WaveEmitter, node: fx.Node): try: @@ -1610,6 +1745,14 @@ def handle_cast(emitter: WaveEmitter, node: fx.Node): is_dst_float = _is_float_type(dst_elem_type) is_src_int = _is_integer_like_type(src_elem_type) is_dst_int = _is_integer_like_type(dst_elem_type) + if ( + is_src_int + and is_dst_int + and (_is_index_type(src_elem_type) or _is_index_type(dst_elem_type)) + ): + casted_vector = arith_d.index_cast(dst_vector_type, vector_src) + emitter.bind_node_proxy(node, IRProxyValue(casted_vector)) + return conversion_ops = { (True, False): arith_d.fptosi, diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 1a83e734f..63d8bf6c6 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -20,6 +20,7 @@ ReduceOp, Reduction, Reshape, + SelfIndex, Write, ) from .constraints import ( @@ -179,7 +180,7 @@ def has_gpr_offsets(node: fx.Node) -> bool: read more than a single element. """ custom = get_custom(node) - if not isinstance(custom, (Read, Write)): + if not isinstance(custom, (Read, Write, SelfIndex)): return False num_dims_with_gpr = sum( 1 for v in custom.index.values() if sympy.sympify(v.start).has(GPR_NUM) @@ -197,7 +198,13 @@ def has_gpr_offsets(node: fx.Node) -> bool: dim: simplify_index(custom.index.get(dim, custom.index[dim])) for dim in custom.index } - elements_per_thread = subs_idxc(custom.elements_per_thread) + if isinstance(custom, SelfIndex): + # If specified use element_per_thread instead of IndexExpr size. + elements_per_thread = ( + custom.elements_per_thread or custom.index[custom.dim].size + ) + else: + elements_per_thread = subs_idxc(custom.elements_per_thread) dim_with_gpr_offsets = [ (k, v.start) for k, v in simplified_index.items() if v.start.has(GPR_NUM) ] @@ -233,7 +240,7 @@ def has_gpr_offsets(node: fx.Node) -> bool: cur_gpr_start_id = chunk_id * gpr_size # Get updated index with VGPR offset. output_mapping = list(custom.index) - if custom.mapping is not None: + if hasattr(custom, "mapping") and custom.mapping is not None: output_mapping = list(custom.mapping.output_mapping.keys()) # Modify stride to 1 S.T we can have vectorized read/write # iff gpr_offset_dim is or will be (after mapping) fastest dim. @@ -273,6 +280,13 @@ def has_gpr_offsets(node: fx.Node) -> bool: mapping=custom.mapping, _write_dependency=custom._write_dependency, ).add_to_graph(custom.graph) + elif isinstance(custom, SelfIndex): + # iff elements_per_thread is specified, we update + # elements_per_thread to chunk size, else return None. + self_index_size = gpr_size if custom.elements_per_thread else None + new_node = SelfIndex( + custom.dim, custom.dtype, self_index_size + ).add_to_graph(custom.graph) # Update new_node information new_node.index = updated_index_with_gpr_offset @@ -283,7 +297,7 @@ def has_gpr_offsets(node: fx.Node) -> bool: if isinstance(custom, Write): # Useful to handle write/read dependency custom.replace_all_uses_with(ops_to_combine) - elif isinstance(custom, Read): + elif isinstance(custom, (Read, SelfIndex)): reshape = Reshape(ops_to_combine, custom.vector_shapes).add_to_graph( custom.graph ) diff --git a/iree/turbine/kernel/wave/templates/vanilla_attention.py b/iree/turbine/kernel/wave/templates/vanilla_attention.py index f1ca31720..2eb72f3e6 100644 --- a/iree/turbine/kernel/wave/templates/vanilla_attention.py +++ b/iree/turbine/kernel/wave/templates/vanilla_attention.py @@ -17,7 +17,10 @@ def get_vanilla_attention_kernel( - shape: AttentionShape, mfma_variant: MMAType, dynamic_dims: bool + shape: AttentionShape, + mfma_variant: MMAType, + dynamic_dims: bool, + is_causal: bool = False, ): # Input sizes B = tkl.sym.B @@ -81,6 +84,9 @@ def base_attention( c_reg = tkl.Register[B, N, M, tkl.f32](0.0) init_sum = tkl.Register[B, M, tkl.f32](0.0) init_max = tkl.Register[B, M, tkl.f32](-1e6) + if is_causal: + ZEROF = tkl.Register[M, K2, tkl.f32](0.0) + MIN_INF = tkl.Register[M, K2, tkl.f32](-1e6) # This microkernel encodes the fact that if the reduction # dimension were tiled, then we would need to materialize a loop. @@ -95,6 +101,17 @@ def repeat( k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0]) x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + if is_causal: + # Indices i and j broadcasted along K2 with a twist: + # here we use *static* information that is *implicitly* encoded + # in the *transformation*: under the distribution constraints + # specified we know that the shape [M] will eventually resolve + # to [1] and can thus be "cast + broadcast" to [K2]. + m_index = tkw.self_index(M, tkl.i64) + m_index = tkw.broadcast(m_index, target_shape=[M, K2]) + k2_index = tkw.self_index(K2, tkl.i64) + bias = tkw.select(m_index >= k2_index, ZEROF, MIN_INF) + x_j = x_j + bias m_j = tkw.max(x_j, partial_max, dim=K2) e_delta_max = tkw.exp2(partial_max - m_j) e_delta = tkw.exp2(x_j - m_j) diff --git a/lit_tests/kernel/wave/attention.py b/lit_tests/kernel/wave/attention.py index 21996924a..4633cbeed 100644 --- a/lit_tests/kernel/wave/attention.py +++ b/lit_tests/kernel/wave/attention.py @@ -913,6 +913,65 @@ def test_attention(): # CHECK-COUNT-8: {{.*}} = amdgpu.mfma +@run_test +def test_attention_causal(): + shape = AttentionShape( + num_query_heads=8, + num_kv_heads=8, + query_seq_len=128, + head_size_kv=128, + head_size=64, + kv_seq_len=256, + ) + mfma_variant = (tkw.MMAType.F32_16x16x16_F16,) * 2 + base_attention, hyperparams, _, _ = get_vanilla_attention_kernel( + shape, mfma_variant, False, is_causal=True + ) + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=False, + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + ): + torch.manual_seed(0) + q = torch.randn( + shape.num_query_heads, + shape.query_seq_len, + shape.head_size, + dtype=torch.float16, + ) + k = torch.randn( + shape.num_kv_heads, shape.kv_seq_len, shape.head_size, dtype=torch.float16 + ) + v = torch.randn( + shape.num_kv_heads, + shape.kv_seq_len, + shape.head_size_kv, + dtype=torch.float16, + ) + output = torch.zeros( + shape.num_query_heads, + shape.query_seq_len, + shape.head_size_kv, + dtype=torch.float32, + ) + print(base_attention(q, k, v, output).module_op) + + # CHECK-LABEL: func.func @base_attention + # CHECK: %[[NEG_INF:.+]] = arith.constant dense<-1.000000e+06> : vector<4xf32> + # CHECK: %[[ZERO:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> + # CHECK: {{.*}} = scf.for + # CHECK-COUNT-16: {{.*}} = amdgpu.mfma + # CHECK-COUNT-8: {{.*}} = arith.cmpi sge, {{.*}} : vector<4xi64> + # CHECK-COUNT-8: {{.*}} = arith.select %{{.*}}, %[[ZERO]], %[[NEG_INF]] : vector<4xi1>, vector<4xf32> + # CHECK-COUNT-8: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<4xf32> + # CHECK-COUNT-8: {{.*}} = gpu.shuffle xor {{.*}} + # CHECK-COUNT-8: {{.*}} = amdgpu.mfma + + @run_test def test_attention_bias(): shape = (8, 128, 128, 64, 256) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 85439ea0f..9f8a0c6ce 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -215,6 +215,73 @@ def read_write( # CHECK-SAME: vector<16xf16> +@run_test +def test_read_write_diagonal(): + # This test, tests for functionality of tkw.self_index, by + # generating code that generate a triangular matrix if M > N. + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16} + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def read_write_diagonal( + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + ZEROF = tkl.Register[M, N, tkl.f16](0.0) + ONEF = tkl.Register[M, N, tkl.f16](1.0) + m_index = tkw.self_index(M, tkl.i64) + m_index = tkw.broadcast(m_index, target_shape=[M, N]) + n_index = tkw.self_index(N, tkl.i64) + res = tkw.select(m_index >= n_index, ZEROF, ONEF) + tkw.write(res, c, elements_per_thread=16) + + with codegen_test_context(canonicalize=True): + c = torch.zeros(16, 16, dtype=torch.float16) + print(read_write_diagonal(c).module_op) + + # CHECK-LABEL: func.func @read_write_diagonal + # CHECK-SAME: (%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding) + # CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index + # CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index + # CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index + # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[ONE:.+]] = arith.constant dense<1.000000e+00> : vector<16xf16> + # CHECK-DAG: %[[ZERO:.+]] = arith.constant dense<0.000000e+00> : vector<16xf16> + # CHECK: %[[WORKGROUP_ID_0:.+]] = stream.dispatch.workgroup.id[0] : index + # CHECK: %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index + # CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x + # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y + # CHECK: %[[D1:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C16]] overflow : index + # CHECK: %[[D2:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D3:.+]] = arith.muli %[[D2]], %[[C16]] overflow : index + # CHECK: %[[D4:.+]] = arith.addi %[[D3]], %[[D1]] overflow : index + # CHECK: %[[BASE_INDEX_X:.+]] = arith.addi %[[D4]], %[[THREAD_ID_X]] overflow : index + # CHECK: %[[D5:.+]] = vector.step : vector<1xindex> + # CHECK: %[[D6:.+]] = arith.muli %[[D5]], %{{.*}} : vector<1xindex> + # CHECK: %[[D7:.+]] = vector.splat %[[BASE_INDEX_X]] : vector<1xindex> + # CHECK: %[[D8:.+]] = arith.addi %[[D6]], %[[D7]] : vector<1xindex> + # CHECK: %[[INDEX_X:.+]] = arith.index_cast %[[D8]] : vector<1xindex> to vector<1xi64> + # CHECK: %[[D10:.+]] = vector.extract %[[INDEX_X]][0] : i64 from vector<1xi64> + # CHECK: %[[BCAST_INDEX_X:.+]] = vector.splat %[[D10]] : vector<16xi64> + # CHECK: %[[D12:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C16]] overflow : index + # CHECK: %[[D13:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C32]] overflow : index + # CHECK: %[[BASE_INDEX_Y:.+]] = arith.addi %[[D13]], %[[D12]] overflow : index + # CHECK: %[[D15:.+]] = vector.step : vector<16xindex> + # CHECK: %[[D16:.+]] = vector.splat %[[BASE_INDEX_Y]] : vector<16xindex> + # CHECK: %[[D17:.+]] = arith.addi %[[D15]], %[[D16]] : vector<16xindex> + # CHECK: %[[INDEX_Y:.+]] = arith.index_cast %[[D17]] : vector<16xindex> to vector<16xi64> + # CHECK: %[[MASK:.+]] = arith.cmpi sge, %[[BCAST_INDEX_X]], %[[INDEX_Y]] : vector<16xi64> + # CHECK: %[[MASK_VAL:.+]] = arith.select %[[MASK]], %[[ZERO]], %[[ONE]] : vector<16xi1>, vector<16xf16> + # CHECK: %[[OUTPUT:.+]] = stream.binding.subspan %arg0[%c0] : !stream.binding -> memref<16x16xf16, strided<[16, 1], offset: ?>> + # CHECK: vector.store %[[MASK_VAL]], %[[OUTPUT]][%[[BASE_INDEX_X]], %[[BASE_INDEX_Y]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xf16> + + @run_test def test_read_write_masked(): constraints: list[tkw.Constraint] = [ @@ -1630,6 +1697,92 @@ def binary_lowerings( # CHECK: %[[MINIMUM:.+]] = arith.minimumf +@run_test +def test_int_comparisons(): + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16} + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + @tkw.wave(constraints) + def cmp_lowerings( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + ): + a_reg = tkw.read(a, elements_per_thread=4) + b_reg = tkw.read(b, elements_per_thread=4) + sgt = a_reg > b_reg + s1 = tkw.select(sgt, a_reg, b_reg) + slt = a_reg < b_reg + s2 = tkw.select(slt, a_reg, b_reg) + sge = s1 >= s2 + s3 = tkw.select(sge, s1, s2) + sle = s1 <= s2 + s4 = tkw.select(sle, s1, s2) + res = s1 + s2 + s3 + s4 + tkw.write(res, a, elements_per_thread=4) + + a = torch.randint(42, (16, 16), dtype=torch.int32) + b = torch.randint(42, (16, 16), dtype=torch.int32) + with codegen_test_context(): + print(cmp_lowerings(a, b).module_op) + # CHECK-LABEL: @cmp_lowerings + # CHECK: arith.cmpi sgt + # CHECK: arith.select + # CHECK: arith.cmpi slt + # CHECK: arith.select + # CHECK: arith.cmpi sge + # CHECK: arith.select + + +@run_test +def test_verbose_int_comparisons(): + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16} + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + @tkw.wave(constraints) + def verbose_cmp_lowerings( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + ): + a_reg = tkw.read(a, elements_per_thread=4) + b_reg = tkw.read(b, elements_per_thread=4) + sgt = tkw.gt(a_reg, b_reg) + s1 = tkw.select(sgt, a_reg, b_reg) + slt = tkw.lt(a_reg, b_reg) + s2 = tkw.select(slt, a_reg, b_reg) + sge = tkw.ge(s1, s2) + s3 = tkw.select(sge, s1, s2) + sle = tkw.le(s1, s2) + s4 = tkw.select(sle, s1, s2) + res = s1 + s2 + s3 + s4 + tkw.write(res, a, elements_per_thread=4) + + a = torch.randint(42, (16, 16), dtype=torch.int32) + b = torch.randint(42, (16, 16), dtype=torch.int32) + with codegen_test_context(): + print(verbose_cmp_lowerings(a, b).module_op) + # CHECK-LABEL: @verbose_cmp_lowerings + # CHECK: arith.cmpi sgt + # CHECK: arith.select + # CHECK: arith.cmpi slt + # CHECK: arith.select + # CHECK: arith.cmpi sge + # CHECK: arith.select + + # TODO: Something is broken in codegen and we are getting int in place of fx.Node # @launch @pytest.mark.skip(reason="getitem: Currently only stub implementation") diff --git a/tests/kernel/wave/attention/vanilla_attention_test.py b/tests/kernel/wave/attention/vanilla_attention_test.py index ace220180..1b32b04b1 100644 --- a/tests/kernel/wave/attention/vanilla_attention_test.py +++ b/tests/kernel/wave/attention/vanilla_attention_test.py @@ -121,6 +121,89 @@ def testAttention( assert_close(output, torch_ref, check_dtype=False, atol=1e-3, rtol=1e-3) +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("attention")) +@pytest.mark.parametrize("enable_scheduling", [False]) +@pytest.mark.parametrize("dynamic_dims", [False]) +@pytest.mark.parametrize( + "mfma_variant", + [ + (MMAType.F32_32x32x8_F16, MMAType.F32_32x32x8_F16), + (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16), + ], +) +def testAttentionCausal( + shape: tuple[int], + enable_scheduling: bool, + dynamic_dims: bool, + mfma_variant: tuple[MMAType], + request, +): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + shape = AttentionShape( + num_query_heads=shape[0], + num_kv_heads=shape[0], + query_seq_len=shape[1], + head_size_kv=shape[2], + head_size=shape[3], + kv_seq_len=shape[4], + ) + ( + base_attention, + hyperparams, + dynamic_symbols, + dynamic_symbols_map, + ) = get_vanilla_attention_kernel(shape, mfma_variant, dynamic_dims, is_causal=True) + q_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size) + k_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size) + v_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size_kv) + o_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size_kv) + hyperparams.update(get_default_scheduling_params()) + config = get_default_run_config() + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + compile_config = {"waves_per_eu": 2, "denorm_fp_math_f32": "preserve-sign"} + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + compile_config=compile_config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + dynamic_symbols=dynamic_symbols, + dynamic_symbols_map=dynamic_symbols_map, + ): + torch.manual_seed(1) + q = device_randn(q_shape, dtype=torch.float16) + k = device_randn(k_shape, dtype=torch.float16) + v = device_randn(v_shape, dtype=torch.float16) + output = device_zeros(o_shape, dtype=torch.float32) + log2e = 1.44269504089 + dk_sqrt = math.sqrt(1.0 / shape.head_size) + # TODO: Add scaling of QK as part of kernel. + # TODO: Add variant of non-transposed V attention kernel. + mb = base_attention(q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), output) + torch_ref = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=True + ) + + if dump_generated_mlir: + filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb.module_op.get_asm()) + + assert_close(output, torch_ref, check_dtype=False, atol=1e-3, rtol=1e-3) + + @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("attention")) @pytest.mark.parametrize("enable_scheduling", [False])