Skip to content

Commit

Permalink
Allow passing compile_kwargs to step inner functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 22, 2024
1 parent 8c2ced3 commit 161b859
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 7 deletions.
16 changes: 15 additions & 1 deletion pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def instantiate_steppers(
*,
step_kwargs: dict[str, dict] | None = None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
) -> Step | list[Step]:
"""Instantiate steppers assigned to the model variables.
Expand Down Expand Up @@ -147,6 +148,7 @@ def instantiate_steppers(
vars=vars,
model=model,
initial_point=initial_point,
compile_kwargs=compile_kwargs,
**kwargs,
)
steps.append(step)
Expand Down Expand Up @@ -435,6 +437,7 @@ def sample(
callback=None,
mp_ctx=None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,
**kwargs,
) -> InferenceData: ...

Expand Down Expand Up @@ -467,6 +470,7 @@ def sample(
mp_ctx=None,
model: Model | None = None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,
**kwargs,
) -> MultiTrace: ...

Expand Down Expand Up @@ -498,6 +502,7 @@ def sample(
mp_ctx=None,
blas_cores: int | None | Literal["auto"] = "auto",
model: Model | None = None,
compile_kwargs: dict | None = None,
**kwargs,
) -> InferenceData | MultiTrace:
r"""Draw samples from the posterior using the given step methods.
Expand Down Expand Up @@ -599,6 +604,9 @@ def sample(
See multiprocessing documentation for details.
model : Model (optional if in ``with`` context)
Model to sample from. The model needs to have free random variables.
compile_kwargs: dict, optional
Dictionary with keyword argument to pass to the functions compiled by the step methods.
Returns
-------
Expand Down Expand Up @@ -796,6 +804,7 @@ def joined_blas_limiter():
jitter_max_retries=jitter_max_retries,
tune=tune,
initvals=initvals,
compile_kwargs=compile_kwargs,
**kwargs,
)
else:
Expand All @@ -815,6 +824,7 @@ def joined_blas_limiter():
selected_steps=selected_steps,
step_kwargs=kwargs,
initial_point=initial_points[0],
compile_kwargs=compile_kwargs,
)
if isinstance(step, list):
step = CompoundStep(step)
Expand Down Expand Up @@ -1390,6 +1400,7 @@ def init_nuts(
jitter_max_retries: int = 10,
tune: int | None = None,
initvals: StartDict | Sequence[StartDict | None] | None = None,
compile_kwargs: dict | None = None,
**kwargs,
) -> tuple[Sequence[PointType], NUTS]:
"""Set up the mass matrix initialization for NUTS.
Expand Down Expand Up @@ -1466,6 +1477,9 @@ def init_nuts(
if init == "auto":
init = "jitter+adapt_diag"

if compile_kwargs is None:
compile_kwargs = {}

random_seed_list = _get_seeds_per_chain(random_seed, chains)

_log.info(f"Initializing NUTS using {init}...")
Expand All @@ -1477,7 +1491,7 @@ def init_nuts(
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
]

logp_dlogp_func = model.logp_dlogp_function()
logp_dlogp_func = model.logp_dlogp_function(**compile_kwargs)
initial_points = _init_jitter(
model,
initvals,
Expand Down
5 changes: 4 additions & 1 deletion pymc/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,16 @@ def __init__(
logp_dlogp_func=None,
rng: RandomGenerator = None,
initial_point=None,
compile_kwargs: dict | None = None,
**pytensor_kwargs,
):
model = modelcontext(model)

if logp_dlogp_func is None:
if compile_kwargs is None:
compile_kwargs = {}
logp_dlogp_func = model.logp_dlogp_function(
vars, dtype=dtype, initial_point=initial_point, **pytensor_kwargs
vars, dtype=dtype, initial_point=initial_point, **compile_kwargs, **pytensor_kwargs
)

self._logp_dlogp_func = logp_dlogp_func
Expand Down
14 changes: 10 additions & 4 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
initial_point: PointType | None = None,
mode=None,
rng=None,
compile_kwargs: dict | None = None,
**kwargs,
):
"""Create an instance of a Metropolis stepper.
Expand Down Expand Up @@ -254,7 +255,7 @@ def __init__(
self.mode = mode

shared = pm.make_shared_replacements(initial_point, vars, model)
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared)
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs)
super().__init__(vars, shared, rng=rng)

def reset_tuning(self):
Expand Down Expand Up @@ -856,6 +857,7 @@ def __init__(
mode=None,
rng=None,
initial_point=None,
compile_kwargs: dict | None = None,
**kwargs,
):
model = pm.modelcontext(model)
Expand Down Expand Up @@ -891,7 +893,7 @@ def __init__(
self.mode = mode

shared = pm.make_shared_replacements(initial_values, vars, model)
self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared)
self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared, compile_kwargs)
super().__init__(vars, shared, rng=rng, **kwargs)

def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
Expand Down Expand Up @@ -1025,6 +1027,7 @@ def __init__(
tune_drop_fraction: float = 0.9,
model=None,
initial_point=None,
compile_kwargs: dict | None = None,
mode=None,
rng=None,
**kwargs,
Expand Down Expand Up @@ -1074,7 +1077,7 @@ def __init__(
self.mode = mode

shared = pm.make_shared_replacements(initial_point, vars, model)
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared)
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs)
super().__init__(vars, shared, rng=rng)

def reset_tuning(self):
Expand Down Expand Up @@ -1165,6 +1168,7 @@ def delta_logp(
logp: pt.TensorVariable,
vars: list[pt.TensorVariable],
shared: dict[pt.TensorVariable, pt.sharedvar.TensorSharedVariable],
compile_kwargs: dict | None,
) -> pytensor.compile.Function:
[logp0], inarray0 = join_nonshared_inputs(
point=point, outputs=[logp], inputs=vars, shared_inputs=shared
Expand All @@ -1177,6 +1181,8 @@ def delta_logp(
# Replace any potential duplicated RNG nodes
(logp1,) = replace_rng_nodes((logp1,))

f = compile_pymc([inarray1, inarray0], logp1 - logp0)
if compile_kwargs is None:
compile_kwargs = {}
f = compile_pymc([inarray1, inarray0], logp1 - logp0, **compile_kwargs)
f.trust_input = True
return f
5 changes: 4 additions & 1 deletion pymc/step_methods/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
initial_point=None,
iter_limit=np.inf,
rng=None,
compile_kwargs: dict | None,
**kwargs,
):
model = modelcontext(model)
Expand All @@ -105,7 +106,9 @@ def __init__(
[logp], raveled_inp = join_nonshared_inputs(
point=initial_point, outputs=[model.logp()], inputs=vars, shared_inputs=shared
)
self.logp = compile_pymc([raveled_inp], logp)
if compile_kwargs is None:
compile_kwargs = {}
self.logp = compile_pymc([raveled_inp], logp, **compile_kwargs)
self.logp.trust_input = True

super().__init__(vars, shared, rng=rng, **kwargs)
Expand Down

0 comments on commit 161b859

Please sign in to comment.