-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
408 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent | ||
|
||
__app_name__ = "kajihs_utils" | ||
__version__ = "0.3.8" | ||
__version__ = "0.4.0" | ||
__authors__ = ["Kajih"] | ||
__author_emails__ = ["[email protected]"] | ||
__repo_url__ = "https://github.com/Kajiih/kajihs_utils" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
"""Tools for numpy.""" | ||
|
||
from collections.abc import Iterable | ||
from typing import Any, Literal | ||
|
||
import numpy as np | ||
from numpy import dtype, int_, ndarray | ||
from numpy.typing import ArrayLike, NDArray | ||
|
||
type Norm = float | Literal["fro", "nuc"] | ||
|
||
|
||
class IncompatibleShapeError(ValueError): | ||
"""Shapes of input arrays are incompatible for a given function.""" | ||
|
||
def __init__(self, arr1: NDArray[Any], arr2: NDArray[Any], obj: Any) -> None: | ||
Check warning on line 16 in src/kajihs_utils/numpy_utils.py GitHub Actions / build (3.12, ubuntu-latest)
|
||
super().__init__( | ||
f"Shapes of inputs arrays {arr1.shape} and {arr2.shape} are incompatible for {obj.__name__}" | ||
) | ||
|
||
|
||
# TODO: Add axis parameters | ||
def find_closest[T]( | ||
x: Iterable[T] | ArrayLike, | ||
targets: Iterable[T] | T | ArrayLike, | ||
norm_ord: Norm | None = None, | ||
) -> ndarray[tuple[int], dtype[int_]] | int_: | ||
""" | ||
Find the index of the closest element(s) from `x` for each target in `targets`. | ||
Given one or multiple `targets` (vectors vectors or scalars), | ||
this function computes the distance to each element in `x` and returns the | ||
indices of the closest matches. If `targets` is of the same shape as an | ||
element of `x`, the function returns a single integer index. If `targets` | ||
contains multiple elements, it returns an array of indices corresponding to | ||
each target. | ||
If the dimensionality of the vectors in `x` is greater than 2, the vectors | ||
will be flattened into 1D before computing distances. | ||
Args: | ||
x: An iterable or array-like collection of elements (scalars, vectors, | ||
or higher-dimensional arrays). For example, `x` could be an array of | ||
shape `(N,)` (scalars), `(N, D)` (D-dimensional vectors), | ||
`(N, H, W)` (2D arrays), or higher-dimensional arrays. | ||
targets: One or multiple target elements for which you want to find the | ||
closest match in `x`. Can be a single scalar/vector/array or an | ||
iterable of them. | ||
Must be shape-compatible with the elements of `x`. | ||
norm_ord: The order of the norm used for distance computation. | ||
Uses the same conventions as `numpy.linalg.norm`. | ||
Returns: | ||
An array of indices. If a single target was given, a single index is | ||
returned. If multiple targets were given, an array of shape `(M,)` is | ||
returned, where `M` is the number of target elements. Each value is the | ||
index of the closest element in `x` to the corresponding target. | ||
Raises: | ||
IncompatibleShapeError: If `targets` cannot be broadcast or reshaped to | ||
match the shape structure of the elements in `x`. | ||
Examples: | ||
>>> import numpy as np | ||
>>> x = np.array([0, 10, 20, 30]) | ||
>>> int(find_closest(x, 12)) | ||
1 | ||
>>> # Multiple targets | ||
>>> find_closest(x, [2, 26]) | ||
array([0, 3]) | ||
>>> # Using vectors | ||
>>> x = np.array([[0, 0], [10, 10], [20, 20]]) | ||
>>> int(find_closest(x, [6, 5])) # Single target vector | ||
1 | ||
>>> find_closest(x, [[-1, -1], [15, 12]]) # Multiple target vectors | ||
array([0, 1]) | ||
>>> # Higher dimensional arrays | ||
>>> x = np.array([[[0, 0], [0, 0]], [[10, 10], [10, 10]], [[20, 20], [20, 20]]]) | ||
>>> int(find_closest(x, [[2, 2], [2, 2]])) | ||
0 | ||
>>> find_closest(x, [[[0, 0], [1, 1]], [[19, 19], [19, 19]]]) | ||
array([0, 2]) | ||
""" | ||
x = np.array(x) # (N, vector_shape) | ||
targets = np.array(targets) | ||
vector_shape = x.shape[1:] | ||
|
||
# Check that shapes are compatible | ||
do_unsqueeze = False | ||
if targets.shape == vector_shape: | ||
targets = np.atleast_1d(targets)[np.newaxis, :] # (M, vector_shape) | ||
do_unsqueeze = True | ||
elif targets.shape[1:] != vector_shape: | ||
raise IncompatibleShapeError(x, targets, find_closest) | ||
|
||
nb_vectors = x.shape[0] # N | ||
nb_targets = targets.shape[0] # M | ||
|
||
diffs = x[:, np.newaxis] - targets | ||
|
||
match vector_shape: | ||
case (): | ||
distances = np.linalg.norm(diffs[:, np.newaxis], ord=norm_ord, axis=1) | ||
case (_,): | ||
distances = np.linalg.norm(diffs, ord=norm_ord, axis=2) | ||
case (_, _): | ||
distances = np.linalg.norm(diffs, ord=norm_ord, axis=(2, 3)) | ||
case _: # Tensors | ||
# Reshape to 1d vectors | ||
diffs = diffs.reshape(nb_vectors, nb_targets, -1) | ||
distances = np.linalg.norm(diffs, ord=norm_ord, axis=2) | ||
|
||
closest_indices = np.argmin(distances, axis=0) | ||
if do_unsqueeze: | ||
closest_indices = closest_indices[0] | ||
|
||
return closest_indices | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
import numpy as np | ||
import pytest | ||
from numpy.testing import assert_array_equal | ||
|
||
from kajihs_utils.numpy_utils import IncompatibleShapeError, find_closest | ||
|
||
|
||
class TestFindClosest: | ||
def test_scalar_inputs_single_target(self): | ||
# x: (N,), target: scalar | ||
x = np.array([0, 10, 20, 30]) | ||
target = 12 | ||
# The closest to 12 is 10 at index 1 | ||
result = find_closest(x, target) | ||
assert result == 1 | ||
|
||
def test_scalar_inputs_multiple_targets(self): | ||
# x: (N,), targets: (M,) | ||
x = np.array([0, 10, 20, 30]) | ||
targets = np.array([2, 26]) | ||
# Closest to 2 is 0 at index 0, closest to 26 is 30 at index 3 | ||
result = find_closest(x, targets) | ||
assert_array_equal(result, np.array([0, 3])) | ||
|
||
def test_vector_inputs_single_target(self): | ||
# x: (N, D), target: (D,) | ||
x = np.array([[0, 0], [10, 10], [20, 20]]) | ||
target = np.array([6, 5]) | ||
# Distances: | ||
# to [0,0] ~ sqrt(6^2+5^2)=sqrt(61) | ||
# to [10,10] ~ sqrt((10-6)^2+(10-5)^2)=sqrt(16+25)=sqrt(41)=closest | ||
# to [20,20] ~ sqrt((20-6)^2+(20-5)^2) large number | ||
result = find_closest(x, target) | ||
assert result == 1 | ||
|
||
def test_vector_inputs_multiple_targets(self): | ||
# x: (N, D), targets: (M, D) | ||
x = np.array([[0, 0], [10, 10], [20, 20]]) | ||
targets = np.array([[-1, -1], [15, 12]]) | ||
# Distances to [-1,-1]: | ||
# [0,0]: sqrt(1+1)=sqrt(2) | ||
# [10,10]: large | ||
# [20,20]: larger | ||
# Closest: index 0 | ||
# Distances to [15,12]: | ||
# [0,0]: sqrt(15^2+12^2)=sqrt(225+144)=sqrt(369) | ||
# [10,10]: sqrt(5^2+2^2)=sqrt(29)=closest | ||
# [20,20]: sqrt((20-15)^2+(20-12)^2)=sqrt(25+64)=sqrt(89) | ||
# Closest: index 1 | ||
result = find_closest(x, targets) | ||
assert_array_equal(result, np.array([0, 1])) | ||
|
||
def test_matrix_inputs_single_target(self): | ||
# x: (N, H, W), target: (H, W) | ||
x = np.array([[[0, 0], [0, 0]], [[10, 10], [10, 10]], [[20, 20], [20, 20]]]) | ||
target = np.array([[2, 2], [2, 2]]) | ||
# Distances (fro norm): | ||
# to [[[0,0],[0,0]]] = sqrt(2^2+2^2+2^2+2^2)=sqrt(16)=4 | ||
# to [[[10,10],[10,10]]] = large | ||
# to [[[20,20],[20,20]]] = larger | ||
# Closest: index 0 | ||
result = find_closest(x, target, norm_ord="fro") | ||
assert result == 0 | ||
|
||
def test_matrix_inputs_multiple_targets(self): | ||
# x: (N, H, W) | ||
x = np.array([[[0, 0], [0, 0]], [[10, 10], [10, 10]], [[20, 20], [20, 20]]]) | ||
targets = np.array([ | ||
[[0, 0], [1, 1]], # close to first | ||
[[19, 19], [19, 19]], # close to last | ||
]) | ||
# First target distances: | ||
# to x[0]: fro sqrt(0^2+0^2+(1^2)+(1^2)=sqrt(2)=1.414.. | ||
# to x[1]: fro sqrt((10^2+10^2)+(9^2+9^2)) large | ||
# to x[2]: even larger | ||
# Closest: index 0 | ||
# Second target distances: | ||
# to x[0]: big | ||
# to x[1]: sqrt((9^2 four times)= sqrt(81*4)=sqrt(324)=18 | ||
# to x[2]: sqrt((1^2 four times)= sqrt(4)=2 closest | ||
# Closest: index 2 | ||
result = find_closest(x, targets, norm_ord="fro") | ||
assert_array_equal(result, np.array([0, 2])) | ||
|
||
def test_four_dimensional_inputs(self): | ||
# x: (N, A, B, C) e.g. (3,2,2,2) | ||
x = np.array([ | ||
np.zeros((2, 2, 2)), | ||
np.ones((2, 2, 2)) * 10, | ||
np.ones((2, 2, 2)) * 20, | ||
]) | ||
target = np.ones((2, 2, 2)) * 9 | ||
# Flattened norms: | ||
# Dist to x[0]: sqrt(9^2 * 8 elements)= sqrt(81*8)= sqrt(648) | ||
# Dist to x[1]: sqrt((10-9)^2 * 8)= sqrt(1*8)= sqrt(8)=2.828... | ||
# Dist to x[2]: sqrt((20-9)^2 * 8)= large | ||
# Closest: index 1 | ||
result = find_closest(x, target) | ||
assert result == 1 | ||
|
||
def test_multiple_targets_for_four_dimensional(self): | ||
x = np.array([ | ||
np.zeros((2, 2, 2)), | ||
np.ones((2, 2, 2)) * 10, | ||
np.ones((2, 2, 2)) * 20, | ||
]) | ||
targets = np.array([ | ||
np.ones((2, 2, 2)), # closer to 0 (distance large) or 10? | ||
np.ones((2, 2, 2)) * 15, # closer to 10 or 20? | ||
]) | ||
# For target ~1: | ||
# Dist to x[0]: sqrt((1-0)^2 *8)= sqrt(1*8)= sqrt(8) | ||
# Dist to x[1]: sqrt((10-1)^2 *8)= sqrt(81*8)=sqrt(648) | ||
# Dist to x[2]: sqrt((20-1)^2 *8)= even bigger | ||
# Closest: index 0 | ||
# For target ~15: | ||
# Dist to x[0]: sqrt(15^2*8) large | ||
# Dist to x[1]: sqrt((10-15)^2*8)= sqrt(25*8)= sqrt(200) | ||
# Dist to x[2]: sqrt((20-15)^2*8)= sqrt(25*8)= sqrt(200) | ||
# Ties between index 1 and 2, np.argmin returns the first min, so index 1. | ||
result = find_closest(x, targets) | ||
assert_array_equal(result, np.array([0, 1])) | ||
|
||
def test_single_target_same_shape_returns_int(self): | ||
x = np.array([[0, 1], [2, 3], [4, 5]]) | ||
target = np.array([3, 4]) | ||
# This should return a single integer because target is the same shape as an element | ||
result = find_closest(x, target) | ||
assert isinstance(result, np.integer) | ||
|
||
def test_multiple_targets_return_array(self): | ||
x = np.array([[0, 1], [2, 3], [4, 5]]) | ||
targets = np.array([[1, 1], [3, 3]]) | ||
# This should return an array because we have multiple targets | ||
result = find_closest(x, targets) | ||
assert isinstance(result, np.ndarray) | ||
assert result.shape == (2,) | ||
|
||
def test_incompatible_shape_error_single_target(self): | ||
x = np.array([0, 10, 20, 30]) # shape (4,) | ||
target = np.array([1, 2]) # shape (2,) incompatible with (N,) elements | ||
with pytest.raises(IncompatibleShapeError): | ||
find_closest(x, target) | ||
|
||
def test_incompatible_shape_error_multiple_targets(self): | ||
x = np.array([[0, 0], [1, 1], [2, 2]]) # shape (3,2) | ||
targets = np.array([[0], [1]]) # shape (2,1), incompatible with (2,) | ||
with pytest.raises(IncompatibleShapeError): | ||
find_closest(x, targets) | ||
|
||
def test_norm_ord_none(self): | ||
# Check that passing norm_ord=None is acceptable and uses default norm (2-norm) | ||
x = np.array([[0, 0], [10, 10], [20, 20]]) | ||
target = np.array([6, 5]) | ||
result_default = find_closest(x, target, norm_ord=None) | ||
result_2 = find_closest(x, target, norm_ord=2) | ||
assert result_default == result_2 | ||
|
||
def test_invalid_norm_ord_for_non_matrix_with_fro(self): | ||
# "fro" norm is defined for matrices (2D), but let's see if it raises an error for vectors. | ||
x = np.array([0, 10, 20, 30]) # shape (4,), scalar dimension | ||
target = 12 | ||
# numpy.linalg.norm with 'fro' and 1D array works but treats it as a 2D array with shape (4,1) | ||
# which leads to a ValueError because 'fro' is only defined for 2D matrices. | ||
# We can test if this raises a ValueError. | ||
with pytest.raises(ValueError): | ||
find_closest(x, target, norm_ord="fro") | ||
|
||
def test_invalid_norm_ord_for_non_matrix_with_nuc(self): | ||
# "nuc" norm (nuclear norm) is only defined for 2D arrays. | ||
# Here we test with a 1D array and expect a ValueError. | ||
x = np.array([0, 10, 20, 30]) | ||
target = 12 | ||
with pytest.raises(ValueError): | ||
find_closest(x, target, norm_ord="nuc") | ||
|
||
def test_fro_norm_for_2d_arrays(self): | ||
# Valid usage of 'fro' norm for 2D arrays | ||
x = np.array([[0, 0], [10, 10], [20, 20]]) | ||
target = np.array([6, 5]) | ||
# Should not raise an error | ||
result = find_closest(x, target, norm_ord="fro") | ||
assert result == 1 | ||
|
||
def test_nuc_norm_for_2d_arrays(self): | ||
# The nuclear norm is defined for matrices (2D). | ||
x = np.array([[0, 0], [10, 10], [20, 20]]) | ||
target = np.array([6, 5]) | ||
# This might still raise a ValueError internally if the nuclear norm isn't defined for single rows. | ||
# For a single vector (1D considered as 2D?), let's reshape target to ensure it's 2D. | ||
x2 = np.array([[[0, 0]], [[10, 10]], [[20, 20]]]) # shape (3, 1, 2) | ||
target2 = np.array([[6, 5]]) # shape (1,2) | ||
# Here x2 elements are (1,2) matrices. Nuclear norm is defined. | ||
# Distances: | ||
# x2[0]: [[0,0]] vs [[6,5]] difference [[-6,-5]], norm = nuclear_norm = sum of singular values | ||
# Singular values of [[-6, -5]] is sqrt( (-6)^2 + (-5)^2 ) = sqrt(61) | ||
# similarly for others... | ||
# Just check no error is raised: | ||
result = find_closest(x2, target2, norm_ord="nuc") | ||
# Should return a single index | ||
assert isinstance(result, np.integer) | ||
|
||
def test_custom_norm_ord_int(self): | ||
# Test a custom integer norm (like 1-norm) | ||
x = np.array([[0, 0], [10, 10], [20, 20]]) | ||
target = np.array([6, 5]) | ||
# With 1-norm: | ||
# Distances: | ||
# to [0,0]: |6|+|5|=11 | ||
# to [10,10]: |10-6|+|10-5|=4+5=9 closest | ||
# to [20,20]: big | ||
result = find_closest(x, target, norm_ord=1) | ||
assert result == 1 | ||
|
||
def test_multiple_targets_and_int_norm(self): | ||
x = np.array([[0, 0], [10, 10], [20, 20]]) | ||
targets = np.array([[1, 1], [15, 15]]) | ||
# With 1-norm: | ||
# For [1,1]: | ||
# to [0,0]: sum(|1|+|1|)=2 | ||
# to [10,10]: sum(|9|+|9|)=18 | ||
# to [20,20]: sum(|19|+|19|)=38 | ||
# closest: 0 | ||
# For [15,15]: | ||
# to [0,0]: sum(|15|+|15|)=30 | ||
# to [10,10]: sum(|5|+|5|)=10 | ||
# to [20,20]: sum(|5|+|5|)=10 tie, argmin picks first: index 1 | ||
result = find_closest(x, targets, norm_ord=1) | ||
assert_array_equal(result, [0, 1]) |
Oops, something went wrong.