Skip to content

Commit

Permalink
✨ Add module for pyplot
Browse files Browse the repository at this point in the history
  • Loading branch information
Kajiih committed Dec 16, 2024
1 parent 91e00e9 commit eecffe0
Show file tree
Hide file tree
Showing 6 changed files with 353 additions and 44 deletions.
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ dependencies = [
]

[project.optional-dependencies]
jax = ["jax>=0.4.36"]
loguru = ["loguru>=0.7.3", "rich>=13.9.4"]
jax = ["jax>=0.4.36"]
loguru = ["loguru>=0.7.3", "rich>=13.9.4"]
pyplot = [
"matplotlib>=3.10.0",
]


[tool.uv]
Expand Down
2 changes: 1 addition & 1 deletion src/kajihs_utils/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent

__app_name__ = "kajihs_utils"
__version__ = "0.2.0"
__version__ = "0.3.0"
__authors__ = ["Kajih"]
__author_emails__ = ["[email protected]"]
__repo_url__ = "https://github.com/Kajiih/kajihs_utils"
Expand Down
54 changes: 54 additions & 0 deletions src/kajihs_utils/arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Utils for arithmetic.
close_factors and almost_factors taken from:
https://code.visualstudio.com/api/language-extensions/semantic-highlight-guide
"""


def closest_factors(n: int, /) -> tuple[int, int]:
"""
Find the closest pair of factors.
Args:
n: The number to find factors for.
Returns:
A tuple containing the two closest factors of n, the larger first.
Example:
>>> close_factors(99)
(11, 9)
"""
factor1 = 0
factor2 = n
while factor1 + 1 <= factor2:
factor1 += 1
if n % factor1 == 0:
factor2 = n // factor1

return factor1, factor2


def almost_factors(n: int, /, ratio: float = 0.5) -> tuple[int, int]:
"""
Find a pair of factors that are close enough.
Args:
n: The number to almost-factorize.
ratio: The threshold ratio between both factors.
Returns:
A tuple containing the first two numbers factoring to n or more such
that factor 1 is at most 1/ratio times larger than factor 2.
Example:
>>> almost_factors(10, ratio=0.5)
(4, 3)
"""
while True:
factor1, factor2 = closest_factors(n)
if ratio * factor1 <= factor2:
break
n += 1
return factor1, factor2
6 changes: 4 additions & 2 deletions src/kajihs_utils/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Core module."""
"""General utils without dependencies."""

from collections.abc import Iterable, Iterator, Mapping, Sequence
from typing import Any, Literal, overload
Expand All @@ -24,7 +24,7 @@ def get_first[K, V, D](
d: Mapping[K, V], /, keys: Iterable[K], default: D = None, no_default: bool = False
) -> V | D:
"""
Return the first value for the first key that exists in the mapping.
Return the value for the first key that exists in the mapping.
Args:
d: The dictionary to search in.
Expand Down Expand Up @@ -63,6 +63,8 @@ def batch[S: Sequence[Any]](seq: S, /, size: int) -> Iterator[S]:
"""
Generate batches of the sequence.
Maybe you better use the itertools.batched, it works with any iterable!
Args:
seq: The sequence to batch.
size: Size of the batches. Last batch may be shorter.
Expand Down
36 changes: 36 additions & 0 deletions src/kajihs_utils/pyplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Utils for matplotlib.pyplot."""

from typing import Any

import matplotlib.pyplot as plt

Check failure on line 5 in src/kajihs_utils/pyplot.py

View workflow job for this annotation

GitHub Actions / build (3.12, ubuntu-latest)

Import "matplotlib.pyplot" could not be resolved (reportMissingImports)
from matplotlib.figure import Figure
from numpy import ndarray

from .arithmetic import (
almost_factors,
)


def auto_subplot(

Check warning on line 14 in src/kajihs_utils/pyplot.py

View workflow job for this annotation

GitHub Actions / build (3.12, ubuntu-latest)

Return type, "tuple[Unknown, Unknown]", is partially unknown (reportUnknownParameterType)
size: int, /, ratio: float = 9 / 16, **subplot_params: Any
) -> tuple[Figure, ndarray[tuple[int], Any]]:
"""
Automatically creates a subplot grid with an adequate number of rows and columns.
Args:
size: The total number of subplots.
ratio: The threshold aspect ratio between rows and columns.
**subplot_params: Additional keyword parameters for subplot.
Returns:
Tuple containing the figure and the flatten axes.
"""
rows, cols = almost_factors(size, ratio)

fig, axes = plt.subplots(rows, cols, **subplot_params)

# if isinstance(axes, np.ndarray):
# axes = axes.flatten()
axes = axes.flatten()

return fig, axes
Loading

0 comments on commit eecffe0

Please sign in to comment.