@@ -124,54 +124,62 @@ def translate(
124124 return self
125125 if len (shift ) == 2 :
126126 off , val = shift
127- assert isinstance (off , itir .OffsetLiteral ) and isinstance (off .value , str )
128- connectivity_type = common .get_offset_type (offset_provider_type , off .value )
129-
130- if isinstance (connectivity_type , common .Dimension ):
131- if val is trace_shifts .Sentinel .VALUE :
132- raise NotImplementedError ("Dynamic offsets not supported." )
133- assert isinstance (val , itir .OffsetLiteral ) and isinstance (val .value , int )
134- current_dim = connectivity_type
135- # cartesian offset
136- new_ranges [current_dim ] = SymbolicRange .translate (
137- self .ranges [current_dim ], val .value
138- )
139- elif isinstance (connectivity_type , common .NeighborConnectivityType ):
140- # unstructured shift
141- assert (
142- isinstance (val , itir .OffsetLiteral ) and isinstance (val .value , int )
143- ) or val in [
144- trace_shifts .Sentinel .ALL_NEIGHBORS ,
145- trace_shifts .Sentinel .VALUE ,
146- ]
147- horizontal_sizes : dict [str , itir .Expr ]
148- if symbolic_domain_sizes is not None :
149- horizontal_sizes = {
150- k : im .ensure_expr (v ) for k , v in symbolic_domain_sizes .items ()
151- }
152- else :
153- # note: ugly but cheap re-computation, but should disappear
154- assert common .is_offset_provider (offset_provider )
155- horizontal_sizes = {
156- k : im .literal (str (v ), builtins .INTEGER_INDEX_BUILTIN )
157- for k , v in _max_domain_sizes_by_location_type (offset_provider ).items ()
158- }
159-
160- old_dim = connectivity_type .source_dim
161- new_dim = connectivity_type .codomain
162-
163- assert new_dim not in new_ranges or old_dim == new_dim
164-
165- new_range = SymbolicRange (
166- im .literal ("0" , builtins .INTEGER_INDEX_BUILTIN ),
167- horizontal_sizes [new_dim .value ],
168- )
127+ if isinstance (off , itir .CartesianOffset ):
128+ old_dim = common .Dimension (value = off .domain .value , kind = off .domain .kind )
129+ new_dim = common .Dimension (value = off .codomain .value , kind = off .codomain .kind )
130+ new_range = SymbolicRange .translate (self .ranges [old_dim ], val .value )
169131 new_ranges = dict (
170132 (dim , range_ ) if dim != old_dim else (new_dim , new_range )
171133 for dim , range_ in new_ranges .items ()
172134 )
173135 else :
174- raise AssertionError ()
136+ assert isinstance (off , itir .OffsetLiteral ) and isinstance (off .value , str )
137+ connectivity_type = common .get_offset_type (offset_provider_type , off .value )
138+ if isinstance (connectivity_type , common .Dimension ):
139+ if val is trace_shifts .Sentinel .VALUE :
140+ raise NotImplementedError ("Dynamic offsets not supported." )
141+ assert isinstance (val , itir .OffsetLiteral ) and isinstance (val .value , int )
142+ current_dim = connectivity_type
143+ # cartesian offset
144+ new_ranges [current_dim ] = SymbolicRange .translate (
145+ self .ranges [current_dim ], val .value
146+ )
147+ elif isinstance (connectivity_type , common .NeighborConnectivityType ):
148+ # unstructured shift
149+ assert (
150+ isinstance (val , itir .OffsetLiteral ) and isinstance (val .value , int )
151+ ) or val in [
152+ trace_shifts .Sentinel .ALL_NEIGHBORS ,
153+ trace_shifts .Sentinel .VALUE ,
154+ ]
155+ horizontal_sizes : dict [str , itir .Expr ]
156+ if symbolic_domain_sizes is not None :
157+ horizontal_sizes = {
158+ k : im .ensure_expr (v ) for k , v in symbolic_domain_sizes .items ()
159+ }
160+ else :
161+ # note: ugly but cheap re-computation, but should disappear
162+ assert common .is_offset_provider (offset_provider )
163+ horizontal_sizes = {
164+ k : im .literal (str (v ), builtins .INTEGER_INDEX_BUILTIN )
165+ for k , v in _max_domain_sizes_by_location_type (offset_provider ).items ()
166+ }
167+
168+ old_dim = connectivity_type .source_dim
169+ new_dim = connectivity_type .codomain
170+
171+ assert new_dim not in new_ranges or old_dim == new_dim
172+
173+ new_range = SymbolicRange (
174+ im .literal ("0" , builtins .INTEGER_INDEX_BUILTIN ),
175+ horizontal_sizes [new_dim .value ],
176+ )
177+ new_ranges = dict (
178+ (dim , range_ ) if dim != old_dim else (new_dim , new_range )
179+ for dim , range_ in new_ranges .items ()
180+ )
181+ else :
182+ raise AssertionError ()
175183 return SymbolicDomain (self .grid_type , new_ranges )
176184 elif len (shift ) > 2 :
177185 return self .translate (shift [0 :2 ], offset_provider , symbolic_domain_sizes ).translate (
0 commit comments