diff --git a/python/setup.py b/python/setup.py index cedd6338..9c9e9706 100644 --- a/python/setup.py +++ b/python/setup.py @@ -19,7 +19,7 @@ copyfile(lib_path, path.join(dirname, "fedtree", path.basename(lib_path))) setuptools.setup(name="fedtree", - version="1.0.3", + version="1.0.4", packages=["fedtree"], package_dir={"python": "fedtree"}, description="A federated learning library for trees", diff --git a/src/FedTree/scikit_fedtree.cpp b/src/FedTree/scikit_fedtree.cpp index c01f1ff5..5ba1eb84 100644 --- a/src/FedTree/scikit_fedtree.cpp +++ b/src/FedTree/scikit_fedtree.cpp @@ -234,6 +234,11 @@ extern "C" { test_dataset.label.emplace_back(group_label[i]); } } + else{ + for (int i = 0; i < num_class; ++i) { + test_dataset.label.emplace_back(i); + } + } // predict SyncArray y_predict; vector> boosted_model_in_mem; @@ -274,8 +279,16 @@ extern "C" { test_dataset.load_from_sparse(row_size, val, row_ptr, col_ptr, NULL, group, num_group, model_param); set_logger(verbose); test_dataset.label.clear(); - for (int i = 0; i < num_class; ++i) { - test_dataset.label.emplace_back(group_label[i]); + if(group_label != NULL) { + test_dataset.label.clear(); + for (int i = 0; i < num_class; ++i) { + test_dataset.label.emplace_back(group_label[i]); + } + } + else{ + for (int i = 0; i < num_class; ++i) { + test_dataset.label.emplace_back(i); + } } // predict SyncArray y_predict;