-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* first attempt * update changelog * forgot to add the file * fixes * add testing * typo * add parameter check at class instantiation * Update zfit_physics/models/pdf_tsallis.py Co-authored-by: Jonas Eschle <[email protected]> * Update zfit_physics/models/pdf_tsallis.py Co-authored-by: Jonas Eschle <[email protected]> * add formula ref * fix --------- Co-authored-by: Jonas Eschle <[email protected]>
- Loading branch information
1 parent
b2f5127
commit b260af3
Showing
4 changed files
with
237 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""Tests for CMSShape PDF.""" | ||
import numpy as np | ||
import pytest | ||
import tensorflow as tf | ||
import zfit | ||
from numba_stats import tsallis as tsallis_numba | ||
|
||
# Important, do the imports below | ||
from zfit.core.testing import tester | ||
|
||
import zfit_physics as zphys | ||
|
||
# specify globals here. Do NOT add any TensorFlow but just pure python | ||
m_true = 90.0 | ||
t_true = 10.0 | ||
n_true = 3.0 | ||
|
||
|
||
def create_tsallis(m, t, n, limits): | ||
obs = zfit.Space("obs1", limits) | ||
tsallis = zphys.pdf.Tsallis(m=m, t=t, n=n, obs=obs) | ||
return tsallis, obs | ||
|
||
|
||
def test_tsallis_pdf(): | ||
# Test PDF here | ||
tsallis, _ = create_tsallis(m=m_true, t=t_true, n=n_true, limits=(0, 150)) | ||
assert tsallis.pdf(90.0, norm=False).numpy().item() == pytest.approx( | ||
tsallis_numba.pdf(90.0, m=m_true, t=t_true, n=n_true), 1e-5 | ||
) | ||
np.testing.assert_allclose( | ||
tsallis.pdf(tf.range(0.0, 150, 10_000), norm=False), | ||
tsallis_numba.pdf(tf.range(0.0, 150, 10_000).numpy(), m=m_true, t=t_true, n=n_true), | ||
rtol=1e-5, | ||
) | ||
|
||
sample = tsallis.sample(1000) | ||
assert all(np.isfinite(sample.value())), "Some samples from the tsallis PDF are NaN or infinite" | ||
assert sample.n_events == 1000 | ||
assert all(tf.logical_and(0 <= sample.value(), sample.value() <= 150)) | ||
|
||
|
||
def test_tsallis_integral(): | ||
# Test CDF and integral here | ||
tsallis, obs = create_tsallis(m=m_true, t=t_true, n=n_true, limits=(0, 150)) | ||
full_interval_analytic = zfit.run(tsallis.analytic_integrate(obs, norm=False)) | ||
full_interval_numeric = zfit.run(tsallis.numeric_integrate(obs, norm=False)) | ||
true_integral = 0.835415 | ||
numba_stats_full_integral = tsallis_numba.cdf(150, m=m_true, t=t_true, n=n_true) - tsallis_numba.cdf( | ||
0, m=m_true, t=t_true, n=n_true | ||
) | ||
assert full_interval_analytic == pytest.approx(true_integral, 1e-5) | ||
assert full_interval_numeric == pytest.approx(true_integral, 1e-5) | ||
assert full_interval_analytic == pytest.approx(numba_stats_full_integral, 1e-8) | ||
assert full_interval_numeric == pytest.approx(numba_stats_full_integral, 1e-8) | ||
|
||
analytic_integral = zfit.run(tsallis.analytic_integrate(limits=(20, 60), norm=False)) | ||
numeric_integral = zfit.run(tsallis.numeric_integrate(limits=(20, 60), norm=False)) | ||
numba_stats_integral = tsallis_numba.cdf(60, m=m_true, t=t_true, n=n_true) - tsallis_numba.cdf( | ||
20, m=m_true, t=t_true, n=n_true | ||
) | ||
assert analytic_integral == pytest.approx(numeric_integral, 1e-8) | ||
assert analytic_integral == pytest.approx(numba_stats_integral, 1e-8) | ||
|
||
|
||
# register the pdf here and provide sets of working parameter configurations | ||
def tsallis_params_factory(): | ||
m = zfit.Parameter("m", m_true) | ||
t = zfit.Parameter("t", t_true) | ||
n = zfit.Parameter("n", n_true) | ||
|
||
return {"m": m, "t": t, "n": n} | ||
|
||
|
||
tester.register_pdf(pdf_class=zphys.pdf.Tsallis, params_factories=tsallis_params_factory) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from typing import Optional | ||
|
||
import tensorflow as tf | ||
import zfit | ||
import zfit.z.numpy as znp | ||
from zfit import run, z | ||
from zfit.core.space import ANY_LOWER, ANY_UPPER, Space | ||
from zfit.util import ztyping | ||
|
||
|
||
@z.function(wraps="tensor") | ||
def tsallis_pdf_func(x, m, t, n): | ||
"""Calculate the Tsallis PDF. | ||
Args: | ||
x: value(s) for which the PDF will be calculated. | ||
m: mass of the particle. | ||
t: width parameter. | ||
n: absolute value of the exponent of the power law. | ||
Returns: | ||
`tf.Tensor`: The calculated PDF values. | ||
Notes: | ||
Based on code from `numba-stats <https://github.com/HDembinski/numba-stats/blob/main/src/numba_stats/tsallis.py>`_. | ||
Formula from CMS, Eur. Phys. J. C (2012) 72:2164 | ||
""" | ||
if run.executing_eagerly(): | ||
if n <= 2: | ||
msg = "n > 2 is required" | ||
raise ValueError(msg) | ||
elif run.numeric_checks: | ||
tf.debugging.assert_greater(n, znp.asarray(2.0), message="n > 2 is required") | ||
|
||
x = z.unstack_x(x) | ||
mt = znp.hypot(m, x) | ||
nt = n * t | ||
c = (n - 1) * (n - 2) / (nt * (nt + (n - 2) * m)) | ||
return c * x * znp.power(1 + (mt - m) / nt, -n) | ||
|
||
|
||
@z.function(wraps="tensor") | ||
def tsallis_cdf_func(x, m, t, n): | ||
"""Calculate the Tsallis CDF. | ||
Args: | ||
x: value(s) for which the CDF will be calculated. | ||
m: mass of the particle. | ||
t: width parameter. | ||
n: absolute value of the exponent of the power law. | ||
Returns: | ||
`tf.Tensor`: The calculated CDF values. | ||
Notes: | ||
Based on code from `numba-stats <https://github.com/HDembinski/numba-stats/blob/main/src/numba_stats/tsallis.py>`_. | ||
Formula from CMS, Eur. Phys. J. C (2012) 72:2164 | ||
""" | ||
if run.executing_eagerly(): | ||
if n <= 2: | ||
msg = "n > 2 is required" | ||
raise ValueError(msg) | ||
elif run.numeric_checks: | ||
tf.debugging.assert_greater(n, znp.asarray(2.0), message="n > 2 is required") | ||
|
||
x = z.unstack_x(x) | ||
mt = znp.hypot(m, x) | ||
nt = n * t | ||
return znp.power((mt - m) / nt + 1, 1 - n) * (m + mt - n * (mt + t)) / (m * (n - 2) + nt) | ||
|
||
|
||
def tsallis_integral(limits: ztyping.SpaceType, params: dict, model) -> tf.Tensor: | ||
"""Calculates the analytic integral of the Tsallis PDF. | ||
Args: | ||
limits: An object with attribute limit1d. | ||
params: A hashmap from which the parameters that defines the PDF will be extracted. | ||
model: Will be ignored. | ||
Returns: | ||
The calculated integral. | ||
""" | ||
lower, upper = limits._rect_limits_tf | ||
m = params["m"] | ||
t = params["t"] | ||
n = params["n"] | ||
lower_cdf = tsallis_cdf_func(x=lower, m=m, t=t, n=n) | ||
upper_cdf = tsallis_cdf_func(x=upper, m=m, t=t, n=n) | ||
return upper_cdf - lower_cdf | ||
|
||
|
||
class Tsallis(zfit.pdf.BasePDF): | ||
_N_OBS = 1 | ||
|
||
def __init__( | ||
self, | ||
m: ztyping.ParamTypeInput, | ||
t: ztyping.ParamTypeInput, | ||
n: ztyping.ParamTypeInput, | ||
obs: ztyping.ObsTypeInput, | ||
*, | ||
extended: Optional[ztyping.ExtendedInputType] = None, | ||
norm: Optional[ztyping.NormInputType] = None, | ||
name: str = "Tsallis", | ||
): | ||
"""Tsallis-Hagedorn PDF. | ||
A generalisation (q-analog) of the exponential distribution based on Tsallis entropy. | ||
It approximately describes the pT distribution charged particles produced in high-energy | ||
minimum bias particle collisions. | ||
Based on code from `numba-stats <https://github.com/HDembinski/numba-stats/blob/main/src/numba_stats/tsallis.py>`_. | ||
Formula from CMS, Eur. Phys. J. C (2012) 72:2164 | ||
Args: | ||
m: Mass of the particle. | ||
t: Width parameter. | ||
n: Absolute value of the exponent of the power law. | ||
obs: |@doc:pdf.init.obs| Observables of the | ||
model. This will be used as the default space of the PDF and, | ||
if not given explicitly, as the normalization range. | ||
The default space is used for example in the sample method: if no | ||
sampling limits are given, the default space is used. | ||
The observables are not equal to the domain as it does not restrict or | ||
truncate the model outside this range. |@docend:pdf.init.obs| | ||
extended: |@doc:pdf.init.extended| The overall yield of the PDF. | ||
If this is parameter-like, it will be used as the yield, | ||
the expected number of events, and the PDF will be extended. | ||
An extended PDF has additional functionality, such as the | ||
``ext_*`` methods and the ``counts`` (for binned PDFs). |@docend:pdf.init.extended| | ||
norm: |@doc:pdf.init.norm| Normalization of the PDF. | ||
By default, this is the same as the default space of the PDF. |@docend:pdf.init.norm| | ||
name: |@doc:pdf.init.name| Human-readable name | ||
or label of | ||
the PDF for better identification. | ||
Has no programmatical functional purpose as identification. |@docend:pdf.init.name| | ||
""" | ||
if run.executing_eagerly(): | ||
if n <= 2: | ||
msg = "n > 2 is required" | ||
raise ValueError(msg) | ||
elif run.numeric_checks: | ||
tf.debugging.assert_greater(n, znp.asarray(2.0), message="n > 2 is required") | ||
|
||
params = {"m": m, "t": t, "n": n} | ||
super().__init__(obs=obs, params=params, name=name, extended=extended, norm=norm) | ||
|
||
def _unnormalized_pdf(self, x: tf.Tensor) -> tf.Tensor: | ||
m = self.params["m"] | ||
t = self.params["t"] | ||
n = self.params["n"] | ||
return tsallis_pdf_func(x=x, m=m, t=t, n=n) | ||
|
||
|
||
tsallis_integral_limits = Space(axes=0, limits=(ANY_LOWER, ANY_UPPER)) | ||
Tsallis.register_analytic_integral(func=tsallis_integral, limits=tsallis_integral_limits) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters