Skip to content

Commit

Permalink
Streamline internal de/conditioning interface (#776)
Browse files Browse the repository at this point in the history
* Remove `condition` type piracy

* Add tests for model conditioning syntax

* Add tests for ConditionContext/decondition_context

* Format

* Bump patch version

* Add ConditionContext docstring to docs

* Fix type annotation of | in docs

* Fix remaining bugs e.g. in nested `decondition_context`
  • Loading branch information
penelopeysm authored Jan 10, 2025
1 parent 003ff2f commit e673b69
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 84 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.33.0"
version = "0.33.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ DynamicPPL.LogDensityFunction
A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).

```@docs
|(::Model, ::Any)
|(::Model, ::Union{Tuple,NamedTuple,AbstractDict{<:VarName}})
condition
DynamicPPL.conditioned
```
Expand Down Expand Up @@ -403,6 +403,7 @@ LikelihoodContext
PriorContext
MiniBatchContext
PrefixContext
ConditionContext
```

### Samplers
Expand Down
124 changes: 50 additions & 74 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,20 +309,40 @@ function prefix(model::Model, ::Val{x}) where {x}
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
end

struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext
"""
ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext}
Model context that contains values that are to be conditioned on. The values
can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or
an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1,
@varname(b) => 2)`). The former is more performant, but the latter must be used
when there are varnames that cannot be represented as symbols, e.g.
`@varname(x[1])`.
"""
struct ConditionContext{
Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext
} <: AbstractContext
values::Values
context::Ctx
end

const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}}
const DictConditionContext = ConditionContext{<:AbstractDict}

ConditionContext(values) = ConditionContext(values, DefaultContext())

# Try to avoid nested `ConditionContext`.
# Use DefaultContext as the default base context
function ConditionContext(values::Union{NamedTuple,AbstractDict})
return ConditionContext(values, DefaultContext())
end
# Optimisation when there are no values to condition on
ConditionContext(::NamedTuple{()}, context::AbstractContext) = context
# Collapse consecutive levels of `ConditionContext`. Note that this overrides
# values inside the child context, thus giving precedence to the outermost
# `ConditionContext`.
function ConditionContext(values::NamedTuple, context::NamedConditionContext)
# Note that this potentially overrides values from `context`, thus giving
# precedence to the outmost `ConditionContext`.
return ConditionContext(merge(context.values, values), childcontext(context))
end
function ConditionContext(values::AbstractDict{<:VarName}, context::DictConditionContext)
return ConditionContext(merge(context.values, values), childcontext(context))
end

Expand Down Expand Up @@ -399,43 +419,6 @@ function getconditioned_nested(::IsParent, context, vn)
end
end

"""
condition([context::AbstractContext,] values::NamedTuple)
condition([context::AbstractContext]; values...)
Return `ConditionContext` with `values` and `context` if `values` is non-empty,
otherwise return `context` which is [`DefaultContext`](@ref) by default.
See also: [`decondition`](@ref)
"""
AbstractPPL.condition(; values...) = condition(NamedTuple(values))
AbstractPPL.condition(values::NamedTuple) = condition(DefaultContext(), values)
function AbstractPPL.condition(value::Pair{<:VarName}, values::Pair{<:VarName}...)
return condition((value, values...))
end
function AbstractPPL.condition(values::NTuple{<:Any,<:Pair{<:VarName}})
return condition(DefaultContext(), values)
end
AbstractPPL.condition(context::AbstractContext, values::NamedTuple{()}) = context
function AbstractPPL.condition(
context::AbstractContext, values::Union{AbstractDict,NamedTuple}
)
return ConditionContext(values, context)
end
function AbstractPPL.condition(context::AbstractContext; values...)
return condition(context, NamedTuple(values))
end
function AbstractPPL.condition(
context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...
)
return condition(context, (value, values...))
end
function AbstractPPL.condition(
context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}
)
return condition(context, Dict(values))
end

"""
decondition(context::AbstractContext, syms...)
Expand All @@ -445,41 +428,34 @@ Note that this recursively traverses contexts, deconditioning all along the way.
See also: [`condition`](@ref)
"""
AbstractPPL.decondition(::IsLeaf, context, args...) = context
function AbstractPPL.decondition(::IsParent, context, args...)
return setchildcontext(context, decondition(childcontext(context), args...))
decondition_context(::IsLeaf, context, args...) = context
function decondition_context(::IsParent, context, args...)
return setchildcontext(context, decondition_context(childcontext(context), args...))
end
function AbstractPPL.decondition(context, args...)
return decondition(NodeTrait(context), context, args...)
function decondition_context(context, args...)
return decondition_context(NodeTrait(context), context, args...)
end
function AbstractPPL.decondition(context::ConditionContext)
return decondition(childcontext(context))
end
function AbstractPPL.decondition(context::ConditionContext, sym)
return condition(
decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym)
)
function decondition_context(context::ConditionContext)
return decondition_context(childcontext(context))
end
function AbstractPPL.decondition(context::ConditionContext, sym, syms...)
return decondition(
condition(
decondition(childcontext(context), syms...),
BangBang.delete!!(context.values, sym),
),
syms...,
)
end

function AbstractPPL.decondition(
context::NamedConditionContext, vn::VarName{sym}
) where {sym}
return condition(
decondition(childcontext(context), vn), BangBang.delete!!(context.values, sym)
)
function decondition_context(context::ConditionContext, sym, syms...)
new_values = deepcopy(context.values)
for s in (sym, syms...)
new_values = BangBang.delete!!(new_values, s)
end
return if length(new_values) == 0
# No more values left, can unwrap
decondition_context(childcontext(context), syms...)
else
ConditionContext(
new_values, decondition_context(childcontext(context), sym, syms...)
)
end
end
function AbstractPPL.decondition(context::ConditionContext, vn::VarName)
return condition(
decondition(childcontext(context), vn), BangBang.delete!!(context.values, vn)
function decondition_context(context::NamedConditionContext, vn::VarName{sym}) where {sym}
return ConditionContext(
BangBang.delete!!(context.values, sym),
decondition_context(childcontext(context), vn),
)
end

Expand Down
38 changes: 30 additions & 8 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ Return a `Model` which now treats variables on the right-hand side as observatio
See [`condition`](@ref) for more information and examples.
"""
Base.:|(model::Model, values) = condition(model, values)
Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}}) =
condition(model, values)

