diff --git a/tiledb/enumeration.py b/tiledb/enumeration.py index d5c086728a..4310efe942 100644 --- a/tiledb/enumeration.py +++ b/tiledb/enumeration.py @@ -33,6 +33,8 @@ def __init__( :type ordered: bool :param values: A Numpy array of values for this enumeration :type values: np.array + :param dtype: The Numpy data type for this enumeration + :type dtype: np.dtype :param ctx: A TileDB context :type ctx: tiledb.Ctx """ diff --git a/tiledb/tests/test_enumeration.py b/tiledb/tests/test_enumeration.py index a1124cc014..f918479d8d 100644 --- a/tiledb/tests/test_enumeration.py +++ b/tiledb/tests/test_enumeration.py @@ -1,3 +1,5 @@ +import re + import numpy as np import pytest from numpy.testing import assert_array_equal @@ -37,6 +39,28 @@ def test_attribute_enumeration(self): attr.enum = "enum" assert attr.enum == "enum" + def test_enumeration_repr(self): + """Doesn't check exact string, just makes sure each component is matched, in case order is changed in the future.""" + enmr = tiledb.Enumeration("e", False, [1, 2, 3]) + # Get its string representation + repr_str = repr(enmr) + + # Define patterns to match each component in the representation + patterns = { + "Enumeration": r"Enumeration", + "name": r"name='e'", + # use regex because it is depending on platform + "dtype": r"dtype=int\d+", + "dtype_name": r"dtype_name='int\d+'", + "cell_val_num": r"cell_val_num=1", + "ordered": r"ordered=False", + "values": r"values=\[1, 2, 3\]", + } + + # Check that each pattern is found in the representation string + for key, pattern in patterns.items(): + assert re.search(pattern, repr_str), f"{key} not found or incorrect in repr" + def test_array_schema_enumeration(self): uri = self.path("test_array_schema_enumeration") dom = tiledb.Domain(tiledb.Dim(domain=(1, 8), tile=1))