From d9b4794e4f0c5b5b341bd990f8546c58360f3f93 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 15 Oct 2024 13:49:35 +0100 Subject: [PATCH 1/6] add note --- design_notes/interface_for_gibbs.md | 77 +++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 design_notes/interface_for_gibbs.md diff --git a/design_notes/interface_for_gibbs.md b/design_notes/interface_for_gibbs.md new file mode 100644 index 0000000..e776850 --- /dev/null +++ b/design_notes/interface_for_gibbs.md @@ -0,0 +1,77 @@ +# Notes on Potential Future Interface for Gibbs Sampling Support + +## Background + +This document was written after [PR #144](https://github.com/TuringLang/AbstractMCMC.jl/pull/144) was closed. + +It was last updated on October 15, 2024. At that time: + +- _AbstractMCMC.jl_ was on version 5.5.0 +- _Turing.jl_ was on version 0.34.1 + +The goal is to document some of the considerations that went into the closed PR mentioned above. + +## Gibbs Sampling Considerations + +### Recomputing Log Densities for Parameter Groups + +Let's consider splitting the model parameters into three groups (assuming the grouping stays fixed between iterations). Each parameter group will have a corresponding sampler state (along with the sampler used for that group). + +In the general case, the log densities stored in the states will be incorrect at the time of sampling each group. This is because the values of the other two parameter groups can change from when the current log density was computed, as they get updated within the Gibbs sweep. + +### Current Approach: `recompute_logp!!` + +_Turing.jl_'s current solution, at the time of writing this, is the `recompute_logp!!` function (see [Tor's comment](https://github.com/TuringLang/AbstractMCMC.jl/issues/85#issuecomment-2061300622) and the [`Gibbs` PR](https://github.com/TuringLang/Turing.jl/pull/2099)). + +Here's an example implementation of this function for _AbstractHMC.jl_ ([permalink](https://github.com/TuringLang/Turing.jl/blob/24e68701b01695bffe69eda9e948e910c1ae2996/src/mcmc/abstractmcmc.jl#L77C1-L90C1)): + +```julia +function recompute_logprob!!( + rng::Random.AbstractRNG, + model::AbstractMCMC.LogDensityModel, + sampler::AdvancedHMC.AbstractHMCSampler, + state::AdvancedHMC.HMCState, +) + # Construct hamiltionian. + hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) + # Re-compute the log-probability and gradient. + return Accessors.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, state.transition.z.θ, state.transition.z.r + ) +end +``` + +### Alternative Approach Proposed in [PR #144](https://github.com/TuringLang/AbstractMCMC.jl/pull/144) + +The proposal is to separate `recompute_logp!!` into two functions: + +1. A function to compute the log density given the model and sampler state +2. A function to set the computed log density in the sampler state + +There are a few considerations with this approach: + +- Computing the log density involves a `model`, which may not be defined by the sampler package in the general case. It's unclear if this interface is appropriate, as the model details might be needed to calculate the log density. However, in many situations, the `LogDensityProblems` interface (`LogDensityModel` in `AbstractMCMC`) could be sufficient. + - One interfacial consideration is that `LogDensityProblems.logdensity` expects a vector input. For our use case, we may want to reuse the log density stored in the state instead of recomputing it each time. This would require modifying `logdensity` to accept a sampler state and potentially a boolean flag to indicate whether to recompute the log density or not. +- In some cases, samplers require more than just the log joint density. They may also need the log likelihood and log prior separately (see [this discussion](https://github.com/TuringLang/AbstractMCMC.jl/issues/112)). + +## Potential Path Forward + +A reasonable next step would be to explore an interface similar to `LogDensityProblems.logdensity`, but with the ability to compute both the log prior and log likelihood. It should also accept alternative inputs and keyword arguments. + +To complement this computation interface, we would need functions to `get` and `set` the log likelihood and log prior from/to the sampler states. + +For situations where model-specific details are required to compute the log density from a sampler state, the necessary abstractions are not yet clear. We will need to consider appropriate abstractions as such use cases emerge. + +## Additional Notes on a More Independent Gibbs Implementation + +### Regarding `AbstractPPL.condition` + +While the `condition` function is a promising idea for Gibbs sampling, it is not currently being utilized in _Turing.jl_'s implementation. Instead, _Turing.jl_ uses a `GibbsContext` for reasons outlined [here](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/experimental/gibbs.jl#L1-L12). Additionally, _JuliaBUGS_ requires caching the Markov blanket when calling `condition`, which means the proposed `Gibbs` implementation in this PR would not be fully compatible. + +### Samplers Should Not Manage Variable Names + +To make `AbstractMCMC.Gibbs` more independent and flexible, it should manage a mapping of `range → sampler` rather than `variable name → sampler`. This means it would maintain a vector of parameter values internally. The responsibility of managing both the variable names and any necessary transformations should be handled by a higher-level interface such as `AbstractPPL` or `DynamicPPL`. + +By separating these concerns, `AbstractMCMC.Gibbs` can focus on the core Gibbs sampling logic while the PPL interface handles the specifics of variable naming and transformations. This modular approach allows for greater flexibility and easier integration with different PPL frameworks. + +However, the issue arises when we have transdimensional parameters. In such cases, the parameter space can change during sampling, making it challenging to maintain a fixed mapping between ranges and samplers. From 5127bc88f077946206f4bf61b148b28c625c1926 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 17 Oct 2024 18:41:55 +0800 Subject: [PATCH 2/6] Update design_notes/interface_for_gibbs.md --- design_notes/interface_for_gibbs.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/design_notes/interface_for_gibbs.md b/design_notes/interface_for_gibbs.md index e776850..2cbdf1e 100644 --- a/design_notes/interface_for_gibbs.md +++ b/design_notes/interface_for_gibbs.md @@ -15,7 +15,7 @@ The goal is to document some of the considerations that went into the closed PR ### Recomputing Log Densities for Parameter Groups -Let's consider splitting the model parameters into three groups (assuming the grouping stays fixed between iterations). Each parameter group will have a corresponding sampler state (along with the sampler used for that group). +Let's consider splitting the model parameters into several groups (assuming the grouping stays fixed between iterations). Each parameter group will have a corresponding sampler state (along with the sampler used for that group). In the general case, the log densities stored in the states will be incorrect at the time of sampling each group. This is because the values of the other two parameter groups can change from when the current log density was computed, as they get updated within the Gibbs sweep. From 9e946c0de83f882de2359b071a5405c89bd1999b Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 17 Oct 2024 18:42:00 +0800 Subject: [PATCH 3/6] Update design_notes/interface_for_gibbs.md --- design_notes/interface_for_gibbs.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/design_notes/interface_for_gibbs.md b/design_notes/interface_for_gibbs.md index 2cbdf1e..7ecb341 100644 --- a/design_notes/interface_for_gibbs.md +++ b/design_notes/interface_for_gibbs.md @@ -66,7 +66,7 @@ For situations where model-specific details are required to compute the log dens ### Regarding `AbstractPPL.condition` -While the `condition` function is a promising idea for Gibbs sampling, it is not currently being utilized in _Turing.jl_'s implementation. Instead, _Turing.jl_ uses a `GibbsContext` for reasons outlined [here](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/experimental/gibbs.jl#L1-L12). Additionally, _JuliaBUGS_ requires caching the Markov blanket when calling `condition`, which means the proposed `Gibbs` implementation in this PR would not be fully compatible. +While the `condition` function is a promising idea for Gibbs sampling, it is not currently being utilized in _Turing.jl_'s implementation. Instead, _Turing.jl_ uses a `GibbsContext` for reasons outlined [here](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/experimental/gibbs.jl#L1-L12). Additionally, _JuliaBUGS_ requires caching the Markov blanket when calling `condition`, which means the proposed `Gibbs` implementation in the PR above would not be fully compatible. ### Samplers Should Not Manage Variable Names From 21df935b7cfde1db22583a3e9a220b7dc2cc61a3 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 22 Oct 2024 21:02:13 +0100 Subject: [PATCH 4/6] first draft --- design_notes/logdensity_interface.md | 202 +++++++++++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 design_notes/logdensity_interface.md diff --git a/design_notes/logdensity_interface.md b/design_notes/logdensity_interface.md new file mode 100644 index 0000000..ae26553 --- /dev/null +++ b/design_notes/logdensity_interface.md @@ -0,0 +1,202 @@ +# Proposal for a New LogDensity Function Interface + +## Introduction + +The goal is to design a flexible and user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts such as Gibbs sampling. This interface should facilitate: + +- **Conditioning**: Incorporating observed data into the model. +- **Fixing**: Fixing certain variables to specific values. (like `do` operator) +- **Generated Quantities**: Computing additional expressions or functions based on the model parameters. +- **Prediction**: Making predictions by fixing parameters and unconditioning on data. + +This proposal aims to redefine the interface from the user's perspective, focusing on ease of use and extensibility beyond the traditional probabilistic programming languages (PPLs). + +## Proposed Interface + +Below is a proposed interface with key functionalities and their implementations. + +### Core Functions + +#### Check if a Model is Parametric + +```julia +# Check if a log density model is parametric +function is_parametric(model::LogDensityModel) -> Bool + ... +end +``` + +- **Description**: Determines if the model has a parameter space with a defined dimension. +- + +#### Get the Dimension of a Parametric Model + +```julia +# Get the dimension of the parameter space (only defined when is_parametric(model) is true) +function dimension(model::LogDensityModel) -> Int + ... +end +``` + +- **Description**: Returns the dimension of the parameter space for parametric models. + +### Log Density Computations + +#### Log-Likelihood + +```julia +# Compute the log-likelihood given parameters +function loglikelihood(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 + ... +end +``` + +- **Description**: Computes the log-likelihood of the data given the model parameters. + +#### Log-Prior + +```julia +# Compute the log-prior given parameters +function logprior(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 + ... +end +``` + +- **Description**: Computes the log-prior probability of the model parameters. + +#### Log-Joint + +```julia +# Compute the log-joint density (log-likelihood + log-prior) +function logjoint(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 + return loglikelihood(model, params) + logprior(model, params) +end +``` + +- **Description**: Computes the total log density by summing the log-likelihood and log-prior. + +### Conditioning and Fixing Variables + +#### Conditioning a Model + +```julia +# Condition the model on observed data +function condition(model::LogDensityModel, data::NamedTuple) -> ConditionedModel + ... +end +``` + +- **Description**: Incorporates observed data into the model, returning a `ConditionedModel`. + +#### Checking if a Model is Conditioned + +```julia +# Check if a model is conditioned +function is_conditioned(model::LogDensityModel) -> Bool + ... +end +``` + +- **Description**: Checks whether the model has been conditioned on data. + +#### Fixing Variables in a Model + +```julia +# Fix certain variables in the model +function fix(model::LogDensityModel, variables::NamedTuple) -> FixedModel + ... +end +``` + +- **Description**: Fixes specific variables in the model to given values, returning a `FixedModel`. + +#### Checking if a Model has Fixed Variables + +```julia +# Check if a model has fixed variables +function is_fixed(model::LogDensityModel) -> Bool + ... +end +``` + +- **Description**: Determines if any variables in the model have been fixed. + +### Specialized Models + +#### Conditioned Model Methods + +```julia +# Log-likelihood for a conditioned model +function loglikelihood(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 + ... +end + +# Log-prior for a conditioned model +function logprior(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 + ... +end + +# Log-joint for a conditioned model +function logjoint(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 + return loglikelihood(model, params) + logprior(model, params) +end +``` + +- **Description**: Overrides log density computations to account for the conditioned data. + +#### Fixed Model Methods + +```julia +# Log-likelihood for a fixed model +function loglikelihood(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 + ... +end + +# Log-prior for a fixed model +function logprior(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 + ... +end + +# Log-joint for a fixed model +function logjoint(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 + return loglikelihood(model, data) + logprior(model, data) +end +``` + +- **Description**: Adjusts log density computations based on the fixed variables. + +### Additional Functionalities + +#### Generated Quantities + +```julia +# Compute generated quantities after fixing parameters +function generated_quantities(model::LogDensityModel, fixed_vars::NamedTuple) -> NamedTuple + ... +end +``` + +- **Description**: Computes additional expressions or functions based on the fixed model parameters. + +#### Prediction + +```julia +# Predict data based on fixed parameters +function predict(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> NamedTuple + ... +end +``` + +- **Description**: Generates predictions by fixing the parameters and unconditioning the data. + +## Advantages of the Proposed Interface + +- **Flexibility**: Allows for advanced model operations like conditioning and fixing, essential for methods like Gibbs sampling. + +- **User-Centric Design**: Focuses on usability from the model user's perspective rather than the PPL implementation side. + +- **Consistency**: Maintains a uniform interface for both parametric and non-parametric models, simplifying the learning curve. + +## Usage Examples + +## Non-Parametric Models \ No newline at end of file From 6117e4efee8f16e6fee8a01eff7946a7bf087099 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 23 Oct 2024 08:02:25 +0100 Subject: [PATCH 5/6] simplify --- design_notes/logdensity_interface.md | 214 ++++----------------------- 1 file changed, 26 insertions(+), 188 deletions(-) diff --git a/design_notes/logdensity_interface.md b/design_notes/logdensity_interface.md index ae26553..79e3dd2 100644 --- a/design_notes/logdensity_interface.md +++ b/design_notes/logdensity_interface.md @@ -1,202 +1,40 @@ # Proposal for a New LogDensity Function Interface -## Introduction + -The goal is to design a flexible and user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts such as Gibbs sampling. This interface should facilitate: +The goal is to design a flexible, user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts like Gibbs sampling and Bayesian workflows. -- **Conditioning**: Incorporating observed data into the model. -- **Fixing**: Fixing certain variables to specific values. (like `do` operator) -- **Generated Quantities**: Computing additional expressions or functions based on the model parameters. -- **Prediction**: Making predictions by fixing parameters and unconditioning on data. +## Evaluation functions: -This proposal aims to redefine the interface from the user's perspective, focusing on ease of use and extensibility beyond the traditional probabilistic programming languages (PPLs). +1. `evaluate` -## Proposed Interface +## Query functions: -Below is a proposed interface with key functionalities and their implementations. +1. `is_parametric(model)` +2. `dimension(model)` (only defined when `is_parametric(model) == true`) +3. `is_conditioned(model)` +4. `is_fixed(model)` +5. `logjoint(model, params)` +6. `loglikelihood(model, params)` +7. `logprior(model, params)` -### Core Functions +where `params` can be `Vector`, `NamedTuple`, `Dict`, etc. -#### Check if a Model is Parametric +## Transformation functions: -```julia -# Check if a log density model is parametric -function is_parametric(model::LogDensityModel) -> Bool - ... -end -``` +1. `condition(model, conditioned_vars)` +2. `fix(model, fixed_vars)` +3. `factor(model, variables_in_the_factor)` -- **Description**: Determines if the model has a parameter space with a defined dimension. -- +`condition` and `factor` are similar, but `factor` effectively generates a sub-model. -#### Get the Dimension of a Parametric Model +## Higher-order functions: -```julia -# Get the dimension of the parameter space (only defined when is_parametric(model) is true) -function dimension(model::LogDensityModel) -> Int - ... -end -``` +1. `generated_quantities(model, sample, [, expr])` or `generated_quantities(model, sample, f, args...)` + 1. `generated_quantities` computes things from the sampling result. + 2. In `DynamicPPL`, this is the model's return value. For more flexibility, we should allow passing an expression or function. (Currently, users can rewrite the model definition to achieve this in `DynamicPPL`, but with limitations. We want to make this more generic.) + 3. `rand` is a special case of `generated_quantities` (when no sample is passed). +2. `predict(model, sample)` -- **Description**: Returns the dimension of the parameter space for parametric models. - -### Log Density Computations - -#### Log-Likelihood - -```julia -# Compute the log-likelihood given parameters -function loglikelihood(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 - ... -end -``` - -- **Description**: Computes the log-likelihood of the data given the model parameters. - -#### Log-Prior - -```julia -# Compute the log-prior given parameters -function logprior(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 - ... -end -``` - -- **Description**: Computes the log-prior probability of the model parameters. - -#### Log-Joint - -```julia -# Compute the log-joint density (log-likelihood + log-prior) -function logjoint(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 - return loglikelihood(model, params) + logprior(model, params) -end -``` - -- **Description**: Computes the total log density by summing the log-likelihood and log-prior. - -### Conditioning and Fixing Variables - -#### Conditioning a Model - -```julia -# Condition the model on observed data -function condition(model::LogDensityModel, data::NamedTuple) -> ConditionedModel - ... -end -``` - -- **Description**: Incorporates observed data into the model, returning a `ConditionedModel`. - -#### Checking if a Model is Conditioned - -```julia -# Check if a model is conditioned -function is_conditioned(model::LogDensityModel) -> Bool - ... -end -``` - -- **Description**: Checks whether the model has been conditioned on data. - -#### Fixing Variables in a Model - -```julia -# Fix certain variables in the model -function fix(model::LogDensityModel, variables::NamedTuple) -> FixedModel - ... -end -``` - -- **Description**: Fixes specific variables in the model to given values, returning a `FixedModel`. - -#### Checking if a Model has Fixed Variables - -```julia -# Check if a model has fixed variables -function is_fixed(model::LogDensityModel) -> Bool - ... -end -``` - -- **Description**: Determines if any variables in the model have been fixed. - -### Specialized Models - -#### Conditioned Model Methods - -```julia -# Log-likelihood for a conditioned model -function loglikelihood(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 - ... -end - -# Log-prior for a conditioned model -function logprior(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 - ... -end - -# Log-joint for a conditioned model -function logjoint(model::ConditionedModel, params::Union{Vector, NamedTuple, Dict}) -> Float64 - return loglikelihood(model, params) + logprior(model, params) -end -``` - -- **Description**: Overrides log density computations to account for the conditioned data. - -#### Fixed Model Methods - -```julia -# Log-likelihood for a fixed model -function loglikelihood(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 - ... -end - -# Log-prior for a fixed model -function logprior(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 - ... -end - -# Log-joint for a fixed model -function logjoint(model::FixedModel, data::Union{Vector, NamedTuple, Dict}) -> Float64 - return loglikelihood(model, data) + logprior(model, data) -end -``` - -- **Description**: Adjusts log density computations based on the fixed variables. - -### Additional Functionalities - -#### Generated Quantities - -```julia -# Compute generated quantities after fixing parameters -function generated_quantities(model::LogDensityModel, fixed_vars::NamedTuple) -> NamedTuple - ... -end -``` - -- **Description**: Computes additional expressions or functions based on the fixed model parameters. - -#### Prediction - -```julia -# Predict data based on fixed parameters -function predict(model::LogDensityModel, params::Union{Vector, NamedTuple, Dict}) -> NamedTuple - ... -end -``` - -- **Description**: Generates predictions by fixing the parameters and unconditioning the data. - -## Advantages of the Proposed Interface - -- **Flexibility**: Allows for advanced model operations like conditioning and fixing, essential for methods like Gibbs sampling. - -- **User-Centric Design**: Focuses on usability from the model user's perspective rather than the PPL implementation side. - -- **Consistency**: Maintains a uniform interface for both parametric and non-parametric models, simplifying the learning curve. - -## Usage Examples - -## Non-Parametric Models \ No newline at end of file +`generated_quantities` can be implemented by `fix`ing the model on `sample` and calling `evaluate`. +`predict` can be implemented by `uncondition`ing the model on `data`, fixing it on `sample`, and calling `evaluate`. From 1b0e26ca829abaa49423862735234cdcd684ca32 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 23 Oct 2024 08:05:06 +0100 Subject: [PATCH 6/6] move the discussion to github --- design_notes/logdensity_interface.md | 40 ---------------------------- 1 file changed, 40 deletions(-) delete mode 100644 design_notes/logdensity_interface.md diff --git a/design_notes/logdensity_interface.md b/design_notes/logdensity_interface.md deleted file mode 100644 index 79e3dd2..0000000 --- a/design_notes/logdensity_interface.md +++ /dev/null @@ -1,40 +0,0 @@ -# Proposal for a New LogDensity Function Interface - - - -The goal is to design a flexible, user-friendly interface for log density functions that can handle various model operations, especially in higher-order contexts like Gibbs sampling and Bayesian workflows. - -## Evaluation functions: - -1. `evaluate` - -## Query functions: - -1. `is_parametric(model)` -2. `dimension(model)` (only defined when `is_parametric(model) == true`) -3. `is_conditioned(model)` -4. `is_fixed(model)` -5. `logjoint(model, params)` -6. `loglikelihood(model, params)` -7. `logprior(model, params)` - -where `params` can be `Vector`, `NamedTuple`, `Dict`, etc. - -## Transformation functions: - -1. `condition(model, conditioned_vars)` -2. `fix(model, fixed_vars)` -3. `factor(model, variables_in_the_factor)` - -`condition` and `factor` are similar, but `factor` effectively generates a sub-model. - -## Higher-order functions: - -1. `generated_quantities(model, sample, [, expr])` or `generated_quantities(model, sample, f, args...)` - 1. `generated_quantities` computes things from the sampling result. - 2. In `DynamicPPL`, this is the model's return value. For more flexibility, we should allow passing an expression or function. (Currently, users can rewrite the model definition to achieve this in `DynamicPPL`, but with limitations. We want to make this more generic.) - 3. `rand` is a special case of `generated_quantities` (when no sample is passed). -2. `predict(model, sample)` - -`generated_quantities` can be implemented by `fix`ing the model on `sample` and calling `evaluate`. -`predict` can be implemented by `uncondition`ing the model on `data`, fixing it on `sample`, and calling `evaluate`.