Skip to content

Commit

Permalink
fixed converged finite-horizon case. It now only returns the converge…
Browse files Browse the repository at this point in the history
…d graph/alpha.
  • Loading branch information
mhahsler committed May 15, 2022
1 parent cd61f71 commit a1513af
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 36 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* sample_belief_space() gained method 'trajectories'.
* simulate_POMDP(): supports not epsilon-greedy policies.
* observation_matrix() et al. functions are now created with a separate function ending in _function.
* fixed converged finite-horizon case. It now only returns the converged graph/alpha.

# pomdp 1.0.1 (03/25/2022)

Expand Down
35 changes: 13 additions & 22 deletions R/POMDP.R
Original file line number Diff line number Diff line change
Expand Up @@ -455,36 +455,28 @@ print.POMDP <- function(x, ...) {
stop("x needs to be a solved POMDP. Use solve_POMDP() first.")
}

.timedependent_POMDP <- function(x) {
.timedependent_POMDP <- function(x)
!is.null(x$horizon) && length(x$horizon) > 1L
}


# get pg and alpha for a epoch
.get_pg_index <- function(model, epoch) {
.solved_POMDP(model)
#.solved_POMDP(model)

if (epoch < 1)
stop("Epoch has to be >= 1")
epoch <- as.integer(epoch)
if(epoch < 1L) stop("Epoch has to be >= 1")

h <- model$horizon
l <- length(model$solution$pg)
### (converged) infinite horizon POMDPs. We ignore epoch.
if (length(model$solution$pg) == 1L) return(1L)

if (epoch > h)
stop("POMDP model was only solved for ", h, " epochs!")
### regular epoch for finite/infinite horizon case
if (epoch <= length(model$solution$pg)) return(epoch)

### (converged) infinite horizon POMDPs
if (is.infinite(h))
epoch <- 1L
if (epoch > sum(model$horizon))
stop("POMDP model was only solved for ", sum(model$horizon), " epochs!")

### converged finite horizon model
else {
if (epoch <= h - l)
epoch <- 1L
else
epoch <- epoch - (h - l)
}

epoch
### converged finite-horizon case return the last (i.e., converged) epoch
return(length(model$solution$pg))
}

.get_pg <-
Expand All @@ -494,7 +486,6 @@ print.POMDP <- function(x, ...) {
function(model, epoch)
model$solution$alpha[[.get_pg_index(model, epoch)]]


#' @rdname POMDP
#' @export
O_ <-
Expand Down
4 changes: 3 additions & 1 deletion R/plot_policy_graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ policy_graph_unconverged <- function(x, belief = NULL, show_belief = TRUE, col =
pg[["node"]] <- paste0(pg[["epoch"]], "-", pg[["node"]])
for(o in observations) {
pg[[o]] <- paste0(pg[["epoch"]] + 1L, "-", pg[[o]])

## these should be NA. Make sure they are
pg[[o]][pg[["epoch"]] == epochs] <- NA
}

Expand All @@ -231,7 +233,7 @@ policy_graph_unconverged <- function(x, belief = NULL, show_belief = TRUE, col =
else
initial_pg_node <- reward_node_action(x, belief = belief)$pg_node

## remove unused nodes
## remove unreached nodes
used <- paste0("1-", initial_pg_node)
for(i in seq(epochs)) {
used <- append(used, unlist(pg[pg[["epoch"]] == i & pg[["node"]] %in% used, observations]))
Expand Down
5 changes: 0 additions & 5 deletions R/read_write_pomdp_solve.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,6 @@
)
pg <- pg + 1 #index has to start from 1 not 0

### FIXME: I am not sure we need this now
#if (dim(pg)[2]==1 ) {
# pg <- t(pg)
#}

# renaming the columns and actions
colnames(pg) <-
c("node", "action", as.character(model$observations))
Expand Down
22 changes: 16 additions & 6 deletions R/solve_POMDP.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@
#' - `converged` did the solution converge?
#' - `initial_belief` used initial beliefs.
#' - `total_expected_reward` reward from the initial beliefs.
#' - `pg`, `initial_pg_node` a list representing the policy graph. A converged solution has
#' only a single list elements.
#' - `pg`, `initial_pg_node` a list representing the policy graph. The epochs are
#' the list entries. A converged infinite-horizon solution has
#' only a single list elements. Finite-horizon solutions may converge early resulting in a shorter list.
#' - `belief_states` used belief states.
#' - `alpha` value function as hyperplanes representing the nodes in the policy graph.
#' - `policy` the policy.
Expand Down Expand Up @@ -498,11 +499,23 @@ solve_POMDP <- function(model,
cat("Convergence: Finite-horizon POMDP converged early at epoch:",
i - 1,
"\n")
converged <- i - 1
converged <- TRUE

# we only need to keep the first pg element with the graph
pg <- tail(pg, n = 1L)
alpha <- tail(alpha, n = 1L)

break
}
}

## make transitions in last epoch NA for non converged solutions
if (!converged)
pg[[1L]][, as.character(model$observations)] <- NA

## order by epoch
alpha <- rev(alpha)
pg <- rev(pg)

if (method == "grid" &&
!converged &&
Expand All @@ -511,9 +524,6 @@ solve_POMDP <- function(model,
"The grid method for finite horizon did not converge. The value function and the calculated reward values may not be valid with negative reward in the reward matrix. Use method 'simulate_POMDP()' to estimate the reward or use solution method 'incprune'."
)

alpha <- rev(alpha)
pg <- rev(pg)

}

# read belief states if available (method: grid)
Expand Down
5 changes: 3 additions & 2 deletions man/solve_POMDP.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a1513af

Please sign in to comment.