Skip to content

Commit

Permalink
Merge branch 'main' into hamming_dist_join
Browse files Browse the repository at this point in the history
  • Loading branch information
beniaminogreen authored Feb 14, 2024
2 parents 1b98213 + 090a56d commit c7ace44
Show file tree
Hide file tree
Showing 22 changed files with 685 additions and 676 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ src/rust/uncomment.sh
^CRAN-SUBMISSION$
^.*\.Rproj$
^\.Rproj\.user$
.lintr
4 changes: 4 additions & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ on:

name: R-CMD-check

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref }}
cancel-in-progress: true

jobs:
R-CMD-check:
runs-on: ${{ matrix.config.os }}
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/pkgdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ on:

name: pkgdown

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref }}
cancel-in-progress: true

jobs:
pkgdown:
runs-on: ubuntu-latest
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ on:

name: test-coverage

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref }}
cancel-in-progress: true

jobs:
test-coverage:
runs-on: ubuntu-latest
Expand Down
7 changes: 7 additions & 0 deletions .lintr
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
linters: linters_with_defaults(
line_length_linter = NULL,
indentation_linter = NULL,
commas_linter = NULL,
infix_spaces_linter = NULL
) # see vignette("lintr")
encoding: "UTF-8"
51 changes: 26 additions & 25 deletions R/em_link.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,40 +35,41 @@
#'
#' @examples
#'
#' inv_logit <- function (x) {
#' exp(x)/(1+exp(x))
#'}
#'n <- 10^6
#'d <- 1:n %% 5 == 0
#'X <- cbind(
#' as.integer(ifelse(d, runif(n)<.8, runif(n)<.2)),
#' as.integer(ifelse(d, runif(n)<.9, runif(n)<.2)),
#' as.integer(ifelse(d, runif(n)<.7, runif(n)<.2)),
#' as.integer(ifelse(d, runif(n)<.6, runif(n)<.2)),
#' as.integer(ifelse(d, runif(n)<.5, runif(n)<.2)),
#' as.integer(ifelse(d, runif(n)<.1, runif(n)<.9)),
#' as.integer(ifelse(d, runif(n)<.1, runif(n)<.9)),
#' as.integer(ifelse(d, runif(n)<.8, runif(n)<.01))
#' )
#' inv_logit <- function(x) {
#' exp(x) / (1 + exp(x))
#' }
#' n <- 10^6
#' d <- 1:n %% 5 == 0
#' X <- cbind(
#' as.integer(ifelse(d, runif(n) < .8, runif(n) < .2)),
#' as.integer(ifelse(d, runif(n) < .9, runif(n) < .2)),
#' as.integer(ifelse(d, runif(n) < .7, runif(n) < .2)),
#' as.integer(ifelse(d, runif(n) < .6, runif(n) < .2)),
#' as.integer(ifelse(d, runif(n) < .5, runif(n) < .2)),
#' as.integer(ifelse(d, runif(n) < .1, runif(n) < .9)),
#' as.integer(ifelse(d, runif(n) < .1, runif(n) < .9)),
#' as.integer(ifelse(d, runif(n) < .8, runif(n) < .01))
#' )
#'
#' # inital guess at class assignments based on # a hypothetical logistic
#' # regression. Should be based on domain knowledge, or a handful of hand-coded
#' # observations.
#'
#'x_sum <- rowSums(X)
#'g <- inv_logit((x_sum - mean(x_sum))/sd(x_sum))
#' x_sum <- rowSums(X)
#' g <- inv_logit((x_sum - mean(x_sum)) / sd(x_sum))
#'
#' out <- em_link(X, g,tol=.0001, max_iter = 100)
#' out <- em_link(X, g, tol = .0001, max_iter = 100)
#'
#' @export
em_link <- function (X,g, tol = 10^-6, max_iter = 10^3) {
em_link <- function(X, g, tol = 10^-6, max_iter = 10^3) {
stopifnot(
"There can be no NA's in X (but you can add NA as its own agreement level)" = !anyNA(X)
)

stopifnot("There can be no NA's in X (but you can add NA as its own agreement level)"
= !any(is.na(X)))
stopifnot(
"initial guesses must be valid probabilities (greater than 0 and less than 1)" = all(g < 1 & g > 0)
)

stopifnot("initial guesses must be valid probabilities (greater than 0 and less than 1)"
= all(g < 1 & g > 0))


rust_em_link(X,g, tol, max_iter)
rust_em_link(X, g, tol, max_iter)
}
156 changes: 78 additions & 78 deletions R/euclidean_join_core.R
Original file line number Diff line number Diff line change
@@ -1,95 +1,95 @@
multi_by_validate <- function(a,b, by) {
# first pass to handle dplyr::join_by() call
if (inherits(by, "dplyr_join_by")) {
if (any(by$condition != "==")) {
stop("Inequality joins are not supported.")
}
new_by <- by$y
names(new_by) <- by$x
by <- new_by
multi_by_validate <- function(a, b, by) {
# first pass to handle dplyr::join_by() call
if (inherits(by, "dplyr_join_by")) {
if (any(by$condition != "==")) {
stop("Inequality joins are not supported.")
}
new_by <- by$y
names(new_by) <- by$x
by <- new_by
}

if (is.null(by)) {
by_a <- intersect(names(a), names(b))
by_b <- intersect(names(a), names(b))
if (is.null(by)) {
by_a <- intersect(names(a), names(b))
by_b <- intersect(names(a), names(b))
} else {
if (!is.null(names(by))) {
by_a <- names(by)
by_b <- by
} else {
if (!is.null(names(by))) {
by_a <- names(by)
by_b <- by
} else {
by_a <- by
by_b <- by
}

stopifnot(by_a %in% names(a))
stopifnot(by_b %in% names(b))
by_a <- by
by_b <- by
}
return(list(
by_a,
by_b
))
stopifnot(by_a %in% names(a))
stopifnot(by_b %in% names(b))
}
return(list(
by_a,
by_b
))
}

#` @importFrom stats pnorm
euclidean_join_core <- function (a, b, by = NULL, n_bands = 30, band_width = 10, threshold=1.0, r=.5, progress = FALSE, mode="inner") {

stopifnot("'radius' must be greater than 0" = threshold > 0)

by <- multi_by_validate(a,b,by)
by_a <- by[[1]]
by_b <- by[[2]]
stopifnot("There should be no NA's in by_a[1]"=!any(is.na(dplyr::pull(a,by_a[1]))))
stopifnot("There should be no NA's in by_a[2]"=!any(is.na(dplyr::pull(a,by_a[2]))))
stopifnot("There should be no NA's in by_b[1]"=!any(is.na(dplyr::pull(b,by_b[1]))))
stopifnot("There should be no NA's in by_b[2]"=!any(is.na(dplyr::pull(b,by_b[2]))))
# ` @importFrom stats pnorm
euclidean_join_core <- function(a, b, by = NULL, n_bands = 30, band_width = 10, threshold = 1.0, r = .5, progress = FALSE, mode = "inner") {
stopifnot("'radius' must be greater than 0" = threshold > 0)

thresh_prob <- euclidean_probability(threshold, n_bands, band_width, r)
if (thresh_prob < .95) {
str <- paste0("A pair of records at the threshold (", threshold,
") have only a ", round(thresh_prob*100), "% chance of being compared.\n",
"Please consider changing `n_bands` and `band_width`, and `r`.")

warning(str)
}
by <- multi_by_validate(a, b, by)
by_a <- by[[1]]
by_b <- by[[2]]
stopifnot("There should be no NA's in by_a[1]" = !anyNA(a[[by_a[1]]]))
stopifnot("There should be no NA's in by_a[2]" = !anyNA(a[[by_a[2]]]))
stopifnot("There should be no NA's in by_b[1]" = !anyNA(b[[by_b[1]]]))
stopifnot("There should be no NA's in by_b[2]" = !anyNA(b[[by_b[2]]]))

match_table <- rust_p_norm_join(
a_mat = as.matrix(dplyr::select(a, dplyr::all_of(by_a))),
b_mat = as.matrix(dplyr::select(b, dplyr::all_of(by_b))),
radius = threshold,
band_width = band_width,
n_bands = n_bands,
r = r,
progress = progress,
seed = round(runif(1,0,2^32))
thresh_prob <- euclidean_probability(threshold, n_bands, band_width, r)
if (thresh_prob < .95) {
str <- paste0(
"A pair of records at the threshold (", threshold,
") have only a ", round(thresh_prob * 100), "% chance of being compared.\n",
"Please consider changing `n_bands` and `band_width`, and `r`."
)

names_in_both <- intersect(names(a), names(b))
warning(str)
}

names(a)[names(a) %in% names_in_both] <-
paste0(names(a)[names(a) %in% names_in_both], ".x")
names(b)[names(b) %in% names_in_both] <-
paste0(names(b)[names(b) %in% names_in_both], ".y")
match_table <- rust_p_norm_join(
a_mat = as.matrix(dplyr::select(a, dplyr::all_of(by_a))),
b_mat = as.matrix(dplyr::select(b, dplyr::all_of(by_b))),
radius = threshold,
band_width = band_width,
n_bands = n_bands,
r = r,
progress = progress,
seed = round(runif(1, 0, 2^32))
)

matches <- dplyr::bind_cols(a[match_table[, 1], ], b[match_table[, 2], ])
names_in_both <- intersect(names(a), names(b))

# No need to look for rows that don't match
if (mode == "inner") {
return(matches)
}
names(a)[names(a) %in% names_in_both] <-
paste0(names(a)[names(a) %in% names_in_both], ".x")
names(b)[names(b) %in% names_in_both] <-
paste0(names(b)[names(b) %in% names_in_both], ".y")

not_matched_a <- ! seq(nrow(a)) %in% match_table[,1]
not_matched_b <- ! seq(nrow(b)) %in% match_table[,2]
matches <- dplyr::bind_cols(a[match_table[, 1], ], b[match_table[, 2], ])

if (mode == "left") {
matches <- dplyr::bind_rows(matches,a[not_matched_a,])
} else if (mode == "right") {
matches <- dplyr::bind_rows(matches,b[not_matched_b,])
} else if (mode == "full") {
matches <- dplyr::bind_rows(matches,a[not_matched_a,],b[not_matched_b,])
} else if (mode == "anti") {
matches <- dplyr::bind_rows(a[not_matched_a,], b[not_matched_b,])
} else {
stop("Invalid Mode Selected!")
}
# No need to look for rows that don't match
if (mode == "inner") {
return(matches)
}

not_matched_a <- !seq(nrow(a)) %in% match_table[, 1]
not_matched_b <- !seq(nrow(b)) %in% match_table[, 2]

if (mode == "left") {
matches <- dplyr::bind_rows(matches, a[not_matched_a, ])
} else if (mode == "right") {
matches <- dplyr::bind_rows(matches, b[not_matched_b, ])
} else if (mode == "full") {
matches <- dplyr::bind_rows(matches, a[not_matched_a, ], b[not_matched_b, ])
} else if (mode == "anti") {
matches <- dplyr::bind_rows(a[not_matched_a, ], b[not_matched_b, ])
} else {
stop("Invalid Mode Selected!")
}
return(matches)
}
Loading

0 comments on commit c7ace44

Please sign in to comment.