Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kounelisagis committed Dec 6, 2024
1 parent a1046ed commit 865ecba
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 26 deletions.
43 changes: 20 additions & 23 deletions tiledb/array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from typing import Dict, List

import numpy as np

Expand Down Expand Up @@ -771,13 +772,13 @@ def dindex(self):
return self.domain_index

def _write_array(
tiledb_array,
self,
subarray,
coordinates: list,
buffer_names: list,
values: list,
labels: dict,
nullmaps: dict,
coordinates: List,
buffer_names: List,
values: List,
labels: Dict,
nullmaps: Dict,
issparse: bool,
):
# used for buffer conversion (local import to avoid circularity)
Expand All @@ -790,7 +791,7 @@ def _write_array(
# Create arrays to hold buffer sizes
nbuffer = nattr + nlabel
if issparse:
nbuffer += tiledb_array.schema.ndim
nbuffer += self.schema.ndim
buffer_sizes = np.zeros((nbuffer,), dtype=np.uint64)

# Create lists for data and offset buffers
Expand All @@ -799,17 +800,16 @@ def _write_array(

# Set data and offset buffers for attributes
for i in range(nattr):
attr = self.schema.attr(i)
# if dtype is ASCII, ensure all characters are valid
if tiledb_array.schema.attr(i).isascii:
if attr.isascii:
try:
values[i] = np.asarray(values[i], dtype=np.bytes_)
except Exception as exc:
raise tiledb.TileDBError(
f'dtype of attr {tiledb_array.schema.attr(i).name} is "ascii" but attr_val contains invalid ASCII characters'
f'dtype of attr {attr.name} is "ascii" but attr_val contains invalid ASCII characters'
)

attr = tiledb_array.schema.attr(i)

if attr.isvar:
try:
if attr.isnullable:
Expand Down Expand Up @@ -847,13 +847,12 @@ def _write_array(
ibuffer = nattr
if issparse:
for dim_idx, coords in enumerate(coordinates):
dim = tiledb_array.schema.domain.dim(dim_idx)
dim = self.schema.domain.dim(dim_idx)
if dim.isvar:
buffer, offsets = array_to_buffer(coords, True, False)
buffer_sizes[ibuffer] = buffer.nbytes // (dim.dtype.itemsize or 1)
else:
buffer, offsets = coords, None
buffer_sizes[ibuffer] = buffer.nbytes // (dim.dtype.itemsize or 1)
buffer_sizes[ibuffer] = buffer.nbytes // (dim.dtype.itemsize or 1)
output_values.append(buffer)
output_offsets.append(offsets)

Expand All @@ -866,30 +865,28 @@ def _write_array(
# Append buffer name
buffer_names.append(label_name)
# Get label data buffer and offsets buffer for the labels
dim_label = tiledb_array.schema.dim_label(label_name)
dim_label = self.schema.dim_label(label_name)
if dim_label.isvar:
buffer, offsets = array_to_buffer(label_values, True, False)
buffer_sizes[ibuffer] = buffer.nbytes // (dim_label.dtype.itemsize or 1)
else:
buffer, offsets = label_values, None
buffer_sizes[ibuffer] = buffer.nbytes // (dim_label.dtype.itemsize or 1)
buffer_sizes[ibuffer] = buffer.nbytes // (dim_label.dtype.itemsize or 1)
# Append the buffers
output_values.append(buffer)
output_offsets.append(offsets)

ibuffer = ibuffer + 1

# Allocate the query
ctx = lt.Context(tiledb_array.ctx)
q = lt.Query(ctx, tiledb_array.array, lt.QueryType.WRITE)
ctx = lt.Context(self.ctx)
q = lt.Query(ctx, self.array, lt.QueryType.WRITE)

# Set the layout
layout = (
q.layout = (
lt.LayoutType.UNORDERED
if issparse
else (lt.LayoutType.COL_MAJOR if isfortran else lt.LayoutType.ROW_MAJOR)
)
q.layout = layout

# Create and set the subarray for the query (dense arrays only)
if not issparse:
Expand Down Expand Up @@ -920,8 +917,8 @@ def _write_array(
q._submit()
q.finalize()

fragment_info = tiledb_array.last_fragment_info
if fragment_info is not False:
fragment_info = self.last_fragment_info
if fragment_info != False:
if not isinstance(fragment_info, dict):
raise ValueError(
f"Expected fragment_info to be a dict, got {type(fragment_info)}"
Expand Down
6 changes: 3 additions & 3 deletions tiledb/cc/query.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ void init_query(py::module &m) {
.def("set_data_buffer",
[](Query &q, std::string name, py::array a, uint64_t nelements) {
QueryExperimental::set_data_buffer(
q, name, const_cast<void *>(a.data()), nelements);
q, name, a.mutable_data(), nelements);
})

.def("set_offsets_buffer",
[](Query &q, std::string name, py::array a, uint64_t nelements) {
q.set_offsets_buffer(name, (uint64_t *)(a.data()), nelements);
q.set_offsets_buffer(name, static_cast<uint64_t*>(a.mutable_data()), nelements);
})

.def("set_subarray",
Expand All @@ -90,7 +90,7 @@ void init_query(py::module &m) {

.def("set_validity_buffer",
[](Query &q, std::string name, py::array a, uint64_t nelements) {
q.set_validity_buffer(name, (uint8_t *)(a.data()), nelements);
q.set_validity_buffer(name, static_cast<uint8_t*>(a.mutable_data()), nelements);
})

.def("_submit", &Query::submit, py::call_guard<py::gil_scoped_release>())
Expand Down

0 comments on commit 865ecba

Please sign in to comment.