From b33c8360acdab6f245097e2011353f8379b92dad Mon Sep 17 00:00:00 2001 From: Agis Kounelis Date: Wed, 3 Apr 2024 21:16:23 +0300 Subject: [PATCH] Add check for Group constructor uri type and fix tests --- tiledb/group.py | 6 ++++++ tiledb/tests/test_group.py | 17 ++++++++++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/tiledb/group.py b/tiledb/group.py index a3435565e2..e38baba611 100644 --- a/tiledb/group.py +++ b/tiledb/group.py @@ -2,6 +2,7 @@ import numpy as np +import tiledb import tiledb.cc as lt from .ctx import Config, Ctx, CtxMixin, default_ctx @@ -264,6 +265,11 @@ def __init__( config: Config = None, ctx: Optional[Ctx] = None, ): + if uri is not None and tiledb.object_type(uri) != "group": + raise ValueError( + f"uri '{uri}' is not a valid TileDB Group path, uri is of type {tiledb.object_type(uri)}" + ) + if mode not in Group._mode_to_query_type: raise ValueError(f"invalid mode {mode}") query_type = Group._mode_to_query_type[mode] diff --git a/tiledb/tests/test_group.py b/tiledb/tests/test_group.py index 7466c1d461..cb36e905bc 100644 --- a/tiledb/tests/test_group.py +++ b/tiledb/tests/test_group.py @@ -171,7 +171,7 @@ def test_group_members(self): grp = tiledb.Group(grp_path, "w") assert os.path.basename(grp.uri) == os.path.basename(grp_path) - array_path = self.path("test_group_members") + array_path = self.path("test_group_members_array") domain = tiledb.Domain(tiledb.Dim(domain=(1, 8), tile=2)) a1 = tiledb.Attr("val", dtype="f8") schema = tiledb.ArraySchema(domain=domain, attrs=(a1,)) @@ -203,7 +203,7 @@ def test_group_members(self): assert grp[1].name is None assert "test_group_members GROUP" in repr(grp) - assert "|-- test_group_members ARRAY" in repr(grp) + assert "|-- test_group_members_array ARRAY" in repr(grp) assert "|-- test_group_0 GROUP" in repr(grp) grp.close() @@ -342,8 +342,19 @@ def test_set_config(self): with tiledb.Group(group_uri, config=cfg) as G: assert len(G) == sz + def test_invalid_object_type(self): + path = self.path() + schema = tiledb.ArraySchema( + domain=tiledb.Domain(tiledb.Dim("id", dtype="ascii")), + attrs=(tiledb.Attr("value", dtype=np.int64),), + sparse=True, + ) + tiledb.Array.create(path, schema) + with self.assertRaises(ValueError): + tiledb.Group(uri=path, mode="w") + def test_group_does_not_exist(self): - with self.assertRaises(tiledb.TileDBError): + with self.assertRaises(ValueError): tiledb.Group("does-not-exist")