22
33from abc import abstractmethod
44from dataclasses import dataclass
5- from typing import TYPE_CHECKING , Any
5+ from typing import Any , Self , TypeVar
66
77from ..symbolic import Term
88
9- if TYPE_CHECKING :
10- pass
11-
129__all__ = [
1310 "LogicNode" ,
1411 "Immediate" ,
2825]
2926
3027
28+ T = TypeVar ("T" , bound = "LogicNode" )
29+
30+
3131@dataclass (eq = True , frozen = True )
3232class LogicNode (Term ):
3333 """
@@ -57,14 +57,14 @@ def head(cls):
5757 """Returns the head of the node."""
5858 return cls
5959
60- def children (self ) -> list [Term ]:
60+ @abstractmethod
61+ def children (self ) -> list [LogicNode ]:
6162 """Returns the children of the node."""
62- raise Exception (f"`children` isn't supported for { self .__class__ } ." )
6363
6464 @classmethod
65- def make_term (cls , head , args ) :
65+ def make_term (cls , * args : Term ) -> Self :
6666 """Creates a term with the given head and arguments."""
67- return head (* args )
67+ return cls (* args )
6868
6969
7070@dataclass (eq = True , frozen = True )
@@ -94,6 +94,9 @@ def fill_value(self):
9494
9595 return fill_value (self )
9696
97+ def children (self ):
98+ raise TypeError (f"`{ type (self ).__name__ } ` doesn't support `.children()`." )
99+
97100
98101@dataclass (eq = True , frozen = True )
99102class Deferred (LogicNode ):
@@ -222,7 +225,7 @@ class MapJoin(LogicNode):
222225 """
223226
224227 op : Immediate
225- args : tuple [Term , ...]
228+ args : tuple [LogicNode , ...]
226229
227230 @staticmethod
228231 def is_expr ():
@@ -239,8 +242,8 @@ def children(self):
239242 return [self .op , * self .args ]
240243
241244 @classmethod
242- def make_term (cls , head , args ):
243- return head ( args [ 0 ] , tuple (args [ 1 :] ))
245+ def make_term (cls , op : Immediate , * args : LogicNode ) -> Self : # type: ignore[override]
246+ return cls ( op , tuple (args ))
244247
245248
246249@dataclass (eq = True , frozen = True )
@@ -258,7 +261,7 @@ class Aggregate(LogicNode):
258261
259262 op : Immediate
260263 init : Immediate
261- arg : Term
264+ arg : LogicNode
262265 idxs : tuple [Field , ...]
263266
264267 @staticmethod
@@ -288,7 +291,7 @@ class Reorder(LogicNode):
288291 idxs: The new order of dimensions.
289292 """
290293
291- arg : Term
294+ arg : LogicNode
292295 idxs : tuple [Field , ...]
293296
294297 @staticmethod
@@ -317,7 +320,7 @@ class Relabel(LogicNode):
317320 idxs: The new labels for dimensions.
318321 """
319322
320- arg : Term
323+ arg : LogicNode
321324 idxs : tuple [Field , ...]
322325
323326 @staticmethod
@@ -346,7 +349,7 @@ class Reformat(LogicNode):
346349 """
347350
348351 tns : Immediate
349- arg : Term
352+ arg : LogicNode
350353
351354 @staticmethod
352355 def is_expr ():
@@ -374,8 +377,8 @@ class Subquery(LogicNode):
374377 rhs: The argument to evaluate.
375378 """
376379
377- lhs : Term
378- arg : Term
380+ lhs : LogicNode
381+ arg : LogicNode
379382
380383 @staticmethod
381384 def is_expr ():
@@ -403,8 +406,8 @@ class Query(LogicNode):
403406 rhs: The right-hand side to evaluate.
404407 """
405408
406- lhs : Term
407- rhs : Term
409+ lhs : LogicNode
410+ rhs : LogicNode
408411
409412 @staticmethod
410413 def is_expr ():
@@ -431,7 +434,7 @@ class Produces(LogicNode):
431434 args: The arguments to return.
432435 """
433436
434- args : tuple [Term , ...]
437+ args : tuple [LogicNode , ...]
435438
436439 @staticmethod
437440 def is_expr ():
@@ -445,11 +448,11 @@ def is_stateful():
445448
446449 def children (self ):
447450 """Returns the children of the node."""
448- return list ( self .args )
451+ return [ * self .args ]
449452
450453 @classmethod
451- def make_term (cls , head , args ):
452- return head (tuple (args ))
454+ def make_term (cls , * args : LogicNode ) -> Self : # type: ignore[override]
455+ return cls (tuple (args ))
453456
454457
455458@dataclass (eq = True , frozen = True )
@@ -462,7 +465,7 @@ class Plan(LogicNode):
462465 bodies: The sequence of statements to execute.
463466 """
464467
465- bodies : tuple [Term , ...] = ()
468+ bodies : tuple [LogicNode , ...] = ()
466469
467470 @staticmethod
468471 def is_expr ():
@@ -479,5 +482,5 @@ def children(self):
479482 return tuple (self .bodies )
480483
481484 @classmethod
482- def make_term (cls , head , val ):
483- return head (tuple (val ))
485+ def make_term (cls , * bodies : LogicNode ) -> Self : # type: ignore[override]
486+ return cls (tuple (bodies ))
0 commit comments