Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-eschle committed Apr 23, 2024
1 parent aa850b0 commit dc71ebc
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 69 deletions.
2 changes: 2 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 20 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
ci:
autoupdate_commit_msg: "chore: update pre-commit hooks"
autofix_commit_msg: "style: pre-commit fixes"
autoupdate_schedule: quarterly

repos:
- repo: https://github.com/adamchainz/blacken-docs
rev: "1.16.0"
Expand Down Expand Up @@ -137,12 +142,24 @@ repos:
rev: 0.7.1
hooks:
- id: nbstripout

- repo: https://github.com/MarcoGorelli/auto-walrus
rev: 0.3.3
hooks:
- id: auto-walrus

- repo: https://github.com/shssoichiro/oxipng
rev: v9.1.0
# uncomment locally if needed, currently needs rust version not available on pre-commit.ci
# - repo: https://github.com/shssoichiro/oxipng
# rev: v9.1.0
# hooks:
# - id: oxipng

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.4.1"
hooks:
- id: oxipng
- id: ruff
types_or: [python, pyi, jupyter]
args: [--fix, --unsafe-fixes, --show-fixes]
# Run the formatter.
- id: ruff-format
types_or: [python, pyi, jupyter]
135 changes: 70 additions & 65 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,52 @@ build-backend = "hatchling.build"
[project]
name = "zfit-pwa"
authors = [
{ name = "Jonas Eschle", email = "[email protected]" },
{ name = "Jonas Eschle", email = "[email protected]" },
]
description = "Tools to adapt to Partial Wave Analysis packages"
readme = "README.md"
license.file = "LICENSE"
requires-python = ">=3.8"
classifiers = [
"Development Status :: 1 - Planning",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
"Typing :: Typed",
"Development Status :: 1 - Planning",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
"Typing :: Typed",
]
dynamic = ["version"]
dependencies = []
dependencies = ["zfit"]

[project.optional-dependencies]
test = [
"pytest >=6",
"pytest-cov >=3",
]
dev = [
"pytest >=6",
"pytest-cov >=3",
"pytest >=6",
"pytest-cov >=3",
]

docs = [
"sphinx>=7.0",
"myst_parser>=0.13",
"sphinx_copybutton",
"sphinx_autodoc_typehints",
"furo>=2023.08.17",
"sphinx>=7.0",
"myst_parser>=0.13",
"sphinx_copybutton",
"sphinx_autodoc_typehints",
"furo>=2023.08.17",
]
compwa = [
"qrules",
"ampform",
"tensorwaves",
]
dev = [
"zfit-pwa[test,docs,compwa]",
"tensorwaves[phsp]",
]

[project.urls]
Expand All @@ -67,22 +72,22 @@ scripts.test = "pytest {args}"

[tool.pytest.ini_options]
minversion = "6.0"
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"]
xfail_strict = true
filterwarnings = [
"error",
]
#addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"]
#xfail_strict = true
#filterwarnings = [
# "error",
#]
log_cli_level = "INFO"
testpaths = [
"tests",
"tests",
]


[tool.coverage]
run.source = ["zfit_pwa"]
report.exclude_also = [
'\.\.\.',
'if typing.TYPE_CHECKING:',
'\.\.\.',
'if typing.TYPE_CHECKING:',
]

[tool.mypy]
Expand All @@ -106,32 +111,32 @@ src = ["src"]

