Skip to content

Commit

Permalink
✨ Implement find closest
Browse files Browse the repository at this point in the history
  • Loading branch information
Kajiih committed Dec 20, 2024
1 parent 82e267d commit 26b0eb4
Show file tree
Hide file tree
Showing 5 changed files with 408 additions and 59 deletions.
2 changes: 1 addition & 1 deletion src/kajihs_utils/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent

__app_name__ = "kajihs_utils"
__version__ = "0.3.8"
__version__ = "0.4.0"
__authors__ = ["Kajih"]
__author_emails__ = ["[email protected]"]
__repo_url__ = "https://github.com/Kajiih/kajihs_utils"
Expand Down
2 changes: 1 addition & 1 deletion src/kajihs_utils/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def closest_factors(n: int, /) -> tuple[int, int]:
A tuple containing the two closest factors of n, the larger first.
Example:
>>> close_factors(99)
>>> closest_factors(99)
(11, 9)
"""
factor1 = 0
Expand Down
120 changes: 120 additions & 0 deletions src/kajihs_utils/numpy_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Tools for numpy."""

from collections.abc import Iterable
from typing import Any, Literal

import numpy as np

Check failure on line 6 in src/kajihs_utils/numpy_utils.py

View workflow job for this annotation

GitHub Actions / build (3.12, ubuntu-latest)

Import "numpy" could not be resolved (reportMissingImports)
from numpy import dtype, int_, ndarray
from numpy.typing import ArrayLike, NDArray

type Norm = float | Literal["fro", "nuc"]


class IncompatibleShapeError(ValueError):
"""Shapes of input arrays are incompatible for a given function."""

def __init__(self, arr1: NDArray[Any], arr2: NDArray[Any], obj: Any) -> None:

Check warning on line 16 in src/kajihs_utils/numpy_utils.py

View workflow job for this annotation

GitHub Actions / build (3.12, ubuntu-latest)

Type of parameter "arr1" is unknown (reportUnknownParameterType)

Check warning on line 16 in src/kajihs_utils/numpy_utils.py

View workflow job for this annotation

GitHub Actions / build (3.12, ubuntu-latest)

Type of parameter "arr2" is unknown (reportUnknownParameterType)
super().__init__(
f"Shapes of inputs arrays {arr1.shape} and {arr2.shape} are incompatible for {obj.__name__}"
)


# TODO: Add axis parameters
def find_closest[T](

Check warning on line 23 in src/kajihs_utils/numpy_utils.py

View workflow job for this annotation

GitHub Actions / build (3.12, ubuntu-latest)

Return type is unknown (reportUnknownParameterType)
x: Iterable[T] | ArrayLike,

Check warning on line 24 in src/kajihs_utils/numpy_utils.py

View workflow job for this annotation

GitHub Actions / build (3.12, ubuntu-latest)

Type of parameter "x" is partially unknown   Parameter type is "Iterable[T@find_closest] | Unknown" (reportUnknownParameterType)
targets: Iterable[T] | T | ArrayLike,

Check warning on line 25 in src/kajihs_utils/numpy_utils.py

View workflow job for this annotation

GitHub Actions / build (3.12, ubuntu-latest)

Type of parameter "targets" is partially unknown   Parameter type is "Iterable[T@find_closest] | T@find_closest | Unknown" (reportUnknownParameterType)
norm_ord: Norm | None = None,
) -> ndarray[tuple[int], dtype[int_]] | int_:
"""
Find the index of the closest element(s) from `x` for each target in `targets`.
Given one or multiple `targets` (vectors vectors or scalars),
this function computes the distance to each element in `x` and returns the
indices of the closest matches. If `targets` is of the same shape as an
element of `x`, the function returns a single integer index. If `targets`
contains multiple elements, it returns an array of indices corresponding to
each target.
If the dimensionality of the vectors in `x` is greater than 2, the vectors
will be flattened into 1D before computing distances.
Args:
x: An iterable or array-like collection of elements (scalars, vectors,
or higher-dimensional arrays). For example, `x` could be an array of
shape `(N,)` (scalars), `(N, D)` (D-dimensional vectors),
`(N, H, W)` (2D arrays), or higher-dimensional arrays.
targets: One or multiple target elements for which you want to find the
closest match in `x`. Can be a single scalar/vector/array or an
iterable of them.
Must be shape-compatible with the elements of `x`.
norm_ord: The order of the norm used for distance computation.
Uses the same conventions as `numpy.linalg.norm`.
Returns:
An array of indices. If a single target was given, a single index is
returned. If multiple targets were given, an array of shape `(M,)` is
returned, where `M` is the number of target elements. Each value is the
index of the closest element in `x` to the corresponding target.
Raises:
IncompatibleShapeError: If `targets` cannot be broadcast or reshaped to
match the shape structure of the elements in `x`.
Examples:
>>> import numpy as np
>>> x = np.array([0, 10, 20, 30])
>>> int(find_closest(x, 12))
1
>>> # Multiple targets
>>> find_closest(x, [2, 26])
array([0, 3])
>>> # Using vectors
>>> x = np.array([[0, 0], [10, 10], [20, 20]])
>>> int(find_closest(x, [6, 5])) # Single target vector
1
>>> find_closest(x, [[-1, -1], [15, 12]]) # Multiple target vectors
array([0, 1])
>>> # Higher dimensional arrays
>>> x = np.array([[[0, 0], [0, 0]], [[10, 10], [10, 10]], [[20, 20], [20, 20]]])
>>> int(find_closest(x, [[2, 2], [2, 2]]))
0
>>> find_closest(x, [[[0, 0], [1, 1]], [[19, 19], [19, 19]]])
array([0, 2])
"""
x = np.array(x) # (N, vector_shape)
targets = np.array(targets)
vector_shape = x.shape[1:]

