-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'develop' into qshi/Ipatia2
- Loading branch information
Showing
14 changed files
with
571 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
ref-names: $Format:%D$ | ||
node: $Format:%H$ | ||
node-date: $Format:%cI$ | ||
describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
ComPWA | ||
======================= | ||
|
||
`ComPWA <https://compwa.github.io/>`_ is a framework for the coherent amplitude analysis of multi-body decays. It uses a symbolic approach to describe the decay amplitudes and can be used to fit data to extract the decay parameters. ComPWA can be used in combination with zfit to perform the fit by either creating a zfit pdf from the ComPWA model or by using the ComPWA estimator as a loss function for the zfit minimizer. | ||
|
||
Import the module with: | ||
|
||
.. code-block:: python | ||
|
||
import zfit_physics.compwa as zcompwa | ||
|
||
This will enable that :py:function:~` tensorwaves.estimator.Estimator`, can be used as a loss function in zfit minimizers as | ||
|
||
.. code-block:: python | ||
|
||
minimizer.minimize(loss=estimator) | ||
|
||
More explicitly, the loss function can be created with | ||
|
||
.. code-block:: python | ||
|
||
nll = zcompwa.loss.nll_from_estimator(estimator) | ||
|
||
which optionally takes already created :py:class:~`zfit.core.interfaces.ZfitParameter` as arguments. | ||
|
||
A whole ComPWA model can be converted to a zfit pdf with | ||
|
||
.. code-block:: python | ||
|
||
pdf = zcompwa.pdf.ComPWAPDF(compwa_model) | ||
|
||
``pdf`` is a full fledged zfit pdf that can be used in the same way as any other zfit pdf! In a sum, product, convolution and of course to fit data. | ||
|
||
Variables | ||
++++++++++++ | ||
|
||
|
||
.. automodule:: zfit_physics.compwa.variables | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
++++++++++++ | ||
|
||
.. automodule:: zfit_physics.compwa.pdf | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
Loss | ||
++++++++++++ | ||
|
||
.. automodule:: zfit_physics.compwa.loss | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from . import data, loss, pdf, variables | ||
|
||
__all__ = ["pdf", "variables", "loss"] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from __future__ import annotations | ||
|
||
import warnings | ||
from typing import TYPE_CHECKING | ||
|
||
import zfit | ||
from zfit.util.container import convert_to_container | ||
|
||
from .variables import params_from_intensity | ||
|
||
if TYPE_CHECKING: | ||
from tensorwaves.estimator import Estimator | ||
from zfit.core.interfaces import ZfitLoss | ||
|
||
__all__ = ["nll_from_estimator"] | ||
|
||
|
||
def nll_from_estimator(estimator: Estimator, *, params=None, errordef=None, numgrad=None) -> ZfitLoss: | ||
r"""Create a negative log-likelihood function from a tensorwaves estimator. | ||
|
||
Args: | ||
estimator: An estimator object that computes a scalar loss function. | ||
params: A list of zfit parameters that the loss function depends on. | ||
errordef: The error definition of the loss function. | ||
numgrad: If True, the gradient of the loss function is computed numerically and the ComPWA estimators | ||
gradient method is not used. Can be useful as not all backends in ComPWA support gradients. | ||
|
||
Returns: | ||
A zfit loss function that can be used with zfit. | ||
|
||
""" | ||
from tensorwaves.estimator import ChiSquared, UnbinnedNLL | ||
|
||
if params is None: | ||
classname = estimator.__class__.__name__ | ||
intensity = getattr(estimator, f"_{classname}__function", None) | ||
if intensity is None: | ||
msg = f"Could not find intensity function in {estimator}. Maybe the attribute changed?" | ||
raise ValueError(msg) | ||
params = params_from_intensity(intensity) | ||
else: | ||
params = convert_to_container(params) | ||
|
||
paramnames = [param.name for param in params] | ||
|
||
def func(params): | ||
paramdict = dict(zip(paramnames, params)) | ||
return estimator(paramdict) | ||
|
||
if numgrad: | ||
grad = None | ||
else: | ||
|
||
def grad(params): | ||
paramdict = dict(zip(paramnames, params)) | ||
return estimator.gradient(paramdict) | ||
|
||
if errordef is None: | ||
if hasattr(estimator, "errordef"): | ||
errordef = estimator.errordef | ||
elif isinstance(estimator, ChiSquared): | ||
errordef = 1.0 | ||
elif isinstance(estimator, UnbinnedNLL): | ||
errordef = 0.5 | ||
return zfit.loss.SimpleLoss(func=func, gradient=grad, params=params, errordef=errordef) | ||
|
||
|
||
def _nll_from_estimator_or_false(estimator: Estimator, *, params=None, errordef=None) -> ZfitLoss | bool: | ||
if "tensorwaves" in repr(type(estimator)): | ||
try: | ||
import tensorwaves as tw | ||
except ImportError: | ||
return False | ||
if not isinstance(estimator, (tw.estimator.ChiSquared, tw.estimator.UnbinnedNLL)): | ||
warnings.warn( | ||
"Only ChiSquared and UnbinnedNLL are supported from tensorwaves currently." | ||
f"TensorWaves is in name of {estimator}, this could be a bug.", | ||
stacklevel=2, | ||
) | ||
return False | ||
return nll_from_estimator(estimator, params=params, errordef=errordef) | ||
return None | ||
|
||
|
||
zfit.loss.SimpleLoss.register_convertable_loss(_nll_from_estimator_or_false) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from __future__ import annotations | ||
|
||
import tensorflow as tf | ||
import zfit # suppress tf warnings | ||
import zfit.z.numpy as znp | ||
from zfit import supports, z | ||
|
||
from .variables import obs_from_frame, params_from_intensity | ||
|
||
__all__ = ["ComPWAPDF"] | ||
|
||
|
||
class ComPWAPDF(zfit.pdf.BasePDF): | ||
def __init__(self, intensity, norm, obs=None, params=None, extended=None, name="ComPWA"): | ||
"""ComPWA intensity normalized over the *norm* dataset.""" | ||
if params is None: | ||
params = {p.name: p for p in params_from_intensity(intensity)} | ||
norm = zfit.Data(norm, obs=obs) | ||
if obs is None: | ||
obs = obs_from_frame(norm.to_pandas()) | ||
norm = norm.with_obs(obs) | ||
super().__init__(obs, params=params, name=name, extended=extended, autograd_params=[]) | ||
self.intensity = intensity | ||
norm = {ob: znp.array(ar) for ob, ar in zip(self.obs, z.unstack_x(norm))} | ||
self.norm_sample = norm | ||
|
||
@supports(norm=True) | ||
def _pdf(self, x, norm, params): | ||
paramvalsfloat = [] | ||
paramvalscomplex = [] | ||
iscomplex = [] | ||
# we need to split complex and floats to pass them to the numpy function, as it creates a tensor | ||
for val in params.values(): | ||
if val.dtype == znp.complex128: | ||
iscomplex.append(True) | ||
paramvalscomplex.append(val) | ||
paramvalsfloat.append(znp.zeros_like(val, dtype=znp.float64)) | ||
else: | ||
iscomplex.append(False) | ||
paramvalsfloat.append(val) | ||
paramvalscomplex.append(znp.zeros_like(val, dtype=znp.complex128)) | ||
|
||
def unnormalized_pdf_helper(x, paramvalsfloat, paramvalscomplex): | ||
data = {ob: znp.array(ar) for ob, ar in zip(self.obs, x)} | ||
paramsinternal = { | ||
n: c if isc else f for n, f, c, isc in zip(params.keys(), paramvalsfloat, paramvalscomplex, iscomplex) | ||
} | ||
self.intensity.update_parameters(paramsinternal) | ||
return self.intensity(data) | ||
|
||
xunstacked = z.unstack_x(x) | ||
|
||
probs = tf.numpy_function( | ||
unnormalized_pdf_helper, [xunstacked, paramvalsfloat, paramvalscomplex], Tout=tf.float64 | ||
) | ||
if norm is not False: | ||
normvalues = [znp.asarray(self.norm_sample[ob]) for ob in self.obs] | ||
normval = ( | ||
znp.mean( | ||
tf.numpy_function( | ||
unnormalized_pdf_helper, [normvalues, paramvalsfloat, paramvalscomplex], Tout=tf.float64 | ||
) | ||
) | ||
* znp.array([1.0]) # HACK: ComPWA just uses 1 as the phase space volume, better solution? | ||
# norm.volue is very small, since as it's done now (autoconverting in init), there are variables like | ||
# masses that have a tiny space, so the volume is very small | ||
# * norm.volume | ||
) | ||
normval.set_shape((1,)) | ||
probs /= normval | ||
probs.set_shape([None]) | ||
return probs | ||
|
||
# @z.function(wraps="tensorwaves") | ||
# def _jitted_normalization(self, norm, params): | ||
# return znp.mean(self._jitted_unnormalized_pdf(norm, params=params)) |
Oops, something went wrong.