[tool.ruff.lint]
extend-select = [
"B", # flake8-bugbear
"I", # isort
"ARG", # flake8-unused-arguments
"C4", # flake8-comprehensions
"EM", # flake8-errmsg
"ICN", # flake8-import-conventions
"G", # flake8-logging-format
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PTH", # flake8-use-pathlib
"RET", # flake8-return
"RUF", # Ruff-specific
"SIM", # flake8-simplify
"T20", # flake8-print
"UP", # pyupgrade
"YTT", # flake8-2020
"EXE", # flake8-executable
"NPY", # NumPy specific rules
"PD", # pandas-vet
"B", # flake8-bugbear
"I", # isort
"ARG", # flake8-unused-arguments
"C4", # flake8-comprehensions
"EM", # flake8-errmsg
"ICN", # flake8-import-conventions
"G", # flake8-logging-format
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PTH", # flake8-use-pathlib
"RET", # flake8-return
"RUF", # Ruff-specific
"SIM", # flake8-simplify
"T20", # flake8-print
"UP", # pyupgrade
"YTT", # flake8-2020
"EXE", # flake8-executable
"NPY", # NumPy specific rules
"PD", # pandas-vet
]
ignore = [
"PLR09", # Too many <...>
"PLR2004", # Magic value used in comparison
"ISC001", # Conflicts with formatter
"PLR09", # Too many <...>
"PLR2004", # Magic value used in comparison
"ISC001", # Conflicts with formatter
]
isort.required-imports = ["from __future__ import annotations"]
# Uncomment if using a _compat.typing backport
Expand All @@ -148,9 +153,9 @@ ignore-paths = [".*/_version.py"]
reports.output-format = "colorized"
similarities.ignore-imports = "yes"
messages_control.disable = [
"design",
"fixme",
"line-too-long",
"missing-module-docstring",
"wrong-import-position",
"design",
"fixme",
"line-too-long",
"missing-module-docstring",
"wrong-import-position",
]
1 change: 0 additions & 1 deletion src/zfit_pwa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
zfit-pwa: Tools to adapt to Partial Wave Analysis packages
"""


from __future__ import annotations

from ._version import version as __version__
Expand Down
Empty file added src/zfit_pwa/compwa/__init__.py
Empty file.
Empty file added src/zfit_pwa/compwa/data.py
Empty file.
42 changes: 42 additions & 0 deletions src/zfit_pwa/compwa/variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from typing import Mapping

import numpy as np
import pandas as pd
import zfit
from zfit.core.interfaces import ZfitUnbinnedData


def obs_from_frame(frame1, frame2=None, bufferfactor=0.01):
obs = []
if frame2 is None:
frame2 = frame1

if isinstance(frame1, ZfitUnbinnedData) or isinstance(frame2, ZfitUnbinnedData):
return frame1.space

if not isinstance(frame1, (Mapping, pd.DataFrame)) or not isinstance(
frame2, (Mapping, pd.DataFrame)
):
raise ValueError(
"frame1 and frame2 have to be either a mapping or a pandas DataFrame, or a zfit Data object. They are currently of type: ",
type(frame1),
type(frame2),
)
for ob in frame2:
minimum = np.min([np.min(frame1[ob]), np.min(frame2[ob])])
maximum = np.max([np.max(frame1[ob]), np.max(frame2[ob])])
dist = maximum - minimum
buffer = bufferfactor * dist
obs.append(
zfit.Space(
ob,
limits=(
minimum - buffer,
maximum + buffer,
),
)
)
obsall = zfit.dimension.combine_spaces(*obs)
return obsall
60 changes: 60 additions & 0 deletions src/zfit_pwa/compwa/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

import types

import numpy as np
import zfit # suppress tf warnings
import zfit.z.numpy as znp
from zfit import supports, z

from zfit_pwa.compwa.variables import obs_from_frame


def patched_call(self, data) -> np.ndarray:
extended_data = {**self.__parameters, **data} # type: ignore[arg-type]
return self.__function(extended_data) # type: ignore[arg-type]


class ComPWAPDF(zfit.pdf.BasePDF):
def __init__(
self, intensity, norm, obs=None, params=None, extended=None, name="ComPWA"
):
"""ComPWA intensity normalized over the *norm* dataset."""
if params is None:
params = {
name: zfit.param.convert_to_parameter(
val, name=name, prefer_constant=False
)
for name, val in intensity.parameters.items()
}
if obs is None:
obs = obs_from_frame(norm)
intensity.__call__ = types.MethodType(patched_call, intensity)
super().__init__(obs, params=params, name=name, extended=extended)
self.intensity = intensity
norm = {ob: znp.array(ar) for ob, ar in zip(self.obs, z.unstack_x(norm))}
self.norm_sample = norm

@supports(norm=True)
def _pdf(self, x, norm):
data = {ob: znp.array(ar) for ob, ar in zip(self.obs, z.unstack_x(x))}
params = {p.name: znp.array(p) for p in self.params.values()}
data |= params

unnormalized_pdf = self._jitted_unnormalized_pdf(data)

if norm is False:
return unnormalized_pdf
else:
norm_sample = self.norm_sample | params
return unnormalized_pdf / self._jitted_normalization(norm_sample)

@z.function(wraps="tensorwaves")
def _jitted_unnormalized_pdf(self, data):
unnormalized_pdf = self.intensity(data)

return unnormalized_pdf

@z.function(wraps="tensorwaves")
def _jitted_normalization(self, norm_sample):
return znp.mean(self._jitted_unnormalized_pdf(norm_sample))
Loading

0 comments on commit dc71ebc

Please sign in to comment.