Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fix and unfix #488

Merged
merged 21 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.0"
version = "0.23.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
18 changes: 17 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ Similarly, one can specify with [`AbstractPPL.decondition`](@ref) that certain,
decondition
```

## Fixing and unfixing

We can also fix a collection of variables in a [`Model`](@ref) to certain values using [`fix`](@ref):
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

```@docs
fix
fixed
```

The difference between [`fix`](@ref) and [`condition`](@ref) is described in the docstring of [`fix`](@ref) above.

Similarly, we can "unfix" variables, i.e. return them to their original meaning, using [`unfix`](@ref)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

```@docs
unfix
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
```

## Utilities

It is possible to manually increase (or decrease) the accumulated log density from within a model function.
Expand Down Expand Up @@ -321,4 +338,3 @@ dot_tilde_assume
tilde_observe
dot_tilde_observe
```

2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ export AbstractVarInfo,
pointwise_loglikelihoods,
condition,
decondition,
fix,
unfix,
# Convenience macros
@addlogprob!,
@submodel
Expand Down
45 changes: 42 additions & 3 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,49 @@ function contextual_isassumption(context::ConditionContext, vn)
end
end

# We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`
# We might have nested contexts, e.g. `ConditionContext{.., <:PrefixContext{..., <:ConditionContext}}`
# so we defer to `childcontext` if we haven't concluded that anything yet.
return contextual_isassumption(childcontext(context), vn)
end
function contextual_isassumption(context::PrefixContext, vn)
return contextual_isassumption(childcontext(context), prefix(context, vn))
end

isfixed(expr, vn) = false
isfixed(::Union{Symbol,Expr}, vn) = :($(DynamicPPL.contextual_isfixed)(__context__, $vn))

"""
contextual_isfixed(context, vn)

Return `true` if `vn` is considered fixed by `context`.
"""
contextual_isfixed(::IsLeaf, context, vn) = false
function contextual_isfixed(::IsParent, context, vn)
return contextual_isfixed(childcontext(context), vn)
end
function contextual_isfixed(context::AbstractContext, vn)
return contextual_isfixed(NodeTrait(context), context, vn)
end
function contextual_isfixed(context::PrefixContext, vn)
return contextual_isfixed(childcontext(context), prefix(context, vn))
end
function contextual_isfixed(context::FixedContext, vn)
if has_fixed_value(context, vn)
val = get_fixed_value(context, vn)
# TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler?
if eltype(val) >: Missing && val === missing
return false
else
return true
end
end

# We might have nested contexts, e.g. `FixedContext{.., <:PrefixContext{..., <:FixedContext}}`
# so we defer to `childcontext` if we haven't concluded that anything yet.
return contextual_isfixed(childcontext(context), vn)
end


