Skip to content

Commit

Permalink
clean up getdomain logic
Browse files Browse the repository at this point in the history
  • Loading branch information
jreadey committed May 30, 2024
1 parent ee02ea1 commit 27d830b
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 175 deletions.
191 changes: 24 additions & 167 deletions hsds/domain_sn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
import os.path as op

from aiohttp.web_exceptions import HTTPBadRequest, HTTPForbidden, HTTPNotFound
from aiohttp.web_exceptions import HTTPGone, HTTPInternalServerError
from aiohttp.web_exceptions import HTTPInternalServerError
from aiohttp.web_exceptions import HTTPConflict, HTTPServiceUnavailable
from aiohttp import ClientResponseError
from aiohttp.web import json_response

from .util.httpUtil import getObjectClass, http_post, http_put, http_delete
Expand Down Expand Up @@ -161,14 +160,13 @@ async def get_domains(request):
raise HTTPServiceUnavailable()

# allow domain with / to indicate a folder
prefix = None
try:
prefix = getDomainFromRequest(request, validate=False)
except ValueError:
pass # igore
if not prefix:
folder_path = getDomainFromRequest(request, validate=False)

if not folder_path:
# if there is no domain passed in, get a list of top level domains
prefix = "/"
folder_path = "/"

prefix = getPathForDomain(folder_path) # don't include the bucket if any

if "pattern" not in request.rel_url.query:
pattern = None
Expand All @@ -191,11 +189,6 @@ async def get_domains(request):

log.info(f"get_domains for: {prefix} verbose: {verbose}")

if not prefix.startswith("/"):
msg = "Prefix must start with '/'"
log.warn(msg)
raise HTTPBadRequest(reason=msg)

limit = None
if "Limit" in request.rel_url.query:
try:
Expand All @@ -217,10 +210,13 @@ async def get_domains(request):
bucket = params["bucket"]
elif "X-Hdf-bucket" in request.headers:
bucket = request.headers["X-Hdf-bucket"]
elif getBucketForDomain(folder_path):
bucket = getBucketForDomain(folder_path)
elif "bucket_name" in app and app["bucket_name"]:
bucket = app["bucket_name"]
else:
bucket = None

if not bucket:
msg = "no bucket specified for request"
log.warn(msg)
Expand Down Expand Up @@ -473,22 +469,10 @@ async def GET_Domain(request):
log.response(request, resp=resp)
return resp

log.info(f"got domain: {domain}")
log.info(f"get domain: {domain}")

domain_json = await getDomainJson(app, domain, reload=True)

if domain_json is None:
log.warn(f"domain: {domain} not found")
raise HTTPNotFound()

if "owner" not in domain_json:
log.error("No owner key found in domain")
raise HTTPInternalServerError()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()

log.debug(f"got domain_json: {domain_json}")
# validate that the requesting user has permission to read this domain
# aclCheck throws exception if not authorized
Expand Down Expand Up @@ -703,14 +687,6 @@ async def POST_Domain(request):

domain_json = await getDomainJson(app, domain, reload=True)

if domain_json is None:
log.warn(f"domain: {domain} not found")
raise HTTPNotFound()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()

if "root" not in domain_json:
msg = f"{domain} is a folder, not a domain"
log.warn(msg)
Expand Down Expand Up @@ -845,17 +821,6 @@ async def PUT_Domain(request):
domain_json = await getDomainJson(app, domain, reload=True)
log.debug(f"got domain_json: {domain_json}")

if domain_json is None:
log.warn(f"domain: {domain} not found")
raise HTTPNotFound()

if "owner" not in domain_json:
log.error("No owner key found in domain")
raise HTTPInternalServerError()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()
# throws exception if not allowed
aclCheck(app, domain_json, "update", username)
rsp_json = None
Expand Down Expand Up @@ -980,22 +945,10 @@ async def PUT_Domain(request):
log.warn(msg)
raise HTTPForbidden()

parent_json = None
if not is_toplevel:
try:
parent_json = await getDomainJson(app, parent_domain, reload=True)
except ClientResponseError as ce:
if ce.code == 404:
msg = f"Parent domain: {parent_domain} not found"
log.warn(msg)
raise HTTPNotFound()
elif ce.code == 410:
msg = f"Parent domain: {parent_domain} removed"
log.warn(msg)
raise HTTPGone()
else:
log.error(f"Unexpected error: {ce.code}")
raise HTTPInternalServerError()
if is_toplevel:
parent_json = None
else:
parent_json = await getDomainJson(app, parent_domain, reload=True)