# Check that shapes are compatible
do_unsqueeze = False
if targets.shape == vector_shape:
targets = np.atleast_1d(targets)[np.newaxis, :] # (M, vector_shape)
do_unsqueeze = True
elif targets.shape[1:] != vector_shape:
raise IncompatibleShapeError(x, targets, find_closest)

nb_vectors = x.shape[0] # N
nb_targets = targets.shape[0] # M

diffs = x[:, np.newaxis] - targets

match vector_shape:
case ():
distances = np.linalg.norm(diffs[:, np.newaxis], ord=norm_ord, axis=1)
case (_,):
distances = np.linalg.norm(diffs, ord=norm_ord, axis=2)
case (_, _):
distances = np.linalg.norm(diffs, ord=norm_ord, axis=(2, 3))
case _: # Tensors
# Reshape to 1d vectors
diffs = diffs.reshape(nb_vectors, nb_targets, -1)
distances = np.linalg.norm(diffs, ord=norm_ord, axis=2)

closest_indices = np.argmin(distances, axis=0)
if do_unsqueeze:
closest_indices = closest_indices[0]

return closest_indices

229 changes: 229 additions & 0 deletions tests/test_numpy_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import numpy as np
import pytest
from numpy.testing import assert_array_equal

from kajihs_utils.numpy_utils import IncompatibleShapeError, find_closest


class TestFindClosest:
def test_scalar_inputs_single_target(self):
# x: (N,), target: scalar
x = np.array([0, 10, 20, 30])
target = 12
# The closest to 12 is 10 at index 1
result = find_closest(x, target)
assert result == 1

def test_scalar_inputs_multiple_targets(self):
# x: (N,), targets: (M,)
x = np.array([0, 10, 20, 30])
targets = np.array([2, 26])
# Closest to 2 is 0 at index 0, closest to 26 is 30 at index 3
result = find_closest(x, targets)
assert_array_equal(result, np.array([0, 3]))

def test_vector_inputs_single_target(self):
# x: (N, D), target: (D,)
x = np.array([[0, 0], [10, 10], [20, 20]])
target = np.array([6, 5])
# Distances:
# to [0,0] ~ sqrt(6^2+5^2)=sqrt(61)
# to [10,10] ~ sqrt((10-6)^2+(10-5)^2)=sqrt(16+25)=sqrt(41)=closest
# to [20,20] ~ sqrt((20-6)^2+(20-5)^2) large number
result = find_closest(x, target)
assert result == 1

def test_vector_inputs_multiple_targets(self):
# x: (N, D), targets: (M, D)
x = np.array([[0, 0], [10, 10], [20, 20]])
targets = np.array([[-1, -1], [15, 12]])
# Distances to [-1,-1]:
# [0,0]: sqrt(1+1)=sqrt(2)
# [10,10]: large
# [20,20]: larger
# Closest: index 0
# Distances to [15,12]:
# [0,0]: sqrt(15^2+12^2)=sqrt(225+144)=sqrt(369)
# [10,10]: sqrt(5^2+2^2)=sqrt(29)=closest
# [20,20]: sqrt((20-15)^2+(20-12)^2)=sqrt(25+64)=sqrt(89)
# Closest: index 1
result = find_closest(x, targets)
assert_array_equal(result, np.array([0, 1]))

