Skip to content

Commit

Permalink
Make vector package optional (#110)
Browse files Browse the repository at this point in the history
* Make `vector` package optional

* FIX: import `vector` inline
  • Loading branch information
redeboer authored Apr 23, 2024
1 parent f6db1e0 commit 7774693
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion phasespace/phasespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@
import inspect
from collections.abc import Callable
from math import pi
from typing import TYPE_CHECKING, NoReturn

import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
import vector

from . import kinematics as kin
from .backend import function, function_jit_fixedshape
from .random import SeedLike, get_rng

if TYPE_CHECKING:
import vector


RELAX_SHAPES = False


Expand Down Expand Up @@ -669,6 +673,10 @@ def generate(
"""
rng = get_rng(seed)
if boost_to is not None:
try:
import vector
except ImportError as error:
_raise_missing_vector_package(error)
if isinstance(boost_to, vector.Vector):
if not (
isinstance(boost_to, vector.Momentum)
Expand Down Expand Up @@ -815,8 +823,18 @@ def to_vectors(particles: dict[str, tf.Tensor]) -> dict[str, vector.Momentum]:
Return:
dict: Dictionary of `vector.Momentum` instances with numpy arrays
"""
try:
import vector
except ImportError as error:
_raise_missing_vector_package(error)
newparticles = {}
for name, particle in particles.items():
px, py, pz, e = np.moveaxis(particle, -1, 0) # numpy "unstack"
newparticles[name] = vector.array(dict(px=px, py=py, pz=pz, energy=e))
return newparticles


def _raise_missing_vector_package(exception: ImportError) -> NoReturn:
raise ImportError(
"To use `boost_to`, the `vector` package has to be installed."
) from exception

0 comments on commit 7774693

Please sign in to comment.