log.debug(f"parent_json {parent_domain}: {parent_json}")
if "root" in parent_json and parent_json["root"]:
Expand Down Expand Up @@ -1041,12 +994,8 @@ async def PUT_Domain(request):
bucket = getBucketForDomain(domain)
if bucket:
post_params["bucket"] = bucket
try:
group_json = await http_post(app, req, data=group_json, params=post_params)
except ClientResponseError as ce:
msg = "Error creating root group for domain -- " + str(ce)
log.error(msg)
raise HTTPInternalServerError()
group_json = await http_post(app, req, data=group_json, params=post_params)

else:
log.debug("no root group, creating folder")

Expand Down Expand Up @@ -1094,12 +1043,7 @@ async def PUT_Domain(request):
body["root"] = root_id

log.debug(f"creating domain: {domain} with body: {body}")
try:
domain_json = await http_put(app, req, data=body)
except ClientResponseError as ce:
msg = "Error creating domain state -- " + str(ce)
log.error(msg)
raise HTTPInternalServerError()
domain_json = await http_put(app, req, data=body)

# domain creation successful
# mixin limits
Expand Down Expand Up @@ -1184,18 +1128,7 @@ async def DELETE_Domain(request):
log.warn(msg)
raise HTTPForbidden()

try:
domain_json = await getDomainJson(app, domain, reload=True)
except ClientResponseError as ce:
if ce.code == 404:
log.warn("domain not found")
raise HTTPNotFound()
elif ce.code == 410:
log.warn("domain has been removed")
raise HTTPGone()
else:
log.error(f"unexpected error: {ce.code}")
raise HTTPInternalServerError()
domain_json = await getDomainJson(app, domain, reload=True)

# throws exception if not allowed
aclCheck(app, domain_json, "delete", username)
Expand Down Expand Up @@ -1277,16 +1210,7 @@ async def GET_ACL(request):
checkBucketAccess(app, bucket)

# use reload to get authoritative domain json
try:
domain_json = await getDomainJson(app, domain, reload=True)
except ClientResponseError as ce:
if ce.code in (404, 410):
msg = "domain not found"
log.warn(msg)
raise HTTPNotFound()
else:
log.error(f"unexpected error: {ce.code}")
raise HTTPInternalServerError()
domain_json = await getDomainJson(app, domain, reload=True)

# validate that the requesting user has permission to read ACLs
# in this domain
Expand Down Expand Up @@ -1361,19 +1285,7 @@ async def GET_ACLs(request):
checkBucketAccess(app, bucket)

# use reload to get authoritative domain json
try:
domain_json = await getDomainJson(app, domain, reload=True)
except ClientResponseError:
log.warn("domain not found")
raise HTTPNotFound()

if "owner" not in domain_json:
log.error("No owner key found in domain")
raise HTTPInternalServerError()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()
domain_json = await getDomainJson(app, domain, reload=True)

acls = domain_json["acls"]

Expand Down Expand Up @@ -1498,28 +1410,7 @@ async def GET_Datasets(request):
bucket = getBucketForDomain(domain)

# verify the domain
try:
domain_json = await getDomainJson(app, domain)
except ClientResponseError as ce:
if ce.code == 404:
msg = f"Domain: {domain} not found"
log.warn(msg)
raise HTTPNotFound()
elif ce.code == 410:
msg = f"Domain: {domain} removed"
log.warn(msg)
raise HTTPGone()
else:
log.error(f"Unexpected error: {ce.code}")
raise HTTPInternalServerError()

if "owner" not in domain_json:
log.error("No owner key found in domain")
raise HTTPInternalServerError()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()
domain_json = await getDomainJson(app, domain)

log.debug(f"got domain_json: {domain_json}")
# validate that the requesting user has permission to read this domain
Expand Down Expand Up @@ -1588,24 +1479,7 @@ async def GET_Groups(request):
bucket = getBucketForDomain(domain)

# use reload to get authoritative domain json
try:
domain_json = await getDomainJson(app, domain, reload=True)
except ClientResponseError as ce:
if ce.code == 404:
msg = "domain not found"
log.warn(msg)
raise HTTPNotFound()
else:
log.error(f"Unexpected error: {ce.code}")
raise HTTPInternalServerError()

