Skip to content

Commit

Permalink
Improved performance of colorsynth._piecewise_gaussian() using Numb…
Browse files Browse the repository at this point in the history
…a. (#6)
  • Loading branch information
byrdie authored Aug 2, 2024
1 parent 9feb3b9 commit bd723a9
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
39 changes: 33 additions & 6 deletions colorsynth/_colorsynth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Callable
import pathlib
import math
import numpy as np
import numba
import astropy.units as u

__all__ = [
Expand Down Expand Up @@ -87,15 +89,40 @@ def _piecewise_gaussian(
mean: u.Quantity,
stddev_1: u.Quantity,
stddev_2: u.Quantity,
):
where = x < mean
not_where = ~where
result = np.empty(x.shape)
result[where] = np.exp(-np.square((x[where] - mean) / stddev_1) / 2)
result[not_where] = np.exp(-np.square((x[not_where] - mean) / stddev_2) / 2)
) -> np.ndarray:

unit = x.unit
x = x.value
mean = mean.to_value(unit)
stddev_1 = stddev_1.to_value(unit)
stddev_2 = stddev_2.to_value(unit)

result = _piecewise_guassian_ufunc(x, mean, stddev_1, stddev_2)

return result


@numba.vectorize(
[numba.float64(numba.float64, numba.float64, numba.float64, numba.float64)],
target="parallel",
)
def _piecewise_guassian_ufunc(
x: float,
mean: float,
stddev_1: float,
stddev_2: float,
) -> float: # pragma: nocover

if x < mean:
stddev = stddev_1
else:
stddev = stddev_2

a = (x - mean) / stddev

return math.exp(-a * a / 2)


def color_matching_x(wavelength: u.Quantity) -> u.Quantity:
r"""
The CIE 1931 :math:`\overline{x}(\lambda)` color matching function.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ classifiers = [
]
dependencies = [
"numpy",
"numba",
"matplotlib",
"astropy",
]
Expand Down

0 comments on commit bd723a9

Please sign in to comment.