Skip to content

Commit

Permalink
add callers for grad and hessian
Browse files Browse the repository at this point in the history
  • Loading branch information
jackaraz committed Jan 28, 2024
1 parent 045297d commit a46b31e
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 0 deletions.
11 changes: 11 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ Hypothesis testing
asymptotic_calculator.compute_asymptotic_confidence_level
toy_calculator.compute_toy_confidence_level

Gradient Tools
--------------

.. currentmodule:: spey.math

.. autosummary::
:toctree: _generated/

value_and_grad
hessian

Default Backends
----------------

Expand Down
1 change: 1 addition & 0 deletions src/spey/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"BackendBase",
"ConverterBase",
"about",
"math",
]


Expand Down
86 changes: 86 additions & 0 deletions src/spey/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Callable, Optional, Tuple

from autograd import numpy

from .interface.statistical_model import StatisticalModel
from .utils import ExpectationType

# pylint: disable=E1101

__all__ = ["value_and_grad", "hessian"]


def __dir__():
return __all__


def value_and_grad(
statistical_model: StatisticalModel,
expected: ExpectationType = ExpectationType.observed,
data: Optional[numpy.ndarray] = None,
) -> Callable[[numpy.ndarray], Tuple[numpy.ndarray, numpy.ndarray]]:
"""
Retreive function to compute negative log-likelihood and its gradient.
Args:
statistical_model (~spey.StatisticalModel): statistical model to be used.
expected (~spey.ExpectationType): Sets which values the fitting algorithm should focus and
p-values to be computed.
* :obj:`~spey.ExpectationType.observed`: Computes the p-values with via post-fit
prescriotion which means that the experimental data will be assumed to be the truth
(default).
* :obj:`~spey.ExpectationType.aposteriori`: Computes the expected p-values with via
post-fit prescriotion which means that the experimental data will be assumed to be
the truth.
* :obj:`~spey.ExpectationType.apriori`: Computes the expected p-values with via pre-fit
prescription which means that the SM will be assumed to be the truth.
data (``numpy.ndarray``, default ``None``): input data that to fit. If `None` observed
data will be used.
Returns:
``Callable[[numpy.ndarray], numpy.ndarray, numpy.ndarray]``:
negative log-likelihood and its gradient with respect to nuisance parameters
"""
val_and_grad = statistical_model.backend.get_objective_function(
expected=expected, data=data, do_grad=True
)
return lambda pars: val_and_grad(numpy.array(pars))


def hessian(
statistical_model: StatisticalModel,
expected: ExpectationType = ExpectationType.observed,
data: Optional[numpy.ndarray] = None,
) -> Callable[[numpy.ndarray], numpy.ndarray]:
r"""
Retreive the function to compute Hessian of negative log-likelihood
.. math::
{\rm Hessian} = -\frac{\partial^2\mathcal{L}(\theta)}{\partial\theta_i\partial\theta_j}
Args:
statistical_model (~spey.StatisticalModel): statistical model to be used.
expected (~spey.ExpectationType): Sets which values the fitting algorithm should focus and
p-values to be computed.
* :obj:`~spey.ExpectationType.observed`: Computes the p-values with via post-fit
prescriotion which means that the experimental data will be assumed to be the truth
(default).
* :obj:`~spey.ExpectationType.aposteriori`: Computes the expected p-values with via
post-fit prescriotion which means that the experimental data will be assumed to be
the truth.
* :obj:`~spey.ExpectationType.apriori`: Computes the expected p-values with via pre-fit
prescription which means that the SM will be assumed to be the truth.
data (``numpy.ndarray``, default ``None``): input data that to fit. If `None` observed
data will be used.
Returns:
``Callable[[numpy.ndarray], numpy.ndarray]``:
function to compute hessian of negative log-likelihood
"""
hess = statistical_model.backend.get_hessian_logpdf_func(expected=expected, data=data)
return lambda pars: -1.0 * hess(numpy.array(pars))

0 comments on commit a46b31e

Please sign in to comment.