Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOP-5237: Deserialize AST #638

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
15 changes: 15 additions & 0 deletions snooty/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,21 @@ def __init__(
self.name = name


class UnexpectedNodeType(Diagnostic):
severity = Diagnostic.Level.error

def __init__(
self,
found_type: Union[str, None],
expected_type: Optional[str],
start: Union[int, Tuple[int, int]],
) -> None:
suggestion = f' Expected: "{expected_type}".' if expected_type else ""
i80and marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(
f'Found unexpected node type "{found_type}".{suggestion}', start
)


class UnnamedPage(Diagnostic):
severity = Diagnostic.Level.error

Expand Down
2 changes: 1 addition & 1 deletion snooty/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def handle_document(
fully_qualified_pageid: str,
document: Dict[str, Any],
) -> None:
if page_id.suffix != EXT_FOR_PAGE:
if page_id.suffix not in [EXT_FOR_PAGE, ".ast"]:
return
super().handle_document(
build_identifiers, page_id, fully_qualified_pageid, document
Expand Down
2 changes: 1 addition & 1 deletion snooty/n.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
class FileId(PurePosixPath):
"""An unambiguous file path relative to the local project's root."""

PAT_FILE_EXTENSIONS = re.compile(r"\.((txt)|(rst)|(yaml))$")
PAT_FILE_EXTENSIONS = re.compile(r"\.((txt)|(rst)|(yaml)|(ast))$")

def collapse_dots(self) -> "FileId":
result: List[str] = []
Expand Down
43 changes: 43 additions & 0 deletions snooty/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
TodoInfo,
UnexpectedDirectiveOrder,
UnexpectedIndentation,
UnexpectedNodeType,
UnknownOptionId,
UnknownTabID,
UnknownTabset,
Expand Down Expand Up @@ -1843,6 +1844,48 @@ def build(
fileids = (self.config.get_fileid(path) for path in paths)
self.parse_rst_files(fileids, max_workers)

# Handle custom AST from API reference docs
with util.PerformanceLogger.singleton().start("parse pre-existing AST"):
ast_pages = util.get_files(
self.config.source_path,
{".ast"},
self.config.root,
nested_projects_diagnostics,
)

for path in ast_pages:
fileid = self.config.get_fileid(path)
diagnostics: List[Diagnostic] = []

try:
text, read_diagnostics = self.config.read(fileid)
diagnostics.extend(read_diagnostics)
ast_json = json.loads(text)
is_valid_ast_root = (
isinstance(ast_json, Dict)
and ast_json.get("type") == n.Root.type
)

if not is_valid_ast_root:
diagnostics.append(
UnexpectedNodeType(ast_json.get("type"), "root", 0)
)

ast_root = (
util.deserialize_ast(ast_json, n.Root, diagnostics)
if is_valid_ast_root
else None
)
new_page = Page.create(
fileid,
fileid.as_posix().replace(".ast", ".txt"),
"",
ast_root,
)
self._page_updated(new_page, diagnostics)
except Exception as e:
logger.error(e)

for nested_path, diagnostics in nested_projects_diagnostics.items():
with self._backend_lock:
self.on_diagnostics(nested_path, diagnostics)
Expand Down
100 changes: 100 additions & 0 deletions snooty/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TabMustBeDirective,
UnexpectedDirectiveOrder,
UnexpectedIndentation,
UnexpectedNodeType,
UnknownOptionId,
UnknownTabID,
UnknownTabset,
Expand Down Expand Up @@ -4488,3 +4489,102 @@ def test_video() -> None:
page.finish(diagnostics)
# Diagnostic due to invalid upload-date format
assert [type(x) for x in diagnostics] == [DocUtilsParseError]


def test_parse_ast() -> None:
with make_test(
{
Path(
"source/test.ast"
): """
{
"type": "root",
"children": [
{
"type": "section",
"children": [
{
"type": "heading",
"children": [
{
"type": "text",
"value": "Interface GridFSBucket"
}
],
"id": "interface-gridfsbucket"
},
{
"type": "paragraph",
"children": [
{
"type": "reference",
"children": [
{
"type": "text",
"value": "@ThreadSafe"
}
],
"refuri": "http://mongodb.github.io/mongo-java-driver/5.2/apidocs/mongodb-driver-core/com/mongodb/annotations/ThreadSafe.html"
}
]
},
{
"type": "directive",
"name": "important",
"domain": "",
"argument": [
{
"type": "text",
"value": "Important Callout Heading"
}
],
"children": [
{
"type": "paragraph",
"children": [
{
"type": "text",
"value": "Important Callout Body Text"
}
]
}
]
}
]
}
],
"fileid": "test.ast"
}
""",
Path(
"source/bad-types.ast"
): """
{
"type": "root",
"children": [
{
"type": "section",
"children": [
{
"type": "beep",
"children": [
{
"type": "text",
"value": "Interface GridFSBucket"
}
],
"id": "interface-gridfsbucket"
}
]
}
],
"fileid": "bad-types.ast"
}
""",
}
) as result:
diagnostics = result.diagnostics[FileId("test.ast")]
assert not diagnostics
bad_types_diagnostics = result.diagnostics[FileId("bad-types.ast")]
result.pages
assert [type(d) for d in bad_types_diagnostics] == [UnexpectedNodeType]
91 changes: 90 additions & 1 deletion snooty/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
Set,
TextIO,
Tuple,
Type,
TypeVar,
Union,
cast,
Expand All @@ -48,7 +49,7 @@
import requests
import tomli

from snooty.diagnostics import Diagnostic, NestedProject
from snooty.diagnostics import Diagnostic, NestedProject, UnexpectedNodeType
from snooty.n import FileId

from . import n, tinydocutils
Expand Down Expand Up @@ -648,3 +649,91 @@ def parse_toml_and_add_line_info(text: str) -> Dict[str, Any]:
raise TOMLDecodeErrorWithSourceInfo(message, text.count("\n") + 1) from err

raise err


def deserialize_ast(
i80and marked this conversation as resolved.
Show resolved Hide resolved
node: n.SerializedNode, node_type: Type[n._N], diagnostics: List[Diagnostic]
) -> n._N:
filtered_fields: Dict[str, Any] = {}
node_classes: List[Type[n.Node]] = [
i80and marked this conversation as resolved.
Show resolved Hide resolved
n.BlockSubstitutionReference,
n.Code,
n.Comment,
n.DefinitionList,
n.DefinitionListItem,
n.Directive,
n.DirectiveArgument,
n.Emphasis,
n.Field,
n.FieldList,
n.Footnote,
n.FootnoteReference,
n.Heading,
n.InlineTarget,
n.Label,
n.Line,
n.LineBlock,
n.ListNode,
n.ListNodeItem,
n.Literal,
n.NamedReference,
n.Paragraph,
n.Reference,
n.RefRole,
n.Role,
n.Root,
n.Section,
n.Strong,
n.SubstitutionDefinition,
n.SubstitutionReference,
n.Table,
n.Target,
n.TargetIdentifier,
n.Text,
n.TocTreeDirective,
n.Transition,
]

def find_matching_type(node: n.SerializedNode) -> Optional[Type[n.Node]]:
for c in node_classes:
i80and marked this conversation as resolved.
Show resolved Hide resolved
if c.type == node.get("type"):
return c
return None

for field in dataclasses.fields(node_type):
# We don't need "span" to be present here since we need to hardcode it as the first argument of Node
if field.name == "span":
continue

node_value = node.get(field.name)
has_nested_children = field.name == "children" and issubclass(
node_type, n.Parent
)
has_nested_argument = field.name == "argument" and issubclass(
node_type, n.Directive
)
if isinstance(node_value, List) and (
has_nested_children or has_nested_argument
):
deserialized_children: List[n.Node] = []

for child in node_value:
if not isinstance(child, dict):
continue

child_node_type = find_matching_type(child)
if child_node_type:
deserialized_children.append(
deserialize_ast(child, child_node_type, diagnostics)
)
else:
diagnostics.append(UnexpectedNodeType(child.get("type"), None, 0))
continue

filtered_fields[field.name] = deserialized_children
else:
# Ideally, we validate that the data types of the fields match the data types of the JSON node,
# but that requires a more verbose and time-consuming process. For now, we assume data types are correct.
filtered_fields[field.name] = node_value

return node_type((0,), **filtered_fields)