def test_matrix_inputs_single_target(self):
# x: (N, H, W), target: (H, W)
x = np.array([[[0, 0], [0, 0]], [[10, 10], [10, 10]], [[20, 20], [20, 20]]])
target = np.array([[2, 2], [2, 2]])
# Distances (fro norm):
# to [[[0,0],[0,0]]] = sqrt(2^2+2^2+2^2+2^2)=sqrt(16)=4
# to [[[10,10],[10,10]]] = large
# to [[[20,20],[20,20]]] = larger
# Closest: index 0
result = find_closest(x, target, norm_ord="fro")
assert result == 0

def test_matrix_inputs_multiple_targets(self):
# x: (N, H, W)
x = np.array([[[0, 0], [0, 0]], [[10, 10], [10, 10]], [[20, 20], [20, 20]]])
targets = np.array([
[[0, 0], [1, 1]], # close to first
[[19, 19], [19, 19]], # close to last
])
# First target distances:
# to x[0]: fro sqrt(0^2+0^2+(1^2)+(1^2)=sqrt(2)=1.414..
# to x[1]: fro sqrt((10^2+10^2)+(9^2+9^2)) large
# to x[2]: even larger
# Closest: index 0
# Second target distances:
# to x[0]: big
# to x[1]: sqrt((9^2 four times)= sqrt(81*4)=sqrt(324)=18
# to x[2]: sqrt((1^2 four times)= sqrt(4)=2 closest
# Closest: index 2
result = find_closest(x, targets, norm_ord="fro")
assert_array_equal(result, np.array([0, 2]))

def test_four_dimensional_inputs(self):
# x: (N, A, B, C) e.g. (3,2,2,2)
x = np.array([
np.zeros((2, 2, 2)),
np.ones((2, 2, 2)) * 10,
np.ones((2, 2, 2)) * 20,
])
target = np.ones((2, 2, 2)) * 9
# Flattened norms:
# Dist to x[0]: sqrt(9^2 * 8 elements)= sqrt(81*8)= sqrt(648)
# Dist to x[1]: sqrt((10-9)^2 * 8)= sqrt(1*8)= sqrt(8)=2.828...
# Dist to x[2]: sqrt((20-9)^2 * 8)= large
# Closest: index 1
result = find_closest(x, target)
assert result == 1

def test_multiple_targets_for_four_dimensional(self):
x = np.array([
np.zeros((2, 2, 2)),
np.ones((2, 2, 2)) * 10,
np.ones((2, 2, 2)) * 20,
])
targets = np.array([
np.ones((2, 2, 2)), # closer to 0 (distance large) or 10?
np.ones((2, 2, 2)) * 15, # closer to 10 or 20?
])
# For target ~1:
# Dist to x[0]: sqrt((1-0)^2 *8)= sqrt(1*8)= sqrt(8)
# Dist to x[1]: sqrt((10-1)^2 *8)= sqrt(81*8)=sqrt(648)
# Dist to x[2]: sqrt((20-1)^2 *8)= even bigger
# Closest: index 0
# For target ~15:
# Dist to x[0]: sqrt(15^2*8) large
# Dist to x[1]: sqrt((10-15)^2*8)= sqrt(25*8)= sqrt(200)
# Dist to x[2]: sqrt((20-15)^2*8)= sqrt(25*8)= sqrt(200)
# Ties between index 1 and 2, np.argmin returns the first min, so index 1.
result = find_closest(x, targets)
assert_array_equal(result, np.array([0, 1]))

def test_single_target_same_shape_returns_int(self):
x = np.array([[0, 1], [2, 3], [4, 5]])
target = np.array([3, 4])
# This should return a single integer because target is the same shape as an element
result = find_closest(x, target)
assert isinstance(result, np.integer)

def test_multiple_targets_return_array(self):
x = np.array([[0, 1], [2, 3], [4, 5]])
targets = np.array([[1, 1], [3, 3]])
# This should return an array because we have multiple targets
result = find_closest(x, targets)
assert isinstance(result, np.ndarray)
assert result.shape == (2,)

