diff --git a/.gitignore b/.gitignore index df191195..0ec272f6 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ vulture.egg-info/ .pytest_cache/ .tox/ .venv/ +.vscode/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 42d9d406..8e7a0b4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ * Add `UnicodeEncodeError` exception handling to `core.py` (milanbalazs, #299). * Add whitelist for `Enum` attributes `_name_` and `_value_` (Eugene Toder, #305). +* Fix false positive when iterating over `Enum` (anudaweerasinghe, pm3512, addykan, #304) # 2.7 (2023-01-08) @@ -301,4 +302,4 @@ # 0.1 (2012-03-17) -* First release. +* First release. \ No newline at end of file diff --git a/tests/test_scavenging.py b/tests/test_scavenging.py index 27154cd6..6a287547 100644 --- a/tests/test_scavenging.py +++ b/tests/test_scavenging.py @@ -899,3 +899,38 @@ class Color(Enum): check(v.unused_classes, []) check(v.unused_vars, ["BLUE"]) + + +def test_enum_list(v): + v.scan( + """\ +import enum +class E(enum.Enum): + A = 1 + B = 2 + +print(list(E)) +""" + ) + + check(v.defined_classes, ["E"]) + check(v.defined_vars, ["A", "B"]) + check(v.unused_vars, []) + + +def test_enum_for(v): + v.scan( + """\ +import enum +class E(enum.Enum): + A = 1 + B = 2 + +for e in E: + print(e) +""" + ) + + check(v.defined_classes, ["E"]) + check(v.defined_vars, ["A", "B", "e"]) + check(v.unused_vars, []) diff --git a/vulture/core.py b/vulture/core.py index 16c71948..523a3dbc 100644 --- a/vulture/core.py +++ b/vulture/core.py @@ -219,6 +219,10 @@ def get_list(typ): self.code = [] self.found_dead_code_or_error = False + self.enum_class_vars = ( + dict() + ) # stores variables defined in enum classes + def scan(self, code, filename=""): filename = Path(filename) self.code = code.splitlines() @@ -551,6 +555,18 @@ def visit_Call(self, node): ): self._handle_new_format_string(node.func.value.s) + # handle enum.Enum members + iter_functions = ["list", "tuple", "set"] + if ( + isinstance(node.func, ast.Name) + and node.func.id in iter_functions + and len(node.args) > 0 + and isinstance(node.args[0], ast.Name) + ): + arg = node.args[0].id + if arg in self.enum_class_vars: + self.used_names.update(self.enum_class_vars[arg]) + def _handle_new_format_string(self, s): def is_identifier(name): return bool(re.match(r"[a-zA-Z_][a-zA-Z0-9_]*", name)) @@ -581,6 +597,20 @@ def _is_locals_call(node): and not node.keywords ) + @staticmethod + def _is_subclass(node, class_name): + """Return True if the node is a subclass of the given class.""" + assert isinstance(node, ast.ClassDef) + for superclass in node.bases: + if ( + isinstance(superclass, ast.Name) + and superclass.id == class_name + or isinstance(superclass, ast.Attribute) + and superclass.attr == class_name + ): + return True + return False + def visit_ClassDef(self, node): for decorator in node.decorator_list: if _match( @@ -594,6 +624,15 @@ def visit_ClassDef(self, node): self._define( self.defined_classes, node.name, node, ignore=_ignore_class ) + # if subclasses enum add class variables to enum_class_vars + if self._is_subclass(node, "Enum"): + newKey = node.name + classVariables = [] + for stmt in node.body: + if isinstance(stmt, ast.Assign): + for target in stmt.targets: + classVariables.append(target.id) + self.enum_class_vars[newKey] = classVariables def visit_FunctionDef(self, node): decorator_names = [ @@ -661,6 +700,14 @@ def visit_Assign(self, node): def visit_While(self, node): self._handle_conditional_node(node, "while") + def visit_For(self, node): + # Handle iterating over Enum + if ( + isinstance(node.iter, ast.Name) + and node.iter.id in self.enum_class_vars + ): + self.used_names.update(self.enum_class_vars[node.iter.id]) + def visit_MatchClass(self, node): for kwd_attr in node.kwd_attrs: self.used_names.add(kwd_attr)