Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamline internal de/conditioning interface #776

Merged
merged 8 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.33.0"
version = "0.33.1"
sunxd3 marked this conversation as resolved.
Show resolved Hide resolved

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ DynamicPPL.LogDensityFunction
A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).

```@docs
|(::Model, ::Any)
|(::Model, ::Union{Tuple,NamedTuple,AbstractDict{<:VarName}})
condition
DynamicPPL.conditioned
```
Expand Down Expand Up @@ -403,6 +403,7 @@ LikelihoodContext
PriorContext
MiniBatchContext
PrefixContext
ConditionContext
```

### Samplers
Expand Down
124 changes: 50 additions & 74 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,20 +309,40 @@
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
end

struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext
"""

ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext}

Model context that contains values that are to be conditioned on. The values
can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or
an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1,
@varname(b) => 2)`). The former is more performant, but the latter must be used
when there are varnames that cannot be represented as symbols, e.g.
`@varname(x[1])`.
"""
struct ConditionContext{
Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext
} <: AbstractContext
values::Values
context::Ctx
end

const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}}
const DictConditionContext = ConditionContext{<:AbstractDict}

ConditionContext(values) = ConditionContext(values, DefaultContext())

# Try to avoid nested `ConditionContext`.
# Use DefaultContext as the default base context
function ConditionContext(values::Union{NamedTuple,AbstractDict})
return ConditionContext(values, DefaultContext())
end
# Optimisation when there are no values to condition on
ConditionContext(::NamedTuple{()}, context::AbstractContext) = context

Check warning on line 338 in src/contexts.jl

View check run for this annotation

Codecov / codecov/patch

src/contexts.jl#L338

Added line #L338 was not covered by tests
# Collapse consecutive levels of `ConditionContext`. Note that this overrides
# values inside the child context, thus giving precedence to the outermost
# `ConditionContext`.
function ConditionContext(values::NamedTuple, context::NamedConditionContext)
# Note that this potentially overrides values from `context`, thus giving
# precedence to the outmost `ConditionContext`.
return ConditionContext(merge(context.values, values), childcontext(context))
end
function ConditionContext(values::AbstractDict{<:VarName}, context::DictConditionContext)
return ConditionContext(merge(context.values, values), childcontext(context))
end

Expand Down Expand Up @@ -399,43 +419,6 @@
end
end

"""
condition([context::AbstractContext,] values::NamedTuple)
condition([context::AbstractContext]; values...)

Return `ConditionContext` with `values` and `context` if `values` is non-empty,
otherwise return `context` which is [`DefaultContext`](@ref) by default.

See also: [`decondition`](@ref)
"""
AbstractPPL.condition(; values...) = condition(NamedTuple(values))
AbstractPPL.condition(values::NamedTuple) = condition(DefaultContext(), values)
function AbstractPPL.condition(value::Pair{<:VarName}, values::Pair{<:VarName}...)
return condition((value, values...))
end
function AbstractPPL.condition(values::NTuple{<:Any,<:Pair{<:VarName}})
return condition(DefaultContext(), values)
end
AbstractPPL.condition(context::AbstractContext, values::NamedTuple{()}) = context
function AbstractPPL.condition(
context::AbstractContext, values::Union{AbstractDict,NamedTuple}
)
return ConditionContext(values, context)
end
function AbstractPPL.condition(context::AbstractContext; values...)
return condition(context, NamedTuple(values))
end
function AbstractPPL.condition(
context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...
)
return condition(context, (value, values...))
end
function AbstractPPL.condition(
context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}
)
return condition(context, Dict(values))
end

"""
decondition(context::AbstractContext, syms...)

Expand All @@ -445,41 +428,34 @@

