Skip to content

Commit

Permalink
Smooth estimated U for convex clustering
Browse files Browse the repository at this point in the history
- When convex clustering, it appears that
  the V's (active edge IDs / cluster identities)
  converge more rapidly than the associated rows
  of U, so "smooth" the estimated U by replacing
  each row with the mean of all rows in that cluster
  at each iteration

  Not currently doing this for CBASS since I'm not
  sure how to get the joint row- and column-wise cluster
  info with our current independent post-processing design
  • Loading branch information
michaelweylandt committed Dec 19, 2018
1 parent 190b13a commit d062576
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 8 deletions.
3 changes: 2 additions & 1 deletion R/carp.R
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ CARP <- function(X,
v_zero_indices = carp.sol.path$v_zero_inds,
labels = labels,
dendrogram_scale = dendrogram.scale,
npcs = npcs)
npcs = npcs,
smooth_U = TRUE)

carp.fit <- list(
X = X.orig,
Expand Down
20 changes: 13 additions & 7 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ ConvexClusteringPostProcess <- function(X,
labels,
dendrogram_scale,
npcs,
internal_transpose = FALSE){
internal_transpose = FALSE,
smooth_U = FALSE){

n <- NROW(X)
p <- NCOL(X)
Expand All @@ -342,13 +343,18 @@ ConvexClusteringPostProcess <- function(X,
gamma.path = gamma_path,
cardE = num_edges)

cluster_path[["clust.path"]] <- get_cluster_assignments(edge_matrix, cluster_path$sp.path.inter, n)
cluster_path[["clust.path.dups"]] <- duplicated(cluster_path[["clust.path"]], fromList = FALSE)
cluster_fusion_info <- get_cluster_assignments(edge_matrix, cluster_path$sp.path.inter, n)
cluster_path[["clust.path"]] <- cluster_fusion_info
cluster_path[["clust.path.dups"]] <- duplicated(cluster_fusion_info, fromList = FALSE)

U <- array(cluster_path$u.path.inter, dim = c(n, p, length(cluster_path[["clust.path.dups"]])))
rownames(U) <- rownames(X)
colnames(U) <- colnames(X)

if(smooth_U){
U <- smooth_u_clustering(U, cluster_fusion_info)
}

if (internal_transpose) {
## When looking at the column fusions from CBASS, we want U to be
## a 3 tensor of size p by n by K, each slice of which is really U^T
Expand All @@ -366,11 +372,11 @@ ConvexClusteringPostProcess <- function(X,
X_pca <- stats::prcomp(X, scale. = FALSE, center = FALSE)
rotation_matrix <- X_pca$rotation[, seq_len(npcs)]

membership_info <- tibble(Iter = rep(seq_along(cluster_path$clust.path), each = n),
Obs = rep(seq_len(n), times = length(cluster_path$clust.path)),
Cluster = as.vector(vapply(cluster_path$clust.path, function(x) x$membership, double(n))),
membership_info <- tibble(Iter = rep(seq_along(cluster_fusion_info), each = n),
Obs = rep(seq_len(n), times = length(cluster_fusion_info)),
Cluster = as.vector(vapply(cluster_fusion_info, function(x) x$membership, double(n))),
Gamma = rep(cluster_path$gamma.path.inter, each = n),
ObsLabel = rep(labels, times = length(cluster_path$clust.path))) %>%
ObsLabel = rep(labels, times = length(cluster_fusion_info))) %>%
group_by(.data$Iter) %>%
mutate(NCluster = n_distinct(.data$Cluster)) %>%
ungroup() %>%
Expand Down
63 changes: 63 additions & 0 deletions src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,69 @@ void check_weight_matrix(const Eigen::MatrixXd& weight_matrix){
}
}

// U-smoothing for convex clustering
//
// Given cluster memberships, replace rows of U which belong to the same cluster
// with their mutual mean....
//
// [[Rcpp::export]]
Rcpp::NumericVector smooth_u_clustering(Rcpp::NumericVector U_old, Rcpp::List cluster_info_list){
// The first argument is really an array but we pass as a NumericVector
// The second argument is a list produced by get_cluster_assignments()
Rcpp::IntegerVector U_dims = U_old.attr("dim");
if(U_dims.size() != 3){
ClustRVizLogger::error("U must be a three rank tensor.");
}
int N = U_dims(0);
int P = U_dims(1);
int Q = U_dims(2);

// Check length of cluster_info
if(cluster_info_list.size() != Q){
ClustRVizLogger::error("Dimensions of U and cluster_info do not match");
}

Rcpp::NumericVector U(N * P * Q);
U.attr("dim") = U_dims;
Rcpp::rownames(U) = Rcpp::rownames(U_old);
Rcpp::colnames(U) = Rcpp::colnames(U_old);

for(Eigen::Index q = 0; q < Q; q++){
Rcpp::List cluster_info = cluster_info_list[q];
uint n_clusters = Rcpp::as<uint>(cluster_info[2]);

Rcpp::IntegerVector cluster_ids = cluster_info[0];
Rcpp::IntegerVector cluster_sizes = cluster_info[1];

Eigen::MatrixXd U_old_slice = Eigen::Map<Eigen::MatrixXd>(&U_old[N * P * q], N, P);
Eigen::MatrixXd U_new(N, P);

for(uint j = 1; j <= n_clusters; j++){ // Cluster IDs are 1-based (per R conventions)
Eigen::VectorXd vec = Eigen::VectorXd::Zero(P);

// Manually work out new mean
for(uint n = 0; n < N; n++){
if(cluster_ids[n] == j){
vec += U_old_slice.row(n);
}
}

vec /= cluster_sizes[j - 1]; // Subtract 1 to adjust to C++ indexing

// Assign new mean where needed...
for(uint n = 0; n < N; n++){
if(cluster_ids[n] == j){
U_new.row(n) = vec;
}
}
}

Eigen::Map<Eigen::MatrixXd>(&U[N * P * q], N, P) = U_new;
}

return U;
}

// Tensor projection along the second mode
//
// Given a 3D tensor X in R^{n-by-p-by-q} (observations by features by iterations)
Expand Down
27 changes: 27 additions & 0 deletions tests/testthat/test_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,30 @@ test_that("connectedness check works", {
A <- eye(3); A[1,2] <- A[2,1] <- 1
expect_false(is_connected_adj_mat(A))
})

test_that("U smoothing for CARP works", {
set.seed(200)

N <- 50
P <- 30

U <- array(rnorm(N * P), c(N, P, 1))

# Fake cluster assignments
K <- 5
membership <- sample(K, n, replace = TRUE)
cluster_info <- list(membership = membership,
csize = table(membership),
no = length(unique(membership)))

U_smoothed <- smooth_u_clustering(U, list(cluster_info))

for(k in 1:K){
u_row_mean <- colMeans(U[membership == k,,1])
for(n in 1:N){
if(membership[n] == k){
expect_equal(U_smoothed[n,,1], u_row_mean)
}
}
}
})

0 comments on commit d062576

Please sign in to comment.