Skip to content

Commit

Permalink
Add prior sense related functionality (#24)
Browse files Browse the repository at this point in the history
* add priorsense related functions

* rebase

* add test, remove weights, reorder

* add _get_power_scale_weights as itermediate function

* single call to get_power_scale_weights, use check mark unicode

* minor updates

* use group everywhere

* lint

---------

Co-authored-by: Oriol (VANT Edge) <[email protected]>
  • Loading branch information
aloctavodia and OriolAbril authored Oct 16, 2024
1 parent b09a771 commit d622d99
Show file tree
Hide file tree
Showing 9 changed files with 443 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/arviz_stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@
try:
from arviz_stats.utils import *
from arviz_stats.accessors import *
from arviz_stats.psense import psense, psense_summary

except ModuleNotFoundError:
pass
26 changes: 26 additions & 0 deletions src/arviz_stats/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ def pareto_min_ss(self, dims=None):
"""Compute the minimum effective sample size on the DataArray."""
return get_function("pareto_min_ss")(self._obj, dims=dims)

def power_scale_lw(self, alpha=1, dims=None):
"""Compute log weights for power-scaling of the DataTree."""
return get_function("power_scale_lw")(self._obj, alpha=alpha, dims=dims)

def power_scale_sense(self, lower_w=None, upper_w=None, delta=None, dims=None):
"""Compute power-scaling sensitivity."""
return get_function("power_scale_sense")(
self._obj, lower_w=lower_w, upper_w=upper_w, delta=delta, dims=dims
)


@xr.register_dataset_accessor("azstats")
class AzStatsDsAccessor(_BaseAccessor):
Expand Down Expand Up @@ -179,6 +189,14 @@ def pareto_min_ss(self, dims=None):
"""Compute the min sample size for all variables in the dataset."""
return self._apply("pareto_min_ss", dims=dims)

def power_scale_lw(self, dims=None, **kwargs):
"""Compute log weights for power-scaling of the DataTree."""
return self._apply("power_scale_lw", dims=dims, **kwargs)

def power_scale_sense(self, dims=None, **kwargs):
"""Compute power-scaling sensitivity."""
return self._apply("power_scale_sense", dims=dims, **kwargs)


@register_datatree_accessor("azstats")
class AzStatsDtAccessor(_BaseAccessor):
Expand Down Expand Up @@ -276,3 +294,11 @@ def thin(self, dims=None, group="posterior", **kwargs):
def pareto_min_ss(self, dims=None, group="posterior"):
"""Compute the min sample size for all variables in a group of the DataTree."""
return self._apply("pareto_min_ss", dims=dims, group=group)

def power_scale_lw(self, dims=None, group="log_likelihood", **kwargs):
"""Compute log weights for power-scaling of the DataTree."""
return self._apply("power_scale_lw", dims=dims, group=group, **kwargs)

def power_scale_sense(self, dims=None, group="posterior", **kwargs):
"""Compute power-scaling sensitivity."""
return self._apply("power_scale_sense", dims=dims, group=group, **kwargs)
27 changes: 27 additions & 0 deletions src/arviz_stats/base/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,33 @@ def pareto_min_ss(self, ary, chain_axis=-2, draw_axis=-1):
pms_array = make_ufunc(self._pareto_min_ss, n_output=1, n_input=1, n_dims=2, ravel=False)
return pms_array(ary)

def power_scale_lw(self, ary, alpha=0, axes=-1):
"""Compute ranks of MCMC samples."""
ary, axes = process_ary_axes(ary, axes)
psl_ufunc = make_ufunc(
self._power_scale_lw,
n_output=1,
n_input=1,
n_dims=len(axes),
ravel=False,
)
return psl_ufunc(ary, out_shape=(ary.shape[i] for i in axes), alpha=alpha)

def power_scale_sense(self, ary, lower_w, upper_w, delta, chain_axis=-2, draw_axis=-1):
"""Compute power-scaling sensitivity."""
if chain_axis is None:
ary = np.expand_dims(ary, axis=0)
lower_w = np.expand_dims(lower_w, axis=0)
upper_w = np.expand_dims(upper_w, axis=0)
chain_axis = 0
ary, _ = process_ary_axes(ary, [chain_axis, draw_axis])
lower_w, _ = process_ary_axes(lower_w, [chain_axis, draw_axis])
upper_w, _ = process_ary_axes(upper_w, [chain_axis, draw_axis])
pss_array = make_ufunc(
self._power_scale_sense, n_output=1, n_input=3, n_dims=2, ravel=False
)
return pss_array(ary, lower_w, upper_w, delta=delta)

def compute_ranks(self, ary, axes=-1, relative=False):
"""Compute ranks of MCMC samples."""
ary, axes = process_ary_axes(ary, axes)
Expand Down
26 changes: 25 additions & 1 deletion src/arviz_stats/base/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
from arviz_base import rcParams
from xarray import DataArray, apply_ufunc, concat
from xarray import DataArray, apply_ufunc, broadcast, concat
from xarray_einstats.stats import _apply_nonreduce_func

from arviz_stats.base.array import array_stats
Expand Down Expand Up @@ -274,5 +274,29 @@ def pareto_min_ss(self, da, dims=None):
kwargs={"chain_axis": chain_axis, "draw_axis": draw_axis},
)

