Skip to content

Commit

Permalink
enh: add RooFit compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-eschle committed Oct 13, 2024
1 parent 8c598e7 commit 1d7ff29
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 63 deletions.
127 changes: 64 additions & 63 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,90 +1,91 @@
ci:
autoupdate_schedule: quarterly

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-added-large-files
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: check-yaml
- id: check-toml
- id: debug-statements
- id: end-of-file-fixer
- id: mixed-line-ending
- id: requirements-txt-fixer
- id: trailing-whitespace
- id: detect-private-key
- id: fix-byte-order-marker
- id: check-ast

# - repo: https://github.com/PyCQA/docformatter
# rev: v1.7.5
# hooks:
# - id: docformatter
# args: [ -r, --in-place, --wrap-descriptions, '120', --wrap-summaries, '120', -- ]
- id: check-added-large-files
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: check-yaml
- id: check-toml
- id: debug-statements
- id: end-of-file-fixer
- id: mixed-line-ending
- id: requirements-txt-fixer
- id: trailing-whitespace
- id: detect-private-key
- id: fix-byte-order-marker
- id: check-ast

- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: python-use-type-annotations
- id: python-check-mock-methods
- id: python-no-eval
- id: rst-directive-colons

- repo: https://github.com/PyCQA/isort
- id: python-use-type-annotations
- id: python-check-mock-methods
- id: python-no-eval
- id: rst-directive-colons
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort

- repo: https://github.com/asottile/pyupgrade
- id: isort
- repo: https://github.com/asottile/pyupgrade
rev: v3.17.0
hooks:
- id: pyupgrade
args: [ --py38-plus ]

- repo: https://github.com/asottile/setup-cfg-fmt
- id: pyupgrade
args:
- --py38-plus
- repo: https://github.com/asottile/setup-cfg-fmt
rev: v2.5.0
hooks:
- id: setup-cfg-fmt
args: [ --max-py-version=3.12, --include-version-classifiers ]


# Notebook formatting
- repo: https://github.com/nbQA-dev/nbQA
- id: setup-cfg-fmt
args:
- --max-py-version=3.12
- --include-version-classifiers
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.8.7
hooks:
- id: nbqa-isort
additional_dependencies: [ isort ]

- id: nbqa-pyupgrade
additional_dependencies: [ pyupgrade ]
args: [ --py38-plus ]

- repo: https://github.com/mgedmin/check-manifest
- id: nbqa-isort
additional_dependencies:
- isort
- id: nbqa-pyupgrade
additional_dependencies:
- pyupgrade
args:
- --py38-plus
- repo: https://github.com/mgedmin/check-manifest
rev: '0.49'
hooks:
- id: check-manifest
stages: [ manual ]
- repo: https://github.com/sondrelg/pep585-upgrade
rev: 'v1.0'
- id: check-manifest
stages:
- manual
- repo: https://github.com/sondrelg/pep585-upgrade
rev: v1.0
hooks:
- id: upgrade-type-hints
args: [ '--futures=true' ]

- repo: https://github.com/MarcoGorelli/auto-walrus
- id: upgrade-type-hints
args:
- --futures=true
- repo: https://github.com/MarcoGorelli/auto-walrus
rev: 0.3.4
hooks:
- id: auto-walrus

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.6.9"
hooks:
- id: ruff
types_or: [ python, pyi, jupyter ]
args: [ --fix, --unsafe-fixes, --show-fixes , --line-length=120]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi, jupyter ]
- id: ruff
types_or:
- python
- pyi
- jupyter
args:
- --fix
- --unsafe-fixes
- --show-fixes
- --line-length=120
- id: ruff-format
types_or:
- python
- pyi
- jupyter
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Develop

Major Features and Improvements
-------------------------------
- add a RooFit compatibility layer and automatically convert losses, also inside minimizers (through ``SimpleLoss.from_any``)

Breaking changes
------------------
Expand Down
67 changes: 67 additions & 0 deletions tests/roofit/test_loss_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
import pytest


def test_loss_registry():
_ = pytest.importorskip("ROOT")
# Copyright (c) 2024 zfit

import zfit

import zfit_physics.roofit as zroofit

# create space
obs = zfit.Space("x", -2, 3)

# parameters
mu = zfit.Parameter("mu", 1.2, -4, 6)
sigma = zfit.Parameter("sigma", 1.3, 0.5, 10)

# model building, pdf creation
gauss = zfit.pdf.Gauss(mu=mu, sigma=sigma, obs=obs)

# data
ndraw = 10_000
data = np.random.normal(loc=2.0, scale=3.0, size=ndraw)
data = obs.filter(data) # works also for pandas DataFrame

