-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathsymbol_analyzer.py
394 lines (305 loc) · 11 KB
/
symbol_analyzer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
import ast
import yapypy.extended_python.extended_ast as ex_ast
import typing
from yapypy.utils.namedlist import INamedList, as_namedlist, trait
from typing import NamedTuple, List, Optional, Union
from pprint import pformat
from enum import Enum, auto as _auto
class ContextType(Enum):
"""
Generator
Coroutine
"""
Module = _auto()
Generator = _auto() # yield
Coroutine = _auto() # async
Annotation = _auto()
ClassDef = _auto()
class AnalyzedSymTable(NamedTuple):
bounds: Optional[set]
freevars: Optional[set]
cellvars: Optional[set]
class SymTable(INamedList, metaclass=trait(as_namedlist)):
requires: set
entered: set
explicit_nonlocals: set
explicit_globals: set
parent: Optional['SymTable']
children: List['SymTable']
depth: int # to test if it is global context
analyzed: Optional[AnalyzedSymTable]
cts: typing.Union[typing.Set[ContextType], typing.FrozenSet[ContextType]]
def update(self,
requires: set = None,
entered: set = None,
explicit_nonlocals=None,
explicit_globals=None,
parent=None,
children=None,
depth=None,
analyzed=None,
cts=None):
return SymTable(
requires if requires is not None else self.requires,
entered if entered is not None else self.entered,
explicit_nonlocals
if explicit_nonlocals is not None else self.explicit_nonlocals,
explicit_globals if explicit_globals is not None else self.explicit_globals,
parent if parent is not None else self.parent,
children if children is not None else self.children,
depth if depth is not None else self.depth,
analyzed if analyzed is not None else self.analyzed,
cts if cts is not None else self.cts,
)
@staticmethod
def global_context():
return SymTable(
requires=set(),
entered=set(),
explicit_globals=set(),
explicit_nonlocals=set(),
parent=None,
children=[],
depth=0,
analyzed=None,
cts={ContextType.Module},
)
def enter_new(self):
new = self.update(
requires=set(),
entered=set(),
explicit_globals=set(),
explicit_nonlocals=set(),
parent=self,
children=[],
depth=self.depth + 1,
cts=set(),
)
self.children.append(new)
return new
def can_resolve_by_parents(self, symbol: str):
return (symbol in self.analyzed.bounds
or self.parent and self.parent.can_resolve_by_parents(symbol))
def resolve_bounds(self):
enters = self.entered
nonlocals = self.explicit_nonlocals
globals_ = self.explicit_globals
# split bounds
bounds = {
each
for each in enters
if each not in nonlocals and each not in globals_
}
self.analyzed = AnalyzedSymTable(bounds, set(), set())
return bounds
def resolve_freevars(self):
enters = self.entered
requires = self.requires - enters
nonlocals = self.explicit_nonlocals
freevars = self.analyzed.freevars
freevars.update(
nonlocals.union(
{each
for each in requires
if self.parent.can_resolve_by_parents(each)}))
return freevars
def resolve_cellvars(self):
def fetched_from_outside(sym_tb: SymTable):
return sym_tb.analyzed.freevars.union(
*(each.analyze().analyzed.freevars for each in sym_tb.children),
)
analyzed = self.analyzed
cellvars = analyzed.cellvars
bounds = analyzed.bounds
requires_from_sub_contexts = fetched_from_outside(self)
cellvars.update(requires_from_sub_contexts.intersection(bounds))
borrowed_freevars = (requires_from_sub_contexts - cellvars)
bounds.difference_update(cellvars)
analyzed.freevars.update(borrowed_freevars)
return cellvars
def is_global(self):
return self.depth == 0
def analyze(self):
if self.analyzed is not None:
return self
if self.is_global():
# global context
self.analyzed = AnalyzedSymTable(set(), set(), set())
for each in self.children:
each.analyze()
return self
# analyze local table.
self.resolve_bounds()
self.resolve_freevars()
self.resolve_cellvars()
return self
def show_resolution(self):
def show_resolution(this):
return [this.analyzed, [show_resolution(each) for each in this.children]]
return pformat(show_resolution(self))
class Tag(ast.AST):
it: ast.AST
tag: SymTable
def __init__(self, it, tag):
super().__init__()
self.it = it
self.tag = tag
_fields = 'it',
def _visit_name(self, node: ast.Name):
symtable = self.symtable
name = node.id
if isinstance(node.ctx, ast.Store):
symtable.entered.add(name)
elif isinstance(node.ctx, ast.Load):
symtable.requires.add(name)
return node
def _visit_import(self, node: ast.ImportFrom):
for each in node.names:
name = each.asname or each.name
self.symtable.entered.add(name)
return node
def _visit_global(self, node: ast.Global):
self.symtable.explicit_globals.update(node.names)
return node
def _visit_nonlocal(self, node: ast.Nonlocal):
self.symtable.explicit_nonlocals.update(node.names)
return node
def visit_suite(visit_fn, suite: list):
return [visit_fn(each) for each in suite]
def _visit_cls(self: 'ASTTagger', node: ast.ClassDef):
bases = visit_suite(self.visit, node.bases)
keywords = visit_suite(self.visit, node.keywords)
decorator_list = visit_suite(self.visit, node.decorator_list)
self.symtable.entered.add(node.name)
new = self.symtable.enter_new()
new.entered.add('__module__')
new.entered.add('__qualname__') # pep-3155 nested name.
new_tagger = ASTTagger(new)
new.cts.add(ContextType.ClassDef)
body = visit_suite(new_tagger.visit, node.body)
node.bases = bases
node.keywords = keywords
node.decorator_list = decorator_list
node.body = body
return Tag(node, new)
def _visit_await(self: 'ASTTagger', node: ast.Await):
cts = self.symtable.cts
if ContextType.Coroutine not in cts:
cts.add(ContextType.Coroutine)
return self.generic_visit(node)
def _visit_list_set_gen_comp(self: 'ASTTagger', node: ast.ListComp):
new = self.symtable.enter_new()
new.entered.add('.0')
new_tagger = ASTTagger(new)
node.elt = new_tagger.visit(node.elt)
head, *tail = node.generators
head.iter = self.visit(head.iter)
head.target = new_tagger.visit(head.target)
if head.ifs:
head.ifs = [new_tagger.visit(each) for each in head.ifs]
if any(each.is_async for each in node.generators):
new.cts.add(ContextType.Coroutine)
node.generators = [head, *[new_tagger.visit(each) for each in tail]]
return Tag(node, new)
def _visit_dict_comp(self: 'ASTTagger', node: ast.DictComp):
new = self.symtable.enter_new()
new.entered.add('.0')
new_tagger = ASTTagger(new)
node.key = new_tagger.visit(node.key)
node.value = new_tagger.visit(node.value)
head, *tail = node.generators
head.iter = self.visit(head.iter)
head.target = new_tagger.visit(head.target)
if head.ifs:
head.ifs = [new_tagger.visit(each) for each in head.ifs]
if any(each.is_async for each in node.generators):
new.cts.add(ContextType.Coroutine)
node.generators = [head, *[new_tagger.visit(each) for each in tail]]
return Tag(node, new)
def _visit_yield(self: 'ASTTagger', node: ast.Yield):
self.symtable.cts.add(ContextType.Generator)
return node
def _visit_yield_from(self: 'ASTTagger', node: ast.YieldFrom):
self.symtable.cts.add(ContextType.Generator)
return node
def _visit_ann_assign(self: 'ASTTagger', node: ast.AnnAssign):
self.symtable.cts.add(ContextType.Annotation)
return node
def _visit_fn_def(self: 'ASTTagger', node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
self.symtable.entered.add(node.name)
args = node.args
visit_suite(self.visit, node.decorator_list)
visit_suite(self.visit, args.defaults)
visit_suite(self.visit, args.kw_defaults)
if node.returns:
node.returns = self.visit(node.returns)
new = self.symtable.enter_new()
if isinstance(node, ast.AsyncFunctionDef):
new.cts.add(ContextType.Coroutine)
arguments = args.args + args.kwonlyargs
if args.vararg:
arguments.append(args.vararg)
if args.kwarg:
arguments.append(args.kwarg)
for arg in arguments:
annotation = arg.annotation
if annotation:
self.visit(annotation)
new.entered.add(arg.arg)
new_tagger = ASTTagger(new)
node.body = [new_tagger.visit(each) for each in node.body]
return Tag(node, new)
def _visit_lam(self: 'ASTTagger', node: ast.Lambda):
args = node.args
new = self.symtable.enter_new()
arguments = args.args + args.kwonlyargs
if args.vararg:
arguments.append(args.vararg)
if args.kwarg:
arguments.append(args.kwarg)
for arg in arguments:
# lambda might be able to annotated in the future?
annotation = arg.annotation
if annotation:
self.visit(annotation)
new.entered.add(arg.arg)
new_tagger = ASTTagger(new)
node.body = new_tagger.visit(node.body)
return Tag(node, new)
class ASTTagger(ast.NodeTransformer):
def __init__(self, symtable: SymTable):
self.symtable = symtable
visit_Name = _visit_name
visit_Import = _visit_import
visit_ImportFrom = _visit_import
visit_Global = _visit_global
visit_Nonlocal = _visit_nonlocal
visit_FunctionDef = _visit_fn_def
visit_AsyncFunctionDef = _visit_fn_def
visit_Lambda = _visit_lam
visit_ListComp = visit_SetComp = visit_GeneratorExp = _visit_list_set_gen_comp
visit_DictComp = _visit_dict_comp
visit_Yield = _visit_yield
visit_YieldFrom = _visit_yield_from
visit_AnnAssign = _visit_ann_assign
visit_ClassDef = _visit_cls
visit_Await = _visit_await
def to_tagged_ast(node: ast.Module):
global_table = SymTable.global_context()
# transform ast node to tagged. visit is an proxy method to spec method.
node = Tag(ASTTagger(global_table).visit(node), global_table)
global_table.analyze()
return node
if __name__ == '__main__':
import ast
mod = ("""
class h():
def docmodule(self, object, name=None, mod=None, *ignored):
lambda t: self.modulelink(t[1])
""")
print(mod)
mod = ast.parse(mod)
g = SymTable.global_context()
ASTTagger(g).visit(mod)
g.analyze()
print(g.show_resolution())