Skip to content

Commit

Permalink
Optimized get_unit_normal() and replaced np.cross() with custom `…
Browse files Browse the repository at this point in the history
…cross()` in `manim.utils.space_ops` (ManimCommunity#3494)

* Added cross and optimized get_unit_normal in manim.utils.space_ops

* Added missing border case to new get_unit_normal where one vector is nonzero

* Updated test_threed.py::test_Sphere test data
  • Loading branch information
chopan050 authored Dec 6, 2023
1 parent 7cead84 commit 6949c66
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
69 changes: 50 additions & 19 deletions manim/utils/space_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from manim.typing import Point3D_Array, Vector
from manim.typing import Point3D_Array, Vector, Vector3

__all__ = [
"quaternion_mult",
Expand Down Expand Up @@ -53,6 +53,16 @@ def norm_squared(v: float) -> float:
return np.dot(v, v)


def cross(v1: Vector3, v2: Vector3) -> Vector3:
return np.array(
[
v1[1] * v2[2] - v1[2] * v2[1],
v1[2] * v2[0] - v1[0] * v2[2],
v1[0] * v2[1] - v1[1] * v2[0],
]
)


# Quaternions
# TODO, implement quaternion type

Expand Down Expand Up @@ -273,12 +283,12 @@ def z_to_vector(vector: np.ndarray) -> np.ndarray:
(normalized) vector provided as an argument
"""
axis_z = normalize(vector)
axis_y = normalize(np.cross(axis_z, RIGHT))
axis_x = np.cross(axis_y, axis_z)
axis_y = normalize(cross(axis_z, RIGHT))
axis_x = cross(axis_y, axis_z)
if np.linalg.norm(axis_y) == 0:
# the vector passed just so happened to be in the x direction.
axis_x = normalize(np.cross(UP, axis_z))
axis_y = -np.cross(axis_x, axis_z)
axis_x = normalize(cross(UP, axis_z))
axis_y = -cross(axis_x, axis_z)

return np.array([axis_x, axis_y, axis_z]).T

Expand Down Expand Up @@ -359,7 +369,7 @@ def normalize_along_axis(array: np.ndarray, axis: np.ndarray) -> np.ndarray:
return array


def get_unit_normal(v1: np.ndarray, v2: np.ndarray, tol: float = 1e-6) -> np.ndarray:
def get_unit_normal(v1: Vector3, v2: Vector3, tol: float = 1e-6) -> Vector3:
"""Gets the unit normal of the vectors.
Parameters
Expand All @@ -376,16 +386,37 @@ def get_unit_normal(v1: np.ndarray, v2: np.ndarray, tol: float = 1e-6) -> np.nda
np.ndarray
The normal of the two vectors.
"""
v1, v2 = (normalize(i) for i in (v1, v2))
cp = np.cross(v1, v2)
cp_norm = np.linalg.norm(cp)
if cp_norm < tol:
# Vectors align, so find a normal to them in the plane shared with the z-axis
cp = np.cross(np.cross(v1, OUT), v1)
cp_norm = np.linalg.norm(cp)
if cp_norm < tol:
# Instead of normalizing v1 and v2, just divide by the greatest
# of all their absolute components, which is just enough
div1, div2 = max(np.abs(v1)), max(np.abs(v2))
if div1 == 0.0:
if div2 == 0.0:
return DOWN
return normalize(cp)
u = v2 / div2
elif div2 == 0.0:
u = v1 / div1
else:
# Normal scenario: v1 and v2 are both non-null
u1, u2 = v1 / div1, v2 / div2
cp = cross(u1, u2)
cp_norm = np.sqrt(norm_squared(cp))
if cp_norm > tol:
return cp / cp_norm
# Otherwise, v1 and v2 were aligned
u = u1

# If you are here, you have an "unique", non-zero, unit-ish vector u
# If it's also too aligned to the Z axis, just return DOWN
if abs(u[0]) < tol and abs(u[1]) < tol:
return DOWN
# Otherwise rotate u in the plane it shares with the Z axis,
# 90° TOWARDS the Z axis. This is done via (u x [0, 0, 1]) x u,
# which gives [-xz, -yz, x²+y²] (slightly scaled as well)
cp = np.array([-u[0] * u[2], -u[1] * u[2], u[0] * u[0] + u[1] * u[1]])
cp_norm = np.sqrt(norm_squared(cp))
# Because the norm(u) == 0 case was filtered in the beginning,
# there is no need to check if the norm of cp is 0
return cp / cp_norm


###
Expand Down Expand Up @@ -529,8 +560,8 @@ def line_intersection(
np.pad(np.array(i)[:, :2], ((0, 0), (0, 1)), constant_values=1)
for i in (line1, line2)
)
line1, line2 = (np.cross(*i) for i in padded)
x, y, z = np.cross(line1, line2)
line1, line2 = (cross(*i) for i in padded)
x, y, z = cross(line1, line2)

if z == 0:
raise ValueError(
Expand Down Expand Up @@ -558,7 +589,7 @@ def find_intersection(
result = []

for p0, v0, p1, v1 in zip(*[p0s, v0s, p1s, v1s]):
normal = np.cross(v1, np.cross(v0, v1))
normal = cross(v1, cross(v0, v1))
denom = max(np.dot(v0, normal), threshold)
result += [p0 + np.dot(p1 - p0, normal) / denom * v0]
return result
Expand Down Expand Up @@ -814,6 +845,6 @@ def perpendicular_bisector(
"""
p1 = line[0]
p2 = line[1]
direction = np.cross(p1 - p2, norm_vector)
direction = cross(p1 - p2, norm_vector)
m = midpoint(p1, p2)
return [m + direction, m - direction]
Binary file modified tests/test_graphical_units/control_data/threed/Sphere.npz
Binary file not shown.

0 comments on commit 6949c66

Please sign in to comment.