forked from zachetienne/nrpytutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSIMDExprTree.py
More file actions
142 lines (120 loc) · 4.91 KB
/
Copy pathSIMDExprTree.py
File metadata and controls
142 lines (120 loc) · 4.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
""" SymPy (N-Ary) Expression Tree
The following script will extend the expression tree from SymPy,
allowing direct node manipulation for subexpression replacement.
The expression tree structure within SymPy expressions stores
subexpressions inside immutable tuples, preventing the client from
modifying the expression tree. Therefore, the client must depend on
build-in functions, such as xreplace, for subexpression replacement,
which might be suboptimal for their specific purpose. The ExprTree class
is implemented as an n-ary tree data structure for SymPy expressions,
equipped with a build method for constructing the expression tree,
a reconstruct method for reconstructing the root expression, a replace
method for subexpression replacement, and preorder/postorder traversal
iterators (or generators). The __repr__ representation of the expression
tree will return a string of the expressions using the preorder traversal,
while the __str__ representation will return a string of the class name
and root expression. The Node subclass has a field for an expression and
a field for subexpression children (implemented as a mutable list).
"""
# Author: Ken Sible
# Email: ksible *at* outlook *dot* com
# pylint: disable=too-few-public-methods
class ExprTree:
""" SymPy (N-Ary) Expression Tree
>>> from sympy.abc import a, b, x
>>> from sympy import cos
>>> tree = ExprTree(cos(a + b)**2)
>>> print(tree)
ExprTree(cos(a + b)**2)
>>> repr(tree)
'[cos(a + b)**2, cos(a + b), a + b, a, b, 2]'
"""
def __init__(self, expr):
self.root = self.Node(expr, None)
self.build(self.root)
def build(self, node, clear=False):
""" Build expression (sub)tree.
:arg: root node of (sub)tree
:arg: clear children (default: False)
>>> from sympy.abc import a, b
>>> from sympy import cos, sin
>>> tree = ExprTree(cos(a + b)**2)
>>> tree.root.expr = sin(a*b)**2
>>> tree.build(tree.root, clear=True)
>>> repr(tree)
'[sin(a*b)**2, sin(a*b), a*b, a, b, 2]'
"""
if clear: del node.children[:]
for arg in node.expr.args:
subtree = self.Node(arg, node.expr.func)
node.append(subtree)
self.build(subtree)
def preorder(self, node=None):
""" Generate iterator for preorder traversal.
:arg: root node of (sub)tree
:return: iterator
>>> from sympy.abc import a, b
>>> from sympy import cos, Mul
>>> tree = ExprTree(cos(a*b)**2)
>>> for i, subtree in enumerate(tree.preorder()):
... if subtree.expr.func == Mul:
... print((i, subtree.expr))
(2, a*b)
"""
if node is None:
node = self.root
yield node
for child in node.children:
for subtree in self.preorder(child):
yield subtree
def postorder(self, node=None):
""" Generate iterator for postorder traversal.
:arg: root node of (sub)tree
:return: iterator
>>> from sympy.abc import a, b
>>> from sympy import cos, Mul
>>> tree = ExprTree(cos(a*b)**2)
>>> for i, subtree in enumerate(tree.postorder()):
... if subtree.expr.func == Mul:
... print((i, subtree.expr))
(2, a*b)
"""
if node is None:
node = self.root
for child in node.children:
for subtree in self.postorder(child):
yield subtree
yield node
def reconstruct(self, evaluate=False):
"""
Reconstruct root expression from expression tree.
:arg: evaluate root expression (default: False)
:return: root expression
>>> from sympy.abc import a, b
>>> from sympy import cos, sin
>>> tree = ExprTree(cos(a + b)**2)
>>> tree.root.children[0].expr = sin(a + b)
>>> tree.reconstruct()
sin(a + b)**2
"""
for subtree in self.postorder():
if subtree.children:
expr_list = [node.expr for node in subtree.children]
subtree.expr = subtree.expr.func(\
*expr_list, evaluate=evaluate)
return self.root.expr
class Node:
""" Expression Tree Node; a node cannot exist outside the tree """
def __init__(self, expr, func):
self.expr = expr
self.func = func
self.children = []
def append(self, node):
self.children.append(node)
def __repr__(self):
return str([node.expr for node in self.preorder()])
def __str__(self):
return 'ExprTree(%s)' % str(self.root.expr)
if __name__ == "__main__":
import doctest
doctest.testmod()