if "owner" not in domain_json:
log.error("No owner key found in domain")
raise HTTPInternalServerError()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()
domain_json = await getDomainJson(app, domain, reload=True)

log.debug(f"got domain_json: {domain_json}")
# validate that the requesting user has permission to read this domain
Expand Down Expand Up @@ -1673,24 +1547,7 @@ async def GET_Datatypes(request):
bucket = getBucketForDomain(domain)

# use reload to get authoritative domain json
try:
domain_json = await getDomainJson(app, domain, reload=True)
except ClientResponseError as ce:
if ce.code in (404, 410):
msg = "domain not found"
log.warn(msg)
raise HTTPNotFound()
else:
log.error(f"Unexpected Error: {ce.code})")
raise HTTPInternalServerError()

if "owner" not in domain_json:
log.error("No owner key found in domain")
raise HTTPInternalServerError()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()
domain_json = await getDomainJson(app, domain, reload=True)

log.debug(f"got domain_json: {domain_json}")
# validate that the requesting user has permission to read this domain
Expand Down
29 changes: 28 additions & 1 deletion hsds/servicenode_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from aiohttp.web_exceptions import HTTPBadRequest, HTTPForbidden, HTTPGone, HTTPConflict
from aiohttp.web_exceptions import HTTPNotFound, HTTPInternalServerError
from aiohttp.client_exceptions import ClientOSError, ClientError
from aiohttp import ClientResponseError

from .util.authUtil import getAclKeys
from .util.arrayUtil import encodeData
Expand Down Expand Up @@ -63,7 +64,33 @@ async def getDomainJson(app, domain, reload=False):

log.debug(f"sending dn req: {req} params: {params}")

domain_json = await http_get(app, req, params=params)
try:
domain_json = await http_get(app, req, params=params)
except HTTPNotFound:
log.warn(f"domain: {domain} not found")
raise
except HTTPGone:
log.warn(f"domain: {domain} has been removed")
raise
except ClientResponseError as ce:
# shouldn't get this if we are catching relevent exceptions
# in http_get...
log.error(f"Unexpected ClientResponseError: {ce}")

if ce.code == 404:
log.warn("domain not found")
raise HTTPNotFound()
elif ce.code == 410:
log.warn("domain has been removed")
raise HTTPGone()
else:
log.error(f"unexpected error: {ce.code}")
raise HTTPInternalServerError()

if not domain_json:
msg = f"nothing returned (and no exceptionraised) for domain: {domain}"
log.error(msg)
raise HTTPInternalServerError()

if "owner" not in domain_json:
log.warn("No owner key found in domain")
Expand Down
8 changes: 4 additions & 4 deletions hsds/util/domainUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def validateDomainKey(domain_key):


def getDomainFromRequest(request, validate=True):
# print("gotDomainFromRequest:", request, "validate=", validate)
print(f"tmp - getDomainFromRequest: {request}, validate={validate}")
app = request.app
domain = None
bucket = None
Expand All @@ -279,7 +279,7 @@ def getDomainFromRequest(request, validate=True):
domain = request.headers["X-Hdf-domain"]
else:
return None

print(f"tmp - domain: {domain}")
if domain.startswith("hdf5:/"):
# strip off the prefix to make following logic easier
domain = domain[6:]
Expand All @@ -297,8 +297,8 @@ def getDomainFromRequest(request, validate=True):
else:
pass # no bucket specified

if bucket and validate:
if not isValidBucketName(bucket):
if bucket:
if validate and not isValidBucketName(bucket):
raise ValueError(f"bucket name: {bucket} is not valid")
if domain[0] == "/":
domain = bucket + domain
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/domain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ def testGetDomains(self):

headers = helper.getRequestHeaders(domain=folder + "/")
req = helper.getEndpoint() + "/domains"
rsp = self.session.get(req, headers=headers) # , params=params)
rsp = self.session.get(req, headers=headers)
self.assertEqual(rsp.status_code, 200)
self.assertEqual(rsp.headers["content-type"], "application/json; charset=utf-8")
rspJson = json.loads(rsp.text)
Expand Down
Loading

0 comments on commit 27d830b

Please sign in to comment.