"""
condition(model::Model; values...)
Expand Down Expand Up @@ -264,11 +265,32 @@ julia> conditioned_model_dict()
1.0
```
"""
AbstractPPL.condition(model::Model; values...) = condition(model, NamedTuple(values))
function AbstractPPL.condition(model::Model, value, values...)
return contextualize(model, condition(model.context, value, values...))
function AbstractPPL.condition(model::Model, values...)
# Positional arguments - need to handle cases carefully
return contextualize(
model, ConditionContext(_make_conditioning_values(values...), model.context)
)
end
function AbstractPPL.condition(model::Model; values...)
# Keyword arguments -- just convert to a NamedTuple
return contextualize(model, ConditionContext(NamedTuple(values), model.context))
end

"""
_make_conditioning_values(vals...)
Convert different types of input to either a `NamedTuple` or `AbstractDict` of
conditioning values, suitable for storage in a `ConditionContext`.
This handles all the cases where `vals` is either already a NamedTuple or
AbstractDict (e.g. `model | (x=1, y=2)`), as well as if they are splatted (e.g.
`condition(model, x=1, y=2)`).
"""
_make_conditioning_values(values::Union{NamedTuple,AbstractDict}) = values
_make_conditioning_values(values::Tuple{Pair{<:VarName}}) = Dict(values)
_make_conditioning_values(v::Pair{<:Symbol}, vs::Pair{<:Symbol}...) = NamedTuple(v, vs...)
_make_conditioning_values(v::Pair{<:VarName}, vs::Pair{<:VarName}...) = Dict(v, vs...)

"""
decondition(model::Model)
decondition(model::Model, variables...)
Expand Down Expand Up @@ -379,7 +401,7 @@ true
```
"""
function AbstractPPL.decondition(model::Model, syms...)
return contextualize(model, decondition(model.context, syms...))
return contextualize(model, decondition_context(model.context, syms...))
end

