Skip to content

Commit

Permalink
[Wave] Add self_index, predicate, and selectOp to implement causal at…
Browse files Browse the repository at this point in the history
…tention (#452)

- Extracted core pieces of self_index, predicate, and selectOp, and LIT
for predicate and select written by @nicolasvasilache and @ftynse which
is required for causal mask and remove causal mask unrelated pieces.
- Implemented a numerically correct causal attention kernel based on
original from @nicolasvasilache
- Added GPR_NUM partitioning support for SelfIndex to allow causal to
work on more MMA intrinsics(i.e 32x32x8 which has GPR_NUMs)
- Refactored tkw.slt/sgt/sge/sle to be operator.lt/gt/ge/le to preserve
number of tkw ops and for user ergonomics
- Refactored vanilla kernel to support both in single kernel, controlled
by is_causal flag
- Add support on handle_op to take in multiple Ops that map to same
function.
 - Added a bunch of LIT tests

---------

Signed-off-by: Alex Zinenko <[email protected]>
Signed-off-by: Nicolas Vasilache <[email protected]>
Signed-off-by: Stanley Winata <[email protected]>
Co-authored-by: Alex Zinenko <[email protected]>
Co-authored-by: Nicolas Vasilache <[email protected]>
  • Loading branch information
3 people authored Feb 5, 2025
1 parent 915d24c commit 7038127
Show file tree
Hide file tree
Showing 9 changed files with 628 additions and 31 deletions.
4 changes: 4 additions & 0 deletions iree/turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ def _is_float_type(type):
return isinstance(type, (BF16Type, F16Type, F32Type, F64Type, Float8E4M3FNUZType))


def _is_index_type(type):
return isinstance(type, (IndexType))


def _is_integer_like_type(type):
return isinstance(type, (IntegerType, IndexType))

Expand Down
1 change: 1 addition & 0 deletions iree/turbine/kernel/_support/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def bitwidth(self):

bf16 = DataType("bf16")
bool = DataType("bool", "i1")
i1 = bool
i4 = DataType("i4")
i8 = DataType("i8")
i16 = DataType("i16")
Expand Down
147 changes: 135 additions & 12 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..lang.wave_types import Memory, Register, IndexMapping
from ..lang.global_symbols import *
from .._support.indexing import IndexExpr, IndexSymbol, IndexSequence
from .._support.dtype import DataType
from .._support.dtype import DataType, i1
from .._support.regions import RegionGraph
from .base import OpDispatcher
import numpy as np
Expand All @@ -45,6 +45,14 @@ def allocate(
...


def self_index(
idx: IndexExpr,
dtype: DataType,
elements_per_thread: Optional[IndexExpr | int] = None,
) -> "Register":
...


def extract(
register: "Register",
offsets: tuple[IndexExpr],
Expand Down Expand Up @@ -166,6 +174,22 @@ def shuffle(src: "Register", offset: int, width: int) -> "Register":
...


def gt(lhs: "Register", rhs: "Register") -> "Register":
...


def ge(lhs: "Register", rhs: "Register") -> "Register":
...


def lt(lhs: "Register", rhs: "Register") -> "Register":
...


def le(lhs: "Register", rhs: "Register") -> "Register":
...


def cast(src: "Register", dtype: DataType) -> "Register":
...

Expand All @@ -178,6 +202,10 @@ def reshape(inputs: Sequence["Register"]) -> "Register":
...


def select(cond: "Register", if_true: "Register", if_false: "Register") -> "Register":
...


def define_op(op_name: str) -> Callable[[T], T]:
def decorator(cls: T) -> T:
cls.tkw_op_name = op_name
Expand Down Expand Up @@ -680,14 +708,8 @@ def transform_index(
return index


@define_py_op(operator.add)
@define_py_op(operator.sub)
@define_py_op(operator.mul)
@define_py_op(operator.truediv)
@define_interface_op("maximum")
@define_interface_op("minimum")
@dataclass
class BinaryPyOp(CustomOp, ABC):
class BinaryOpBase(CustomOp, ABC):
"""
Represents an elementwise binary python operator.
Expand Down Expand Up @@ -715,21 +737,51 @@ def indexing_dims(self) -> list[IndexSymbol]:
def py_operator(self) -> str:
return self.tkw_op_name

def infer_type(self):
def infer_shape(self) -> Any:
lhs_type = get_custom(self.lhs).type
rhs_type = get_custom(self.rhs).type
has_same_type = has_same_custom_type(lhs_type, rhs_type)
if has_same_type:
self.type = lhs_type
return
return lhs_type.symbolic_shape

lhs_dim_set = set(lhs_type.symbolic_shape)
rhs_dim_set = set(rhs_type.symbolic_shape)
if lhs_dim_set.isdisjoint(rhs_dim_set):
raise ValueError(
"BinaryPyOp requires lhs and rhs shape to be at least broadcastable."
f" got {lhs_type.symbolic_shape} vs {rhs_type.symbolic_shape}"
)

# TODO: this logic looks suspicious. Specifically, there's no check that
# rhs_dim_set subsumes lhs_dim_set, they may partially overlap.
broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type
self.type = broadcasted_type
return broadcasted_type.symbolic_shape


@define_py_op(operator.add)
@define_py_op(operator.sub)
@define_py_op(operator.mul)
@define_py_op(operator.truediv)
@define_interface_op("maximum")
@define_interface_op("minimum")
@dataclass
class BinaryPyOp(BinaryOpBase, ABC):
def infer_type(self):
self.type = Register[(*self.infer_shape(), get_custom(self.lhs).type.dtype)]


@define_py_op(operator.gt)
@define_py_op(operator.ge)
@define_py_op(operator.lt)
@define_py_op(operator.le)
@define_interface_op("gt")
@define_interface_op("ge")
@define_interface_op("lt")
@define_interface_op("le")
@dataclass
class ComparisonPyOp(BinaryOpBase, ABC):
def infer_type(self):
self.type = Register[(*self.infer_shape(), i1)]


@define_interface_op("log2")
Expand Down Expand Up @@ -759,6 +811,42 @@ def infer_type(self):
self.type = src_type


@define_op("select")
@dataclass
class SelectOp(CustomOp):
cond: fx.Node
if_true: fx.Node
if_false: fx.Node

@property
def indexing_dims(self) -> list[IndexSymbol]:
combined_dims = []
combined_dims += get_custom(self.cond).indexing_dims
combined_dims += get_custom(self.if_true).indexing_dims
combined_dims += get_custom(self.if_false).indexing_dims
return list(dict.fromkeys(combined_dims))

def infer_type(self):
cond_type = get_custom(self.cond).type
if_true_type = get_custom(self.if_true).type
if_false_type = get_custom(self.if_false).type

if cond_type.dtype != i1:
raise ValueError("SelectOp expects condition type to be i1.")

if if_true_type.dtype != if_false_type.dtype:
raise ValueError("SelectOp expects lhs and rhs dtype to match.")

# TODO: support broadcasting behavior.
if (
cond_type.symbolic_shape != if_true_type.symbolic_shape
or cond_type.symbolic_shape != if_false_type.symbolic_shape
):
raise ValueError("SelectOp doesn't support broadcasting. (yet?)")

self.type = if_true_type


@final
@dataclass
class Unknown(CustomOp):
Expand Down Expand Up @@ -940,6 +1028,22 @@ def type(self) -> "Memory":
return Memory[(*self.shape, self.address_space, self.dtype)]


@define_op("self_index")
@dataclass
class SelfIndex(CustomOp):
dim: IndexExpr
dtype: DataType
elements_per_thread: Optional[IndexExpr | int] = None

@property
def indexing_dims(self) -> list[IndexSymbol]:
return [self.dim]

@property
def type(self) -> "Register":
return Register[(self.dim, self.dtype)]


@define_op("shared_memory_barrier")
@dataclass
class SharedMemoryBarrier(CustomOp):
Expand Down Expand Up @@ -1657,6 +1761,25 @@ class Broadcast(CustomOp, ABC):
arg: fx.Node
target_shape: Sequence[IndexSymbol] = None

def __post_init__(self):
# Required for setting up hash.
super().__post_init__()
# Verify for valid src type.
if isinstance(self.arg, fx.Node):
src = self.arg
elif isinstance(self.arg, fx.Proxy):
src = self.arg.node
else:
raise ValueError(f"Unexpected broadcast src type of {type(self.arg)}")

# Verifies target broadcast shape is valid.
src_type = get_custom(src).type
src_shape = set(getattr(src_type, "symbolic_shape", []))
dst_shape = set(self.target_shape)
assert src_shape.issubset(
dst_shape
), "Fail to initialize broadcast because of invalid target_shape."

@property
def indexing_dims(self) -> list[IndexSymbol]:
return self.target_shape
Expand Down
Loading

0 comments on commit 7038127

Please sign in to comment.