Skip to content

Commit f56e761

Browse files
authored
Merge pull request #171 from bcaller/chains
Chained function calls separated into multiple assignments
2 parents 11567c4 + 2e91ce7 commit f56e761

File tree

6 files changed

+108
-7
lines changed

6 files changed

+108
-7
lines changed

Diff for: pyt/core/ast_helper.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import subprocess
77
from functools import lru_cache
88

9-
from .transformer import AsyncTransformer
9+
from .transformer import PytTransformer
1010

1111

1212
BLACK_LISTED_CALL_NAMES = ['self']
@@ -35,7 +35,7 @@ def generate_ast(path):
3535
with open(path, 'r') as f:
3636
try:
3737
tree = ast.parse(f.read())
38-
return AsyncTransformer().visit(tree)
38+
return PytTransformer().visit(tree)
3939
except SyntaxError: # pragma: no cover
4040
global recursive
4141
if not recursive:

Diff for: pyt/core/transformer.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ast
22

33

4-
class AsyncTransformer(ast.NodeTransformer):
4+
class AsyncTransformer():
55
"""Converts all async nodes into their synchronous counterparts."""
66

77
def visit_Await(self, node):
@@ -16,3 +16,53 @@ def visit_AsyncFor(self, node):
1616

1717
def visit_AsyncWith(self, node):
1818
return self.visit(ast.With(**node.__dict__))
19+
20+
21+
class ChainedFunctionTransformer():
22+
def visit_chain(self, node, depth=1):
23+
if (
24+
isinstance(node.value, ast.Call) and
25+
isinstance(node.value.func, ast.Attribute) and
26+
isinstance(node.value.func.value, ast.Call)
27+
):
28+
# Node is assignment or return with value like `b.c().d()`
29+
call_node = node.value
30+
# If we want to handle nested functions in future, depth needs fixing
31+
temp_var_id = '__chain_tmp_{}'.format(depth)
32+
# AST tree is from right to left, so d() is the outer Call and b.c() is the inner Call
33+
unvisited_inner_call = ast.Assign(
34+
targets=[ast.Name(id=temp_var_id, ctx=ast.Store())],
35+
value=call_node.func.value,
36+
)
37+
ast.copy_location(unvisited_inner_call, node)
38+
inner_calls = self.visit_chain(unvisited_inner_call, depth + 1)
39+
for inner_call_node in inner_calls:
40+
ast.copy_location(inner_call_node, node)
41+
outer_call = self.generic_visit(type(node)(
42+
value=ast.Call(
43+
func=ast.Attribute(
44+
value=ast.Name(id=temp_var_id, ctx=ast.Load()),
45+
attr=call_node.func.attr,
46+
ctx=ast.Load(),
47+
),
48+
args=call_node.args,
49+
keywords=call_node.keywords,
50+
),
51+
**{field: value for field, value in ast.iter_fields(node) if field != 'value'} # e.g. targets
52+
))
53+
ast.copy_location(outer_call, node)
54+
ast.copy_location(outer_call.value, node)
55+
ast.copy_location(outer_call.value.func, node)
56+
return [*inner_calls, outer_call]
57+
else:
58+
return [self.generic_visit(node)]
59+
60+
def visit_Assign(self, node):
61+
return self.visit_chain(node)
62+
63+
def visit_Return(self, node):
64+
return self.visit_chain(node)
65+
66+
67+
class PytTransformer(AsyncTransformer, ChainedFunctionTransformer, ast.NodeTransformer):
68+
pass

Diff for: tests/base_test_case.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pyt.cfg import make_cfg
55
from pyt.core.ast_helper import generate_ast
66
from pyt.core.module_definitions import project_definitions
7+
from pyt.core.transformer import PytTransformer
78

89

