@@ -68,7 +68,14 @@ def make_term(cls, *args: Term) -> Self:
6868
6969
7070@dataclass (eq = True , frozen = True )
71- class Immediate (LogicNode ):
71+ class WithFields (LogicNode ):
72+ @abstractmethod
73+ def get_fields (self ) -> tuple [Field , ...]:
74+ """Get this node's fields."""
75+
76+
77+ @dataclass (eq = True , frozen = True )
78+ class Immediate (WithFields ):
7279 """
7380 Represents a logical AST expression for the literal value `val`.
7481
@@ -97,6 +104,10 @@ def fill_value(self):
97104 def children (self ):
98105 raise TypeError (f"`{ type (self ).__name__ } ` doesn't support `.children()`." )
99106
107+ def get_fields (self ) -> tuple [Field , ...]:
108+ """Returns fields of the node."""
109+ return ()
110+
100111
101112@dataclass (eq = True , frozen = True )
102113class Deferred (LogicNode ):
@@ -183,7 +194,7 @@ def children(self):
183194
184195
185196@dataclass (eq = True , frozen = True )
186- class Table (LogicNode ):
197+ class Table (WithFields ):
187198 """
188199 Represents a logical AST expression for a tensor object `tns`, indexed by fields
189200 `idxs...`. A table is a tensor with named dimensions.
@@ -210,9 +221,17 @@ def children(self):
210221 """Returns the children of the node."""
211222 return [self .tns , * self .idxs ]
212223
224+ def get_fields (self ) -> tuple [Field , ...]:
225+ """Returns fields of the node."""
226+ return self .idxs
227+
228+ @classmethod
229+ def make_term (cls , head , tns , * idxs ):
230+ return head (tns , idxs )
231+
213232
214233@dataclass (eq = True , frozen = True )
215- class MapJoin (LogicNode ):
234+ class MapJoin (WithFields ):
216235 """
217236 Represents a logical AST expression for mapping the function `op` across `args...`.
218237 Dimensions which are not present are broadcasted. Dimensions which are
@@ -225,7 +244,7 @@ class MapJoin(LogicNode):
225244 """
226245
227246 op : Immediate
228- args : tuple [LogicNode , ...]
247+ args : tuple [WithFields , ...]
229248
230249 @staticmethod
231250 def is_expr ():
@@ -241,8 +260,19 @@ def children(self):
241260 """Returns the children of the node."""
242261 return [self .op , * self .args ]
243262
263+ def get_fields (self ) -> tuple [Field , ...]:
264+ """Returns fields of the node."""
265+ # (mtsokol) I'm not sure if this comment still applies - the order is preserved.
266+ # TODO: this is wrong here: the overall order should at least be concordant with
267+ # the args if the args are concordant
268+ fs : list [Field ] = []
269+ for arg in self .args :
270+ fs .extend (arg .get_fields ())
271+
272+ return tuple (fs )
273+
244274 @classmethod
245- def make_term (cls , op : Immediate , * args : LogicNode ) -> Self : # type: ignore[override]
275+ def make_term (cls , op : Immediate , * args : WithFields ) -> Self : # type: ignore[override]
246276 return cls (op , tuple (args ))
247277
248278
@@ -261,7 +291,7 @@ class Aggregate(LogicNode):
261291
262292 op : Immediate
263293 init : Immediate
264- arg : LogicNode
294+ arg : WithFields
265295 idxs : tuple [Field , ...]
266296
267297 @staticmethod
@@ -278,9 +308,17 @@ def children(self):
278308 """Returns the children of the node."""
279309 return [self .op , self .init , self .arg , * self .idxs ]
280310
311+ def get_fields (self ) -> tuple [Field , ...]:
312+ """Returns fields of the node."""
313+ return tuple (field for field in self .arg .get_fields () if field not in self .idxs )
314+
315+ @classmethod
316+ def make_term (cls , head , op , init , arg , * idxs ):
317+ return head (op , init , arg , idxs )
318+
281319
282320@dataclass (eq = True , frozen = True )
283- class Reorder (LogicNode ):
321+ class Reorder (WithFields ):
284322 """
285323 Represents a logical AST statement that reorders the dimensions of `arg` to be
286324 `idxs...`. Dimensions known to be length 1 may be dropped. Dimensions that do not
@@ -308,9 +346,17 @@ def children(self):
308346 """Returns the children of the node."""
309347 return [self .arg , * self .idxs ]
310348
349+ def get_fields (self ) -> tuple [Field , ...]:
350+ """Returns fields of the node."""
351+ return self .idxs
352+
353+ @classmethod
354+ def make_term (cls , head , arg , * idxs ):
355+ return head (arg , idxs )
356+
311357
312358@dataclass (eq = True , frozen = True )
313- class Relabel (LogicNode ):
359+ class Relabel (WithFields ):
314360 """
315361 Represents a logical AST statement that relabels the dimensions of `arg` to be
316362 `idxs...`.
@@ -337,9 +383,13 @@ def children(self):
337383 """Returns the children of the node."""
338384 return [self .arg , * self .idxs ]
339385
386+ def get_fields (self ) -> tuple [Field , ...]:
387+ """Returns fields of the node."""
388+ return self .idxs
389+
340390
341391@dataclass (eq = True , frozen = True )
342- class Reformat (LogicNode ):
392+ class Reformat (WithFields ):
343393 """
344394 Represents a logical AST statement that reformats `arg` into the tensor `tns`.
345395
@@ -349,7 +399,7 @@ class Reformat(LogicNode):
349399 """
350400
351401 tns : Immediate
352- arg : LogicNode
402+ arg : WithFields
353403
354404 @staticmethod
355405 def is_expr ():
@@ -365,20 +415,24 @@ def children(self):
365415 """Returns the children of the node."""
366416 return [self .tns , self .arg ]
367417
418+ def get_fields (self ) -> tuple [Field , ...]:
419+ """Returns fields of the node."""
420+ return self .arg .get_fields ()
421+
368422
369423@dataclass (eq = True , frozen = True )
370- class Subquery (LogicNode ):
424+ class Subquery (WithFields ):
371425 """
372426 Represents a logical AST statement that evaluates `rhs`, binding the result to
373427 `lhs`, and returns `rhs`.
374428
375429 Attributes:
376430 lhs: The left-hand side of the binding.
377- rhs : The argument to evaluate.
431+ arg : The argument to evaluate.
378432 """
379433
380434 lhs : LogicNode
381- arg : LogicNode
435+ arg : WithFields
382436
383437 @staticmethod
384438 def is_expr ():
@@ -394,6 +448,10 @@ def children(self):
394448 """Returns the children of the node."""
395449 return [self .lhs , self .arg ]
396450
451+ def get_fields (self ) -> tuple [Field , ...]:
452+ """Returns fields of the node."""
453+ return self .arg .get_fields ()
454+
397455
398456@dataclass (eq = True , frozen = True )
399457class Query (LogicNode ):
0 commit comments