Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncnorm #129

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 204 additions & 126 deletions R/truncnorm.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,65 +29,110 @@
#' @seealso \code{\link{my_e2truncnorm}}, \code{\link{my_vtruncnorm}}
#'
#' @export
#'
my_etruncnorm = function(a, b, mean = 0, sd = 1) {
# based off of https://github.com/cossio/TruncatedNormal.jl
do_truncnorm_argchecks(a, b)

# The case where some sds are zero is handled last. In the meantime, assume
# that sd > 0.
# initialize array to store result
common_shape = 0 * (a + b + mean + sd)
if (is.matrix(common_shape)){
res = array(dim = dim(common_shape))
}
else{
res = rep(NA, length.out = length(common_shape))
}

#make sure input sizes all match
a = rep(a, length.out = length(res))
b = rep(b, length.out = length(res))
mean = rep(mean, length.out = length(res))
sd = rep(sd, length.out = length(res))

# Handle NAN inputs
isna = is.na(a) | is.na(b) | is.na(mean) | is.na(sd)
res[isna] = NA

# Handle zero sds. Return the mean of the untruncated normal when it is
# located inside of the interval [alpha, beta]. Otherwise, return the
# endpoint that is closer to the mean.
sd.zero = (sd == 0)
res[!isna & sd.zero & b <= mean] = b[!isna & sd.zero & b <= mean]
res[!isna & sd.zero & a >= mean] = a[!isna & sd.zero & a >= mean]
res[!isna & sd.zero & a < mean & b > mean] = mean[!isna & sd.zero & a < mean & b > mean]

# Focus in on where sd is nonzero and nothing is nan
a = a[!sd.zero & !isna]
b = b[!sd.zero & !isna]
mean = mean[!sd.zero & !isna]
sd = sd[!sd.zero & !isna]

# Rescale to standard normal distributions
alpha = (a - mean) / sd
beta = (b - mean) / sd
# initialize array for scaled 2nd moments
scaled.mean = rep(NA, length.out = length(alpha))

# Flip alpha and beta when: 1. Both are positive (since computations are
# unstable when both values of pnorm are close to 1); 2. dnorm(alpha) is
# greater than dnorm(beta) (since subtraction is done on the log scale).
flip = (alpha > 0 & beta > 0) | (beta > abs(alpha))
# point mass bc endpoints are equal
# 1st moment is point-value
endpoints_equal = (alpha == beta)
scaled.mean[endpoints_equal] = alpha[endpoints_equal]
# keep track of which spots in scaled.2mom are already computed
computed = endpoints_equal

# force to satisfy β ≥ 0 and |α| ≤ |β|
# so either α ≤ 0 ≤ β or 0 < α ≤ β
flip = !computed & abs(alpha) > abs(beta)
flip[is.na(flip)] = FALSE
orig.alpha = alpha
alpha[flip] = -beta[flip]
beta[flip] = -orig.alpha[flip]

dnorm.diff = logscale_sub(dnorm(beta, log = TRUE), dnorm(alpha, log = TRUE))
pnorm.diff = logscale_sub(pnorm(beta, log.p = TRUE), pnorm(alpha, log.p = TRUE))
scaled.res = -exp(dnorm.diff - pnorm.diff)
# both endpoints infinite/untruncated normal distribution
# scaled mean is 0
both_inf = !computed & is.infinite(alpha) & is.infinite(beta)
scaled.mean[both_inf] = 0
computed = computed | both_inf

# truncated to [α,∞)
# 2nd moment simplifies to ϕ(α)/(1 - Φ(α))
beta_inf = !computed & is.infinite(beta)
scaled.mean[beta_inf] = sqrt(2/pi) / Re(erfcx(alpha[beta_inf] / sqrt(2)))
computed = computed | beta_inf

# a ≤ 0 ≤ b
#catestrophic cancellation is less of an issue
alpha_negative = !computed & alpha <= 0
diff = (beta[alpha_negative] - alpha[alpha_negative]) * (alpha[alpha_negative] + beta[alpha_negative]) / 2
#√(2/π) * expm1(-Δ) * exp(-α^2 / 2) / erf(β/√2, α/√2)
scaled.mean[alpha_negative] = sqrt(2/pi) * expm1(-diff) * exp(-alpha[alpha_negative]^2 / 2)
denom = Re(erf(alpha[alpha_negative] / sqrt(2))) - Re(erf(beta[alpha_negative] / sqrt(2)))
scaled.mean[alpha_negative] = scaled.mean[alpha_negative] / denom
computed = computed | alpha_negative

# Handle the division by zero that occurs when pnorm.diff = -Inf (that is,
# when endpoints are approximately equal).
endpts.equal = is.infinite(pnorm.diff)
scaled.res[endpts.equal] = (alpha[endpts.equal] + beta[endpts.equal]) / 2
# 0 < a < b
#strategically avoid catestrophic cancellation as much as possible
diff = (beta[!computed] - alpha[!computed]) * (alpha[!computed] + beta[!computed]) / 2
denom = exp(-diff) * Re(erfcx(beta[!computed] / sqrt(2))) - Re(erfcx(alpha[!computed] / sqrt(2)))
scaled.mean[!computed] = sqrt(2/pi) * expm1(-diff) / denom

# When alpha and beta are very large and both negative (due to the flipping
# logic, they cannot both be positive), computations can become unstable.
# We find such cases by checking that the expectations make sense. When
# beta is negative, beta + 1 / beta is a lower bound for the expectation.
# Further, it is an increasingly good approximation as beta goes to -Inf
# as long as alpha and beta aren't too close to one another. If they are,
# then their midpoint can be used as an alternate approximation (and lower
# bound).
lower.bd = pmax(beta + 1 / beta, (alpha + beta) / 2)
bad.idx = (!is.na(beta) & beta < 0
& (scaled.res < lower.bd | scaled.res > beta))
scaled.res[bad.idx] = lower.bd[bad.idx]
#double check that things are within bounds
scaled.mean[scaled.mean > beta] = beta[scaled.mean > beta]
scaled.mean[scaled.mean < alpha] = alpha[scaled.mean < alpha]

# Flip back.
scaled.res[flip] = -scaled.res[flip]
scaled.mean[flip] = -scaled.mean[flip]

res = mean + sd * scaled.res
#transform back to nonstandard normal case
res[!sd.zero & !isna] = mean + sd * scaled.mean

# Handle zero sds. Return the mean of the untruncated normal when it is
# located inside of the interval [alpha, beta]. Otherwise, return the
# endpoint that is closer to the mean.
if (any(sd == 0)) {
# For the subsetting to work correctly, arguments need to be recycled.
a = rep(a, length.out = length(res))
b = rep(b, length.out = length(res))
mean = rep(mean, length.out = length(res))

sd.zero = (sd == 0)
res[sd.zero & b <= mean] = b[sd.zero & b <= mean]
res[sd.zero & a >= mean] = a[sd.zero & a >= mean]
res[sd.zero & a < mean & b > mean] = mean[sd.zero & a < mean & b > mean]
#throw error if results are far outside the plausible range
error_tol = 1
if (any(res < a-error_tol & !isna) | any(res > b+error_tol & !isna)) {
stop("Computed expected value of truncated normal is outside plausible range")
}
#silently correct small errors
res[res < a & !isna] = a
res[res > b & !isna] = b

return(res)
}
Expand Down Expand Up @@ -118,71 +163,116 @@ my_etruncnorm = function(a, b, mean = 0, sd = 1) {
#'
#' @export
#'
library(RcppFaddeeva)
my_e2truncnorm = function(a, b, mean = 0, sd = 1) {
# based off of https://github.com/cossio/TruncatedNormal.jl
do_truncnorm_argchecks(a, b)

# initialize array to store result
common_shape = 0 * (a + b + mean + sd)
if (is.matrix(common_shape)){
res = array(dim = dim(common_shape))
}
else{
res = rep(NA, length.out = length(common_shape))
}

#make sure input sizes all match
a = rep(a, length.out = length(res))
b = rep(b, length.out = length(res))
mean = rep(mean, length.out = length(res))
sd = rep(sd, length.out = length(res))

# Handle NAN inputs
isna = is.na(a) | is.na(b) | is.na(mean) | is.na(sd)
res[isna] = NA

# Handle zero sds. Return the mean of the untruncated normal when it is
# located inside of the interval [alpha, beta]. Otherwise, return the
# endpoint that is closer to the mean.
sd.zero = (sd == 0)
res[!isna & sd.zero & b <= mean] = b[!isna & sd.zero & b <= mean]^2
res[!isna & sd.zero & a >= mean] = a[!isna & sd.zero & a >= mean]^2
res[!isna & sd.zero & a < mean & b > mean] = mean[!isna & sd.zero & a < mean & b > mean]^2

# Focus in on where sd is nonzero and nothing is nan
a = a[!sd.zero & !isna]
b = b[!sd.zero & !isna]
mean = mean[!sd.zero & !isna]
sd = sd[!sd.zero & !isna]

# Rescale to standard normal distributions if sd is nonzero
alpha = (a - mean) / sd
beta = (b - mean) / sd

# Flip alpha and beta when both are positive (as above, but the mean is
# also recycled and flipped so that we don't have to flip back).
flip = (alpha > 0 & beta > 0)
scaled.mean = my_etruncnorm(alpha, beta)
# initialize array for scaled 2nd moments
scaled.2mom = rep(NA, length.out = length(alpha))

# point mass bc endpoints are equal
# 2nd moment is point-value squared
endpoints_equal = (alpha == beta)
scaled.2mom[endpoints_equal] = alpha[endpoints_equal] ^ 2
# keep track of which spots in scaled.2mom are already computed
computed = endpoints_equal

# force to satisfy β ≥ 0 and |α| ≤ |β|
# so either α ≤ 0 ≤ β or 0 < α ≤ β
flip = !computed & abs(alpha) > abs(beta)
flip[is.na(flip)] = FALSE
orig.alpha = alpha
alpha[flip] = -beta[flip]
beta[flip] = -orig.alpha[flip]
if (any(mean != 0)) {
mean = rep(mean, length.out = length(alpha))
mean[flip] = -mean[flip]
}

pnorm.diff = logscale_sub(pnorm(beta, log.p = TRUE), pnorm(alpha, log.p = TRUE))
alpha.frac = alpha * exp(dnorm(alpha, log = TRUE) - pnorm.diff)
beta.frac = beta * exp(dnorm(beta, log = TRUE) - pnorm.diff)

# Create a vector or matrix of 1's with NA's in the correct places.
if (is.matrix(alpha))
scaled.res = array(1, dim = dim(alpha))
else
scaled.res = rep(1, length.out = length(alpha))
is.na(scaled.res) = is.na(flip)

alpha.idx = is.finite(alpha)
scaled.res[alpha.idx] = 1 + alpha.frac[alpha.idx]
beta.idx = is.finite(beta)
scaled.res[beta.idx] = scaled.res[beta.idx] - beta.frac[beta.idx]

# Handle approximately equal endpoints.
endpts.equal = is.infinite(pnorm.diff)
scaled.res[endpts.equal] = (alpha[endpts.equal] + beta[endpts.equal])^2 / 4

# Check that the results make sense. When beta is negative,
# beta^2 + 2 * (1 + 1 / beta^2) is an upper bound for the expected squared
# value, and it is typically a good approximation as beta goes to -Inf.
# When the endpoints are very close to one another, the expected squared
# value of the uniform distribution on [alpha, beta] is a better upper
# bound (and approximation).
upper.bd1 = beta^2 + 2 * (1 + 1 / beta^2)
upper.bd2 = (alpha^2 + alpha * beta + beta^2) / 3
upper.bd = pmin(upper.bd1, upper.bd2)
bad.idx = (!is.na(beta) & beta < 0
& (scaled.res < beta^2 | scaled.res > upper.bd))
scaled.res[bad.idx] = upper.bd[bad.idx]

res = mean^2 + 2 * mean * sd * my_etruncnorm(alpha, beta) + sd^2 * scaled.res

# both endpoints infinite/untruncated normal distribution
# 2nd moment is 1
both_inf = !computed & is.infinite(alpha) & is.infinite(beta)
scaled.2mom[both_inf] = 1
computed = computed | both_inf

# truncated to [α,∞)
# 2nd moment simplifies to 1 + αϕ(α)/(1 - Φ(α))
beta_inf = !computed & is.infinite(beta)
scaled.2mom[beta_inf] = 1 + sqrt(2 / pi) * alpha[beta_inf] / Re(erfcx(alpha[beta_inf] / sqrt(2)))
computed = computed | beta_inf

# a ≤ 0 ≤ b
#catestrophic cancellation is less of an issue
alpha_negative = !computed & alpha <= 0
ea = sqrt(pi/2) * Re(erf(alpha[alpha_negative] / sqrt(2)))
eb = sqrt(pi/2) * Re(erf(beta[alpha_negative] / sqrt(2)))
fa = ea - alpha[alpha_negative] * exp(-alpha[alpha_negative]^2 / 2)
fb = eb - beta[alpha_negative] * exp(-beta[alpha_negative]^2 / 2)
scaled.2mom[alpha_negative] = (fb - fa) / (eb - ea)
computed = computed | alpha_negative

# 0 < a ≤ b
#strategically avoid catestrophic cancellation as much as possible
exdiff = exp((alpha[!computed] - beta[!computed])*(alpha[!computed] + beta[!computed])/2)
ea = sqrt(pi/2) * Re(erfcx(alpha[!computed] / sqrt(2)))
eb = sqrt(pi/2) * Re(erfcx(beta[!computed] / sqrt(2)))
fa = ea + alpha[!computed]
fb = eb + beta[!computed]
scaled.2mom[!computed] = (fa - fb * exdiff) / (ea - eb * exdiff)

# transform results back to nonstandard normal case
# μ^2 + σ^2 E(Z^2) + 2 μ σ E(Z)
# possible catestropic cancellation because μ^2 + σ^2 E(Z^2) ≧ 0 while 2μσE(Z) has unknown sign
# If |μ| < σ , compute as μ^2 + σ(σ E(Z^2) + 2 μ E(Z))
m_sd = abs(mean) < sd
res[!sd.zero & !isna][m_sd] = mean[m_sd]^2 + sd[m_sd]*(sd[m_sd] * scaled.2mom[m_sd] + 2 * mean[m_sd] * scaled.mean[m_sd])
# If σ ≦ |μ| , compute as μ(μ + 2 σ E(Z)) + σ^2 E(Z^2)
res[!sd.zero & !isna][!m_sd] = mean[!m_sd]*(mean[!m_sd] + 2 * sd[!m_sd] * scaled.mean[!m_sd]) + sd[!m_sd]^2 * scaled.2mom[!m_sd]
# TODO experiment with whether the above is a good idea or not...

# Handle zero sds.
if (any(sd == 0)) {
a = rep(a, length.out = length(res))
b = rep(b, length.out = length(res))
mean = rep(mean, length.out = length(res))

sd.zero = (sd == 0)
res[sd.zero & b <= mean] = b[sd.zero & b <= mean]^2
res[sd.zero & a >= mean] = a[sd.zero & a >= mean]^2
res[sd.zero & a < mean & b > mean] = mean[sd.zero & a < mean & b > mean]^2
#throw error if results are far outside the plausible range
error_tol = 1
if (any(res < a^2-error_tol & !isna) | any(res > b^2+error_tol & !isna)) {
stop("Computed second moment of truncated normal is outside plausible range")
}

#silently correct small errors
res[res < a^2 & !isna] = a^2
res[res > b^2 & !isna] = b^2

return(res)
}

Expand All @@ -203,30 +293,35 @@ my_e2truncnorm = function(a, b, mean = 0, sd = 1) {
#' @seealso \code{\link{my_etruncnorm}}, \code{\link{my_e2truncnorm}}
#'
#' @export
#'
my_vtruncnorm = function(a, b, mean = 0, sd = 1) {
# based off of https://github.com/cossio/TruncatedNormal.jl
do_truncnorm_argchecks(a, b)

#solved scaled problem
alpha = (a - mean) / sd
beta = (b - mean) / sd

scaled.res = my_e2truncnorm(alpha, beta) - my_etruncnorm(alpha, beta)^2
m1 = my_etruncnorm(alpha, beta)
m2 = sqrt(my_e2truncnorm(alpha, beta))
scaled.res = (m2 - m1) * (m1 + m2)

# Handle endpoints equal
isna = is.na(a) | is.na(b) | is.na(mean) | is.na(sd)
scaled.res[(alpha == beta) & sd != 0 & !isna] = alpha[(alpha == beta) & sd != 0 & !isna]

# When alpha and beta are large and share the same sign, this computation
# becomes unstable. A good approximation in this regime is 1 / beta^2
# (when alpha and beta are both negative). If my_e2truncnorm and
# my_etruncnorm are accurate to the eighth digit, then we can only trust
# results up to the second digit if beta^2 and 1 / beta^2 differ by an
# order of magnitude no more than 6.
smaller.endpt = pmin(abs(alpha), abs(beta))
bad.idx = (is.finite(smaller.endpt) & smaller.endpt > 30)
scaled.res[bad.idx] = pmin(1 / smaller.endpt[bad.idx]^2,
(beta[bad.idx] - alpha[bad.idx])^2 / 12)
#transform back to unscaled
res = sd^2 * scaled.res

# Handle zero sds.
scaled.res[is.nan(scaled.res)] = 0
res[sd == 0] = 0

res = sd^2 * scaled.res
#throw error if results are far outside the plausible range
error_tol = 1
if (any(res < -error_tol & !isna)) {
stop("Computed variance of truncated normal is outside plausible range")
}
#silently correct small errors
res[res < 0 & !isna] = 0

return(res)
}
Expand All @@ -237,20 +332,3 @@ do_truncnorm_argchecks = function(a, b) {
if (any(b < a, na.rm = TRUE))
stop("truncnorm functions require that a <= b.")
}

logscale_sub = function(logx, logy) {
# In rare cases, logx can become numerically less than logy. When this
# occurs, logx is adjusted and a warning is issued.
diff = logx - logy
if (any(diff < 0, na.rm = TRUE)) {
bad.idx = (diff < 0)
bad.idx[is.na(bad.idx)] = FALSE
logx[bad.idx] = logy[bad.idx]
warning("logscale_sub encountered negative value(s) of logx - logy (min: ",
formatC(min(diff[bad.idx]), format = "e", digits = 2), ")")
}

scale.by = logx
scale.by[is.infinite(scale.by)] = 0
return(log(exp(logx - scale.by) - exp(logy - scale.by)) + scale.by)
}
Loading