diff --git a/tests/test_ast.py b/tests/test_ast.py index 388b3ae..b90c9d5 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -260,7 +260,7 @@ def bar(): """ ast.build_ast(src) functiondef_node = ast.get_internal_function_nodes()[0] - fn_ast = AST.create_new_instance(functiondef_node) + fn_ast = AST.from_node(functiondef_node) references = fn_ast.find_nodes_referencing_symbol("x") assert len(references) == 1 assert ( diff --git a/tests/test_navigation.py b/tests/test_navigation.py index a7ffd9b..0cc8b37 100644 --- a/tests/test_navigation.py +++ b/tests/test_navigation.py @@ -8,16 +8,21 @@ @pytest.fixture -def doc(): - doc = Document(uri="examples/Foo.vy") +def ast(): ast = AST() + return ast + + +@pytest.fixture +def doc(ast): + doc = Document(uri="examples/Foo.vy") ast.build_ast(doc.source) return doc @pytest.fixture -def navigator(): - return ASTNavigator() +def navigator(ast): + return ASTNavigator(ast) def test_find_references_event_name(doc, navigator): diff --git a/vyper_lsp/analyzer/AstAnalyzer.py b/vyper_lsp/analyzer/AstAnalyzer.py index 819f949..5efe763 100644 --- a/vyper_lsp/analyzer/AstAnalyzer.py +++ b/vyper_lsp/analyzer/AstAnalyzer.py @@ -7,7 +7,6 @@ from vyper.compiler import CompilerData from vyper.exceptions import VyperException from vyper_lsp.analyzer.BaseAnalyzer import Analyzer -from vyper_lsp.ast import AST from vyper_lsp.utils import ( get_expression_at_cursor, get_word_at_cursor, @@ -39,9 +38,9 @@ class AstAnalyzer(Analyzer): - def __init__(self, ast=None) -> None: + def __init__(self, ast) -> None: super().__init__() - self.ast = ast or AST() + self.ast = ast if get_installed_vyper_version() < min_vyper_version: self.diagnostics_enabled = False else: diff --git a/vyper_lsp/ast.py b/vyper_lsp/ast.py index 5ff0c04..c7eb6d6 100644 --- a/vyper_lsp/ast.py +++ b/vyper_lsp/ast.py @@ -5,22 +5,21 @@ from vyper.ast import VyperNode, nodes from vyper.compiler import CompilerData -ast = None - class AST: - _instance = None ast_data = None ast_data_folded = None ast_data_unfolded = None custom_type_node_types = (nodes.StructDef, nodes.EnumDef, nodes.EventDef) - def __new__(cls): - if cls._instance is None: - cls._instance = super(AST, cls).__new__(cls) - cls._instance.ast_data = None - return cls._instance + @classmethod + def from_node(cls, node: VyperNode): + ast = cls() + ast.ast_data = node + ast.ast_data_unfolded = node + ast.ast_data_folded = node + return ast def update_ast(self, document): self.build_ast(document.source) @@ -45,47 +44,38 @@ def build_ast(self, src: str): print(f"Error generating folded AST, {e}") pass - def get_descendants_from_best_ast(self, *args, **kwargs): + @property + def best_ast(self): if self.ast_data_unfolded: - return self.ast_data_unfolded.get_descendants(*args, **kwargs) + return self.ast_data_unfolded elif self.ast_data: - return self.ast_data.get_descendants(*args, **kwargs) + return self.ast_data elif self.ast_data_folded: - return self.ast_data_folded.get_descendants(*args, **kwargs) + return self.ast_data_folded else: + return None + + def get_descendants(self, *args, **kwargs): + if self.best_ast is None: return [] + return self.best_ast.get_descendants(*args, **kwargs) - def get_children_from_best_ast(self, *args, **kwargs): - if self.ast_data_unfolded: - return self.ast_data_unfolded.get_children(*args, **kwargs) - elif self.ast_data: - return self.ast_data.get_children(*args, **kwargs) - elif self.ast_data_folded: - return self.ast_data_folded.get_children(*args, **kwargs) - else: + def get_top_level_nodes(self, *args, **kwargs): + if self.best_ast is None: return [] + return self.best_ast.get_children(*args, **kwargs) def get_enums(self) -> List[str]: - return [node.name for node in self.get_descendants_from_best_ast(nodes.EnumDef)] + return [node.name for node in self.get_descendants(nodes.EnumDef)] def get_structs(self) -> List[str]: - if self.ast_data_unfolded is None: - return [] - - return [ - node.name for node in self.get_descendants_from_best_ast(nodes.StructDef) - ] + return [node.name for node in self.get_descendants(nodes.StructDef)] def get_events(self) -> List[str]: - return [ - node.name for node in self.get_descendants_from_best_ast(nodes.EventDef) - ] + return [node.name for node in self.get_descendants(nodes.EventDef)] def get_user_defined_types(self): - return [ - node.name - for node in self.get_descendants_from_best_ast(self.custom_type_node_types) - ] + return [node.name for node in self.get_descendants(self.custom_type_node_types)] def get_constants(self): # NOTE: Constants should be fetched from self.ast_data, they are missing @@ -123,38 +113,31 @@ def get_state_variables(self): ] def get_internal_function_nodes(self): - function_nodes = self.get_descendants_from_best_ast(nodes.FunctionDef) - inernal_nodes = [] + function_nodes = self.get_descendants(nodes.FunctionDef) + internal_nodes = [] for node in function_nodes: for decorator in node.decorator_list: if decorator.id == "internal": - inernal_nodes.append(node) + internal_nodes.append(node) - return inernal_nodes + return internal_nodes def get_internal_functions(self): return [node.name for node in self.get_internal_function_nodes()] def find_nodes_referencing_internal_function(self, function: str): - return self.get_descendants_from_best_ast( + return self.get_descendants( nodes.Call, {"func.attr": function, "func.value.id": "self"} ) def find_nodes_referencing_state_variable(self, variable: str): - return self.get_descendants_from_best_ast( + return self.get_descendants( nodes.Attribute, {"value.id": "self", "attr": variable} ) def find_nodes_referencing_constant(self, constant: str): - # NOTE: Constants should be fetched from self.ast_data, they are missing - # from self.ast_data_unfolded and self.ast_data_folded - if self.ast_data_unfolded is None: - return [] - - name_nodes = self.ast_data_unfolded.get_descendants( - nodes.Name, {"id": constant} - ) + name_nodes = self.get_descendants(nodes.Name, {"id": constant}) return [ node for node in name_nodes @@ -174,7 +157,7 @@ def get_attributes_for_symbol(self, symbol: str): return [] def find_function_declaration_node_for_name(self, function: str): - for node in self.get_descendants_from_best_ast(nodes.FunctionDef): + for node in self.get_descendants(nodes.FunctionDef): name_match = node.name == function not_interface_declaration = not isinstance( node.get_ancestor(), nodes.InterfaceDef @@ -197,7 +180,7 @@ def find_state_variable_declaration_node_for_name(self, variable: str): return None def find_type_declaration_node_for_name(self, symbol: str): - for node in self.get_descendants_from_best_ast(self.custom_type_node_types): + for node in self.get_descendants(self.custom_type_node_types): if node.name == symbol: return node if isinstance(node, nodes.EnumDef): @@ -210,61 +193,54 @@ def find_type_declaration_node_for_name(self, symbol: str): def find_nodes_referencing_enum(self, enum: str): return_nodes = [] - for node in self.get_descendants_from_best_ast( - nodes.AnnAssign, {"annotation.id": enum} - ): + for node in self.get_descendants(nodes.AnnAssign, {"annotation.id": enum}): return_nodes.append(node) - for node in self.get_descendants_from_best_ast( - nodes.Attribute, {"value.id": enum} - ): + for node in self.get_descendants(nodes.Attribute, {"value.id": enum}): return_nodes.append(node) - for node in self.get_descendants_from_best_ast( - nodes.VariableDecl, {"annotation.id": enum} - ): + for node in self.get_descendants(nodes.VariableDecl, {"annotation.id": enum}): + return_nodes.append(node) + for node in self.get_descendants(nodes.FunctionDef, {"returns.id": enum}): return_nodes.append(node) return return_nodes def find_nodes_referencing_enum_variant(self, enum: str, variant: str): - return self.get_descendants_from_best_ast( + return self.get_descendants( nodes.Attribute, {"attr": variant, "value.id": enum} ) def find_nodes_referencing_struct(self, struct: str): return_nodes = [] - for node in self.get_descendants_from_best_ast( - nodes.AnnAssign, {"annotation.id": struct} - ): + for node in self.get_descendants(nodes.AnnAssign, {"annotation.id": struct}): return_nodes.append(node) - for node in self.get_descendants_from_best_ast(nodes.Call, {"func.id": struct}): + for node in self.get_descendants(nodes.Call, {"func.id": struct}): return_nodes.append(node) - for node in self.get_descendants_from_best_ast( - nodes.VariableDecl, {"annotation.id": struct} - ): + for node in self.get_descendants(nodes.VariableDecl, {"annotation.id": struct}): return_nodes.append(node) - for node in self.get_descendants_from_best_ast( - nodes.FunctionDef, {"returns.id": struct} - ): + for node in self.get_descendants(nodes.FunctionDef, {"returns.id": struct}): return_nodes.append(node) return return_nodes def find_top_level_node_at_pos(self, pos: Position) -> Optional[VyperNode]: - for node in self.get_children_from_best_ast(): - if node.lineno <= pos.line and node.end_lineno >= pos.line: + for node in self.get_top_level_nodes(): + if node.lineno <= pos.line and pos.line <= node.end_lineno: return node def find_nodes_referencing_symbol(self, symbol: str): + # this only runs on subtrees return_nodes = [] - for node in self.get_descendants_from_best_ast(nodes.Name, {"id": symbol}): + for node in self.get_descendants(nodes.Name, {"id": symbol}): parent = node.get_ancestor() if isinstance(parent, nodes.Dict): + # skip struct key names if symbol not in [key.id for key in parent.keys]: return_nodes.append(node) - elif isinstance(node.get_ancestor(), nodes.AnnAssign): - if node.id == node.get_ancestor().target.id: + elif isinstance(parent, nodes.AnnAssign): + if node.id == parent.target.id: + # lhs of variable declaration continue else: return_nodes.append(node) @@ -274,18 +250,6 @@ def find_nodes_referencing_symbol(self, symbol: str): return return_nodes def find_node_declaring_symbol(self, symbol: str): - for node in self.get_descendants_from_best_ast( - (nodes.AnnAssign, nodes.VariableDecl) - ): + for node in self.get_descendants((nodes.AnnAssign, nodes.VariableDecl)): if node.target.id == symbol: return node - - @classmethod - def create_new_instance(cls, ast): - # Create a new instance - new_instance = super(AST, cls).__new__(cls) - # Optionally, initialize the new instance - new_instance.ast_data = ast - new_instance.ast_data_unfolded = ast - new_instance.ast_data_folded = ast - return new_instance diff --git a/vyper_lsp/main.py b/vyper_lsp/main.py index c6b5ca2..5722a36 100755 --- a/vyper_lsp/main.py +++ b/vyper_lsp/main.py @@ -37,18 +37,19 @@ from .ast import AST +ast = AST() + server = LanguageServer("vyper", "v0.0.1") -navigator = ASTNavigator() +navigator = ASTNavigator(ast) # AstAnalyzer is faster and better, but depends on the locally installed vyper version # we should keep it around for now and use it when the contract version pragma is missing # or if the version pragma matches the system version. its much faster so we can run it # on every keystroke, with sourceanalyzer we should only run it on save -ast_analyzer = AstAnalyzer() +ast_analyzer = AstAnalyzer(ast) completer = ast_analyzer source_analyzer = SourceAnalyzer() -ast = AST() debouncer = Debouncer(wait=0.5) diff --git a/vyper_lsp/navigation.py b/vyper_lsp/navigation.py index 625630c..f2202e6 100644 --- a/vyper_lsp/navigation.py +++ b/vyper_lsp/navigation.py @@ -15,8 +15,8 @@ # # the navigator should mainly return Ranges class ASTNavigator: - def __init__(self, ast=None): - self.ast = ast or AST() + def __init__(self, ast): + self.ast = ast def find_state_variable_declaration(self, word: str) -> Optional[Range]: node = self.ast.find_state_variable_declaration_node_for_name(word) @@ -30,7 +30,7 @@ def find_state_variable_declaration(self, word: str) -> Optional[Range]: def find_variable_declaration_under_node( self, node: VyperNode, symbol: str ) -> Optional[Range]: - decl_node = AST.create_new_instance(node).find_node_declaring_symbol(symbol) + decl_node = AST.from_node(node).find_node_declaring_symbol(symbol) if decl_node: range = Range( start=Position( @@ -120,9 +120,7 @@ def find_references(self, doc: Document, pos: Position) -> List[Range]: ) references.append(range) elif isinstance(top_level_node, FunctionDef): - refs = AST.create_new_instance( - top_level_node - ).find_nodes_referencing_symbol(word) + refs = AST.from_node(top_level_node).find_nodes_referencing_symbol(word) for ref in refs: range = Range( start=Position(line=ref.lineno - 1, character=ref.col_offset),