diff --git a/automata/code_parsers/py/__init__.py b/automata/code_parsers/py/__init__.py index 5b70dfa8..fae17859 100644 --- a/automata/code_parsers/py/__init__.py +++ b/automata/code_parsers/py/__init__.py @@ -4,7 +4,11 @@ get_node_without_docstrings, get_node_without_imports, ) -from .context_retriever import PyContextRetriever, PyContextRetrieverConfig +from .context_retriever import ( + ContextComponent, + PyContextRetriever, + PyContextRetrieverConfig, +) from .dotpath_map import DotPathMap __all__ = [ @@ -15,4 +19,5 @@ "DotPathMap", "PyContextRetriever", "PyContextRetrieverConfig", + "ContextComponent", ] diff --git a/automata/code_parsers/py/ast_utils.py b/automata/code_parsers/py/ast_utils.py index f16b1ac6..1e555c4d 100644 --- a/automata/code_parsers/py/ast_utils.py +++ b/automata/code_parsers/py/ast_utils.py @@ -15,7 +15,7 @@ get_docstring, ) from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional logger = logging.getLogger(__name__) @@ -73,7 +73,7 @@ class DocstringRemover(NodeTransformer): def visit(self, node): # If this node is a function, class, or module, remove its docstring. if ( - isinstance(node, Union[AsyncFunctionDef, ClassDef, FunctionDef, Module]) + isinstance(node, (AsyncFunctionDef, ClassDef, FunctionDef, Module)) and isinstance(node.body[0], Expr) and isinstance(node.body[0].value, Str) ): @@ -94,7 +94,7 @@ class ImportRemover(NodeTransformer): def visit(self, node): # If this node is a function, class, or module, and its first child is an import statement, # remove the import statement. - if isinstance(node, Union[AsyncFunctionDef, ClassDef, FunctionDef, Module]) and ( + if isinstance(node, (AsyncFunctionDef, ClassDef, FunctionDef, Module)) and ( isinstance(node.body[0], (Import, ImportFrom)) ): node.body.pop(0) diff --git a/automata/code_parsers/py/context_retriever.py b/automata/code_parsers/py/context_retriever.py index 6b147986..c2667935 100644 --- a/automata/code_parsers/py/context_retriever.py +++ b/automata/code_parsers/py/context_retriever.py @@ -2,7 +2,7 @@ from ast import AST, AsyncFunctionDef, ClassDef, FunctionDef, unparse, walk from contextlib import contextmanager from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Protocol, Union +from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Union from automata.code_parsers.py import ( get_docstring_from_node, @@ -113,18 +113,18 @@ def _source_code( self, symbol: Symbol, ast_object: AST, - remove_imports: bool = True, - remove_docstrings: bool = False, + include_imports: bool = False, + include_docstrings: bool = True, max_length: Optional[int] = None, *args, **kwargs, ) -> str: """Convert a symbol into underlying source code.""" - if remove_docstrings: + if not include_docstrings: ast_object = get_node_without_docstrings(ast_object) - if remove_imports: + if not include_imports: ast_object = get_node_without_imports(ast_object) source = unparse(ast_object) @@ -140,6 +140,7 @@ def _interface( header: str = "Interface:\n\n", class_header: str = "class ", recursion_depth: int = 0, + processed_classes: Optional[Set[int]] = None, # add this argument *args, **kwargs, ) -> str: @@ -147,6 +148,15 @@ def _interface( if recursion_depth > self.MAX_RECURSION_DEPTH: raise RecursionError(f"Max recursion depth of {self.MAX_RECURSION_DEPTH} exceeded.") + if processed_classes is None: + processed_classes = set() + + if id(ast_object) in processed_classes: + return "" + + # add the class so that we do not process it twice + processed_classes.add(id(ast_object)) + # indent according to indent_level interface = self.process_entry(header) @@ -167,6 +177,7 @@ def _interface( header, class_header, recursion_depth=recursion_depth + 1, + processed_classes=processed_classes, ) methods = sorted(self._get_all_methods(ast_object), key=lambda x: x.name) @@ -213,9 +224,9 @@ def _get_method_arguments(method: Union[AsyncFunctionDef, FunctionDef]) -> str: for arg in method.args.args: if arg.arg in defaults: - args.append(f"{arg.arg}={unparse(defaults[arg.arg])}") + args.append(f"{unparse(arg)}={unparse(defaults[arg.arg])}") else: - args.append(arg.arg) + args.append(unparse(arg)) # Handle keyword-only arguments if method.args.kwonlyargs: @@ -232,8 +243,8 @@ def _get_method_arguments(method: Union[AsyncFunctionDef, FunctionDef]) -> str: ) if default is not None: - args.append(f"{kwarg.arg}={unparse(default)}") + args.append(f"{unparse(kwarg)}={unparse(default)}") else: - args.append(kwarg.arg) + args.append(unparse(kwarg)) return ", ".join(args) diff --git a/automata/singletons/py_module_loader.py b/automata/singletons/py_module_loader.py index 25fadc5e..be4fe048 100644 --- a/automata/singletons/py_module_loader.py +++ b/automata/singletons/py_module_loader.py @@ -102,8 +102,6 @@ def fetch_ast_module(self, module_dotpath: str) -> Optional[Module]: if not self._dotpath_map.contains_dotpath(module_dotpath): # type: ignore return None - print("module_dotpath = ", module_dotpath) - if module_dotpath not in self._loaded_modules: module_fpath = self._dotpath_map.get_module_fpath_by_dotpath(module_dotpath) # type: ignore self._loaded_modules[module_dotpath] = self._load_module_from_fpath(module_fpath) diff --git a/automata/symbol/symbol_utils.py b/automata/symbol/symbol_utils.py index c4bdf65d..a26bae27 100644 --- a/automata/symbol/symbol_utils.py +++ b/automata/symbol/symbol_utils.py @@ -84,7 +84,6 @@ def get_rankable_symbols( do_continue = any( filter_string in symbol.uri for filter_string in symbols_strings_to_filter ) - print(f"symbol = {symbol}, do_continue={do_continue}") if do_continue: continue diff --git a/tests/unit/sample_modules/my_project/core/calculator.py b/tests/unit/sample_modules/my_project/core/calculator.py index 2f3d79ba..0f470a0b 100644 --- a/tests/unit/sample_modules/my_project/core/calculator.py +++ b/tests/unit/sample_modules/my_project/core/calculator.py @@ -1,8 +1,13 @@ class Calculator: + """Docstring for Calculator class""" + + import math # This is bad code, but it's just for testing purposes + def __init__(self): pass - def add(self, a, b): + def add(self, a: int, b: int) -> int: + """Docstring for add method""" return a + b def subtract(self, a, b): diff --git a/tests/unit/test_context_retriever.py b/tests/unit/test_context_retriever.py index 74fabbfa..b7f1ecee 100644 --- a/tests/unit/test_context_retriever.py +++ b/tests/unit/test_context_retriever.py @@ -1,10 +1,19 @@ +import ast +import inspect import os import pytest -from automata.code_parsers.py import PyContextRetriever, PyContextRetrieverConfig +from automata.code_parsers.py import ( + ContextComponent, + PyContextRetriever, + PyContextRetrieverConfig, +) from automata.core.utils import get_root_fpath from automata.singletons.py_module_loader import py_module_loader +from automata.symbol import parse_symbol + +from .sample_modules.my_project.core.calculator import Calculator # TODO - Unify module loader fixture @@ -23,6 +32,177 @@ def context_retriever(): return PyContextRetriever(PyContextRetrieverConfig()) -# def test_retrieve(context_retriever, local_module_loader): -# print("local_module_loader = ", local_module_loader._dotpath_map.items()) -# assert False +def test_process_symbol(context_retriever): + symbol = parse_symbol( + "scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#" + ) + components = { + ContextComponent.HEADLINE: {}, + ContextComponent.INTERFACE: {"skip_private": True}, + } + context = context_retriever.process_symbol(symbol, components) + assert "Building context for symbol -" in context + assert "add(self, a: int, b: int) -> int" in context + + +def test_process_symbol_error(context_retriever): + symbol = parse_symbol( + "scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#" + ) + components = {ContextComponent.SOURCE_CODE: {}, ContextComponent.INTERFACE: {}} + with pytest.raises(ValueError): + context_retriever.process_symbol(symbol, components) + + +def test_process_symbol_invalid_component(context_retriever, caplog): + symbol = parse_symbol( + "scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#" + ) + components = {ContextComponent.HEADLINE: {}, "invalid_component": {}} + context_retriever.process_symbol(symbol, components) + assert "Warning: invalid_component is not a valid context component." in caplog.text + + +def test_source_code(context_retriever): + symbol = parse_symbol( + "scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#" + ) + ast_object = ast.parse(inspect.getsource(Calculator)) + source_code = context_retriever._source_code( + symbol, ast_object, include_imports=False, include_docstrings=True + ) + assert "class Calculator:" in source_code + assert "def add(self, a: int, b: int) -> int:" in source_code + assert "return a + b" in source_code + + +def test_source_code_2(context_retriever): + symbol = parse_symbol( + "scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#" + ) + ast_object = ast.parse(inspect.getsource(Calculator)) + source_code = context_retriever._source_code( + symbol, ast_object, include_imports=True, include_docstrings=True + ) + assert "import math" in source_code + assert "class Calculator:" in source_code + assert "def add(self, a: int, b: int) -> int:" in source_code + assert "return a + b" in source_code + + +def test_interface(context_retriever): + symbol = parse_symbol( + "scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#" + ) + ast_object = ast.parse(inspect.getsource(Calculator)) + interface = context_retriever._interface(symbol, ast_object, skip_private=True) + print("interface = ", interface) + print("-" * 100) + assert "Interface:" in interface + assert "add(self, a: int, b: int) -> int" in interface + + +def test_interface_recursion_error(context_retriever): + symbol = parse_symbol( + "scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#" + ) + ast_object = ast.parse(inspect.getsource(Calculator)) + with pytest.raises(RecursionError): + context_retriever._interface(symbol, ast_object, recursion_depth=3) + + +def test_process_headline(context_retriever): + symbol = parse_symbol( + "scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#" + ) + ast_object = ast.parse(inspect.getsource(Calculator)) + headline = context_retriever._process_headline(symbol, ast_object) + assert headline == "Building context for symbol - my_project.core.calculator.Calculator\n" + + +def test_process_method(context_retriever): + source_code = inspect.getsource(Calculator.add) + # Remove the leading spaces from the source code + source_code = "\n".join( + [ + line[len(" ") :] if line.startswith(" ") else line + for line in source_code.split("\n") + ] + ) + ast_object = ast.parse(source_code).body[0] + processed_method = context_retriever._process_method(ast_object) + assert processed_method == "add(self, a: int, b: int) -> int\n" + + +def test_get_method_return_annotation(context_retriever): + source_code = inspect.getsource(Calculator.add) + # Remove the leading spaces from the source code + source_code = "\n".join( + [ + line[len(" ") :] if line.startswith(" ") else line + for line in source_code.split("\n") + ] + ) + ast_object = ast.parse(source_code).body[0] + return_annotation = context_retriever._get_method_return_annotation(ast_object) + assert return_annotation == "int" + + +def test_is_private_method(context_retriever): + source_code = "def _private_method(): pass" + ast_object = ast.parse(source_code).body[0] + assert context_retriever._is_private_method(ast_object) + + +def test_get_all_methods(context_retriever): + source_code = inspect.getsource(Calculator) + ast_object = ast.parse(source_code) + methods = context_retriever._get_all_methods(ast_object) + assert len(methods) == 3 + assert all(isinstance(method, (ast.FunctionDef, ast.AsyncFunctionDef)) for method in methods) + + +def test_get_all_classes(context_retriever): + import textwrap + + source = textwrap.dedent( + """ + class Calculator: + def add(self, a, b): + return a + b + + def subtract(self, a, b): + return a - b + + class Calculator4: + pass + """ + ) + ast_object = ast.parse(source) + classes = context_retriever._get_all_classes(ast_object) + assert len(classes) == 2 + assert all(isinstance(cls, ast.ClassDef) for cls in classes) + + +def test_interface_include_docstrings(context_retriever): + symbol = parse_symbol( + "scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#" + ) + ast_object = ast.parse(inspect.getsource(Calculator)) + interface = context_retriever._interface(symbol, ast_object, include_docstrings=True) + assert "Interface:" in interface + assert "add(self, a: int, b: int) -> int" in interface + assert "Docstring for Calculator class" in interface + assert "Docstring for add method" in interface + + +def test_interface_exclude_docstrings(context_retriever): + symbol = parse_symbol( + "scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#" + ) + ast_object = ast.parse(inspect.getsource(Calculator)) + interface = context_retriever._interface(symbol, ast_object, include_docstrings=False) + assert "Interface:" in interface + assert "add(self, a: int, b: int) -> int" in interface + assert "Docstring for Calculator class" not in interface + assert "Docstring for add method" not in interface