Skip to content

Commit 419647b

Browse files
committed
fix alias generation / units
1 parent 4376708 commit 419647b

2 files changed

Lines changed: 114 additions & 8 deletions

File tree

src/strip_python3.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> Optional[ast.AST]: # pylint
784784
self.importfrom[modulename][symbol.name] = symbol.asname or symbol.name
785785
origname = modulename + "." + symbol.name
786786
codename = symbol.name if not symbol.asname else symbol.asname
787-
stmt = ast.ImportFrom(imports.module, [ast.alias(symbol.name, symbol.asname)], imports.level)
787+
stmt = ast.ImportFrom(imports.module, [ast.alias(symbol.name, symbol.asname if symbol.asname != symbol.name else None)], imports.level)
788788
self.imported[origname] = stmt
789789
self.asimport[codename] = origname
790790
return self.generic_visit(node)
@@ -793,7 +793,7 @@ def visit_Import(self, node: ast.Import) -> Optional[ast.AST]: # pylint: disabl
793793
for symbol in imports.names:
794794
origname = symbol.name
795795
codename = symbol.name if not symbol.asname else symbol.asname
796-
stmt = ast.Import([ast.alias(symbol.name, symbol.asname)])
796+
stmt = ast.Import([ast.alias(symbol.name, symbol.asname if symbol.asname != symbol.name else None)])
797797
self.imported[origname] = stmt
798798
self.asimport[codename] = origname
799799
return self.generic_visit(node)
@@ -967,13 +967,12 @@ def visit(self, node: ast.AST) -> ast.AST:
967967
body.append(stmt)
968968
else:
969969
if simple:
970-
body.append(ast.Import([ast.alias(mod, simple[mod]) for mod in sorted(simple)]))
970+
body.append(ast.Import([ast.alias(mod, simple[mod] if simple[mod] != mod else None) for mod in sorted(simple)]))
971971
for mod in sorted(dotted):
972972
alias = dotted[mod]
973973
if alias and "." in mod:
974974
libname, sym = mod.rsplit(".", 1)
975-
renamed = alias if sym != alias else None
976-
body.append(ast.ImportFrom(libname, [ast.alias(sym, renamed)], 0))
975+
body.append(ast.ImportFrom(libname, [ast.alias(sym, alias if alias != sym else None)], 0))
977976
else:
978977
body.append(ast.Import([ast.alias(mod, alias)]))
979978
body.append(stmt)
@@ -989,13 +988,12 @@ def visit(self, node: ast.AST) -> ast.AST:
989988
body.append(stmt)
990989
else:
991990
if simple:
992-
body.append(ast.Import([ast.alias(mod, simple[mod]) for mod in sorted(simple)]))
991+
body.append(ast.Import([ast.alias(mod, simple[mod] if simple[mod] != mod else None) for mod in sorted(simple)]))
993992
for mod in sorted(dotted):
994993
alias = dotted[mod]
995994
if alias and "." in mod:
996995
libname, sym = mod.rsplit(".", 1)
997-
renamed = alias if sym != alias else None
998-
body.append(ast.ImportFrom(libname, [ast.alias(sym, renamed)], 0))
996+
body.append(ast.ImportFrom(libname, [ast.alias(sym, alias if alias != sym else None)], 0))
999997
else:
1000998
body.append(ast.Import([ast.alias(mod, alias)]))
1001999
body.append(stmt)

tests/unittests.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
__author__ = "Guido U. Draheim"
88
__version__ = "1.1.1127"
99

10+
from typing import cast
1011
import sys
1112
import unittest
1213
import logging
1314
import os.path
1415
from fnmatch import fnmatchcase as fnmatch
16+
import ast
1517

1618
logg = logging.getLogger(os.path.basename(__file__))
1719

