From 96826a3de6ce87dac04424e88e6ab1d1eb44f5b3 Mon Sep 17 00:00:00 2001 From: rayangler <27821750+rayangler@users.noreply.github.com> Date: Tue, 10 Dec 2024 15:44:14 -0500 Subject: [PATCH 01/13] Deserialize ASTs --- snooty/main.py | 3 +- snooty/n.py | 72 ++++++++++++++++++++++++++++++++++++++++++- snooty/parser.py | 22 +++++++++++++ snooty/postprocess.py | 8 +++++ snooty/util.py | 57 +++++++++++++++++++++++++++++++++- 5 files changed, 159 insertions(+), 3 deletions(-) diff --git a/snooty/main.py b/snooty/main.py index 5a00431e..bf80e287 100644 --- a/snooty/main.py +++ b/snooty/main.py @@ -184,7 +184,8 @@ def handle_document( fully_qualified_pageid: str, document: Dict[str, Any], ) -> None: - if page_id.suffix != EXT_FOR_PAGE: + # if page_id.suffix != EXT_FOR_PAGE: + if page_id.suffix not in [".txt", ".ast"]: return super().handle_document( build_identifiers, page_id, fully_qualified_pageid, document diff --git a/snooty/n.py b/snooty/n.py index 734fd355..112be356 100644 --- a/snooty/n.py +++ b/snooty/n.py @@ -22,6 +22,8 @@ Union, ) +from typing_extensions import Self + __all__ = ( "Node", "InlineNode", @@ -60,7 +62,7 @@ class FileId(PurePosixPath): """An unambiguous file path relative to the local project's root.""" - PAT_FILE_EXTENSIONS = re.compile(r"\.((txt)|(rst)|(yaml))$") + PAT_FILE_EXTENSIONS = re.compile(r"\.((txt)|(rst)|(yaml)|(ast))$") def collapse_dots(self) -> "FileId": result: List[str] = [] @@ -135,6 +137,74 @@ def serialize(self) -> SerializedNode: del result["span"] return result + @classmethod + def deserialize(cls, node: Dict[str, SerializableType]) -> Self: + fields = [field.name for field in dataclasses.fields(cls)] + filtered_fields = {k: node.get(k) for k in fields if k in fields} + + if not filtered_fields["span"]: + filtered_fields["span"] = (0,) + + node_classes: List[Type[Node]] = [ + Code, + Comment, + Label, + Section, + Paragraph, + Footnote, + FootnoteReference, + SubstitutionDefinition, + SubstitutionReference, + BlockSubstitutionReference, + Root, + Heading, + DefinitionListItem, + DefinitionList, + ListNodeItem, + ListNode, + Line, + LineBlock, + Directive, + TocTreeDirective, + DirectiveArgument, + Target, + TargetIdentifier, + InlineTarget, + NamedReference, + Role, + RefRole, + Text, + Literal, + Emphasis, + Field, + FieldList, + Strong, + Transition, + Table, + ] + + def find_matching_type( + node: Dict[str, SerializableType] + ) -> Optional[Type[Node]]: + for c in node_classes: + if c.type == node["type"]: + return c + return None + + deserialized_children = [] + if "children" in filtered_fields and isinstance( + filtered_fields["children"], List + ): + for child in filtered_fields["children"]: + node_type = find_matching_type(child) + if node_type: + deserialized_children.append(node_type.deserialize(child)) + + # if "children" in filtered_fields: + filtered_fields["children"] = deserialized_children + + return cls(**filtered_fields) + def get_text(self) -> str: """Return pure textual content from a given AST node. Most nodes will return an empty string.""" return "" diff --git a/snooty/parser.py b/snooty/parser.py index 89d69476..10fe77f1 100644 --- a/snooty/parser.py +++ b/snooty/parser.py @@ -1843,6 +1843,28 @@ def build( fileids = (self.config.get_fileid(path) for path in paths) self.parse_rst_files(fileids, max_workers) + # Handle custom AST from API reference docs + with util.PerformanceLogger.singleton().start("parse pre-existing AST"): + ast_pages = util.get_files( + self.config.source_path, + {".ast"}, + self.config.root, + nested_projects_diagnostics, + ) + + for path in ast_pages: + fileid = self.config.get_fileid(path) + text, _ = self.config.read(fileid) + ast = json.loads(text) + util.deserialize_ast(ast) + new_page = Page.create( + fileid, + fileid.as_posix().replace(".ast", ".txt"), + "", + util.deserialize_ast(ast), + ) + self._page_updated(new_page, []) + for nested_path, diagnostics in nested_projects_diagnostics.items(): with self._backend_lock: self.on_diagnostics(nested_path, diagnostics) diff --git a/snooty/postprocess.py b/snooty/postprocess.py index 7bedfb9e..09ccde09 100644 --- a/snooty/postprocess.py +++ b/snooty/postprocess.py @@ -1,5 +1,6 @@ import collections import errno +import json import logging import os.path import sys @@ -2228,6 +2229,13 @@ def run( return PostprocessorResult({}, {}, {}, self.targets) self.pages = pages + # print(self.pages) + + # For debugging purposes + # print("Foofoo") + # data = self.pages.get(FileId("api/example.txt")).ast.serialize() + # print(json.dumps(data, indent=2)) + self.cancellation_token = cancellation_token context = Context(pages) context.add(self.project_config) diff --git a/snooty/util.py b/snooty/util.py index a55a1449..d6343bed 100644 --- a/snooty/util.py +++ b/snooty/util.py @@ -49,7 +49,7 @@ import tomli from snooty.diagnostics import Diagnostic, NestedProject -from snooty.n import FileId +from snooty.n import FileId, SerializableType from . import n, tinydocutils @@ -620,6 +620,61 @@ def structural_hash(obj: object) -> bytes: return hasher.digest() +def deserialize_ast(node: SerializableType) -> n.Node | None: + if not isinstance(node, dict): + return None + + node_classes = [ + n.Code, + n.Comment, + n.Label, + n.Section, + n.Paragraph, + n.Footnote, + n.FootnoteReference, + n.SubstitutionDefinition, + n.SubstitutionReference, + n.BlockSubstitutionReference, + n.Root, + n.Heading, + n.DefinitionListItem, + n.DefinitionList, + n.ListNodeItem, + n.ListNode, + n.Line, + n.LineBlock, + n.Directive, + n.TocTreeDirective, + n.DirectiveArgument, + n.Target, + n.TargetIdentifier, + n.InlineTarget, + n.NamedReference, + n.Role, + n.RefRole, + n.Text, + n.Literal, + n.Emphasis, + n.Field, + n.FieldList, + n.Strong, + n.Transition, + n.Table, + ] + + def find_matching_type(): + for c in node_classes: + if c.type == node["type"]: + return c + return None + + node_type = find_matching_type() + if node_type: + return node_type.deserialize(node) + + return None + + class TOMLDecodeErrorWithSourceInfo(tomli.TOMLDecodeError): def __init__(self, message: str, lineno: int) -> None: super().__init__(message) From a79e771680483cccc7abd398e4a87ba40fdf66a6 Mon Sep 17 00:00:00 2001 From: rayangler <27821750+rayangler@users.noreply.github.com> Date: Tue, 10 Dec 2024 17:29:26 -0500 Subject: [PATCH 02/13] Fix typing --- snooty/n.py | 22 +++++++++++++--------- snooty/parser.py | 14 +++++++++++--- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/snooty/n.py b/snooty/n.py index 112be356..ca21c3e7 100644 --- a/snooty/n.py +++ b/snooty/n.py @@ -1,6 +1,6 @@ import dataclasses import re -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime from enum import Enum from pathlib import PurePosixPath @@ -140,10 +140,10 @@ def serialize(self) -> SerializedNode: @classmethod def deserialize(cls, node: Dict[str, SerializableType]) -> Self: fields = [field.name for field in dataclasses.fields(cls)] - filtered_fields = {k: node.get(k) for k in fields if k in fields} + filtered_fields = {k: node.get(k) for k in fields} - if not filtered_fields["span"]: - filtered_fields["span"] = (0,) + if "span" in filtered_fields: + del filtered_fields["span"] node_classes: List[Type[Node]] = [ Code, @@ -191,19 +191,23 @@ def find_matching_type( return c return None + # "span" is expected to be passed in first + deserialized_node = cls((0,), **filtered_fields) deserialized_children = [] - if "children" in filtered_fields and isinstance( - filtered_fields["children"], List + + if ( + "children" in filtered_fields + and isinstance(filtered_fields["children"], List) + and isinstance(deserialized_node, Parent) ): for child in filtered_fields["children"]: node_type = find_matching_type(child) if node_type: deserialized_children.append(node_type.deserialize(child)) - # if "children" in filtered_fields: - filtered_fields["children"] = deserialized_children + deserialized_node.children = deserialized_children - return cls(**filtered_fields) + return deserialized_node def get_text(self) -> str: """Return pure textual content from a given AST node. Most nodes will return an empty string.""" diff --git a/snooty/parser.py b/snooty/parser.py index 10fe77f1..8b5d1e3d 100644 --- a/snooty/parser.py +++ b/snooty/parser.py @@ -1855,13 +1855,21 @@ def build( for path in ast_pages: fileid = self.config.get_fileid(path) text, _ = self.config.read(fileid) - ast = json.loads(text) - util.deserialize_ast(ast) + ast_json = json.loads(text) + + if not ( + isinstance(ast_json, Dict) + and ast_json.get("type", "") == n.Root.type + ): + # TODO-5237: Add diagnostic + continue + + ast_root = n.Root.deserialize(ast_json) new_page = Page.create( fileid, fileid.as_posix().replace(".ast", ".txt"), "", - util.deserialize_ast(ast), + ast_root, ) self._page_updated(new_page, []) From 94f361eeab5c19a54bb8050ae5489f7a0499c719 Mon Sep 17 00:00:00 2001 From: rayangler <27821750+rayangler@users.noreply.github.com> Date: Wed, 11 Dec 2024 17:10:05 -0500 Subject: [PATCH 03/13] Remove unused code --- snooty/n.py | 2 +- snooty/postprocess.py | 1 - snooty/util.py | 57 +------------------------------------------ 3 files changed, 2 insertions(+), 58 deletions(-) diff --git a/snooty/n.py b/snooty/n.py index ca21c3e7..5036429c 100644 --- a/snooty/n.py +++ b/snooty/n.py @@ -1,6 +1,6 @@ import dataclasses import re -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime from enum import Enum from pathlib import PurePosixPath diff --git a/snooty/postprocess.py b/snooty/postprocess.py index 09ccde09..8d17daf1 100644 --- a/snooty/postprocess.py +++ b/snooty/postprocess.py @@ -1,6 +1,5 @@ import collections import errno -import json import logging import os.path import sys diff --git a/snooty/util.py b/snooty/util.py index d6343bed..a55a1449 100644 --- a/snooty/util.py +++ b/snooty/util.py @@ -49,7 +49,7 @@ import tomli from snooty.diagnostics import Diagnostic, NestedProject -from snooty.n import FileId, SerializableType +from snooty.n import FileId from . import n, tinydocutils @@ -620,61 +620,6 @@ def structural_hash(obj: object) -> bytes: return hasher.digest() -def deserialize_ast(node: SerializableType) -> n.Node | None: - if not isinstance(node, dict): - return None - - node_classes = [ - n.Code, - n.Comment, - n.Label, - n.Section, - n.Paragraph, - n.Footnote, - n.FootnoteReference, - n.SubstitutionDefinition, - n.SubstitutionReference, - n.BlockSubstitutionReference, - n.Root, - n.Heading, - n.DefinitionListItem, - n.DefinitionList, - n.ListNodeItem, - n.ListNode, - n.Line, - n.LineBlock, - n.Directive, - n.TocTreeDirective, - n.DirectiveArgument, - n.Target, - n.TargetIdentifier, - n.InlineTarget, - n.NamedReference, - n.Role, - n.RefRole, - n.Text, - n.Literal, - n.Emphasis, - n.Field, - n.FieldList, - n.Strong, - n.Transition, - n.Table, - ] - - def find_matching_type(): - for c in node_classes: - if c.type == node["type"]: - return c - return None - - node_type = find_matching_type() - if node_type: - return node_type.deserialize(node) - - return None - - class TOMLDecodeErrorWithSourceInfo(tomli.TOMLDecodeError): def __init__(self, message: str, lineno: int) -> None: super().__init__(message) From 67f8307fc497ca822a4e7976875adb0f9fc2643e Mon Sep 17 00:00:00 2001 From: rayangler <27821750+rayangler@users.noreply.github.com> Date: Wed, 11 Dec 2024 17:17:38 -0500 Subject: [PATCH 04/13] Reorder --- snooty/n.py | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/snooty/n.py b/snooty/n.py index 5036429c..3deb4e88 100644 --- a/snooty/n.py +++ b/snooty/n.py @@ -146,41 +146,41 @@ def deserialize(cls, node: Dict[str, SerializableType]) -> Self: del filtered_fields["span"] node_classes: List[Type[Node]] = [ + BlockSubstitutionReference, Code, Comment, - Label, - Section, - Paragraph, + DefinitionList, + DefinitionListItem, + Directive, + DirectiveArgument, + Emphasis, + Field, + FieldList, Footnote, FootnoteReference, - SubstitutionDefinition, - SubstitutionReference, - BlockSubstitutionReference, - Root, Heading, - DefinitionListItem, - DefinitionList, - ListNodeItem, - ListNode, + InlineTarget, + Label, Line, LineBlock, - Directive, - TocTreeDirective, - DirectiveArgument, - Target, - TargetIdentifier, - InlineTarget, + ListNode, + ListNodeItem, + Literal, NamedReference, - Role, + Paragraph, RefRole, - Text, - Literal, - Emphasis, - Field, - FieldList, + Role, + Root, + Section, Strong, - Transition, + SubstitutionDefinition, + SubstitutionReference, Table, + Target, + TargetIdentifier, + Text, + TocTreeDirective, + Transition, ] def find_matching_type( From 236250c149bdbd96aff5e93c0028edc1ae23819f Mon Sep 17 00:00:00 2001 From: rayangler <27821750+rayangler@users.noreply.github.com> Date: Wed, 11 Dec 2024 17:22:56 -0500 Subject: [PATCH 05/13] Clean up --- snooty/main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/snooty/main.py b/snooty/main.py index bf80e287..850b36ce 100644 --- a/snooty/main.py +++ b/snooty/main.py @@ -184,8 +184,7 @@ def handle_document( fully_qualified_pageid: str, document: Dict[str, Any], ) -> None: - # if page_id.suffix != EXT_FOR_PAGE: - if page_id.suffix not in [".txt", ".ast"]: + if page_id.suffix not in [EXT_FOR_PAGE, ".ast"]: return super().handle_document( build_identifiers, page_id, fully_qualified_pageid, document From e6542edbbff9e778d8bece22ecb12b1f213cec72 Mon Sep 17 00:00:00 2001 From: rayangler <27821750+rayangler@users.noreply.github.com> Date: Wed, 11 Dec 2024 18:25:58 -0500 Subject: [PATCH 06/13] Add diagnostics --- snooty/diagnostics.py | 15 +++++++++++++ snooty/n.py | 8 ++++++- snooty/parser.py | 50 ++++++++++++++++++++++++++++--------------- snooty/test_parser.py | 4 ++++ 4 files changed, 59 insertions(+), 18 deletions(-) diff --git a/snooty/diagnostics.py b/snooty/diagnostics.py index c1d43652..219c88bf 100644 --- a/snooty/diagnostics.py +++ b/snooty/diagnostics.py @@ -149,6 +149,21 @@ def __init__( self.name = name +class UnexpectedNodeType(Diagnostic): + severity = Diagnostic.Level.error + + def __init__( + self, + found_type: Union[str, None], + expected_type: Optional[str], + start: Union[int, Tuple[int, int]], + ) -> None: + suggestion = f' Expected: "{expected_type}".' if expected_type else "" + super().__init__( + f'Found unexpected node type "{found_type}".{suggestion}', start + ) + + class UnnamedPage(Diagnostic): severity = Diagnostic.Level.error diff --git a/snooty/n.py b/snooty/n.py index 3deb4e88..cda9fbcd 100644 --- a/snooty/n.py +++ b/snooty/n.py @@ -168,6 +168,7 @@ def deserialize(cls, node: Dict[str, SerializableType]) -> Self: Literal, NamedReference, Paragraph, + Reference, RefRole, Role, Root, @@ -185,7 +186,7 @@ def deserialize(cls, node: Dict[str, SerializableType]) -> Self: def find_matching_type( node: Dict[str, SerializableType] - ) -> Optional[Type[Node]]: + ) -> Union[Type[Node], None]: for c in node_classes: if c.type == node["type"]: return c @@ -201,9 +202,14 @@ def find_matching_type( and isinstance(deserialized_node, Parent) ): for child in filtered_fields["children"]: + if not isinstance(child, dict): + continue + node_type = find_matching_type(child) if node_type: deserialized_children.append(node_type.deserialize(child)) + else: + raise NotImplementedError(child.get("type")) deserialized_node.children = deserialized_children diff --git a/snooty/parser.py b/snooty/parser.py index 8b5d1e3d..62b7f19f 100644 --- a/snooty/parser.py +++ b/snooty/parser.py @@ -77,6 +77,7 @@ TodoInfo, UnexpectedDirectiveOrder, UnexpectedIndentation, + UnexpectedNodeType, UnknownOptionId, UnknownTabID, UnknownTabset, @@ -1854,24 +1855,39 @@ def build( for path in ast_pages: fileid = self.config.get_fileid(path) - text, _ = self.config.read(fileid) - ast_json = json.loads(text) - - if not ( - isinstance(ast_json, Dict) - and ast_json.get("type", "") == n.Root.type - ): - # TODO-5237: Add diagnostic - continue + diagnostics: List[Diagnostic] = [] - ast_root = n.Root.deserialize(ast_json) - new_page = Page.create( - fileid, - fileid.as_posix().replace(".ast", ".txt"), - "", - ast_root, - ) - self._page_updated(new_page, []) + try: + text, read_diagnostics = self.config.read(fileid) + diagnostics.extend(read_diagnostics) + ast_json = json.loads(text) + is_valid_ast_root = ( + isinstance(ast_json, Dict) + and ast_json.get("type") == n.Root.type + ) + + if not is_valid_ast_root: + diagnostics.append(UnexpectedNodeType(ast_json.get("type"), "root", 0)) + + ast_root = ( + n.Root.deserialize(ast_json) if is_valid_ast_root else None + ) + new_page = Page.create( + fileid, + fileid.as_posix().replace(".ast", ".txt"), + "", + ast_root, + ) + self._page_updated(new_page, diagnostics) + except NotImplementedError as e: + if e.args: + invalid_node_type = e.args[0] + diagnostics.append(UnexpectedNodeType(invalid_node_type, None, 0)) + self.pages.set_orphan_diagnostics(fileid, diagnostics) + with self._backend_lock: + self.on_diagnostics(fileid, diagnostics) + except Exception as e: + logger.error(e) for nested_path, diagnostics in nested_projects_diagnostics.items(): with self._backend_lock: diff --git a/snooty/test_parser.py b/snooty/test_parser.py index 84e8e37a..f98d489e 100644 --- a/snooty/test_parser.py +++ b/snooty/test_parser.py @@ -4488,3 +4488,7 @@ def test_video() -> None: page.finish(diagnostics) # Diagnostic due to invalid upload-date format assert [type(x) for x in diagnostics] == [DocUtilsParseError] + + +def test_parse_ast() -> None: + pass From b915eb9010abbf04801f6081dffac7a3abcf0aea Mon Sep 17 00:00:00 2001 From: rayangler <27821750+rayangler@users.noreply.github.com> Date: Thu, 12 Dec 2024 17:16:41 -0500 Subject: [PATCH 07/13] Refactor and test --- snooty/n.py | 59 +++++++++++++------------- snooty/test_parser.py | 98 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 127 insertions(+), 30 deletions(-) diff --git a/snooty/n.py b/snooty/n.py index cda9fbcd..ea486093 100644 --- a/snooty/n.py +++ b/snooty/n.py @@ -138,13 +138,8 @@ def serialize(self) -> SerializedNode: return result @classmethod - def deserialize(cls, node: Dict[str, SerializableType]) -> Self: - fields = [field.name for field in dataclasses.fields(cls)] - filtered_fields = {k: node.get(k) for k in fields} - - if "span" in filtered_fields: - del filtered_fields["span"] - + def deserialize(cls, node: Dict[str, SerializedNode]) -> Self: + filtered_fields = {} node_classes: List[Type[Node]] = [ BlockSubstitutionReference, Code, @@ -185,35 +180,41 @@ def deserialize(cls, node: Dict[str, SerializableType]) -> Self: ] def find_matching_type( - node: Dict[str, SerializableType] - ) -> Union[Type[Node], None]: + node: SerializedNode + ) -> Optional[Type[Node]]: for c in node_classes: if c.type == node["type"]: return c return None - # "span" is expected to be passed in first - deserialized_node = cls((0,), **filtered_fields) - deserialized_children = [] - - if ( - "children" in filtered_fields - and isinstance(filtered_fields["children"], List) - and isinstance(deserialized_node, Parent) - ): - for child in filtered_fields["children"]: - if not isinstance(child, dict): - continue - - node_type = find_matching_type(child) - if node_type: - deserialized_children.append(node_type.deserialize(child)) - else: - raise NotImplementedError(child.get("type")) + for field in dataclasses.fields(cls): + # We don't need "span" to be present here since we need to hardcode it as the first argument of Node + if field.name == "span": + continue - deserialized_node.children = deserialized_children + node_value = node.get(field.name) + has_nested_children = field.name == "children" and issubclass(cls, Parent) + has_nested_argument = field.name == "argument" and issubclass(cls, Directive) + if isinstance(node_value, List) and (has_nested_children or has_nested_argument): + deserialized_children = [] + + for child in node_value: + if not isinstance(child, dict): + continue + + child_node_type = find_matching_type(child) + if child_node_type: + deserialized_children.append(child_node_type.deserialize(child)) + else: + raise NotImplementedError(child.get("type")) + + filtered_fields[field.name] = deserialized_children + else: + # Ideally, we validate that the data types of the fields match the data types of the JSON node, + # but that requires a more verbose and time-consuming process. For now, we assume data types are correct. + filtered_fields[field.name] = node_value - return deserialized_node + return cls((0,), **filtered_fields) def get_text(self) -> str: """Return pure textual content from a given AST node. Most nodes will return an empty string.""" diff --git a/snooty/test_parser.py b/snooty/test_parser.py index f98d489e..83e99f09 100644 --- a/snooty/test_parser.py +++ b/snooty/test_parser.py @@ -30,6 +30,7 @@ TabMustBeDirective, UnexpectedDirectiveOrder, UnexpectedIndentation, + UnexpectedNodeType, UnknownOptionId, UnknownTabID, UnknownTabset, @@ -4491,4 +4492,99 @@ def test_video() -> None: def test_parse_ast() -> None: - pass + with make_test( + { + Path( + "source/test.ast" + ): """ +{ + "type": "root", + "children": [ + { + "type": "section", + "children": [ + { + "type": "heading", + "children": [ + { + "type": "text", + "value": "Interface GridFSBucket" + } + ], + "id": "interface-gridfsbucket" + }, + { + "type": "paragraph", + "children": [ + { + "type": "reference", + "children": [ + { + "type": "text", + "value": "@ThreadSafe" + } + ], + "refuri": "http://mongodb.github.io/mongo-java-driver/5.2/apidocs/mongodb-driver-core/com/mongodb/annotations/ThreadSafe.html" + } + ] + }, + { + "type": "directive", + "name": "important", + "domain": "", + "argument": [ + { + "type": "text", + "value": "Important Callout Heading" + } + ], + "children": [ + { + "type": "paragraph", + "children": [ + { + "type": "text", + "value": "Important Callout Body Text" + } + ] + } + ] + } + ] + } + ], + "fileid": "test.ast" +} +""", + Path( + "source/bad-types.ast" + ): """ +{ + "type": "root", + "children": [ + { + "type": "section", + "children": [ + { + "type": "beep", + "children": [ + { + "type": "text", + "value": "Interface GridFSBucket" + } + ], + "id": "interface-gridfsbucket" + } + ] + } + ], + "fileid": "bad-types.ast" +} +""", + } + ) as result: + diagnostics = result.diagnostics[FileId("test.ast")] + assert not diagnostics + bad_types_diagnostics = result.diagnostics[FileId("bad-types.ast")] + result.pages + assert [type(d) for d in bad_types_diagnostics] == [UnexpectedNodeType] From 6b399cf719cccc97703952b15bf1cdaa2c7e7cf0 Mon Sep 17 00:00:00 2001 From: rayangler <27821750+rayangler@users.noreply.github.com> Date: Thu, 12 Dec 2024 17:17:53 -0500 Subject: [PATCH 08/13] Format --- snooty/n.py | 14 ++++++++------ snooty/parser.py | 8 ++++++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/snooty/n.py b/snooty/n.py index ea486093..843753d4 100644 --- a/snooty/n.py +++ b/snooty/n.py @@ -179,9 +179,7 @@ def deserialize(cls, node: Dict[str, SerializedNode]) -> Self: Transition, ] - def find_matching_type( - node: SerializedNode - ) -> Optional[Type[Node]]: + def find_matching_type(node: SerializedNode) -> Optional[Type[Node]]: for c in node_classes: if c.type == node["type"]: return c @@ -194,8 +192,12 @@ def find_matching_type( node_value = node.get(field.name) has_nested_children = field.name == "children" and issubclass(cls, Parent) - has_nested_argument = field.name == "argument" and issubclass(cls, Directive) - if isinstance(node_value, List) and (has_nested_children or has_nested_argument): + has_nested_argument = field.name == "argument" and issubclass( + cls, Directive + ) + if isinstance(node_value, List) and ( + has_nested_children or has_nested_argument + ): deserialized_children = [] for child in node_value: @@ -207,7 +209,7 @@ def find_matching_type( deserialized_children.append(child_node_type.deserialize(child)) else: raise NotImplementedError(child.get("type")) - + filtered_fields[field.name] = deserialized_children else: # Ideally, we validate that the data types of the fields match the data types of the JSON node, diff --git a/snooty/parser.py b/snooty/parser.py index 62b7f19f..330ae3c2 100644 --- a/snooty/parser.py +++ b/snooty/parser.py @@ -1867,7 +1867,9 @@ def build( ) if not is_valid_ast_root: - diagnostics.append(UnexpectedNodeType(ast_json.get("type"), "root", 0)) + diagnostics.append( + UnexpectedNodeType(ast_json.get("type"), "root", 0) + ) ast_root = ( n.Root.deserialize(ast_json) if is_valid_ast_root else None @@ -1882,7 +1884,9 @@ def build( except NotImplementedError as e: if e.args: invalid_node_type = e.args[0] - diagnostics.append(UnexpectedNodeType(invalid_node_type, None, 0)) + diagnostics.append( + UnexpectedNodeType(invalid_node_type, None, 0) + ) self.pages.set_orphan_diagnostics(fileid, diagnostics) with self._backend_lock: self.on_diagnostics(fileid, diagnostics) From d024d2a1809dae58e01979ad964265feff202d4d Mon Sep 17 00:00:00 2001 From: rayangler <27821750+rayangler@users.noreply.github.com> Date: Thu, 12 Dec 2024 18:29:26 -0500 Subject: [PATCH 09/13] Refactor to util function --- snooty/n.py | 83 ------------------------------------------- snooty/parser.py | 4 ++- snooty/util.py | 91 +++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 93 insertions(+), 85 deletions(-) diff --git a/snooty/n.py b/snooty/n.py index 843753d4..904befa2 100644 --- a/snooty/n.py +++ b/snooty/n.py @@ -22,8 +22,6 @@ Union, ) -from typing_extensions import Self - __all__ = ( "Node", "InlineNode", @@ -137,87 +135,6 @@ def serialize(self) -> SerializedNode: del result["span"] return result - @classmethod - def deserialize(cls, node: Dict[str, SerializedNode]) -> Self: - filtered_fields = {} - node_classes: List[Type[Node]] = [ - BlockSubstitutionReference, - Code, - Comment, - DefinitionList, - DefinitionListItem, - Directive, - DirectiveArgument, - Emphasis, - Field, - FieldList, - Footnote, - FootnoteReference, - Heading, - InlineTarget, - Label, - Line, - LineBlock, - ListNode, - ListNodeItem, - Literal, - NamedReference, - Paragraph, - Reference, - RefRole, - Role, - Root, - Section, - Strong, - SubstitutionDefinition, - SubstitutionReference, - Table, - Target, - TargetIdentifier, - Text, - TocTreeDirective, - Transition, - ] - - def find_matching_type(node: SerializedNode) -> Optional[Type[Node]]: - for c in node_classes: - if c.type == node["type"]: - return c - return None - - for field in dataclasses.fields(cls): - # We don't need "span" to be present here since we need to hardcode it as the first argument of Node - if field.name == "span": - continue - - node_value = node.get(field.name) - has_nested_children = field.name == "children" and issubclass(cls, Parent) - has_nested_argument = field.name == "argument" and issubclass( - cls, Directive - ) - if isinstance(node_value, List) and ( - has_nested_children or has_nested_argument - ): - deserialized_children = [] - - for child in node_value: - if not isinstance(child, dict): - continue - - child_node_type = find_matching_type(child) - if child_node_type: - deserialized_children.append(child_node_type.deserialize(child)) - else: - raise NotImplementedError(child.get("type")) - - filtered_fields[field.name] = deserialized_children - else: - # Ideally, we validate that the data types of the fields match the data types of the JSON node, - # but that requires a more verbose and time-consuming process. For now, we assume data types are correct. - filtered_fields[field.name] = node_value - - return cls((0,), **filtered_fields) - def get_text(self) -> str: """Return pure textual content from a given AST node. Most nodes will return an empty string.""" return "" diff --git a/snooty/parser.py b/snooty/parser.py index 330ae3c2..079d3cba 100644 --- a/snooty/parser.py +++ b/snooty/parser.py @@ -1872,7 +1872,9 @@ def build( ) ast_root = ( - n.Root.deserialize(ast_json) if is_valid_ast_root else None + util.deserialize_ast(ast_json, n.Root, diagnostics) + if is_valid_ast_root + else None ) new_page = Page.create( fileid, diff --git a/snooty/util.py b/snooty/util.py index a55a1449..9440373d 100644 --- a/snooty/util.py +++ b/snooty/util.py @@ -40,6 +40,7 @@ Set, TextIO, Tuple, + Type, TypeVar, Union, cast, @@ -48,7 +49,7 @@ import requests import tomli -from snooty.diagnostics import Diagnostic, NestedProject +from snooty.diagnostics import Diagnostic, NestedProject, UnexpectedNodeType from snooty.n import FileId from . import n, tinydocutils @@ -648,3 +649,91 @@ def parse_toml_and_add_line_info(text: str) -> Dict[str, Any]: raise TOMLDecodeErrorWithSourceInfo(message, text.count("\n") + 1) from err raise err + + +def deserialize_ast( + node: n.SerializedNode, node_type: Type[n._N], diagnostics: List[Diagnostic] +) -> n._N: + filtered_fields: Dict[str, Any] = {} + node_classes: List[Type[n.Node]] = [ + n.BlockSubstitutionReference, + n.Code, + n.Comment, + n.DefinitionList, + n.DefinitionListItem, + n.Directive, + n.DirectiveArgument, + n.Emphasis, + n.Field, + n.FieldList, + n.Footnote, + n.FootnoteReference, + n.Heading, + n.InlineTarget, + n.Label, + n.Line, + n.LineBlock, + n.ListNode, + n.ListNodeItem, + n.Literal, + n.NamedReference, + n.Paragraph, + n.Reference, + n.RefRole, + n.Role, + n.Root, + n.Section, + n.Strong, + n.SubstitutionDefinition, + n.SubstitutionReference, + n.Table, + n.Target, + n.TargetIdentifier, + n.Text, + n.TocTreeDirective, + n.Transition, + ] + + def find_matching_type(node: n.SerializedNode) -> Optional[Type[n.Node]]: + for c in node_classes: + if c.type == node.get("type"): + return c + return None + + for field in dataclasses.fields(node_type): + # We don't need "span" to be present here since we need to hardcode it as the first argument of Node + if field.name == "span": + continue + + node_value = node.get(field.name) + has_nested_children = field.name == "children" and issubclass( + node_type, n.Parent + ) + has_nested_argument = field.name == "argument" and issubclass( + node_type, n.Directive + ) + if isinstance(node_value, List) and ( + has_nested_children or has_nested_argument + ): + deserialized_children: List[n.Node] = [] + + for child in node_value: + if not isinstance(child, dict): + continue + + child_node_type = find_matching_type(child) + if child_node_type: + deserialized_children.append( + deserialize_ast(child, child_node_type, diagnostics) + ) + else: + diagnostics.append(UnexpectedNodeType(child.get("type"), None, 0)) + continue + + filtered_fields[field.name] = deserialized_children + else: + # Ideally, we validate that the data types of the fields match the data types of the JSON node, + # but that requires a more verbose and time-consuming process. For now, we assume data types are correct. + filtered_fields[field.name] = node_value + + return node_type((0,), **filtered_fields) From 007a04bad623ccfe9f3bc74ae6ba0d6a4ec90fc9 Mon Sep 17 00:00:00 2001 From: rayangler <27821750+rayangler@users.noreply.github.com> Date: Thu, 12 Dec 2024 18:29:48 -0500 Subject: [PATCH 10/13] Remove exception --- snooty/parser.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/snooty/parser.py b/snooty/parser.py index 079d3cba..1f589494 100644 --- a/snooty/parser.py +++ b/snooty/parser.py @@ -1883,15 +1883,6 @@ def build( ast_root, ) self._page_updated(new_page, diagnostics) - except NotImplementedError as e: - if e.args: - invalid_node_type = e.args[0] - diagnostics.append( - UnexpectedNodeType(invalid_node_type, None, 0) - ) - self.pages.set_orphan_diagnostics(fileid, diagnostics) - with self._backend_lock: - self.on_diagnostics(fileid, diagnostics) except Exception as e: logger.error(e) From b0b840ab81c589644d82da1bf2dd3a54cacea2cd Mon Sep 17 00:00:00 2001 From: rayangler <27821750+rayangler@users.noreply.github.com> Date: Fri, 13 Dec 2024 09:48:56 -0500 Subject: [PATCH 11/13] Remove prints --- snooty/postprocess.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/snooty/postprocess.py b/snooty/postprocess.py index 4bec601d..a8dcc956 100644 --- a/snooty/postprocess.py +++ b/snooty/postprocess.py @@ -2229,13 +2229,6 @@ def run( return PostprocessorResult({}, {}, {}, self.targets) self.pages = pages - # print(self.pages) - - # For debugging purposes - # print("Foofoo") - # data = self.pages.get(FileId("api/example.txt")).ast.serialize() - # print(json.dumps(data, indent=2)) - self.cancellation_token = cancellation_token context = Context(pages) context.add(self.project_config) From e80c006f8a0d043cb5f3979352c283f114a94d12 Mon Sep 17 00:00:00 2001 From: rayangler <27821750+rayangler@users.noreply.github.com> Date: Fri, 13 Dec 2024 12:16:17 -0500 Subject: [PATCH 12/13] Address feedback --- snooty/diagnostics.py | 13 +-- snooty/parser.py | 2 +- snooty/util.py | 180 +++++++++++++++++++++--------------------- 3 files changed, 101 insertions(+), 94 deletions(-) diff --git a/snooty/diagnostics.py b/snooty/diagnostics.py index 219c88bf..8b3d7ed5 100644 --- a/snooty/diagnostics.py +++ b/snooty/diagnostics.py @@ -154,14 +154,17 @@ class UnexpectedNodeType(Diagnostic): def __init__( self, - found_type: Union[str, None], + found_type: Optional[str], expected_type: Optional[str], start: Union[int, Tuple[int, int]], ) -> None: - suggestion = f' Expected: "{expected_type}".' if expected_type else "" - super().__init__( - f'Found unexpected node type "{found_type}".{suggestion}', start - ) + msg = f'Found unexpected node type "{found_type}".' + + if expected_type: + suggestion = f'Expected: "{expected_type}".' + msg += " " + suggestion + + super().__init__(msg, start) class UnnamedPage(Diagnostic): diff --git a/snooty/parser.py b/snooty/parser.py index 1f589494..aa165595 100644 --- a/snooty/parser.py +++ b/snooty/parser.py @@ -1872,7 +1872,7 @@ def build( ) ast_root = ( - util.deserialize_ast(ast_json, n.Root, diagnostics) + util.NodeDeserializer.deserialize(ast_json, n.Root, diagnostics) if is_valid_ast_root else None ) diff --git a/snooty/util.py b/snooty/util.py index 9440373d..76c7f4e6 100644 --- a/snooty/util.py +++ b/snooty/util.py @@ -455,6 +455,98 @@ def cancel(self) -> None: self.__cancel.clear() +class NodeDeserializer: + node_types: List[Type[n.Node]] = [ + n.BlockSubstitutionReference, + n.Code, + n.Comment, + n.DefinitionList, + n.DefinitionListItem, + n.Directive, + n.DirectiveArgument, + n.Emphasis, + n.Field, + n.FieldList, + n.Footnote, + n.FootnoteReference, + n.Heading, + n.InlineTarget, + n.Label, + n.Line, + n.LineBlock, + n.ListNode, + n.ListNodeItem, + n.Literal, + n.NamedReference, + n.Paragraph, + n.Reference, + n.RefRole, + n.Role, + n.Root, + n.Section, + n.Strong, + n.SubstitutionDefinition, + n.SubstitutionReference, + n.Table, + n.Target, + n.TargetIdentifier, + n.Text, + n.TocTreeDirective, + n.Transition, + ] + node_classes: Dict[str, Type[n.Node]] = { + node_class.type: node_class for node_class in node_types + } + + @classmethod + def deserialize( + cls, + node: n.SerializedNode, + node_type: Type[n._N], + diagnostics: List[Diagnostic], + ) -> n._N: + filtered_fields: Dict[str, Any] = {} + + for field in dataclasses.fields(node_type): + # We don't need "span" to be present here since we need to hardcode it as the first argument of Node + if field.name == "span": + continue + + node_value = node.get(field.name) + has_nested_children = field.name == "children" and issubclass( + node_type, n.Parent + ) + has_nested_argument = field.name == "argument" and issubclass( + node_type, n.Directive + ) + if isinstance(node_value, List) and ( + has_nested_children or has_nested_argument + ): + deserialized_children: List[n.Node] = [] + + for child in node_value: + if not isinstance(child, dict): + continue + + child_type: str = child.get("type", "") + child_node_type = cls.node_classes.get(child_type) + if child_node_type: + deserialized_children.append( + cls.deserialize(child, child_node_type, diagnostics) + ) + else: + diagnostics.append(UnexpectedNodeType(child_type, None, 0)) + continue + + filtered_fields[field.name] = deserialized_children + else: + # Ideally, we validate that the data types of the fields match the data types of the JSON node, + # but that requires a more verbose and time-consuming process. For now, we assume data types are correct. + filtered_fields[field.name] = node_value + + return node_type((0,), **filtered_fields) + + def bundle( filename: PurePath, members: Iterable[Tuple[str, Union[str, bytes]]] ) -> bytes: @@ -649,91 +741,3 @@ def parse_toml_and_add_line_info(text: str) -> Dict[str, Any]: raise TOMLDecodeErrorWithSourceInfo(message, text.count("\n") + 1) from err raise err - - -def deserialize_ast( - node: n.SerializedNode, node_type: Type[n._N], diagnostics: List[Diagnostic] -) -> n._N: - filtered_fields: Dict[str, Any] = {} - node_classes: List[Type[n.Node]] = [ - n.BlockSubstitutionReference, - n.Code, - n.Comment, - n.DefinitionList, - n.DefinitionListItem, - n.Directive, - n.DirectiveArgument, - n.Emphasis, - n.Field, - n.FieldList, - n.Footnote, - n.FootnoteReference, - n.Heading, - n.InlineTarget, - n.Label, - n.Line, - n.LineBlock, - n.ListNode, - n.ListNodeItem, - n.Literal, - n.NamedReference, - n.Paragraph, - n.Reference, - n.RefRole, - n.Role, - n.Root, - n.Section, - n.Strong, - n.SubstitutionDefinition, - n.SubstitutionReference, - n.Table, - n.Target, - n.TargetIdentifier, - n.Text, - n.TocTreeDirective, - n.Transition, - ] - - def find_matching_type(node: n.SerializedNode) -> Optional[Type[n.Node]]: - for c in node_classes: - if c.type == node.get("type"): - return c - return None - - for field in dataclasses.fields(node_type): - # We don't need "span" to be present here since we need to hardcode it as the first argument of Node - if field.name == "span": - continue - - node_value = node.get(field.name) - has_nested_children = field.name == "children" and issubclass( - node_type, n.Parent - ) - has_nested_argument = field.name == "argument" and issubclass( - node_type, n.Directive - ) - if isinstance(node_value, List) and ( - has_nested_children or has_nested_argument - ): - deserialized_children: List[n.Node] = [] - - for child in node_value: - if not isinstance(child, dict): - continue - - child_node_type = find_matching_type(child) - if child_node_type: - deserialized_children.append( - deserialize_ast(child, child_node_type, diagnostics) - ) - else: - diagnostics.append(UnexpectedNodeType(child.get("type"), None, 0)) - continue - - filtered_fields[field.name] = deserialized_children - else: - # Ideally, we validate that the data types of the fields match the data types of the JSON node, - # but that requires a more verbose and time-consuming process. For now, we assume data types are correct. - filtered_fields[field.name] = node_value - - return node_type((0,), **filtered_fields) From 213394b0814a00e182693a369e626a39362309b3 Mon Sep 17 00:00:00 2001 From: rayangler <27821750+rayangler@users.noreply.github.com> Date: Mon, 16 Dec 2024 10:05:40 -0500 Subject: [PATCH 13/13] Whoops --- snooty/test_parser.py | 1 - 1 file changed, 1 deletion(-) diff --git a/snooty/test_parser.py b/snooty/test_parser.py index 83e99f09..83b11754 100644 --- a/snooty/test_parser.py +++ b/snooty/test_parser.py @@ -4586,5 +4586,4 @@ def test_parse_ast() -> None: diagnostics = result.diagnostics[FileId("test.ast")] assert not diagnostics bad_types_diagnostics = result.diagnostics[FileId("bad-types.ast")] - result.pages assert [type(d) for d in bad_types_diagnostics] == [UnexpectedNodeType]