Skip to content

Commit

Permalink
update credentials to use new rest-tools syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
dsschult committed Aug 27, 2024
1 parent 2fddba2 commit 7df2350
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 40 deletions.
90 changes: 54 additions & 36 deletions iceprod/credentials/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import motor.motor_asyncio
import requests.exceptions
from rest_tools.client import RestClient, ClientCredentialsAuth
from rest_tools.server import RestServer
from rest_tools.server import RestServer, ArgumentHandler, ArgumentSource
from tornado.web import HTTPError
from tornado.web import RequestHandler as TornadoRequestHandler
from wipac_dev_tools import from_environment
Expand Down Expand Up @@ -59,8 +59,20 @@ async def check_attr_auth(self, arg, val, role):
raise HTTPError(500, 'auth could not be completed')

async def create(self, db, base_data):
url = self.get_json_body_argument('url', type=str, strict_type=True)
credential_type = self.get_json_body_argument('type', type=str, choices=['s3', 'oauth'], strict_type=True)
now = time.time()
argo = ArgumentHandler(ArgumentSource.JSON_BODY_ARGUMENTS, self)
argo.add_argument('url', type=str, required=True)
argo.add_argument('type', type=str, choices=['s3', 'oauth'], required=True)
argo.add_argument('buckets', type=list, default=[], required=False)
argo.add_argument('access_key', type=str, default='', required=False)
argo.add_argument('secret_key', type=str, default='', required=False)
argo.add_argument('access_token', type=str, default='', required=False)
argo.add_argument('refresh_token', type=str, default='', required=False)
argo.add_argument('expire_date', type=float, default=now, required=False)
argo.add_argument('last_use', type=float, default=now, required=False)
args = vars(argo.parse_args())
url = args['url']
credential_type = args['type']

base_data['url'] = url
data = base_data.copy()
Expand All @@ -69,28 +81,24 @@ async def create(self, db, base_data):
})

if credential_type == 's3':
buckets = self.get_json_body_argument('buckets', type=list, strict_type=True)
access_key = self.get_json_body_argument('access_key', type=str, strict_type=True)
secret_key = self.get_json_body_argument('secret_key', type=str, strict_type=True)
if not buckets:
if not args['buckets']:
raise HTTPError(400, reason='must specify bucket(s)')
data['buckets'] = buckets
data['access_key'] = access_key
data['secret_key'] = secret_key
if not args['access_key']:
raise HTTPError(400, reason='must specify access_key')
if not args['secret_key']:
raise HTTPError(400, reason='must specify secret_key')

elif credential_type == 'oauth':
access_token = self.get_json_body_argument('access_token', default='', type=str, strict_type=True)
refresh_token = self.get_json_body_argument('refresh_token', default='', type=str, strict_type=True)
now = time.time()
exp = self.get_json_body_argument('expire_date', default=now, type=float)
last_use = self.get_json_body_argument('last_use', default=now, type=float)
data['buckets'] = args['buckets']
data['access_key'] = args['access_key']
data['secret_key'] = args['secret_key']

if (not access_token) and not refresh_token:
elif credential_type == 'oauth':
if (not args['access_token']) and not args['refresh_token']:
raise HTTPError(400, reason='must specify either access or refresh tokens')
data['access_token'] = access_token
data['refresh_token'] = refresh_token
data['expiration'] = exp
data['last_use'] = last_use
data['access_token'] = args['access_token']
data['refresh_token'] = args['refresh_token']
data['expiration'] = args['expire_date']
data['last_use'] = args['last_use']

if 'refresh_token' in data and not data.get('access_token', ''):
new_cred = await self.refresh_service.refresh_cred(data)
Expand All @@ -108,17 +116,21 @@ async def create(self, db, base_data):
)

async def patch_cred(self, db, base_data):
base_data['url'] = self.get_json_body_argument('url', type=str, strict_type=True)
argo = ArgumentHandler(ArgumentSource.JSON_BODY_ARGUMENTS, self)
argo.add_argument('url', type=str, required=True)
argo.add_argument('buckets', type=list, default=[], required=False)
argo.add_argument('access_key', type=str, default='', required=False)
argo.add_argument('secret_key', type=str, default='', required=False)
argo.add_argument('access_token', type=str, default='', required=False)
argo.add_argument('refresh_token', type=str, default='', required=False)
argo.add_argument('expiration', type=float, default=0, required=False)
argo.add_argument('last_use', type=float, default=0, required=False)
args = vars(argo.parse_args())
url = args['url']

data = {}
buckets = self.get_json_body_argument('buckets', default=[], type=list, strict_type=True)
if buckets:
data['buckets'] = buckets
for key in ('access_key', 'secret_key', 'access_token', 'refresh_token'):
if val := self.get_json_body_argument(key, default='', type=str, strict_type=True):
data[key] = val
for key in ('expiration', 'last_use'):
if val := self.get_json_body_argument(key, default=None, type=float):
for key in ('buckets', 'access_key', 'secret_key', 'access_token', 'refresh_token', 'expiration', 'last_use'):
if val := args[key]:
data[key] = val

if 'refresh_token' in data and 'access_token' not in data:
Expand Down Expand Up @@ -245,9 +257,12 @@ async def delete(self, groupname):
raise HTTPError(403, 'unauthorized')

args = {'groupname': groupname}
url = self.get_json_body_argument('url', default='', type=str)
if url:
args['url'] = url

argo = ArgumentHandler(ArgumentSource.JSON_BODY_ARGUMENTS, self)
argo.add_argument('url', type=str, default='', required=False)
body_args = argo.parse_args()
if body_args.url:
args['url'] = body_args.url

await self.db.group_creds.delete_many(args)
self.write({})
Expand Down Expand Up @@ -336,9 +351,12 @@ async def delete(self, username):
raise HTTPError(403, 'unauthorized')

args = {'username': username}
url = self.get_json_body_argument('url', default='', type=str)
if url:
args['url'] = url

argo = ArgumentHandler(ArgumentSource.JSON_BODY_ARGUMENTS, self)
argo.add_argument('url', type=str, default='', required=False)
body_args = argo.parse_args()
if body_args.url:
args['url'] = body_args.url

await self.db.user_creds.delete_many(args)
self.write({})
Expand Down
10 changes: 6 additions & 4 deletions tests/credentials/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ async def test_credentials_groups_s3(server):
await client.request('POST', f'/groups/{GROUP}/credentials', data3)

ret = await client.request('GET', f'/groups/{GROUP}/credentials')
data3['groupname'] = GROUP
assert ret == {data['url']: data3, data2['url']: data2}
data3_out = data3.copy()
data3_out['groupname'] = GROUP
assert ret == {data['url']: data3_out, data2['url']: data2}

await client.request('DELETE', f'/groups/{GROUP}/credentials', {'url': 'http://foo'})

Expand Down Expand Up @@ -270,8 +271,9 @@ async def test_credentials_users_s3(server):
await client.request('POST', f'/users/{USER}/credentials', data3)

ret = await client.request('GET', f'/users/{USER}/credentials')
data3['username'] = USER
assert ret == {data['url']: data3, data2['url']: data2}
data3_out = data3.copy()
data3_out['username'] = USER
assert ret == {data['url']: data3_out, data2['url']: data2}

await client.request('DELETE', f'/users/{USER}/credentials', {'url': 'http://foo'})

Expand Down

0 comments on commit 7df2350

Please sign in to comment.