"""
Expand Down Expand Up @@ -413,7 +435,7 @@ julia> # Returns all the variables we have conditioned on + their values.
(x = 100.0, m = 1.0)
julia> # Nested ones also work (note that `PrefixContext` does nothing to the result).
cm = condition(contextualize(m, PrefixContext{:a}(condition(m=1.0))), x=100.0);
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0);
julia> conditioned(cm)
(x = 100.0, m = 1.0)
Expand All @@ -425,15 +447,15 @@ julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed
a.m
julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation.
cm = condition(contextualize(m, PrefixContext{:a}(condition(var"a.m"=1.0))), x=100.0);
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((var"a.m"=1.0,)))), x=100.0);
julia> conditioned(cm).x
100.0
julia> conditioned(cm).var"a.m"
1.0
julia> keys(VarInfo(cm)) # <= no variables are sampled
julia> keys(VarInfo(cm)) # No variables are sampled
VarName[]
```
"""
Expand Down
83 changes: 83 additions & 0 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using DynamicPPL:
PointwiseLogdensityContext,
contextual_isassumption,
ConditionContext,
decondition_context,
hasconditioned,
getconditioned,
hasconditioned_nested,
Expand Down Expand Up @@ -196,6 +197,88 @@ end
@test EnzymeCore.EnzymeRules.inactive_type(typeof(context))
end

@testset "ConditionContext" begin
@testset "Nesting" begin
@testset "NamedTuple" begin
n1 = (x=1, y=2)
n2 = (x=3,)
# Values from outer context should override inner one
ctx1 = ConditionContext(n1, ConditionContext(n2))
@test ctx1.values == (x=1, y=2)
# Check that the two ConditionContexts are collapsed
@test childcontext(ctx1) isa DefaultContext
# Then test the nesting the other way round
ctx2 = ConditionContext(n2, ConditionContext(n1))
@test ctx2.values == (x=3, y=2)
@test childcontext(ctx2) isa DefaultContext
end

@testset "Dict" begin
# Same tests as NamedTuple above
d1 = Dict(@varname(x) => 1, @varname(y) => 2)
d2 = Dict(@varname(x) => 3)
ctx1 = ConditionContext(d1, ConditionContext(d2))
@test ctx1.values == Dict(@varname(x) => 1, @varname(y) => 2)
@test childcontext(ctx1) isa DefaultContext
ctx2 = ConditionContext(d2, ConditionContext(d1))
@test ctx2.values == Dict(@varname(x) => 3, @varname(y) => 2)
@test childcontext(ctx2) isa DefaultContext
end
end

@testset "decondition_context" begin
@testset "NamedTuple" begin
ctx = ConditionContext((x=1, y=2, z=3))
# Decondition all variables
@test decondition_context(ctx) isa DefaultContext
# Decondition only some variables
dctx = decondition_context(ctx, :x)
@test dctx isa ConditionContext
@test dctx.values == (y=2, z=3)
dctx = decondition_context(ctx, :y, :z)
@test dctx isa ConditionContext
@test dctx.values == (x=1,)
# Decondition all variables manually
@test decondition_context(ctx, :x, :y, :z) isa DefaultContext
end

@testset "Dict" begin
ctx = ConditionContext(
Dict(@varname(x) => 1, @varname(y) => 2, @varname(z) => 3)
)
# Decondition all variables
@test decondition_context(ctx) isa DefaultContext
# Decondition only some variables
dctx = decondition_context(ctx, @varname(x))
@test dctx isa ConditionContext
@test dctx.values == Dict(@varname(y) => 2, @varname(z) => 3)
dctx = decondition_context(ctx, @varname(y), @varname(z))
@test dctx isa ConditionContext
@test dctx.values == Dict(@varname(x) => 1)
# Decondition all variables manually
@test decondition_context(ctx, @varname(x), @varname(y), @varname(z)) isa
DefaultContext
end

@testset "Nesting" begin
ctx = ConditionContext(
(x=1, y=2), ConditionContext(Dict(@varname(a) => 3, @varname(b) => 4))
)
# Decondition an outer variable
dctx = decondition_context(ctx, :x)
@test dctx.values == (y=2,)
@test childcontext(dctx).values == Dict(@varname(a) => 3, @varname(b) => 4)
# Decondition an inner variable
dctx = decondition_context(ctx, @varname(a))
@test dctx.values == (x=1, y=2)
@test childcontext(dctx).values == Dict(@varname(b) => 4)
# Try deconditioning everything
dctx = decondition_context(ctx)
@test dctx isa DefaultContext
end
end
end

@testset "FixedContext" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
retval = model()
Expand Down
Loading

2 comments on commit e673b69

@penelopeysm
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Reworked internals of condition and decondition. There are no changes to the public-facing API, but internally you can no longer use condition and decondition on an AbstractContext, you can only use it on a DynamicPPL.Model. If you want to modify a context, use ConditionContext and decondition_context.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/122758

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.33.1 -m "<description of version>" e673b69210e85b60199f2ccd8165226cb03cf040
git push origin v0.33.1

Please sign in to comment.