See also: [`condition`](@ref)
"""
AbstractPPL.decondition(::IsLeaf, context, args...) = context
function AbstractPPL.decondition(::IsParent, context, args...)
return setchildcontext(context, decondition(childcontext(context), args...))
decondition_context(::IsLeaf, context, args...) = context
function decondition_context(::IsParent, context, args...)
return setchildcontext(context, decondition_context(childcontext(context), args...))

Check warning on line 433 in src/contexts.jl

View check run for this annotation

Codecov / codecov/patch

src/contexts.jl#L432-L433

Added lines #L432 - L433 were not covered by tests
end
function AbstractPPL.decondition(context, args...)
return decondition(NodeTrait(context), context, args...)
function decondition_context(context, args...)
return decondition_context(NodeTrait(context), context, args...)
end
function AbstractPPL.decondition(context::ConditionContext)
return decondition(childcontext(context))
end
function AbstractPPL.decondition(context::ConditionContext, sym)
return condition(
decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym)
)
function decondition_context(context::ConditionContext)
return decondition_context(childcontext(context))
end
function AbstractPPL.decondition(context::ConditionContext, sym, syms...)
return decondition(
condition(
decondition(childcontext(context), syms...),
BangBang.delete!!(context.values, sym),
),
syms...,
)
end

function AbstractPPL.decondition(
context::NamedConditionContext, vn::VarName{sym}
) where {sym}
return condition(
decondition(childcontext(context), vn), BangBang.delete!!(context.values, sym)
)
function decondition_context(context::ConditionContext, sym, syms...)
new_values = deepcopy(context.values)
for s in (sym, syms...)
new_values = BangBang.delete!!(new_values, s)
end
return if length(new_values) == 0
# No more values left, can unwrap
decondition_context(childcontext(context), syms...)
else
ConditionContext(
new_values, decondition_context(childcontext(context), sym, syms...)
)
end
end
function AbstractPPL.decondition(context::ConditionContext, vn::VarName)
return condition(
decondition(childcontext(context), vn), BangBang.delete!!(context.values, vn)
function decondition_context(context::NamedConditionContext, vn::VarName{sym}) where {sym}
return ConditionContext(
BangBang.delete!!(context.values, sym),
decondition_context(childcontext(context), vn),
)
end

Expand Down
38 changes: 30 additions & 8 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@

See [`condition`](@ref) for more information and examples.
"""
Base.:|(model::Model, values) = condition(model, values)
Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}}) =
condition(model, values)

"""
condition(model::Model; values...)
Expand Down Expand Up @@ -264,11 +265,32 @@
1.0
```
"""
AbstractPPL.condition(model::Model; values...) = condition(model, NamedTuple(values))
function AbstractPPL.condition(model::Model, value, values...)
return contextualize(model, condition(model.context, value, values...))
function AbstractPPL.condition(model::Model, values...)
# Positional arguments - need to handle cases carefully
return contextualize(
model, ConditionContext(_make_conditioning_values(values...), model.context)
)
end
function AbstractPPL.condition(model::Model; values...)
# Keyword arguments -- just convert to a NamedTuple
return contextualize(model, ConditionContext(NamedTuple(values), model.context))
end

"""
_make_conditioning_values(vals...)

Convert different types of input to either a `NamedTuple` or `AbstractDict` of
conditioning values, suitable for storage in a `ConditionContext`.

This handles all the cases where `vals` is either already a NamedTuple or
AbstractDict (e.g. `model | (x=1, y=2)`), as well as if they are splatted (e.g.
`condition(model, x=1, y=2)`).
"""
_make_conditioning_values(values::Union{NamedTuple,AbstractDict}) = values
_make_conditioning_values(values::Tuple{Pair{<:VarName}}) = Dict(values)
_make_conditioning_values(v::Pair{<:Symbol}, vs::Pair{<:Symbol}...) = NamedTuple(v, vs...)

Check warning on line 291 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L291

Added line #L291 was not covered by tests
_make_conditioning_values(v::Pair{<:VarName}, vs::Pair{<:VarName}...) = Dict(v, vs...)

"""
decondition(model::Model)
decondition(model::Model, variables...)
Expand Down Expand Up @@ -379,7 +401,7 @@
```
"""
function AbstractPPL.decondition(model::Model, syms...)
return contextualize(model, decondition(model.context, syms...))
return contextualize(model, decondition_context(model.context, syms...))
end

