Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up getdomain logic #368

Merged
merged 5 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 24 additions & 167 deletions hsds/domain_sn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
import os.path as op

from aiohttp.web_exceptions import HTTPBadRequest, HTTPForbidden, HTTPNotFound
from aiohttp.web_exceptions import HTTPGone, HTTPInternalServerError
from aiohttp.web_exceptions import HTTPInternalServerError
from aiohttp.web_exceptions import HTTPConflict, HTTPServiceUnavailable
from aiohttp import ClientResponseError
from aiohttp.web import json_response

from .util.httpUtil import getObjectClass, http_post, http_put, http_delete
Expand Down Expand Up @@ -161,14 +160,13 @@ async def get_domains(request):
raise HTTPServiceUnavailable()

# allow domain with / to indicate a folder
prefix = None
try:
prefix = getDomainFromRequest(request, validate=False)
except ValueError:
pass # igore
if not prefix:
folder_path = getDomainFromRequest(request, validate=False)

if not folder_path:
# if there is no domain passed in, get a list of top level domains
prefix = "/"
folder_path = "/"

prefix = getPathForDomain(folder_path) # don't include the bucket if any

if "pattern" not in request.rel_url.query:
pattern = None
Expand All @@ -191,11 +189,6 @@ async def get_domains(request):

log.info(f"get_domains for: {prefix} verbose: {verbose}")

if not prefix.startswith("/"):
msg = "Prefix must start with '/'"
log.warn(msg)
raise HTTPBadRequest(reason=msg)

limit = None
if "Limit" in request.rel_url.query:
try:
Expand All @@ -217,10 +210,13 @@ async def get_domains(request):
bucket = params["bucket"]
elif "X-Hdf-bucket" in request.headers:
bucket = request.headers["X-Hdf-bucket"]
elif getBucketForDomain(folder_path):
bucket = getBucketForDomain(folder_path)
elif "bucket_name" in app and app["bucket_name"]:
bucket = app["bucket_name"]
else:
bucket = None

if not bucket:
msg = "no bucket specified for request"
log.warn(msg)
Expand Down Expand Up @@ -473,22 +469,10 @@ async def GET_Domain(request):
log.response(request, resp=resp)
return resp

log.info(f"got domain: {domain}")
log.info(f"get domain: {domain}")

domain_json = await getDomainJson(app, domain, reload=True)

if domain_json is None:
log.warn(f"domain: {domain} not found")
raise HTTPNotFound()

if "owner" not in domain_json:
log.error("No owner key found in domain")
raise HTTPInternalServerError()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()

log.debug(f"got domain_json: {domain_json}")
# validate that the requesting user has permission to read this domain
# aclCheck throws exception if not authorized
Expand Down Expand Up @@ -703,14 +687,6 @@ async def POST_Domain(request):

domain_json = await getDomainJson(app, domain, reload=True)

if domain_json is None:
log.warn(f"domain: {domain} not found")
raise HTTPNotFound()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()

if "root" not in domain_json:
msg = f"{domain} is a folder, not a domain"
log.warn(msg)
Expand Down Expand Up @@ -845,17 +821,6 @@ async def PUT_Domain(request):
domain_json = await getDomainJson(app, domain, reload=True)
log.debug(f"got domain_json: {domain_json}")

if domain_json is None:
log.warn(f"domain: {domain} not found")
raise HTTPNotFound()

if "owner" not in domain_json:
log.error("No owner key found in domain")
raise HTTPInternalServerError()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()
# throws exception if not allowed
aclCheck(app, domain_json, "update", username)
rsp_json = None
Expand Down Expand Up @@ -980,22 +945,10 @@ async def PUT_Domain(request):
log.warn(msg)
raise HTTPForbidden()

parent_json = None
if not is_toplevel:
try:
parent_json = await getDomainJson(app, parent_domain, reload=True)
except ClientResponseError as ce:
if ce.code == 404:
msg = f"Parent domain: {parent_domain} not found"
log.warn(msg)
raise HTTPNotFound()
elif ce.code == 410:
msg = f"Parent domain: {parent_domain} removed"
log.warn(msg)
raise HTTPGone()
else:
log.error(f"Unexpected error: {ce.code}")
raise HTTPInternalServerError()
if is_toplevel:
parent_json = None
else:
parent_json = await getDomainJson(app, parent_domain, reload=True)

