diff --git a/colorsynth/_colorsynth.py b/colorsynth/_colorsynth.py index 9378352..dcbf87f 100644 --- a/colorsynth/_colorsynth.py +++ b/colorsynth/_colorsynth.py @@ -1,6 +1,8 @@ from typing import Callable import pathlib +import math import numpy as np +import numba import astropy.units as u __all__ = [ @@ -87,15 +89,40 @@ def _piecewise_gaussian( mean: u.Quantity, stddev_1: u.Quantity, stddev_2: u.Quantity, -): - where = x < mean - not_where = ~where - result = np.empty(x.shape) - result[where] = np.exp(-np.square((x[where] - mean) / stddev_1) / 2) - result[not_where] = np.exp(-np.square((x[not_where] - mean) / stddev_2) / 2) +) -> np.ndarray: + + unit = x.unit + x = x.value + mean = mean.to_value(unit) + stddev_1 = stddev_1.to_value(unit) + stddev_2 = stddev_2.to_value(unit) + + result = _piecewise_guassian_ufunc(x, mean, stddev_1, stddev_2) + return result +@numba.vectorize( + [numba.float64(numba.float64, numba.float64, numba.float64, numba.float64)], + target="parallel", +) +def _piecewise_guassian_ufunc( + x: float, + mean: float, + stddev_1: float, + stddev_2: float, +) -> float: # pragma: nocover + + if x < mean: + stddev = stddev_1 + else: + stddev = stddev_2 + + a = (x - mean) / stddev + + return math.exp(-a * a / 2) + + def color_matching_x(wavelength: u.Quantity) -> u.Quantity: r""" The CIE 1931 :math:`\overline{x}(\lambda)` color matching function. diff --git a/pyproject.toml b/pyproject.toml index 58889f8..520c8b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ classifiers = [ ] dependencies = [ "numpy", + "numba", "matplotlib", "astropy", ]