From 7f8a62864fb2f4587bfc9b3bc5a0bf8bf3766bee Mon Sep 17 00:00:00 2001 From: Alessandro Candido Date: Tue, 25 Apr 2023 19:35:46 +0200 Subject: [PATCH] Improve xgrid rotations structure, fix EKO raw representation --- src/eko/io/manipulate.py | 176 ++++++++++++++++++++++----------------- src/eko/io/struct.py | 6 +- tests/conftest.py | 5 +- 3 files changed, 104 insertions(+), 83 deletions(-) diff --git a/src/eko/io/manipulate.py b/src/eko/io/manipulate.py index 15173e800..9b32da337 100644 --- a/src/eko/io/manipulate.py +++ b/src/eko/io/manipulate.py @@ -1,12 +1,14 @@ """Manipulate output generate by EKO.""" import logging import warnings -from typing import Optional +from typing import Callable, Optional, Union import numpy as np +import numpy.typing as npt from .. import basis_rotation as br from .. import interpolation +from ..interpolation import XGrid from .struct import EKO logger = logging.getLogger(__name__) @@ -16,24 +18,58 @@ SIMGRID_ROTATION = "ij,ajbk,kl->aibl" """Simultaneous grid rotation contraction indices.""" +Basis = Union[XGrid, npt.NDArray] + + +def rotation(new: Optional[Basis], old: Basis, check: Callable, compute: Callable): + """Define grid rotation. + + This function returns the new grid to be assigned and the rotation computed, + if the checks for a non-trivial new grid are passed. + + However, the check and the computation are delegated respectively to the + callables `check` and `compute`. + + """ + if new is None: + return old, None + + if check(new, old): + warnings.warn("The new grid is close to the current one") + return old, None + + return new, compute(new, old) + + +def xgrid_check(new: Optional[XGrid], old: XGrid): + """Check validity of new xgrid.""" + return new is not None and len(new) == len(old) and np.allclose(new.raw, old.raw) + + +def xgrid_compute_rotation(new: XGrid, old: XGrid, interpdeg: int, swap: bool = False): + """Compute rotation from old to new xgrid. + + By default, the roation is computed for a target xgrid. Whether the function + should be used for an input xgrid, the `swap` argument should be set to + `True`, in order to compute it in the other direction (i.e. the transposed). + + """ + if swap: + new, old = old, new + b = interpolation.InterpolatorDispatcher(old, interpdeg, False) + return b.get_interpolation(new.raw) + def xgrid_reshape( eko: EKO, - targetgrid: Optional[interpolation.XGrid] = None, - inputgrid: Optional[interpolation.XGrid] = None, + targetgrid: Optional[XGrid] = None, + inputgrid: Optional[XGrid] = None, ): """Reinterpolate operators on output and/or input grids. - The operation is in-place. + Target corresponds to the output PDF. - Parameters - ---------- - eko : - the operator to be rotated - targetgrid : - xgrid for the target (output PDF) - inputgrid : - xgrid for the input (input PDF) + The operation is in-place. """ eko.assert_permissions(write=True) @@ -41,77 +77,61 @@ def xgrid_reshape( # calling with no arguments is an error if targetgrid is None and inputgrid is None: raise ValueError("Nor inputgrid nor targetgrid was given") - # now check to the current status - if ( - targetgrid is not None - and len(targetgrid) == len(eko.rotations.targetgrid) - and np.allclose(targetgrid.raw, eko.rotations.targetgrid.raw) - ): - targetgrid = None - warnings.warn("The new targetgrid is close to the current targetgrid") - if ( - inputgrid is not None - and len(inputgrid) == len(eko.rotations.inputgrid) - and np.allclose(inputgrid.raw, eko.rotations.inputgrid.raw) - ): - inputgrid = None - warnings.warn("The new inputgrid is close to the current inputgrid") + + interpdeg = eko.operator_card.configs.interpolation_polynomial_degree + check = xgrid_check + crot = xgrid_compute_rotation + + # construct matrices + newtarget, targetrot = rotation( + targetgrid, + eko.rotations.targetgrid, + check, + lambda new, old: crot(new, old, interpdeg), + ) + newinput, inputrot = rotation( + inputgrid, + eko.rotations.inputgrid, + check, + lambda new, old: crot(new, old, interpdeg, swap=True), + ) + # after the checks: if there is still nothing to do, skip - if targetgrid is None and inputgrid is None: + if targetrot is None and inputrot is None: logger.debug("Nothing done.") return - - # construct matrices - if targetgrid is not None: - b = interpolation.InterpolatorDispatcher( - eko.rotations.targetgrid, - eko.operator_card.configs.interpolation_polynomial_degree, - False, - ) - target_rot = b.get_interpolation(targetgrid.raw) - eko.rotations.targetgrid = targetgrid - if inputgrid is not None: - b = interpolation.InterpolatorDispatcher( - inputgrid, - eko.operator_card.configs.interpolation_polynomial_degree, - False, - ) - input_rot = b.get_interpolation(eko.rotations.inputgrid.raw) - eko.rotations.inputgrid = inputgrid + # if no rotation is done, the grids are not modified + if targetrot is not None: + eko.rotations.targetgrid = newtarget + if targetrot is not None: + eko.rotations.targetgrid = newinput # build new grid - for q2, elem in eko.items(): - ops = elem.operator - errs = elem.error - if targetgrid is not None and inputgrid is None: - ops = np.einsum(TARGETGRID_ROTATION, target_rot, ops, optimize="optimal") - errs = ( - np.einsum(TARGETGRID_ROTATION, target_rot, errs, optimize="optimal") - if errs is not None - else None - ) - elif inputgrid is not None and targetgrid is None: - ops = np.einsum(INPUTGRID_ROTATION, ops, input_rot, optimize="optimal") - errs = ( - np.einsum(INPUTGRID_ROTATION, errs, input_rot, optimize="optimal") - if errs is not None - else None - ) + for ep, elem in eko.items(): + assert elem is not None + + operands = [elem.operator] + operands_errs = [elem.error] + + if targetrot is not None and inputrot is None: + contraction = TARGETGRID_ROTATION + elif inputrot is not None and targetrot is None: + contraction = INPUTGRID_ROTATION else: - ops = np.einsum( - SIMGRID_ROTATION, target_rot, ops, input_rot, optimize="optimal" - ) - errs = ( - np.einsum( - SIMGRID_ROTATION, target_rot, errs, input_rot, optimize="optimal" - ) - if errs is not None - else None - ) - elem.operator = ops - elem.error = errs + contraction = SIMGRID_ROTATION - eko[q2] = elem + if targetrot is not None: + operands.insert(0, targetrot) + operands_errs.insert(0, targetrot) + if inputrot is not None: + operands.append(inputrot) + operands_errs.append(inputrot) + + elem.operator = np.einsum(contraction, *operands, optimize="optimal") + if elem.error is not None: + elem.error = np.einsum(contraction, *operands_errs, optimize="optimal") + + eko[ep] = elem eko.update() @@ -124,8 +144,8 @@ def xgrid_reshape( def flavor_reshape( eko: EKO, - targetpids: Optional[np.ndarray] = None, - inputpids: Optional[np.ndarray] = None, + targetpids: Optional[npt.NDArray] = None, + inputpids: Optional[npt.NDArray] = None, update: bool = True, ): """Change the operators to have in the output targetpids and/or in the input inputpids. diff --git a/src/eko/io/struct.py b/src/eko/io/struct.py index 9dbaff131..ca364e22d 100644 --- a/src/eko/io/struct.py +++ b/src/eko/io/struct.py @@ -773,8 +773,8 @@ def items(self): immediately after """ - for ep, op in self._operators.items(): - yield ep, op + for ep in self._operators: + yield ep, self[ep] del self[ep] def __contains__(self, q2: float) -> bool: @@ -1015,7 +1015,7 @@ def raw(self) -> dict: operators themselves """ - return dict(mu2grid=self.mu2grid.tolist(), metadata=self.metadata.raw) + return dict(mu2grid=self.mu2grid, metadata=self.metadata.raw) @dataclass diff --git a/tests/conftest.py b/tests/conftest.py index ce91ee8cb..2e54dc171 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from eko import interpolation from eko.io.runcards import OperatorCard, TheoryCard from eko.io.struct import EKO, Operator +from eko.io.types import EvolutionPoint from ekobox import cards @@ -80,7 +81,7 @@ def __init__(self, theory: TheoryCard, operator: OperatorCard, path: os.PathLike self.cache: Optional[EKO] = None @staticmethod - def _operators(mugrid: Iterable[float], shape: Tuple[int, int]): + def _operators(mugrid: Iterable[EvolutionPoint], shape: Tuple[int, int]): ops = {} for mu in mugrid: ops[mu] = Operator(np.random.rand(*shape, *shape)) @@ -94,7 +95,7 @@ def _create(self): lx = len(self.operator.xgrid) lpids = len(self.operator.pids) for mu2, op in self._operators( - mugrid=self.operator.mu2grid, shape=(lpids, lx) + mugrid=self.operator.evolgrid, shape=(lpids, lx) ).items(): self.cache[mu2] = op