From 8639f1aaba20e97b27f898806a57189594963b41 Mon Sep 17 00:00:00 2001 From: Michael Hahsler Date: Mon, 16 May 2022 16:26:35 -0500 Subject: [PATCH] added x_prob() and x_val() functions to access individual parts of the matrices. we use not internally NA to represent * in the POMDP definition. actions, states and observations are now factors in most places. --- DESCRIPTION | 1 + NAMESPACE | 7 +- NEWS.md | 4 +- R/MDP.R | 60 ++++++- R/Maze.R | 2 +- R/POMDP.R | 242 ++++++++++++++++++----------- R/plot_policy_graph.R | 4 +- R/plot_value_function.R | 4 +- R/policy.R | 4 +- R/read_write_POMDP.R | 20 ++- R/read_write_pomdp_solve.R | 2 +- R/round_stochchastic.R | 8 +- R/simulate_MDP.R | 165 ++++++++++++++++++++ R/simulate_POMDP.R | 27 ++-- R/solve_MDP.R | 3 +- R/solve_SARSOP.R | 4 +- R/transition_matrix.R | 94 +++++------ Work/data/create_Three_doors.R | 14 +- Work/data/create_Tiger.R | 1 + data/Maze.rda | Bin 744 -> 771 bytes data/Three_doors.rda | Bin 527 -> 554 bytes data/Tiger.rda | Bin 470 -> 496 bytes man/MDP.Rd | 6 + man/Maze.Rd | 2 +- man/POMDP.Rd | 26 ++-- man/simulate_MDP.Rd | 80 ++++++++++ man/simulate_POMDP.Rd | 4 +- man/solve_MDP.Rd | 4 + man/transition_matrix.Rd | 29 ++-- tests/testthat/test-solve_SARSOP.R | 3 +- 30 files changed, 587 insertions(+), 233 deletions(-) create mode 100644 R/simulate_MDP.R create mode 100644 man/simulate_MDP.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 265b57e..0e3de93 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -48,6 +48,7 @@ Collate: 'reward.R' 'round_stochchastic.R' 'sample_belief_space.R' + 'simulate_MDP.R' 'simulate_POMDP.R' 'solve_MDP.R' 'solve_POMDP.R' diff --git a/NAMESPACE b/NAMESPACE index 35b3db0..112de2d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -11,8 +11,8 @@ export(R_) export(T_) export(approx_MDP_policy_evaluation) export(estimate_belief_for_nodes) -export(observation_function) export(observation_matrix) +export(observation_prob) export(optimal_action) export(plot_belief_space) export(plot_policy_graph) @@ -23,18 +23,19 @@ export(q_values_MDP) export(random_MDP_policy) export(read_POMDP) export(reward) -export(reward_function) export(reward_matrix) export(reward_node_action) +export(reward_val) export(round_stochastic) export(sample_belief_space) +export(simulate_MDP) export(simulate_POMDP) export(solve_MDP) export(solve_POMDP) export(solve_POMDP_parameter) export(solve_SARSOP) -export(transition_function) export(transition_matrix) +export(transition_prob) export(update_belief) export(write_POMDP) import(graphics) diff --git a/NEWS.md b/NEWS.md index 2a1b2de..07f7cbf 100644 --- a/NEWS.md +++ b/NEWS.md @@ -5,8 +5,10 @@ * reward() and reward_node_action() have now been separated. * sample_belief_space() gained method 'trajectories'. * simulate_POMDP(): supports not epsilon-greedy policies. -* observation_matrix() et al. functions are now created with a separate function ending in _function. +* added x_prob() and x_val() functions to access individual parts of the matrices. * fixed converged finite-horizon case. It now only returns the converged graph/alpha. +* we use not internally NA to represent * in the POMDP definition. +* actions, states and observations are now factors in most places. # pomdp 1.0.1 (03/25/2022) diff --git a/R/MDP.R b/R/MDP.R index 3342557..f90ca7f 100644 --- a/R/MDP.R +++ b/R/MDP.R @@ -13,6 +13,9 @@ #' #' More details on specifying the model components can be found in the documentation #' for [POMDP]. +#' +#' @family MDP +#' #' @include POMDP.R #' @param states a character vector specifying the names of the states. #' @param actions a character vector specifying the names of the available @@ -103,7 +106,35 @@ MDP <- function(states, } #' @export -print.MDP <- print.POMDP +print.MDP <- function(x, ...) { + writeLines(paste(paste(class(x), collapse = ", "), + "-", + x$name)) + + if (!is.null(x$discount)) + writeLines(sprintf(" Discount factor: %s", + paste(x$discount, collapse = "+"))) + + if (!is.null(x$horizon)) + writeLines(sprintf(" Horizon: %s epochs", + paste(x$horizon, collapse = " + "))) + + if (.solved_MDP(x)) + writeLines(c( + " Solved:", + sprintf(" Solution converged: %s", + x$solution$converged) + ) + ) + + writeLines(strwrap( + paste("List components:", paste(sQuote(names( + x + )), collapse = ", "), "\n"), + indent = 2, + exdent = 4 + )) +} #' @rdname MDP #' @export @@ -117,14 +148,33 @@ MDP2POMDP <- function(x) { ident_matrix <- diag(length(x$states)) dimnames(ident_matrix) <- list(x$states, x$observations) - x$observation_prob <- list('*' = ident_matrix) + x$observation_prob <- sapply(x$actions, FUN = function(x) ident_matrix, simplify = FALSE) class(x) <- c("MDP", "POMDP", "list") x } -.solved_MDP <- function(x) { +.solved_MDP <- function(x, stop = FALSE) { if (!inherits(x, "MDP")) - stop("x needs to be a POMDP object!") - if (is.null(x$solution)) + stop("x needs to be a MDP object!") + solved <- !is.null(x$solution) + if (stop && !solved) stop("x needs to be a solved MDP. Use solve_MDP() first.") + + solved +} + +## this is .get_pg_index for MDPs +.get_pol_index <- function(model, epoch) { + + epoch <- as.integer(epoch) + if(epoch < 1L) stop("Epoch has to be >= 1") + + ### (converged) infinite horizon POMDPs. We ignore epoch. + if (length(model$solution$policy) == 1L) return(1L) + + ### regular epoch for finite/infinite horizon case + if (epoch > length(model$solution$policy)) + stop("MDP model has only a policy up to epoch ", length(model$solution$policy)) + + return(epoch) } \ No newline at end of file diff --git a/R/Maze.R b/R/Maze.R index 0bdb161..387aa56 100644 --- a/R/Maze.R +++ b/R/Maze.R @@ -18,7 +18,7 @@ #' The # (state `s_5`) in the middle of the maze is an obstruction and not reachable. #' Rewards are associated with transitions. The default reward (penalty) is -0.04. #' Transitioning to + (state `s_12`) gives a reward of 1.0, transitioning to - (state `s_11`) -#' has a reward of -1.0. States `s_11` and `s_12` are terminal states. +#' has a reward of -1.0. States `s_11` and `s_12` are terminal (absorbing) states. #' #' Actions are movements (`north`, `south`, `east`, `west`). The actions are unreliable with a .8 chance #' to move in the correct direction and a 0.1 chance to instead to move in a diff --git a/R/POMDP.R b/R/POMDP.R index 2dc46f4..838709d 100644 --- a/R/POMDP.R +++ b/R/POMDP.R @@ -25,8 +25,9 @@ #' #' State names, actions and observations can be specified as strings or index numbers #' (e.g., `start.state` can be specified as the index of the state in `states`). -#' For the specification as data.frames below, `'*'` can be used to mean -#' any `start.state`, `end.state`, `action` or `observation`. +#' For the specification as data.frames below, `'*'` or `NA` can be used to mean +#' any `start.state`, `end.state`, `action` or `observation`. Note that `'*'` is internally +#' always represented as an `NA`. #' #' The specification below map to the format used by pomdp-solve #' (see \url{http://www.pomdp.org}). @@ -84,14 +85,12 @@ #' #' **Start Belief** #' -#' This belief is used to calculate the total expected cumulative reward -#' printed with the solved model. The function [reward()] can be +#' The initial belief state of the agent is a distribution over the states. It is used to calculate the +#' total expected cumulative reward printed with the solved model. The function [reward()] can be #' used to calculate rewards for any belief. #' #' Some methods use this belief to decide which belief states to explore (e.g., -#' the finite grid method). The default initial belief is a uniform -#' distribution over all states. No initial belief state can be used by setting -#' `start = NULL`. +#' the finite grid method). #' #' Options to specify the start belief state are: #' @@ -101,7 +100,10 @@ #'* The string `"uniform"` for a uniform #' distribution over all states. #'* An integer in the range \eqn{1} to \eqn{n} to specify the index of a single starting state. -#'* a string specifying the name of a single starting state. +#'* A string specifying the name of a single starting state. +#' +#' The default initial belief is a uniform +#' distribution over all states. #' #' **Time-dependent POMDPs** #' @@ -135,16 +137,16 @@ #' matrix specifying the terminal rewards via a terminal value function (e.g., #' the alpha component produced by solve_POMDP). A single 0 specifies that all #' terminal values are zero. -#' @param start Specifies the initial probabilities for each state (i.e., the -#' initial belief), typically as a vector or the string `'uniform'` -#' (default). This belief is used to calculate the total expected cumulative +#' @param start Specifies the initial belief state of the agent. A vector with the +#' probability for each state is supplied. Also the string `'uniform'` +#' (default) can be used. The belief is used to calculate the total expected cumulative #' reward. It is also used by some solvers. See Details section for more #' information. #' @param name a string to identify the POMDP problem. #' @param action,start.state,end.state,observation,probability,value Values #' used in the helper functions `O_()`, `R_()`, and `T_()` to #' create an entry for `observation_prob`, `reward`, or -#' `transistion_prob` above, respectively. The default value `'*"'` +#' `transition_prob` above, respectively. The default value `'*"'` #' matches any action/state/observation. #' #' @return The function returns an object of class POMDP which is list of the model specification. @@ -261,48 +263,73 @@ POMDP <- function(states, check_and_fix_MDP(x) } -check_func <- function(x, func, name) { - if (is.function(x)) { + + + +# make sure the definition is complete and everything is in the right order and the right factors +check_and_fix_MDP <- function(x) { + + ## TODO: fix and use check_formals + check_func <- function(x, func, name) { req_formals <- head(names(formals(func)), -1) if (!identical(names(formals(x)), req_formals)) stop(name, " function needs formal arguments: ", paste(sQuote(req_formals), collapse = ", ")) } -} - -check_df <- function(x, func, name) { - if (is.data.frame(x)) { + + ## Note: uses x (model) from the surrounding environment + check_df <- function(field, func) { req_columns <- names(formals(func)) - if (!identical(colnames(x), req_columns)) + if (is.null(colnames(field))) + colnames(field) <- req_columns + + if (!identical(colnames(field), req_columns)) stop("The ", - name, + deparse(substitute(field)), " data.frame needs columns named: ", paste(sQuote(req_columns), collapse = ", ")) + + # convert * to NA + field[field == '*'] <- NA + + for (i in grep("action", colnames(field))) { + if(is.numeric(field[[i]])) field[[i]] <- x$actions[field[[i]]] + field[[i]] <- factor(field[[i]], levels = x$actions) + } + for (i in grep("state", colnames(field))) { + if(is.numeric(field[[i]])) field[[i]] <- x$states[field[[i]]] + field[[i]] <- factor(field[[i]], levels = x$states) + } + for (i in grep("observation", colnames(field))){ + if(is.numeric(field[[i]])) field[[i]] <- x$observations[field[[i]]] + field[[i]] <- factor(field[[i]], levels = x$observations) + } + + field } -} - -check_and_fix_MDP <- function(x) { + within(x, { + if (is.numeric(states) && - length(states) == 1) + length(states) == 1L) states <- seq_len(states) states <- as.character(states) if (is.numeric(actions) && - length(actions) == 1) + length(actions) == 1L) actions <- seq_len(actions) actions <- as.character(actions) if (inherits(x, "POMDP")) { if (is.numeric(observations) && - length(observations) == 1) + length(observations) == 1L) observations <- seq_len(observations) observations <- as.character(observations) } discount <- as.numeric(discount) - if (length(discount) != 1 || discount < 0 || discount > 1) + if (length(discount) != 1L || discount < 0 || discount > 1) stop("discount has to be a single value in the range [0,1].") if (!exists("horizon")) @@ -311,11 +338,11 @@ check_and_fix_MDP <- function(x) { if (any(horizon != floor(horizon))) stop("horizon needs to be an integer.") - ## FIXME: check terminal_values + ## TODO: check terminal_values # start if (is.numeric(start) && length(start) == length(states)) { - if (zapsmall(sum(start)) != 1) + if (!sum1(start)) stop("The start probability vector does not add up to 1.") if (is.null(names(start))) names(start) <- states @@ -332,8 +359,10 @@ check_and_fix_MDP <- function(x) { } ## read_POMDP does not parse these! - ## For now, we expand functions into matrices if (!exists("problem")) { + + ## TODO: keep functions. For now we expand functions into matrices + #check_func(transition_prob, T_, "transition_prob") if (is.function(transition_prob)) transition_prob <- transition_matrix(x) @@ -345,63 +374,92 @@ check_and_fix_MDP <- function(x) { if (is.function(observation_prob)) observation_prob <- observation_matrix(x) - check_df(transition_prob, T_, "transition_prob") - check_df(reward, R_, "reward") - if (inherits(x, "POMDP")) - check_df(observation_prob, O_, "observation_prob") - - ## FIXME: check that a is actions! - # if we have matrices then check and add names - for (a in names(transition_prob)) { - if (is.matrix(transition_prob[[a]])) { - if (!identical(dim(transition_prob[[a]]), c(length(states), length(states)))) - stop("transition_prob matrix for action ", - a, - ": has not the right dimensions!") - if (!all(rowSums(transition_prob[[a]]) == 1)) - stop("transition_prob matrix for action ", - a, - ": rows do not add up to 1!") - dimnames(transition_prob[[a]]) <- list(states, states) + if (is.data.frame(transition_prob)) + transition_prob <- check_df(transition_prob, T_) + else { + if (is.null(names(transition_prob))) + names(transition_prob) <- actions + for (a in actions) { + if (is.null(transition_prob[[a]])) + stop("transition_prob for action ", a, " is missing!") + if (is.matrix(transition_prob[[a]])) { + if (!identical(dim(transition_prob[[a]]), c(length(states), length(states)))) + stop("transition_prob matrix for action ", + a, + ": has not the right dimensions!") + if (!sum1(transition_prob[[a]])) + stop("transition_prob matrix for action ", + a, + ": rows do not add up to 1!") + if (is.null(dimnames(transition_prob[[a]]))) + dimnames(transition_prob[[a]]) <- list(states, states) + else + transition_prob[[a]][states, states] + } } } - - for (a in names(reward)) { - for (s in names(reward[[a]])) { - if (is.matrix(reward[[a]][[s]])) { - if (!identical(dim(reward[[a]][[s]]), c(length(states), length(observations)))) - stop( - "reward matrix for action ", - a, - " and start.state ", - s, - ": has not the right dimensions!" - ) - dimnames(reward[[a]][[s]]) <- - list(states, observations) + if (is.data.frame(reward)) + reward <- check_df(reward, R_) + else { + if (is.null(names(reward))) + names(reward) <- actions + for (a in actions) { + if (is.null(reward[[a]])) + stop("reward for action ", a, " is missing!") + for (s in states) { + if (is.null(reward[[a]][[s]])) + stop("reward for action ", a, " and state ", s, " is missing!") + if (is.matrix(reward[[a]][[s]])) { + if (!identical(dim(reward[[a]][[s]]), c(length(states), length(observations)))) + stop( + "reward matrix for action ", + a, + " and start.state ", + s, + ": has not the right dimensions!" + ) + if (is.null(dimnames(reward[[a]][[s]]))) + dimnames(reward[[a]][[s]]) <- + list(states, observations) + else + reward[[a]][[s]][states, observations] + } } } } if (inherits(x, "POMDP")) { - for (a in names(observation_prob)) { - if (is.matrix(observation_prob[[a]])) { - if (!identical(dim(observation_prob[[a]]), c(length(states), length(observations)))) - stop("observation_prob matrix for action ", - a, - ": has not the right dimensions!") - if (!all(rowSums(observation_prob[[a]]) == 1)) - stop("observation_prob matrix for action ", - a, - ": rows do not add up to 1!") - dimnames(observation_prob[[a]]) <- - list(states, observations) + if (is.data.frame(observation_prob)) + observation_prob <- check_df(observation_prob, O_) + else { + if (is.null(names(observation_prob))) + names(observation_prob) <- actions + for (a in actions) { + if (is.null(observation_prob[[a]])) + stop("observation_prob for action ", a, " is missing!") + if (is.matrix(observation_prob[[a]])) { + if (!identical(dim(observation_prob[[a]]), c(length(states), length(observations)))) + stop("observation_prob matrix for action ", + a, + ": has not the right dimensions!") + if (!all(rowSums(observation_prob[[a]]) == 1)) + stop("observation_prob matrix for action ", + a, + ": rows do not add up to 1!") + + if (is.null(dimnames(observation_prob[[a]]))) + dimnames(observation_prob[[a]]) <- + list(states, observations) + else + observation_prob[[a]][states, observations] + } } } } } + # cleanup if (exists("a", inherits = FALSE)) rm(a) @@ -410,9 +468,6 @@ check_and_fix_MDP <- function(x) { }) } - - - #' @export print.POMDP <- function(x, ...) { writeLines(paste(paste(class(x), collapse = ", "), @@ -427,12 +482,13 @@ print.POMDP <- function(x, ...) { writeLines(sprintf(" Horizon: %s epochs", paste(x$horizon, collapse = " + "))) - if (!is.null(x$solution)) + if (.solved_POMDP(x)) writeLines(c( - sprintf(" Solved. Solution converged: %s", + " Solved:", + sprintf(" Solution converged: %s", x$solution$converged), sprintf( - " Total expected reward (for start probabilities): %f", + " Total expected reward: %f", x$solution$total_expected_reward ) )) @@ -448,20 +504,24 @@ print.POMDP <- function(x, ...) { # check if x is a solved POMDP -.solved_POMDP <- function(x) { +.solved_POMDP <- function(x, stop = FALSE) { if (!inherits(x, "POMDP")) stop("x needs to be a POMDP object!") - if (is.null(x$solution)) + + solved <- !is.null(x$solution) + if (stop && !solved) stop("x needs to be a solved POMDP. Use solve_POMDP() first.") -} - + + solved +} + .timedependent_POMDP <- function(x) !is.null(x$horizon) && length(x$horizon) > 1L # get pg and alpha for a epoch .get_pg_index <- function(model, epoch) { - #.solved_POMDP(model) + #.solved_POMDP(model, stop = TRUE) epoch <- as.integer(epoch) if(epoch < 1L) stop("Epoch has to be >= 1") @@ -470,18 +530,16 @@ print.POMDP <- function(x, ...) { if (length(model$solution$pg) == 1L) return(1L) ### regular epoch for finite/infinite horizon case - if (epoch <= length(model$solution$pg)) return(epoch) - - if (epoch > sum(model$horizon)) - stop("POMDP model was only solved for ", sum(model$horizon), " epochs!") - - ### converged finite-horizon case return the last (i.e., converged) epoch - return(length(model$solution$pg)) + if (epoch > length(model$solution$pg)) + stop("POMDP model has only solutions for ", length(model$solution$pg), " epochs!") + + return(epoch) } .get_pg <- function(model, epoch) model$solution$pg[[.get_pg_index(model, epoch)]] + .get_alpha <- function(model, epoch) model$solution$alpha[[.get_pg_index(model, epoch)]] diff --git a/R/plot_policy_graph.R b/R/plot_policy_graph.R index fd8b0c6..419d4d9 100644 --- a/R/plot_policy_graph.R +++ b/R/plot_policy_graph.R @@ -118,7 +118,7 @@ #' #' @export policy_graph <- function(x, belief = NULL, show_belief = TRUE, col = NULL, ...) { - .solved_POMDP(x) + .solved_POMDP(x, stop = TRUE) if (!x$solution$converged || length(x$solution$pg) > 1) policy_graph_unconverged(x, belief, show_belief = show_belief, col = col, ...) @@ -209,7 +209,6 @@ policy_graph_converged <- function(x, belief = NULL, show_belief = TRUE, col = N } policy_graph_unconverged <- function(x, belief = NULL, show_belief = TRUE, col = NULL, ...) { - .solved_POMDP(x) pg <- x$solution$pg observations <- x$observations @@ -345,7 +344,6 @@ plot_policy_graph <- function(x, engine = c("igraph", "visNetwork"), col = NULL, ...) { - .solved_POMDP(x) engine <- match.arg(engine) switch( diff --git a/R/plot_value_function.R b/R/plot_value_function.R index 2070273..981cae9 100644 --- a/R/plot_value_function.R +++ b/R/plot_value_function.R @@ -56,7 +56,7 @@ plot_value_function <- lty = 1, ...) { if (inherits(model, "MDP")) { - .solved_MDP(model) + .solved_MDP(model, stop = TRUE) policy <- policy(model)[[epoch]] @@ -81,7 +81,7 @@ plot_value_function <- at = 0 ) } else { - .solved_POMDP(model) + .solved_POMDP(model, stop = TRUE) if (is.character(projection)) projection <- pmatch(projection, model$states) diff --git a/R/policy.R b/R/policy.R index c2f1770..d1c17bf 100644 --- a/R/policy.R +++ b/R/policy.R @@ -41,7 +41,9 @@ policy <- function(x) x$solution$policy .policy_MDP_from_POMDP <- function(x) { pg <- x$solution$pg - bs <- x$observation_prob[['*']] + + ## all observation_probs should be the same! + bs <- x$observation_prob[[1L]] # create a list ith epochs lapply( diff --git a/R/read_write_POMDP.R b/R/read_write_POMDP.R index a3ab08c..e5f8543 100644 --- a/R/read_write_POMDP.R +++ b/R/read_write_POMDP.R @@ -38,6 +38,9 @@ format_fixed <- function(x, digits = 7, debug = "unknown") { write_POMDP <- function(x, file, digits = 7) { if (!inherits(x, "POMDP")) stop("model needs to be a POMDP model use POMDP()!") + + x <- check_and_fix_MDP(x) + with(x, { number_of_states <- length(states) @@ -47,6 +50,7 @@ write_POMDP <- function(x, file, digits = 7) { # we only support rewards and not cost values <- "reward" + ## TODO: we currently convert function to matrix if (is.function(transition_prob)) transition_prob <- transition_matrix(x) if (is.function(observation_prob)) @@ -84,7 +88,7 @@ write_POMDP <- function(x, file, digits = 7) { if (!is.null(start)) { ## if the starting beliefs are given by enumerating the probabilities for each state if (is.numeric(start)) { - if (length(start) == length(states) && zapsmall(sum(start) - 1) == 0) { + if (length(start) == length(states) && sum1(start)) { code <- paste0(code, "start: ", @@ -138,11 +142,16 @@ write_POMDP <- function(x, file, digits = 7) { var_cols <- seq_len(ncol(x) - 1L) value_col <- ncol(x) - # fix indexing - for (j in var_cols) + # fix indexing and convert factor to character + for (j in var_cols) { if (is.numeric(x[[j]])) x[[j]] <- as.integer(x[[j]]) - 1L - + if (is.factor(x[[j]])) { + x[[j]] <- as.character(x[[j]]) + x[[j]][is.na(x[[j]])] <- "*" + } + } + # write lines for (i in 1:nrow(x)) { code <- paste0( @@ -179,7 +188,6 @@ write_POMDP <- function(x, file, digits = 7) { ### Transition Probabilities if (is.data.frame(transition_prob)) { - check_df(transition_prob, T_, "transition_prob") code <- paste0(code, format_POMDP_df(transition_prob, "T", digits)) } else{ @@ -199,7 +207,6 @@ write_POMDP <- function(x, file, digits = 7) { } ### Observation Probabilities if (is.data.frame(observation_prob)) { - check_df(observation_prob, O_, "observation_prob") code <- paste0(code, format_POMDP_df(observation_prob, "O", digits)) } else{ @@ -218,7 +225,6 @@ write_POMDP <- function(x, file, digits = 7) { ### Rewards/Costs if (is.data.frame(reward)) { - check_df(reward, R_, "reward") code <- paste0(code, format_POMDP_df(reward, "R", digits)) } else { diff --git a/R/read_write_pomdp_solve.R b/R/read_write_pomdp_solve.R index 018cfe7..7bd95e4 100644 --- a/R/read_write_pomdp_solve.R +++ b/R/read_write_pomdp_solve.R @@ -152,7 +152,7 @@ # renaming the columns and actions colnames(pg) <- c("node", "action", as.character(model$observations)) - pg[, 2] <- model$actions[pg[, 2]] + pg[, 2] <- factor(pg[, 2], levels = seq(length(model$actions)), labels = model$actions) pg } diff --git a/R/round_stochchastic.R b/R/round_stochchastic.R index 799fd26..ca15886 100644 --- a/R/round_stochchastic.R +++ b/R/round_stochchastic.R @@ -43,5 +43,9 @@ round_stochastic <- function(x, digits = 3) { r } - - +sum1 <- function(x) { + if(is.matrix(x)) + all(apply(x, MARGIN = 1, sum1)) + else + zapsmall(sum(x) - 1) == 0 +} diff --git a/R/simulate_MDP.R b/R/simulate_MDP.R new file mode 100644 index 0000000..b67d080 --- /dev/null +++ b/R/simulate_MDP.R @@ -0,0 +1,165 @@ +## TODO: Reimplement in C++ + +#' Simulate Trajectories in a MDP +#' +#' Simulate trajectories through a MDP. The start state for each +#' trajectory is randomly chosen using the specified belief. The belief is used to choose actions +#' from an epsilon-greedy policy and then update the state. +#' +#' @family MDP +#' @importFrom stats runif +#' +#' @param model a MDP model. +#' @param n number of trajectories. +#' @param start probability distribution over the states for choosing the +#' starting states for the trajectories. +#' Defaults to "uniform". +#' @param horizon number of epochs for the simulation. If `NULL` then the +#' horizon for the model is used. +#' @param visited_states logical; Should all visited states on the +#' trajectories be returned? If `FALSE` then only the final +#' state is returned. +#' @param epsilon the probability of random actions for using an epsilon-greedy policy. +#' Default for solved models is 0 and for unsolved model 1. +#' @param verbose report used parameters. +#' @return A vector with state ids (in the final epoch or all). Attributes containing action +#' counts, and rewards for each trajectory may be available. +#' @author Michael Hahsler +#' @md +#' @examples +#' data(Maze) +#' +#' # solve the POMDP for 5 epochs and no discounting +#' sol <- solve_MDP(Maze, discount = 1) +#' sol +#' policy(sol) +#' +#' ## Example 1: simulate 10 trajectories, only the final belief state is returned +#' sim <- simulate_MDP(sol, n = 10, horizon = 10, verbose = TRUE) +#' head(sim) +#' +#' # additional data is available as attributes +#' names(attributes(sim)) +#' attr(sim, "avg_reward") +#' colMeans(attr(sim, "action")) +#' +#' ## Example 2: simulate starting always in state s_1 +#' sim <- simulate_MDP(sol, n = 100, start = "s_1", horizon = 10) +#' sim +#' +#' # the average reward is an estimate of the utility in the optimal policy: +#' policy(sol)[[1]][1,] +#' +#' @export +simulate_MDP <- + function(model, + n = 100, + start = NULL, + horizon = NULL, + visited_states = FALSE, + epsilon = NULL, + verbose = FALSE) { + + start <- .translate_belief(start, model = model) + solved <- .solved_MDP(model) + + if (is.null(horizon)) + horizon <- model$horizon + if (is.null(horizon)) + stop("The horizon (number of epochs) has to be specified!") + if (is.infinite(horizon)) + stop("Simulation needs a finite simulation horizon.") + + if (is.null(epsilon)) { + if (!solved) epsilon <- 1 + else epsilon <- 0 + } + + if (!solved && epsilon != 1) + stop("epsilon has to be 1 for unsolved models.") + + disc <- model$discount + if (is.null(disc)) + disc <- 1 + + states <- as.character(model$states) + n_states <- length(states) + actions <- as.character(model$actions) + + trans_m <- transition_matrix(model) + rew_m <- reward_matrix(model) + + # for easier access + pol <- lapply(model$solution$policy, FUN = function(p) structure(p$action, names = p$state)) + + if (verbose) { + cat("Simulating MDP trajectories.\n") + cat("- horizon:", horizon, "\n") + cat("- epsilon:", epsilon, "\n") + cat("- discount factor:", disc, "\n") + cat("- starting state:\n") + print(start) + cat("\n") + } + + st <- replicate(n, expr = { + # find a initial state + + s <- sample(states, 1, prob = start) + + action_cnt <- rep(0L, length(actions)) + names(action_cnt) <- actions + state_cnt <- rep(0L, length(states)) + names(state_cnt) <- states + + rew <- 0 + + if (visited_states) + s_all <- integer(horizon) + + for (j in 1:horizon) { + + if (runif(1) < epsilon) { + a <- sample.int(length(actions), 1L, replace = TRUE) + } else { + a <- pol[[.get_pol_index(model, j)]][s] + } + + action_cnt[a] <- action_cnt[a] + 1L + state_cnt[s] <- state_cnt[s] + 1L + + s_prev <- s + s <- sample.int(length(states), 1L, prob = trans_m[[a]][s,]) + + rew <- rew + rew_m[[a]][[s_prev]][s] * disc ^ (j - 1L) + + if (visited_states) + s_all[j] <- s + } + + if (!visited_states) + s_all <- s + + rownames(s_all) <- NULL + attr(s_all, "action_cnt") <- action_cnt + attr(s_all, "state_cnt") <- state_cnt + attr(s_all, "reward") <- rew + s_all + + }, simplify = FALSE) + + ac <- do.call(rbind, lapply(st, attr, "action_cnt")) + rownames(ac) <- NULL + sc <- do.call(rbind, lapply(st, attr, "state_cnt")) + rownames(sc) <- NULL + rew <- do.call(rbind, lapply(st, attr, "reward")) + rownames(rew) <- NULL + st <- do.call(c, st) + + attr(st, "action_cnt") <- ac + attr(st, "state_cnt") <- sc + attr(st, "reward") <- rew + attr(st, "avg_reward") <- mean(rew, na.rm = TRUE) + + st + } diff --git a/R/simulate_POMDP.R b/R/simulate_POMDP.R index 079697d..9f21f69 100644 --- a/R/simulate_POMDP.R +++ b/R/simulate_POMDP.R @@ -4,9 +4,7 @@ #' #' Simulate trajectories through a POMDP. The start state for each #' trajectory is randomly chosen using the specified belief. The belief is used to choose actions -#' from the policy and then updated using observations. For solved POMDPs -#' the optimal actions will be chosen, for unsolved POMDPs random actions will -#' be used. +#' from the the epsilon-greedy policy and then updated using observations. #' #' @family POMDP #' @importFrom stats runif @@ -146,7 +144,7 @@ simulate_POMDP <- bs <- replicate(n, expr = { # find a initial state - s <- sample(states, 1, prob = belief) + s <- sample(states, 1L, prob = belief) b <- belief action_cnt <- rep(0L, length(actions)) @@ -156,6 +154,7 @@ simulate_POMDP <- names(state_cnt) <- states rew <- 0 + e <- 1L if (visited_beliefs) b_all <- matrix( @@ -179,20 +178,20 @@ simulate_POMDP <- # find action (if we have no solution then take a random action) and update state and sample obs if (runif(1) < epsilon) { - a <- sample(actions, 1) + a <- sample.int(length(actions), 1L, replace = TRUE) } else { - # convert index for converged POMDPs - e <- .get_pg_index(model, j) + if(!model$solution$converged) + e <- .get_pg_index(model, j) a <- - as.character(model$solution$pg[[e]][which.max(model$solution$alpha[[e]] %*% b), "action"]) + as.integer(model$solution$pg[[e]][["action"]][which.max(model$solution$alpha[[e]] %*% b)]) } action_cnt[a] <- action_cnt[a] + 1L state_cnt[s] <- state_cnt[s] + 1L s_prev <- s - s <- sample(states, 1, prob = trans_m[[a]][s,]) - o <- sample(obs, 1, prob = obs_m[[a]][s,]) + s <- sample.int(length(states), 1L, prob = trans_m[[a]][s,]) + o <- sample.int(length(obs), 1L, prob = obs_m[[a]][s,]) rew <- rew + rew_m[[a]][[s_prev]][s, o] * disc ^ (j - 1L) @@ -216,13 +215,13 @@ simulate_POMDP <- }, simplify = FALSE) - ac <- Reduce(rbind, lapply(bs, attr, "action_cnt")) + ac <- do.call(rbind, lapply(bs, attr, "action_cnt")) rownames(ac) <- NULL - sc <- Reduce(rbind, lapply(bs, attr, "state_cnt")) + sc <- do.call(rbind, lapply(bs, attr, "state_cnt")) rownames(sc) <- NULL - rew <- Reduce(rbind, lapply(bs, attr, "reward")) + rew <- do.call(rbind, lapply(bs, attr, "reward")) rownames(rew) <- NULL - bs <- Reduce(rbind, bs) + bs <- do.call(rbind, bs) rownames(bs) <- NULL attr(bs, "action_cnt") <- ac diff --git a/R/solve_MDP.R b/R/solve_MDP.R index 6392745..638da9a 100644 --- a/R/solve_MDP.R +++ b/R/solve_MDP.R @@ -1,5 +1,4 @@ -# TODO: -# * actions(s) +# TODO: deal with available actions for states actions(s) #' Solve an MDP Problem #' diff --git a/R/solve_SARSOP.R b/R/solve_SARSOP.R index 82f1f59..894aff9 100644 --- a/R/solve_SARSOP.R +++ b/R/solve_SARSOP.R @@ -149,8 +149,8 @@ if (!is.null(terminal_values)) # package solution policy <- sarsop::read_policyx(policy_file) - pg <- data.frame(node = 1:length(policy$action), - action = model$actions[policy$action]) + pg <- data.frame(node = seq_along(policy$action), + action = factor(model$actions[policy$action], levels = model$actions)) alpha <- t(policy$vectors) colnames(alpha) = model$states diff --git a/R/transition_matrix.R b/R/transition_matrix.R index d2c8e3c..145a0f9 100644 --- a/R/transition_matrix.R +++ b/R/transition_matrix.R @@ -1,7 +1,8 @@ #' Extract the Transition, Observation or Reward Information from a POMDP #' #' Converts the description of transition probabilities and observation -#' probabilities in a POMDP into a list of matrices or a function. +#' probabilities in a POMDP into a list of matrices. Individual values or parts of the matrices +#' can be more efficiently retrieved using the functions ending `_prob` and `_val`. #' #' See Details section in [POMDP] for details. #' @@ -9,7 +10,8 @@ #' #' @param x A [POMDP] object. #' @param episode Episode used for time-dependent POMDPs ([POMDP]). -#' @param action only return the matrix for a given action. +#' @param action only return the matrix/value for a given action. +#' @param start.state,end.state,observation name of the state or observation. #' @return A list or a list of lists of matrices. #' @author Michael Hahsler #' @examples @@ -18,20 +20,18 @@ #' # List of |A| transition matrices. One per action in the from states x states #' Tiger$transition_prob #' transition_matrix(Tiger) -#' -#' f <- transition_function(Tiger) -#' args(f) -#' ## listening does not change the tiger's position. -#' f("listen", "tiger-left", "tiger-left") +#' transition_prob(Tiger, action = "listen", start.state = "tiger-left") #' #' # List of |A| observation matrices. One per action in the from states x observations #' Tiger$observation_prob #' observation_matrix(Tiger) +#' observation_prob(Tiger, action = "listen", end.state = "tiger-left") #' #' # List of list of reward matrices. 1st level is action and second level is the #' # start state in the form end state x observation #' Tiger$reward #' reward_matrix(Tiger) +#' reward_val(Tiger, action = "listen", start.state = "tiger") #' #' # Visualize transition matrix for action 'open-left' #' library("igraph") @@ -66,18 +66,13 @@ transition_matrix <- function(x, episode = 1, action = NULL) { ) } +## TODO: make the access functions more efficient for a single value + #' @rdname transition_matrix #' @export -transition_function <- function(x, episode = 1) { - m <- transition_matrix(x, episode) - - return ({ - m - function(action, start.state, end.state) - m[[action]][start.state, end.state] - }) -} - +transition_prob <- function(x, action, start.state, end.state, episode = 1) { + transition_matrix(x, episode = 1, action = action)[start.state, end.state] +} #' @rdname transition_matrix #' @export @@ -98,48 +93,23 @@ observation_matrix <- function(x, episode = 1, action = NULL) { #' @rdname transition_matrix #' @export -observation_function <- function(x, episode = 1) { - m <- observation_matrix(x, episode) - - return ( { - m - function(action, observation, end.state) - m[[action]][end.state, observation] - }) +observation_prob <- function(x, action, end.state, observation, episode = 1) { + observation_matrix(x, episode = 1, action = action)[end.state, observation] } - #' @rdname transition_matrix #' @export -reward_matrix <- function(x, episode = 1, action = NULL) { +reward_matrix <- function(x, episode = 1, action = NULL, start.state = NULL) { ## action list of s' x o matrices ## action list of s list of s' x o matrices ## if not observations are available then it is a s' vector - .translate_reward(x, episode = episode, action = action) + .translate_reward(x, episode = episode, action = action, start.state = start.state) } #' @rdname transition_matrix #' @export -reward_function <- function(x, episode = 1) { - m <- reward_matrix(x, episode = 1) - - ## MDP has no observations! - if (inherits(x, "POMDP")) - return({ - m - function(action, - start.state, - end.state, - observation) - m[[action]][[start.state]][end.state, observation] - }) - - else - return({ - m - function(action, start.state, end.state) - m[[action]][[start.state]][end.state] - }) +reward_val <- function(x, action, start.state, end.state, observation, episode = 1) { + reward_matrix(x, episode = 1, action = action, start.state = start.state)[end.state, observation] } # translate different specifications of transitions, observations and rewards @@ -164,11 +134,11 @@ reward_function <- function(x, episode = 1) { ) for (i in 1:nrow(df)) { - if (df[i, 1] == "*" && df[i, 2] == "*") + if (is.na(df[i, 1]) && is.na(df[i, 2])) m[] <- df[i, 3] - else if (df[i, 1] == "*") + else if (is.na(df[i, 1])) m[, df[i, 2]] <- df[i, 3] - else if (df[i, 2] == "*") + else if (is.na(df[i, 2])) m[df[i, 1],] <- df[i, 3] else m[df[i, 1], df[i, 2]] <- df[i, 3] @@ -188,7 +158,7 @@ reward_function <- function(x, episode = 1) { names(v) <- from for (i in 1:nrow(df)) { - if (df[i, 1] == "*") + if (is.na(df[i, 1])) v[] <- df[i, 2] else v[df[i, 1]] <- df[i, 2] @@ -251,7 +221,7 @@ reward_function <- function(x, episode = 1) { if (is.data.frame(prob)) { prob <- sapply(actions, function(a) { .df2matrix(model, - prob[(prob$action == a | prob$action == "*"), 2:4], + prob[(prob$action == a | is.na(prob$action)), 2:4], from = from, to = to) }, simplify = FALSE, USE.NAMES = TRUE) @@ -324,8 +294,11 @@ reward_function <- function(x, episode = 1) { ## reward is action -> start.state -> end.state x observation -.translate_reward <- function(model, episode = 1, action = NULL) { - states <- model$states +.translate_reward <- function(model, episode = 1, action = NULL, start.state = NULL) { + if (is.null(start.state)) + states <- model$states + else + states <- start.state if (is.null(action)) actions <- model$actions @@ -384,16 +357,16 @@ reward_function <- function(x, episode = 1) { FUN = function(s) { if(!is.null(observations)) { .df2matrix(model, - reward[(reward$action == a | reward$action == "*") & + reward[(reward$action == a | is.na(reward$action)) & (reward$start.state == s | - reward$start.state == "*"), 3:5], + is.na(reward$start.state)), 3:5], from = "states", to = "observations") }else{ ## MDPs have no observations .df2vector(model, - reward[(reward$action == a | reward$action == "*") & + reward[(reward$action == a | is.na(reward$action)) & (reward$start.state == s | - reward$start.state == "*"), c(3,5)], + is.na(reward$start.state)), c(3,5)], from = "states") } }, @@ -442,5 +415,8 @@ reward_function <- function(x, episode = 1) { if(!is.null(action) && length(action) == 1) reward <- reward[[1]] + if(!is.null(start.state) && length(start.state) == 1) + reward <- reward[[1]] + reward } diff --git a/Work/data/create_Three_doors.R b/Work/data/create_Three_doors.R index f3d57fd..f1b6428 100644 --- a/Work/data/create_Three_doors.R +++ b/Work/data/create_Three_doors.R @@ -27,7 +27,7 @@ Three_doors <- POMDP( "open-center" = "uniform", "open-right" = "uniform"), - # the rew helper expects: action, start.state, end.state, observation, value + # the raw helper expects: action, start.state, end.state, observation, value reward = rbind( R_("listen", "*", "*", "*", -1 ), R_("open-left", "*", "*", "*", 10), @@ -43,11 +43,11 @@ Three_doors save(Three_doors, file = "data/Three_doors.rda") -sol <- solve_POMDP(Three_doors) - -plot(sol) -reward(sol) - -plot_belief_space(sol, projection = 1:3, n = 10000) +#sol <- solve_POMDP(Three_doors) +# +#plot(sol) +#reward(sol) +# +#plot_belief_space(sol, projection = 1:3, n = 10000) diff --git a/Work/data/create_Tiger.R b/Work/data/create_Tiger.R index b244555..4f8c8cc 100644 --- a/Work/data/create_Tiger.R +++ b/Work/data/create_Tiger.R @@ -33,3 +33,4 @@ Tiger save(Tiger, file = "data/Tiger.rda") + diff --git a/data/Maze.rda b/data/Maze.rda index 7b94f232501635492c2a06eab7ffdbe1e9980ae7..dfaff53bb2e1a3eed926f64e5450faab4c46a252 100644 GIT binary patch literal 771 zcmV+e1N{6SiwFP!000001MOMAZ__{&J`y|4j{=p#!j_GRh&Bbvl%;~H5-K2dtFFmK zwdC69&Q4Gm;7?$MnVG#)=Sc8JFjivl%usV@-<=aDr~Ke-rBY6^{qFhh``$ah`(ioA zhbOJ-NfiK41V(|LO=<%3-P@Os4?yGrcYpyZ^lBWzA)+Jd)GpIwXYAg?H=Yl%*EseG zLAL#fG+H0_8}S5ij}2@GK(wg;Dd{tyZ4OiiTNxspC{%)V!jZ0 zP>LsG_U@Ry$HOtW&vabuor?}}ec3S(B%ZJbh}xh~k7j}bOYlqk@+)iKK&q+y%eMA) zjME>>^$xDru#Xh@Fb67-E{HH^XuY~ z`2@NcX7fnD}?(O)P7*eJ~9C(2Uwd;J`4p=*FUZZesq}X$DOc$;8 z;u>YMIi`3tZvOKDn80|cPE`~*Un35MlFjL`+%)e3e`-t2f<>uF5CS#W0&k`X$ds z(KC8SPwxcH?>4z9y6S_BrhFuxLIJh&1iFys@Z^<`LJ4ney=8>2;>#L8w|bL28%5Wp z=o88QI~AgNA`blPo~nbArWwrqz5~mi+Pw+ z;n7y(wPeO{X1I0_jfGVT7fVbeG`;S(JTY}a$;7}Sh|xc2rK|75->-f=_;?k>^XBUr zLpB7mDUd0!g_)VXQ|CzVBN!_&cxGt0i|@{{lT(`TY!xXdIez!<-uvFW`|jCt z4)>3m)uSpxsEEvhfr=Irk@fb?i^qEinFi|=kclenZXV!)&F0mZTxQqi#Mpz^et?Nz zKMW|fyWNMh-h97Pj||XVKCu-d*~a~kai2Td5RQuUcqBrV@qW41V6VfzVP-HrJ7E^ zXlq?Xep*8qe{*qu|KRQC)cBJ5XxHZ|+jUd^LMTtdUrzZfY6s3Y=zU1}Q#hJ+gTk4o zhr&_$gir{xag<)+B&v6@CacD1Vv?2#ym%d76WhS3a6b_}}s^BmX+Qto`T4H-)pR z=uj0s@!iUjAFX(tOMSWLF310CSi(z5j29j)CfwM1I)r{q$vK`dv(9lh zuv1%KAzr^Bw|{)9KCz$CX!Iw|k#-kT$q}incCe2dr-a!@97v3=1+H`IkzOjPvfKd< zE97#=p=~d4C9T%ZJmL&g=%sOLR1QX~A-nf{f?eux)$u#xok$yE_c7@x0`oPCA}Ps8 a49S^;{Z}%mt|R%3uKxmromXw#8vp=9+;+AA diff --git a/data/Three_doors.rda b/data/Three_doors.rda index 2bac91caff347f9fc2292f11d76ce695338728ac..b9b302e9191126b2de675f3e5ce9271e253c8f9d 100644 GIT binary patch literal 554 zcmV+_0@eK=iwFP!000001Lal0j?*v@HEEJ=lLac357-0O_R3M(Y_}`8z;D1G>{o!q5AX>b;kH*U4Yp&qakL<FbmKwk_mKdmR>lkN(fpNxX=7i)*gUELcyt|=f$@u!du$P&L;x&ktY1kbi33R@C z*{@H!ea^mDNy~2R`rFvF?1#nO`lcao^&8H2a-ze0sCreaBkq_~YC~~(e|LP>{$F>@d*1zWTV?3^)4S)dZ||NPHz0B+r~?Rd^_B{Wxg4oc zf-6Cwxd)`|LIB_5np(8Q#ReMH*W>JZLSyxf+c152SSK%qD=?@04@*_P5=M^ literal 527 zcmV+q0`UDGiwFP!000001Fcm(Zrd;vre(>Fr8F_%93Vrc0`}sin>rZXf*OGfB$L5d zbWA`d1CmbAF3C088{`xXkPGAp8QU#u7ds@y=SYIlBDDxmH3lzvg3+1`CamMRdXNsBGp^3Jj>T3yaE*PD1F@Gyus7%1{oXCpPuzZcg zGxci4zrUODBYw<^7T-Dfp>Ie00EP)xxHlEr{_lOUECm)h72mx2wb|u_FBnw5IpczUTf5(|;Ym2mG6Ldm zE?mI&mB36=1C1~BAIZV(oN69#2y%@@0^!;mubGILQ-;&AO`*(!X+SU@*bF}(4~7t# R={J{@XfhMCR zZNcb8|CuKK851wuO_*)h!qxAHdj|}gnkM+cWMj=urDYx8S2bl@*%O=?X0dV zuoP^{v~miw%|g2CKZVp^px||$Gy8U%@25xQ-QrHa1@=UFAg1=>V=f>QE%_G6_5vbF z35e-EM{EWQZov$uQ{;3;ozaJu9X~y$ij6GGzSa~MB~ZZDfJqnd<2WsxT4G~TQPEbj z&--*rEV6_aO1aV2%z2E!O_kMkESDo9If9oTJ%E&D_sO+gv3lkd3+y{u>$tBvi*+WG4j)ADr;xQI{Q*P zHzG2S773OJr^Au7#$~IjL{vz^tW=99gk?Q_17saHJ;kY;Q}T6HX-Jr)@}Yq5Oc@=- ze7q9QQ^)X_i2oF__{q=FRJ!etmU-H5h27$X49X$Zx$YT0(v>r*Dig@FZf*X2OB)+6 zwjtC>{>h=~UC##bHVd*g>Z+!L*rk)r2>AoWze2D7*wRTinfGGn_4RJKpB?9SzzesN z5{vti#z1T!B<8ZI-jV>dgoQaEi0=jAGn{erW;mLXLa)|%Ja0m?nx=kMuOvynij&um zuz<+}&=0Z1%>h^b>pO8#C@iohu0Hx_Y(zhvGPtpwQOxOMntsKUCqTJm1jOA;xqux= zj+vku>P^(|)WU}u(=1-$WR7_V