Skip to content

Commit

Permalink
be more permissive with what gets inserted into categorical vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Oct 24, 2023
1 parent f46cc4d commit 7408b96
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 23 deletions.
6 changes: 3 additions & 3 deletions src/categorical_vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ mutable struct CategoricalVector{T}
items::Vector{T}
cdf::Vector{Float64}

CategoricalVector{T}(item::T, weight::Float64) where T = new(T[item], Float64[weight])
CategoricalVector{T}(item::T, weight) where T = new(T[item], Float64[weight])
end

CategoricalVector(item::T, weight::Float64) where T = CategoricalVector{T}(item, weight)
CategoricalVector(item::T, weight) where T = CategoricalVector{T}(item, weight)

n_items(d::CategoricalVector) = length(d.items)

function insert!(c::CategoricalVector{T}, item::T, weight::Float64) where T
function insert!(c::CategoricalVector, item, weight)
push!(c.items, item)
push!(c.cdf, c.cdf[end]+weight)
end
Expand Down
3 changes: 0 additions & 3 deletions test/categorical_tree.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
using POMCPOW
using Base.Test

t = CategoricalTree(1, 1.0)
@test POMCPOW.nleaves(t) == 1
insert!(t, 2, 3.0)
Expand Down
32 changes: 16 additions & 16 deletions test/categorical_vector.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
using POMCPOW
using Base.Test

t = CategoricalVector(1, 1.0)
insert!(t, 2, 3.0)
rand(Base.GLOBAL_RNG, t)

results = Int[]
@time for i in 1:1000
push!(results, rand(Base.GLOBAL_RNG, t))
end
rand(Random.default_rng(), t)

@time for i in 3:1000
insert!(t,i,1.0)
end

@time for i in 1:1000
push!(results, rand(Base.GLOBAL_RNG, t))
end
# results = Int[]
# @time for i in 1:1000
# push!(results, rand(Random.default_rng(), t))
# end
#
# @time for i in 3:1000
# insert!(t,i,1.0)
# end
#
# @time for i in 1:1000
# push!(results, rand(Random.default_rng(), t))
# end

#=
using Plots
histogram(results)
gui()
=#

t2 = CategoricalVector(2, 1.0)
insert!(t2, 2.0, 3) # test that types can get converted correctly
2 changes: 1 addition & 1 deletion test/init_node_sr_belief_error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function POMDPs.reward(m::SimplePOMDP, s::Int, a::Int, sp::Int)
end
end

function POMDPs.initialstate_distribution(m::SimplePOMDP)
function POMDPs.initialstate(m::SimplePOMDP)
return SparseCat(1:7, ones(7) ./ 7)
end

Expand Down
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ using ParticleFilters
using POMDPTools
using D3Trees

import Random

@testset "all" begin

@testset "POMDPTesting" begin
Expand Down Expand Up @@ -86,6 +88,10 @@ using D3Trees
@test actionvalues(planner, b) isa AbstractVector
end

@testset "categorical vector" begin
include("categorical_vector.jl")
end

@testset "init_node_sr_belief_error" begin
include("init_node_sr_belief_error.jl")
end;
Expand Down

0 comments on commit 7408b96

Please sign in to comment.