Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -553,18 +553,21 @@ def _gt_auto_process_top_level_maps(
dace_sdutils.canonicalize_memlet_trees(sdfg)
dace_propagation.propagate_memlets_sdfg(sdfg)

sdfg.apply_transformations_repeated(
[
# TODO(phimuell): The transformation is also active inside Maps.
# Which is against the description of this function, but it should
# not matter that much.
gtx_transformations.SplitAccessNode(
single_use_data=single_use_data,
),
gtx_transformations.GT4PyMapBufferElimination(
assume_pointwise=assume_pointwise,
),
],
# Split the top level AccessNodes.
# NOTE: This function will also update `single_use_data`.
gtx_transformations.gt_split_access_nodes(
sdfg=sdfg,
validate=False,
validate_all=True,
single_use_data=single_use_data,
)

# Perform buffer elimination.
# TODO(phimuell): Implement a faster matching.
sdfg.apply_transformations_once_everywhere(
gtx_transformations.GT4PyMapBufferElimination(
assume_pointwise=assume_pointwise,
),
validate=False,
validate_all=validate_all,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from typing import Any, Iterable, Optional

import dace
from dace import properties as dace_properties, transformation as dace_transformation
from dace import (
properties as dace_properties,
subsets as dace_sbs,
transformation as dace_transformation,
)
from dace.sdfg import graph as dace_graph, nodes as dace_nodes
from dace.transformation.passes import analysis as dace_analysis

Expand All @@ -28,14 +32,20 @@ def gt_split_access_nodes(
) -> Optional[int]:
"""Applies the `SplitAccessNode` transformation to the SDFG.

This function should be the preferred way to run the `SplitAccessNode`
transformation. Since it will ensure that the single data is only computed
once. Furthermore, it guarantees that the transformations are applied in
a deterministic order.

The transformation returns the number of AccessNodes that have been split.

Args:
sdfg: The SDFG to process.
validate: Perform validation after the pass has run.
validate_all: Perform extensive validation.
single_use_data: Which data descriptors are used only once.
If not passed the function will run `FindSingleUseData`.
If not passed the function will run `FindSingleUseData`, if passed the
function will update its content and add the newly generated data.
"""

# To ensures that the `{src,dst}_subset` are properly set, run initialization.
Expand All @@ -48,12 +58,88 @@ def gt_split_access_nodes(
find_single_use_data = dace_analysis.FindSingleUseData()
single_use_data = find_single_use_data.apply_pass(sdfg, None)

return sdfg.apply_transformations_repeated(
SplitAccessNode(single_use_data=single_use_data),
validate=validate,
validate_all=validate_all,
apply_count = 0
for nsdfg in sdfg.all_sdfgs_recursive():
apply_count += _apply_split_access_node_non_recursive(
sdfg=nsdfg,
validate=validate,
validate_all=validate_all,
single_use_data=single_use_data[nsdfg],
)

return apply_count


def _apply_split_access_node_non_recursive(
sdfg: dace.SDFG,
validate: bool,
validate_all: bool,
single_use_data: set[str],
) -> int:
apply_count = 0
if len(single_use_data) == 0:
return apply_count

# The splitter transformation. Note that we set `assume_single_use_data` to `True`
# because we do this test outside.
access_node_splitter = gtx_transformations.SplitAccessNode(
assume_single_use_data=True,
)

# Since the transformation only applies to single use data, the order in which the
# states are processed is irrelevant. Furthermore, the fragments generated through
# a node that was split, should never be split (as long as this function runs),
# because otherwise that split should have been done in the initial split.
for state in sdfg.states():
state_cfg_id = state.parent_graph.cfg_id
state_id = state.block_id
scope_dict = state.scope_dict()

# Now find all single use data, located at the top level in this state.
# Note single use data is classified by having only one node that is
# referring to it, thus a `set` is safe.
access_nodes_to_process = sorted(
(
dnode
for dnode in state.data_nodes()
if dnode.data in single_use_data and scope_dict[dnode] is None
),
key=lambda dnode: dnode.data,
)
assert len(access_nodes_to_process) == len(set(access_nodes_to_process))

if len(access_nodes_to_process) == 0:
# Nothing to process in this state, continue.
continue

# Now try to split all candidates that we have found.
for access_node_to_process in access_nodes_to_process:
access_node_splitter.setup_match(
sdfg=sdfg,
cfg_id=state_cfg_id,
state_id=state_id,
subgraph={gtx_transformations.SplitAccessNode.access_node: access_node_to_process},
expr_index=0,
override=True,
)
if access_node_splitter.can_be_applied(
graph=state, expr_index=0, sdfg=sdfg, permissive=False
):
splitted_access_nodes = access_node_splitter.apply(graph=state, sdfg=sdfg)
if validate_all:
# Not super correct as we not check at the top of the hierarchy.
sdfg.validate()

# We have to update `single_use_data`. By definition all data that we
# generate through splitting is also single use data.
single_use_data.update(sac.data for sac in splitted_access_nodes.values())
apply_count += 1

if validate:
sdfg.validate()

return apply_count


@dace_properties.make_properties
class SplitAccessNode(dace_transformation.SingleStateTransformation):
Expand Down Expand Up @@ -81,6 +167,11 @@ class SplitAccessNode(dace_transformation.SingleStateTransformation):
be described by two producers.
- Create a version that is able to split over multiple states. This is
mostly useful to enable more state fusion.

Note:
The actual split operation is performed using `splitting_tools.split_node()`.
Furthermore, as a special extension, to support certain workflows the
`apply()` function returns the return value of that function.
"""

access_node = dace_transformation.PatternNode(dace_nodes.AccessNode)
Expand Down Expand Up @@ -171,7 +262,7 @@ def apply(
self,
graph: dace.SDFGState,
sdfg: dace.SDFG,
) -> None:
) -> dict[dace_sbs.Subset, dace_nodes.AccessNode]:
access_node: dace_nodes.AccessNode = self.access_node

edge_reassignments = self._find_edge_reassignment(graph)
Expand All @@ -194,17 +285,22 @@ def apply(

# We have to clean up the isolated fragments. This is because we specified
# `allow_to_bypass_nodes` in the call above.
for ac in fragment_access_nodes.values():
for split_sbs in list(fragment_access_nodes.keys()):
ac = fragment_access_nodes[split_sbs]
if graph.degree(ac) == 0:
graph.remove_node(ac)
sdfg.remove_data(ac.data, validate=False)
fragment_access_nodes.pop(split_sbs)

# NOTE: In some situation it happens that when a producer writes
# something inside `access_node` and the data is never read. This is
# not an error, but can be a side effect of MapFusion or similar
# transformations. This will lead to dead data flow, that we will
# not remove. Instead DDE should be run.

# Special extension to support certain workflows.
return fragment_access_nodes

def _find_edge_reassignment(
self,
state: dace.SDFGState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ def split_node(
if already_reconfigured_nodes is None:
already_reconfigured_nodes = set()

# Sort the split descriptions such that they are processed in a deterministic way.
# NOTE: Turning them into a string is the best solution is probably the only way
# to achieve some stability. The only downside is that the order now depends
# on the specialization level that is used, i.e. to we have numbers or symbols.
split_description = sorted(split_description, key=lambda split: str(split))

desc_to_split = node_to_split.desc(sdfg)
assert desc_to_split.transient
assert not gtx_transformations.utils.is_view(desc_to_split)
Expand Down Expand Up @@ -334,6 +340,13 @@ def split_edge(
# detail that does not limit the applicability of this function.
# TODO(phimuell): Implements some check that nothing is lost.

# Bring the split description in a deterministic order.
# NOTE: See note in `split_node()` why the sorting is done in this way.
# NOTE: The main benefit of bringing `split_description` into a deterministic
# order is that the output of this function is deterministic as well. I am
# not sure if there is any benefit beside that.
split_description = sorted(split_description, key=lambda split: str(split))

assert isinstance(edge_to_split.src, dace_nodes.AccessNode)
assert not isinstance(edge_to_split.src.desc(sdfg), dace_data.View)
assert isinstance(edge_to_split.src.desc(sdfg), dace_data.Array)
Expand Down Expand Up @@ -366,6 +379,8 @@ def split_edge(
new_fully_splitted_subsets.append(consumer)
fully_splitted_subsets = new_fully_splitted_subsets

# Allocate the return `dict` the order is important, first the reordered split
# description followed by the `None` key.
new_edges: dict[Union[dace_sbs.Range, None], dace_graph.MultiConnectorEdge] = {
split: set() for split in split_description
}
Expand Down Expand Up @@ -842,12 +857,7 @@ def _perform_node_split_with_bypass_impl(
edges_to_relocate: set[EdgeConnectionSpec],
already_reconfigured_nodes: set[tuple[dace_nodes.Node, str]],
) -> list[dace_graph.MultiConnectorEdge]:
"""Performs the splitting but the edge might go directly to the consumer.

# TODO: Remove the producer edge, run reconfiguration, split operation.
# TODO ADDING PRODUCER TO THE SET OF PROCESSED NODES

"""
"""Performs the splitting but the edge might go directly to the consumer."""
producer_edge_desc = next(
edesc for edesc in edges_to_relocate if describes_incoming_edge(edesc)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import numpy as np

dace = pytest.importorskip("dace")
import dace
from dace.sdfg import nodes as dace_nodes

from gt4py.next.program_processors.runners.dace import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,6 @@ def test_subset_merging_stability():
stable_result: Optional[set[dace_sbs.Subset]] = None
for permut in itertools.permutations(subsets):
merged_subset = gtx_dace_split.subset_merger(list(permut))
print(merged_subset)

assert merged_subset is not None
assert len(merged_subset) == 2
Expand All @@ -418,3 +417,67 @@ def test_subset_merging_stability():
stable_result = {copy.deepcopy(ss) for ss in merged_subset}
else:
assert set(merged_subset) == stable_result


def _make_sdfg_for_deterministic_splitting() -> tuple[
dace.SDFG, dace.SDFGState, dace_nodes.AccessNode
]:
sdfg = dace.SDFG(util.unique_name("deterministic_splitter"))
state = sdfg.add_state(is_start_block=True)

for name in "abtcd":
sdfg.add_array(
name,
shape=(40,),
dtype=dace.float64,
transient=(name == "t"),
)
t = state.add_access("t")

state.add_nedge(state.add_access("a"), t, dace.Memlet("a[1:31] -> [10:40]"))
state.add_nedge(state.add_access("b"), t, dace.Memlet("b[2:12] -> [0:10]"))
state.add_nedge(t, state.add_access("c"), dace.Memlet("t[0:10] -> [2:12]"))
state.add_nedge(t, state.add_access("d"), dace.Memlet("t[10:40] -> [1:31]"))

sdfg.validate()
return sdfg, state, t


@pytest.mark.parametrize("use_first_split_order", [True, False])
def test_deterministic_splitting(use_first_split_order: bool):
import dace

sdfg, state, t = _make_sdfg_for_deterministic_splitting()
assert util.count_nodes(sdfg, dace_nodes.AccessNode) == 5

ref, res = util.make_sdfg_args(sdfg)
util.compile_and_run_sdfg(sdfg, **ref)

split_description = [dace_sbs.Range.from_string("0:10"), dace_sbs.Range.from_string("10:40")]
expected_split_order = [
dace_sbs.Range.from_string(split) for split in sorted(map(str, split_description))
]
if use_first_split_order:
split_description = list(reversed(split_description))

access_node_fragments = gtx_transformations.splitting_tools.split_node(
state=state,
sdfg=sdfg,
node_to_split=t,
split_description=split_description,
allow_to_bypass_nodes=False,
)
assert len(access_node_fragments) == 2
after_ac = util.count_nodes(sdfg, dace_nodes.AccessNode, True)
assert len(after_ac) == 6
assert "t" not in after_ac
assert {"t_split_0", "t_split_1"}.issubset(ac.data for ac in after_ac)

for i, (split, access_node_fragment) in enumerate(access_node_fragments.items()):
assert split == expected_split_order[i]
expected_data_name = f"t_split_{i}"
assert access_node_fragment.data == expected_data_name
assert access_node_fragment.desc(sdfg).shape[0] == split.size()[0]

util.compile_and_run_sdfg(sdfg, **res)
assert util.compare_sdfg_res(ref=ref, res=res)