Skip to content

Commit

Permalink
Merge pull request #2563 from stfc/2499_scalarization_trans
Browse files Browse the repository at this point in the history
(Closes #2499) scalarization transformation implementation
  • Loading branch information
sergisiso authored Feb 27, 2025
2 parents 35cbc6a + 612f264 commit 9186a95
Show file tree
Hide file tree
Showing 18 changed files with 1,722 additions and 28 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/nemo_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ jobs:
module load perl/${PERL_VERSION}
make -j 4 passthrough
make -j 4 compile-passthrough
make run-passthrough
# Check for full numerical reproducibility with KGO results
diff <(make -s output-passthrough) KGOs/run.stat.nemo4.splitz12.nvhpc.10steps
# PSyclone, compile and run MetOffice NEMO with OpenMP for GPUs
- name: NEMO MetOffice OpenMP for GPU
Expand Down
4 changes: 4 additions & 0 deletions changelog
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
1) PR #2563 for #2499. Adds ScalarisationTrans.

release 3.1.0 26th of February 2025

1) PR #2827. Update Zenodo with release 3.0.0 and update link in
README.md.

Expand Down
7 changes: 7 additions & 0 deletions doc/user_guide/transformations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,13 @@ can be found in the API-specific sections).

####

.. autoclass:: psyclone.psyir.transformations.ScalarisationTrans
:members: apply
:noindex:

####


Algorithm-layer
---------------

