Skip to content

Commit 7da8248

Browse files
committed
test enum class
1 parent 9c88c45 commit 7da8248

3 files changed

Lines changed: 44 additions & 16 deletions

File tree

tests/exectests3.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,18 @@ def q_str(part: Union[str, int, None]) -> str:
6060
if isinstance(part, int):
6161
return str(part)
6262
return "'%s'" % part # pylint: disable=consider-using-f-string
63-
def sh____(cmd: Union[str, List[str]], shell: bool = True) -> int:
63+
def sh____(cmd: Union[str, List[str]], shell: bool = True, env=None) -> int:
6464
if isinstance(cmd, string_types):
6565
logg.info(": %s", cmd)
6666
else:
6767
logg.info(": %s", " ".join([q_str(item) for item in cmd]))
68-
return subprocess.check_call(cmd, shell=shell)
69-
def sx____(cmd: Union[str, List[str]], shell: bool = True) -> int:
68+
return subprocess.check_call(cmd, shell=shell, env=env)
69+
def sx____(cmd: Union[str, List[str]], shell: bool = True, env=None) -> int:
7070
if isinstance(cmd, string_types):
7171
logg.info(": %s", cmd)
7272
else:
7373
logg.info(": %s", " ".join([q_str(item) for item in cmd]))
74-
return subprocess.call(cmd, shell=shell)
74+
return subprocess.call(cmd, shell=shell, env=env)
7575

7676
class CalledProcessError(subprocess.SubprocessError):
7777
def __init__(self, args: Union[str, List[str]], returncode: int = 0, stdout: Union[str,bytes] = NIX, stderr: Union[str,bytes] = NIX) -> None:
@@ -889,6 +889,36 @@ class X(TypedDict):
889889
self.assertEqual(x2.out, "Success: no issues found in 1 source file")
890890
self.rm_testdir()
891891
self.end()
892+
def test_3541(self) -> None:
893+
""" check Enum classes are replaced by local def"""
894+
vv = self.begin()
895+
python = PYTHON
896+
tmp = self.testdir()
897+
text_file(F"{tmp}/test3.py", """
898+
from enum import Enum
899+
class A(Enum):
900+
B = 2
901+
C = 3
902+
def func1() -> A:
903+
return A(3)
904+
f = func1()
905+
if f is A.C:
906+
print("OK:", f)
907+
else:
908+
print("NO:", f)
909+
""")
910+
sh____(F"{PYTHON3} {STRIP} -3 {tmp}/test3.py {vv}", env={"PYTHON3_ENUM_CLASS_ATLEAST":"3.99"})
911+
self.assertTrue(os.path.exists(F"{tmp}/test.py"))
912+
self.assertTrue(os.path.exists(F"{tmp}/test.pyi"))
913+
script = lines4(open(F"{tmp}/test.py").read())
914+
logg.info("script = %s", script)
915+
self.assertTrue(greps(script, "(3, 99)"))
916+
self.assertTrue(greps(script, "class Enum"))
917+
x1 = X(F"{python} {tmp}/test.py")
918+
logg.info("%s -> %s\n%s", x1.args, x1.out, x1.err)
919+
self.assertTrue(greps(x1.out, "OK: <C: 3>"))
920+
self.rm_testdir()
921+
self.end()
892922

893923

894924
if __name__ == "__main__":

tests/transformertests2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4522,6 +4522,7 @@ def func1() -> int:
45224522
self.coverage()
45234523
self.rm_testdir()
45244524
def test_2541(self) -> None:
4525+
""" check Enum classes are replaced by local def"""
45254526
vv = self.begin()
45264527
strip = coverage(STRIP)
45274528
tmp = self.testdir()
@@ -4541,6 +4542,7 @@ def func1() -> int:
45414542
py, pyi = file_text4(F"{tmp}/test.py"), file_text4(F"{tmp}/test.pyi")
45424543
logg.debug("py:\n%s", py)
45434544
self.assertEqual(lines4(py), lines4(text4("""
4545+
import sys
45444546
if sys.version_info >= (3, 3):
45454547
from enum import Enum
45464548
else:

tool/strip_python3.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,8 +1191,7 @@ def visit(self, node: ast.AST) -> ast.AST:
11911191

11921192
class EnumClassTransformer(DetectImportsTransformer):
11931193
typedefs: List[ast.stmt]
1194-
requiresfrom: Set[str]
1195-
only: Set[str]
1194+
requires: List[str]
11961195
_Enum = """class Enum:
11971196
11981197
def __new__(cls, *values):
@@ -1215,21 +1214,15 @@ def __iter__(self):
12151214
elem = getattr(cls, name)
12161215
if isinstance(elem, Enum):
12171216
yield elem
1217+
12181218
def __str__(self):
12191219
return "<%s: %s>" % (self.name, self.value)
12201220
"""
12211221
def __init__(self) -> None:
12221222
DetectImportsTransformer.__init__(self)
1223-
self.only = set()
1224-
def visit(self, node: ast.AST) -> ast.AST:
1225-
if isinstance(node, ast.Module):
1226-
module = cast(ast.Module, node) # type: ignore[redundant-cast]
1227-
for stmt in module.body:
1228-
if isinstance(stmt, ast.ClassDef):
1229-
self.only.add(stmt.name) # only top-level class names
1230-
return cast(ast.AST, DetectImportsTransformer.visit(self, node))
1223+
self.requires = []
12311224
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST: # pylint: disable=invalid-name
1232-
atleast = (3, 3)
1225+
atleast = [int(val) for val in os.environ.get("PYTHON3_ENUM_CLASS_ATLEAST", "3.3").split(".")]
12331226
imports: ast.ImportFrom = node
12341227
if imports.module and imports.module == "enum":
12351228
orelse: List[ast.stmt] = []
@@ -1245,10 +1238,12 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST: # pylint: disable=i
12451238
testcompare = testbody.value
12461239
python2 = ast.If(test=testcompare, body=[imports], orelse=orelse)
12471240
python2 = copy_location(python2, imports)
1241+
if "sys" not in self.requires:
1242+
self.requires += ["sys"]
12481243
return python2
12491244
return node
12501245
def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST: # pylint: disable=invalid-name
1251-
atleast = (3, 3)
1246+
atleast = [int(val) for val in os.environ.get("PYTHON3_ENUM_CLASS_ATLEAST", "3.3").split(".")]
12521247
classname = node.name
12531248
for base in node.bases:
12541249
if isinstance(base, ast.Name):
@@ -2528,6 +2523,7 @@ def visit(self, tree: ast.AST) -> ast.AST:
25282523
if want.replace_enum_class:
25292524
typedenum = EnumClassTransformer()
25302525
tree = typedenum.visit(tree)
2526+
importrequires.append(typedenum.requires)
25312527
extracted = ExtractTypeHints()
25322528
tree = extracted.visit(tree)
25332529
self.typedefs.extend(extracted.typedefs)

0 commit comments

Comments
 (0)