Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable limit for followlinks #299

Merged
merged 2 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hsds/attr_sn.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ async def GET_Attributes(request):
msg = f"DomainCrawler returned: {len(attributes)} objects"
log.info(msg)
else:
# just get attributes for this objects
# just get attributes for this object
kwargs = {"bucket": bucket}
if include_data:
kwargs["include_data"] = True
Expand Down
45 changes: 37 additions & 8 deletions hsds/domain_crawl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
self._create_order = create_order
self._pattern = pattern
self._limit = limit
self._count = 0 # items collected
self._replace = replace
self._max_tasks = max_tasks
self._q = asyncio.Queue()
Expand Down Expand Up @@ -175,16 +176,27 @@ async def get_attributes(self, obj_id, attr_names):
log.error(f"unexpected exception from post request: {e}")
status = 500

follow_links = self._follow_links
if isOK(status):
log.debug(f"got attributes: {attributes}")
if self._limit:
left = self._limit - self._count
if len(attributes) > left:
# truncate the attribute list
msg = f"limit reached, returning {left} attributes out"
msg += f"of {len(attributes)} for {obj_id}"
log.warn(msg)
attributes = attributes[:left]
follow_links = False
self._count += len(attributes)
self._obj_dict[obj_id] = attributes
else:
log.warn(f"Domain crawler - got {status} status for obj_id {obj_id}")
self._obj_dict[obj_id] = {"status": status}

collection = getCollectionForId(obj_id)

if collection == "groups" and self._follow_links:
if collection == "groups" and follow_links:
links = None
status = 200
try:
Expand Down Expand Up @@ -263,7 +275,9 @@ async def get_obj_json(self, obj_id):
status = 500
log.debug(f"getObjectJson status: {status}")

if obj_json is None:
if isOK(status):
log.debug(f"got obj json for: {obj_id}")
else:
msg = f"DomainCrawler - getObjectJson for {obj_id} "
if status >= 500:
msg += f"failed, status: {status}"
Expand All @@ -273,6 +287,8 @@ async def get_obj_json(self, obj_id):
log.warn(msg)
return

self._obj_dict[obj_id] = {"status": status}

log.debug(f"DomainCrawler - got json for {obj_id}")
log.debug(f"obj_json: {obj_json}")

Expand Down Expand Up @@ -344,7 +360,7 @@ async def get_links(self, grp_id, titles=None):
status = 500
log.debug(f"get_links status: {status}")

if links is None:
if not isOK(status):
msg = f"DomainCrawler - get_links for {grp_id} "
if status >= 500:
msg += f"failed, status: {status}"
Expand All @@ -366,14 +382,27 @@ async def get_links(self, grp_id, titles=None):
msg = f"getLinks with pattern: {pattern} returning "
msg += f"{len(filtered_links)} links from {len(links)}"
log.debug(msg)
log.debug(f"save to obj_dict: {filtered_links}")
self._obj_dict[grp_id] = filtered_links
new_links = filtered_links
else:
log.debug(f"save to obj_dict: {links}")
self._obj_dict[grp_id] = links # store the links
new_links = links # store the links

follow_links = self._follow_links
# check that we are not exceeding the limit
if self._limit:
left = self._limit - self._count
if left < len(new_links):
# will need to truncate this list
msg = f"limit reached, adding {left} new links out"
msg += f" of {len(new_links)} for {grp_id}"
log.warn(msg)
new_links = new_links[:left]
follow_links = False # no need to search more
self._count += len(new_links)
log.debug(f"adding {len(new_links)} to obj_dict for {grp_id}")
self._obj_dict[grp_id] = new_links

# if follow_links, add any group links to the lookup ids set
if self._follow_links:
if follow_links:
self.follow_links(grp_id, links)

async def put_links(self, grp_id, link_items):
Expand Down
2 changes: 2 additions & 0 deletions hsds/link_sn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ async def GET_Links(request):

kwargs = {"action": "get_link", "bucket": bucket, "follow_links": True}
kwargs["include_links"] = True
if limit:
kwargs["limit"] = limit
items = [group_id, ]
crawler = DomainCrawler(app, items, **kwargs)

Expand Down
22 changes: 4 additions & 18 deletions tests/integ/attr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2368,8 +2368,9 @@ def testGetRecursive(self):
self.assertTrue("value" in attrJson)

# same thing with Limit
limit = 3
req = helper.getEndpoint() + "/groups/" + root_uuid + "/attributes"
params = {"follow_links": "1", "Limit": 1}
params = {"follow_links": "1", "Limit": limit}
rsp = self.session.get(req, params=params, headers=headers)
self.assertEqual(rsp.status_code, 200)
rspJson = json.loads(rsp.text)
Expand All @@ -2378,24 +2379,9 @@ def testGetRecursive(self):
self.assertEqual(len(obj_map), 10)
attr_count = 0
for obj_id in obj_map:
self.assertTrue(len(obj_map[obj_id]) <= 1)
self.assertTrue(len(obj_map[obj_id]) <= limit)
attr_count += len(obj_map[obj_id])
self.assertEqual(attr_count, 2)
for obj_id in (root_uuid, d111_uuid):
# these are the only two objects with attributes
self.assertTrue(obj_id in obj_map)
obj_attrs = obj_map[obj_id]
self.assertEqual(len(obj_attrs), 1)
for attrJson in obj_attrs:
self.assertTrue("name" in attrJson)
attr_name = attrJson["name"]
self.assertTrue(attr_name in ("attr1", "attr2"))
self.assertTrue("type" in attrJson)
self.assertTrue("shape" in attrJson)
shapeJson = attrJson["shape"]
self.assertEqual(shapeJson["class"], "H5S_SIMPLE")
self.assertTrue("created" in attrJson)
self.assertFalse("value" in attrJson)
self.assertEqual(attr_count, limit)

# do a get with encoding
req = helper.getEndpoint() + "/groups/" + root_uuid + "/attributes"
Expand Down
18 changes: 18 additions & 0 deletions tests/integ/link_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,24 @@ def testGetRecursive(self):
self.assertEqual(softlink_count, len(expected_soft_links))
self.assertEqual(extlink_count, len(expected_external_links))

# test follow with limit
# get links for root group and other groups recursively
req = helper.getEndpoint() + "/groups/" + root_uuid + "/links"
limit = 5
params = {"follow_links": 1, "Limit": limit}
rsp = self.session.get(req, params=params, headers=headers)
self.assertEqual(rsp.status_code, 200)
rspJson = json.loads(rsp.text)
self.assertTrue("hrefs" in rspJson)
hrefs = rspJson["hrefs"]
self.assertEqual(len(hrefs), 3)
self.assertTrue("links" in rspJson)
obj_map = rspJson["links"] # map of obj_ids to links
link_count = 0
for obj_id in obj_map:
link_count += len(obj_map[obj_id])
self.assertEqual(link_count, limit)

def testGetPattern(self):
# test getting links from an existing domain, with a glob filter
domain = helper.getTestDomain("tall.h5")
Expand Down
Loading