Skip to content

Commit

Permalink
Global to shared gathers
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Feb 14, 2025
1 parent 3ccd679 commit df7f3c7
Show file tree
Hide file tree
Showing 10 changed files with 767 additions and 86 deletions.
9 changes: 5 additions & 4 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,7 @@ class Allocate(CustomOp):
distributed_shape: tuple[IndexExpr]
dtype: DataType
address_space: AddressSpace
padding: int = 0

@property
def indexing_dims(self) -> list[IndexSymbol]:
Expand Down Expand Up @@ -1288,7 +1289,7 @@ def has_identity_mapping(self) -> bool:
if mapping is None:
return True

mem_shape = self.memory.type.symbolic_shape
mem_shape = get_custom(self.memory).type.symbolic_shape
if mapping.is_identity() and mapping.input_shape == mem_shape:
return True

Expand All @@ -1303,7 +1304,7 @@ def is_contiguous_vec(self) -> bool:

mapping = self.mapping

mem_shape = self.memory.type.symbolic_shape
mem_shape = get_custom(self.memory).type.symbolic_shape

from ..wave.utils import check_is_mapping_contiguous

Expand Down Expand Up @@ -1588,7 +1589,7 @@ def has_identity_mapping(self) -> bool:
if mapping is None:
return True

mem_shape = self.memory.type.symbolic_shape
mem_shape = get_custom(self.memory).type.symbolic_shape
if mapping.is_identity() and mapping.output_shape == mem_shape:
return True

Expand All @@ -1602,7 +1603,7 @@ def is_contiguous_vec(self) -> bool:
return True
mapping = self.mapping

mem_shape = self.memory.type.symbolic_shape
mem_shape = get_custom(self.memory).type.symbolic_shape

from ..wave.utils import check_is_mapping_contiguous

Expand Down
27 changes: 17 additions & 10 deletions iree/turbine/kernel/wave/codegen/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def handle_register(emitter: WaveEmitter, node: fx.Node):
@handle_op(allocate)
def handle_allocate(emitter: WaveEmitter, node: fx.Node):
try:
shape, distributed_shape, dtype, address_space = node.args
shape, distributed_shape, dtype, address_space, padding = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
memref_shape = cast_py_literal(emitter, distributed_shape)
Expand Down Expand Up @@ -800,16 +800,23 @@ def handle_extract(emitter: WaveEmitter, node: fx.Node):
register, offset = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
assert isinstance(offset, list) and len(offset) == 1
extract_vector = cast_vector(emitter, register)
result_type = VectorType.get([1], extract_vector.type.element_type)
element = vector_d.extract_strided_slice(
result_type,
extract_vector,
offset,
[1],
[1],
assert (
isinstance(offset, list) and len(offset) == 1 or isinstance(offset, IndexExpr)
)
extract_vector = cast_vector(emitter, register)
if isinstance(offset, IndexExpr):
# For dynamic offsets, we have to use vector.extractelement.
offset = gen_sympy_index(add_emitter_subs(emitter), offset)
element = vector_d.extractelement(extract_vector, position=offset)
else:
result_type = VectorType.get([1], extract_vector.type.element_type)
element = vector_d.extract_strided_slice(
result_type,
extract_vector,
offset,
[1],
[1],
)

emitter.bind_node_proxy(node, IRProxyValue(element))

Expand Down
41 changes: 26 additions & 15 deletions iree/turbine/kernel/wave/codegen/read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@
cast_vector,
)

from ...ops.wave_ops import (
get_custom,
read,
write,
)
from ...ops.wave_ops import get_custom, read, write, CustomOp

from ..utils import safe_subs, subs_idxc, find_index_bounds

Expand Down Expand Up @@ -173,27 +169,28 @@ def _constant_mask(vec_type: IrType) -> Value:

def _construct_gather_scatter_indices(
emitter: WaveEmitter,
symbolc_shape: tuple[IndexExpr],
symbolic_shape: tuple[IndexExpr],
index: tuple[IndexExpr],
mapping: IndexMapping,
elements_per_thread: int,
is_read: bool,
dynamic_vals: tuple[Any, ...],
is_contiguous: bool,
memory: CustomOp,
) -> tuple[list[OpResult], list[OpResult], list[OpResult], OpResult, OpResult]:
# Apply symbolc_shape order to indices, e.g. if original mapping is
# {M: iter(0), N: iter(1)} and symbolc_shape is (N, M), result will
# Apply symbolic_shape order to indices, e.g. if original mapping is
# {M: iter(0), N: iter(1)} and symbolic_shape is (N, M), result will
# be (iter(1), iter(0))
if is_read:
assert (
mapping.is_output_identity()
), "non-identity output mapping is not supported yet"
index_mapping = mapping.map_input_indices(symbolc_shape)
index_mapping = mapping.map_input_indices(symbolic_shape)
else:
assert (
mapping.is_input_identity()
), "non-identity input mapping is not supported yet"
index_mapping = mapping.map_output_indices(symbolc_shape)
index_mapping = mapping.map_output_indices(symbolic_shape)

