forked from dbt-labs/dbt-core
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwerkzeug-refresh-token.py
139 lines (116 loc) · 4.01 KB
/
werkzeug-refresh-token.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import json
import secrets
import textwrap
from base64 import b64encode
import requests
from werkzeug.utils import redirect
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from werkzeug.wrappers import Request, Response
from werkzeug.serving import run_simple
from urllib.parse import urlencode
def _make_rfp_claim_value():
# from https://tools.ietf.org/id/draft-bradley-oauth-jwt-encoded-state-08.html#rfc.section.4 # noqa
# we can do whatever we want really, so just token.urlsafe?
return secrets.token_urlsafe(112)
def _make_response(client_id, client_secret, refresh_token):
return Response(textwrap.dedent(
f'''\
SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN="{refresh_token}"
SNOWFLAKE_TEST_OAUTH_CLIENT_ID="{client_id}"
SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET="{client_secret}"'''
))
class TokenManager:
def __init__(self, account_name, client_id, client_secret):
self.account_name = account_name
self.client_id = client_id
self.client_secret = client_secret
self.token = None
self.rfp_claim = _make_rfp_claim_value()
self.port = 8080
@property
def account_url(self):
return f'https://{self.account_name}.snowflakecomputing.com'
@property
def auth_url(self):
return f'{self.account_url}/oauth/authorize'
@property
def token_url(self):
return f'{self.account_url}/oauth/token-request'
@property
def redirect_uri(self):
return f'http://localhost:{self.port}'
@property
def headers(self):
auth = f'{self.client_id}:{self.client_secret}'.encode('ascii')
encoded_auth = b64encode(auth).decode('ascii')
return {
'Authorization': f'Basic {encoded_auth}',
'Content-type': 'application/x-www-form-urlencoded; charset=utf-8'
}
def _code_to_token(self, code):
data = {
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': self.redirect_uri,
}
# data = urlencode(data)
resp = requests.post(
url=self.token_url,
headers=self.headers,
data=data,
)
try:
refresh_token = resp.json()['refresh_token']
except KeyError:
print(resp.json())
raise
return refresh_token
@Request.application
def auth(self, request):
code = request.args.get('code')
if code:
# we got 303'ed here with a code
state_received = request.args.get('state')
if state_received != self.rfp_claim:
return Response('Invalid RFP claim: MITM?', status=401)
refresh_token = self._code_to_token(code)
return _make_response(
self.client_id,
self.client_secret,
refresh_token,
)
else:
return redirect('/login')
@Request.application
def login(self, request):
# take the auth URL and add the query string to it
query = {
'response_type': 'code',
'client_id': self.client_id,
'redirect_uri': self.redirect_uri,
'state': self.rfp_claim,
}
query = urlencode(query)
return redirect(f'{self.auth_url}?{query}')
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('account_name', help='The account name')
parser.add_argument('json_blob', help='The json auth blob')
return parser.parse_args()
def main():
args = parse_args()
data = json.loads(args.json_blob)
client_id = data['OAUTH_CLIENT_ID']
client_secret = data['OAUTH_CLIENT_SECRET']
token_manager = TokenManager(
account_name=args.account_name,
client_id=client_id,
client_secret=client_secret,
)
app = DispatcherMiddleware(token_manager.auth, {
'/login': token_manager.login,
})
run_simple('localhost', token_manager.port, app)
if __name__ == '__main__':
main()