Skip to content

Commit

Permalink
Move BW to unstable, add simple tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-eschle committed May 2, 2019
1 parent ef3f9e9 commit bc4e0b4
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 25 deletions.
20 changes: 16 additions & 4 deletions tests/test_model_breit_wigner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Example test for a pdf or function"""

import pytest
import zfit
import numpy as np
# Important, do the imports below
from zfit.core.testing import setup_function, teardown_function, tester

Expand All @@ -9,11 +10,23 @@
# specify globals here. Do NOT add any TensorFlow but just pure python
param1_true = 0.3
param2_true = 1.2
obs1 = zfit.Space("obs1", limits=(-4, 5))


def test_special_property1():
def test_bw_pdf():
# test special properties here
assert True
bw = zphys.unstable.pdf.RelativisticBreitWignerSquared(obs=obs1, mres=1., wres=0.3)

integral = bw.integrate(limits=obs1)
assert pytest.approx(1., rel=1e-3) == zfit.run(integral)


def test_bw_func():
# test special properties here
bw = zphys.unstable.func.RelativisticBreitWigner(obs=obs1, mres=1., wres=0.3)

vals = bw.func(x=np.random.uniform(size=(100, 1)))
assert 100 == len(zfit.run(vals))


# register the pdf here and provide sets of working parameter configurations
Expand All @@ -23,5 +36,4 @@ def _bw_params_factory():
wres = zfit.Parameter('wres', param2_true)
return {"mres": mres, "wres": wres}


# tester.register_func(func_class=zfit.pdf.Gauss, params_factories=_bw_params_factory())
55 changes: 36 additions & 19 deletions zfit_physics/models/model_breit_wigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,32 @@ def relativistic_breit_wigner(m2, mres, wres):
return 1. / below_div


def _bw_func(self, x):
"""The Breit Wigner function wrapped to be used within a PDF or a Func directly.
Args:
self:
x:
Returns:
"""
var = x.unstack_x()
if isinstance(var, list):
m_sq = kinematics.mass_squared(tf.reduce_sum(
[kinematics.lorentz_vector(kinematics.vector(px, py, pz), pe)
for px, py, pz, pe in zip(*[iter(var)] * 4)],
axis=0))
elif self.using_m_squared:
m_sq = var
else:
m_sq = var * tf.math.conj(
var) # TODO(Albert): this was squared, but should be mult with conj, right?
mres = self.params['mres']
wres = self.params['wres']
return relativistic_breit_wigner(m_sq, mres, wres)


class RelativisticBreitWigner(zfit.func.BaseFunc):

def __init__(self, obs: ztyping.ObsTypeInput, mres: ztyping.ParamTypeInput, wres: ztyping.ParamTypeInput,
Expand Down Expand Up @@ -41,29 +67,20 @@ def __init__(self, obs: ztyping.ObsTypeInput, mres: ztyping.ParamTypeInput, wres
# HACK end

def _func(self, x):
var = x.unstack_x()
if isinstance(var, list):
m_sq = kinematics.mass_squared(tf.reduce_sum(
[kinematics.lorentz_vector(kinematics.vector(px, py, pz), pe)
for px, py, pz, pe in zip(*[iter(var)] * 4)],
axis=0))
elif self.using_m_squared:
m_sq = var
else:
m_sq = var * tf.math.conj(var) # TODO(Albert): this was squared, but should be mult with conj, right?
mres = self.params['mres']
wres = self.params['wres']
return relativistic_breit_wigner(m_sq, mres, wres)


class RelativisticBreitWignerSquared(RelativisticBreitWigner):
return _bw_func(self, x)


class RelativisticBreitWignerSquared(zfit.pdf.BasePDF):

def __init__(self, obs: ztyping.ObsTypeInput, mres: ztyping.ParamTypeInput, wres: ztyping.ParamTypeInput,
using_m_squared: bool = False, name="RelativisticBreitWignerPDF"):
super().__init__(obs=obs, mres=mres, wres=wres, using_m_squared=using_m_squared, name=name)
self.using_m_squared = using_m_squared

def _func(self, x):
propagator = super()._func(x)
super().__init__(obs=obs, name=name, dtype=zfit.settings.ztypes.float,
params={'mres': mres, 'wres': wres})

def _unnormalized_pdf(self, x):
propagator = _bw_func(self, x)
val = propagator * tf.math.conj(propagator)
val = ztf.to_real(val)
return val
2 changes: 1 addition & 1 deletion zfit_physics/unstable/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import pdf
from . import pdf, func
1 change: 1 addition & 0 deletions zfit_physics/unstable/func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ..models.model_breit_wigner import RelativisticBreitWigner
3 changes: 2 additions & 1 deletion zfit_physics/unstable/pdf.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ..models.pdf_conv import ConvPDF
from ..models.pdf_kde import GaussianKDE
from ..models.pdf_kde import GaussianKDE
from ..models.model_breit_wigner import RelativisticBreitWignerSquared

0 comments on commit bc4e0b4

Please sign in to comment.