diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index b43f096c31ee..8f8137df1f48 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -422,7 +422,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data, * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` * \param nrow Number of rows * \param ncol Number of columns - * \param is_row_major 1 for row-major, 0 for column-major + * \param is_row_major Pointer to the data layouts. 1 for row-major, 0 for column-major * \param parameters Additional parameters * \param reference Used to align bin mapper with other dataset, nullptr means isn't used * \param[out] out Created dataset @@ -433,7 +433,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMats(int32_t nmat, int data_type, int32_t* nrow, int32_t ncol, - int is_row_major, + int* is_row_major, const char* parameters, const DatasetHandle reference, DatasetHandle* out); diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 7b152fd2b006..5fa3cb872ca7 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -2342,6 +2342,7 @@ def __init_from_list_np2d( ptr_data = (ctypes.POINTER(ctypes.c_double) * len(mats))() else: ptr_data = (ctypes.POINTER(ctypes.c_float) * len(mats))() + layouts = (ctypes.c_int * len(mats))() holders = [] type_ptr_data = -1 @@ -2355,15 +2356,13 @@ def __init_from_list_np2d( nrow[i] = mat.shape[0] - if mat.dtype == np.float32 or mat.dtype == np.float64: - mats[i] = np.asarray(mat.reshape(mat.size), dtype=mat.dtype) - else: # change non-float data to float data, need to copy - mats[i] = np.array(mat.reshape(mat.size), dtype=np.float32) + mat, layout = _np2d_to_np1d(mat) - chunk_ptr_data, chunk_type_ptr_data, holder = _c_float_array(mats[i]) + chunk_ptr_data, chunk_type_ptr_data, holder = _c_float_array(mat) if type_ptr_data != -1 and chunk_type_ptr_data != type_ptr_data: raise ValueError("Input chunks must have same type") ptr_data[i] = chunk_ptr_data + layouts[i] = layout type_ptr_data = chunk_type_ptr_data holders.append(holder) @@ -2375,7 +2374,7 @@ def __init_from_list_np2d( ctypes.c_int(type_ptr_data), nrow.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ctypes.c_int32(ncol), - ctypes.c_int(_C_API_IS_ROW_MAJOR), + layouts, _c_str(params_str), ref_dataset, ctypes.byref(self._handle), diff --git a/src/c_api.cpp b/src/c_api.cpp index cf6577ad5e2c..967d00d3bd15 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -1309,7 +1309,7 @@ int LGBM_DatasetCreateFromMat(const void* data, data_type, &nrow, ncol, - is_row_major, + &is_row_major, parameters, reference, out); @@ -1320,7 +1320,7 @@ int LGBM_DatasetCreateFromMats(int32_t nmat, int data_type, int32_t* nrow, int32_t ncol, - int is_row_major, + int* is_row_major, const char* parameters, const DatasetHandle reference, DatasetHandle* out) { @@ -1337,7 +1337,7 @@ int LGBM_DatasetCreateFromMats(int32_t nmat, std::vector(int row_idx)>> get_row_fun; for (int j = 0; j < nmat; ++j) { - get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major)); + get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major[j])); } if (reference == nullptr) { diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index bdd4d3f58b80..f34be5cc1574 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -983,3 +983,17 @@ def test_equal_datasets_from_row_major_and_col_major_data(tmp_path): # check datasets are equal assert filecmp.cmp(ds_row_path, ds_col_path) + + +def test_equal_datasets_from_one_and_several_matrices_w_different_layouts(rng, tmp_path): + # several matrices + mats = [np.require(rng.random(size=(100, 2)), requirements=order) for order in ("C", "F", "F", "C")] + several_path = tmp_path / "several.txt" + lgb.Dataset(mats)._dump_text(several_path) + + # one matrix + mat = np.vstack(mats) + one_path = tmp_path / "one.txt" + lgb.Dataset(mat)._dump_text(one_path) + + assert filecmp.cmp(one_path, several_path)