Skip to content

Commit

Permalink
Issue #625: solve aggregation issue for test.join with probs (#1275)
Browse files Browse the repository at this point in the history
* Issue #625: solve aggregation issue for test.join with probs

* extend tests
  • Loading branch information
giuseppec authored and larskotthoff committed Oct 13, 2016
1 parent b46e562 commit 501e60b
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 1 deletion.
7 changes: 6 additions & 1 deletion R/aggregations.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,13 @@ test.join = makeAggregation(
df = as.data.frame(pred)
f = if (length(group)) group[df$iter] else factor(rep(1L, nrow(df)))
mean(vnapply(split(df, f), function(df) {
if (pred$predict.type == "response") y = df$response
if (pred$predict.type == "prob") {
y = df[,grepl("^prob[.]", colnames(df))]
colnames(y) = gsub("^prob[.]", "", colnames(y))
}
npred = makePrediction(task.desc = pred$task.desc, row.names = rownames(df),
id = NULL, truth = df$truth, predict.type = pred$predict.type, y = df$response,
id = NULL, truth = df$truth, predict.type = pred$predict.type, y = y,
time = NA_real_)
performance(npred, measure)
}))
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test_base_resample_cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,15 @@ test_that("cv instance works is stochastic", {
rin2 = makeResampleInstance(makeResampleDesc("CV", iters = 2L), size = 500)
expect_true(!all(sort(rin1$test.inds[[1]]) == sort(rin2$test.inds[[1]])))
})


test_that("test.join works somehow", {
lrn = makeLearner("classif.rpart", predict.type = "prob")

# check if test.join computes acc correctly
mm = setAggregation(acc, test.join)
r = resample(lrn, sonar.task, cv2, measures = mm)
rpred = getRRPredictions(r)
expect_equal(as.numeric(r$aggr),
mean(getPredictionTruth(rpred) == getPredictionResponse(rpred)))
})
10 changes: 10 additions & 0 deletions tests/testthat/test_base_resample_holdout.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,13 @@ test_that("holdout fixed instance works", {
expect_equal(length(rin$train.inds[[1]]), 5)
expect_equal(length(rin$test.inds[[1]]), 15)
})

test_that("holdout test.join works somehow", {
lrn = makeLearner("classif.rpart", predict.type = "prob")

# for holdout test.join and test.mean should be the same
rin = makeResampleDesc("Holdout")
mm = list(setAggregation(auc, test.join), auc)
r = resample(lrn, sonar.task, rin, measures = mm)
expect_equal(as.integer(diff(r$aggr)), 0)
})
4 changes: 4 additions & 0 deletions tests/testthat/test_base_resample_repcv.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,8 @@ test_that("test.join works somehow", {
res = resample(learner = lrn, task = task, resampling = rin, measures = measures)
expect_equal(res$measures.test[, 2L], res$measures.test[, 3L])
expect_true(diff(res$aggr) > 0)

lrn = setPredictType(lrn, predict.type = "prob")
res.prob = resample(learner = lrn, task = task, resampling = rin, measures = measures)
expect_equal(res.prob$measures.test[, 2L], res.prob$measures.test[, 3L])
})

0 comments on commit 501e60b

Please sign in to comment.