Skip to content

Commit

Permalink
TNode defaults to Node (not Any)
Browse files Browse the repository at this point in the history
TNode = TypeVar("TNode", bound="Node", default="Node")
  • Loading branch information
mar10 committed Nov 2, 2024
1 parent ebdb1a4 commit 8f9b701
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 98 deletions.
10 changes: 5 additions & 5 deletions nutree/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(self, value=None):
DataIdType = Union[str, int]

#: Type of ``Tree(..., calc_data_id)```
CalcIdCallbackType = Callable[["Tree", Any], DataIdType]
CalcIdCallbackType = Callable[["Tree[Any]", Any], DataIdType]

#: Type of ``format(..., repr=)```
ReprArgType = Union[str, Callable[["Node"], str]]
Expand Down Expand Up @@ -227,19 +227,18 @@ class DictWrapper:
__slots__ = ("_dict",)

def __init__(self, dict_inst: dict | None = None, **values) -> None:
self._dict: dict = {}
if dict_inst is not None:
# A dictionary was passed: store a reference to that instance
if not isinstance(dict_inst, dict):
self._dict = None # type: ignore
raise TypeError("dict_inst must be a dictionary or None")
if values:
self._dict = None # type: ignore
raise ValueError("Cannot pass both dict_inst and **values")
self._dict: dict = dict_inst
self._dict = dict_inst
else:
# Single keyword arguments are passed (probably from unpacked dict):
# store them in a new dictionary
self._dict: dict = values
self._dict = values

