diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e93f576..657b08a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,80 +1,72 @@ ci: autoupdate_schedule: quarterly - repos: - - repo: https://github.com/pre-commit/pre-commit-hooks +- repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: - - id: check-added-large-files - - id: check-case-conflict - - id: check-merge-conflict - - id: check-symlinks - - id: check-yaml - - id: check-toml - - id: debug-statements - - id: end-of-file-fixer - - id: mixed-line-ending - - id: requirements-txt-fixer - - id: trailing-whitespace - - id: detect-private-key - - id: fix-byte-order-marker - - id: check-ast - -# - repo: https://github.com/PyCQA/docformatter -# rev: v1.7.5 -# hooks: -# - id: docformatter -# args: [ -r, --in-place, --wrap-descriptions, '120', --wrap-summaries, '120', -- ] + - id: check-added-large-files + - id: check-case-conflict + - id: check-merge-conflict + - id: check-symlinks + - id: check-yaml + - id: check-toml + - id: debug-statements + - id: end-of-file-fixer + - id: mixed-line-ending + - id: requirements-txt-fixer + - id: trailing-whitespace + - id: detect-private-key + - id: fix-byte-order-marker + - id: check-ast - repo: https://github.com/pre-commit/pygrep-hooks rev: v1.10.0 hooks: - - id: python-use-type-annotations - - id: python-check-mock-methods - - id: python-no-eval - - id: rst-directive-colons - - - repo: https://github.com/PyCQA/isort + - id: python-use-type-annotations + - id: python-check-mock-methods + - id: python-no-eval + - id: rst-directive-colons +- repo: https://github.com/PyCQA/isort rev: 5.13.2 hooks: - - id: isort - - - repo: https://github.com/asottile/pyupgrade + - id: isort +- repo: https://github.com/asottile/pyupgrade rev: v3.17.0 hooks: - - id: pyupgrade - args: [ --py38-plus ] - - - repo: https://github.com/asottile/setup-cfg-fmt + - id: pyupgrade + args: + - --py38-plus +- repo: https://github.com/asottile/setup-cfg-fmt rev: v2.5.0 hooks: - - id: setup-cfg-fmt - args: [ --max-py-version=3.12, --include-version-classifiers ] - - - # Notebook formatting - - repo: https://github.com/nbQA-dev/nbQA + - id: setup-cfg-fmt + args: + - --max-py-version=3.12 + - --include-version-classifiers +- repo: https://github.com/nbQA-dev/nbQA rev: 1.8.7 hooks: - - id: nbqa-isort - additional_dependencies: [ isort ] - - - id: nbqa-pyupgrade - additional_dependencies: [ pyupgrade ] - args: [ --py38-plus ] - - - repo: https://github.com/mgedmin/check-manifest + - id: nbqa-isort + additional_dependencies: + - isort + - id: nbqa-pyupgrade + additional_dependencies: + - pyupgrade + args: + - --py38-plus +- repo: https://github.com/mgedmin/check-manifest rev: '0.49' hooks: - - id: check-manifest - stages: [ manual ] - - repo: https://github.com/sondrelg/pep585-upgrade - rev: 'v1.0' + - id: check-manifest + stages: + - manual +- repo: https://github.com/sondrelg/pep585-upgrade + rev: v1.0 hooks: - - id: upgrade-type-hints - args: [ '--futures=true' ] - - - repo: https://github.com/MarcoGorelli/auto-walrus + - id: upgrade-type-hints + args: + - --futures=true +- repo: https://github.com/MarcoGorelli/auto-walrus rev: 0.3.4 hooks: - id: auto-walrus @@ -82,9 +74,18 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: "v0.6.9" hooks: - - id: ruff - types_or: [ python, pyi, jupyter ] - args: [ --fix, --unsafe-fixes, --show-fixes , --line-length=120] - # Run the formatter. - - id: ruff-format - types_or: [ python, pyi, jupyter ] + - id: ruff + types_or: + - python + - pyi + - jupyter + args: + - --fix + - --unsafe-fixes + - --show-fixes + - --line-length=120 + - id: ruff-format + types_or: + - python + - pyi + - jupyter diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4248f7b..079336b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,7 @@ Develop Major Features and Improvements ------------------------------- +- add a RooFit compatibility layer and automatically convert losses, also inside minimizers (through ``SimpleLoss.from_any``) Breaking changes ------------------ diff --git a/tests/roofit/test_loss_compat.py b/tests/roofit/test_loss_compat.py new file mode 100644 index 0000000..8b9db51 --- /dev/null +++ b/tests/roofit/test_loss_compat.py @@ -0,0 +1,67 @@ +import numpy as np +import pytest + + +def test_loss_registry(): + _ = pytest.importorskip("ROOT") + # Copyright (c) 2024 zfit + + import zfit + + import zfit_physics.roofit as zroofit + + # create space + obs = zfit.Space("x", -2, 3) + + # parameters + mu = zfit.Parameter("mu", 1.2, -4, 6) + sigma = zfit.Parameter("sigma", 1.3, 0.5, 10) + + # model building, pdf creation + gauss = zfit.pdf.Gauss(mu=mu, sigma=sigma, obs=obs) + + # data + ndraw = 10_000 + data = np.random.normal(loc=2.0, scale=3.0, size=ndraw) + data = obs.filter(data) # works also for pandas DataFrame + + from ROOT import RooArgSet, RooDataSet, RooFit, RooGaussian, RooRealVar + + mur = RooRealVar("mu", "mu", 1.2, -4, 6) + sigmar = RooRealVar("sigma", "sigma", 1.3, 0.5, 10) + obsr = RooRealVar("x", "x", -2, 3) + gaussr = RooGaussian("gauss", "gauss", obsr, mur, sigmar) + + datar = RooDataSet("data", "data", {obsr}) + for d in data: + obsr.setVal(d) + datar.add(RooArgSet(obsr)) + + # create a loss function + nll = gaussr.createNLL(datar) + + nllz = zfit.loss.UnbinnedNLL(model=gauss, data=data) + + # create a minimizer + tol = 1e-3 + verbosity = 0 + minimizer = zfit.minimize.Minuit(gradient=True, verbosity=verbosity, tol=tol, mode=1) + minimizerzgrad = zfit.minimize.Minuit(gradient=False, verbosity=verbosity, tol=tol, mode=1) + + params = nllz.get_params() + initvals = np.array(params) + + with zfit.param.set_values(params, initvals): + result = minimizer.minimize(nllz) + + with zfit.param.set_values(params, initvals): + result2 = minimizer.minimize(nll) + + assert result.params['mu']['value'] == pytest.approx(result2.params['mu']['value'], rel=1e-3) + assert result.params['sigma']['value'] == pytest.approx(result2.params['sigma']['value'], rel=1e-3) + + with zfit.param.set_values(params, params): + result4 = minimizerzgrad.minimize(nll) + + assert result.params['mu']['value'] == pytest.approx(result4.params['mu']['value'], rel=1e-3) + assert result.params['sigma']['value'] == pytest.approx(result4.params['sigma']['value'], rel=1e-3) diff --git a/zfit_physics/roofit/__init__.py b/zfit_physics/roofit/__init__.py new file mode 100644 index 0000000..e66f8bb --- /dev/null +++ b/zfit_physics/roofit/__init__.py @@ -0,0 +1 @@ +from . import loss, variables diff --git a/zfit_physics/roofit/loss.py b/zfit_physics/roofit/loss.py new file mode 100644 index 0000000..16fa835 --- /dev/null +++ b/zfit_physics/roofit/loss.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 zfit +from __future__ import annotations + +from contextlib import suppress + +import zfit +from zfit.util.container import convert_to_container + +from .variables import roo2z_param + + +def nll_from_roofit(nll, params=None): + """ + Converts a RooFit NLL (negative log-likelihood) to a Zfit loss object. + + Args: + nll: The RooFit NLL object to be converted. + + Returns: + zfit.loss.SimpleLoss: The converted Zfit loss object. + + Raises: + TypeError: If the provided RooFit loss does not have an error level. + """ + params = {} if params is None else {p.name: p for p in convert_to_container(params)} + + ROOT = None + if "cppyy.gbl.RooAbsReal" in str(type(nll)): + with suppress(ImportError): + import ROOT + if ROOT is None or not isinstance(nll, ROOT.RooAbsReal): + return False # not a RooFit loss + + import zfit + + def roofit_eval(x): + for par, arg in zip(nll.getVariables(), x): + par.setVal(arg) + # following RooMinimizerFcn.cxx + nll.setHideOffset(False) + r = nll.getVal() + nll.setHideOffset(True) + return r + + paramsall = [] + for v in nll.getVariables(): + param = params[name] if (name := v.GetName()) in params else roo2z_param(v) + paramsall.append(param) + + if (errordef := getattr(nll, "defaultErrorLevel", lambda: None)()) is None and ( + errordef := getattr(nll, "errordef", lambda: None)() + ) is None: + msg = ( + "Provided loss is RooFit loss but has not error level. " + "Either set it or create an attribute on the fly (like `nllroofit.errordef = 0.5` " + ) + raise TypeError(msg) + return zfit.loss.SimpleLoss(roofit_eval, paramsall, errordef=errordef, jit=False, gradient="num", hessian="num") + + +zfit.loss.SimpleLoss.register_convertable_loss(nll_from_roofit, priority=50) diff --git a/zfit_physics/roofit/variables.py b/zfit_physics/roofit/variables.py new file mode 100644 index 0000000..e6ead77 --- /dev/null +++ b/zfit_physics/roofit/variables.py @@ -0,0 +1,27 @@ +from __future__ import annotations + + +def roo2z_param(v): + """ + Converts a RooFit RooRealVar to a zfit parameter. + + Args: + v: RooFit RooRealVar to convert. + + Returns: + A zfit.Parameter object with properties copied from the RooFit variable. + """ + import zfit + + name = v.GetName() + value = v.getVal() + label = v.GetTitle() + lower = v.getMin() + upper = v.getMax() + floating = not v.isConstant() + stepsize = None + if v.hasError(): + stepsize = v.getError() + elif v.hasAsymError(): # just take average + stepsize = (v.getErrorHi() - v.getErrorLo()) / 2 + return zfit.Parameter(name, value, lower=lower, upper=upper, floating=floating, step_size=stepsize, label=label)