Skip to content

Commit

Permalink
The initial belief can now be specified for pg graph generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
mhahsler committed May 14, 2022
1 parent 61c9a1b commit cd61f71
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 36 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pomdp 1.0.1-1 (xx/xx/xxxx)

* policy_graph() can now produce policy trees for finite-horizon problems.
* policy_graph() can now produce policy trees for finite-horizon problems and the initial belief can be specified.
* simulate_POMDP(): fixed bug with not using horizon.
* reward() and reward_node_action() have now been separated.
* sample_belief_space() gained method 'trajectories'.
Expand Down
66 changes: 48 additions & 18 deletions R/plot_policy_graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
#' policy tree for a finite-horizon solution.
#' uses `plot` in \pkg{igraph} with appropriate plotting options.
#'
#' The policy graph nodes represent the segments in the value function. Each
#' segment represents one or more believe states. If available, a pie chart (or the color) in each node
#' represent the average belief of the belief states
#' belonging to the node/segment. This can help with interpreting the policy graph.
#' Each policy graph node represent a segment (or part of a hyperplane) of the value function.
#' Each node represents one or more believe states. If available, a pie chart (or the color) in each node
#' represent the central belief of the belief states
#' belonging to the node (i.e., the center of the hyperplane segment).
#' This can help with interpreting the policy graph.
#'
#' For converged POMDP solution a graph is produced, for finite-horizon solution a policy tree is produced.
#' The levels of the tree and the first number in the node label represent the epochs. Many algorithms produce
#' unused policy graph nodes which are filtered to produce a clean tree structure.
#' Non-converged policies depend on the initial belief and if an initial belief is
#' specified, then different nodes will be filtered and the tree will look different.
#'
#' First, the policy in the solved POMDP is converted into an [igraph] object using `policy_graph()`.
#' Average beliefs for the graph nodes are estimated using `estimate_belief_for_node()` and then the igraph
Expand All @@ -26,10 +33,14 @@
#' @import igraph
#'
#' @param x object of class [POMDP] containing a solved and converged POMDP problem.
#' @param belief the initial belief is used to mark the initial belief state in the
#' grave of a converged solution and to identify the root node in a policy graph for a finite-horizon solution.
#' If `NULL` then the belief is taken from the model definition.
#' @param show_belief logical; estimate belief proportions? If `TRUE` then `estimate_belief_for_nodes()` is used
#' and the belief is visualized as a pie chart in each node.
#' @param legend logical; display a legend for colors used belief proportions?
#' @param engine The plotting engine to be used.
#' @param engine The plotting engine to be used. For `"visNetwork"`, `flip.y = FALSE` can be used
#' to show the root node on top.
#' @param col colors used for the states.
#' @param ... parameters are passed on to `policy_graph()`, `estimate_belief_for_nodes()` and the functions
#' they use. Also, plotting options are passed on to the plotting engine [igraph::plot.igraph()]
Expand Down Expand Up @@ -95,22 +106,28 @@
#' estimate_belief_for_nodes(sol, n = 100)
#'
#' ## policy trees for finite-horizon solutions
#' sol <- solve_POMDP(model = Tiger, horizon = 3, method = "incprune")
#' sol <- solve_POMDP(model = Tiger, horizon = 4, method = "incprune")
#'
#' policy_graph(sol)
#'
#' plot_policy_graph(sol)
#' # Note: the first number in the node id is the epoch.
#'
#' # plot the policy tree for an initial belief of 90% that the tiger is to the left
#' plot_policy_graph(sol, belief = c(0.9, 0.1))
#'
#' @export
policy_graph <- function(x, show_belief = TRUE, col = NULL, ...) {
policy_graph <- function(x, belief = NULL, show_belief = TRUE, col = NULL, ...) {
.solved_POMDP(x)
## FIXME: add initial belief!


if (!x$solution$converged || length(x$solution$pg) > 1)
return(policy_graph_unconverged(x, show_belief, col = col, ...))
policy_graph_unconverged(x, belief, show_belief = show_belief, col = col, ...)
else
policy_graph_converged(x, belief, show_belief = show_belief, col = col, ...)
}

policy_graph_converged <- function(x, belief = NULL, show_belief = TRUE, col = NULL, ...) {

# create policy graph and belief proportions (average belief for each alpha vector)
pg <- x$solution$pg[[1]]

Expand Down Expand Up @@ -147,13 +164,19 @@ policy_graph <- function(x, show_belief = TRUE, col = NULL, ...) {
l <- do.call(rbind, l)
l <- l[!is.na(l$to), ] # remove links to nowhere ('-' in pg)

# creating the initial graph
# creating graph
policy_graph <- graph.edgelist(as.matrix(l[, 1:2]))
edge.attributes(policy_graph) <- list(label = l$label)

# mark the node for the initial belief
if (is.null(belief))
initial_pg_node <- x$solution$initial_pg_node
else
initial_pg_node <- reward_node_action(x, belief = belief)$pg_node

### Note: the space helps with moving the id away from the pie cut.
init <- rep(": ", nrow(pg))
init[x$solution$initial_pg_node] <- ": initial belief"
init[initial_pg_node] <- ": initial belief"

V(policy_graph)$label <- paste0(pg$node, init, "\n", pg$action)

Expand Down Expand Up @@ -185,7 +208,7 @@ policy_graph <- function(x, show_belief = TRUE, col = NULL, ...) {
policy_graph
}

policy_graph_unconverged <- function(x, show_belief = TRUE, col = NULL, ...) {
policy_graph_unconverged <- function(x, belief = NULL, show_belief = TRUE, col = NULL, ...) {
.solved_POMDP(x)

pg <- x$solution$pg
Expand All @@ -202,8 +225,14 @@ policy_graph_unconverged <- function(x, show_belief = TRUE, col = NULL, ...) {
pg[[o]][pg[["epoch"]] == epochs] <- NA
}

# mark the node for the initial belief
if (is.null(belief))
initial_pg_node <- x$solution$initial_pg_node
else
initial_pg_node <- reward_node_action(x, belief = belief)$pg_node

## remove unused nodes
used <- paste0("1-", x$solution$initial_pg_node)
used <- paste0("1-", initial_pg_node)
for(i in seq(epochs)) {
used <- append(used, unlist(pg[pg[["epoch"]] == i & pg[["node"]] %in% used, observations]))
}
Expand Down Expand Up @@ -308,6 +337,7 @@ policy_graph_unconverged <- function(x, show_belief = TRUE, col = NULL, ...) {
#' @rdname policy_graph
#' @export
plot_policy_graph <- function(x,
belief = NULL,
show_belief = TRUE,
legend = TRUE,
engine = c("igraph", "visNetwork"),
Expand All @@ -318,15 +348,15 @@ plot_policy_graph <- function(x,
engine <- match.arg(engine)
switch(
engine,
igraph = .plot.igraph(x, show_belief, legend = legend, col = col, ...),
visNetwork = .plot.visNetwork(x, show_belief, legend = legend, col = col, ...)
igraph = .plot.igraph(x, belief, show_belief = show_belief, legend = legend, col = col, ...),
visNetwork = .plot.visNetwork(x, belief, show_belief = show_belief, legend = legend, col = col, ...)
)
}


.plot.igraph <-
function(x, show_belief, legend, col, edge.curved = NULL, ...) {
pg <- policy_graph(x, belief = show_belief, col = col, ...)
function(x, belief = NULL, show_belief, legend, col, edge.curved = NULL, ...) {
pg <- policy_graph(x, belief, show_belief = show_belief, col = col, ...)

if (is.null(edge.curved))
edge.curved <- .curve_multiple_directed(pg)
Expand Down
18 changes: 12 additions & 6 deletions R/visNetwork.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@
# Note: legend is not used right now!
.plot.visNetwork <-
function(x,
belief = TRUE,
belief = NULL,
show_belief = TRUE,
legend = NULL,
col = NULL,
smooth = list(type = "continuous"),
layout = NULL,
...) {
check_installed("visNetwork")

pg <- policy_graph(x, belief = belief, col = col)

unconverged <- !x$solution$converged || length(x$solution$pg) > 1
if (is.null(layout))
layout <- ifelse(unconverged, "layout_as_tree", "layout_nicely")

pg <- policy_graph(x, belief, show_belief = show_belief, col = col)

### add tooltip
#V(pg)$title <- paste(htmltools::tags$b(V(pg)$label)
Expand All @@ -23,11 +29,11 @@
))

### colors
if (belief) {
if (show_belief) {
# winner
#V(pg)$color <- V(pg)$pie.color[[1]][sapply(V(pg)$pie, which.max)]

# mixing in rgb spave
# mixing in rgb space
V(pg)$color <- sapply(
seq(length(V(pg))),
FUN = function(i)
Expand All @@ -41,7 +47,7 @@
# do.call(hsv, as.list(rgb2hsv(col2rgb(V(pg)$pie.color[[1]])) %*% V(pg)$pie[[i]])))
}

visNetwork::visIgraph(pg, idToLabel = FALSE, smooth = smooth, ...) %>%
visNetwork::visIgraph(pg, idToLabel = FALSE, layout = layout, smooth = smooth, ...) %>%
visNetwork::visOptions(
highlightNearest = list(enabled = TRUE, degree = 0),
nodesIdSelection = TRUE
Expand Down
30 changes: 23 additions & 7 deletions man/policy_graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions tests/testthat/test-solve_POMDP.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ expect_equal(pg_horizon$action, pg_stepwise$action) # transitions do not work
context("solve_POMDP and model files")

sol <- solve_POMDP("http://www.pomdp.org/examples/1d.POMDP")
plot_policy_graph(sol, belief = FALSE)
plot_policy_graph(sol, show_belief = FALSE)
policy(sol)

sol <- solve_POMDP("http://www.pomdp.org/examples/cheese.95.POMDP")
plot_policy_graph(sol, belief = FALSE)
plot_policy_graph(sol, show_belief = FALSE)
policy(sol)

sol <- solve_POMDP("http://www.pomdp.org/examples/shuttle.95.POMDP",
parameter = list(fg_points = 10))
plot_policy_graph(sol, belief = FALSE)
plot_policy_graph(sol, show_belief = FALSE)
policy(sol)

sol <- solve_POMDP("http://www.pomdp.org/examples/stand-tiger.95.POMDP")
plot_policy_graph(sol, belief = FALSE)
plot_policy_graph(sol, show_belief = FALSE)
policy(sol)

0 comments on commit cd61f71

Please sign in to comment.