910
class BaseTestCase(unittest.TestCase):
@@ -36,7 +37,7 @@ def cfg_create_from_ast(
3637
):
3738
project_definitions.clear()
3839
self.cfg = make_cfg(
39-
ast_tree,
40+
PytTransformer().visit(ast_tree),
4041
project_modules,
4142
local_modules,
4243
filename='?'

Diff for: tests/cfg/cfg_test.py

+32
Original file line numberDiff line numberDiff line change
@@ -1497,3 +1497,35 @@ def test_name_for(self):
14971497

14981498
self.assert_length(self.cfg.nodes, expected_length=4)
14991499
self.assertEqual(self.cfg.nodes[1].label, 'for x in l:')
1500+
1501+
1502+
class CFGFunctionChain(CFGBaseTestCase):
1503+
def test_simple(self):
1504+
self.cfg_create_from_ast(
1505+
ast.parse('a = b.c(z)')
1506+
)
1507+
middle_nodes = self.cfg.nodes[1:-1]
1508+
self.assert_length(middle_nodes, expected_length=2)
1509+
self.assertEqual(middle_nodes[0].label, '~call_1 = ret_b.c(z)')
1510+
self.assertEqual(middle_nodes[0].func_name, 'b.c')
1511+
self.assertCountEqual(middle_nodes[0].right_hand_side_variables, ['z', 'b'])
1512+
1513+
def test_chain(self):
1514+
self.cfg_create_from_ast(
1515+
ast.parse('a = b.xxx.c(z).d(y)')
1516+
)
1517+
middle_nodes = self.cfg.nodes[1:-1]
1518+
self.assert_length(middle_nodes, expected_length=4)
1519+
1520+
self.assertEqual(middle_nodes[0].left_hand_side, '~call_1')
1521+
self.assertCountEqual(middle_nodes[0].right_hand_side_variables, ['b', 'z'])
1522+
self.assertEqual(middle_nodes[0].label, '~call_1 = ret_b.xxx.c(z)')
1523+
1524+
self.assertEqual(middle_nodes[1].left_hand_side, '__chain_tmp_1')
1525+
self.assertCountEqual(middle_nodes[1].right_hand_side_variables, ['~call_1'])
1526+
1527+
self.assertEqual(middle_nodes[2].left_hand_side, '~call_2')
1528+
self.assertCountEqual(middle_nodes[2].right_hand_side_variables, ['__chain_tmp_1', 'y'])
1529+
1530+
self.assertEqual(middle_nodes[3].left_hand_side, 'a')
1531+
self.assertCountEqual(middle_nodes[3].right_hand_side_variables, ['~call_2'])

Diff for: tests/core/transformer_test.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import ast
22
import unittest
33

4-
from pyt.core.transformer import AsyncTransformer
4+
from pyt.core.transformer import PytTransformer
55

66

77
class TransformerTest(unittest.TestCase):
88
"""Tests for the AsyncTransformer."""
99

1010
def test_async_removed_by_transformer(self):
11+
self.maxDiff = 99999
1112
async_tree = ast.parse("\n".join([
1213
"async def a():",
1314
" async for b in c():",
@@ -30,7 +31,24 @@ def test_async_removed_by_transformer(self):
3031
]))
3132
self.assertIsInstance(sync_tree.body[0], ast.FunctionDef)
3233

33-
transformed = AsyncTransformer().visit(async_tree)
34+
transformed = PytTransformer().visit(async_tree)
3435
self.assertIsInstance(transformed.body[0], ast.FunctionDef)
3536

3637
self.assertEqual(ast.dump(transformed), ast.dump(sync_tree))
38+
39+
def test_chained_function(self):
40+
chained_tree = ast.parse("\n".join([
41+
"def a():",
42+
" b = c.d(e).f(g).h(i).j(k)",
43+
]))
44+
45+
separated_tree = ast.parse("\n".join([
46+
"def a():",
47+
" __chain_tmp_3 = c.d(e)",
48+
" __chain_tmp_2 = __chain_tmp_3.f(g)",
49+
" __chain_tmp_1 = __chain_tmp_2.h(i)",
50+
" b = __chain_tmp_1.j(k)",
51+
]))
52+
53+
transformed = PytTransformer().visit(chained_tree)
54+
self.assertEqual(ast.dump(transformed), ast.dump(separated_tree))

Diff for: tests/vulnerabilities/vulnerabilities_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def test_path_traversal_sanitised_2_result(self):
282282

283283
def test_sql_result(self):
284284
vulnerabilities = self.run_analysis('examples/vulnerable_code/sql/sqli.py')
285-
self.assert_length(vulnerabilities, expected_length=2)
285+
self.assert_length(vulnerabilities, expected_length=3)
286286
vulnerability_description = str(vulnerabilities[0])
287287
EXPECTED_VULNERABILITY_DESCRIPTION = """
288288
File: examples/vulnerable_code/sql/sqli.py

0 commit comments

Comments
 (0)