Skip to content
This repository has been archived by the owner on Mar 16, 2024. It is now read-only.

Commit

Permalink
comprehensive cli tests (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty authored Jul 12, 2023
1 parent 6b2fae2 commit c5a2bf2
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 21 deletions.
7 changes: 6 additions & 1 deletion automata/code_parsers/py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
get_node_without_docstrings,
get_node_without_imports,
)
from .context_retriever import PyContextRetriever, PyContextRetrieverConfig
from .context_retriever import (
ContextComponent,
PyContextRetriever,
PyContextRetrieverConfig,
)
from .dotpath_map import DotPathMap

__all__ = [
Expand All @@ -15,4 +19,5 @@
"DotPathMap",
"PyContextRetriever",
"PyContextRetrieverConfig",
"ContextComponent",
]
6 changes: 3 additions & 3 deletions automata/code_parsers/py/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
get_docstring,
)
from dataclasses import dataclass
from typing import Optional, Union
from typing import Optional

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,7 +73,7 @@ class DocstringRemover(NodeTransformer):
def visit(self, node):
# If this node is a function, class, or module, remove its docstring.
if (
isinstance(node, Union[AsyncFunctionDef, ClassDef, FunctionDef, Module])
isinstance(node, (AsyncFunctionDef, ClassDef, FunctionDef, Module))
and isinstance(node.body[0], Expr)
and isinstance(node.body[0].value, Str)
):
Expand All @@ -94,7 +94,7 @@ class ImportRemover(NodeTransformer):
def visit(self, node):
# If this node is a function, class, or module, and its first child is an import statement,
# remove the import statement.
if isinstance(node, Union[AsyncFunctionDef, ClassDef, FunctionDef, Module]) and (
if isinstance(node, (AsyncFunctionDef, ClassDef, FunctionDef, Module)) and (
isinstance(node.body[0], (Import, ImportFrom))
):
node.body.pop(0)
Expand Down
29 changes: 20 additions & 9 deletions automata/code_parsers/py/context_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ast import AST, AsyncFunctionDef, ClassDef, FunctionDef, unparse, walk
from contextlib import contextmanager
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Union

from automata.code_parsers.py import (
get_docstring_from_node,
Expand Down Expand Up @@ -113,18 +113,18 @@ def _source_code(
self,
symbol: Symbol,
ast_object: AST,
remove_imports: bool = True,
remove_docstrings: bool = False,
include_imports: bool = False,
include_docstrings: bool = True,
max_length: Optional[int] = None,
*args,
**kwargs,
) -> str:
"""Convert a symbol into underlying source code."""

if remove_docstrings:
if not include_docstrings:
ast_object = get_node_without_docstrings(ast_object)

if remove_imports:
if not include_imports:
ast_object = get_node_without_imports(ast_object)

source = unparse(ast_object)
Expand All @@ -140,13 +140,23 @@ def _interface(
header: str = "Interface:\n\n",
class_header: str = "class ",
recursion_depth: int = 0,
processed_classes: Optional[Set[int]] = None, # add this argument
*args,
**kwargs,
) -> str:
"""Convert a symbol into an interface, skipping 'private' methods/classes if indicated."""
if recursion_depth > self.MAX_RECURSION_DEPTH:
raise RecursionError(f"Max recursion depth of {self.MAX_RECURSION_DEPTH} exceeded.")

if processed_classes is None:
processed_classes = set()

if id(ast_object) in processed_classes:
return ""

# add the class so that we do not process it twice
processed_classes.add(id(ast_object))

# indent according to indent_level
interface = self.process_entry(header)

Expand All @@ -167,6 +177,7 @@ def _interface(
header,
class_header,
recursion_depth=recursion_depth + 1,
processed_classes=processed_classes,
)

methods = sorted(self._get_all_methods(ast_object), key=lambda x: x.name)
Expand Down Expand Up @@ -213,9 +224,9 @@ def _get_method_arguments(method: Union[AsyncFunctionDef, FunctionDef]) -> str:

for arg in method.args.args:
if arg.arg in defaults:
args.append(f"{arg.arg}={unparse(defaults[arg.arg])}")
args.append(f"{unparse(arg)}={unparse(defaults[arg.arg])}")
else:
args.append(arg.arg)
args.append(unparse(arg))

# Handle keyword-only arguments
if method.args.kwonlyargs:
Expand All @@ -232,8 +243,8 @@ def _get_method_arguments(method: Union[AsyncFunctionDef, FunctionDef]) -> str:
)

if default is not None:
args.append(f"{kwarg.arg}={unparse(default)}")
args.append(f"{unparse(kwarg)}={unparse(default)}")
else:
args.append(kwarg.arg)
args.append(unparse(kwarg))

return ", ".join(args)
2 changes: 0 additions & 2 deletions automata/singletons/py_module_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ def fetch_ast_module(self, module_dotpath: str) -> Optional[Module]:
if not self._dotpath_map.contains_dotpath(module_dotpath): # type: ignore
return None

print("module_dotpath = ", module_dotpath)

if module_dotpath not in self._loaded_modules:
module_fpath = self._dotpath_map.get_module_fpath_by_dotpath(module_dotpath) # type: ignore
self._loaded_modules[module_dotpath] = self._load_module_from_fpath(module_fpath)
Expand Down
1 change: 0 additions & 1 deletion automata/symbol/symbol_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def get_rankable_symbols(
do_continue = any(
filter_string in symbol.uri for filter_string in symbols_strings_to_filter
)
print(f"symbol = {symbol}, do_continue={do_continue}")
if do_continue:
continue

Expand Down
7 changes: 6 additions & 1 deletion tests/unit/sample_modules/my_project/core/calculator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
class Calculator:
"""Docstring for Calculator class"""

import math # This is bad code, but it's just for testing purposes

def __init__(self):
pass

def add(self, a, b):
def add(self, a: int, b: int) -> int:
"""Docstring for add method"""
return a + b

def subtract(self, a, b):
Expand Down
188 changes: 184 additions & 4 deletions tests/unit/test_context_retriever.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import ast
import inspect
import os

import pytest

from automata.code_parsers.py import PyContextRetriever, PyContextRetrieverConfig
from automata.code_parsers.py import (
ContextComponent,
PyContextRetriever,
PyContextRetrieverConfig,
)
from automata.core.utils import get_root_fpath
from automata.singletons.py_module_loader import py_module_loader
from automata.symbol import parse_symbol

from .sample_modules.my_project.core.calculator import Calculator


# TODO - Unify module loader fixture
Expand All @@ -23,6 +32,177 @@ def context_retriever():
return PyContextRetriever(PyContextRetrieverConfig())


# def test_retrieve(context_retriever, local_module_loader):
# print("local_module_loader = ", local_module_loader._dotpath_map.items())
# assert False
def test_process_symbol(context_retriever):
symbol = parse_symbol(
"scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#"
)
components = {
ContextComponent.HEADLINE: {},
ContextComponent.INTERFACE: {"skip_private": True},
}
context = context_retriever.process_symbol(symbol, components)
assert "Building context for symbol -" in context
assert "add(self, a: int, b: int) -> int" in context


def test_process_symbol_error(context_retriever):
symbol = parse_symbol(
"scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#"
)
components = {ContextComponent.SOURCE_CODE: {}, ContextComponent.INTERFACE: {}}
with pytest.raises(ValueError):
context_retriever.process_symbol(symbol, components)


def test_process_symbol_invalid_component(context_retriever, caplog):
symbol = parse_symbol(
"scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#"
)
components = {ContextComponent.HEADLINE: {}, "invalid_component": {}}
context_retriever.process_symbol(symbol, components)
assert "Warning: invalid_component is not a valid context component." in caplog.text


def test_source_code(context_retriever):
symbol = parse_symbol(
"scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#"
)
ast_object = ast.parse(inspect.getsource(Calculator))
source_code = context_retriever._source_code(
symbol, ast_object, include_imports=False, include_docstrings=True
)
assert "class Calculator:" in source_code
assert "def add(self, a: int, b: int) -> int:" in source_code
assert "return a + b" in source_code


def test_source_code_2(context_retriever):
symbol = parse_symbol(
"scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#"
)
ast_object = ast.parse(inspect.getsource(Calculator))
source_code = context_retriever._source_code(
symbol, ast_object, include_imports=True, include_docstrings=True
)
assert "import math" in source_code
assert "class Calculator:" in source_code
assert "def add(self, a: int, b: int) -> int:" in source_code
assert "return a + b" in source_code


def test_interface(context_retriever):
symbol = parse_symbol(
"scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#"
)
ast_object = ast.parse(inspect.getsource(Calculator))
interface = context_retriever._interface(symbol, ast_object, skip_private=True)
print("interface = ", interface)
print("-" * 100)
assert "Interface:" in interface
assert "add(self, a: int, b: int) -> int" in interface


def test_interface_recursion_error(context_retriever):
symbol = parse_symbol(
"scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#"
)
ast_object = ast.parse(inspect.getsource(Calculator))
with pytest.raises(RecursionError):
context_retriever._interface(symbol, ast_object, recursion_depth=3)


def test_process_headline(context_retriever):
symbol = parse_symbol(
"scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#"
)
ast_object = ast.parse(inspect.getsource(Calculator))
headline = context_retriever._process_headline(symbol, ast_object)
assert headline == "Building context for symbol - my_project.core.calculator.Calculator\n"


def test_process_method(context_retriever):
source_code = inspect.getsource(Calculator.add)
# Remove the leading spaces from the source code
source_code = "\n".join(
[
line[len(" ") :] if line.startswith(" ") else line
for line in source_code.split("\n")
]
)
ast_object = ast.parse(source_code).body[0]
processed_method = context_retriever._process_method(ast_object)
assert processed_method == "add(self, a: int, b: int) -> int\n"


def test_get_method_return_annotation(context_retriever):
source_code = inspect.getsource(Calculator.add)
# Remove the leading spaces from the source code
source_code = "\n".join(
[
line[len(" ") :] if line.startswith(" ") else line
for line in source_code.split("\n")
]
)
ast_object = ast.parse(source_code).body[0]
return_annotation = context_retriever._get_method_return_annotation(ast_object)
assert return_annotation == "int"


def test_is_private_method(context_retriever):
source_code = "def _private_method(): pass"
ast_object = ast.parse(source_code).body[0]
assert context_retriever._is_private_method(ast_object)


def test_get_all_methods(context_retriever):
source_code = inspect.getsource(Calculator)
ast_object = ast.parse(source_code)
methods = context_retriever._get_all_methods(ast_object)
assert len(methods) == 3
assert all(isinstance(method, (ast.FunctionDef, ast.AsyncFunctionDef)) for method in methods)


def test_get_all_classes(context_retriever):
import textwrap

source = textwrap.dedent(
"""
class Calculator:
def add(self, a, b):
return a + b
def subtract(self, a, b):
return a - b
class Calculator4:
pass
"""
)
ast_object = ast.parse(source)
classes = context_retriever._get_all_classes(ast_object)
assert len(classes) == 2
assert all(isinstance(cls, ast.ClassDef) for cls in classes)


def test_interface_include_docstrings(context_retriever):
symbol = parse_symbol(
"scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#"
)
ast_object = ast.parse(inspect.getsource(Calculator))
interface = context_retriever._interface(symbol, ast_object, include_docstrings=True)
assert "Interface:" in interface
assert "add(self, a: int, b: int) -> int" in interface
assert "Docstring for Calculator class" in interface
assert "Docstring for add method" in interface


def test_interface_exclude_docstrings(context_retriever):
symbol = parse_symbol(
"scip-python python automata v0.0.0 `my_project.core.calculator`/Calculator#"
)
ast_object = ast.parse(inspect.getsource(Calculator))
interface = context_retriever._interface(symbol, ast_object, include_docstrings=False)
assert "Interface:" in interface
assert "add(self, a: int, b: int) -> int" in interface
assert "Docstring for Calculator class" not in interface
assert "Docstring for add method" not in interface

0 comments on commit c5a2bf2

Please sign in to comment.