def test_incompatible_shape_error_single_target(self):
x = np.array([0, 10, 20, 30]) # shape (4,)
target = np.array([1, 2]) # shape (2,) incompatible with (N,) elements
with pytest.raises(IncompatibleShapeError):
find_closest(x, target)

def test_incompatible_shape_error_multiple_targets(self):
x = np.array([[0, 0], [1, 1], [2, 2]]) # shape (3,2)
targets = np.array([[0], [1]]) # shape (2,1), incompatible with (2,)
with pytest.raises(IncompatibleShapeError):
find_closest(x, targets)

def test_norm_ord_none(self):
# Check that passing norm_ord=None is acceptable and uses default norm (2-norm)
x = np.array([[0, 0], [10, 10], [20, 20]])
target = np.array([6, 5])
result_default = find_closest(x, target, norm_ord=None)
result_2 = find_closest(x, target, norm_ord=2)
assert result_default == result_2

def test_invalid_norm_ord_for_non_matrix_with_fro(self):
# "fro" norm is defined for matrices (2D), but let's see if it raises an error for vectors.
x = np.array([0, 10, 20, 30]) # shape (4,), scalar dimension
target = 12
# numpy.linalg.norm with 'fro' and 1D array works but treats it as a 2D array with shape (4,1)
# which leads to a ValueError because 'fro' is only defined for 2D matrices.
# We can test if this raises a ValueError.
with pytest.raises(ValueError):
find_closest(x, target, norm_ord="fro")

def test_invalid_norm_ord_for_non_matrix_with_nuc(self):
# "nuc" norm (nuclear norm) is only defined for 2D arrays.
# Here we test with a 1D array and expect a ValueError.
x = np.array([0, 10, 20, 30])
target = 12
with pytest.raises(ValueError):
find_closest(x, target, norm_ord="nuc")

def test_fro_norm_for_2d_arrays(self):
# Valid usage of 'fro' norm for 2D arrays
x = np.array([[0, 0], [10, 10], [20, 20]])
target = np.array([6, 5])
# Should not raise an error
result = find_closest(x, target, norm_ord="fro")
assert result == 1

def test_nuc_norm_for_2d_arrays(self):
# The nuclear norm is defined for matrices (2D).
x = np.array([[0, 0], [10, 10], [20, 20]])
target = np.array([6, 5])
# This might still raise a ValueError internally if the nuclear norm isn't defined for single rows.
# For a single vector (1D considered as 2D?), let's reshape target to ensure it's 2D.
x2 = np.array([[[0, 0]], [[10, 10]], [[20, 20]]]) # shape (3, 1, 2)
target2 = np.array([[6, 5]]) # shape (1,2)
# Here x2 elements are (1,2) matrices. Nuclear norm is defined.
# Distances:
# x2[0]: [[0,0]] vs [[6,5]] difference [[-6,-5]], norm = nuclear_norm = sum of singular values
# Singular values of [[-6, -5]] is sqrt( (-6)^2 + (-5)^2 ) = sqrt(61)
# similarly for others...
# Just check no error is raised:
result = find_closest(x2, target2, norm_ord="nuc")
# Should return a single index
assert isinstance(result, np.integer)

def test_custom_norm_ord_int(self):
# Test a custom integer norm (like 1-norm)
x = np.array([[0, 0], [10, 10], [20, 20]])
target = np.array([6, 5])
# With 1-norm:
# Distances:
# to [0,0]: |6|+|5|=11
# to [10,10]: |10-6|+|10-5|=4+5=9 closest
# to [20,20]: big
result = find_closest(x, target, norm_ord=1)
assert result == 1

def test_multiple_targets_and_int_norm(self):
x = np.array([[0, 0], [10, 10], [20, 20]])
targets = np.array([[1, 1], [15, 15]])
# With 1-norm:
# For [1,1]:
# to [0,0]: sum(|1|+|1|)=2
# to [10,10]: sum(|9|+|9|)=18
# to [20,20]: sum(|19|+|19|)=38
# closest: 0
# For [15,15]:
# to [0,0]: sum(|15|+|15|)=30
# to [10,10]: sum(|5|+|5|)=10
# to [20,20]: sum(|5|+|5|)=10 tie, argmin picks first: index 1
result = find_closest(x, targets, norm_ord=1)
assert_array_equal(result, [0, 1])
Loading

0 comments on commit 26b0eb4

Please sign in to comment.