diff --git a/CHANGELOG.md b/CHANGELOG.md index 3cf83ee7d..39894b727 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ * 24.x.x - Bug fixes: - Fix bug with 'median' and 'mean' methods in Masker averaging over the wrong axes. +- Enhancements: + - Removed multiple exits from numba implementation of KullbackLeibler divergence (#1901) * 24.2.0 - New Features: @@ -21,7 +23,7 @@ - Internal refactor: Separate framework into multiple files (#1692) - Allow the SIRT algorithm to take `initial=None` (#1906) - Add checks on equality method of `AcquisitionData` and `ImageData` for equality of data type and geometry (#1919) - - Add check on equality method of `AcquisitionGeometry` for equality of dimension labels (#1919) + - Add check on equality method of `AcquisitionGeometry` for equality of dimension labels (#1919) - Testing: - New unit tests for operators and functions to check for in place errors and the behaviour of `out` (#1805) - Updates in SPDHG vs PDHG unit test to reduce test time and adjustments to parameters (#1898) diff --git a/Wrappers/Python/cil/optimisation/functions/KullbackLeibler.py b/Wrappers/Python/cil/optimisation/functions/KullbackLeibler.py index 9fa890a6f..3a84fcc8d 100644 --- a/Wrappers/Python/cil/optimisation/functions/KullbackLeibler.py +++ b/Wrappers/Python/cil/optimisation/functions/KullbackLeibler.py @@ -343,25 +343,9 @@ def kl_gradient_mask(x, b, out, eta, mask): @njit(parallel=True) def kl_div(x, y, eta): accumulator = numpy.zeros(get_num_threads(), dtype=numpy.float64) + has_inf = 0 for i in prange(x.size): - X = x.flat[i] - Y = y.flat[i] + eta.flat[i] - if X > 0 and Y > 0: - # out.flat[i] = X * numpy.log(X/Y) - X + Y - accumulator[get_thread_id()] += X * numpy.log(X/Y) - X + Y - elif X == 0 and Y >= 0: - # out.flat[i] = Y - accumulator[get_thread_id()] += Y - else: - # out.flat[i] = numpy.inf - return numpy.inf - return sum(accumulator) - - @njit(parallel=True) - def kl_div_mask(x, y, eta, mask): - accumulator = numpy.zeros(get_num_threads(), dtype=numpy.float64) - for i in prange(x.size): - if mask.flat[i] > 0: + if has_inf == 0: X = x.flat[i] Y = y.flat[i] + eta.flat[i] if X > 0 and Y > 0: @@ -372,7 +356,29 @@ def kl_div_mask(x, y, eta, mask): accumulator[get_thread_id()] += Y else: # out.flat[i] = numpy.inf - return numpy.inf + accumulator[get_thread_id()] = numpy.inf + has_inf = 1 + return sum(accumulator) + + @njit(parallel=True) + def kl_div_mask(x, y, eta, mask): + accumulator = numpy.zeros(get_num_threads(), dtype=numpy.float64) + has_inf = 0 + for i in prange(x.size): + if has_inf == 0: + if mask.flat[i] > 0: + X = x.flat[i] + Y = y.flat[i] + eta.flat[i] + if X > 0 and Y > 0: + # out.flat[i] = X * numpy.log(X/Y) - X + Y + accumulator[get_thread_id()] += X * numpy.log(X/Y) - X + Y + elif X == 0 and Y >= 0: + # out.flat[i] = Y + accumulator[get_thread_id()] += Y + else: + # out.flat[i] = numpy.inf + accumulator[get_thread_id()] = numpy.inf + has_inf = 1 return sum(accumulator) # convex conjugate diff --git a/Wrappers/Python/test/test_function_KullbackLeibler.py b/Wrappers/Python/test/test_function_KullbackLeibler.py index e5e9dc9ed..a71ccfea5 100644 --- a/Wrappers/Python/test/test_function_KullbackLeibler.py +++ b/Wrappers/Python/test/test_function_KullbackLeibler.py @@ -173,6 +173,22 @@ def test_KullbackLeibler_numba_call(self): numpy.testing.assert_allclose(f(u1), f_np(u1), rtol=1e-5) + def test_KullbackLeibler_numba_kl_div_has_inf(self): + x = self.u1 * self.mask + x *= -1 + + from cil.optimisation.functions.KullbackLeibler import kl_div + + numpy.testing.assert_equal(kl_div(x.array, self.b1.array, self.eta.array), numpy.inf) + + def test_KullbackLeibler_numba_kl_div_mask_has_inf(self): + x = self.u1 * self.mask + x *= -1 + + from cil.optimisation.functions.KullbackLeibler import kl_div_mask + + numpy.testing.assert_equal(kl_div_mask(x.array, self.b1.array, self.eta.array, self.mask.array), numpy.inf) + def test_KullbackLeibler_numba_call_mask(self): f = self.f f_np = self.f_np