Skip to content

Commit

Permalink
Merge pull request #13 from jmuchovej/add-pretty-printing
Browse files Browse the repository at this point in the history
Add "pretty printing"
  • Loading branch information
zsunberg authored Jun 10, 2023
2 parents 4766f6e + 6d1e1f6 commit c9e7055
Show file tree
Hide file tree
Showing 13 changed files with 432 additions and 157 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.jl.cov
*.jl.*.cov
*.jl.mem
Manifest.toml
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[compat]
POMDPModels = "0.4"
POMDPXFiles = "0.2"
POMDPTools = "0.1.4"
POMDPXFiles = "0.2"
POMDPs = "0.9"
Reexport = "0.2, 1"
julia = "1"

[extras]
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[targets]
test = ["Downloads", "Test"]
test = ["Downloads", "Test", "SHA"]
12 changes: 6 additions & 6 deletions src/POMDPFiles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ using Reexport
using POMDPs
using POMDPTools
using Printf
using POMDPModels: TabularPOMDP

@reexport using POMDPXFiles # for POMDPAlphas

import POMDPs: action, value

export
POMDPFile,

read_alpha
export read_alpha, read_pomdp
include("reader.jl")

include("read.jl")
include("write.jl")
export numericprint, symbolicprint
include("writer.jl")

end # module
23 changes: 6 additions & 17 deletions src/read.jl → src/reader.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using POMDPModels: TabularPOMDP

const REGEX_FLOATING_POINT = r"[-+]?[0-9]*\.?[0-9]+"

"""
Expand Down Expand Up @@ -40,17 +38,15 @@ const REGEX_FLOATING_POINT = r"[-+]?[0-9]*\.?[0-9]+"
the best action to take for that belief state given the value function.
"""
function read_alpha(filename::AbstractString)

@assert isfile(filename) "filename $(filename) does not exist"

lines = readlines(open(filename))
lines = open(readlines, filename)

alpha_vector_line_indeces = Int[]
vector_length = -1

for i in 1:length(lines)

matches = collect((m.match for m = eachmatch(REGEX_FLOATING_POINT, lines[i])))

if length(matches) > 1
push!(alpha_vector_line_indeces, i)
@assert occursin(r"^(\d)*$", lines[i-1]) "previous line must contain an action index"
Expand Down Expand Up @@ -86,12 +82,8 @@ function read_alpha(filename::AbstractString)
end

function read_pomdp(filename::AbstractString)

lines = open(readlines, filename)

alpha_vector_line_indeces = Int[]
vector_length = -1

discount = 0
num_states = 0
num_actions = 0
Expand All @@ -101,8 +93,6 @@ function read_pomdp(filename::AbstractString)
actions = 0
observations = 0

all_indices = ':'

T_lines = Vector{Int64}()
O_lines = Vector{Int64}()
R_lines = Vector{Int64}()
Expand Down Expand Up @@ -166,7 +156,7 @@ function read_pomdp(filename::AbstractString)
ind2 = 0
ind3 = 0

if length(T_lines) > 0
if length(T_lines) > 0
if length(findall(x->x==':', lines[T_lines[1]])) == 3
for t in T_lines
l = replace(lines[t], ':'=>' ')
Expand Down Expand Up @@ -226,7 +216,7 @@ function read_pomdp(filename::AbstractString)
end
end

if length(O_lines) > 0
if length(O_lines) > 0
if length(findall(x->x==':', lines[O_lines[1]])) == 3
for t in O_lines
l = replace(lines[t], ':'=>' ')
Expand Down Expand Up @@ -342,5 +332,4 @@ function read_pomdp(filename::AbstractString)

m = TabularPOMDP(T, R, O, discount)
return m

end
119 changes: 0 additions & 119 deletions src/write.jl

This file was deleted.

138 changes: 138 additions & 0 deletions src/writer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@

"""
Writes out the alpha vectors in the `.alpha` file format
"""
function Base.write(io::IO, alphas::POMDPAlphas)
alphavectors = alphas.alpha_vectors
for aidx in eachindex(alphas.alpha_actions)
println(io, alphas.alpha_actions[aidx])
vector_as_str = join([@sprintf "%.25f" v for v=alphavectors[aidx]], " ")
println(io, vector_as_str)
end
println(io)
end