log.debug(f"parent_json {parent_domain}: {parent_json}")
if "root" in parent_json and parent_json["root"]:
Expand Down Expand Up @@ -1041,12 +994,8 @@ async def PUT_Domain(request):
bucket = getBucketForDomain(domain)
if bucket:
post_params["bucket"] = bucket
try:
group_json = await http_post(app, req, data=group_json, params=post_params)
except ClientResponseError as ce:
msg = "Error creating root group for domain -- " + str(ce)
log.error(msg)
raise HTTPInternalServerError()
group_json = await http_post(app, req, data=group_json, params=post_params)

else:
log.debug("no root group, creating folder")

Expand Down Expand Up @@ -1094,12 +1043,7 @@ async def PUT_Domain(request):
body["root"] = root_id

log.debug(f"creating domain: {domain} with body: {body}")
try:
domain_json = await http_put(app, req, data=body)
except ClientResponseError as ce:
msg = "Error creating domain state -- " + str(ce)
log.error(msg)
raise HTTPInternalServerError()
domain_json = await http_put(app, req, data=body)

# domain creation successful
# mixin limits
Expand Down Expand Up @@ -1184,18 +1128,7 @@ async def DELETE_Domain(request):
log.warn(msg)
raise HTTPForbidden()

try:
domain_json = await getDomainJson(app, domain, reload=True)
except ClientResponseError as ce:
if ce.code == 404:
log.warn("domain not found")
raise HTTPNotFound()
elif ce.code == 410:
log.warn("domain has been removed")
raise HTTPGone()
else:
log.error(f"unexpected error: {ce.code}")
raise HTTPInternalServerError()
domain_json = await getDomainJson(app, domain, reload=True)

# throws exception if not allowed
aclCheck(app, domain_json, "delete", username)
Expand Down Expand Up @@ -1277,16 +1210,7 @@ async def GET_ACL(request):
checkBucketAccess(app, bucket)

# use reload to get authoritative domain json
try:
domain_json = await getDomainJson(app, domain, reload=True)
except ClientResponseError as ce:
if ce.code in (404, 410):
msg = "domain not found"
log.warn(msg)
raise HTTPNotFound()
else:
log.error(f"unexpected error: {ce.code}")
raise HTTPInternalServerError()
domain_json = await getDomainJson(app, domain, reload=True)

# validate that the requesting user has permission to read ACLs
# in this domain
Expand Down Expand Up @@ -1361,19 +1285,7 @@ async def GET_ACLs(request):
checkBucketAccess(app, bucket)

# use reload to get authoritative domain json
try:
domain_json = await getDomainJson(app, domain, reload=True)
except ClientResponseError:
log.warn("domain not found")
raise HTTPNotFound()

if "owner" not in domain_json:
log.error("No owner key found in domain")
raise HTTPInternalServerError()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()
domain_json = await getDomainJson(app, domain, reload=True)

acls = domain_json["acls"]

Expand Down Expand Up @@ -1498,28 +1410,7 @@ async def GET_Datasets(request):
bucket = getBucketForDomain(domain)

# verify the domain
try:
domain_json = await getDomainJson(app, domain)
except ClientResponseError as ce:
if ce.code == 404:
msg = f"Domain: {domain} not found"
log.warn(msg)
raise HTTPNotFound()
elif ce.code == 410:
msg = f"Domain: {domain} removed"
log.warn(msg)
raise HTTPGone()
else:
log.error(f"Unexpected error: {ce.code}")
raise HTTPInternalServerError()

if "owner" not in domain_json:
log.error("No owner key found in domain")
raise HTTPInternalServerError()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()
domain_json = await getDomainJson(app, domain)

log.debug(f"got domain_json: {domain_json}")
# validate that the requesting user has permission to read this domain
Expand Down Expand Up @@ -1588,24 +1479,7 @@ async def GET_Groups(request):
bucket = getBucketForDomain(domain)