torfjelde marked this conversation as resolved.
Show resolved Hide resolved
# If we're working with, say, a `Symbol`, then we're not going to `view`.
maybe_view(x) = x
maybe_view(x::Expr) = :(@views($x))
Expand Down Expand Up @@ -341,7 +376,9 @@ function generate_tilde(left, right)
$(AbstractPPL.drop_escape(varname(left))), $dist
)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
if $(DynamicPPL.isfixed(left, vn))
$left = $(DynamicPPL.get_fixed_value_nested)(__context__, $vn)
elseif $isassumption
$(generate_tilde_assume(left, dist, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
Expand Down Expand Up @@ -400,7 +437,9 @@ function generate_dot_tilde(left, right)
$(AbstractPPL.drop_escape(varname(left))), $right
)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
if $(DynamicPPL.isfixed(left, vn))
$left .= $(DynamicPPL.get_fixed_value_nested)(__context__, $vn)
elseif $isassumption
$(generate_dot_tilde_assume(left, right, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
Expand Down
176 changes: 176 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,179 @@ function conditioned(context::ConditionContext)
# precedence over decendants of `context`.
return merge(context.values, conditioned(childcontext(context)))
end

struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext
values::Values
context::Ctx
end

const NamedFixedContext{Names} = FixedContext{<:NamedTuple{Names}}
const DictFixedContext = FixedContext{<:AbstractDict}

FixedContext(values) = FixedContext(values, DefaultContext())

# Try to avoid nested `FixedContext`.
function FixedContext(values::NamedTuple, context::NamedFixedContext)
# Note that this potentially overrides values from `context`, thus giving
# precedence to the outmost `FixedContext`.
return FixedContext(merge(context.values, values), childcontext(context))
end

function Base.show(io::IO, context::FixedContext)
return print(io, "FixedContext($(context.values), $(childcontext(context)))")
end

NodeTrait(::FixedContext) = IsParent()
childcontext(context::FixedContext) = context.context
setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child)

"""
has_fixed_value(context::AbstractContext, vn::VarName)

Return `true` if a fixed value for `vn` is found in `context`.
"""
has_fixed_value(context::AbstractContext, vn::VarName) = false
has_fixed_value(context::FixedContext, vn::VarName) = hasvalue(context.values, vn)
function has_fixed_value(context::FixedContext, vns::AbstractArray{<:VarName})
return all(Base.Fix1(hasvalue, context.values), vns)
end

"""
get_fixed_value(context::AbstractContext, vn::VarName)

Return the fixed value of `vn` in `context`.
"""
function get_fixed_value(context::AbstractContext, vn::VarName)
return error("context $(context) does not contain value for $vn")
end
get_fixed_value(context::FixedContext, vn::VarName) = getvalue(context.values, vn)

"""
has_fixed_value_nested(context, vn)

Return `true` if a fixed value for `vn` is found in `context` or any of its descendants.

This is contrast to [`has_fixed_value(::AbstractContext, ::VarName)`](@ref) which only checks
for `vn` in `context`, not recursively checking if `vn` is in any of its descendants.
"""
function has_fixed_value_nested(context::AbstractContext, vn)
return has_fixed_value_nested(NodeTrait(has_fixed_value_nested, context), context, vn)
end
has_fixed_value_nested(::IsLeaf, context, vn) = has_fixed_value(context, vn)
function has_fixed_value_nested(::IsParent, context, vn)
return has_fixed_value(context, vn) || has_fixed_value_nested(childcontext(context), vn)
end
function has_fixed_value_nested(context::PrefixContext, vn)
return has_fixed_value_nested(childcontext(context), prefix(context, vn))
end

"""
get_fixed_value_nested(context, vn)

Return the fixed value of the parameter corresponding to `vn` from `context` or its descendants.

This is contrast to [`get_fixed_value`](@ref) which only returns the value `vn` in `context`,
not recursively looking into its descendants.
"""
function get_fixed_value_nested(context::AbstractContext, vn)
return get_fixed_value_nested(NodeTrait(get_fixed_value_nested, context), context, vn)
end
function get_fixed_value_nested(::IsLeaf, context, vn)
return error("context $(context) does not contain value for $vn")
end
function get_fixed_value_nested(context::PrefixContext, vn)
return get_fixed_value_nested(childcontext(context), prefix(context, vn))
end
function get_fixed_value_nested(::IsParent, context, vn)
return if has_fixed_value(context, vn)
get_fixed_value(context, vn)
else
get_fixed_value_nested(childcontext(context), vn)
end
end

"""
fix([context::AbstractContext,] values::NamedTuple)
fix([context::AbstractContext]; values...)

Return `FixedContext` with `values` and `context` if `values` is non-empty,
otherwise return `context` which is [`DefaultContext`](@ref) by default.

See also: [`unfix`](@ref)
"""
fix(; values...) = fix(NamedTuple(values))
fix(values::NamedTuple) = fix(DefaultContext(), values)
function fix(value::Pair{<:VarName}, values::Pair{<:VarName}...)
return fix((value, values...))
end
function fix(values::NTuple{<:Any,<:Pair{<:VarName}})
return fix(DefaultContext(), values)
end
fix(context::AbstractContext, values::NamedTuple{()}) = context
function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple})
return FixedContext(values, context)
end
function fix(context::AbstractContext; values...)
return fix(context, NamedTuple(values))
end
function fix(context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...)
return fix(context, (value, values...))
end
function fix(context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}})
return fix(context, Dict(values))
end

"""
unfix(context::AbstractContext, syms...)

Return `context` but with `syms` no longer fixed.

Note that this recursively traverses contexts, unfixing all along the way.

See also: [`fix`](@ref)
"""
unfix(::IsLeaf, context, args...) = context
function unfix(::IsParent, context, args...)
return setchildcontext(context, unfix(childcontext(context), args...))
end
function unfix(context, args...)
return unfix(NodeTrait(context), context, args...)
end
function unfix(context::FixedContext)
return unfix(childcontext(context))
end
function unfix(context::FixedContext, sym)
return fix(unfix(childcontext(context), sym), BangBang.delete!!(context.values, sym))
end
function unfix(context::FixedContext, sym, syms...)
return unfix(
fix(unfix(childcontext(context), syms...), BangBang.delete!!(context.values, sym)),
syms...,
)
end

function unfix(context::NamedFixedContext, vn::VarName{sym}) where {sym}
return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, sym))
end
function unfix(context::FixedContext, vn::VarName)
return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, vn))
end

"""
fixed(context::AbstractContext)

Return the values that are fixed under `context`.

Note that this will recursively traverse the context stack and return
a merged version of the fix values.
"""
fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context)
fixed(::IsLeaf, context) = ()
fixed(::IsParent, context) = fixed(childcontext(context))
function fixed(context::FixedContext)
# Note the order of arguments to `merge`. The behavior of the rest of DPPL
# is that the outermost `context` takes precendence, hence when resolving
# the `fixed` variables we need to ensure that `context.values` takes
# precedence over decendants of `context`.
return merge(context.values, fixed(childcontext(context)))
end
Loading