"""
Write out a `.pomdp` file using the POMDPs.jl interface
Specification: http://cs.brown.edu/research/ai/pomdp/examples/pomdp-file-spec.html
A more recent version of the spec: https://pomdp.org/code/pomdp-file-spec.html
"""
function Base.write(io::IO, pomdp::POMDP; pretty=false)
pretty ? symbolicprint(io, pomdp) : numericprint(io, pomdp)
end

function numericprint(filename::String, pomdp::POMDP)
file = open(filename, "w")
numericprint(file, pomdp)
close(file)
end

function numericprint(io::IO, pomdp::POMDP)
_states = ordered_states(pomdp)
_actions = ordered_actions(pomdp)
_observations = ordered_observations(pomdp)

println(io, "discount: $(discount(pomdp))")
println(io, "values: reward")
println(io, "states: $(length(_states))")
println(io, "actions: $(length(_actions))")
println(io, "observations: $(length(_observations))")
println(io)

for a=_actions
aidx = actionindex(pomdp, a) -1
println(io, "T: $(aidx)")
for s=_states
T = transition(pomdp, s, a)
Trow = join([pdf(T, sp) for sp=_states], " ")
println(io, Trow)
end
println(io)
end

for a=_actions
aidx = actionindex(pomdp, a) -1
println(io, "O: $(aidx)")
for s=_states
O = observation(pomdp, a, s)
Orow = join([pdf(O, o) for o=_observations], " ")
println(io, Orow)
end
print(io, "\n")
end

for a=_actions, s=_states, sp=_states, o=_observations
aidx = actionindex(pomdp, a) - 1
sidx = stateindex(pomdp, s) - 1
spidx = stateindex(pomdp, sp) - 1
oidx = obsindex(pomdp, o) - 1
r = reward(pomdp, s, a, sp, o)
println(io, "R: $(aidx) : $(sidx) : $(spidx) : $(oidx) $(r)")
end

println(io)
end

function normalize(s)
s = string(s)
clean = replace(s, r"[^a-zA-Z0-9]" => "_")
return replace(clean, r"_+" => "_")
end

function symbolicprint(
filename::String, pomdp::POMDP;
sname::Function=normalize, aname::Function=normalize, oname::Function=normalize
)
file = open(filename, "w")
symbolicprint(file, pomdp; sname=sname, aname=aname, oname=oname)
close(file)
end

function symbolicprint(
io::IO, pomdp::POMDP;
sname::Function=normalize, aname::Function=normalize, oname::Function=normalize
)
_states = ordered_states(pomdp)
_actions = ordered_actions(pomdp)
_observations = ordered_observations(pomdp)

println(io, "discount: $(discount(pomdp))")
println(io, "values: reward")
println(io, "states: $(join(map(sname, _states), " "))")
println(io, "actions: $(join(map(aname, _actions), " "))")
println(io, "observations: $(join(map(oname, _observations), " "))")
println(io)

println(io, "# -------------------------------------------------------------------")
println(io, "# TRANSITIONS")
println(io, "T: * : * : * 0.0")
for a=_actions, s=_states, sp=_states
T = transition(pomdp, s, a)
if pdf(T, sp) > 0.
println(io, "T: $(aname(a)) : $(sname(s)) : $(sname(sp)) : $(pdf(T, sp))")
end
end
println(io)

println(io, "# -------------------------------------------------------------------")
println(io, "# OBSERVATIONS")
println(io, "O: * : * : * 0.0")
for a=_actions, sp=_states, o=_observations
O = observation(pomdp, a, sp)
if pdf(O, o) > 0.
println(io, "O: $(aname(a)) : $(sname(sp)) : $(oname(o)) $(pdf(O, o))")
end
end
println(io)

println(io, "# -------------------------------------------------------------------")
println(io, "# REWARDS")
println(io, "R: * : * : * : * 0.0")
for a=_actions, s=_states, sp=_states, o=_observations
r = reward(pomdp, s, a, sp, o)
if r != 0
println(io, "R: $(aname(a)) : $(sname(s)) : $(sname(sp)) : $(oname(o)) $(r)")
end
end
println(io)
end
Loading

0 comments on commit c9e7055

Please sign in to comment.