From 501e60bdae28e19b78a66e48d78a5856d7432586 Mon Sep 17 00:00:00 2001 From: giuseppec Date: Thu, 13 Oct 2016 09:46:29 +0200 Subject: [PATCH] Issue #625: solve aggregation issue for test.join with probs (#1275) * Issue #625: solve aggregation issue for test.join with probs * extend tests --- R/aggregations.R | 7 ++++++- tests/testthat/test_base_resample_cv.R | 12 ++++++++++++ tests/testthat/test_base_resample_holdout.R | 10 ++++++++++ tests/testthat/test_base_resample_repcv.R | 4 ++++ 4 files changed, 32 insertions(+), 1 deletion(-) diff --git a/R/aggregations.R b/R/aggregations.R index d89efd82fb..885713813b 100644 --- a/R/aggregations.R +++ b/R/aggregations.R @@ -235,8 +235,13 @@ test.join = makeAggregation( df = as.data.frame(pred) f = if (length(group)) group[df$iter] else factor(rep(1L, nrow(df))) mean(vnapply(split(df, f), function(df) { + if (pred$predict.type == "response") y = df$response + if (pred$predict.type == "prob") { + y = df[,grepl("^prob[.]", colnames(df))] + colnames(y) = gsub("^prob[.]", "", colnames(y)) + } npred = makePrediction(task.desc = pred$task.desc, row.names = rownames(df), - id = NULL, truth = df$truth, predict.type = pred$predict.type, y = df$response, + id = NULL, truth = df$truth, predict.type = pred$predict.type, y = y, time = NA_real_) performance(npred, measure) })) diff --git a/tests/testthat/test_base_resample_cv.R b/tests/testthat/test_base_resample_cv.R index b91e5ef483..68492ce0f9 100644 --- a/tests/testthat/test_base_resample_cv.R +++ b/tests/testthat/test_base_resample_cv.R @@ -49,3 +49,15 @@ test_that("cv instance works is stochastic", { rin2 = makeResampleInstance(makeResampleDesc("CV", iters = 2L), size = 500) expect_true(!all(sort(rin1$test.inds[[1]]) == sort(rin2$test.inds[[1]]))) }) + + +test_that("test.join works somehow", { + lrn = makeLearner("classif.rpart", predict.type = "prob") + + # check if test.join computes acc correctly + mm = setAggregation(acc, test.join) + r = resample(lrn, sonar.task, cv2, measures = mm) + rpred = getRRPredictions(r) + expect_equal(as.numeric(r$aggr), + mean(getPredictionTruth(rpred) == getPredictionResponse(rpred))) +}) diff --git a/tests/testthat/test_base_resample_holdout.R b/tests/testthat/test_base_resample_holdout.R index d8df5d8478..4142b09345 100644 --- a/tests/testthat/test_base_resample_holdout.R +++ b/tests/testthat/test_base_resample_holdout.R @@ -19,3 +19,13 @@ test_that("holdout fixed instance works", { expect_equal(length(rin$train.inds[[1]]), 5) expect_equal(length(rin$test.inds[[1]]), 15) }) + +test_that("holdout test.join works somehow", { + lrn = makeLearner("classif.rpart", predict.type = "prob") + + # for holdout test.join and test.mean should be the same + rin = makeResampleDesc("Holdout") + mm = list(setAggregation(auc, test.join), auc) + r = resample(lrn, sonar.task, rin, measures = mm) + expect_equal(as.integer(diff(r$aggr)), 0) +}) diff --git a/tests/testthat/test_base_resample_repcv.R b/tests/testthat/test_base_resample_repcv.R index a897103ac3..28fdd81f1a 100644 --- a/tests/testthat/test_base_resample_repcv.R +++ b/tests/testthat/test_base_resample_repcv.R @@ -62,4 +62,8 @@ test_that("test.join works somehow", { res = resample(learner = lrn, task = task, resampling = rin, measures = measures) expect_equal(res$measures.test[, 2L], res$measures.test[, 3L]) expect_true(diff(res$aggr) > 0) + + lrn = setPredictType(lrn, predict.type = "prob") + res.prob = resample(learner = lrn, task = task, resampling = rin, measures = measures) + expect_equal(res.prob$measures.test[, 2L], res.prob$measures.test[, 3L]) })