Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def can_be_applied(
return False

# This avoids that we have to modify the subsets in a fancy way.
# TODO(phimuell): Lift this limitation.
if len(a1_desc.shape) != len(a2_desc.shape):
return False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any, Container, Optional, Sequence, TypeVar, Union

import dace
from dace import data as dace_data, subsets as dace_sbs, symbolic as dace_sym
from dace import data as dace_data, libraries as dace_lib, subsets as dace_sbs, symbolic as dace_sym
from dace.sdfg import graph as dace_graph, nodes as dace_nodes
from dace.transformation import pass_pipeline as dace_ppl
from dace.transformation.passes import analysis as dace_analysis
Expand Down Expand Up @@ -405,9 +405,9 @@ def reroute_edge(
"""
current_memlet: dace.Memlet = current_edge.data
if is_producer_edge:
# NOTE: See the note in `_reconfigure_dataflow()` why it is not save to
# use the `get_{dst, src}_subset()` function, although it would be more
# appropriate.
# NOTE: See the note in `reconfigure_dataflow_after_rerouting()` why it is not
# safe to use the `get_{dst, src}_subset()` function, although it would be
# more appropriate.
assert current_edge.dst is old_node
current_subset: dace_sbs.Range = current_memlet.dst_subset
new_src = current_edge.src
Expand Down Expand Up @@ -503,6 +503,10 @@ def reconfigure_dataflow_after_rerouting(
old_node: The old that was involved in the old, rerouted, edge.
new_node: The new node that should be used instead of `old_node`.
"""

# NOTE: The base assumption of this function is that the subset on the side of
# `new_node` is already correct and we have to adjust the subset on the side
# of `other_node`.
other_node = new_edge.src if is_producer_edge else new_edge.dst

if isinstance(other_node, dace_nodes.AccessNode):
Expand Down Expand Up @@ -565,6 +569,21 @@ def reconfigure_dataflow_after_rerouting(
# the full array, but essentially slice a bit.
pass

elif isinstance(other_node, dace_lib.standard.Reduce):
# For now we only handle the case that the reduction node is writing into
# `new_node`, before the data was written into `old_node`. In that case
# there is nothing to do, we just do some checks.
# TODO(phimuell): This about how to handle the other case or how to extend
# to other library nodes.

if not is_producer_edge:
raise ValueError("Reduction nodes are only supported as output.")
assert isinstance(new_node, dace_nodes.AccessNode)

# The subset at the reduction node needs to be `None`, which means undefined.
other_subset = new_edge.data.src_subset if is_producer_edge else new_edge.data.dst_subset
assert other_subset is None

else:
# As we encounter them we should handle them case by case.
raise NotImplementedError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import numpy as np

dace = pytest.importorskip("dace")

from dace import libraries as dace_libnode
from dace.sdfg import nodes as dace_nodes

from gt4py.next.program_processors.runners.dace import (
Expand Down Expand Up @@ -355,6 +357,105 @@ def _make_a1_has_output_sdfg() -> dace.SDFG:
return sdfg


def _make_copy_chain_with_reduction_node(
output_an_array: bool,
) -> tuple[
dace.SDFG,
dace.SDFGState,
dace_libnode.Reduce,
dace_nodes.AccessNode,
dace_libnode.standar.Reduce,
dace_nodes.AccessNode,
dace_nodes.AccessNode,
]:
sdfg = dace.SDFG(util.unique_name("copy_chain_remover_with_reduction_sdfg"))
state = sdfg.add_state(is_start_block=True)

if output_an_array:
input_shape = (10, 3, 20)
output_shape = (10, 2, 20)
reduce_axes = [1]
# Actually we should use `(10, 20)` as shape, but then the transformation
# does not apply, this is a limitation in the transformation.
acc_shape = (10, 1, 20)
else:
input_shape = (3,)
acc_shape = () # Is a scalar.
output_shape = (2,)
reduce_axes = None

for i in range(2):
sdfg.add_array(
f"data{i}",
shape=input_shape,
dtype=dace.float64,
transient=False,
)
if output_an_array:
sdfg.add_array(
f"acc{i}",
shape=acc_shape,
dtype=dace.float64,
transient=True,
)
else:
sdfg.add_scalar(
f"acc{i}",
dtype=dace.float64,
transient=True,
)
sdfg.add_array(
"output",
shape=output_shape,
dtype=dace.float64,
transient=False,
)

accumulators: list[dace_nodes.AccessNode] = []
reducers: list[dace_libnode.standard.Reduce] = []
for i in range(2):
data_ac = state.add_access(f"data{i}")
acc_ac = state.add_access(f"acc{i}")
reduce_node = state.add_reduce(
wcr="lambda x, y: x + y",
axes=reduce_axes,
identity=0.0,
)
state.add_nedge(
data_ac,
reduce_node,
dace.Memlet(f"{data_ac.data}[" + ", ".join(f"0:{s}" for s in input_shape) + "]"),
)
if output_an_array:
state.add_nedge(
reduce_node,
acc_ac,
dace.Memlet(f"{acc_ac.data}[0:{acc_shape[0]}, 0, 0:{acc_shape[-1]}]"),
)
else:
state.add_nedge(reduce_node, acc_ac, dace.Memlet(f"{acc_ac.data}[0]"))
accumulators.append(acc_ac)
reducers.append(reduce_node)

red0, red1 = reducers
output_ac = state.add_access("output")

for i, acc in enumerate(accumulators):
if output_an_array:
state.add_nedge(
acc,
output_ac,
dace.Memlet(
f"{acc.data}[0:{output_shape[0]}, 0, 0:{output_shape[-1]}] -> [0:{output_shape[0]}, {i}, 0:{output_shape[-1]}]"
),
)
else:
state.add_nedge(acc, output_ac, dace.Memlet(f"{acc.data}[0] -> [{i}]"))
sdfg.validate()

return sdfg, state, red0, accumulators[0], red1, accumulators[1], output_ac


def test_simple_linear_chain():
sdfg = _make_simple_linear_chain_sdfg()

Expand Down Expand Up @@ -567,3 +668,102 @@ def inner_ref(i0, o0):
# Now run the transformed SDFG to see if the same output is generated.
util.compile_and_run_sdfg(sdfg, **res)
assert all(np.allclose(ref[name], res[name]) for name in ref.keys())


@pytest.mark.parametrize("output_an_array", [False, True])
def test_copy_chain_remover_with_reduction(output_an_array: bool):
sdfg, state, red0, acc0, red1, acc1, output = _make_copy_chain_with_reduction_node(
output_an_array
)

def apply_to(a1, a2):
candidate = {
gtx_transformations.CopyChainRemover.node_a1: a1,
gtx_transformations.CopyChainRemover.node_a2: a2,
}
copy_chain_remover = gtx_transformations.CopyChainRemover(
single_use_data={sdfg: {acc0.data, acc1.data}},
)
copy_chain_remover.setup_match(
sdfg=sdfg,
cfg_id=state.parent_graph.cfg_id,
state_id=state.block_id,
subgraph=candidate,
expr_index=0,
override=True,
)
assert copy_chain_remover.can_be_applied(state, 0, sdfg, permissive=False)
copy_chain_remover.apply(state, sdfg)

assert sdfg.number_of_nodes() == 1
assert state.number_of_nodes() == 7

assert all(e.dst is acc0 for e in state.out_edges(red0))
assert state.in_degree(acc0) == 1
assert state.out_degree(acc0) == 1

assert all(e.dst is acc1 for e in state.out_edges(red1))
assert state.in_degree(acc1) == 1
assert state.out_degree(acc1) == 1

assert all(e.src in [acc0, acc1] for e in state.in_edges(output))
assert state.out_degree(output) == 0

# Now remove the `acc0` intermediate.
apply_to(a1=acc0, a2=output)

# The accumulator `acc0` has been removed.
sdfg.validate()
assert state.number_of_nodes() == 6

access_nodes: list[dace_nodes.AccessNode] = util.count_nodes(sdfg, dace_nodes.AccessNode, True)
assert len(access_nodes) == 4
assert acc0 not in access_nodes
assert acc1 in access_nodes
assert output in access_nodes

assert state.out_degree(red0) == 1
red0_oedge: dace.sdfg.graph.MultiDiConnectorGraph[dace.Memlet] = next(
iter(state.out_edges(red0))
)
assert red0_oedge.dst is output
red0_oedge_mlet: dace.Memlet = red0_oedge.data
assert red0_oedge_mlet.src_subset is None

if output_an_array:
assert len(red0_oedge_mlet.dst_subset) == 3
assert red0_oedge_mlet.dst_subset == dace.subsets.Range.from_string("0:10, 0, 0:20")

else:
assert len(red0_oedge_mlet.dst_subset) == 1
assert red0_oedge_mlet.dst_subset == dace.subsets.Range.from_string("0")

assert state.out_degree(red1) == 1
assert all(e.dst is acc1 for e in state.out_edges(red1))

# Now the accumulator `acc1` will be removed.
apply_to(a1=acc1, a2=output)

sdfg.validate()
assert state.number_of_nodes() == 5

access_nodes = util.count_nodes(sdfg, dace_nodes.AccessNode, True)
assert len(access_nodes) == 3
assert acc0 not in access_nodes
assert acc1 not in access_nodes
assert output in access_nodes

assert state.out_degree(red1) == 1
red1_oedge: dace.sdfg.graph.MultiDiConnectorGraph[dace.Memlet] = next(
iter(state.out_edges(red1))
)
assert red1_oedge.dst is output
red1_oedge_mlet: dace.Memlet = red1_oedge.data
assert red1_oedge_mlet.src_subset is None

if output_an_array:
assert len(red1_oedge_mlet.dst_subset) == 3
assert red1_oedge_mlet.dst_subset == dace.subsets.Range.from_string("0:10, 1, 0:20")
else:
assert len(red1_oedge_mlet.dst_subset) == 1
assert red1_oedge_mlet.dst_subset == dace.subsets.Range.from_string("1")
Loading
Loading