1010
1111import gt4py .next .ffront .field_operator_ast as foast
1212from gt4py .eve import NodeTranslator , NodeVisitor , traits
13- from gt4py .next import errors
14- from gt4py .next .common import DimensionKind
13+ from gt4py .next import errors , utils
14+ from gt4py .next .common import DimensionKind , promote_dims
1515from gt4py .next .ffront import ( # noqa
1616 dialect_ast_enums ,
1717 experimental ,
2020 type_specifications as ts_ffront ,
2121)
2222from gt4py .next .ffront .foast_passes .utils import compute_assign_indices
23+ from gt4py .next .iterator import builtins
2324from gt4py .next .type_system import type_info , type_specifications as ts , type_translation
2425
2526
@@ -428,7 +429,7 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs: Any) -> foast.IfStmt:
428429 if not isinstance (new_node .condition .type , ts .ScalarType ):
429430 raise errors .DSLError (
430431 node .location ,
431- "Condition for 'if' must be scalar, " f" got '{ new_node .condition .type } ' instead." ,
432+ f "Condition for 'if' must be scalar, got '{ new_node .condition .type } ' instead." ,
432433 )
433434
434435 if new_node .condition .type .kind != ts .ScalarKind .BOOL :
@@ -566,16 +567,10 @@ def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> foast.Compare:
566567 op = node .op , left = new_left , right = new_right , location = node .location , type = new_type
567568 )
568569
569- def _deduce_compare_type (
570+ def _deduce_arithmetic_compare_type (
570571 self , node : foast .Compare , * , left : foast .Expr , right : foast .Expr , ** kwargs : Any
571572 ) -> Optional [ts .TypeSpec ]:
572- # check both types compatible
573- for arg in (left , right ):
574- if not type_info .is_arithmetic (arg .type ):
575- raise errors .DSLError (
576- arg .location , f"Type '{ arg .type } ' can not be used in operator '{ node .op } '."
577- )
578-
573+ # e.g. `1 < 2`
579574 self ._check_operand_dtypes_match (node , left = left , right = right )
580575
581576 try :
@@ -592,6 +587,51 @@ def _deduce_compare_type(
592587 f" in call to '{ node .op } '." ,
593588 ) from ex
594589
590+ def _deduce_dimension_compare_type (
591+ self , node : foast .Compare , * , left : foast .Expr , right : foast .Expr , ** kwargs : Any
592+ ) -> Optional [ts .TypeSpec ]:
593+ # e.g. `IDim > 1`
594+ index_type = ts .ScalarType (
595+ kind = getattr (ts .ScalarKind , builtins .INTEGER_INDEX_BUILTIN .upper ())
596+ )
597+
598+ def error_msg (left : ts .TypeSpec , right : ts .TypeSpec ) -> str :
599+ return f"Dimension comparison needs to be between a 'Dimension' and index of type '{ index_type } ', got '{ left } ' and '{ right } '."
600+
601+ if isinstance (left .type , ts .DimensionType ):
602+ if not right .type == index_type :
603+ raise errors .DSLError (
604+ right .location ,
605+ error_msg (left .type , right .type ),
606+ )
607+ return ts .DomainType (dims = [left .type .dim ])
608+ elif isinstance (right .type , ts .DimensionType ):
609+ if not left .type == index_type :
610+ raise errors .DSLError (
611+ left .location ,
612+ error_msg (left .type , right .type ),
613+ )
614+ return ts .DomainType (dims = [right .type .dim ])
615+ else :
616+ raise AssertionError ()
617+
618+ def _deduce_compare_type (
619+ self , node : foast .Compare , * , left : foast .Expr , right : foast .Expr , ** kwargs : Any
620+ ) -> Optional [ts .TypeSpec ]:
621+ # e.g. `1 < 1`
622+ if all (type_info .is_arithmetic (arg ) for arg in (left .type , right .type )):
623+ return self ._deduce_arithmetic_compare_type (node , left = left , right = right )
624+ # e.g. `IDim > 1`
625+ if any (isinstance (arg , ts .DimensionType ) for arg in (left .type , right .type )):
626+ return self ._deduce_dimension_compare_type (node , left = left , right = right )
627+
628+ raise errors .DSLError (
629+ left .location ,
630+ "Comparison operators can only be used between arithmetic types "
631+ "(scalars, fields) or between a dimension and an index type "
632+ "({builtins.INTEGER_INDEX_BUILTIN})." ,
633+ )
634+
595635 def _deduce_binop_type (
596636 self , node : foast .BinOp , * , left : foast .Expr , right : foast .Expr , ** kwargs : Any
597637 ) -> Optional [ts .TypeSpec ]:
@@ -612,37 +652,48 @@ def _deduce_binop_type(
612652 dialect_ast_enums .BinaryOperator .BIT_OR ,
613653 dialect_ast_enums .BinaryOperator .BIT_XOR ,
614654 }
615- is_compatible = type_info .is_logical if node .op in logical_ops else type_info .is_arithmetic
616-
617- # check both types compatible
618- for arg in (left , right ):
619- if not is_compatible (arg .type ):
620- raise errors .DSLError (
621- arg .location , f"Type '{ arg .type } ' can not be used in operator '{ node .op } '."
622- )
623-
624- left_type = cast (ts .FieldType | ts .ScalarType , left .type )
625- right_type = cast (ts .FieldType | ts .ScalarType , right .type )
626655
627- if node .op == dialect_ast_enums .BinaryOperator .POW :
628- return left_type
656+ err_msg = f"Unsupported operand type(s) for { node .op } : '{ left .type } ' and '{ right .type } '."
629657
630- if node . op == dialect_ast_enums . BinaryOperator . MOD and not type_info . is_integral (
631- right_type
658+ if isinstance ( left . type , ( ts . ScalarType , ts . FieldType )) and isinstance (
659+ right . type , ( ts . ScalarType , ts . FieldType )
632660 ):
633- raise errors .DSLError (
634- arg .location ,
635- f"Type '{ right_type } ' can not be used in operator '{ node .op } ', it only accepts 'int'." ,
661+ is_compatible = (
662+ type_info .is_logical if node .op in logical_ops else type_info .is_arithmetic
636663 )
664+ for arg in (left , right ):
665+ if not is_compatible (arg .type ):
666+ raise errors .DSLError (arg .location , err_msg )
637667
638- try :
639- return type_info .promote (left_type , right_type )
640- except ValueError as ex :
641- raise errors .DSLError (
642- node .location ,
643- f"Could not promote '{ left_type } ' and '{ right_type } ' to common type"
644- f" in call to '{ node .op } '." ,
645- ) from ex
668+ if node .op == dialect_ast_enums .BinaryOperator .POW :
669+ return left .type
670+
671+ if node .op == dialect_ast_enums .BinaryOperator .MOD and not type_info .is_integral (
672+ right .type
673+ ):
674+ raise errors .DSLError (
675+ arg .location ,
676+ f"Type '{ right .type } ' can not be used in operator '{ node .op } ', it only accepts 'int'." ,
677+ )
678+
679+ try :
680+ return type_info .promote (left .type , right .type )
681+ except ValueError as ex :
682+ raise errors .DSLError (
683+ node .location ,
684+ f"Could not promote '{ left .type } ' and '{ right .type } ' to common type"
685+ f" in call to '{ node .op } '." ,
686+ ) from ex
687+ elif isinstance (left .type , ts .DomainType ) and isinstance (right .type , ts .DomainType ):
688+ if node .op not in logical_ops :
689+ raise errors .DSLError (
690+ node .location ,
691+ f"{ err_msg } Operator "
692+ f"must be one of { ', ' .join ((str (op ) for op in logical_ops ))} ." ,
693+ )
694+ return ts .DomainType (dims = promote_dims (left .type .dims , right .type .dims ))
695+ else :
696+ raise errors .DSLError (node .location , err_msg )
646697
647698 def _check_operand_dtypes_match (
648699 self , node : foast .BinOp | foast .Compare , left : foast .Expr , right : foast .Expr
@@ -908,6 +959,7 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
908959 )
909960
910961 try :
962+ # TODO(tehrengruber): the construct_tuple_type function doesn't look correct
911963 if isinstance (true_branch_type , ts .TupleType ) and isinstance (
912964 false_branch_type , ts .TupleType
913965 ):
@@ -943,7 +995,43 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
943995 location = node .location ,
944996 )
945997
946- _visit_concat_where = _visit_where
998+ def _visit_concat_where (self , node : foast .Call , ** kwargs : Any ) -> foast .Call :
999+ cond_type , true_branch_type , false_branch_type = (arg .type for arg in node .args )
1000+
1001+ assert isinstance (cond_type , ts .DomainType )
1002+ assert all (
1003+ isinstance (el , (ts .FieldType , ts .ScalarType ))
1004+ for arg in (true_branch_type , false_branch_type )
1005+ for el in type_info .primitive_constituents (arg )
1006+ )
1007+
1008+ @utils .tree_map (
1009+ collection_type = ts .TupleType ,
1010+ result_collection_constructor = lambda el : ts .TupleType (types = list (el )),
1011+ )
1012+ def deduce_return_type (
1013+ tb : ts .FieldType | ts .ScalarType , fb : ts .FieldType | ts .ScalarType
1014+ ) -> ts .FieldType :
1015+ if (t_dtype := type_info .extract_dtype (tb )) != (f_dtype := type_info .extract_dtype (fb )):
1016+ raise errors .DSLError (
1017+ node .location ,
1018+ f"Field arguments must be of same dtype, got '{ t_dtype } ' != '{ f_dtype } '." ,
1019+ )
1020+ return_dims = promote_dims (
1021+ cond_type .dims , type_info .extract_dims (type_info .promote (tb , fb ))
1022+ )
1023+ return_type = ts .FieldType (dims = return_dims , dtype = t_dtype )
1024+ return return_type
1025+
1026+ return_type = deduce_return_type (true_branch_type , false_branch_type )
1027+
1028+ return foast .Call (
1029+ func = node .func ,
1030+ args = node .args ,
1031+ kwargs = node .kwargs ,
1032+ type = return_type ,
1033+ location = node .location ,
1034+ )
9471035
9481036 def _visit_broadcast (self , node : foast .Call , ** kwargs : Any ) -> foast .Call :
9491037 arg_type = cast (ts .FieldType | ts .ScalarType , node .args [0 ].type )
0 commit comments