Skip to content

Commit

Permalink
Fix selection of multiple fields in body (#298)
Browse files Browse the repository at this point in the history
* Fix selection of multiple fields in body

* Fix field selection for binary request
  • Loading branch information
mattjala authored Jan 26, 2024
1 parent 9fa156c commit c5fab12
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 8 deletions.
17 changes: 12 additions & 5 deletions hsds/chunk_sn.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,18 +242,20 @@ def _getSelectDtype(params, dset_dtype, body=None):
""" if a field list is defined in params or body,
create a sub-type of the dset dtype. Else,
just return the dset dtype. """

select_fields = None
kw = "fields"

if isinstance(body, dict) and kw in body:
select_fields = body[kw]
log.debug(f"fields value in body: {select_fields}")
elif kw in params:
fields_param = params.get(kw)
log.debug(f"fields param: {fields_param}")
select_fields = fields_param.split(":")
select_fields = params.get(kw)
log.debug(f"fields param: {select_fields}")
else:
select_fields = None

if select_fields:
select_fields = select_fields.split(":")
try:
select_dtype = getSubType(dset_dtype, select_fields)
except TypeError as te:
Expand Down Expand Up @@ -558,6 +560,10 @@ async def _doHyperslabWrite(app,
log.info(f"_doHyperslabWrite on {dset_id} - page: {page_number}")
type_json = dset_json["type"]
item_size = getItemSize(type_json)

if (select_dtype is not None):
item_size = select_dtype.itemsize

layout = getChunkLayout(dset_json)

num_chunks = getNumChunks(page, layout)
Expand Down Expand Up @@ -706,6 +712,7 @@ async def PUT_Value(request):
log.debug(f"got body: {body}")

select_dtype = _getSelectDtype(params, dset_dtype, body=body)
select_item_size = select_dtype.itemsize
append_rows = _getAppendRows(params, dset_json, body=body)

if append_rows:
Expand Down Expand Up @@ -904,7 +911,7 @@ async def PUT_Value(request):
log.debug(f"non-streaming data, setting page list to: {slices}")
else:
max_request_size = int(config.get("max_request_size"))
pages = getSelectionPagination(slices, dims, item_size, max_request_size)
pages = getSelectionPagination(slices, dims, select_item_size, max_request_size)
log.debug(f"getSelectionPagination returned: {len(pages)} pages")

for page_number in range(len(pages)):
Expand Down
87 changes: 84 additions & 3 deletions tests/integ/pointsel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import helper
import config
from hsds.util.arrayUtil import arrayToBytes


class PointSelTest(unittest.TestCase):
Expand Down Expand Up @@ -1605,7 +1606,7 @@ def testPostCompoundDataset(self):
# create 1d dataset
#

field_names = ("x1", "X2", "x3", "X4", "X5")
field_names = ("x1", "x2", "x3", "x4", "x5")

fields = []
for field_name in field_names:
Expand Down Expand Up @@ -1643,7 +1644,7 @@ def testPostCompoundDataset(self):
for i in range(len(points)):
self.assertEqual(ret_value[i], [0, 0, 0, 0, 0])

# write to the dset by fields
# write to the dset by field
for field in field_names:
x = int(field[1]) # get the number part of the field name
data = [(x * i) for i in range(num_elements)]
Expand All @@ -1666,7 +1667,87 @@ def testPostCompoundDataset(self):
self.assertEqual(len(ret_value), len(points))
for i in range(len(points)):
self.assertEqual(ret_value[i], [x * points[i]])
return

# Write "100" to first field and "200" to second field through body
data = [(100, 200) for i in range(num_elements)]
payload = {"value": data, "fields": field_names[0] + ":" + field_names[1]}
req = self.endpoint + "/datasets/" + dset_id + "/value"
rsp = self.session.put(req, data=json.dumps(payload), headers=headers)
self.assertEqual(rsp.status_code, 200)

# read back entire dataset and check values
rsp = self.session.get(req, headers=headers)
self.assertEqual(rsp.status_code, 200)
rspJson = json.loads(rsp.text)
self.assertTrue("value" in rspJson)
ret_value = np.array(rspJson["value"], dtype=int)
self.assertTrue(np.array_equal(ret_value[:, 0],
np.full(shape=num_elements, fill_value=100, dtype=int)))
self.assertTrue(np.array_equal(ret_value[:, 1],
np.full(shape=num_elements, fill_value=200, dtype=int)))
for i in range(2, 5):
self.assertTrue(np.array_equal(ret_value[:, i], [(i + 1) * j for j in range(100)]))

# Write 300 to third field and 400 to fourth field through URL
data = [(300, 400) for i in range(num_elements)]
payload = {"value": data}
req = self.endpoint + "/datasets/" + dset_id + \
"/value?fields=" + field_names[2] + ":" + field_names[3]
rsp = self.session.put(req, data=json.dumps(payload), headers=headers)
self.assertEqual(rsp.status_code, 200)

# read back entire dataset and check values
req = self.endpoint + "/datasets/" + dset_id + "/value"
rsp = self.session.get(req, headers=headers)
self.assertEqual(rsp.status_code, 200)
rspJson = json.loads(rsp.text)
self.assertTrue("value" in rspJson)
ret_value = np.array(rspJson["value"], dtype=int)
for i in range(1, 4):
expected = np.full(shape=num_elements, fill_value=((i + 1) * 100), dtype=int)
self.assertTrue(np.array_equal(ret_value[:, i], expected))
self.assertTrue(np.array_equal(ret_value[:, 4], [5 * j for j in range(100)]))

# Test non-adjacent fields
# Write 1000 to first field and 500 to fifth field through body
data = [(1000, 500) for i in range(num_elements)]
payload = {"value": data, "fields": field_names[0] + ":" + field_names[4]}
req = self.endpoint + "/datasets/" + dset_id + "/value"
rsp = self.session.put(req, data=json.dumps(payload), headers=headers)
self.assertEqual(rsp.status_code, 200)

# read back entire dataset and check values
rsp = self.session.get(req, headers=headers)
self.assertEqual(rsp.status_code, 200)
rspJson = json.loads(rsp.text)
self.assertTrue("value" in rspJson)
ret_value = np.array(rspJson["value"], dtype=int)
self.assertTrue(np.array_equal(ret_value[:, 0],
np.full(shape=num_elements, fill_value=1000, dtype=int)))
for i in range(2, 5):
self.assertTrue(np.array_equal(ret_value[:, i], [(i + 1) * 100 for j in range(100)]))

# try to write to first field through binary request
arr = np.array([(10000,) for i in range(num_elements)], dtype=np.int32)
data = arrayToBytes(arr)
req = self.endpoint + "/datasets/" + dset_id + "/value?fields=" + field_names[0]
headers["Content-Type"] = "application/octet-stream"
rsp = self.session.put(req, data=data, headers=headers)
self.assertEqual(rsp.status_code, 200)

# read back entire dataset and check values
req = self.endpoint + "/datasets/" + dset_id + "/value"
headers["Content-Type"] = "application/json"
rsp = self.session.get(req, headers=headers)
self.assertEqual(rsp.status_code, 200)
rspJson = json.loads(rsp.text)
self.assertTrue("value" in rspJson)
ret_value = np.array(rspJson["value"], dtype=int)
print(f"ret value = {ret_value}")
self.assertTrue(np.array_equal(ret_value[:, 0],
np.full(shape=num_elements, fill_value=10000, dtype=int)))
for i in range(2, 5):
self.assertTrue(np.array_equal(ret_value[:, i], [(i + 1) * 100 for j in range(100)]))


if __name__ == "__main__":
Expand Down

0 comments on commit c5fab12

Please sign in to comment.