-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8c598e7
commit 09a19a2
Showing
6 changed files
with
207 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
|
||
def test_loss_registry(): | ||
_ = pytest.importorskip("ROOT") | ||
# Copyright (c) 2024 zfit | ||
|
||
import zfit | ||
|
||
import zfit_physics.roofit as zroofit | ||
|
||
# create space | ||
obs = zfit.Space("x", -2, 3) | ||
|
||
# parameters | ||
mu = zfit.Parameter("mu", 1.2, -4, 6) | ||
sigma = zfit.Parameter("sigma", 1.3, 0.5, 10) | ||
|
||
# model building, pdf creation | ||
gauss = zfit.pdf.Gauss(mu=mu, sigma=sigma, obs=obs) | ||
|
||
# data | ||
ndraw = 10_000 | ||
data = np.random.normal(loc=2.0, scale=3.0, size=ndraw) | ||
data = obs.filter(data) # works also for pandas DataFrame | ||
|
||
from ROOT import RooArgSet, RooDataSet, RooFit, RooGaussian, RooRealVar | ||
|
||
mur = RooRealVar("mu", "mu", 1.2, -4, 6) | ||
sigmar = RooRealVar("sigma", "sigma", 1.3, 0.5, 10) | ||
obsr = RooRealVar("x", "x", -2, 3) | ||
gaussr = RooGaussian("gauss", "gauss", obsr, mur, sigmar) | ||
|
||
datar = RooDataSet("data", "data", {obsr}) | ||
for d in data: | ||
obsr.setVal(d) | ||
datar.add(RooArgSet(obsr)) | ||
|
||
# create a loss function | ||
nll = gaussr.createNLL(datar) | ||
|
||
nllz = zfit.loss.UnbinnedNLL(model=gauss, data=data) | ||
|
||
# create a minimizer | ||
tol = 1e-3 | ||
verbosity = 0 | ||
minimizer = zfit.minimize.Minuit(gradient=True, verbosity=verbosity, tol=tol, mode=1) | ||
minimizerzgrad = zfit.minimize.Minuit(gradient=False, verbosity=verbosity, tol=tol, mode=1) | ||
|
||
params = nllz.get_params() | ||
initvals = np.array(params) | ||
|
||
with zfit.param.set_values(params, initvals): | ||
result = minimizer.minimize(nllz) | ||
|
||
with zfit.param.set_values(params, initvals): | ||
result2 = minimizer.minimize(nll) | ||
|
||
assert result.params['mu']['value'] == pytest.approx(result2.params['mu']['value'], rel=1e-3) | ||
assert result.params['sigma']['value'] == pytest.approx(result2.params['sigma']['value'], rel=1e-3) | ||
|
||
with zfit.param.set_values(params, params): | ||
result4 = minimizerzgrad.minimize(nll) | ||
|
||
assert result.params['mu']['value'] == pytest.approx(result4.params['mu']['value'], rel=1e-3) | ||
assert result.params['sigma']['value'] == pytest.approx(result4.params['sigma']['value'], rel=1e-3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from . import loss, variables |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) 2024 zfit | ||
from __future__ import annotations | ||
|
||
from contextlib import suppress | ||
|
||
import zfit | ||
from zfit.util.container import convert_to_container | ||
|
||
from .variables import roo2z_param | ||
|
||
|
||
def nll_from_roofit(nll, params=None): | ||
""" | ||
Converts a RooFit NLL (negative log-likelihood) to a Zfit loss object. | ||
Args: | ||
nll: The RooFit NLL object to be converted. | ||
Returns: | ||
zfit.loss.SimpleLoss: The converted Zfit loss object. | ||
Raises: | ||
TypeError: If the provided RooFit loss does not have an error level. | ||
""" | ||
params = {} if params is None else {p.name: p for p in convert_to_container(params)} | ||
|
||
ROOT = None | ||
if "cppyy.gbl.RooAbsReal" in str(type(nll)): | ||
with suppress(ImportError): | ||
import ROOT | ||
if ROOT is None or not isinstance(nll, ROOT.RooAbsReal): | ||
return False # not a RooFit loss | ||
|
||
import zfit | ||
|
||
def roofit_eval(x): | ||
for par, arg in zip(nll.getVariables(), x): | ||
par.setVal(arg) | ||
# following RooMinimizerFcn.cxx | ||
nll.setHideOffset(False) | ||
r = nll.getVal() | ||
nll.setHideOffset(True) | ||
return r | ||
|
||
paramsall = [] | ||
for v in nll.getVariables(): | ||
param = params[name] if (name := v.GetName()) in params else roo2z_param(v) | ||
paramsall.append(param) | ||
|
||
if (errordef := getattr(nll, "defaultErrorLevel", lambda: None)()) is None and ( | ||
errordef := getattr(nll, "errordef", lambda: None)() | ||
) is None: | ||
msg = ( | ||
"Provided loss is RooFit loss but has not error level. " | ||
"Either set it or create an attribute on the fly (like `nllroofit.errordef = 0.5` " | ||
) | ||
raise TypeError(msg) | ||
return zfit.loss.SimpleLoss(roofit_eval, paramsall, errordef=errordef, jit=False, gradient="num", hessian="num") | ||
|
||
|
||
zfit.loss.SimpleLoss.register_convertable_loss(nll_from_roofit, priority=50) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from __future__ import annotations | ||
|
||
|
||
def roo2z_param(v): | ||
""" | ||
Converts a RooFit RooRealVar to a zfit parameter. | ||
Args: | ||
v: RooFit RooRealVar to convert. | ||
Returns: | ||
A zfit.Parameter object with properties copied from the RooFit variable. | ||
""" | ||
import zfit | ||
|
||
name = v.GetName() | ||
value = v.getVal() | ||
label = v.GetTitle() | ||
lower = v.getMin() | ||
upper = v.getMax() | ||
floating = not v.isConstant() | ||
stepsize = None | ||
if v.hasError(): | ||
stepsize = v.getError() | ||
elif v.hasAsymError(): # just take average | ||
stepsize = (v.getErrorHi() - v.getErrorLo()) / 2 | ||
return zfit.Parameter(name, value, lower=lower, upper=upper, floating=floating, step_size=stepsize, label=label) |