From 2dc7562fbee0782003fa2a83b75359ba10264067 Mon Sep 17 00:00:00 2001 From: Agisilaos Kounelis <36283973+kounelisagis@users.noreply.github.com> Date: Thu, 29 Aug 2024 17:34:53 +0300 Subject: [PATCH] Make `from_pandas` respect `column_types` for index dimensions (#2046) --- tiledb/dataframe_.py | 10 ++- tiledb/tests/test_pandas_dataframe.py | 92 ++++++++++++++++++++++++++- 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/tiledb/dataframe_.py b/tiledb/dataframe_.py index 9543d91a3d..489bb10f5d 100644 --- a/tiledb/dataframe_.py +++ b/tiledb/dataframe_.py @@ -420,7 +420,9 @@ def _sparse_from_dtypes(dtypes, sparse=None): return sparse if sparse is not None else False -def create_dims(df, index_dims, tile=None, full_domain=False, filters=None): +def create_dims( + df, index_dims, column_infos, tile=None, full_domain=False, filters=None +): check_dataframe_deps() import pandas as pd @@ -445,7 +447,10 @@ def create_dims(df, index_dims, tile=None, full_domain=False, filters=None): else: raise ValueError(f"Unknown column or index named {name!r}") - dtype = ColumnInfo.from_values(values).dtype + if name in column_infos: + dtype = column_infos[name].dtype + else: + dtype = ColumnInfo.from_values(values).dtype internal_dtype = dtype if name == "__tiledb_rows" and isinstance(index, pd.RangeIndex): @@ -659,6 +664,7 @@ def _create_array(uri, df, sparse, full_domain, index_dims, column_infos, tiledb dims, dim_metadata = create_dims( df, index_dims, + column_infos, full_domain=full_domain, tile=tiledb_args.get("tile"), filters=tiledb_args.get("dim_filters", True), diff --git a/tiledb/tests/test_pandas_dataframe.py b/tiledb/tests/test_pandas_dataframe.py index 51c02d74e5..72a22f00a8 100644 --- a/tiledb/tests/test_pandas_dataframe.py +++ b/tiledb/tests/test_pandas_dataframe.py @@ -127,7 +127,7 @@ def make_dataframe_categorical(): return df -class TestColumnInfo: +class TestColumnInfo(DiskTestCase): def assertColumnInfo(self, info, info_dtype, info_repr=None, info_nullable=False): assert isinstance(info.dtype, np.dtype) assert info.dtype == info_dtype @@ -250,6 +250,96 @@ def test_not_implemented(self, type_specs): # check that the column name is included in the error message assert "supported (column foo)" in str(exc.value) + def test_apply_dtype_index_ingest(self): + uri = self.path("index_dtype_default_dtype") + tiledb.from_pandas( + uri, + pd.DataFrame({"a": np.arange(0, 20), "b": np.arange(20, 40)}), + sparse=True, + index_dims=["a"], + ) + with tiledb.open(uri) as A: + if sys.platform == "win32" and sys.version_info[:2] == (3, 8): + assert A.schema.domain.dim(0).dtype == np.int32 + else: + assert A.schema.domain.dim(0).dtype == np.int64 + + uri = self.path("index_dtype_casted_dtype") + tiledb.from_pandas( + uri, + pd.DataFrame({"a": np.arange(0, 20), "b": np.arange(20, 40)}), + sparse=True, + index_dims=["a"], + column_types={"a": np.uint8}, + ) + with tiledb.open(uri) as A: + assert A.schema.domain.dim(0).dtype == np.uint8 + + # multiple index dims + uri = self.path("index_dtype_default_dtype_multi") + tiledb.from_pandas( + uri, + pd.DataFrame( + { + "a": np.random.random_sample(20), + "b": [str(uuid.uuid4()) for _ in range(20)], + } + ), + sparse=True, + index_dims=["a", "b"], + ) + with tiledb.open(uri) as A: + assert A.schema.domain.dim(0).dtype == np.float64 + assert A.schema.domain.dim(1).dtype == np.bytes_ + + uri = self.path("index_dtype_casted_dtype_multi") + tiledb.from_pandas( + uri, + pd.DataFrame( + { + "a": np.random.random_sample(20), + "b": [str(uuid.uuid4()) for _ in range(20)], + } + ), + sparse=True, + index_dims=["a", "b"], + column_types={"a": np.float32, "b": np.bytes_}, + ) + with tiledb.open(uri) as A: + assert A.schema.domain.dim(0).dtype == np.float32 + assert A.schema.domain.dim(1).dtype == np.bytes_ + + def test_apply_dtype_index_schema_only(self): + uri = self.path("index_dtype_casted_dtype") + tiledb.from_pandas( + uri, + pd.DataFrame({"a": np.arange(0, 20), "b": np.arange(20, 40)}), + sparse=True, + index_dims=["a"], + column_types={"a": np.uint8}, + mode="schema_only", + ) + with tiledb.open(uri) as A: + assert A.schema.domain.dim(0).dtype == np.uint8 + + uri = self.path("index_dtype_casted_dtype_multi") + tiledb.from_pandas( + uri, + pd.DataFrame( + { + "a": np.random.random_sample(20), + "b": [str(uuid.uuid4()) for _ in range(20)], + } + ), + sparse=True, + index_dims=["a", "b"], + column_types={"a": np.float32, "b": np.bytes_}, + mode="schema_only", + ) + with tiledb.open(uri) as A: + assert A.schema.domain.dim(0).dtype == np.float32 + assert A.schema.domain.dim(1).dtype == np.bytes_ + class TestDimType: @pytest.mark.parametrize(