From dd552e235518f4426d92ea59460cae5e10b39179 Mon Sep 17 00:00:00 2001 From: getumen Date: Tue, 6 Aug 2024 11:07:38 +0900 Subject: [PATCH] Fix bug in CUML (#52) * chore: Update Dockerfile and add dependencies in /testdata * chore: Update Dockerfile to remove unused dependencies * fix go bindings * fix bug in FIL --- src/fil.cu | 14 +++----------- tests/CMakeLists.txt | 4 ++-- tests/fil_test.cpp | 9 +++++---- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/fil.cu b/src/fil.cu index 691816a..98a6afb 100644 --- a/src/fil.cu +++ b/src/fil.cu @@ -37,21 +37,13 @@ namespace char const *filename, TreeliteModelHandle *model_handle) { - std::string json_config = "{\"allow_unknown_field\": True}"; + std::string json_config = "{}"; switch (model_type) { case ModelType::XGBoost: - return TreeliteLoadXGBoostModel(filename, json_config.c_str(), model_handle); + return TreeliteLoadXGBoostModelLegacyBinary(filename, json_config.c_str(), model_handle); case ModelType::XGBoostJSON: { - std::ifstream file(filename); // Replace with your file name - if (!file.is_open()) { - return -1; - } - std::string content((std::istreambuf_iterator(file)), - std::istreambuf_iterator()); - file.close(); - - return TreeliteLoadXGBoostModelFromString(content.c_str(), content.length(), json_config.c_str(), model_handle); + return TreeliteLoadXGBoostModel(filename, json_config.c_str(), model_handle); } case ModelType::LightGBM: return TreeliteLoadLightGBMModel(filename, json_config.c_str(), model_handle); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6827bf9..1ac5a50 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -18,8 +18,8 @@ add_executable( cuml_test memory_resource_test.cpp clustering_test.cpp - # fil_test.cpp - # linear_regression_test.cpp + fil_test.cpp + linear_regression_test.cpp ) target_compile_options(cuml_test PRIVATE $<$:--expt-extended-lambda --expt-relaxed-constexpr>) diff --git a/tests/fil_test.cpp b/tests/fil_test.cpp index 17ef08f..c25a86a 100644 --- a/tests/fil_test.cpp +++ b/tests/fil_test.cpp @@ -7,9 +7,10 @@ #include "cuml4c/memory_resource.h" #include "cuml4c/fil.h" -TEST(FILTest, TestTreelite) + +TEST(FILTest, TestTreeliteJSON) { - std::string json_config = "{\"allow_unknown_field\": True}"; + std::string json_config = "{}"; TreeliteModelHandle handle; auto res = TreeliteLoadXGBoostModel("testdata/xgboost.json", json_config.c_str(), &handle); @@ -25,10 +26,10 @@ TEST(FILTest, TestFIL) CreateDeviceResourceHandle(&device_resource_handle); DeviceMemoryResource mr; - UseArenaMemoryResource(&mr, 64 * 1024); + UseArenaMemoryResource(&mr, 1024 * 1024); FILModelHandle handle; - auto res = FILLoadModel(device_resource_handle, 0, "testdata/xgboost.model", 0, true, 0.5, 0, 0, 1, 0, &handle); + auto res = FILLoadModel(device_resource_handle, 1, "testdata/xgboost.json", 0, true, 0.5, 0, 0, 1, 0, &handle); EXPECT_EQ(res, 0); std::vector feature;