idxc = IndexingContext.current()
index_mapping = tuple(i.subs(idxc.subs) for i in index_mapping)
Expand All @@ -208,7 +205,7 @@ def _construct_gather_scatter_indices(

# Contruct input/output index, substituting iterators in input mapping with
# expanded index.
result_index = {key: m.subs(subs) for key, m in zip(symbolc_shape, index_mapping)}
result_index = {key: m.subs(subs) for key, m in zip(symbolic_shape, index_mapping)}

mask = _build_mask(emitter, index, elements_per_thread)
if mask is None:
Expand Down Expand Up @@ -245,7 +242,9 @@ def extract0(src):
need_dynamic_offsets = True

offsets = []
strides = strides_from_symbolic_shape(idxc, symbolc_shape, allow_mixed_shapes=True)
if memory.type.address_space == SHARED_ADDRESS_SPACE:
symbolic_shape = memory.distributed_shape
strides = strides_from_symbolic_shape(idxc, symbolic_shape, allow_mixed_shapes=True)
start_indices_offset = _compute_offset(start_indices, strides)
for i in range(elements_per_thread):
# Update fastest dim, i.e. in case of identity mapping it will
Expand Down Expand Up @@ -280,7 +279,7 @@ def extract0(src):
if need_dynamic_offsets:
# In case we need dynamic `offsets_vec`, set all `start_indices` to 0
# and encode entire index info in `offsets_vec`.
result_index = {key: 0 for key in symbolc_shape}
result_index = {key: 0 for key in symbolic_shape}
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, result_index, dynamic_vals_map_start
)
Expand Down Expand Up @@ -401,6 +400,7 @@ def _create_vec_read(
start_indices_wg: tuple[Value],
start_indices_th: tuple[Value],
elements_per_thread: int,
memory: CustomOp,
mask: Optional[Value],
offsets_vec: Optional[Value],
) -> Value:
Expand All @@ -421,6 +421,8 @@ def _create_vec_read(
zero = get_constant_attr(0, element_type)
zero = arith_d.constant(element_type, zero)

if memory.type.address_space == SHARED_ADDRESS_SPACE:
symbolic_shape = memory.distributed_shape
strides = strides_from_symbolic_shape(
IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True
)
Expand Down Expand Up @@ -512,6 +514,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
start_indices_wg,
start_indices_th,
elements_per_thread,
get_custom(memory),
mask,
offsets_vec=None,
)
Expand All @@ -527,13 +530,14 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
mask,
) = _construct_gather_scatter_indices(
emitter=emitter,
symbolc_shape=input_shape,
symbolic_shape=input_shape,
index=index,
mapping=mapping,
elements_per_thread=elements_per_thread,
is_read=True,
dynamic_vals=dyn_vals,
is_contiguous=get_custom(node).is_contiguous_vec(),
memory=get_custom(memory),
)
result = _create_vec_read(
emitter,
Expand All @@ -544,6 +548,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
start_indices_wg,
start_indices_th,
elements_per_thread,
get_custom(memory),
mask,
offsets_vec,
)
Expand All @@ -560,6 +565,7 @@ def _create_vec_write(
start_indices_wg: tuple[Value],
start_indices_th: tuple[Value],
elements_per_thread: int,
memory: CustomOp,
mask: Optional[Value],
offsets_vec: Optional[Value],
):
Expand All @@ -579,6 +585,8 @@ def _create_vec_write(
offsets_vec_type, DenseElementsAttr.get(vals, offsets_vec_type)
)

if memory.type.address_space == SHARED_ADDRESS_SPACE:
symbolic_shape = memory.distributed_shape
strides = strides_from_symbolic_shape(
IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True
)
Expand Down Expand Up @@ -661,6 +669,7 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
start_indices_wg,
start_indices_th,
elements_per_thread,
get_custom(memory),
mask,
offsets_vec=None,
)
Expand All @@ -680,13 +689,14 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
mask,
) = _construct_gather_scatter_indices(
emitter=emitter,
symbolc_shape=output_shape,
symbolic_shape=output_shape,
index=index,
mapping=mapping,
elements_per_thread=elements_per_thread,
is_read=False,
dynamic_vals=dyn_vals,
is_contiguous=get_custom(node).is_contiguous_vec(),
memory=get_custom(memory),
)

_create_vec_write(
Expand All @@ -698,6 +708,7 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
start_indices_wg,
start_indices_th,
elements_per_thread,
get_custom(memory),
mask,
offsets_vec,
)
Loading

0 comments on commit df7f3c7

Please sign in to comment.