Skip to content

Commit

Permalink
refactor: simplify and unify transformation process
Browse files Browse the repository at this point in the history
- In model.R & pareto.R: remove decompSpendDist from  both scripts to reduce memory leak. Use xDecompAgg subsets instead
- In transformation.R & response.R: unify transformation namings in run_transformation and robyn_response
- In response.R: remove exposure extrapolation because it's already done in robyn_input. Also add inflexion point to output.
- In plots.R: fix onepager saturation plot issues
- In pareto.R: rewrite run_dt_resp() as response_wrapper and align transformation logic & naming.
- In pareto: Replace foreach response loop with lapply for simplicity.
- In pareto.R: Simplify plot data generation process, esp for saturation curve plot, actual vs predicted plot & immediate vs carryover plot.
- In pareto.R: Remove redundancy in xDecompVecCollect -> remove type rawMedia, rawSpend, predictedExposure, saturatedMedia & saturatedSpendReversed. Only keep adstockedMedia & decompMedia for response curve plotting.
  • Loading branch information
gufengzhou committed Oct 30, 2024
1 parent 88e46c2 commit 9158a93
Show file tree
Hide file tree
Showing 9 changed files with 308 additions and 385 deletions.
1 change: 1 addition & 0 deletions R/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ importFrom(dplyr,distinct)
importFrom(dplyr,ends_with)
importFrom(dplyr,everything)
importFrom(dplyr,filter)
importFrom(dplyr,full_join)
importFrom(dplyr,group_by)
importFrom(dplyr,lag)
importFrom(dplyr,left_join)
Expand Down
2 changes: 1 addition & 1 deletion R/R/imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#' @importFrom doRNG %dorng%
#' @importFrom doParallel registerDoParallel stopImplicitCluster
#' @importFrom dplyr across any_of arrange as_tibble bind_rows case_when contains desc distinct
#' everything filter group_by lag left_join mutate n pull rename row_number select slice
#' everything filter full_join group_by lag left_join mutate n pull rename row_number select slice
#' summarise summarise_all ungroup all_of bind_cols mutate_at starts_with ends_with tally n_distinct
#' @importFrom foreach foreach %dopar% getDoParWorkers registerDoSEQ
#' @import ggplot2
Expand Down
83 changes: 48 additions & 35 deletions R/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ robyn_train <- function(InputCollect, hyper_collect,
OutputModels[[1]]$trial <- 1
# Set original solID (to overwrite default 1_1_1)
if ("solID" %in% names(dt_hyper_fixed)) {
these <- c("resultHypParam", "xDecompVec", "xDecompAgg", "decompSpendDist")
these <- c("resultHypParam", "xDecompVec", "xDecompAgg")
for (tab in these) OutputModels[[1]]$resultCollect[[tab]]$solID <- dt_hyper_fixed$solID
}
} else {
Expand Down Expand Up @@ -405,9 +405,9 @@ robyn_train <- function(InputCollect, hyper_collect,
seed = seed + ngt,
quiet = quiet
)
check_coef0 <- any(model_output$resultCollect$decompSpendDist$decomp.rssd == Inf)
check_coef0 <- any(model_output$resultCollect$resultHypParam$decomp.rssd == Inf)
if (check_coef0) {
num_coef0_mod <- filter(model_output$resultCollect$decompSpendDist, is.infinite(.data$decomp.rssd)) %>%
num_coef0_mod <- filter(model_output$resultCollect$resultHypParam, is.infinite(.data$decomp.rssd)) %>%
distinct(.data$iterNG, .data$iterPar) %>%
nrow()
num_coef0_mod <- ifelse(num_coef0_mod > iterations, iterations, num_coef0_mod)
Expand Down Expand Up @@ -515,6 +515,7 @@ robyn_mmm <- function(InputCollect,
rollingWindowLength <- InputCollect$rollingWindowLength
paid_media_spends <- InputCollect$paid_media_spends
paid_media_selected <- InputCollect$paid_media_selected
exposure_vars <- InputCollect$exposure_vars
organic_vars <- InputCollect$organic_vars
context_vars <- InputCollect$context_vars
prophet_vars <- InputCollect$prophet_vars
Expand Down Expand Up @@ -542,6 +543,9 @@ robyn_mmm <- function(InputCollect,
mean_spend = unlist(summarise_all(temp, mean))
) %>%
mutate(spend_share = .data$total_spend / sum(.data$total_spend))
temp <- select(dt_inputTrain, all_of(c(exposure_vars, organic_vars))) %>% summarise_all(mean) %>% unlist
temp <- data.frame(rn = c(exposure_vars, organic_vars), mean_exposure = temp)
dt_spendShare <- full_join(dt_spendShare, temp, by = "rn")
# When not refreshing, dt_spendShareRF = dt_spendShare
refreshAddedStartWhich <- which(dt_modRollWind$ds == refreshAddedStart)
temp <- select(dt_inputTrain, all_of(paid_media_spends)) %>%
Expand All @@ -556,8 +560,14 @@ robyn_mmm <- function(InputCollect,
) %>%
mutate(spend_share = .data$total_spend / sum(.data$total_spend))
# Join both dataframes into a single one
temp <- select(dt_inputTrain, all_of(c(exposure_vars, organic_vars))) %>%
slice(refreshAddedStartWhich:rollingWindowLength) %>%
summarise_all(mean) %>% unlist
temp <- data.frame(rn = c(exposure_vars, organic_vars), mean_exposure = temp)
dt_spendShareRF <- full_join(dt_spendShareRF, temp, by = "rn")
dt_spendShare <- left_join(dt_spendShare, dt_spendShareRF, "rn", suffix = c("", "_refresh"))


################################################
#### Get lambda
lambda_min_ratio <- 0.0001 # default value from glmnet
Expand Down Expand Up @@ -666,7 +676,7 @@ robyn_mmm <- function(InputCollect,
window_end_loc = InputCollect$rollingWindowEndWhich,
dt_mod = InputCollect$dt_mod,
adstock = InputCollect$adstock,
hyperparameters = hypParamSam, ...)
dt_hyppar = hypParamSam, ...)
dt_modSaturated <- temp$dt_modSaturated
dt_saturatedImmediate <- temp$dt_saturatedImmediate
dt_saturatedCarryover <- temp$dt_saturatedCarryover
Expand Down Expand Up @@ -805,36 +815,38 @@ robyn_mmm <- function(InputCollect,
#####################################
#### DECOMP.RSSD: Business error
# Sum of squared distance between decomp share and spend share to be minimized
dt_decompSpendDist <- decompCollect$xDecompAgg %>%
filter(.data$rn %in% paid_media_selected) %>%
dt_loss_calc <- decompCollect$xDecompAgg %>%
filter(.data$rn %in% c(paid_media_selected, organic_vars)) %>%
select(
.data$rn, .data$xDecompAgg, .data$xDecompPerc, .data$xDecompMeanNon0Perc,
.data$xDecompMeanNon0, .data$xDecompPercRF, .data$xDecompMeanNon0PercRF,
.data$xDecompMeanNon0RF
) %>%
left_join(
select(
dt_spendShare,
.data$rn, .data$spend_share, .data$spend_share_refresh,
.data$mean_spend, .data$total_spend
),
by = "rn"
) %>%
mutate(
effect_share = .data$xDecompPerc / sum(.data$xDecompPerc),
effect_share_refresh = .data$xDecompPercRF / sum(.data$xDecompPercRF)
.data$rn, .data$xDecompPerc, .data$xDecompPercRF
)
dt_decompSpendDist <- left_join(
filter(decompCollect$xDecompAgg, .data$rn %in% paid_media_selected),
select(dt_decompSpendDist, .data$rn, contains("_spend"), contains("_share")),
dt_loss_calc <- dt_loss_calc %>% left_join(
select(
dt_spendShare,
c("rn", "spend_share", "spend_share_refresh","mean_spend",
"total_spend", "mean_exposure", "mean_exposure_refresh")
),
by = "rn"
)
dt_loss_calc <- bind_rows(
dt_loss_calc %>% filter(.data$rn %in% paid_media_selected) %>%
mutate(
effect_share = .data$xDecompPerc / sum(.data$xDecompPerc),
effect_share_refresh = .data$xDecompPercRF / sum(.data$xDecompPercRF)
),
dt_loss_calc %>% filter(.data$rn %in% organic_vars) %>%
mutate(
effect_share = NA, effect_share_refresh = NA)
) %>% select(-c("xDecompPerc", "xDecompPercRF"))
decompCollect$xDecompAgg <- left_join(
decompCollect$xDecompAgg, dt_loss_calc, by = "rn")
dt_loss_calc <- dt_loss_calc %>% filter(.data$rn %in% paid_media_selected)
if (!refresh) {
decomp.rssd <- sqrt(sum((dt_decompSpendDist$effect_share - dt_decompSpendDist$spend_share)^2))
decomp.rssd <- sqrt(sum((dt_loss_calc$effect_share - dt_loss_calc$spend_share)^2))
# Penalty for models with more 0-coefficients
if (rssd_zero_penalty) {
is_0eff <- round(dt_decompSpendDist$effect_share, 4) == 0
share_0eff <- sum(is_0eff) / length(dt_decompSpendDist$effect_share)
is_0eff <- round(dt_loss_calc$effect_share, 4) == 0
share_0eff <- sum(is_0eff) / length(dt_loss_calc$effect_share)
decomp.rssd <- decomp.rssd * (1 + share_0eff)
}
} else {
Expand All @@ -856,7 +868,8 @@ robyn_mmm <- function(InputCollect,
# When all media in this iteration have 0 coefficients
if (is.nan(decomp.rssd)) {
decomp.rssd <- Inf
dt_decompSpendDist$effect_share <- 0
decompCollect$xDecompAgg <- decompCollect$xDecompAgg %>%
mutate(effect_share = ifelse(is.na(.data$effect_share), NA, 0))
}

#####################################
Expand Down Expand Up @@ -907,8 +920,8 @@ robyn_mmm <- function(InputCollect,
bind_cols(common)
}

resultCollect[["decompSpendDist"]] <- dt_decompSpendDist %>%
bind_cols(common)
# resultCollect[["decompSpendDist"]] <- dt_decompSpendDist %>%
# bind_cols(common)

resultCollect <- append(resultCollect, as.list(common))
return(resultCollect)
Expand Down Expand Up @@ -1015,11 +1028,11 @@ robyn_mmm <- function(InputCollect,
arrange(.data$mape, .data$liftMedia, .data$liftStart))
}

resultCollect[["decompSpendDist"]] <- as_tibble(bind_rows(
lapply(resultCollectNG, function(x) {
bind_rows(lapply(x, function(y) y$decompSpendDist))
})
))
# resultCollect[["decompSpendDist"]] <- as_tibble(bind_rows(
# lapply(resultCollectNG, function(x) {
# bind_rows(lapply(x, function(y) y$decompSpendDist))
# })
# ))

resultCollect$iter <- length(resultCollect$mape)
resultCollect$elapsed.min <- sysTimeDopar[3] / 60
Expand Down
Loading

0 comments on commit 9158a93

Please sign in to comment.