Skip to content

Commit

Permalink
address #4
Browse files Browse the repository at this point in the history
  • Loading branch information
vnijs committed Apr 2, 2020
1 parent 27f658a commit c648605
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 29 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: radiant.model
Type: Package
Title: Model Menu for Radiant: Business Analytics using R and Shiny
Version: 1.3.10
Date: 2020-3-24
Version: 1.3.11
Date: 2020-4-1
Authors@R: person("Vincent", "Nijs", , "[email protected]", c("aut", "cre"))
Description: The Radiant Model menu includes interfaces for linear and logistic
regression, naive Bayes, neural networks, classification and regression trees,
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# radiant.model 1.3.11

* Improvement in calculating PDP for categorical variables in plot.gbt based on suggestion by @benmarchi (https://github.com/radiant-rstats/radiant.model/issues/4)

# radiant.model 1.3.9

* Minor adjustments in anticipation of dplyr 1.0.0
Expand Down
52 changes: 25 additions & 27 deletions R/gbt.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@
#' @examples
#' gbt(titanic, "survived", c("pclass", "sex"), lev = "Yes") %>% summary()
#' gbt(titanic, "survived", c("pclass", "sex")) %>% str()
#' gbt(titanic, "survived", c("pclass", "sex"), eval_metric = paste0("error@", 0.5/6)) %>% str()
#' gbt(titanic, "survived", c("pclass", "sex"), eval_metric = paste0("error@", 0.5 / 6)) %>% str()
#' gbt(diamonds, "price", c("carat", "clarity"), type = "regression") %>% summary()
#' rig_wrap <- function(preds, dtrain) {
#' labels <- xgboost::getinfo(dtrain, "label")
#' value <- rig(preds, labels, lev = 1)
#' list(metric = "rig", value = value)
#' }
#' gbt(titanic, "survived", c("pclass", "sex"), eval_metric = rig_wrap, maximize = TRUE) %>% str()
#'
#' @seealso \code{\link{summary.gbt}} to summarize results
#' @seealso \code{\link{plot.gbt}} to plot results
#' @seealso \code{\link{predict.gbt}} for prediction
Expand All @@ -44,12 +43,12 @@
#'
#' @export
gbt <- function(
dataset, rvar, evar, type = "classification", lev = "",
max_depth = 6, learning_rate = 0.3, min_split_loss = 0,
min_child_weight = 1, subsample = 1,
nrounds = 100, early_stopping_rounds = 10,
nthread = 12, wts = "None", seed = NA,
data_filter = "", envir = parent.frame(), ...
dataset, rvar, evar, type = "classification", lev = "",
max_depth = 6, learning_rate = 0.3, min_split_loss = 0,
min_child_weight = 1, subsample = 1,
nrounds = 100, early_stopping_rounds = 10,
nthread = 12, wts = "None", seed = NA,
data_filter = "", envir = parent.frame(), ...
) {

if (rvar %in% evar) {
Expand Down Expand Up @@ -177,7 +176,7 @@ gbt <- function(
gbt_input$data <- gbt_input$label <- NULL

## needed to work with prediction functions
check <- ""
check <- ""

as.list(environment()) %>% add_class(c("gbt", "model"))
}
Expand All @@ -193,7 +192,6 @@ gbt <- function(
#' @examples
#' result <- gbt(titanic, "survived", "pclass", lev = "Yes")
#' summary(result)
#'
#' @seealso \code{\link{gbt}} to generate results
#' @seealso \code{\link{plot.gbt}} to plot results
#' @seealso \code{\link{predict.gbt}} for prediction
Expand Down Expand Up @@ -265,7 +263,6 @@ summary.gbt <- function(object, prn = TRUE, ...) {
#'
#' @examples
#' result <- gbt(titanic, "survived", c("pclass", "sex"), lev = "Yes")
#'
#' @seealso \code{\link{gbt}} to generate results
#' @seealso \code{\link{summary.gbt}} to summarize results
#' @seealso \code{\link{predict.gbt}} for prediction
Expand All @@ -274,8 +271,8 @@ summary.gbt <- function(object, prn = TRUE, ...) {
#'
#' @export
plot.gbt <- function(
x, plots = "", nrobs = Inf,
shiny = FALSE, custom = FALSE, ...
x, plots = "", nrobs = Inf,
shiny = FALSE, custom = FALSE, ...
) {

if (is.character(x) || !inherits(x$model, "xgb.Booster")) return(x)
Expand All @@ -296,15 +293,17 @@ plot.gbt <- function(
nr <- length(fn)
for (i in seq_len(nr)) {
seed <- x$seed
pdi <- pdp::partial(
dtx_cat <- dtx
dtx_cat[, setdiff(fn, fn[i])] <- 0
pdi <- pdp::partial(
x$model, pred.var = fn[i], plot = FALSE,
prob = x$type == "classification", train = dtx
prob = x$type == "classification", train = dtx_cat
)
effects[i] <- pdi[pdi[[1]] == 1, 2]
}
pgrid <- as.data.frame(matrix(0, ncol = nr))
colnames(pgrid) <- fn
base <- pdp::partial(
base <- pdp::partial(
x$model, pred.var = fn,
pred.grid = pgrid, plot = FALSE,
prob = x$type == "classification", train = dtx
Expand Down Expand Up @@ -381,14 +380,13 @@ plot.gbt <- function(
#' result <- gbt(diamonds, "price", "carat:color", type = "regression")
#' predict(result, pred_cmd = "carat = 1:3")
#' predict(result, pred_data = diamonds) %>% head()
#'
#' @seealso \code{\link{gbt}} to generate the result
#' @seealso \code{\link{summary.gbt}} to summarize results
#'
#' @export
predict.gbt <- function(
object, pred_data = NULL, pred_cmd = "",
dec = 3, envir = parent.frame(), ...
object, pred_data = NULL, pred_cmd = "",
dec = 3, envir = parent.frame(), ...
) {

if (is.character(object)) return(object)
Expand Down Expand Up @@ -491,10 +489,10 @@ print.gbt.predict <- function(x, ..., n = 10)
#'
#' @export
cv.gbt <- function(
object, K = 5, repeats = 1, params = list(),
nrounds = 500, early_stopping_rounds = 10, nthread = 12,
train = NULL, type = "classification",
trace = TRUE, seed = 1234, maximize = NULL, fun, ...
object, K = 5, repeats = 1, params = list(),
nrounds = 500, early_stopping_rounds = 10, nthread = 12,
train = NULL, type = "classification",
trace = TRUE, seed = 1234, maximize = NULL, fun, ...
) {

if (inherits(object, "gbt")) {
Expand All @@ -517,7 +515,7 @@ cv.gbt <- function(
if (is_empty(params_base[["maximize"]])) {
params_base[["maximize"]] <- object$extra_args[["maximize"]]
}
} else if (!inherits(object, "xgb.Booster")) {
} else if (!inherits(object, "xgb.Booster")) {
stop("The model object does not seems to be a Gradient Boosted Tree")
} else {
if (!inherits(train, "xgb.DMatrix")) {
Expand Down Expand Up @@ -554,10 +552,10 @@ cv.gbt <- function(
fun <- params$eval_metric
} else {
fun <- list("custom" = params$eval_metric)
}
}
}
}
}

if (length(shiny::getDefaultReactiveDomain()) > 0) {
trace <- FALSE
incProgress <- shiny::incProgress
Expand Down

0 comments on commit c648605

Please sign in to comment.