Skip to content

Commit

Permalink
optimise gradients
Browse files Browse the repository at this point in the history
* using vectorised `value_and_grad` instead of grad
  • Loading branch information
jackaraz committed Jan 26, 2024
1 parent 20fd804 commit 045297d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
8 changes: 2 additions & 6 deletions src/spey/backends/default_pdf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union

from autograd import grad, hessian, jacobian
from autograd import value_and_grad, hessian, jacobian
from autograd import numpy as np
from scipy.optimize import NonlinearConstraint

Expand Down Expand Up @@ -251,11 +251,7 @@ def negative_loglikelihood(pars: np.ndarray) -> np.ndarray:
) - self.constraint_model.log_prob(pars)

if do_grad:
grad_negative_loglikelihood = grad(negative_loglikelihood, argnum=0)
return lambda pars: (
negative_loglikelihood(pars),
grad_negative_loglikelihood(pars),
)
return value_and_grad(negative_loglikelihood, argnum=0)

return negative_loglikelihood

Expand Down
8 changes: 2 additions & 6 deletions src/spey/backends/default_pdf/simple_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Callable, List, Optional, Text, Tuple, Union

from autograd import grad, hessian
from autograd import value_and_grad, hessian
from autograd import numpy as np

from spey._version import __version__
Expand Down Expand Up @@ -162,11 +162,7 @@ def negative_loglikelihood(pars: np.ndarray) -> np.ndarray:
return -self.main_model.log_prob(pars, data)

if do_grad:
grad_negative_loglikelihood = grad(negative_loglikelihood, argnum=0)
return lambda pars: (
negative_loglikelihood(pars),
grad_negative_loglikelihood(pars),
)
return value_and_grad(negative_loglikelihood, argnum=0)

return negative_loglikelihood

Expand Down

0 comments on commit 045297d

Please sign in to comment.