diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index ecd8ed88ed..643319e99d 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -206,17 +206,27 @@ def _bool_from_literal(node: itir.Node) -> bool: class _CannonicalizeUnstructuredDomain(eve.NodeTranslator): def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: - if node.fun == itir.SymRef(id="unstructured_domain"): + if cpm.is_call_to(node, "unstructured_domain"): # for no good reason, the domain arguments for unstructured need to be in order (horizontal, vertical) assert isinstance(node.args[0], itir.FunCall) first_axis_literal = node.args[0].args[0] assert isinstance(first_axis_literal, itir.AxisLiteral) - if first_axis_literal.kind == itir.DimensionKind.VERTICAL: - assert len(node.args) == 2 - assert isinstance(node.args[1], itir.FunCall) - assert isinstance(node.args[1].args[0], itir.AxisLiteral) - assert node.args[1].args[0].kind == itir.DimensionKind.HORIZONTAL - return itir.FunCall(fun=node.fun, args=[node.args[1], node.args[0]]) + if len(node.args) == 1: + if first_axis_literal.kind == itir.DimensionKind.VERTICAL: + # a horizontal domain is needed in unstructured, so we convert a K-only domain to cartesian + dim = common.Dimension(first_axis_literal.value, first_axis_literal.kind) + return im.domain( + common.GridType.CARTESIAN, + {dim: (node.args[0].args[1], node.args[0].args[2])}, + ) + elif len(node.args) == 2: + if first_axis_literal.kind == itir.DimensionKind.VERTICAL: + assert isinstance(node.args[1], itir.FunCall) + assert isinstance(node.args[1].args[0], itir.AxisLiteral) + assert node.args[1].args[0].kind == itir.DimensionKind.HORIZONTAL + return itir.FunCall(fun=node.fun, args=[node.args[1], node.args[0]]) + else: + raise NotImplementedError("Only up to two dimensional domains are supported.") return node @classmethod diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8060d5bb36..20ff36ae6d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -64,6 +64,14 @@ def testee(a: cases.IJKField) -> cases.IJKField: cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) +def test_copy_vertical(unstructured_case_3d): + @gtx.field_operator + def testee(a: cases.KField) -> cases.KField: + return a + + cases.verify_with_default_data(unstructured_case_3d, testee, ref=lambda a: a) + + @pytest.mark.uses_tuple_returns def test_multicopy(cartesian_case): @gtx.field_operator