-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
aa850b0
commit dc71ebc
Showing
9 changed files
with
330 additions
and
69 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,47 +6,52 @@ build-backend = "hatchling.build" | |
[project] | ||
name = "zfit-pwa" | ||
authors = [ | ||
{ name = "Jonas Eschle", email = "[email protected]" }, | ||
{ name = "Jonas Eschle", email = "[email protected]" }, | ||
] | ||
description = "Tools to adapt to Partial Wave Analysis packages" | ||
readme = "README.md" | ||
license.file = "LICENSE" | ||
requires-python = ">=3.8" | ||
classifiers = [ | ||
"Development Status :: 1 - Planning", | ||
"Intended Audience :: Science/Research", | ||
"Intended Audience :: Developers", | ||
"License :: OSI Approved :: BSD License", | ||
"Operating System :: OS Independent", | ||
"Programming Language :: Python", | ||
"Programming Language :: Python :: 3", | ||
"Programming Language :: Python :: 3 :: Only", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12", | ||
"Topic :: Scientific/Engineering", | ||
"Typing :: Typed", | ||
"Development Status :: 1 - Planning", | ||
"Intended Audience :: Science/Research", | ||
"Intended Audience :: Developers", | ||
"License :: OSI Approved :: BSD License", | ||
"Operating System :: OS Independent", | ||
"Programming Language :: Python", | ||
"Programming Language :: Python :: 3", | ||
"Programming Language :: Python :: 3 :: Only", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12", | ||
"Topic :: Scientific/Engineering", | ||
"Typing :: Typed", | ||
] | ||
dynamic = ["version"] | ||
dependencies = [] | ||
dependencies = ["zfit"] | ||
|
||
[project.optional-dependencies] | ||
test = [ | ||
"pytest >=6", | ||
"pytest-cov >=3", | ||
] | ||
dev = [ | ||
"pytest >=6", | ||
"pytest-cov >=3", | ||
"pytest >=6", | ||
"pytest-cov >=3", | ||
] | ||
|
||
docs = [ | ||
"sphinx>=7.0", | ||
"myst_parser>=0.13", | ||
"sphinx_copybutton", | ||
"sphinx_autodoc_typehints", | ||
"furo>=2023.08.17", | ||
"sphinx>=7.0", | ||
"myst_parser>=0.13", | ||
"sphinx_copybutton", | ||
"sphinx_autodoc_typehints", | ||
"furo>=2023.08.17", | ||
] | ||
compwa = [ | ||
"qrules", | ||
"ampform", | ||
"tensorwaves", | ||
] | ||
dev = [ | ||
"zfit-pwa[test,docs,compwa]", | ||
"tensorwaves[phsp]", | ||
] | ||
|
||
[project.urls] | ||
|
@@ -67,22 +72,22 @@ scripts.test = "pytest {args}" | |
|
||
[tool.pytest.ini_options] | ||
minversion = "6.0" | ||
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] | ||
xfail_strict = true | ||
filterwarnings = [ | ||
"error", | ||
] | ||
#addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] | ||
#xfail_strict = true | ||
#filterwarnings = [ | ||
# "error", | ||
#] | ||
log_cli_level = "INFO" | ||
testpaths = [ | ||
"tests", | ||
"tests", | ||
] | ||
|
||
|
||
[tool.coverage] | ||
run.source = ["zfit_pwa"] | ||
report.exclude_also = [ | ||
'\.\.\.', | ||
'if typing.TYPE_CHECKING:', | ||
'\.\.\.', | ||
'if typing.TYPE_CHECKING:', | ||
] | ||
|
||
[tool.mypy] | ||
|
@@ -106,32 +111,32 @@ src = ["src"] | |
|
||
[tool.ruff.lint] | ||
extend-select = [ | ||
"B", # flake8-bugbear | ||
"I", # isort | ||
"ARG", # flake8-unused-arguments | ||
"C4", # flake8-comprehensions | ||
"EM", # flake8-errmsg | ||
"ICN", # flake8-import-conventions | ||
"G", # flake8-logging-format | ||
"PGH", # pygrep-hooks | ||
"PIE", # flake8-pie | ||
"PL", # pylint | ||
"PT", # flake8-pytest-style | ||
"PTH", # flake8-use-pathlib | ||
"RET", # flake8-return | ||
"RUF", # Ruff-specific | ||
"SIM", # flake8-simplify | ||
"T20", # flake8-print | ||
"UP", # pyupgrade | ||
"YTT", # flake8-2020 | ||
"EXE", # flake8-executable | ||
"NPY", # NumPy specific rules | ||
"PD", # pandas-vet | ||
"B", # flake8-bugbear | ||
"I", # isort | ||
"ARG", # flake8-unused-arguments | ||
"C4", # flake8-comprehensions | ||
"EM", # flake8-errmsg | ||
"ICN", # flake8-import-conventions | ||
"G", # flake8-logging-format | ||
"PGH", # pygrep-hooks | ||
"PIE", # flake8-pie | ||
"PL", # pylint | ||
"PT", # flake8-pytest-style | ||
"PTH", # flake8-use-pathlib | ||
"RET", # flake8-return | ||
"RUF", # Ruff-specific | ||
"SIM", # flake8-simplify | ||
"T20", # flake8-print | ||
"UP", # pyupgrade | ||
"YTT", # flake8-2020 | ||
"EXE", # flake8-executable | ||
"NPY", # NumPy specific rules | ||
"PD", # pandas-vet | ||
] | ||
ignore = [ | ||
"PLR09", # Too many <...> | ||
"PLR2004", # Magic value used in comparison | ||
"ISC001", # Conflicts with formatter | ||
"PLR09", # Too many <...> | ||
"PLR2004", # Magic value used in comparison | ||
"ISC001", # Conflicts with formatter | ||
] | ||
isort.required-imports = ["from __future__ import annotations"] | ||
# Uncomment if using a _compat.typing backport | ||
|
@@ -148,9 +153,9 @@ ignore-paths = [".*/_version.py"] | |
reports.output-format = "colorized" | ||
similarities.ignore-imports = "yes" | ||
messages_control.disable = [ | ||
"design", | ||
"fixme", | ||
"line-too-long", | ||
"missing-module-docstring", | ||
"wrong-import-position", | ||
"design", | ||
"fixme", | ||
"line-too-long", | ||
"missing-module-docstring", | ||
"wrong-import-position", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Mapping | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import zfit | ||
from zfit.core.interfaces import ZfitUnbinnedData | ||
|
||
|
||
def obs_from_frame(frame1, frame2=None, bufferfactor=0.01): | ||
obs = [] | ||
if frame2 is None: | ||
frame2 = frame1 | ||
|
||
if isinstance(frame1, ZfitUnbinnedData) or isinstance(frame2, ZfitUnbinnedData): | ||
return frame1.space | ||
|
||
if not isinstance(frame1, (Mapping, pd.DataFrame)) or not isinstance( | ||
frame2, (Mapping, pd.DataFrame) | ||
): | ||
raise ValueError( | ||
"frame1 and frame2 have to be either a mapping or a pandas DataFrame, or a zfit Data object. They are currently of type: ", | ||
type(frame1), | ||
type(frame2), | ||
) | ||
for ob in frame2: | ||
minimum = np.min([np.min(frame1[ob]), np.min(frame2[ob])]) | ||
maximum = np.max([np.max(frame1[ob]), np.max(frame2[ob])]) | ||
dist = maximum - minimum | ||
buffer = bufferfactor * dist | ||
obs.append( | ||
zfit.Space( | ||
ob, | ||
limits=( | ||
minimum - buffer, | ||
maximum + buffer, | ||
), | ||
) | ||
) | ||
obsall = zfit.dimension.combine_spaces(*obs) | ||
return obsall |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from __future__ import annotations | ||
|
||
import types | ||
|
||
import numpy as np | ||
import zfit # suppress tf warnings | ||
import zfit.z.numpy as znp | ||
from zfit import supports, z | ||
|
||
from zfit_pwa.compwa.variables import obs_from_frame | ||
|
||
|
||
def patched_call(self, data) -> np.ndarray: | ||
extended_data = {**self.__parameters, **data} # type: ignore[arg-type] | ||
return self.__function(extended_data) # type: ignore[arg-type] | ||
|
||
|
||
class ComPWAPDF(zfit.pdf.BasePDF): | ||
def __init__( | ||
self, intensity, norm, obs=None, params=None, extended=None, name="ComPWA" | ||
): | ||
"""ComPWA intensity normalized over the *norm* dataset.""" | ||
if params is None: | ||
params = { | ||
name: zfit.param.convert_to_parameter( | ||
val, name=name, prefer_constant=False | ||
) | ||
for name, val in intensity.parameters.items() | ||
} | ||
if obs is None: | ||
obs = obs_from_frame(norm) | ||
intensity.__call__ = types.MethodType(patched_call, intensity) | ||
super().__init__(obs, params=params, name=name, extended=extended) | ||
self.intensity = intensity | ||
norm = {ob: znp.array(ar) for ob, ar in zip(self.obs, z.unstack_x(norm))} | ||
self.norm_sample = norm | ||
|
||
@supports(norm=True) | ||
def _pdf(self, x, norm): | ||
data = {ob: znp.array(ar) for ob, ar in zip(self.obs, z.unstack_x(x))} | ||
params = {p.name: znp.array(p) for p in self.params.values()} | ||
data |= params | ||
|
||
unnormalized_pdf = self._jitted_unnormalized_pdf(data) | ||
|
||
if norm is False: | ||
return unnormalized_pdf | ||
else: | ||
norm_sample = self.norm_sample | params | ||
return unnormalized_pdf / self._jitted_normalization(norm_sample) | ||
|
||
@z.function(wraps="tensorwaves") | ||
def _jitted_unnormalized_pdf(self, data): | ||
unnormalized_pdf = self.intensity(data) | ||
|
||
return unnormalized_pdf | ||
|
||
@z.function(wraps="tensorwaves") | ||
def _jitted_normalization(self, norm_sample): | ||
return znp.mean(self._jitted_unnormalized_pdf(norm_sample)) |
Oops, something went wrong.