diff --git a/Project.toml b/Project.toml index 7e35f8f..c84452e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AdaOPS" uuid = "eadfb9d8-44f1-454c-a5eb-0663ee7d74a1" repo = "git@github.com:LAMDA-POMDP/AdaOPS.jl.git" -version = "0.5.1" +version = "0.5.2" [deps] BasicPOMCP = "d721219e-3fc6-5570-a8ef-e5402f47c49e" diff --git a/src/AdaOPS.jl b/src/AdaOPS.jl index 5cf55d0..0b6f8d5 100644 --- a/src/AdaOPS.jl +++ b/src/AdaOPS.jl @@ -170,7 +170,7 @@ mutable struct AdaOPSTree{S,A,O} ba::Int end -mutable struct AdaOPSPlanner{S, A, O, P<:POMDP{S,A,O}, N, B, RNG<:AbstractRNG} <: Policy +mutable struct AdaOPSPlanner{S, A, O, P<:POMDP{S,A,O}, N, B, OD, RNG<:AbstractRNG} <: Policy sol::AdaOPSSolver{N, RNG} pomdp::P bounds::B @@ -186,6 +186,7 @@ mutable struct AdaOPSPlanner{S, A, O, P<:POMDP{S,A,O}, N, B, RNG<:AbstractRNG} < obs_w::Vector{Float64} u::Vector{Float64} l::Vector{Float64} + obs_dists::Vector{OD} tree::Union{Nothing, AdaOPSTree{S,A,O}} end @@ -199,11 +200,12 @@ function AdaOPSPlanner(sol::AdaOPSSolver{N}, pomdp::POMDP{S,A,O}) where {S,A,O,N m_max = sol.m_max access_cnt = zeros_like(sol.grid) norm_w = Vector{Float64}[Vector{Float64}(undef, m_min) for i in 1:m_max] + obs_dists = Vector{typeof(observation(pomdp, first(actions(pomdp)), rand(initialstate(pomdp))))}(undef, m_max) return AdaOPSPlanner(deepcopy(sol), pomdp, bounds, discounts, rng, WeightedParticleBelief(Vector{S}(undef, m_max), ones(m_max), m_max), sizehint!(O[], m_max), Dict{O, Int}(), sizehint!(Vector{Float64}[], m_max), norm_w, access_cnt, sizehint!(Float64[], m_max), sizehint!(Float64[], m_max), sizehint!(Float64[], m_max), - nothing) + obs_dists, nothing) end solver(p::AdaOPSPlanner) = p.sol diff --git a/src/tree.jl b/src/tree.jl index 27a3421..447ec55 100644 --- a/src/tree.jl +++ b/src/tree.jl @@ -64,7 +64,7 @@ function expand!(D::AdaOPSTree, b::Int, p::AdaOPSPlanner) for a in acts empty_buffer!(p) S, O, R = propagate_particles(D, belief, a, p) - gen_packing!(D, S, O, belief, a, p) + gen_packing!(D, O, belief, p) D.ba += 1 # increase ba count n_obs = length(p.w) # number of new obs @@ -207,6 +207,7 @@ function propagate_particles(D::AdaOPSTree, belief::WeightedParticleBelief, a, p else sp, o, r = @gen(:sp, :o, :r)(p.pomdp, s, a, p.rng) Rsum += w * r + p.obs_dists[i] = observation(p.pomdp, a, sp) push!(S, sp) obs_ind = get(p.obs_ind_dict, o, 0) if obs_ind !== 0 @@ -221,16 +222,16 @@ function propagate_particles(D::AdaOPSTree, belief::WeightedParticleBelief, a, p return S, O, Rsum/weight_sum(belief) end -function gen_packing!(D::AdaOPSTree, S, O, belief::WeightedParticleBelief, a, p::AdaOPSPlanner) +function gen_packing!(D::AdaOPSTree, O, belief::WeightedParticleBelief, p::AdaOPSPlanner) sol = solver(p) - m = length(S) + m = n_particles(belief) w = weights(belief) next_obs = 1 # denote the index of the next observation branch for i in eachindex(O) w′ = resize!(D.weights[D.b+next_obs], m) o = O[i] - reweight!(w′, w, S, a, o, p.pomdp) + reweight!(w′, w, o, p.obs_dists) # check if the observation is already covered by the packing w′ .= w′ ./ sum(w′) obs_ind = in_packing(w′, p.w, sol.delta) @@ -253,13 +254,13 @@ function gen_packing!(D::AdaOPSTree, S, O, belief::WeightedParticleBelief, a, p: return nothing end -function reweight!(w′::AbstractVector{Float64}, w::AbstractVector{Float64}, S::AbstractVector, a, o, m) +function reweight!(w′::AbstractVector{Float64}, w::AbstractVector{Float64}, o, obs_dists) @inbounds for i in eachindex(w′) if w[i] == 0.0 w′[i] = 0.0 else # w′[i] = w[i] * obs_weight(m, Φ[i], a, S[i], o) - w′[i] = w[i] * pdf(observation(m, a, S[i]), o) + w′[i] = w[i] * pdf(obs_dists[i], o) end end end