Skip to content

Commit

Permalink
test lambdarank_unbiased
Browse files Browse the repository at this point in the history
  • Loading branch information
tubihongfeili committed May 15, 2024
1 parent 82d846b commit 1b20b25
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 1 deletion.
11 changes: 11 additions & 0 deletions jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@
<module>xgboost4j-spark</module>
<module>xgboost4j-flink</module>
</modules>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<skipTests>true</skipTests>
</configuration>
</plugin>
</plugins>
</build>
</profile>

<profile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
private def getOrCreateSession = synchronized {
if (currentSession == null) {
currentSession = sparkSessionBuilder.getOrCreate()
currentSession.sparkContext.setLogLevel("ERROR")
currentSession.sparkContext.setLogLevel("INFO")
}
currentSession
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,20 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
assert(testDF.count() === prediction.length)
}

test("ranking: test position bias") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "0", "verbosity" -> "3",
"objective" -> "rank:ndcg", "num_workers" -> numWorkers, "num_round" -> 5,
"group_col" -> "group", "tree_method" -> treeMethod, "lambdarank_unbiased" -> true, "eval_metric" -> "ndcg")

val trainingDF = buildDataFrameWithGroup(Ranking.train)
val testDF = buildDataFrame(Ranking.test)
val model = new XGBoostRegressor(paramMap).fit(trainingDF)

val prediction = model.transform(testDF).collect()
println("hello---------hongfei")
assert(testDF.count() === prediction.length)
}

test("use weight") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
Expand Down
3 changes: 3 additions & 0 deletions src/objective/lambdarank_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "xgboost/span.h" // for Span, operator!=
#include "xgboost/string_view.h" // for operator<<, StringView
#include "xgboost/task.h" // for ObjInfo
#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...

namespace xgboost::obj {
namespace cpu_impl {
Expand Down Expand Up @@ -314,6 +315,8 @@ class LambdaRankObj : public FitIntercept {
CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight();
}

collective::Print("hongfeili-cpp lambdarank_unbiased: " + std::to_string(param_.lambdarank_unbiased));

if (ti_plus_.Size() == 0 && param_.lambdarank_unbiased) {
CHECK_EQ(iter, 0);
ti_plus_ = linalg::Constant<double>(ctx_, 1.0, p_cache_->MaxPositionSize());
Expand Down

0 comments on commit 1b20b25

Please sign in to comment.