From a8a0a8840ab562c203afc759e9a490a1998e5611 Mon Sep 17 00:00:00 2001 From: Quek Ching Yee Date: Sun, 25 Aug 2024 16:39:05 +0800 Subject: [PATCH 1/3] feat: plot tree WIP + standardize kwargs param --- bigtree/node/basenode.py | 20 ++++++++++++++++++++ bigtree/tree/export.py | 14 ++++++-------- bigtree/utils/exceptions.py | 22 ++++++++++++++++++++++ bigtree/utils/plot.py | 37 ++++++++++++++++++++++++++++++++++++- pyproject.toml | 2 ++ tests/tree/test_export.py | 4 ++-- 6 files changed, 88 insertions(+), 11 deletions(-) diff --git a/bigtree/node/basenode.py b/bigtree/node/basenode.py index f1969060..e8baf2d3 100644 --- a/bigtree/node/basenode.py +++ b/bigtree/node/basenode.py @@ -8,6 +8,11 @@ from bigtree.utils.exceptions import CorruptedTreeError, LoopError, TreeError from bigtree.utils.iterators import preorder_iter +try: + import matplotlib.pyplot as plt +except ImportError: # pragma: no cover + plt = None + class BaseNode: """ @@ -115,6 +120,7 @@ class BaseNode: 6. ``extend(nodes: List[Self])``: Add multiple children to node 7. ``copy()``: Deep copy self 8. ``sort()``: Sort child nodes + 9. ``plot()``: Plot tree in line form ---- @@ -727,6 +733,7 @@ def copy(self: T) -> T: def sort(self: T, **kwargs: Any) -> None: """Sort children, possible keyword arguments include ``key=lambda node: node.name``, ``reverse=True`` + Accepts kwargs for sort() function. Examples: >>> from bigtree import Node, print_tree @@ -747,6 +754,19 @@ def sort(self: T, **kwargs: Any) -> None: children.sort(**kwargs) self.__children = children + def plot(self, save_path: str = "", *args: Any, **kwargs: Any) -> "plt.Figure": + """Plot tree in line form. + Accepts args and kwargs for matplotlib.pyplot.plot() function. + + Args: + save_path (str): save path of plot + """ + from bigtree.utils.plot import plot_tree, reingold_tilford + + if not self.get_attr("x") or self.get_attr("y"): + reingold_tilford(self) + return plot_tree(self, save_path, *args, **kwargs) + def __copy__(self: T) -> T: """Shallow copy self diff --git a/bigtree/tree/export.py b/bigtree/tree/export.py index 06bdc6ad..a08d3b13 100644 --- a/bigtree/tree/export.py +++ b/bigtree/tree/export.py @@ -73,9 +73,10 @@ def print_tree( attr_omit_null: bool = False, attr_bracket: List[str] = ["[", "]"], style: Union[str, Iterable[str], BasePrintStyle] = "const", - **print_kwargs: Any, + **kwargs: Any, ) -> None: """Print tree to console, starting from `tree`. + Accepts kwargs for print() function. - Able to select which node to print from, resulting in a subtree, using `node_name_or_path` - Able to customize for maximum depth to print, using `max_depth` @@ -91,8 +92,6 @@ def print_tree( - (BasePrintStyle): `ANSIPrintStyle`, `ASCIIPrintStyle`, `ConstPrintStyle`, `ConstBoldPrintStyle`, `RoundedPrintStyle`, `DoublePrintStyle` style or inherit from `BasePrintStyle` - Remaining kwargs are passed without modification to python's `print` function. - Examples: **Printing tree** @@ -249,7 +248,7 @@ def print_tree( if attr_str: attr_str = f" {attr_bracket_open}{attr_str}{attr_bracket_close}" node_str = f"{_node.node_name}{attr_str}" - print(f"{pre_str}{fill_str}{node_str}", **print_kwargs) + print(f"{pre_str}{fill_str}{node_str}", **kwargs) def yield_tree( @@ -433,9 +432,10 @@ def hprint_tree( max_depth: int = 0, intermediate_node_name: bool = True, style: Union[str, Iterable[str], BaseHPrintStyle] = "const", - **print_kwargs: Any, + **kwargs: Any, ) -> None: """Print tree in horizontal orientation to console, starting from `tree`. + Accepts kwargs for print() function. - Able to select which node to print from, resulting in a subtree, using `node_name_or_path` - Able to customize for maximum depth to print, using `max_depth` @@ -448,8 +448,6 @@ def hprint_tree( - (BaseHPrintStyle): `ANSIHPrintStyle`, `ASCIIHPrintStyle`, `ConstHPrintStyle`, `ConstBoldHPrintStyle`, `RoundedHPrintStyle`, `DoubleHPrintStyle` style or inherit from BaseHPrintStyle - Remaining kwargs are passed without modification to python's `print` function. - Examples: **Printing tree** @@ -549,7 +547,7 @@ def hprint_tree( max_depth=max_depth, style=style, ) - print("\n".join(result), **print_kwargs) + print("\n".join(result), **kwargs) def hyield_tree( diff --git a/bigtree/utils/exceptions.py b/bigtree/utils/exceptions.py index 9087e934..dbdb0c40 100644 --- a/bigtree/utils/exceptions.py +++ b/bigtree/utils/exceptions.py @@ -114,6 +114,28 @@ def wrapper(*args: Any, **kwargs: Any) -> T: return wrapper +def optional_dependencies_matplotlib( + func: Callable[..., T] +) -> Callable[..., T]: # pragma: no cover + """ + This is a decorator which can be used to import optional matplotlib dependency. + It will raise a ImportError if the module is not found. + """ + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> T: + try: + import matplotlib.pyplot as plt # noqa: F401 + except ImportError: + raise ImportError( + "matplotlib not available. Please perform a\n\n" + "pip install 'bigtree[matplotlib]'\n\nto install required dependencies" + ) from None + return func(*args, **kwargs) + + return wrapper + + def optional_dependencies_image( package_name: str = "", ) -> Callable[[Callable[..., T]], Callable[..., T]]: diff --git a/bigtree/utils/plot.py b/bigtree/utils/plot.py index 0cc42fb9..8c4b3065 100644 --- a/bigtree/utils/plot.py +++ b/bigtree/utils/plot.py @@ -1,9 +1,18 @@ -from typing import Optional, TypeVar +from typing import Any, Optional, TypeVar from bigtree.node.basenode import BaseNode +from bigtree.utils.exceptions import optional_dependencies_matplotlib +from bigtree.utils.iterators import preorder_iter + +try: + import matplotlib.pyplot as plt +except ImportError: # pragma: no cover + plt = None + __all__ = [ "reingold_tilford", + "plot_tree", ] T = TypeVar("T", bound=BaseNode) @@ -73,6 +82,32 @@ def reingold_tilford( _third_pass(tree_node, x_adjustment) +@optional_dependencies_matplotlib +def plot_tree(tree_node: T, save_path: str = "", **kwargs: Any) -> plt.Figure: + """Plot tree in line form. Tree should have `x` and `y` attribute. + Accepts args and kwargs for matplotlib.pyplot.plot() function. + + Examples: + >>> from bigtree import Node, list_to_tree, plot_tree, reingold_tilford + >>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"] + >>> root = list_to_tree(path_list) + >>> reingold_tilford(root) + >>> plot_tree(root) + + Args: + tree_node (BaseNode): tree to plot + save_path (str): save path of plot + """ + + for node in preorder_iter(tree_node): + if node.is_root: + pass + plt.plot(node.get_attr("x"), node.get_attr("y"), **kwargs) + if save_path: + plt.savefig(save_path) + return plt.figure() + + def _first_pass( tree_node: T, sibling_separation: float, subtree_separation: float ) -> None: diff --git a/pyproject.toml b/pyproject.toml index 01c2c443..db7fe60a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ Source = "https://github.com/kayjan/bigtree" [project.optional-dependencies] all = [ + "matplotlib", "pandas", "polars", "pydot", @@ -46,6 +47,7 @@ image = [ "pydot", "Pillow", ] +matplotlib = ["matplotlib"] pandas = ["pandas"] polars = ["polars"] diff --git a/tests/tree/test_export.py b/tests/tree/test_export.py index fcf56294..1b7c2592 100644 --- a/tests/tree/test_export.py +++ b/tests/tree/test_export.py @@ -402,7 +402,7 @@ def test_print_tree_custom_style_missing_style_error(tree_node): assert str(exc_info.value) == Constants.ERROR_NODE_EXPORT_PRINT_STYLE_SELECT @staticmethod - def test_print_tree_print_kwargs(tree_node): + def test_print_tree_kwargs(tree_node): output = io.StringIO() print_tree(tree_node, file=output) assert output.getvalue() == tree_node_no_attr_str @@ -792,7 +792,7 @@ def test_hprint_tree_custom_style_missing_style_error(tree_node): assert str(exc_info.value) == Constants.ERROR_NODE_EXPORT_HPRINT_STYLE_SELECT @staticmethod - def test_hprint_tree_print_kwargs(tree_node): + def test_hprint_tree_kwargs(tree_node): output = io.StringIO() hprint_tree(tree_node, file=output) assert output.getvalue() == tree_node_hstr From 886e82fd998c996ba2199cbafe1e29fa0372d3fb Mon Sep 17 00:00:00 2001 From: Quek Ching Yee Date: Sun, 25 Aug 2024 17:26:26 +0800 Subject: [PATCH 2/3] feat: plot tree in line form --- bigtree/__init__.py | 2 +- bigtree/node/basenode.py | 11 +++++++++-- bigtree/utils/plot.py | 19 +++++++++++++------ pyproject.toml | 1 + 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/bigtree/__init__.py b/bigtree/__init__.py index ff9087f4..cffa4ea7 100644 --- a/bigtree/__init__.py +++ b/bigtree/__init__.py @@ -93,6 +93,6 @@ zigzag_iter, zigzaggroup_iter, ) -from bigtree.utils.plot import reingold_tilford +from bigtree.utils.plot import plot_tree, reingold_tilford from bigtree.workflows.app_calendar import Calendar from bigtree.workflows.app_todo import AppToDo diff --git a/bigtree/node/basenode.py b/bigtree/node/basenode.py index e8baf2d3..0a1de376 100644 --- a/bigtree/node/basenode.py +++ b/bigtree/node/basenode.py @@ -754,10 +754,17 @@ def sort(self: T, **kwargs: Any) -> None: children.sort(**kwargs) self.__children = children - def plot(self, save_path: str = "", *args: Any, **kwargs: Any) -> "plt.Figure": + def plot(self, *args: Any, save_path: str = "", **kwargs: Any) -> plt.Figure: """Plot tree in line form. Accepts args and kwargs for matplotlib.pyplot.plot() function. + Examples: + >>> from bigtree import list_to_tree + >>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"] + >>> root = list_to_tree(path_list) + >>> root.plot("-ok", save_path="tree.png") +
+ Args: save_path (str): save path of plot """ @@ -765,7 +772,7 @@ def plot(self, save_path: str = "", *args: Any, **kwargs: Any) -> "plt.Figure": if not self.get_attr("x") or self.get_attr("y"): reingold_tilford(self) - return plot_tree(self, save_path, *args, **kwargs) + return plot_tree(self, *args, save_path, **kwargs) def __copy__(self: T) -> T: """Shallow copy self diff --git a/bigtree/utils/plot.py b/bigtree/utils/plot.py index 8c4b3065..0d5447a1 100644 --- a/bigtree/utils/plot.py +++ b/bigtree/utils/plot.py @@ -83,16 +83,19 @@ def reingold_tilford( @optional_dependencies_matplotlib -def plot_tree(tree_node: T, save_path: str = "", **kwargs: Any) -> plt.Figure: +def plot_tree( + tree_node: T, *args: Any, save_path: str = "", **kwargs: Any +) -> plt.Figure: """Plot tree in line form. Tree should have `x` and `y` attribute. Accepts args and kwargs for matplotlib.pyplot.plot() function. Examples: - >>> from bigtree import Node, list_to_tree, plot_tree, reingold_tilford + >>> from bigtree import list_to_tree, plot_tree, reingold_tilford >>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"] >>> root = list_to_tree(path_list) >>> reingold_tilford(root) - >>> plot_tree(root) + >>> plot_tree(root, "-ok", save_path="tree.png") +
Args: tree_node (BaseNode): tree to plot @@ -100,9 +103,13 @@ def plot_tree(tree_node: T, save_path: str = "", **kwargs: Any) -> plt.Figure: """ for node in preorder_iter(tree_node): - if node.is_root: - pass - plt.plot(node.get_attr("x"), node.get_attr("y"), **kwargs) + if not node.is_root: + plt.plot( + [node.get_attr("x"), node.parent.get_attr("x")], + [node.get_attr("y"), node.parent.get_attr("y")], + *args, + **kwargs, + ) if save_path: plt.savefig(save_path) return plt.figure() diff --git a/pyproject.toml b/pyproject.toml index db7fe60a..170d0098 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ path = "bigtree/__init__.py" dependencies = [ "black", "coverage", + "matplotlib", "mypy", "pandas", "Pillow", From caa9731f92a40792c15c64f03cc01f9e5061c4ad Mon Sep 17 00:00:00 2001 From: Kay Date: Mon, 26 Aug 2024 00:55:30 +0800 Subject: [PATCH 3/3] feat: add test cases for plot + optional dependencies imported as MagicMock --- CHANGELOG.md | 10 ++++++- README.md | 1 + bigtree/dag/construct.py | 4 ++- bigtree/dag/export.py | 8 ++++-- bigtree/node/basenode.py | 14 +++++----- bigtree/tree/construct.py | 8 ++++-- bigtree/tree/export.py | 16 ++++++++--- bigtree/utils/assertions.py | 8 ++++-- bigtree/utils/plot.py | 44 ++++++++++++++++++++----------- bigtree/workflows/app_calendar.py | 4 ++- docs/home/tree.md | 1 + tests/node/test_basenode.py | 14 ++++++++++ tests/test_constants.py | 6 +++++ tests/utils/test_plot.py | 36 ++++++++++++++++++++++++- 14 files changed, 136 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fa4d9bf..87a4f8d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.21.0] - TBD +### Added: +- Tree Plot: Plot tree using matplotlib library, added matplotlib as optional dependency. +- BaseNode: Add plot method. +### Changed: +- Misc: Optional dependencies imported as MagicMock + ## [0.20.1] - 2024-08-24 ### Changed: - Misc: Documentation update contributing instructions. @@ -638,7 +645,8 @@ ignore null attribute columns. - Utility Iterator: Tree traversal methods. - Workflow To Do App: Tree use case with to-do list implementation. -[Unreleased]: https://github.com/kayjan/bigtree/compare/0.20.1...HEAD +[Unreleased]: https://github.com/kayjan/bigtree/compare/0.21.0...HEAD +[0.21.0]: https://github.com/kayjan/bigtree/compare/0.20.1...0.21.0 [0.20.1]: https://github.com/kayjan/bigtree/compare/0.20.0...0.20.1 [0.20.0]: https://github.com/kayjan/bigtree/compare/0.19.4...0.20.0 [0.19.4]: https://github.com/kayjan/bigtree/compare/0.19.3...0.19.4 diff --git a/README.md b/README.md index 7b2c1407..fbd68f74 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ For **Tree** implementation, there are 9 main components. 4. Get difference between two trees 7. [**📊 Plotting Tree**](https://bigtree.readthedocs.io/en/stable/bigtree/utils/plot/) 1. Enhanced Reingold Tilford Algorithm to retrieve (x, y) coordinates for a tree structure + 2. Plot tree using matplotlib (optional dependency) 8. [**🔨 Exporting Tree**](https://bigtree.readthedocs.io/en/stable/bigtree/tree/export/) 1. Print to console, in vertical or horizontal orientation 2. Export to *Newick string notation*, *dictionary*, *nested dictionary*, *pandas DataFrame*, or *polars DataFrame* diff --git a/bigtree/dag/construct.py b/bigtree/dag/construct.py index 6a0f8952..060fc089 100644 --- a/bigtree/dag/construct.py +++ b/bigtree/dag/construct.py @@ -16,7 +16,9 @@ try: import pandas as pd except ImportError: # pragma: no cover - pd = None + from unittest.mock import MagicMock + + pd = MagicMock() __all__ = ["list_to_dag", "dict_to_dag", "dataframe_to_dag"] diff --git a/bigtree/dag/export.py b/bigtree/dag/export.py index e9ce9b81..5c6b8e0a 100644 --- a/bigtree/dag/export.py +++ b/bigtree/dag/export.py @@ -13,12 +13,16 @@ try: import pandas as pd except ImportError: # pragma: no cover - pd = None + from unittest.mock import MagicMock + + pd = MagicMock() try: import pydot except ImportError: # pragma: no cover - pydot = None + from unittest.mock import MagicMock + + pydot = MagicMock() __all__ = ["dag_to_list", "dag_to_dict", "dag_to_dataframe", "dag_to_dot"] diff --git a/bigtree/node/basenode.py b/bigtree/node/basenode.py index 0a1de376..23977e2f 100644 --- a/bigtree/node/basenode.py +++ b/bigtree/node/basenode.py @@ -754,25 +754,23 @@ def sort(self: T, **kwargs: Any) -> None: children.sort(**kwargs) self.__children = children - def plot(self, *args: Any, save_path: str = "", **kwargs: Any) -> plt.Figure: + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Plot tree in line form. Accepts args and kwargs for matplotlib.pyplot.plot() function. Examples: + >>> import matplotlib.pyplot as plt >>> from bigtree import list_to_tree >>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"] >>> root = list_to_tree(path_list) - >>> root.plot("-ok", save_path="tree.png") -
- - Args: - save_path (str): save path of plot + >>> root.plot("-ok") +
""" from bigtree.utils.plot import plot_tree, reingold_tilford - if not self.get_attr("x") or self.get_attr("y"): + if self.get_attr("x") is None or self.get_attr("y") is None: reingold_tilford(self) - return plot_tree(self, *args, save_path, **kwargs) + return plot_tree(self, *args, **kwargs) def __copy__(self: T) -> T: """Shallow copy self diff --git a/bigtree/tree/construct.py b/bigtree/tree/construct.py index 51e57f89..a27eca0a 100644 --- a/bigtree/tree/construct.py +++ b/bigtree/tree/construct.py @@ -24,12 +24,16 @@ try: import pandas as pd except ImportError: # pragma: no cover - pd = None + from unittest.mock import MagicMock + + pd = MagicMock() try: import polars as pl except ImportError: # pragma: no cover - pl = None + from unittest.mock import MagicMock + + pl = MagicMock() __all__ = [ "add_path_to_tree", diff --git a/bigtree/tree/export.py b/bigtree/tree/export.py index a08d3b13..7899ed38 100644 --- a/bigtree/tree/export.py +++ b/bigtree/tree/export.py @@ -28,22 +28,30 @@ try: import pandas as pd except ImportError: # pragma: no cover - pd = None + from unittest.mock import MagicMock + + pd = MagicMock() try: import polars as pl except ImportError: # pragma: no cover - pl = None + from unittest.mock import MagicMock + + pl = MagicMock() try: import pydot except ImportError: # pragma: no cover - pydot = None + from unittest.mock import MagicMock + + pydot = MagicMock() try: from PIL import Image, ImageDraw, ImageFont except ImportError: # pragma: no cover - Image = ImageDraw = ImageFont = None + from unittest.mock import MagicMock + + Image = ImageDraw = ImageFont = MagicMock() __all__ = [ diff --git a/bigtree/utils/assertions.py b/bigtree/utils/assertions.py index 185490d3..a6ca31f2 100644 --- a/bigtree/utils/assertions.py +++ b/bigtree/utils/assertions.py @@ -5,12 +5,16 @@ try: import pandas as pd except ImportError: # pragma: no cover - pd = None + from unittest.mock import MagicMock + + pd = MagicMock() try: import polars as pl except ImportError: # pragma: no cover - pl = None + from unittest.mock import MagicMock + + pl = MagicMock() if TYPE_CHECKING: diff --git a/bigtree/utils/plot.py b/bigtree/utils/plot.py index 0d5447a1..c50f8014 100644 --- a/bigtree/utils/plot.py +++ b/bigtree/utils/plot.py @@ -7,7 +7,9 @@ try: import matplotlib.pyplot as plt except ImportError: # pragma: no cover - plt = None + from unittest.mock import MagicMock + + plt = MagicMock() __all__ = [ @@ -84,35 +86,45 @@ def reingold_tilford( @optional_dependencies_matplotlib def plot_tree( - tree_node: T, *args: Any, save_path: str = "", **kwargs: Any + tree_node: T, *args: Any, ax: Optional[plt.Axes] = None, **kwargs: Any ) -> plt.Figure: - """Plot tree in line form. Tree should have `x` and `y` attribute. - Accepts args and kwargs for matplotlib.pyplot.plot() function. + """Plot tree in line form. Tree should have `x` and `y` attribute from Reingold Tilford. + Accepts existing matplotlib Axes. Accepts args and kwargs for matplotlib.pyplot.plot() function. Examples: + >>> import matplotlib.pyplot as plt >>> from bigtree import list_to_tree, plot_tree, reingold_tilford >>> path_list = ["a/b/d", "a/b/e/g", "a/b/e/h", "a/c/f"] >>> root = list_to_tree(path_list) >>> reingold_tilford(root) - >>> plot_tree(root, "-ok", save_path="tree.png") -
+ >>> plot_tree(root, "-ok") +
Args: tree_node (BaseNode): tree to plot - save_path (str): save path of plot + ax (plt.Axes): axes to add Figure to """ + if ax: + fig = ax.get_figure() + else: + fig = plt.figure() + ax = fig.add_subplot(111) for node in preorder_iter(tree_node): if not node.is_root: - plt.plot( - [node.get_attr("x"), node.parent.get_attr("x")], - [node.get_attr("y"), node.parent.get_attr("y")], - *args, - **kwargs, - ) - if save_path: - plt.savefig(save_path) - return plt.figure() + try: + ax.plot( + [node.x, node.parent.x], # type: ignore + [node.y, node.parent.y], # type: ignore + *args, + **kwargs, + ) + except AttributeError: + raise RuntimeError( + "No x or y coordinates detected. " + "Please run reingold_tilford algorithm to retrieve coordinates." + ) + return fig def _first_pass( diff --git a/bigtree/workflows/app_calendar.py b/bigtree/workflows/app_calendar.py index d6482c1d..4c2ae180 100644 --- a/bigtree/workflows/app_calendar.py +++ b/bigtree/workflows/app_calendar.py @@ -11,7 +11,9 @@ try: import pandas as pd except ImportError: # pragma: no cover - pd = None + from unittest.mock import MagicMock + + pydot = MagicMock() class Calendar: diff --git a/docs/home/tree.md b/docs/home/tree.md index 53ba308e..50f989fd 100644 --- a/docs/home/tree.md +++ b/docs/home/tree.md @@ -50,6 +50,7 @@ For **Tree** implementation, there are 9 main components. ## [**📊 Plotting Tree**](../bigtree/utils/plot.md) - Enhanced Reingold Tilford Algorithm to retrieve (x, y) coordinates for a tree structure +- Plot tree using matplotlib (optional dependency) ## [**🔨 Exporting Tree**](../bigtree/tree/export.md) - Print to console, in vertical or horizontal orientation diff --git a/tests/node/test_basenode.py b/tests/node/test_basenode.py index 5621377a..0e61f32d 100644 --- a/tests/node/test_basenode.py +++ b/tests/node/test_basenode.py @@ -1,6 +1,7 @@ import copy import unittest +import matplotlib.pyplot as plt import pytest from bigtree.node.basenode import BaseNode @@ -668,6 +669,19 @@ def test_rollback_set_children_reassign(self): child.parent == parent ), f"Node {child} parent, expected {parent}, received {child.parent}" + def test_plot(self): + self.a.children = [self.b, self.c] + fig = self.a.plot() + assert isinstance(fig, plt.Figure) + + def test_plot_with_reingold_tilford(self): + from bigtree.utils.plot import reingold_tilford + + self.a.children = [self.b, self.c] + reingold_tilford(self.a) + fig = self.a.plot() + assert isinstance(fig, plt.Figure) + def assert_tree_structure_basenode_root(root): """Test tree structure (i.e., ancestors, descendants, leaves, siblings, etc.)""" diff --git a/tests/test_constants.py b/tests/test_constants.py index 957bf2f1..c2adaa59 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -221,5 +221,11 @@ class Constants: "Expected more than or equal to {count} element(s), found " ) + # tree/utils + ERROR_PLOT = ( + "No x or y coordinates detected. " + "Please run reingold_tilford algorithm to retrieve coordinates." + ) + # workflow/todo ERROR_WORKFLOW_TODO_TYPE = "Invalid data type for item" diff --git a/tests/utils/test_plot.py b/tests/utils/test_plot.py index 1cb19e00..d1ccd183 100644 --- a/tests/utils/test_plot.py +++ b/tests/utils/test_plot.py @@ -1,11 +1,45 @@ import unittest +import matplotlib.pyplot as plt import pytest from bigtree.node.node import Node from bigtree.tree.construct import list_to_tree from bigtree.utils.iterators import postorder_iter -from bigtree.utils.plot import _first_pass, reingold_tilford +from bigtree.utils.plot import _first_pass, plot_tree, reingold_tilford +from tests.test_constants import Constants + +LOCAL = Constants.LOCAL + + +class TestPlotTree(unittest.TestCase): + def test_plot_tree_runtime_error(self): + root = Node("a", children=[Node("b")]) + with pytest.raises(RuntimeError) as exc_info: + plot_tree(root) + assert str(exc_info.value) == Constants.ERROR_PLOT + + def test_plot_tree_with_fig(self): + root = Node("a", children=[Node("b"), Node("c")]) + reingold_tilford(root) + + fig = plt.figure() + ax = fig.add_subplot(111) + fig = plot_tree(root, ax=ax) + if LOCAL: + fig.savefig("tests/plot_tree_fig.png") + assert isinstance(fig, plt.Figure) + + def test_plot_tree_with_fig_and_args(self): + root = Node("a", children=[Node("b"), Node("c")]) + reingold_tilford(root) + + fig = plt.figure() + ax = fig.add_subplot(111) + fig = plot_tree(root, "-ok", ax=ax) + if LOCAL: + fig.savefig("tests/plot_tree_fig_and_args.png") + assert isinstance(fig, plt.Figure) class TestPlotNoChildren(unittest.TestCase):