from ROOT import RooArgSet, RooDataSet, RooFit, RooGaussian, RooRealVar

mur = RooRealVar("mu", "mu", 1.2, -4, 6)
sigmar = RooRealVar("sigma", "sigma", 1.3, 0.5, 10)
obsr = RooRealVar("x", "x", -2, 3)
gaussr = RooGaussian("gauss", "gauss", obsr, mur, sigmar)

datar = RooDataSet("data", "data", {obsr})
for d in data:
obsr.setVal(d)
datar.add(RooArgSet(obsr))

# create a loss function
nll = gaussr.createNLL(datar)

nllz = zfit.loss.UnbinnedNLL(model=gauss, data=data)

# create a minimizer
tol = 1e-3
verbosity = 0
minimizer = zfit.minimize.Minuit(gradient=True, verbosity=verbosity, tol=tol, mode=1)
minimizerzgrad = zfit.minimize.Minuit(gradient=False, verbosity=verbosity, tol=tol, mode=1)

params = nllz.get_params()
initvals = np.array(params)

with zfit.param.set_values(params, initvals):
result = minimizer.minimize(nllz)

with zfit.param.set_values(params, initvals):
result2 = minimizer.minimize(nll)

assert result.params['mu']['value'] == pytest.approx(result2.params['mu']['value'], rel=1e-3)
assert result.params['sigma']['value'] == pytest.approx(result2.params['sigma']['value'], rel=1e-3)

with zfit.param.set_values(params, params):
result4 = minimizerzgrad.minimize(nll)

assert result.params['mu']['value'] == pytest.approx(result4.params['mu']['value'], rel=1e-3)
assert result.params['sigma']['value'] == pytest.approx(result4.params['sigma']['value'], rel=1e-3)
1 change: 1 addition & 0 deletions zfit_physics/roofit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import loss, variables
61 changes: 61 additions & 0 deletions zfit_physics/roofit/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2024 zfit
from __future__ import annotations

from contextlib import suppress

import zfit
from zfit.util.container import convert_to_container

from .variables import roo2z_param


def nll_from_roofit(nll, params=None):
"""
Converts a RooFit NLL (negative log-likelihood) to a Zfit loss object.
Args:
nll: The RooFit NLL object to be converted.
Returns:
zfit.loss.SimpleLoss: The converted Zfit loss object.
Raises:
TypeError: If the provided RooFit loss does not have an error level.
"""
params = {} if params is None else {p.name: p for p in convert_to_container(params)}

ROOT = None
if "cppyy.gbl.RooAbsReal" in str(type(nll)):
with suppress(ImportError):
import ROOT
if ROOT is None or not isinstance(nll, ROOT.RooAbsReal):
return False # not a RooFit loss

import zfit

def roofit_eval(x):
for par, arg in zip(nll.getVariables(), x):
par.setVal(arg)
# following RooMinimizerFcn.cxx
nll.setHideOffset(False)
r = nll.getVal()
nll.setHideOffset(True)
return r

paramsall = []
for v in nll.getVariables():
param = params[name] if (name := v.GetName()) in params else roo2z_param(v)
paramsall.append(param)

if (errordef := getattr(nll, "defaultErrorLevel", lambda: None)()) is None and (
errordef := getattr(nll, "errordef", lambda: None)()
) is None:
msg = (
"Provided loss is RooFit loss but has not error level. "
"Either set it or create an attribute on the fly (like `nllroofit.errordef = 0.5` "
)
raise TypeError(msg)
return zfit.loss.SimpleLoss(roofit_eval, paramsall, errordef=errordef, jit=False, gradient="num", hessian="num")


zfit.loss.SimpleLoss.register_convertable_loss(nll_from_roofit, priority=50)
27 changes: 27 additions & 0 deletions zfit_physics/roofit/variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations


def roo2z_param(v):
"""
Converts a RooFit RooRealVar to a zfit parameter.
Args:
v: RooFit RooRealVar to convert.
Returns:
A zfit.Parameter object with properties copied from the RooFit variable.
"""
import zfit

name = v.GetName()
value = v.getVal()
label = v.GetTitle()
lower = v.getMin()
upper = v.getMax()
floating = not v.isConstant()
stepsize = None
if v.hasError():
stepsize = v.getError()
elif v.hasAsymError(): # just take average
stepsize = (v.getErrorHi() - v.getErrorLo()) / 2
return zfit.Parameter(name, value, lower=lower, upper=upper, floating=floating, step_size=stepsize, label=label)

0 comments on commit 1d7ff29

Please sign in to comment.