diff --git a/R/truncnorm.R b/R/truncnorm.R index 4dfe62a..ed03b6d 100644 --- a/R/truncnorm.R +++ b/R/truncnorm.R @@ -29,65 +29,110 @@ #' @seealso \code{\link{my_e2truncnorm}}, \code{\link{my_vtruncnorm}} #' #' @export -#' my_etruncnorm = function(a, b, mean = 0, sd = 1) { + # based off of https://github.com/cossio/TruncatedNormal.jl do_truncnorm_argchecks(a, b) - # The case where some sds are zero is handled last. In the meantime, assume - # that sd > 0. + # initialize array to store result + common_shape = 0 * (a + b + mean + sd) + if (is.matrix(common_shape)){ + res = array(dim = dim(common_shape)) + } + else{ + res = rep(NA, length.out = length(common_shape)) + } + + #make sure input sizes all match + a = rep(a, length.out = length(res)) + b = rep(b, length.out = length(res)) + mean = rep(mean, length.out = length(res)) + sd = rep(sd, length.out = length(res)) + + # Handle NAN inputs + isna = is.na(a) | is.na(b) | is.na(mean) | is.na(sd) + res[isna] = NA + + # Handle zero sds. Return the mean of the untruncated normal when it is + # located inside of the interval [alpha, beta]. Otherwise, return the + # endpoint that is closer to the mean. + sd.zero = (sd == 0) + res[!isna & sd.zero & b <= mean] = b[!isna & sd.zero & b <= mean] + res[!isna & sd.zero & a >= mean] = a[!isna & sd.zero & a >= mean] + res[!isna & sd.zero & a < mean & b > mean] = mean[!isna & sd.zero & a < mean & b > mean] + + # Focus in on where sd is nonzero and nothing is nan + a = a[!sd.zero & !isna] + b = b[!sd.zero & !isna] + mean = mean[!sd.zero & !isna] + sd = sd[!sd.zero & !isna] + + # Rescale to standard normal distributions alpha = (a - mean) / sd beta = (b - mean) / sd + # initialize array for scaled 2nd moments + scaled.mean = rep(NA, length.out = length(alpha)) - # Flip alpha and beta when: 1. Both are positive (since computations are - # unstable when both values of pnorm are close to 1); 2. dnorm(alpha) is - # greater than dnorm(beta) (since subtraction is done on the log scale). - flip = (alpha > 0 & beta > 0) | (beta > abs(alpha)) + # point mass bc endpoints are equal + # 1st moment is point-value + endpoints_equal = (alpha == beta) + scaled.mean[endpoints_equal] = alpha[endpoints_equal] + # keep track of which spots in scaled.2mom are already computed + computed = endpoints_equal + + # force to satisfy β ≥ 0 and |α| ≤ |β| + # so either α ≤ 0 ≤ β or 0 < α ≤ β + flip = !computed & abs(alpha) > abs(beta) flip[is.na(flip)] = FALSE orig.alpha = alpha alpha[flip] = -beta[flip] beta[flip] = -orig.alpha[flip] - dnorm.diff = logscale_sub(dnorm(beta, log = TRUE), dnorm(alpha, log = TRUE)) - pnorm.diff = logscale_sub(pnorm(beta, log.p = TRUE), pnorm(alpha, log.p = TRUE)) - scaled.res = -exp(dnorm.diff - pnorm.diff) + # both endpoints infinite/untruncated normal distribution + # scaled mean is 0 + both_inf = !computed & is.infinite(alpha) & is.infinite(beta) + scaled.mean[both_inf] = 0 + computed = computed | both_inf + + # truncated to [α,∞) + # 2nd moment simplifies to ϕ(α)/(1 - Φ(α)) + beta_inf = !computed & is.infinite(beta) + scaled.mean[beta_inf] = sqrt(2/pi) / Re(erfcx(alpha[beta_inf] / sqrt(2))) + computed = computed | beta_inf + + # a ≤ 0 ≤ b + #catestrophic cancellation is less of an issue + alpha_negative = !computed & alpha <= 0 + diff = (beta[alpha_negative] - alpha[alpha_negative]) * (alpha[alpha_negative] + beta[alpha_negative]) / 2 + #√(2/π) * expm1(-Δ) * exp(-α^2 / 2) / erf(β/√2, α/√2) + scaled.mean[alpha_negative] = sqrt(2/pi) * expm1(-diff) * exp(-alpha[alpha_negative]^2 / 2) + denom = Re(erf(alpha[alpha_negative] / sqrt(2))) - Re(erf(beta[alpha_negative] / sqrt(2))) + scaled.mean[alpha_negative] = scaled.mean[alpha_negative] / denom + computed = computed | alpha_negative - # Handle the division by zero that occurs when pnorm.diff = -Inf (that is, - # when endpoints are approximately equal). - endpts.equal = is.infinite(pnorm.diff) - scaled.res[endpts.equal] = (alpha[endpts.equal] + beta[endpts.equal]) / 2 + # 0 < a < b + #strategically avoid catestrophic cancellation as much as possible + diff = (beta[!computed] - alpha[!computed]) * (alpha[!computed] + beta[!computed]) / 2 + denom = exp(-diff) * Re(erfcx(beta[!computed] / sqrt(2))) - Re(erfcx(alpha[!computed] / sqrt(2))) + scaled.mean[!computed] = sqrt(2/pi) * expm1(-diff) / denom - # When alpha and beta are very large and both negative (due to the flipping - # logic, they cannot both be positive), computations can become unstable. - # We find such cases by checking that the expectations make sense. When - # beta is negative, beta + 1 / beta is a lower bound for the expectation. - # Further, it is an increasingly good approximation as beta goes to -Inf - # as long as alpha and beta aren't too close to one another. If they are, - # then their midpoint can be used as an alternate approximation (and lower - # bound). - lower.bd = pmax(beta + 1 / beta, (alpha + beta) / 2) - bad.idx = (!is.na(beta) & beta < 0 - & (scaled.res < lower.bd | scaled.res > beta)) - scaled.res[bad.idx] = lower.bd[bad.idx] + #double check that things are within bounds + scaled.mean[scaled.mean > beta] = beta[scaled.mean > beta] + scaled.mean[scaled.mean < alpha] = alpha[scaled.mean < alpha] # Flip back. - scaled.res[flip] = -scaled.res[flip] + scaled.mean[flip] = -scaled.mean[flip] - res = mean + sd * scaled.res + #transform back to nonstandard normal case + res[!sd.zero & !isna] = mean + sd * scaled.mean - # Handle zero sds. Return the mean of the untruncated normal when it is - # located inside of the interval [alpha, beta]. Otherwise, return the - # endpoint that is closer to the mean. - if (any(sd == 0)) { - # For the subsetting to work correctly, arguments need to be recycled. - a = rep(a, length.out = length(res)) - b = rep(b, length.out = length(res)) - mean = rep(mean, length.out = length(res)) - - sd.zero = (sd == 0) - res[sd.zero & b <= mean] = b[sd.zero & b <= mean] - res[sd.zero & a >= mean] = a[sd.zero & a >= mean] - res[sd.zero & a < mean & b > mean] = mean[sd.zero & a < mean & b > mean] + #throw error if results are far outside the plausible range + error_tol = 1 + if (any(res < a-error_tol & !isna) | any(res > b+error_tol & !isna)) { + stop("Computed expected value of truncated normal is outside plausible range") } + #silently correct small errors + res[res < a & !isna] = a + res[res > b & !isna] = b return(res) } @@ -118,71 +163,116 @@ my_etruncnorm = function(a, b, mean = 0, sd = 1) { #' #' @export #' +library(RcppFaddeeva) my_e2truncnorm = function(a, b, mean = 0, sd = 1) { + # based off of https://github.com/cossio/TruncatedNormal.jl do_truncnorm_argchecks(a, b) + + # initialize array to store result + common_shape = 0 * (a + b + mean + sd) + if (is.matrix(common_shape)){ + res = array(dim = dim(common_shape)) + } + else{ + res = rep(NA, length.out = length(common_shape)) + } + + #make sure input sizes all match + a = rep(a, length.out = length(res)) + b = rep(b, length.out = length(res)) + mean = rep(mean, length.out = length(res)) + sd = rep(sd, length.out = length(res)) + + # Handle NAN inputs + isna = is.na(a) | is.na(b) | is.na(mean) | is.na(sd) + res[isna] = NA + # Handle zero sds. Return the mean of the untruncated normal when it is + # located inside of the interval [alpha, beta]. Otherwise, return the + # endpoint that is closer to the mean. + sd.zero = (sd == 0) + res[!isna & sd.zero & b <= mean] = b[!isna & sd.zero & b <= mean]^2 + res[!isna & sd.zero & a >= mean] = a[!isna & sd.zero & a >= mean]^2 + res[!isna & sd.zero & a < mean & b > mean] = mean[!isna & sd.zero & a < mean & b > mean]^2 + + # Focus in on where sd is nonzero and nothing is nan + a = a[!sd.zero & !isna] + b = b[!sd.zero & !isna] + mean = mean[!sd.zero & !isna] + sd = sd[!sd.zero & !isna] + + # Rescale to standard normal distributions if sd is nonzero alpha = (a - mean) / sd beta = (b - mean) / sd - - # Flip alpha and beta when both are positive (as above, but the mean is - # also recycled and flipped so that we don't have to flip back). - flip = (alpha > 0 & beta > 0) + scaled.mean = my_etruncnorm(alpha, beta) + # initialize array for scaled 2nd moments + scaled.2mom = rep(NA, length.out = length(alpha)) + + # point mass bc endpoints are equal + # 2nd moment is point-value squared + endpoints_equal = (alpha == beta) + scaled.2mom[endpoints_equal] = alpha[endpoints_equal] ^ 2 + # keep track of which spots in scaled.2mom are already computed + computed = endpoints_equal + + # force to satisfy β ≥ 0 and |α| ≤ |β| + # so either α ≤ 0 ≤ β or 0 < α ≤ β + flip = !computed & abs(alpha) > abs(beta) flip[is.na(flip)] = FALSE orig.alpha = alpha alpha[flip] = -beta[flip] beta[flip] = -orig.alpha[flip] - if (any(mean != 0)) { - mean = rep(mean, length.out = length(alpha)) - mean[flip] = -mean[flip] - } - - pnorm.diff = logscale_sub(pnorm(beta, log.p = TRUE), pnorm(alpha, log.p = TRUE)) - alpha.frac = alpha * exp(dnorm(alpha, log = TRUE) - pnorm.diff) - beta.frac = beta * exp(dnorm(beta, log = TRUE) - pnorm.diff) - - # Create a vector or matrix of 1's with NA's in the correct places. - if (is.matrix(alpha)) - scaled.res = array(1, dim = dim(alpha)) - else - scaled.res = rep(1, length.out = length(alpha)) - is.na(scaled.res) = is.na(flip) - - alpha.idx = is.finite(alpha) - scaled.res[alpha.idx] = 1 + alpha.frac[alpha.idx] - beta.idx = is.finite(beta) - scaled.res[beta.idx] = scaled.res[beta.idx] - beta.frac[beta.idx] - - # Handle approximately equal endpoints. - endpts.equal = is.infinite(pnorm.diff) - scaled.res[endpts.equal] = (alpha[endpts.equal] + beta[endpts.equal])^2 / 4 - - # Check that the results make sense. When beta is negative, - # beta^2 + 2 * (1 + 1 / beta^2) is an upper bound for the expected squared - # value, and it is typically a good approximation as beta goes to -Inf. - # When the endpoints are very close to one another, the expected squared - # value of the uniform distribution on [alpha, beta] is a better upper - # bound (and approximation). - upper.bd1 = beta^2 + 2 * (1 + 1 / beta^2) - upper.bd2 = (alpha^2 + alpha * beta + beta^2) / 3 - upper.bd = pmin(upper.bd1, upper.bd2) - bad.idx = (!is.na(beta) & beta < 0 - & (scaled.res < beta^2 | scaled.res > upper.bd)) - scaled.res[bad.idx] = upper.bd[bad.idx] - - res = mean^2 + 2 * mean * sd * my_etruncnorm(alpha, beta) + sd^2 * scaled.res + + # both endpoints infinite/untruncated normal distribution + # 2nd moment is 1 + both_inf = !computed & is.infinite(alpha) & is.infinite(beta) + scaled.2mom[both_inf] = 1 + computed = computed | both_inf + + # truncated to [α,∞) + # 2nd moment simplifies to 1 + αϕ(α)/(1 - Φ(α)) + beta_inf = !computed & is.infinite(beta) + scaled.2mom[beta_inf] = 1 + sqrt(2 / pi) * alpha[beta_inf] / Re(erfcx(alpha[beta_inf] / sqrt(2))) + computed = computed | beta_inf + + # a ≤ 0 ≤ b + #catestrophic cancellation is less of an issue + alpha_negative = !computed & alpha <= 0 + ea = sqrt(pi/2) * Re(erf(alpha[alpha_negative] / sqrt(2))) + eb = sqrt(pi/2) * Re(erf(beta[alpha_negative] / sqrt(2))) + fa = ea - alpha[alpha_negative] * exp(-alpha[alpha_negative]^2 / 2) + fb = eb - beta[alpha_negative] * exp(-beta[alpha_negative]^2 / 2) + scaled.2mom[alpha_negative] = (fb - fa) / (eb - ea) + computed = computed | alpha_negative + + # 0 < a ≤ b + #strategically avoid catestrophic cancellation as much as possible + exdiff = exp((alpha[!computed] - beta[!computed])*(alpha[!computed] + beta[!computed])/2) + ea = sqrt(pi/2) * Re(erfcx(alpha[!computed] / sqrt(2))) + eb = sqrt(pi/2) * Re(erfcx(beta[!computed] / sqrt(2))) + fa = ea + alpha[!computed] + fb = eb + beta[!computed] + scaled.2mom[!computed] = (fa - fb * exdiff) / (ea - eb * exdiff) + + # transform results back to nonstandard normal case + # μ^2 + σ^2 E(Z^2) + 2 μ σ E(Z) + # possible catestropic cancellation because μ^2 + σ^2 E(Z^2) ≧ 0 while 2μσE(Z) has unknown sign + # If |μ| < σ , compute as μ^2 + σ(σ E(Z^2) + 2 μ E(Z)) + m_sd = abs(mean) < sd + res[!sd.zero & !isna][m_sd] = mean[m_sd]^2 + sd[m_sd]*(sd[m_sd] * scaled.2mom[m_sd] + 2 * mean[m_sd] * scaled.mean[m_sd]) + # If σ ≦ |μ| , compute as μ(μ + 2 σ E(Z)) + σ^2 E(Z^2) + res[!sd.zero & !isna][!m_sd] = mean[!m_sd]*(mean[!m_sd] + 2 * sd[!m_sd] * scaled.mean[!m_sd]) + sd[!m_sd]^2 * scaled.2mom[!m_sd] + # TODO experiment with whether the above is a good idea or not... - # Handle zero sds. - if (any(sd == 0)) { - a = rep(a, length.out = length(res)) - b = rep(b, length.out = length(res)) - mean = rep(mean, length.out = length(res)) - - sd.zero = (sd == 0) - res[sd.zero & b <= mean] = b[sd.zero & b <= mean]^2 - res[sd.zero & a >= mean] = a[sd.zero & a >= mean]^2 - res[sd.zero & a < mean & b > mean] = mean[sd.zero & a < mean & b > mean]^2 + #throw error if results are far outside the plausible range + error_tol = 1 + if (any(res < a^2-error_tol & !isna) | any(res > b^2+error_tol & !isna)) { + stop("Computed second moment of truncated normal is outside plausible range") } - + #silently correct small errors + res[res < a^2 & !isna] = a^2 + res[res > b^2 & !isna] = b^2 + return(res) } @@ -203,30 +293,35 @@ my_e2truncnorm = function(a, b, mean = 0, sd = 1) { #' @seealso \code{\link{my_etruncnorm}}, \code{\link{my_e2truncnorm}} #' #' @export -#' my_vtruncnorm = function(a, b, mean = 0, sd = 1) { + # based off of https://github.com/cossio/TruncatedNormal.jl do_truncnorm_argchecks(a, b) + #solved scaled problem alpha = (a - mean) / sd beta = (b - mean) / sd - scaled.res = my_e2truncnorm(alpha, beta) - my_etruncnorm(alpha, beta)^2 + m1 = my_etruncnorm(alpha, beta) + m2 = sqrt(my_e2truncnorm(alpha, beta)) + scaled.res = (m2 - m1) * (m1 + m2) + + # Handle endpoints equal + isna = is.na(a) | is.na(b) | is.na(mean) | is.na(sd) + scaled.res[(alpha == beta) & sd != 0 & !isna] = alpha[(alpha == beta) & sd != 0 & !isna] - # When alpha and beta are large and share the same sign, this computation - # becomes unstable. A good approximation in this regime is 1 / beta^2 - # (when alpha and beta are both negative). If my_e2truncnorm and - # my_etruncnorm are accurate to the eighth digit, then we can only trust - # results up to the second digit if beta^2 and 1 / beta^2 differ by an - # order of magnitude no more than 6. - smaller.endpt = pmin(abs(alpha), abs(beta)) - bad.idx = (is.finite(smaller.endpt) & smaller.endpt > 30) - scaled.res[bad.idx] = pmin(1 / smaller.endpt[bad.idx]^2, - (beta[bad.idx] - alpha[bad.idx])^2 / 12) + #transform back to unscaled + res = sd^2 * scaled.res # Handle zero sds. - scaled.res[is.nan(scaled.res)] = 0 + res[sd == 0] = 0 - res = sd^2 * scaled.res + #throw error if results are far outside the plausible range + error_tol = 1 + if (any(res < -error_tol & !isna)) { + stop("Computed variance of truncated normal is outside plausible range") + } + #silently correct small errors + res[res < 0 & !isna] = 0 return(res) } @@ -237,20 +332,3 @@ do_truncnorm_argchecks = function(a, b) { if (any(b < a, na.rm = TRUE)) stop("truncnorm functions require that a <= b.") } - -logscale_sub = function(logx, logy) { - # In rare cases, logx can become numerically less than logy. When this - # occurs, logx is adjusted and a warning is issued. - diff = logx - logy - if (any(diff < 0, na.rm = TRUE)) { - bad.idx = (diff < 0) - bad.idx[is.na(bad.idx)] = FALSE - logx[bad.idx] = logy[bad.idx] - warning("logscale_sub encountered negative value(s) of logx - logy (min: ", - formatC(min(diff[bad.idx]), format = "e", digits = 2), ")") - } - - scale.by = logx - scale.by[is.infinite(scale.by)] = 0 - return(log(exp(logx - scale.by) - exp(logy - scale.by)) + scale.by) -} diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 612d70e..c04dc3a 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -5,6 +5,11 @@ using namespace Rcpp; +#ifdef RCPP_USE_GLOBAL_ROSTREAM +Rcpp::Rostream& Rcpp::Rcout = Rcpp::Rcpp_cout_get(); +Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); +#endif + // cxxMixSquarem List cxxMixSquarem(NumericMatrix matrix_lik, NumericVector prior, NumericVector pi_init, List control); RcppExport SEXP _ashr_cxxMixSquarem(SEXP matrix_likSEXP, SEXP priorSEXP, SEXP pi_initSEXP, SEXP controlSEXP) { diff --git a/tests/testthat/test_myetruncnorm.R b/tests/testthat/test_myetruncnorm.R index c5e2198..1568057 100644 --- a/tests/testthat/test_myetruncnorm.R +++ b/tests/testthat/test_myetruncnorm.R @@ -1,18 +1,19 @@ context("my_etruncnorm") test_that("my_etruncnorm returns expected results", { - expect_equal(-100,my_etruncnorm(-Inf,-100,0,1),tolerance=0.01) - expect_equal(-100,my_etruncnorm(-Inf,-100,0,0)) - expect_equal(30,my_etruncnorm(30,100,0,0)) - real = c(-100,-100,30) - a=c(-Inf,-Inf,30) - b=c(-100,-100,100) - m=c(0,0,0) - sd=c(1,0,0) + a = c(-Inf,-Inf, 30,NA, 1, 1, 1,1,100,100,-100,-Inf, 0,-1, 1, -2, -2) + b = c(-100,-100,100, 1,NA, 1, 1,1,Inf,Inf, -30, Inf, Inf, 1, 2, 2, 3) + m = c( 0, 0, 0, 1, 1,NA, 1,1, 0, 0, 0, 5, 0, 0, 0,-5e9, 0) + sd = c( 1, 0, 0, 1, 1, 1,NA,1, 1, 0, 0, 2, 1, 1, 1, 2,1e4) + real = c(-100,-100, 30,NA,NA,NA,NA,1,100,100, -30, 5,0.80, 0,1.38, -2,0.5) + N = length(real) + for (idx in 1:N){ + expect_equal(real[idx],my_etruncnorm(a[idx],b[idx],m[idx],sd[idx]),tolerance=0.01) + } expect_equal(real,my_etruncnorm(a,b,m,sd),tolerance=0.01) - real = matrix(real,3,4) - m = matrix(m,3,4) - sd = matrix(sd,3,4) + real = matrix(real,N,4) + m = matrix(m,N,4) + sd = matrix(sd,N,4) expect_equal(real,my_etruncnorm(a,b,m,sd),tolerance=0.01) a=c(0,0) b=c(1,2) @@ -25,23 +26,26 @@ test_that("my_etruncnorm returns expected results", { expect_equal(my_etruncnorm(0,9999,-2,3),my_etruncnorm(0,Inf,-2,3),tol=1e-3) expect_error(my_etruncnorm(0, 1:2, mean = 0, sd = 1)) expect_error(my_etruncnorm(1, 0, mean = 0, sd = 1)) + + #TODO add test cases from pull request }) context("my_vtruncnorm") test_that("my_vtruncnorm returns expected results", { - expect_equal(0, my_vtruncnorm(-Inf, -100), tolerance = 0.01) - expect_equal(0, my_vtruncnorm(-Inf, -100, sd = 0)) - expect_equal(0, my_vtruncnorm(30, 100, sd = 0)) - real = c(0, 0, 0) - a = c(-Inf, -Inf, 30) - b = c(-100, -100, 100) - m = c(0, 0, 0) - sd = c(1, 0, 0) + a = c(-Inf,-Inf, 30,NA, 1, 1, 1,1,100,100,-100,-Inf, 0, -1, 1, -2, -2) + b = c(-100,-100,100, 1,NA, 1, 1,1,Inf,Inf, -30, Inf, Inf, 1, 2, 2, 3) + m = c( 0, 0, 0, 1, 1,NA, 1,1, 0, 0, 0, 5, 0, 0, 0,-5e9, 0) + sd = c( 1, 0, 0, 1, 1, 1,NA,1, 1, 0, 0, 2, 2, 1, 1, 2, 1e4) + real = c( 0, 0, 0,NA,NA,NA,NA,0, 0, 0, 0, 4,1.45,0.29,0.07, 0,2.08) + N = length(real) + for (idx in 1:N){ + expect_equal(real[idx],my_vtruncnorm(a[idx],b[idx],m[idx],sd[idx]),tolerance=0.01) + } expect_equal(real, my_vtruncnorm(a, b, m, sd), tolerance = 0.01) - real = matrix(real, 3, 4) - m = matrix(m, 3, 4) - sd = matrix(sd, 3, 4) + real = matrix(real, N, 4) + m = matrix(m, N, 4) + sd = matrix(sd, N, 4) expect_equal(real, my_vtruncnorm(a, b, m, sd), tolerance = 0.01) a = c(0, 0) b = c(1, 2)