From 19bd09edf4687ea0569681e8f138c6f8ca2c919a Mon Sep 17 00:00:00 2001 From: Agisilaos Kounelis <36283973+kounelisagis@users.noreply.github.com> Date: Mon, 1 Jul 2024 21:59:17 +0300 Subject: [PATCH] Fix array.query() incorrectly handling nullables (#1998) --- tiledb/libtiledb.pyx | 12 ++++++------ tiledb/tests/test_enumeration.py | 1 + tiledb/tests/test_libtiledb.py | 4 ++++ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tiledb/libtiledb.pyx b/tiledb/libtiledb.pyx index 3f81f6a74e..8dd943976e 100644 --- a/tiledb/libtiledb.pyx +++ b/tiledb/libtiledb.pyx @@ -2170,14 +2170,14 @@ cdef class DenseArrayImpl(Array): if attr.isnullable: data = np.array([values[idx] for idx in result[attr.name].data]) result[attr.name] = np.ma.array( - data, mask=~result[attr.name].mask) + data, mask=result[attr.name].mask) else: result[attr.name] = np.array( [values[idx] for idx in result[attr.name]]) else: if attr.isnullable: result[attr.name] = np.ma.array(result[attr.name].data, - mask=~result[attr.name].mask) + mask=result[attr.name].mask) return result @@ -2429,7 +2429,7 @@ cdef class DenseArrayImpl(Array): out[name] = arr if self.schema.has_attr(name) and self.attr(name).isnullable: - out[name] = np.ma.array(out[name], mask=results[name][2].astype(bool)) + out[name] = np.ma.array(out[name], mask=~results[name][2].astype(bool)) return out @@ -3251,14 +3251,14 @@ cdef class SparseArrayImpl(Array): if attr.isnullable: data = np.array([values[idx] for idx in result[attr.name].data]) result[attr.name] = np.ma.array( - data, mask=~result[attr.name].mask) + data, mask=result[attr.name].mask) else: result[attr.name] = np.array( [values[idx] for idx in result[attr.name]]) else: if attr.isnullable: result[attr.name] = np.ma.array(result[attr.name].data, - mask=~result[attr.name].mask) + mask=result[attr.name].mask) return result @@ -3559,7 +3559,7 @@ cdef class SparseArrayImpl(Array): out[final_name] = arr if self.schema.has_attr(final_name) and self.attr(final_name).isnullable: - out[final_name] = np.ma.array(out[final_name], mask=results[name][2]) + out[final_name] = np.ma.array(out[final_name], mask=~results[name][2].astype(bool)) return out diff --git a/tiledb/tests/test_enumeration.py b/tiledb/tests/test_enumeration.py index f918479d8d..086649ecde 100644 --- a/tiledb/tests/test_enumeration.py +++ b/tiledb/tests/test_enumeration.py @@ -140,6 +140,7 @@ def test_array_schema_enumeration_nullable(self, sparse, pass_df): expected_validity = [False, False, True, False, False] assert_array_equal(A[:]["a"].mask, expected_validity) assert_array_equal(A.df[:]["a"].isna(), expected_validity) + assert_array_equal(A.query(attrs=["a"])[:]["a"].mask, expected_validity) @pytest.mark.parametrize( "dtype, values", diff --git a/tiledb/tests/test_libtiledb.py b/tiledb/tests/test_libtiledb.py index 996b7e0649..bbfa614b9f 100644 --- a/tiledb/tests/test_libtiledb.py +++ b/tiledb/tests/test_libtiledb.py @@ -429,10 +429,12 @@ def test_array_write_nullable(self, sparse, pass_df): expected_validity1 = [False, False, True, False, False] assert_array_equal(A[:]["a1"].mask, expected_validity1) assert_array_equal(A.df[:]["a1"].isna(), expected_validity1) + assert_array_equal(A.query(attrs=["a1"])[:]["a1"].mask, expected_validity1) expected_validity2 = [False, False, True, True, False] assert_array_equal(A[:]["a2"].mask, expected_validity2) assert_array_equal(A.df[:]["a2"].isna(), expected_validity2) + assert_array_equal(A.query(attrs=["a2"])[:]["a2"].mask, expected_validity2) with tiledb.open(uri, "w") as A: dims = pa.array([1, 2, 3, 4, 5]) @@ -452,10 +454,12 @@ def test_array_write_nullable(self, sparse, pass_df): expected_validity1 = [True, True, True, True, True] assert_array_equal(A[:]["a1"].mask, expected_validity1) assert_array_equal(A.df[:]["a1"].isna(), expected_validity1) + assert_array_equal(A.query(attrs=["a1"])[:]["a1"].mask, expected_validity1) expected_validity2 = [True, True, True, True, True] assert_array_equal(A[:]["a2"].mask, expected_validity2) assert_array_equal(A.df[:]["a2"].isna(), expected_validity2) + assert_array_equal(A.query(attrs=["a2"])[:]["a2"].mask, expected_validity2) class DenseArrayTest(DiskTestCase):