Expand Down
4 changes: 2 additions & 2 deletions examples/nemo/scripts/omp_cpu_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def trans(psyir):
:type psyir: :py:class:`psyclone.psyir.nodes.FileContainer`
'''

# If the environemnt has ONLY_FILE defined, only process that one file and
# nothing else. This is useful for file-by-file exhaustive tests.
only_do_file = os.environ.get('ONLY_FILE', False)
Expand All @@ -96,7 +95,8 @@ def trans(psyir):
hoist_local_arrays=False,
convert_array_notation=True,
convert_range_loops=True,
hoist_expressions=False
hoist_expressions=False,
scalarise_loops=False
)

if psyir.name not in PARALLELISATION_ISSUES:
Expand Down
15 changes: 14 additions & 1 deletion examples/nemo/scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from psyclone.psyir.transformations import (
ArrayAssignment2LoopsTrans, HoistLoopBoundExprTrans, HoistLocalArraysTrans,
HoistTrans, InlineTrans, Maxval2LoopTrans, ProfileTrans,
Reference2ArrayRangeTrans)
Reference2ArrayRangeTrans, ScalarisationTrans)
from psyclone.transformations import TransformationError


Expand Down Expand Up @@ -283,6 +283,7 @@ def normalise_loops(
loopify_array_intrinsics: bool = True,
convert_range_loops: bool = True,
hoist_expressions: bool = True,
scalarise_loops: bool = False,
):
''' Normalise all loops in the given schedule so that they are in an
appropriate form for the Parallelisation transformations to analyse
Expand All @@ -299,6 +300,8 @@ def normalise_loops(
loops.
:param bool hoist_expressions: whether to hoist bounds and loop invariant
statements out of the loop nest.
:param scalarise_loops: whether to attempt to convert arrays to scalars
where possible, default is False.
'''
if hoist_local_arrays and schedule.name not in CONTAINS_STMT_FUNCTIONS:
# Apply the HoistLocalArraysTrans when possible, it cannot be applied
Expand Down Expand Up @@ -339,6 +342,16 @@ def normalise_loops(
except TransformationError:
pass

if scalarise_loops:
# Apply scalarisation to every loop. Execute this in reverse order
# as sometimes we can scalarise earlier loops if following loops
# have already been scalarised.
loops = schedule.walk(Loop)
loops.reverse()
scalartrans = ScalarisationTrans()
for loop in loops:
scalartrans.apply(loop)

if hoist_expressions:
# First hoist all possible expressions
for loop in schedule.walk(Loop):
Expand Down
33 changes: 33 additions & 0 deletions src/psyclone/core/symbolic_maths.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,39 @@ def greater_than(exp1, exp2, all_variables_positive=None):
return SymbolicMaths.Fuzzy.TRUE
return SymbolicMaths.Fuzzy.FALSE

@staticmethod
def less_than(exp1, exp2, all_variables_positive=None):
'''
Determines whether exp1 is, or might be, numerically less than exp2.
:param exp1: the first expression for the comparison.
:type exp1: :py:class:`psyclone.psyir.nodes.Node`
:param exp1: the second expression for the comparison.
:type exp1: :py:class:`psyclone.psyir.nodes.Node`
:param Optional[bool] all_variables_positive: whether or not to assume
that all variables appearing in either expression are positive
definite. Default is not to make this assumption.
:returns: whether exp1 is, or might be, numerically less than exp2.
:rtype: :py:class:`psyclone.core.symbolic_maths.Fuzzy`
'''
diff_val = SymbolicMaths._subtract(
exp1, exp2,
all_variables_positive=all_variables_positive)
if isinstance(diff_val, core.numbers.Integer):
if diff_val.is_zero or diff_val.is_positive:
return SymbolicMaths.Fuzzy.FALSE
return SymbolicMaths.Fuzzy.TRUE

# We have some sort of symbolic result
result = diff_val.is_negative
if result is None:
return SymbolicMaths.Fuzzy.MAYBE
if result:
return SymbolicMaths.Fuzzy.TRUE
return SymbolicMaths.Fuzzy.FALSE

# -------------------------------------------------------------------------
@staticmethod
def solve_equal_for(exp1, exp2, symbol):
Expand Down
17 changes: 17 additions & 0 deletions src/psyclone/psyir/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,23 @@ def update_parent_symbol_table(self, new_parent):
'''

def is_descendent_of(self, potential_ancestor) -> bool:
'''
Checks if this node is a descendant of the `potential_ancestor` node.
:param potential_ancestor: The Node to check whether its an ancestor
of self.
:type node: :py:class:`psyclone.psyir.nodes.Node`
:returns: whether potential_ancestor is an ancestor of this node.
'''
current_node = self
while (current_node is not potential_ancestor and
current_node.parent is not None):
current_node = current_node.parent

return current_node is potential_ancestor


# For automatic documentation generation
# TODO #913 the 'colored' routine shouldn't be in this module.
Expand Down
8 changes: 8 additions & 0 deletions src/psyclone/psyir/nodes/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,18 @@ def is_read(self):
'''
# pylint: disable=import-outside-toplevel
from psyclone.psyir.nodes.assignment import Assignment
from psyclone.psyir.nodes.intrinsic_call import IntrinsicCall
parent = self.parent
if isinstance(parent, Assignment):
if parent.lhs is self:
return False

# If we have an intrinsic call parent then we need to check if its
# an inquiry. Inquiry functions don't read from their first argument.
if isinstance(parent, IntrinsicCall):
if parent.arguments[0] is self and parent.is_inquiry:
return False

# All references other than LHS of assignments represent a read. This
# can be improved in the future by looking at Call intents.
return True
Expand Down
74 changes: 50 additions & 24 deletions src/psyclone/psyir/tools/definition_use_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,20 @@ def find_forward_accesses(self):
.abs_position
+ 1
)
# We make a copy of the reference to have a detached
# node to avoid handling the special cases based on
# the parents of the reference.
chain = DefinitionUseChain(
self._reference.copy(),
body,
start_point=ancestor.abs_position,
stop_point=sub_stop_point,
)
chains.insert(0, chain)
# If we have a basic block with no children then skip it,
# e.g. for an if block with no code before the else
# statement.
if len(body) > 0:
# We make a copy of the reference to have a detached
# node to avoid handling the special cases based on
# the parents of the reference.
chain = DefinitionUseChain(
self._reference.copy(),
body,
start_point=ancestor.abs_position,
stop_point=sub_stop_point,
)
chains.insert(0, chain)
# If its a while loop, create a basic block for the while
# condition.
if isinstance(ancestor, WhileLoop):
Expand Down Expand Up @@ -300,6 +304,11 @@ def find_forward_accesses(self):
# Now add all the other standardly handled basic_blocks to the
# list of chains.
for block in basic_blocks:
# If we have a basic block with no children then skip it,
# e.g. for an if block with no code before the else
# statement.
if len(block) == 0:
continue
chain = DefinitionUseChain(
self._reference,
block,
Expand Down Expand Up @@ -449,6 +458,12 @@ def _compute_forward_uses(self, basic_block_list):
if defs_out is not None:
self._defsout.append(defs_out)
return
# If its parent is an inquiry function then its neither
# a read nor write if its the first argument.
if (isinstance(reference.parent, IntrinsicCall) and
reference.parent.is_inquiry and
reference.parent.arguments[0] is reference):
continue
if isinstance(reference, CodeBlock):
# CodeBlocks only find symbols, so we can only do as good
# as checking the symbol - this means we can get false
Expand Down Expand Up @@ -525,9 +540,7 @@ def _compute_forward_uses(self, basic_block_list):
if defs_out is None:
self._uses.append(reference)
elif reference.ancestor(Call):
# It has a Call ancestor so assume read/write access
# for now.
# We can do better for IntrinsicCalls realistically.
# Otherwise we assume read/write access for now.
if defs_out is not None:
self._killed.append(defs_out)
defs_out = reference
Expand Down Expand Up @@ -699,6 +712,12 @@ def _compute_backward_uses(self, basic_block_list):
abs_pos = reference.abs_position
if abs_pos < self._start_point or abs_pos >= stop_position:
continue
# If its parent is an inquiry function then its neither
# a read nor write if its the first argument.
if (isinstance(reference.parent, IntrinsicCall) and
reference.parent.is_inquiry and
reference.parent.arguments[0] is reference):
continue
if isinstance(reference, CodeBlock):
# CodeBlocks only find symbols, so we can only do as good
# as checking the symbol - this means we can get false
Expand Down Expand Up @@ -784,9 +803,7 @@ def _compute_backward_uses(self, basic_block_list):
if defs_out is None:
self._uses.append(reference)
elif reference.ancestor(Call):
# It has a Call ancestor so assume read/write access
# for now.
# We can do better for IntrinsicCalls realistically.
# Otherwise we assume read/write access for now.
if defs_out is not None:
self._killed.append(defs_out)
defs_out = reference
Expand Down Expand Up @@ -835,6 +852,11 @@ def find_backward_accesses(self):
# Now add all the other standardly handled basic_blocks to the
# list of chains.
for block in basic_blocks:
# If we have a basic block with no children then skip it,
# e.g. for an if block with no code before the else
# statement.
if len(block) == 0:
continue
chain = DefinitionUseChain(
self._reference,
block,
Expand Down Expand Up @@ -874,14 +896,18 @@ def find_backward_accesses(self):
).abs_position
else:
sub_start_point = self._reference.abs_position
chain = DefinitionUseChain(
self._reference.copy(),
body,
start_point=sub_start_point,
stop_point=sub_stop_point,
)
chains.append(chain)
control_flow_nodes.append(ancestor)
# If we have a basic block with no children then skip it,
# e.g. for an if block with no code before the else
# statement.
if len(body) > 0:
chain = DefinitionUseChain(
self._reference.copy(),
body,
start_point=sub_start_point,
stop_point=sub_stop_point,
)
chains.append(chain)
control_flow_nodes.append(ancestor)
# If its a while loop, create a basic block for the while
# condition.
if isinstance(ancestor, WhileLoop):
Expand Down
3 changes: 3 additions & 0 deletions src/psyclone/psyir/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
ReplaceInductionVariablesTrans
from psyclone.psyir.transformations.reference2arrayrange_trans import \
Reference2ArrayRangeTrans
from psyclone.psyir.transformations.scalarisation_trans import \
ScalarisationTrans


# For AutoAPI documentation generation
Expand Down Expand Up @@ -145,5 +147,6 @@
'Reference2ArrayRangeTrans',
'RegionTrans',
'ReplaceInductionVariablesTrans',
'ScalarisationTrans',
'TransformationError',
'ValueRangeCheckTrans']
Loading

0 comments on commit 9186a95

Please sign in to comment.