def __repr__(self):
return f"{self.__class__.__name__}<{self._dict}>"
Expand Down Expand Up @@ -302,6 +301,7 @@ def serialize_mapper(cls, nutree_node: Node, data: dict) -> Union[None, dict]:
tree.save(file_path, mapper=DictWrapper.serialize_mapper)
"""
assert isinstance(nutree_node.data, DictWrapper)
return nutree_node.data._dict.copy()

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions nutree/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING: # Imported by type checkers, but prevent circular includes
from nutree.tree import Node, Tree
Expand Down Expand Up @@ -82,7 +82,7 @@ def diff_node_formatter(node):
return s


def diff_tree(t0: Tree, t1: Tree, *, ordered=False, reduce=False) -> Tree:
def diff_tree(t0: Tree[Any], t1: Tree[Any], *, ordered=False, reduce=False) -> Tree:
from nutree import Tree

t2 = Tree(f"diff({t0.name!r}, {t1.name!r})")
Expand Down
4 changes: 2 additions & 2 deletions nutree/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from __future__ import annotations

from pathlib import Path
from typing import IO, TYPE_CHECKING, Iterator
from typing import IO, TYPE_CHECKING, Any, Iterator

from nutree.common import MapperCallbackType, call_mapper

Expand Down Expand Up @@ -111,7 +111,7 @@ def _attr_str(attr_def: dict, mapper=None, node=None):


def tree_to_dotfile(
tree: Tree,
tree: Tree[Any],
target: IO[str] | str | Path,
*,
format=None,
Expand Down
51 changes: 34 additions & 17 deletions nutree/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from nutree.rdf import RDFMapperCallbackType, node_to_rdf

TNode = TypeVar("TNode", bound="Node", default="Node")
# TNode = TypeVar("TNode", bound="Node", default="Node", covariant=True)


# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -410,7 +411,7 @@ def get_siblings(self, *, add_self=False) -> list[Self]:

def first_sibling(self) -> Self:
"""Return first sibling (may be self)."""
return self._parent._children[0] # type: ignore[reportOptionalSubscript]
return self._parent._children[0] # type: ignore

def prev_sibling(self) -> Self | None:
"""Predecessor or None, if node is first sibling."""
Expand Down Expand Up @@ -629,16 +630,18 @@ def add_child(
self.add_child(n, before=before, deep=deep)
return cast(Self, n) # need to return a node

source_node = None
source_node: Self = None # type: ignore
new_node: Self = None # type: ignore

if isinstance(child, Node):
assert isinstance(child, self.tree.node_factory)
# Adding an existing node means that we create a clone
if deep is None:
deep = False
if deep and data_id is not None or node_id is not None:
raise ValueError("Cannot set ID for deep copies.")

source_node = child
source_node = cast(Self, child)
if source_node._tree is self._tree:
if source_node._parent is self:
raise UniqueConstraintError(
Expand All @@ -654,12 +657,18 @@ def add_child(
# If creating an inherited node, use the parent class as constructor
# child_class = child.__class__

node = self.tree.node_factory(
source_node.data, parent=self, data_id=data_id, node_id=node_id
new_node = cast(
Self,
self.tree.node_factory(
source_node.data, parent=self, data_id=data_id, node_id=node_id
),
)
else:
node = self.tree.node_factory(
child, parent=self, data_id=data_id, node_id=node_id
new_node = cast(
Self,
self.tree.node_factory(
child, parent=self, data_id=data_id, node_id=node_id
),
)

if before is True:
Expand All @@ -668,24 +677,24 @@ def add_child(
children = self._children
if children is None:
assert before in (None, True, int, False)
self._children = [node]
self._children = [new_node]
elif isinstance(before, int):
children.insert(before, node)
children.insert(before, new_node)
elif before:
if before._parent is not self:
raise ValueError(
f"`before=node` ({before._parent}) "
f"must be a child of target node ({self})"
)
idx = children.index(before) # raises ValueError
children.insert(idx, node)
children.insert(idx, new_node)
else:
children.append(node)
children.append(new_node)

if deep and source_node:
node._add_from(source_node)
new_node._add_from(source_node)

return node
return new_node

#: Alias for :meth:`add_child`
add = add_child
Expand Down Expand Up @@ -854,7 +863,7 @@ def copy(

def copy_to(
self,
target: Self | Tree,
target: Self | Tree[Self],
*,
add_self=True,
before: Self | bool | int | None = None,
Expand All @@ -873,7 +882,9 @@ def copy_to(
If `deep` is set, all descendants are copied recursively.
"""
if add_self:
return target.add_child(self, before=before, deep=deep)
res = target.add_child(self, before=before, deep=deep)
return cast(Self, res) # if target is Tree, type is not inferred?
# return target.add_child(self, before=before, deep=deep)
assert before is None
if not self._children:
raise ValueError("Need child nodes when `add_self=False`")
Expand Down Expand Up @@ -968,7 +979,7 @@ def filtered(self, predicate: PredicateCallbackType) -> Tree[Self]:
See also :ref:`iteration-callbacks`.
"""
if not predicate:
if not predicate: # mypy: ignore
raise ValueError("Predicate is required (use copy() instead)")
return self.copy(add_self=True, predicate=predicate)

Expand Down Expand Up @@ -1143,6 +1154,7 @@ def visit(
return
except StopTraversal as e:
return e.value
return

def _iter_pre(self) -> Iterator[Self]:
"""Depth-first, pre-order traversal."""
Expand Down Expand Up @@ -1397,7 +1409,12 @@ def format_iter(
yield from self._render_lines(repr=repr, style=style, add_self=add_self)

def format(
self, *, repr: ReprArgType | None = None, style=None, add_self=True, join="\n"
self,
*,
repr: ReprArgType | None = None,
style=None,
add_self=True,
join: str = "\n",
) -> str:
r"""Return a pretty string representation of the node hierarchy.
Expand Down
4 changes: 2 additions & 2 deletions nutree/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Union
from typing import TYPE_CHECKING, Any, Callable, Union

from nutree.common import IterationControl

