Skip to content

Commit

Permalink
refactor: use invoke for predict calls and replace do.call for invoke…
Browse files Browse the repository at this point in the history
… where relevant (#68)

* refactor: use invoke for predict calls and replace do.call for invoke where relevant

* fix: set dbscan and co to density instead or partitioning

* refactor: replace custom warningCondition to warningf and make rename check_ to assert_ for checkmate style behaviour

* refactor: remove unnecessary test_true

* refactor: simplify pars check for PAM

* refactor: always use param_set$get_values() to access param_set values
  • Loading branch information
m-muecke authored May 2, 2024
1 parent 67248b1 commit 9fbe815
Show file tree
Hide file tree
Showing 22 changed files with 72 additions and 64 deletions.
3 changes: 2 additions & 1 deletion R/LearnerClustAffinityPropagation.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ LearnerClustAP = R6Class("LearnerClustAP",
},

.predict = function(task) {
sim_func = self$param_set$values$s
pv = self$param_set$get_values()
sim_func = pv$s
exemplar_data = attributes(self$model)$exemplar_data

d = task$data()
Expand Down
13 changes: 9 additions & 4 deletions R/LearnerClustAgnes.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,22 @@ LearnerClustAgnes = R6Class("LearnerClustAgnes",
),
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
m = invoke(cluster::agnes, x = task$data(), diss = FALSE, .args = pv)
pv = self$param_set$get_values()
m = invoke(cluster::agnes,
x = task$data(),
diss = FALSE,
.args = remove_named(pv, "k")
)
if (self$save_assignments) {
self$assignments = stats::cutree(m, self$param_set$values$k)
self$assignments = stats::cutree(m, pv$k)
}

return(m)
},

.predict = function(task) {
if (self$param_set$values$k > task$nrow) {
pv = self$param_set$get_values(tags = "predict")
if (pv$k > task$nrow) {
stopf("`k` needs to be between 1 and %i", task$nrow)
}

Expand Down
8 changes: 4 additions & 4 deletions R/LearnerClustCMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ LearnerClustCMeans = R6Class("LearnerClustCMeans",
),
private = list(
.train = function(task) {
check_centers_param(self$param_set$values$centers, task, test_data_frame, "centers")

pv = self$param_set$get_values(tags = "train")
assert_centers_param(pv$centers, task, test_data_frame, "centers")

m = invoke(e1071::cmeans, x = task$data(), .args = pv, .opts = allow_partial_matching)
if (self$save_assignments) {
self$assignments = m$cluster
Expand All @@ -74,8 +74,8 @@ LearnerClustCMeans = R6Class("LearnerClustCMeans",
},

.predict = function(task) {
partition = unclass(cl_predict(self$model, newdata = task$data(), type = "class_ids"))
prob = unclass(cl_predict(self$model, newdata = task$data(), type = "memberships"))
partition = unclass(invoke(cl_predict, self$model, newdata = task$data(), type = "class_ids"))
prob = unclass(invoke(cl_predict, self$model, newdata = task$data(), type = "memberships"))
colnames(prob) = seq_len(ncol(prob))

PredictionClust$new(task = task, partition = partition, prob = prob)
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustCobweb.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ LearnerClustCobweb = R6Class("LearnerClustCobweb",
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
ctrl = do.call(RWeka::Weka_control, pv)
ctrl = invoke(RWeka::Weka_control, .args = pv)
m = invoke(RWeka::Cobweb, x = task$data(), control = ctrl)
if (self$save_assignments) {
self$assignments = unname(m$class_ids + 1L)
Expand All @@ -53,7 +53,7 @@ LearnerClustCobweb = R6Class("LearnerClustCobweb",
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
partition = invoke(predict, self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClustDBSCAN.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ LearnerClustDBSCAN = R6Class("LearnerClustDBSCAN",
feature_types = c("logical", "integer", "numeric"),
predict_types = "partition",
param_set = param_set,
properties = c("partitional", "exclusive", "complete"),
properties = c("density", "exclusive", "complete"),
packages = "dbscan",
man = "mlr3cluster::mlr_learners_clust.dbscan",
label = "Density-Based Clustering"
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClustDBSCANfpc.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ LearnerClustDBSCANfpc = R6Class("LearnerClustDBSCANfpc",
feature_types = c("logical", "integer", "numeric"),
predict_types = "partition",
param_set = param_set,
properties = c("partitional", "exclusive", "complete"),
properties = c("density", "exclusive", "complete"),
man = "mlr3cluster::mlr_learners_clust.dbscan_fpc",
label = "Density-Based Clustering with fpc"
)
Expand Down
13 changes: 9 additions & 4 deletions R/LearnerClustDiana.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,22 @@ LearnerClustDiana = R6Class("LearnerClustDiana",
),
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
m = invoke(cluster::diana, x = task$data(), diss = FALSE, .args = pv)
pv = self$param_set$get_values()
m = invoke(cluster::diana,
x = task$data(),
diss = FALSE,
.args = remove_named(pv, "k")
)
if (self$save_assignments) {
self$assignments = stats::cutree(m, self$param_set$values$k)
self$assignments = stats::cutree(m, pv$k)
}

return(m)
},

.predict = function(task) {
if (test_true(self$param_set$values$k > task$nrow)) {
pv = self$param_set$get_values(tags = "predict")
if (pv$k > task$nrow) {
stopf("`k` needs to be between 1 and %s", task$nrow)
}

Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustEM.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ LearnerClustEM = R6Class("LearnerClustEM",
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = chartr("_", "-", names(pv))
ctrl = do.call(RWeka::Weka_control, pv)
ctrl = invoke(RWeka::Weka_control, .args = pv)
m = invoke(RWeka::make_Weka_clusterer("weka/clusterers/EM"), x = task$data(), control = ctrl)
if (self$save_assignments) {
self$assignments = unname(m$class_ids + 1L)
Expand All @@ -64,7 +64,7 @@ LearnerClustEM = R6Class("LearnerClustEM",
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
partition = invoke(predict, self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustFarthestFirst.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ LearnerClustFarthestFirst = R6Class("LearnerClustFF",
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = chartr("_", "-", names(pv))
ctrl = do.call(RWeka::Weka_control, pv)
ctrl = invoke(RWeka::Weka_control, .args = pv)
m = invoke(RWeka::FarthestFirst, x = task$data(), control = ctrl)
if (self$save_assignments) {
self$assignments = unname(m$class_ids + 1L)
Expand All @@ -54,7 +54,7 @@ LearnerClustFarthestFirst = R6Class("LearnerClustFF",
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
partition = invoke(predict, self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClustHDBSCAN.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ LearnerClustHDBSCAN = R6Class("LearnerClustHDBSCAN",
feature_types = c("logical", "integer", "numeric"),
predict_types = "partition",
param_set = param_set,
properties = c("partitional", "exclusive", "complete"),
properties = c("density", "exclusive", "complete"),
packages = "dbscan",
man = "mlr3cluster::mlr_learners_clust.hdbscan",
label = "HDBSCAN Clustering"
Expand Down
17 changes: 10 additions & 7 deletions R/LearnerClustHclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,26 @@ LearnerClustHclust = R6Class("LearnerClustHclust",
),
private = list(
.train = function(task) {
d = self$param_set$values$distmethod
dist_arg = self$param_set$get_values(tags = c("train", "dist"))
pv = self$param_set$get_values()
dist = invoke(stats::dist,
x = task$data(),
method = ifelse(is.null(d), "euclidean", d), .args = dist_arg
method = pv$d %??% "euclidean",
.args = self$param_set$get_values(tags = c("train", "dist"))
)
m = invoke(stats::hclust,
d = dist,
.args = self$param_set$get_values(tags = c("train", "hclust"))
)
pv = self$param_set$get_values(tags = c("train", "hclust"))
m = invoke(stats::hclust, d = dist, .args = pv)
if (self$save_assignments) {
self$assignments = stats::cutree(m, self$param_set$values$k)
self$assignments = stats::cutree(m, pv$k)
}

return(m)
},

.predict = function(task) {
if (self$param_set$values$k > task$nrow) {
pv = self$param_set$get_values(tags = "predict")
if (pv$k > task$nrow) {
stopf("`k` needs to be between 1 and %i", task$nrow)
}

Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustKKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ LearnerClustKKMeans = R6Class("LearnerClustKKMeans",
),
private = list(
.train = function(task) {
check_centers_param(self$param_set$values$centers, task, test_data_frame, "centers")

pv = self$param_set$get_values(tags = "train")
assert_centers_param(pv$centers, task, test_data_frame, "centers")

m = invoke(kernlab::kkmeans, x = as.matrix(task$data()), .args = pv)
if (self$save_assignments) {
self$assignments = m[seq_along(m)]
Expand Down
8 changes: 4 additions & 4 deletions R/LearnerClustKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ LearnerClustKMeans = R6Class("LearnerClustKMeans",

private = list(
.train = function(task) {
if ("nstart" %in% names(self$param_set$values) && !test_int(self$param_set$values$centers)) {
pv = self$param_set$get_values(tags = "train")
if (!is.null(pv$nstart) && !test_int(pv$centers)) {
warningf("`nstart` parameter is only relevant when `centers` is integer.")
}

check_centers_param(self$param_set$values$centers, task, test_data_frame, "centers")
assert_centers_param(pv$centers, task, test_data_frame, "centers")

pv = self$param_set$get_values(tags = "train")
m = invoke(stats::kmeans, x = task$data(), .args = pv)
if (self$save_assignments) {
self$assignments = m$cluster
Expand All @@ -68,7 +68,7 @@ LearnerClustKMeans = R6Class("LearnerClustKMeans",
},

.predict = function(task) {
partition = unclass(cl_predict(self$model, newdata = task$data(), type = "class_ids"))
partition = unclass(invoke(cl_predict, self$model, newdata = task$data(), type = "class_ids"))
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClustMclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ LearnerClustMclust = R6Class("LearnerClustMclust",
},

.predict = function(task) {
predictions = predict(self$model, newdata = task$data())
predictions = invoke(predict, self$model, newdata = task$data())
partition = as.integer(predictions$classification)
prob = predictions$z
PredictionClust$new(task = task, partition = partition, prob = prob)
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustMeanShift.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ LearnerClustMeanShift = R6Class("LearnerClustMeanShift",
),
private = list(
.train = function(task) {
if (!is.null(self$param_set$values$subset) && length(self$param_set$values$subset) > task$nrow) {
pv = self$param_set$get_values(tags = "train")
if (!is.null(pv$subset) && length(pv$subset) > task$nrow) {
stopf("`subset` length must be less than or equal to number of observations in task")
}

pv = self$param_set$get_values(tags = "train")
m = invoke(LPCM::ms, X = task$data(), .args = pv)
if (self$save_assignments) {
self$assignments = m$cluster.label
Expand Down
13 changes: 6 additions & 7 deletions R/LearnerClustMiniBatchKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,15 @@ LearnerClustMiniBatchKMeans = R6Class("LearnerClustMiniBatchKMeans",
),
private = list(
.train = function(task) {
check_centers_param(self$param_set$values$CENTROIDS, task, test_matrix, "CENTROIDS")
if (test_matrix(self$param_set$values$CENTROIDS) &&
nrow(self$param_set$values$CENTROIDS) != self$param_set$values$clusters) {
pv = self$param_set$get_values(tags = "train")
assert_centers_param(pv$CENTROIDS, task, test_matrix, "CENTROIDS")
if (test_matrix(pv$CENTROIDS) && nrow(pv$CENTROIDS) != pv$clusters) {
stopf("`CENTROIDS` must have same number of rows as `clusters`")
}

pv = self$param_set$get_values(tags = "train")
m = invoke(ClusterR::MiniBatchKmeans, data = task$data(), .args = pv)
if (self$save_assignments) {
self$assignments = unclass(ClusterR::predict_MBatchKMeans(
self$assignments = unclass(invoke(ClusterR::predict_MBatchKMeans,
data = task$data(),
CENTROIDS = m$centroids,
fuzzy = FALSE
Expand All @@ -82,15 +81,15 @@ LearnerClustMiniBatchKMeans = R6Class("LearnerClustMiniBatchKMeans",

.predict = function(task) {
if (self$predict_type == "partition") {
partition = unclass(ClusterR::predict_MBatchKMeans(
partition = unclass(invoke(ClusterR::predict_MBatchKMeans,
data = task$data(),
CENTROIDS = self$model$centroids,
fuzzy = FALSE
))
partition = as.integer(partition)
pred = PredictionClust$new(task = task, partition = partition)
} else if (self$predict_type == "prob") {
partition = unclass(ClusterR::predict_MBatchKMeans(
partition = unclass(invoke(ClusterR::predict_MBatchKMeans,
data = task$data(),
CENTROIDS = self$model$centroids,
fuzzy = TRUE
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClustOPTICS.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ LearnerClustOPTICS = R6Class("LearnerClustOPTICS",
feature_types = c("logical", "integer", "numeric"),
predict_types = "partition",
param_set = param_set,
properties = c("partitional", "exclusive", "complete"),
properties = c("density", "exclusive", "complete"),
packages = "dbscan",
man = "mlr3cluster::mlr_learners_clust.optics",
label = "OPTICS Clustering"
Expand Down
16 changes: 6 additions & 10 deletions R/LearnerClustPAM.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,17 @@ LearnerClustPAM = R6Class("LearnerClustPAM",
),
private = list(
.train = function(task) {
if (!is.null(self$param_set$values$medoids)) {
if (test_true(length(self$param_set$values$medoids) != self$param_set$values$k)) {
pv = self$param_set$get_values(tags = "train")
if (!is.null(pv$medoids)) {
if (length(pv$medoids) != pv$k) {
stopf("number of `medoids`' needs to match `k`!")
}
r = map_lgl(self$param_set$values$medoids, function(i) {
test_true(i <= task$nrow) && test_true(i >= 1L)
})
if (sum(r) != self$param_set$values$k) {
if (sum(pv$medoids <= task$nrow & pv$medoids >= 1L) != pv$k) {
msg = sprintf("`medoids` need to contain valid indices from 1")
msg = sprintf("%s to %s (number of observations)!", msg, self$param_set$values$k)
stopf(msg)
stopf("%s to %s (number of observations)!", msg, pv$k)
}
}

pv = self$param_set$get_values(tags = "train")
m = invoke(cluster::pam, x = task$data(), diss = FALSE, .args = pv)
if (self$save_assignments) {
self$assignments = m$clustering
Expand All @@ -76,7 +72,7 @@ LearnerClustPAM = R6Class("LearnerClustPAM",
},

.predict = function(task) {
partition = unclass(cl_predict(self$model, newdata = task$data(), type = "class_ids"))
partition = unclass(invoke(cl_predict, self$model, newdata = task$data(), type = "class_ids"))
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustSimpleKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ LearnerClustSimpleKMeans = R6Class("LearnerClustSimpleKMeans",
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = chartr("_", "-", names(pv))
ctrl = do.call(RWeka::Weka_control, pv)
ctrl = invoke(RWeka::Weka_control, .args = pv)
m = invoke(RWeka::SimpleKMeans, x = task$data(), control = ctrl)
if (self$save_assignments) {
self$assignments = unname(m$class_ids + 1L)
Expand All @@ -68,7 +68,7 @@ LearnerClustSimpleKMeans = R6Class("LearnerClustSimpleKMeans",
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
partition = invoke(predict, self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustXMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ LearnerClustXMeans = R6Class("LearnerClustXMeans",
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = chartr("_", "-", names(pv))
ctrl = do.call(RWeka::Weka_control, pv)
ctrl = invoke(RWeka::Weka_control, .args = pv)
m = invoke(RWeka::XMeans, x = task$data(), control = ctrl)
if (self$save_assignments) {
self$assignments = unname(m$class_ids + 1L)
Expand All @@ -67,7 +67,7 @@ LearnerClustXMeans = R6Class("LearnerClustXMeans",
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
partition = invoke(predict, self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down
5 changes: 2 additions & 3 deletions R/helper.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
warn_prediction_useless = function(id) {
msg = sprintf("Learner '%s' doesn't predict on new data and predictions may not make sense on new data", id)
warning(warningCondition(msg, class = "predictionUselessWarning"))
warningf("Learner '%s' doesn't predict on new data and predictions may not make sense on new data.", id)
}

allow_partial_matching = list(
Expand All @@ -9,7 +8,7 @@ allow_partial_matching = list(
warnPartialMatchDollar = FALSE
)

check_centers_param = function(centers, task, test_class, name) {
assert_centers_param = function(centers, task, test_class, name) {
if (test_class(centers) && ncol(centers) != task$ncol) {
stopf("`%s` must have same number of columns as data.", name)
}
Expand Down
Loading

0 comments on commit 9fbe815

Please sign in to comment.