Skip to content

Commit

Permalink
enh: jit gradient and value
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-eschle committed Oct 16, 2024
1 parent 3003009 commit a590555
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 3 additions & 4 deletions src/zfit_physics/tfpwa/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import zfit
import zfit.z.numpy as znp
from zfit import z
from zfit.core.interfaces import ZfitParameter
from zfit.util.container import convert_to_container

Expand All @@ -32,14 +33,12 @@ def nll_from_fcn(fcn: tf_pwa.model.FCN, *, params: ParamType = None):

# something is off here: for the value, we need to pass the parameters as a dict
# but for the gradient/hesse, we need to pass them as a list
# TODO: activate if https://github.com/jiangyi15/tf-pwa/pull/153 is merged
# @z.function(wraps="loss")
@z.function(wraps="loss")
def eval_func(params):
paramdict = make_paramdict(params)
return fcn(paramdict)

# TODO: activate if https://github.com/jiangyi15/tf-pwa/pull/153 is merged
# @z.function(wraps="loss")
@z.function(wraps="loss")
def eval_grad(params):
return fcn.nll_grad(params)[1]

Expand Down
4 changes: 3 additions & 1 deletion tests/tfpwa/test_basic_example_tfpwa.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def test_example1_tfpwa():
fit_result = config.fit(method="BFGS")

kwargs = dict(gradient='zfit', tol=0.01)
# kwargs = dict(gradient=False, tol=0.01)
# kwargs = dict(tol=0.01)
assert pytest.approx(nll.value(), 0.001) == initial_val
v, g, h = fcn.nll_grad_hessian()
vz, gz, hz = nll.value_gradient_hessian()
Expand All @@ -79,7 +81,7 @@ def test_example1_tfpwa():
np.testing.assert_allclose(g, gz1, atol=0.001)

minimizer = zfit.minimize.Minuit(verbosity=7, **kwargs)
# minimizer = zfit.minimize.ScipyBFGS(verbosity=7, **kwargs) # performs bestamba
# minimizer = zfit.minimize.ScipyBFGS(verbosity=7, **kwargs) # performs best
# minimizer = zfit.minimize.NLoptMMAV1(verbosity=7, **kwargs)
# minimizer = zfit.minimize.ScipyLBFGSBV1(verbosity=7, **kwargs)
# minimizer = zfit.minimize.NLoptLBFGSV1(verbosity=7, **kwargs)
Expand Down

0 comments on commit a590555

Please sign in to comment.