"""
Expand Down Expand Up @@ -413,7 +435,7 @@
(x = 100.0, m = 1.0)

julia> # Nested ones also work (note that `PrefixContext` does nothing to the result).
cm = condition(contextualize(m, PrefixContext{:a}(condition(m=1.0))), x=100.0);
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0);

julia> conditioned(cm)
(x = 100.0, m = 1.0)
Expand All @@ -425,15 +447,15 @@
a.m

julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation.
cm = condition(contextualize(m, PrefixContext{:a}(condition(var"a.m"=1.0))), x=100.0);
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((var"a.m"=1.0,)))), x=100.0);

julia> conditioned(cm).x
100.0

julia> conditioned(cm).var"a.m"
1.0

julia> keys(VarInfo(cm)) # <= no variables are sampled
julia> keys(VarInfo(cm)) # No variables are sampled
VarName[]
```
"""
Expand Down
83 changes: 83 additions & 0 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using DynamicPPL:
PointwiseLogdensityContext,
contextual_isassumption,
ConditionContext,
decondition_context,
hasconditioned,
getconditioned,
hasconditioned_nested,
Expand Down Expand Up @@ -196,6 +197,88 @@ end
@test EnzymeCore.EnzymeRules.inactive_type(typeof(context))
end

@testset "ConditionContext" begin
@testset "Nesting" begin
@testset "NamedTuple" begin
n1 = (x=1, y=2)
n2 = (x=3,)
# Values from outer context should override inner one
ctx1 = ConditionContext(n1, ConditionContext(n2))
@test ctx1.values == (x=1, y=2)
# Check that the two ConditionContexts are collapsed
@test childcontext(ctx1) isa DefaultContext
# Then test the nesting the other way round
ctx2 = ConditionContext(n2, ConditionContext(n1))
@test ctx2.values == (x=3, y=2)
@test childcontext(ctx2) isa DefaultContext
end

@testset "Dict" begin
# Same tests as NamedTuple above
d1 = Dict(@varname(x) => 1, @varname(y) => 2)
d2 = Dict(@varname(x) => 3)
ctx1 = ConditionContext(d1, ConditionContext(d2))
@test ctx1.values == Dict(@varname(x) => 1, @varname(y) => 2)
@test childcontext(ctx1) isa DefaultContext
ctx2 = ConditionContext(d2, ConditionContext(d1))
@test ctx2.values == Dict(@varname(x) => 3, @varname(y) => 2)
@test childcontext(ctx2) isa DefaultContext
end
end

@testset "decondition_context" begin
@testset "NamedTuple" begin
ctx = ConditionContext((x=1, y=2, z=3))
# Decondition all variables
@test decondition_context(ctx) isa DefaultContext
# Decondition only some variables
dctx = decondition_context(ctx, :x)
@test dctx isa ConditionContext
@test dctx.values == (y=2, z=3)
dctx = decondition_context(ctx, :y, :z)
@test dctx isa ConditionContext
@test dctx.values == (x=1,)
# Decondition all variables manually
@test decondition_context(ctx, :x, :y, :z) isa DefaultContext
end

@testset "Dict" begin
ctx = ConditionContext(
Dict(@varname(x) => 1, @varname(y) => 2, @varname(z) => 3)
)
# Decondition all variables
@test decondition_context(ctx) isa DefaultContext
# Decondition only some variables
dctx = decondition_context(ctx, @varname(x))
@test dctx isa ConditionContext
@test dctx.values == Dict(@varname(y) => 2, @varname(z) => 3)
dctx = decondition_context(ctx, @varname(y), @varname(z))
@test dctx isa ConditionContext
@test dctx.values == Dict(@varname(x) => 1)
# Decondition all variables manually
@test decondition_context(ctx, @varname(x), @varname(y), @varname(z)) isa
DefaultContext
end

@testset "Nesting" begin
ctx = ConditionContext(
(x=1, y=2), ConditionContext(Dict(@varname(a) => 3, @varname(b) => 4))
)
# Decondition an outer variable
dctx = decondition_context(ctx, :x)
@test dctx.values == (y=2,)
@test childcontext(dctx).values == Dict(@varname(a) => 3, @varname(b) => 4)
# Decondition an inner variable
dctx = decondition_context(ctx, @varname(a))
@test dctx.values == (x=1, y=2)
@test childcontext(dctx).values == Dict(@varname(b) => 4)
# Try deconditioning everything
dctx = decondition_context(ctx)
@test dctx isa DefaultContext
end
end
end

@testset "FixedContext" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
retval = model()
Expand Down
Loading
Loading