Skip to content

Commit

Permalink
Make from_pandas respect column_types for index dimensions (#2046)
Browse files Browse the repository at this point in the history
  • Loading branch information
kounelisagis authored Aug 29, 2024
1 parent b7f8c13 commit 2dc7562
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 3 deletions.
10 changes: 8 additions & 2 deletions tiledb/dataframe_.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,9 @@ def _sparse_from_dtypes(dtypes, sparse=None):
return sparse if sparse is not None else False


def create_dims(df, index_dims, tile=None, full_domain=False, filters=None):
def create_dims(
df, index_dims, column_infos, tile=None, full_domain=False, filters=None
):
check_dataframe_deps()
import pandas as pd

Expand All @@ -445,7 +447,10 @@ def create_dims(df, index_dims, tile=None, full_domain=False, filters=None):
else:
raise ValueError(f"Unknown column or index named {name!r}")

dtype = ColumnInfo.from_values(values).dtype
if name in column_infos:
dtype = column_infos[name].dtype
else:
dtype = ColumnInfo.from_values(values).dtype
internal_dtype = dtype

if name == "__tiledb_rows" and isinstance(index, pd.RangeIndex):
Expand Down Expand Up @@ -659,6 +664,7 @@ def _create_array(uri, df, sparse, full_domain, index_dims, column_infos, tiledb
dims, dim_metadata = create_dims(
df,
index_dims,
column_infos,
full_domain=full_domain,
tile=tiledb_args.get("tile"),
filters=tiledb_args.get("dim_filters", True),
Expand Down
92 changes: 91 additions & 1 deletion tiledb/tests/test_pandas_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def make_dataframe_categorical():
return df


class TestColumnInfo:
class TestColumnInfo(DiskTestCase):
def assertColumnInfo(self, info, info_dtype, info_repr=None, info_nullable=False):
assert isinstance(info.dtype, np.dtype)
assert info.dtype == info_dtype
Expand Down Expand Up @@ -250,6 +250,96 @@ def test_not_implemented(self, type_specs):
# check that the column name is included in the error message
assert "supported (column foo)" in str(exc.value)

def test_apply_dtype_index_ingest(self):
uri = self.path("index_dtype_default_dtype")
tiledb.from_pandas(
uri,
pd.DataFrame({"a": np.arange(0, 20), "b": np.arange(20, 40)}),
sparse=True,
index_dims=["a"],
)
with tiledb.open(uri) as A:
if sys.platform == "win32" and sys.version_info[:2] == (3, 8):
assert A.schema.domain.dim(0).dtype == np.int32
else:
assert A.schema.domain.dim(0).dtype == np.int64

uri = self.path("index_dtype_casted_dtype")
tiledb.from_pandas(
uri,
pd.DataFrame({"a": np.arange(0, 20), "b": np.arange(20, 40)}),
sparse=True,
index_dims=["a"],
column_types={"a": np.uint8},
)
with tiledb.open(uri) as A:
assert A.schema.domain.dim(0).dtype == np.uint8

# multiple index dims
uri = self.path("index_dtype_default_dtype_multi")
tiledb.from_pandas(
uri,
pd.DataFrame(
{
"a": np.random.random_sample(20),
"b": [str(uuid.uuid4()) for _ in range(20)],
}
),
sparse=True,
index_dims=["a", "b"],
)
with tiledb.open(uri) as A:
assert A.schema.domain.dim(0).dtype == np.float64
assert A.schema.domain.dim(1).dtype == np.bytes_

uri = self.path("index_dtype_casted_dtype_multi")
tiledb.from_pandas(
uri,
pd.DataFrame(
{
"a": np.random.random_sample(20),
"b": [str(uuid.uuid4()) for _ in range(20)],
}
),
sparse=True,
index_dims=["a", "b"],
column_types={"a": np.float32, "b": np.bytes_},
)
with tiledb.open(uri) as A:
assert A.schema.domain.dim(0).dtype == np.float32
assert A.schema.domain.dim(1).dtype == np.bytes_

def test_apply_dtype_index_schema_only(self):
uri = self.path("index_dtype_casted_dtype")
tiledb.from_pandas(
uri,
pd.DataFrame({"a": np.arange(0, 20), "b": np.arange(20, 40)}),
sparse=True,
index_dims=["a"],
column_types={"a": np.uint8},
mode="schema_only",
)
with tiledb.open(uri) as A:
assert A.schema.domain.dim(0).dtype == np.uint8

uri = self.path("index_dtype_casted_dtype_multi")
tiledb.from_pandas(
uri,
pd.DataFrame(
{
"a": np.random.random_sample(20),
"b": [str(uuid.uuid4()) for _ in range(20)],
}
),
sparse=True,
index_dims=["a", "b"],
column_types={"a": np.float32, "b": np.bytes_},
mode="schema_only",
)
with tiledb.open(uri) as A:
assert A.schema.domain.dim(0).dtype == np.float32
assert A.schema.domain.dim(1).dtype == np.bytes_


class TestDimType:
@pytest.mark.parametrize(
Expand Down

0 comments on commit 2dc7562

Please sign in to comment.