Expand Down Expand Up @@ -163,7 +163,7 @@ def node_to_rdf(


def tree_to_rdf(
tree: Tree,
tree: Tree[Any],
*,
node_mapper: RDFMapperCallbackType | None = None,
) -> Graph:
Expand Down
45 changes: 23 additions & 22 deletions nutree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,22 @@
check_python_version(MIN_PYTHON_VERSION_INFO)


# ------------------------------------------------------------------------------
# - _SystemRootNode
# ------------------------------------------------------------------------------
class _SystemRootNode(Node):
"""Invisible system root node."""

def __init__(self, tree: Tree) -> None:
self._tree: Tree = tree # type: ignore
self._parent = None # type: ignore
self._node_id = ROOT_NODE_ID
self._data_id = ROOT_DATA_ID
self._data = tree.name
self._children = []
self._meta = None


# ------------------------------------------------------------------------------
# - Tree
# ------------------------------------------------------------------------------
Expand All @@ -89,14 +105,15 @@ class Tree(Generic[TNode]):
**Note:** Use with care, see also :ref:`forward-attributes`.
"""

node_factory: Type[TNode] = cast(Type[TNode], Node)
node_factory: Type[Node] = Node
root_node_factory = _SystemRootNode

#: Default connector prefixes ``format(style=...)`` argument.
DEFAULT_CONNECTOR_STYLE = "round43"
#: Default value for ``save(..., key_map=...)`` argument.
DEFAULT_KEY_MAP = {"data_id": "i", "str": "s"}
DEFAULT_KEY_MAP: dict[str, str] = {"data_id": "i", "str": "s"}
#: Default value for ``save(..., value_map=...)`` argument.
DEFAULT_VALUE_MAP = {}
DEFAULT_VALUE_MAP: dict[str, list[str]] = {}
# #: Default value for ``save(..., mapper=...)`` argument.
# DEFAULT_SERIALZATION_MAPPER = None
# #: Default value for ``load(..., mapper=...)`` argument.
Expand All @@ -112,7 +129,7 @@ def __init__(
self._lock = threading.RLock()
#: Tree name used for logging
self.name: str = str(id(self) if name is None else name)
self._root: TNode = cast(TNode, _SystemRootNode(self))
self._root: TNode = self.root_node_factory(self) # type: ignore
self._node_by_id: dict[int, TNode] = {}
self._nodes_by_data_id: dict[DataIdType, list[TNode]] = {}
# Optional callback that calculates data_ids from data objects
Expand Down Expand Up @@ -904,7 +921,7 @@ def _self_check(self) -> Literal[True]:
return True

@classmethod
def build_random_tree(cls: type[Self], structure_def: dict) -> Self:
def build_random_tree(cls, structure_def: dict) -> Self:
"""Build a random tree for .
Returns a new :class:`Tree` instance with random nodes, as defined by
Expand All @@ -917,20 +934,4 @@ def build_random_tree(cls: type[Self], structure_def: dict) -> Self:
from nutree.tree_generator import build_random_tree

tt = build_random_tree(tree_class=cls, structure_def=structure_def)
return tt


# ------------------------------------------------------------------------------
# - _SystemRootNode
# ------------------------------------------------------------------------------
class _SystemRootNode(Node):
"""Invisible system root node."""

def __init__(self, tree: Tree) -> None:
self._tree: Tree = tree # type: ignore
self._parent = None # type: ignore
self._node_id = ROOT_NODE_ID
self._data_id = ROOT_DATA_ID
self._data = tree.name
self._children = []
self._meta = None
return cast(Self, tt)
9 changes: 4 additions & 5 deletions nutree/tree_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from datetime import date, datetime, timedelta, timezone
from typing import Any, Sequence, Union

from typing_extensions import TypeVar

# from typing_extensions import TypeVar
from nutree.common import DictWrapper
from nutree.node import Node
from nutree.tree import Tree
Expand All @@ -28,7 +27,7 @@
Fabulist = None
fab = None

TTree = TypeVar("TTree", bound=Tree)
# TTree = TypeVar("TTree", bound=Tree)


# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -392,7 +391,7 @@ def _make_tree(
return


def build_random_tree(*, tree_class: type[TTree], structure_def: dict) -> TTree:
def build_random_tree(*, tree_class: type[Tree[Any]], structure_def: dict) -> Tree:
"""
Return a nutree.TypedTree with random data from a specification.
See :ref:`randomize` for details.
Expand All @@ -405,7 +404,7 @@ def build_random_tree(*, tree_class: type[TTree], structure_def: dict) -> TTree:
assert not structure_def, f"found extra data: {structure_def}"
assert "__root__" in relations, "missing '__root__' relation"

tree: TTree = tree_class(
tree: Tree = tree_class(
name=name,
forward_attrs=True,
)
Expand Down
Loading

0 comments on commit 8f9b701

Please sign in to comment.