@@ -90,6 +92,112 @@ def test_1112(self) -> None:
9092
self.assertEqual(c, "a\nb\n")
9193
self.assertEqual(d, "a\n b\n")
9294
self.assertEqual(e, "a\n b\n")
95+
def test_1201(self) -> None:
96+
other = ast.Constant(1) # unknown element
97+
have: ast.Module = app.pyi_module([other])
98+
have0 = cast(ast.Constant, have.body[0])
99+
self.assertEqual(other.value, have0.value)
100+
def test_1210(self) -> None:
101+
pyi = ast.parse(app.text4("""
102+
def foo(a: A) -> B:
103+
pass
104+
"""))
105+
py1 = ast.parse(app.text4("""
106+
from x import A
107+
"""))
108+
py2 = ast.parse(app.text4("""
109+
from y import B
110+
"""))
111+
want = app.text4("""
112+
from x import A
113+
from y import B
114+
115+
def foo(a: A) -> B:
116+
pass""")
117+
pyi2 = app.pyi_copy_imports(pyi, py1, py2)
118+
have = ast.unparse(pyi2) + "\n"
119+
self.assertEqual(want, have)
120+
def test_1211(self) -> None:
121+
pyi = ast.parse(app.text4("""
122+
def foo(a: A) -> B:
123+
pass
124+
"""))
125+
py1 = ast.parse(app.text4("""
126+
from x.z import A
127+
"""))
128+
py2 = ast.parse(app.text4("""
129+
from y.z import B
130+
"""))
131+
want = app.text4("""
132+
from x.z import A
133+
from y.z import B
134+
135+
def foo(a: A) -> B:
136+
pass""")
137+
pyi2 = app.pyi_copy_imports(pyi, py1, py2)
138+
have = ast.unparse(pyi2) + "\n"
139+
self.assertEqual(want, have)
140+
def test_1212(self) -> None:
141+
pyi = ast.parse(app.text4("""
142+
def foo(a: x.A) -> y.B:
143+
pass
144+
"""))
145+
py1 = ast.parse(app.text4("""
146+
import x
147+
"""))
148+
py2 = ast.parse(app.text4("""
149+
import y
150+
"""))
151+
want = app.text4("""
152+
import x, y
153+
154+
def foo(a: x.A) -> y.B:
155+
pass""")
156+
pyi2 = app.pyi_copy_imports(pyi, py1, py2)
157+
have = ast.unparse(pyi2) + "\n"
158+
self.assertEqual(want, have)
159+
def test_1213(self) -> None:
160+
pyi = ast.parse(app.text4("""
161+
def foo(a: x.A) -> y.B:
162+
pass
163+
"""))
164+
py1 = ast.parse(app.text4("""
165+
import app1.x as x
166+
"""))
167+
py2 = ast.parse(app.text4("""
168+
import app2.y as y
169+
"""))
170+
want = app.text4("""
171+
from app1 import x
172+
from app2 import y
173+
174+
def foo(a: x.A) -> y.B:
175+
pass""")
176+
pyi2 = app.pyi_copy_imports(pyi, py1, py2)
177+
have = ast.unparse(pyi2) + "\n"
178+
self.assertEqual(want, have)
179+
@unittest.expectedFailure
180+
def test_1215(self) -> None:
181+
pyi = ast.parse(app.text4("""
182+
def foo(a: app1.x.A) -> app2.y.B:
183+
pass
184+
"""))
185+
py1 = ast.parse(app.text4("""
186+
import app1.x
187+
"""))
188+
py2 = ast.parse(app.text4("""
189+
import app2.y
190+
"""))
191+
want = app.text4("""
192+
import app1.x
193+
import app2.y
194+
195+
def foo(a: app1.x.A) -> app2.y.B:
196+
pass""")
197+
pyi2 = app.pyi_copy_imports(pyi, py1, py2)
198+
have = ast.unparse(pyi2) + "\n"
199+
self.assertEqual(want, have)
200+
93201

94202
if __name__ == "__main__":
95203
# unittest.main()

0 commit comments

Comments
 (0)