Skip to content

Commit

Permalink
Use styler::style_pkg()
Browse files Browse the repository at this point in the history
  • Loading branch information
etiennebacher committed Feb 14, 2024
1 parent c46c95f commit 210817e
Show file tree
Hide file tree
Showing 18 changed files with 680 additions and 689 deletions.
51 changes: 26 additions & 25 deletions R/em_link.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,40 +35,41 @@
#'
#' @examples
#'
#' inv_logit <- function (x) {
#' exp(x)/(1+exp(x))
#'}
#'n <- 10^6
#'d <- 1:n %% 5 == 0
#'X <- cbind(
#' as.integer(ifelse(d, runif(n)<.8, runif(n)<.2)),
#' as.integer(ifelse(d, runif(n)<.9, runif(n)<.2)),
#' as.integer(ifelse(d, runif(n)<.7, runif(n)<.2)),
#' as.integer(ifelse(d, runif(n)<.6, runif(n)<.2)),
#' as.integer(ifelse(d, runif(n)<.5, runif(n)<.2)),
#' as.integer(ifelse(d, runif(n)<.1, runif(n)<.9)),
#' as.integer(ifelse(d, runif(n)<.1, runif(n)<.9)),
#' as.integer(ifelse(d, runif(n)<.8, runif(n)<.01))
#' )
#' inv_logit <- function(x) {
#' exp(x) / (1 + exp(x))
#' }
#' n <- 10^6
#' d <- 1:n %% 5 == 0
#' X <- cbind(
#' as.integer(ifelse(d, runif(n) < .8, runif(n) < .2)),
#' as.integer(ifelse(d, runif(n) < .9, runif(n) < .2)),
#' as.integer(ifelse(d, runif(n) < .7, runif(n) < .2)),
#' as.integer(ifelse(d, runif(n) < .6, runif(n) < .2)),
#' as.integer(ifelse(d, runif(n) < .5, runif(n) < .2)),
#' as.integer(ifelse(d, runif(n) < .1, runif(n) < .9)),
#' as.integer(ifelse(d, runif(n) < .1, runif(n) < .9)),
#' as.integer(ifelse(d, runif(n) < .8, runif(n) < .01))
#' )
#'
#' # inital guess at class assignments based on # a hypothetical logistic
#' # regression. Should be based on domain knowledge, or a handful of hand-coded
#' # observations.
#'
#'x_sum <- rowSums(X)
#'g <- inv_logit((x_sum - mean(x_sum))/sd(x_sum))
#' x_sum <- rowSums(X)
#' g <- inv_logit((x_sum - mean(x_sum)) / sd(x_sum))
#'
#' out <- em_link(X, g,tol=.0001, max_iter = 100)
#' out <- em_link(X, g, tol = .0001, max_iter = 100)
#'
#' @export
em_link <- function (X,g, tol = 10^-6, max_iter = 10^3) {
em_link <- function(X, g, tol = 10^-6, max_iter = 10^3) {
stopifnot(
"There can be no NA's in X (but you can add NA as its own agreement level)" = !any(is.na(X))
)

stopifnot("There can be no NA's in X (but you can add NA as its own agreement level)"
= !any(is.na(X)))
stopifnot(
"initial guesses must be valid probabilities (greater than 0 and less than 1)" = all(g < 1 & g > 0)
)

stopifnot("initial guesses must be valid probabilities (greater than 0 and less than 1)"
= all(g < 1 & g > 0))


rust_em_link(X,g, tol, max_iter)
rust_em_link(X, g, tol, max_iter)
}
157 changes: 79 additions & 78 deletions R/euclidean_join_core.R
Original file line number Diff line number Diff line change
@@ -1,95 +1,96 @@
multi_by_validate <- function(a,b, by) {
# first pass to handle dplyr::join_by() call
if (inherits(by, "dplyr_join_by")) {
if (any(by$condition != "==")) {
stop("Inequality joins are not supported.")
}
new_by <- by$y
names(new_by) <- by$x
by <- new_by
multi_by_validate <- function(a, b, by) {
# first pass to handle dplyr::join_by() call
if (inherits(by, "dplyr_join_by")) {
if (any(by$condition != "==")) {
stop("Inequality joins are not supported.")
}
new_by <- by$y
names(new_by) <- by$x
by <- new_by
}

if (is.null(by)) {
by_a <- intersect(names(a), names(b))
by_b <- intersect(names(a), names(b))
if (is.null(by)) {
by_a <- intersect(names(a), names(b))
by_b <- intersect(names(a), names(b))
} else {
if (!is.null(names(by))) {
by_a <- names(by)
by_b <- by
} else {
if (!is.null(names(by))) {
by_a <- names(by)
by_b <- by
} else {
by_a <- by
by_b <- by
}

stopifnot(by_a %in% names(a))
stopifnot(by_b %in% names(b))
by_a <- by
by_b <- by
}
return(list(
by_a,
by_b
))
}

#` @importFrom stats pnorm
euclidean_join_core <- function (a, b, by = NULL, n_bands = 30, band_width = 10, threshold=1.0, r=.5, progress = FALSE, mode="inner") {

stopifnot("'radius' must be greater than 0" = threshold > 0)

by <- multi_by_validate(a,b,by)
by_a <- by[[1]]
by_b <- by[[2]]
stopifnot("There should be no NA's in by_a[1]"=!any(is.na(dplyr::pull(a,by_a[1]))))
stopifnot("There should be no NA's in by_a[2]"=!any(is.na(dplyr::pull(a,by_a[2]))))
stopifnot("There should be no NA's in by_b[1]"=!any(is.na(dplyr::pull(b,by_b[1]))))
stopifnot("There should be no NA's in by_b[2]"=!any(is.na(dplyr::pull(b,by_b[2]))))
stopifnot(by_a %in% names(a))
stopifnot(by_b %in% names(b))
}
return(list(
by_a,
by_b
))
}

thresh_prob <- euclidean_probability(threshold, n_bands, band_width, r)
if (thresh_prob < .95) {
str <- paste0("A pair of records at the threshold (", threshold,
") have only a ", round(thresh_prob*100), "% chance of being compared.\n",
"Please consider changing `n_bands` and `band_width`, and `r`.")
# ` @importFrom stats pnorm
euclidean_join_core <- function(a, b, by = NULL, n_bands = 30, band_width = 10, threshold = 1.0, r = .5, progress = FALSE, mode = "inner") {
stopifnot("'radius' must be greater than 0" = threshold > 0)

warning(str)
}
by <- multi_by_validate(a, b, by)
by_a <- by[[1]]
by_b <- by[[2]]
stopifnot("There should be no NA's in by_a[1]" = !any(is.na(dplyr::pull(a, by_a[1]))))
stopifnot("There should be no NA's in by_a[2]" = !any(is.na(dplyr::pull(a, by_a[2]))))
stopifnot("There should be no NA's in by_b[1]" = !any(is.na(dplyr::pull(b, by_b[1]))))
stopifnot("There should be no NA's in by_b[2]" = !any(is.na(dplyr::pull(b, by_b[2]))))

match_table <- rust_p_norm_join(
a_mat = as.matrix(dplyr::select(a, dplyr::all_of(by_a))),
b_mat = as.matrix(dplyr::select(b, dplyr::all_of(by_b))),
radius = threshold,
band_width = band_width,
n_bands = n_bands,
r = r,
progress = progress,
seed = round(runif(1,0,2^32))
thresh_prob <- euclidean_probability(threshold, n_bands, band_width, r)
if (thresh_prob < .95) {
str <- paste0(
"A pair of records at the threshold (", threshold,
") have only a ", round(thresh_prob * 100), "% chance of being compared.\n",
"Please consider changing `n_bands` and `band_width`, and `r`."
)

names_in_both <- intersect(names(a), names(b))
warning(str)
}

names(a)[names(a) %in% names_in_both] <-
paste0(names(a)[names(a) %in% names_in_both], ".x")
names(b)[names(b) %in% names_in_both] <-
paste0(names(b)[names(b) %in% names_in_both], ".y")
match_table <- rust_p_norm_join(
a_mat = as.matrix(dplyr::select(a, dplyr::all_of(by_a))),
b_mat = as.matrix(dplyr::select(b, dplyr::all_of(by_b))),
radius = threshold,
band_width = band_width,
n_bands = n_bands,
r = r,
progress = progress,
seed = round(runif(1, 0, 2^32))
)

matches <- dplyr::bind_cols(a[match_table[, 1], ], b[match_table[, 2], ])
names_in_both <- intersect(names(a), names(b))

# No need to look for rows that don't match
if (mode == "inner") {
return(matches)
}
names(a)[names(a) %in% names_in_both] <-
paste0(names(a)[names(a) %in% names_in_both], ".x")
names(b)[names(b) %in% names_in_both] <-
paste0(names(b)[names(b) %in% names_in_both], ".y")

not_matched_a <- ! seq(nrow(a)) %in% match_table[,1]
not_matched_b <- ! seq(nrow(b)) %in% match_table[,2]
matches <- dplyr::bind_cols(a[match_table[, 1], ], b[match_table[, 2], ])

if (mode == "left") {
matches <- dplyr::bind_rows(matches,a[not_matched_a,])
} else if (mode == "right") {
matches <- dplyr::bind_rows(matches,b[not_matched_b,])
} else if (mode == "full") {
matches <- dplyr::bind_rows(matches,a[not_matched_a,],b[not_matched_b,])
} else if (mode == "anti") {
matches <- dplyr::bind_rows(a[not_matched_a,], b[not_matched_b,])
} else {
stop("Invalid Mode Selected!")
}
# No need to look for rows that don't match
if (mode == "inner") {
return(matches)
}

not_matched_a <- !seq(nrow(a)) %in% match_table[, 1]
not_matched_b <- !seq(nrow(b)) %in% match_table[, 2]

if (mode == "left") {
matches <- dplyr::bind_rows(matches, a[not_matched_a, ])
} else if (mode == "right") {
matches <- dplyr::bind_rows(matches, b[not_matched_b, ])
} else if (mode == "full") {
matches <- dplyr::bind_rows(matches, a[not_matched_a, ], b[not_matched_b, ])
} else if (mode == "anti") {
matches <- dplyr::bind_rows(a[not_matched_a, ], b[not_matched_b, ])
} else {
stop("Invalid Mode Selected!")
}
return(matches)
}
Loading

0 comments on commit 210817e

Please sign in to comment.