# use reload to get authoritative domain json
try:
domain_json = await getDomainJson(app, domain, reload=True)
except ClientResponseError as ce:
if ce.code == 404:
msg = "domain not found"
log.warn(msg)
raise HTTPNotFound()
else:
log.error(f"Unexpected error: {ce.code}")
raise HTTPInternalServerError()

if "owner" not in domain_json:
log.error("No owner key found in domain")
raise HTTPInternalServerError()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()
domain_json = await getDomainJson(app, domain, reload=True)

log.debug(f"got domain_json: {domain_json}")
# validate that the requesting user has permission to read this domain
Expand Down Expand Up @@ -1673,24 +1547,7 @@ async def GET_Datatypes(request):
bucket = getBucketForDomain(domain)

# use reload to get authoritative domain json
try:
domain_json = await getDomainJson(app, domain, reload=True)
except ClientResponseError as ce:
if ce.code in (404, 410):
msg = "domain not found"
log.warn(msg)
raise HTTPNotFound()
else:
log.error(f"Unexpected Error: {ce.code})")
raise HTTPInternalServerError()

if "owner" not in domain_json:
log.error("No owner key found in domain")
raise HTTPInternalServerError()

if "acls" not in domain_json:
log.error("No acls key found in domain")
raise HTTPInternalServerError()
domain_json = await getDomainJson(app, domain, reload=True)

log.debug(f"got domain_json: {domain_json}")
# validate that the requesting user has permission to read this domain
Expand Down
29 changes: 28 additions & 1 deletion hsds/servicenode_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from aiohttp.web_exceptions import HTTPBadRequest, HTTPForbidden, HTTPGone, HTTPConflict
from aiohttp.web_exceptions import HTTPNotFound, HTTPInternalServerError
from aiohttp.client_exceptions import ClientOSError, ClientError
from aiohttp import ClientResponseError

from .util.authUtil import getAclKeys
from .util.arrayUtil import encodeData
Expand Down Expand Up @@ -63,7 +64,33 @@ async def getDomainJson(app, domain, reload=False):

log.debug(f"sending dn req: {req} params: {params}")

domain_json = await http_get(app, req, params=params)
try:
domain_json = await http_get(app, req, params=params)
except HTTPNotFound:
log.warn(f"domain: {domain} not found")
raise
except HTTPGone:
log.warn(f"domain: {domain} has been removed")
raise
except ClientResponseError as ce:
# shouldn't get this if we are catching relevant exceptions
# in http_get...
log.error(f"Unexpected ClientResponseError: {ce}")

if ce.code == 404:
log.warn("domain not found")
raise HTTPNotFound()
elif ce.code == 410:
log.warn("domain has been removed")
raise HTTPGone()
else:
log.error(f"unexpected error: {ce.code}")
raise HTTPInternalServerError()

if not domain_json:
msg = f"nothing returned (and no exceptionraised) for domain: {domain}"
log.error(msg)
raise HTTPInternalServerError()

if "owner" not in domain_json:
log.warn("No owner key found in domain")
Expand Down
7 changes: 3 additions & 4 deletions hsds/util/domainUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def validateDomainKey(domain_key):


def getDomainFromRequest(request, validate=True):
# print("gotDomainFromRequest:", request, "validate=", validate)
# print(f"getDomainFromRequest: {request}, validate={validate}")
app = request.app
domain = None
bucket = None
Expand All @@ -279,7 +279,6 @@ def getDomainFromRequest(request, validate=True):
domain = request.headers["X-Hdf-domain"]
else:
return None

if domain.startswith("hdf5:/"):
# strip off the prefix to make following logic easier
domain = domain[6:]
Expand All @@ -297,8 +296,8 @@ def getDomainFromRequest(request, validate=True):
else:
pass # no bucket specified

if bucket and validate:
if not isValidBucketName(bucket):
if bucket:
if validate and not isValidBucketName(bucket):
raise ValueError(f"bucket name: {bucket} is not valid")
if domain[0] == "/":
domain = bucket + domain
Expand Down
Loading
Loading