Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

link and invlink should correctly work with Selector and thus Gibbs #542

Merged
merged 12 commits into from
Oct 10, 2023
58 changes: 48 additions & 10 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -903,25 +903,44 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f)
end

function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model)
return _link(varinfo)
return _link(varinfo, spl)
end

function _link(varinfo::UntypedVarInfo)
function _link(varinfo::UntypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
return VarInfo(
_link_metadata!(varinfo, varinfo.metadata),
_link_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))),
Base.Ref(getlogp(varinfo)),
Ref(get_num_produce(varinfo)),
)
end

function _link(varinfo::TypedVarInfo)
function _link(varinfo::TypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
md = map(Base.Fix1(_link_metadata!, varinfo), varinfo.metadata)
md = _link_metadata!(varinfo, varinfo.metadata, Val(getspace(spl)))
# TODO: Update logp, etc.
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
end

@generated function _link_metadata!(
varinfo::VarInfo,
metadata::NamedTuple{names},
::Val{space}
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
) where {names,space}
vals = Expr(:tuple)
for f in names
if inspace(f, space) || length(space) == 0
push!(
expr.args,
:(_link_metadata!(varinfo, metadata.$f))
)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
else
push!(vals.args, :(metadata.$f))
end
end

return :(NamedTuple{$names}($vals))
end
function _link_metadata!(varinfo::VarInfo, metadata::Metadata)
vns = metadata.vns

Expand Down Expand Up @@ -972,25 +991,44 @@ end
function invlink(
::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model
)
return _invlink(varinfo)
return _invlink(varinfo, spl)
end

function _invlink(varinfo::UntypedVarInfo)
function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
return VarInfo(
_invlink_metadata!(varinfo, varinfo.metadata),
_invlink_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))),
Base.Ref(getlogp(varinfo)),
Ref(get_num_produce(varinfo)),
)
end

function _invlink(varinfo::TypedVarInfo)
function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
md = map(Base.Fix1(_invlink_metadata!, varinfo), varinfo.metadata)
md = _invlink_metadata!(varinfo, varinfo.metadata, Val(getspace(spl)))
# TODO: Update logp, etc.
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
end

@generated function _invlink_metadata!(
varinfo::VarInfo,
metadata::NamedTuple{names},
::Val{space}
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
) where {names,space}
vals = Expr(:tuple)
for f in names
if inspace(f, space) || length(space) == 0
Copy link
Member

@yebai yebai Oct 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if inspace(f, space) || length(space) == 0
# we select all variables in `varinfo` if `space = nothing`,
if inspace(f, space) || length(space) == 0

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space is never nothing though. At "best" it's an empty tuple. Remember, space !== vns. The scenario with vns === nothing only comes into play in the next call.

push!(
expr.args,
:(_invlink_metadata!(varinfo, metadata.$f))
)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
else
push!(vals.args, :(metadata.$f))
end
end

return :(NamedTuple{$names}($vals))
end
function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata)
vns = metadata.vns

Expand Down