def power_scale_lw(self, da, alpha=0, dims=None):
"""Compute log weights for power-scaling component by alpha."""
dims = validate_dims(dims)
return apply_ufunc(
self.array_class.power_scale_lw,
da,
alpha,
input_core_dims=[dims, []],
output_core_dims=[dims],
kwargs={"axes": np.arange(-len(dims), 0, 1)},
)

def power_scale_sense(self, da, lower_w, upper_w, delta, dims=None):
"""Compute power-scaling sensitivity."""
dims, chain_axis, draw_axis = validate_dims_chain_draw_axis(dims)
return apply_ufunc(
self.array_class.power_scale_sense,
*broadcast(da, lower_w, upper_w),
delta,
input_core_dims=[dims, dims, dims, []],
output_core_dims=[[]],
kwargs={"chain_axis": chain_axis, "draw_axis": draw_axis},
)


dataarray_stats = BaseDataArray(array_class=array_stats)
91 changes: 83 additions & 8 deletions src/arviz_stats/base/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,20 @@ def _pareto_khat(self, ary, r_eff=1, tail="both", log_weights=False):

n_draws = len(ary)

n_draws_tail = self._get_ps_tails(n_draws, r_eff, tail=tail)

if tail == "both":
khat = max(
self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=t)[1]
for t in ("left", "right")
)
else:
_, khat = self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=tail)

return khat

@staticmethod
def _get_ps_tails(n_draws, r_eff, tail):
if n_draws > 255:
n_draws_tail = np.ceil(3 * (n_draws / r_eff) ** 0.5).astype(int)
else:
Expand All @@ -389,14 +403,7 @@ def _pareto_khat(self, ary, r_eff=1, tail="both", log_weights=False):
warnings.warn("Number of tail draws cannot be less than 5. Changing to 5")
n_draws_tail = 5

khat = max(
self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=t)[1]
for t in ("left", "right")
)
else:
_, khat = self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=tail)

return khat
return n_draws_tail

def _ps_tail(
self, ary, n_draws, n_draws_tail, smooth_draws=False, tail="both", log_weights=False
Expand Down Expand Up @@ -543,3 +550,71 @@ def _gpinv(probs, kappa, sigma, mu):
q = mu + sigma * np.expm1(-kappa * np.log1p(-probs)) / kappa

return q

def _power_scale_sense(self, ary, lower_w, upper_w, delta=0.01):
"""Compute power-scaling sensitivity by finite difference second derivative of CJS."""
ary = np.ravel(ary)
lower_w = np.ravel(lower_w)
upper_w = np.ravel(upper_w)
lower_cjs = max(self._cjs_dist(ary, lower_w), self._cjs_dist(-1 * ary, lower_w))
upper_cjs = max(self._cjs_dist(ary, upper_w), self._cjs_dist(-1 * ary, upper_w))
grad = (lower_cjs + upper_cjs) / (2 * np.log2(1 + delta))
return grad

def _power_scale_lw(self, ary, alpha):
"""Compute log weights for power-scaling component by alpha."""
shape = ary.shape
ary = np.ravel(ary)
log_weights = (alpha - 1) * ary
n_draws = len(log_weights)
r_eff = self._ess_tail(ary, relative=True)
n_draws_tail = self._get_ps_tails(n_draws, r_eff, tail="both")
log_weights, _ = self._ps_tail(
log_weights,
n_draws,
n_draws_tail,
smooth_draws=False,
log_weights=True,
)

return log_weights.reshape(shape)

@staticmethod
def _cjs_dist(ary, weights):
"""Calculate the cumulative Jensen-Shannon distance between original and weighted draws."""
# sort draws and weights
order = np.argsort(ary)
ary = ary[order]
weights = weights[order]

binwidth = np.diff(ary)

# ecdfs
cdf_p = np.linspace(1 / len(ary), 1 - 1 / len(ary), len(ary) - 1)
cdf_q = np.cumsum(weights / np.sum(weights))[:-1]

# integrals of ecdfs
cdf_p_int = np.dot(cdf_p, binwidth)
cdf_q_int = np.dot(cdf_q, binwidth)

# cjs calculation
pq_numer = np.log2(cdf_p, out=np.zeros_like(cdf_p), where=cdf_p != 0)
qp_numer = np.log2(cdf_q, out=np.zeros_like(cdf_q), where=cdf_q != 0)

denom = 0.5 * (cdf_p + cdf_q)
denom = np.log2(denom, out=np.zeros_like(denom), where=denom != 0)

cjs_pq = np.sum(binwidth * (cdf_p * (pq_numer - denom))) + 0.5 / np.log(2) * (
cdf_q_int - cdf_p_int
)

cjs_qp = np.sum(binwidth * (cdf_q * (qp_numer - denom))) + 0.5 / np.log(2) * (
cdf_p_int - cdf_q_int
)

cjs_pq = max(0, cjs_pq)
cjs_qp = max(0, cjs_qp)

bound = cdf_p_int + cdf_q_int

return np.sqrt((cjs_pq + cjs_qp) / bound)
Loading

0 comments on commit d622d99

Please sign in to comment.