From 8be32814238b097ad09415e339fc096174e71bd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominik=20Pa=C4=BEo?= Date: Tue, 26 Nov 2024 17:47:55 +0100 Subject: [PATCH] WIP --- Sources/Base/OAuth2Base.swift | 14 +- Sources/Base/OAuth2RequestPerformer.swift | 22 +- Sources/Base/OAuth2Requestable.swift | 64 +++-- Sources/Base/OAuth2Response.swift | 17 +- Sources/DataLoader/OAuth2DataLoader.swift | 40 +-- Sources/Flows/OAuth2.swift | 278 +++++++------------- Sources/Flows/OAuth2ClientCredentials.swift | 36 ++- Sources/Flows/OAuth2CodeGrant.swift | 29 +- Sources/Flows/OAuth2DeviceGrant.swift | 123 ++++----- Sources/Flows/OAuth2DynReg.swift | 35 +-- Sources/Flows/OAuth2PasswordGrant.swift | 81 +++--- 11 files changed, 311 insertions(+), 428 deletions(-) diff --git a/Sources/Base/OAuth2Base.swift b/Sources/Base/OAuth2Base.swift index 36216e1..97bfcfb 100644 --- a/Sources/Base/OAuth2Base.swift +++ b/Sources/Base/OAuth2Base.swift @@ -132,14 +132,8 @@ open class OAuth2Base: OAuth2Securable { set { clientConfig.customUserAgent = newValue } } - - /// This closure is internally used with `authorize(params:callback:)` and only exposed for subclassing reason, do not mess with it! - public final var didAuthorizeOrFail: ((_ parameters: OAuth2JSON?, _ error: OAuth2Error?) -> Void)? - /// Returns true if the receiver is currently authorizing. - public final var isAuthorizing: Bool { - return nil != didAuthorizeOrFail - } + public final var isAuthorizing: Bool = false /// Returns true if the receiver is currently exchanging the refresh token. public final var isExchangingRefreshToken: Bool = false @@ -277,8 +271,7 @@ open class OAuth2Base: OAuth2Securable { storeTokensToKeychain() } callOnMainThread() { - self.didAuthorizeOrFail?(parameters, nil) - self.didAuthorizeOrFail = nil + self.isAuthorizing = false self.internalAfterAuthorizeOrFail?(false, nil) self.afterAuthorizeOrFail?(parameters, nil) } @@ -301,8 +294,7 @@ open class OAuth2Base: OAuth2Securable { finalError = OAuth2Error.requestCancelled } callOnMainThread() { - self.didAuthorizeOrFail?(nil, finalError) - self.didAuthorizeOrFail = nil + self.isAuthorizing = false self.internalAfterAuthorizeOrFail?(true, finalError) self.afterAuthorizeOrFail?(nil, finalError) } diff --git a/Sources/Base/OAuth2RequestPerformer.swift b/Sources/Base/OAuth2RequestPerformer.swift index 2e68c32..34b4583 100644 --- a/Sources/Base/OAuth2RequestPerformer.swift +++ b/Sources/Base/OAuth2RequestPerformer.swift @@ -17,14 +17,12 @@ The class `OAuth2DataTaskRequestPerformer` implements this protocol and is by de public protocol OAuth2RequestPerformer { /** - This method should start executing the given request, returning a URLSessionTask if it chooses to do so. **You do not neet to call - `resume()` on this task**, it's supposed to already have started. It is being returned so you may be able to do additional stuff. + This method should execute the given request asynchronously. - parameter request: An URLRequest object that provides the URL, cache policy, request type, body data or body stream, and so on. - - parameter completionHandler: The completion handler to call when the load request is complete. - - returns: An already running session task + - returns: Data and response. */ - func perform(request: URLRequest, completionHandler: @escaping (Data?, URLResponse?, Error?) -> Void) -> URLSessionTask? + func perform(request: URLRequest) async throws -> (Data, URLResponse) } @@ -36,7 +34,6 @@ open class OAuth2DataTaskRequestPerformer: OAuth2RequestPerformer { /// The URLSession that should be used. public var session: URLSession - /** Designated initializer. */ @@ -45,18 +42,13 @@ open class OAuth2DataTaskRequestPerformer: OAuth2RequestPerformer { } /** - This method should start executing the given request, returning a URLSessionTask if it chooses to do so. **You do not neet to call - `resume()` on this task**, it's supposed to already have started. It is being returned so you may be able to do additional stuff. + This method should execute the given request asynchronously. - parameter request: An URLRequest object that provides the URL, cache policy, request type, body data or body stream, and so on. - - parameter completionHandler: The completion handler to call when the load request is complete. - - returns: An already running session data task + - returns: Data and response. */ - @discardableResult - open func perform(request: URLRequest, completionHandler: @escaping (Data?, URLResponse?, Error?) -> Void) -> URLSessionTask? { - let task = session.dataTask(with: request, completionHandler: completionHandler) - task.resume() - return task + open func perform(request: URLRequest) async throws -> (Data, URLResponse) { + try await session.data(for: request) } } diff --git a/Sources/Base/OAuth2Requestable.swift b/Sources/Base/OAuth2Requestable.swift index c027487..5268459 100644 --- a/Sources/Base/OAuth2Requestable.swift +++ b/Sources/Base/OAuth2Requestable.swift @@ -105,41 +105,40 @@ open class OAuth2Requestable { open var requestPerformer: OAuth2RequestPerformer? /** - Perform the supplied request and call the callback with the response JSON dict or an error. This method is intended for authorization + Perform the supplied request and return the response JSON dict or throw an error. This method is intended for authorization calls, not for data calls outside of the OAuth2 dance. - This implementation uses the shared `NSURLSession` and executes a data task. If the server responds with an error, this will be - converted into an error according to information supplied in the response JSON (if availale). - - The callback returns a response object that is easy to use, like so: - - perform(request: req) { response in - do { - let data = try response.responseData() - // do what you must with `data` as Data and `response.response` as HTTPURLResponse - } - catch let error { - // the request failed because of `error` - } - } - - Easy, right? + This implementation uses the shared `NSURLSession`. If the server responds with an error, this will be + converted into an error according to information supplied in the response JSON (if available). - parameter request: The request to execute - - parameter callback: The callback to call when the request completes/fails. Looks terrifying, see above on how to use it + - returns : OAuth2 response */ - open func perform(request: URLRequest, callback: @escaping ((OAuth2Response) -> Void)) { + open func perform(request: URLRequest) async -> OAuth2Response { self.logger?.trace("OAuth2", msg: "REQUEST\n\(request.debugDescription)\n---") let performer = requestPerformer ?? OAuth2DataTaskRequestPerformer(session: session) requestPerformer = performer - let task = performer.perform(request: request) { sessData, sessResponse, error in - self.abortableTask = nil - self.logger?.trace("OAuth2", msg: "RESPONSE\n\(sessResponse?.debugDescription ?? "no response")\n\n\(String(data: sessData ?? Data(), encoding: String.Encoding.utf8) ?? "no data")\n---") - let http = (sessResponse as? HTTPURLResponse) ?? HTTPURLResponse(url: request.url!, statusCode: 499, httpVersion: nil, headerFields: nil)! - let response = OAuth2Response(data: sessData, request: request, response: http, error: error) - callback(response) + + do { + // TODO: add support for aborting the request, see https://www.hackingwithswift.com/quick-start/concurrency/how-to-cancel-a-task + let (sessData, sessResponse) = try await performer.perform(request: request) + self.logger?.trace("OAuth2", msg: "RESPONSE\n\(sessResponse.debugDescription)\n\n\(String(data: sessData, encoding: String.Encoding.utf8) ?? "no data")\n---") + + guard let response = sessResponse as? HTTPURLResponse else { + throw CommonError.castError( + from: String(describing: sessResponse.self), + to: String(describing: HTTPURLResponse.self) + ) + } + + return OAuth2Response(data: sessData, request: request, response: response, error: nil) + + } catch { + self.logger?.trace("OAuth2", msg: "RESPONSE\nno response\n\nno data\n---") + + let http = HTTPURLResponse(url: request.url!, statusCode: 499, httpVersion: nil, headerFields: nil)! + return OAuth2Response(data: nil, request: request, response: http, error: error) } - abortableTask = task } /// Currently running abortable session task. @@ -222,3 +221,16 @@ public func callOnMainThread(_ callback: (() -> Void)) { } } +// TODO: move to a separate file +enum CommonError: Error { + case castError(from: String, to: String) +} + +extension CommonError: CustomStringConvertible { + public var description: String { + switch self { + case .castError(from: let from, to: let to): + return "Could not cast \(from) to \(to)" + } + } +} diff --git a/Sources/Base/OAuth2Response.swift b/Sources/Base/OAuth2Response.swift index 8016036..6b907e0 100644 --- a/Sources/Base/OAuth2Response.swift +++ b/Sources/Base/OAuth2Response.swift @@ -26,15 +26,14 @@ Encapsulates a URLResponse to a URLRequest. Instances of this class are returned from `OAuth2Requestable` calls, they can be used like so: - perform(request: req) { response in - do { - let data = try response.responseData() - // do what you must with `data` as Data and `response.response` as HTTPURLResponse - } - catch let error { - // the request failed because of `error` - } - } + await perform(request: req) + do { + let data = try response.responseData() + // do what you must with `data` as Data and `response.response` as HTTPURLResponse + } + catch let error { + // the request failed because of `error` + } */ open class OAuth2Response { diff --git a/Sources/DataLoader/OAuth2DataLoader.swift b/Sources/DataLoader/OAuth2DataLoader.swift index d5674ff..a747a03 100644 --- a/Sources/DataLoader/OAuth2DataLoader.swift +++ b/Sources/DataLoader/OAuth2DataLoader.swift @@ -80,7 +80,7 @@ open class OAuth2DataLoader: OAuth2Requestable { - parameter request: The request to execute - parameter callback: The callback to call when the request completes/fails. Looks terrifying, see above on how to use it */ - override open func perform(request: URLRequest, callback: @escaping ((OAuth2Response) -> Void)) { + open func perform(request: URLRequest, callback: @escaping ((OAuth2Response) -> Void)) { perform(request: request, retry: true, callback: callback) } @@ -112,7 +112,9 @@ open class OAuth2DataLoader: OAuth2Requestable { return } - super.perform(request: request) { response in + Task { + let response = await super.perform(request: request) + do { if self.alsoIntercept403, 403 == response.response.statusCode { throw OAuth2Error.unauthorizedClient(nil) @@ -126,16 +128,19 @@ open class OAuth2DataLoader: OAuth2Requestable { if retry { self.enqueue(request: request, callback: callback) self.oauth2.clientConfig.accessToken = nil - self.attemptToAuthorize() { json, error in - - // dequeue all if we're authorized, throw all away if something went wrong - if nil != json { - self.retryAll() - } - else { - self.throwAllAway(with: error ?? OAuth2Error.requestCancelled) + + + do { + let json = try await self.attemptToAuthorize() + guard json != nil else { + throw OAuth2Error.requestCancelled } + + self.retryAll() + } catch { + self.throwAllAway(with: error.asOAuth2Error) } + } else { callback(response) @@ -157,14 +162,15 @@ open class OAuth2DataLoader: OAuth2Requestable { - parameter callback: The callback passed on from `authorize(callback:)`. Authorization finishes successfully (auth parameters will be non-nil but may be an empty dict), fails (error will be non-nil) or is canceled (both params and error are nil) */ - open func attemptToAuthorize(callback: @escaping ((OAuth2JSON?, OAuth2Error?) -> Void)) { - if !isAuthorizing { - isAuthorizing = true - oauth2.authorize() { authParams, error in - self.isAuthorizing = false - callback(authParams, error) - } + open func attemptToAuthorize() async throws -> OAuth2JSON? { + guard !self.isAuthorizing else { + return nil } + + self.isAuthorizing = true + let authParams = try await oauth2.authorize() + self.isAuthorizing = false + return authParams } diff --git a/Sources/Flows/OAuth2.swift b/Sources/Flows/OAuth2.swift index a655661..95fdd1c 100644 --- a/Sources/Flows/OAuth2.swift +++ b/Sources/Flows/OAuth2.swift @@ -97,76 +97,37 @@ open class OAuth2: OAuth2Base { calling the callback with a failure. If client_id is not set but a "registration_uri" has been provided, a dynamic client registration will be attempted and if it success, an access token will be requested. - - parameter params: Optional key/value pairs to pass during authorization and token refresh - - parameter callback: The callback to call when authorization finishes (parameters will be non-nil but may be an empty dict), fails or - is canceled (error will be non-nil, e.g. `.requestCancelled` if auth was aborted) + - parameter params: Optional key/value pairs to pass during authorization and token refresh + - returns: JSON dictionary or nil */ - public final func authorize(params: OAuth2StringDict? = nil, callback: @escaping ((OAuth2JSON?, OAuth2Error?) -> Void)) { - if isAuthorizing { - callback(nil, OAuth2Error.alreadyAuthorizing) - return + public final func authorize(params: OAuth2StringDict? = nil) async throws -> OAuth2JSON? { + guard !self.isAuthorizing else { + throw OAuth2Error.alreadyAuthorizing } - if isExchangingRefreshToken { - callback(nil, OAuth2Error.alreadyExchangingRefreshToken) - return + guard !isExchangingRefreshToken else { + throw OAuth2Error.alreadyExchangingRefreshToken } - didAuthorizeOrFail = callback + self.isAuthorizing = true logger?.debug("OAuth2", msg: "Starting authorization") - tryToObtainAccessTokenIfNeeded(params: params) { successParams, error in - if let successParams = successParams { + + do { + if let successParams = try await tryToObtainAccessTokenIfNeeded(params: params) { self.didAuthorize(withParameters: successParams) + return successParams } - else if let error = error { - self.didFail(with: error) - } - else { - self.registerClientIfNeeded() { json, error in - if let error = error { - self.didFail(with: error) - } - else { - do { - assert(Thread.isMainThread) - try self.doAuthorize(params: params) - } - catch let error { - self.didFail(with: error.asOAuth2Error) - } - } - } - } - } - } - - /** - Shortcut function to start embedded authorization from the given context (a UIViewController on iOS, an NSWindow on OS X). - - This method sets `authConfig.authorizeEmbedded = true` and `authConfig.authorizeContext = <# context #>`, then calls `authorize()` - - - parameter from: The context to start authorization from, depends on platform (UIViewController or NSWindow, see `authorizeContext`) - - parameter params: Optional key/value pairs to pass during authorization - - parameter callback: The callback to call when authorization finishes (parameters will be non-nil but may be an empty dict), fails or - is canceled (error will be non-nil, e.g. `.requestCancelled` if auth was aborted) - */ - @available(*, deprecated, message: "Use ASWebAuthenticationSession (preferred) or SFSafariWebViewController. This will be removed in v6.") - open func authorizeEmbedded(from context: AnyObject, params: OAuth2StringDict? = nil, callback: @escaping ((_ authParameters: OAuth2JSON?, _ error: OAuth2Error?) -> Void)) { - if isAuthorizing { // `authorize()` will check this, but we want to exit before changing `authConfig` - callback(nil, OAuth2Error.alreadyAuthorizing) - return - } - - if (isExchangingRefreshToken) { - callback(nil, OAuth2Error.alreadyExchangingRefreshToken) - return + + _ = try await self.registerClientIfNeeded() + try await self.doAuthorize(params: params) + return nil + + } catch { + self.didFail(with: error.asOAuth2Error) + throw error.asOAuth2Error } - - authConfig.authorizeEmbedded = true - authConfig.authorizeContext = context - authorize(params: params, callback: callback) } - + /** If the instance has an accessToken, checks if its expiry time has not yet passed. If we don't have an expiry date we assume the token is still valid. @@ -192,33 +153,26 @@ open class OAuth2: OAuth2Base { Indicates, in the callback, whether the client has been able to obtain an access token that is likely to still work (but there is no guarantee!) or not. - - parameter params: Optional key/value pairs to pass during authorization - - parameter callback: The callback to call once the client knows whether it has an access token or not; if `success` is true an - access token is present + - parameter params: Optional key/value pairs to pass during authorization + - returns: TODO */ - open func tryToObtainAccessTokenIfNeeded(params: OAuth2StringDict? = nil, callback: @escaping ((OAuth2JSON?, OAuth2Error?) -> Void)) { + open func tryToObtainAccessTokenIfNeeded(params: OAuth2StringDict? = nil) async throws -> OAuth2JSON? { if hasUnexpiredAccessToken() { logger?.debug("OAuth2", msg: "Have an apparently unexpired access token") - callback(OAuth2JSON(), nil) + return OAuth2JSON() } else { logger?.debug("OAuth2", msg: "No access token, checking if a refresh token is available") - doRefreshToken(params: params) { successParams, error in - if let successParams = successParams { - callback(successParams, nil) - } - else { - var returnedError: OAuth2Error? = nil - if let err = error { - self.logger?.debug("OAuth2", msg: "Error refreshing token: \(err)") - switch err { - case .noRefreshToken, .noClientId, .unauthorizedClient: - returnedError = nil - default: - returnedError = err - } - } - callback(nil, returnedError) + do { + return try await self.doRefreshToken(params: params) + } catch { + self.logger?.debug("OAuth2", msg: "Error refreshing token: \(error)") + + switch error.asOAuth2Error { + case .noRefreshToken, .noClientId, .unauthorizedClient: + return nil + default: + throw error } } } @@ -232,7 +186,7 @@ open class OAuth2: OAuth2Base { - parameter params: Optional key/value pairs to pass during authorization */ - open func doAuthorize(params: OAuth2StringDict? = nil) throws { + open func doAuthorize(params: OAuth2StringDict? = nil) async throws { if authConfig.authorizeEmbedded { try doAuthorizeEmbedded(with: authConfig, params: params) } @@ -375,35 +329,29 @@ open class OAuth2: OAuth2Base { If the request returns an error, the refresh token is thrown away. - parameter params: Optional key/value pairs to pass during token refresh - - parameter callback: The callback to call after the refresh token exchange has finished + - returns: OAuth2 JSON dictionary */ - open func doRefreshToken(params: OAuth2StringDict? = nil, callback: @escaping ((OAuth2JSON?, OAuth2Error?) -> Void)) { + open func doRefreshToken(params: OAuth2StringDict? = nil) async throws -> OAuth2JSON { do { let post = try tokenRequestForTokenRefresh(params: params).asURLRequest(for: self) logger?.debug("OAuth2", msg: "Using refresh token to receive access token from \(post.url?.description ?? "nil")") - perform(request: post) { response in - do { - let data = try response.responseData() - let json = try self.parseRefreshTokenResponseData(data) - if response.response.statusCode >= 400 { - self.clientConfig.refreshToken = nil - throw OAuth2Error.generic("Failed with status \(response.response.statusCode)") - } - self.logger?.debug("OAuth2", msg: "Did use refresh token for access token [\(nil != self.clientConfig.accessToken)]") - if (self.useKeychain) { - self.storeTokensToKeychain() - } - callback(json, nil) - } - catch let error { - self.logger?.debug("OAuth2", msg: "Error refreshing access token: \(error)") - callback(nil, error.asOAuth2Error) - } + let response = await perform(request: post) + let data = try response.responseData() + let json = try self.parseRefreshTokenResponseData(data) + if response.response.statusCode >= 400 { + self.clientConfig.refreshToken = nil + throw OAuth2Error.generic("Failed with status \(response.response.statusCode)") } + self.logger?.debug("OAuth2", msg: "Did use refresh token for access token [\(nil != self.clientConfig.accessToken)]") + if (self.useKeychain) { + self.storeTokensToKeychain() + } + + return json } - catch let error { - callback(nil, error.asOAuth2Error) + catch { + throw error.asOAuth2Error } } @@ -433,16 +381,17 @@ open class OAuth2: OAuth2Base { return req } - + /** Exchanges the subject's refresh token for audience client. see: https://datatracker.ietf.org/doc/html/rfc8693 see: https://www.scottbrady91.com/oauth/delegation-patterns-for-oauth-20 - parameter audienceClientId: The client ID of the audience requesting for its own refresh token + - parameter traceId: Unique identifier for debugging purposes. - parameter params: Optional key/value pairs to pass during token exchange - - parameter callback: The callback to call after the exchange of refresh token has finished + - returns: Exchanged refresh token */ - open func doExchangeRefreshToken(audienceClientId: String, traceId: String, params: OAuth2StringDict? = nil, callback: @escaping ((String?, OAuth2Error?) -> Void)) { + open func doExchangeRefreshToken(audienceClientId: String, traceId: String, params: OAuth2StringDict? = nil) async throws -> String { do { guard !self.isExchangingRefreshToken else { throw OAuth2Error.alreadyExchangingRefreshToken @@ -452,44 +401,36 @@ open class OAuth2: OAuth2Base { let post = try tokenRequestForExchangeRefreshToken(audienceClientId: audienceClientId, params: params).asURLRequest(for: self) logger?.debug("OAuth2", msg: "Exchanging refresh token for client with ID \(audienceClientId) from \(post.url?.description ?? "nil") [trace=\(traceId)]") - perform(request: post) { response in - do { - let data = try response.responseData() - let json = try self.parseExchangeRefreshTokenResponseData(data) - if response.response.statusCode >= 400 { - self.clientConfig.refreshToken = nil - throw OAuth2Error.generic("Failed with status \(response.response.statusCode)") - } - - // The `access_token` field contains the `requested_token_type` = the exchanged (audience) refresh token in our case. - // - // Explanation: - // The security token issued by the authorization server in response to the token exchange request. The access_token parameter - // from Section 5.1 of [RFC6749] is used here to carry the requested token, which allows this token exchange protocol to use the - // existing OAuth 2.0 request and response constructs defined for the token endpoint. - // **The identifier access_token is used for historical reasons and the issued token need not be an OAuth access token.** - // See: https://tools.ietf.org/id/draft-ietf-oauth-token-exchange-12.html#rfc.section.2.2.1 - guard let exchangedRefreshToken = json["access_token"] as? String else { - throw OAuth2Error.generic("Exchange refresh token didn't return exchanged refresh token (response.access_token) [trace=\(traceId)]") - } - self.logger?.debug("OAuth2", msg: "Did use refresh token for exchanging refresh token [trace=\(traceId)]") - self.logger?.trace("OAuth2", msg: "Exchanged refresh token in [trace=\(traceId)] is [\(exchangedRefreshToken)]") - if self.useKeychain { - self.storeTokensToKeychain() - } - self.isExchangingRefreshToken = false - callback(exchangedRefreshToken, nil) - } catch let error { - self.logger?.debug("OAuth2", msg: "Error exchanging refresh token in [trace=\(traceId)]: \(error)") - self.isExchangingRefreshToken = false - - callback(nil, error.asOAuth2Error) - } + let response = await perform(request: post) + let data = try response.responseData() + let json = try self.parseExchangeRefreshTokenResponseData(data) + if response.response.statusCode >= 400 { + self.clientConfig.refreshToken = nil + throw OAuth2Error.generic("Failed with status \(response.response.statusCode)") } - } catch let error { + + /// The `access_token` field contains the `requested_token_type` = the exchanged (audience) refresh token in our case. + /// + /// **Explanation:** + /// The security token issued by the authorization server in response to the token exchange request. The access_token parameter + /// from Section 5.1 of [RFC6749] is used here to carry the requested token, which allows this token exchange protocol to use the + /// existing OAuth 2.0 request and response constructs defined for the token endpoint. + /// **The identifier access_token is used for historical reasons and the issued token need not be an OAuth access token.** + /// See: https://tools.ietf.org/id/draft-ietf-oauth-token-exchange-12.html#rfc.section.2.2.1 + guard let exchangedRefreshToken = json["access_token"] as? String else { + throw OAuth2Error.generic("Exchange refresh token didn't return exchanged refresh token (response.access_token) [trace=\(traceId)]") + } + self.logger?.debug("OAuth2", msg: "Did use refresh token for exchanging refresh token [trace=\(traceId)]") + self.logger?.trace("OAuth2", msg: "Exchanged refresh token in [trace=\(traceId)] is [\(exchangedRefreshToken)]") + if self.useKeychain { + self.storeTokensToKeychain() + } + self.isExchangingRefreshToken = false + return exchangedRefreshToken + } catch { self.logger?.debug("OAuth2", msg: "Error exchanging refresh in [trace=\(traceId)] token: \(error)") self.isExchangingRefreshToken = false - callback(nil, error.asOAuth2Error) + throw error.asOAuth2Error } } @@ -529,34 +470,27 @@ open class OAuth2: OAuth2Base { - parameter resourcePath: The path of the resource requesting for its own access token - parameter params: Optional key/value pairs to pass during token exchange - - parameter callback: The callback to call after the exchange of resource access token has finished + - returns: Exchanged access token */ - open func doExchangeAccessTokenForResource(resourcePath: String, params: OAuth2StringDict? = nil, callback: @escaping ((String?, OAuth2Error?) -> Void)) { + open func doExchangeAccessTokenForResource(resourcePath: String, params: OAuth2StringDict? = nil) async throws -> String { do { let post = try tokenRequestForExchangeAccessTokenForResource(resourcePath: resourcePath, params: params).asURLRequest(for: self) logger?.debug("OAuth2", msg: "Exchanging access token for resource \(resourcePath) from \(post.url?.description ?? "nil")") - perform(request: post) { response in - do { - let data = try response.responseData() - let json = try self.parseAccessTokenResponse(data: data) - if response.response.statusCode >= 400 { - self.clientConfig.accessToken = nil - throw OAuth2Error.generic("Failed with status \(response.response.statusCode)") - } - guard let exchangedAccessToken = json["access_token"] as? String else { - throw OAuth2Error.generic("Exchange access token for resource didn't return exchanged access token (response.access_token)") - } - callback(exchangedAccessToken, nil) - } catch let error { - self.logger?.warn("OAuth2", msg: "Error exchanging access token for resource: \(error)") - - callback(nil, error.asOAuth2Error) - } + let response = await perform(request: post) + let data = try response.responseData() + let json = try self.parseAccessTokenResponse(data: data) + if response.response.statusCode >= 400 { + self.clientConfig.accessToken = nil + throw OAuth2Error.generic("Failed with status \(response.response.statusCode)") } + guard let exchangedAccessToken = json["access_token"] as? String else { + throw OAuth2Error.generic("Exchange access token for resource didn't return exchanged access token (response.access_token)") + } + return exchangedAccessToken } catch let error { self.logger?.debug("OAuth2", msg: "Error exchanging access token for resource \(resourcePath): \(error)") - callback(nil, error.asOAuth2Error) + throw error.asOAuth2Error } } @@ -569,27 +503,19 @@ open class OAuth2: OAuth2Base { calls `onBeforeDynamicClientRegistration()` -- if it is non-nil -- and uses the returned `OAuth2DynReg` instance -- if it is non-nil. If both are nil, instantiates a blank `OAuth2DynReg` instead, then attempts client registration. - - parameter callback: The callback to call on the main thread; if both json and error is nil no registration was attempted; error is nil - on success + - returns: JSON dictionary or nil if no registration was attempted; */ - public func registerClientIfNeeded(callback: @escaping ((OAuth2JSON?, OAuth2Error?) -> Void)) { + @MainActor + public func registerClientIfNeeded() async throws -> OAuth2JSON? { if nil != clientId || !type(of: self).clientIdMandatory { - callOnMainThread() { - callback(nil, nil) - } + return nil } else if let url = clientConfig.registrationURL { let dynreg = onBeforeDynamicClientRegistration?(url as URL) ?? OAuth2DynReg() - dynreg.register(client: self) { json, error in - callOnMainThread() { - callback(json, error?.asOAuth2Error) - } - } + return try await dynreg.register(client: self) } else { - callOnMainThread() { - callback(nil, OAuth2Error.noRegistrationURL) - } + throw OAuth2Error.noRegistrationURL } } } diff --git a/Sources/Flows/OAuth2ClientCredentials.swift b/Sources/Flows/OAuth2ClientCredentials.swift index eb5e8c3..1196f53 100644 --- a/Sources/Flows/OAuth2ClientCredentials.swift +++ b/Sources/Flows/OAuth2ClientCredentials.swift @@ -34,14 +34,14 @@ open class OAuth2ClientCredentials: OAuth2 { return OAuth2GrantTypes.clientCredentials } - override open func doAuthorize(params inParams: OAuth2StringDict? = nil) { - self.obtainAccessToken(params: inParams) { params, error in - if let error = error { + override open func doAuthorize(params inParams: OAuth2StringDict? = nil) async { + Task { + do { + let result = try await self.obtainAccessToken() + self.didAuthorize(withParameters: result) + } catch { self.didFail(with: error.asOAuth2Error) } - else { - self.didAuthorize(withParameters: params ?? OAuth2JSON()) - } } } @@ -71,27 +71,21 @@ open class OAuth2ClientCredentials: OAuth2 { Uses `accessTokenRequest(params:)` to create the request, which you can subclass to change implementation specifics. - - parameter callback: The callback to call after the process has finished + - returns: OAuth2 JSON dictionary */ - public func obtainAccessToken(params: OAuth2StringDict? = nil, callback: @escaping ((_ params: OAuth2JSON?, _ error: OAuth2Error?) -> Void)) { + public func obtainAccessToken(params: OAuth2StringDict? = nil) async throws -> OAuth2JSON { do { let post = try accessTokenRequest(params: params).asURLRequest(for: self) logger?.debug("OAuth2", msg: "Requesting new access token from \(post.url?.description ?? "nil")") - perform(request: post) { response in - do { - let data = try response.responseData() - let params = try self.parseAccessTokenResponse(data: data) - self.logger?.debug("OAuth2", msg: "Did get access token [\(nil != self.clientConfig.accessToken)]") - callback(params, nil) - } - catch let error { - callback(nil, error.asOAuth2Error) - } - } + let response = await perform(request: post) + let data = try response.responseData() + let params = try self.parseAccessTokenResponse(data: data) + self.logger?.debug("OAuth2", msg: "Did get access token [\(nil != self.clientConfig.accessToken)]") + return params } - catch let error { - callback(nil, error.asOAuth2Error) + catch { + throw error.asOAuth2Error } } } diff --git a/Sources/Flows/OAuth2CodeGrant.swift b/Sources/Flows/OAuth2CodeGrant.swift index 8d01288..a5a8b5d 100644 --- a/Sources/Flows/OAuth2CodeGrant.swift +++ b/Sources/Flows/OAuth2CodeGrant.swift @@ -81,7 +81,9 @@ open class OAuth2CodeGrant: OAuth2 { logger?.debug("OAuth2", msg: "Handling redirect URL \(redirect.description)") do { let code = try validateRedirectURL(redirect) - exchangeCodeForToken(code) + Task { + await exchangeCodeForToken(code) + } } catch let error { didFail(with: error.asOAuth2Error) @@ -93,7 +95,7 @@ open class OAuth2CodeGrant: OAuth2 { Uses `accessTokenRequest(params:)` to create the request, which you can subclass to change implementation specifics. */ - public func exchangeCodeForToken(_ code: String) { + public func exchangeCodeForToken(_ code: String) async { do { guard !code.isEmpty else { throw OAuth2Error.prerequisiteFailed("I don't have a code to exchange, let the user authorize first") @@ -102,22 +104,15 @@ open class OAuth2CodeGrant: OAuth2 { let post = try accessTokenRequest(with: code).asURLRequest(for: self) logger?.debug("OAuth2", msg: "Exchanging code \(code) for access token at \(post.url!)") - perform(request: post) { response in - do { - let data = try response.responseData() - let params = try self.parseAccessTokenResponse(data: data) - if response.response.statusCode >= 400 { - throw OAuth2Error.generic("Failed with status \(response.response.statusCode)") - } - self.logger?.debug("OAuth2", msg: "Did exchange code for access [\(nil != self.clientConfig.accessToken)] and refresh [\(nil != self.clientConfig.refreshToken)] tokens") - self.didAuthorize(withParameters: params) - } - catch let error { - self.didFail(with: error.asOAuth2Error) - } + let response = await perform(request: post) + let data = try response.responseData() + let params = try self.parseAccessTokenResponse(data: data) + if response.response.statusCode >= 400 { + throw OAuth2Error.generic("Failed with status \(response.response.statusCode)") } - } - catch let error { + self.logger?.debug("OAuth2", msg: "Did exchange code for access [\(nil != self.clientConfig.accessToken)] and refresh [\(nil != self.clientConfig.refreshToken)] tokens") + self.didAuthorize(withParameters: params) + } catch { didFail(with: error.asOAuth2Error) } } diff --git a/Sources/Flows/OAuth2DeviceGrant.swift b/Sources/Flows/OAuth2DeviceGrant.swift index 5d48ba0..edc90e2 100644 --- a/Sources/Flows/OAuth2DeviceGrant.swift +++ b/Sources/Flows/OAuth2DeviceGrant.swift @@ -79,28 +79,20 @@ open class OAuth2DeviceGrant: OAuth2 { /** Start the device authorization flow. - - parameter params: Optional key/value pairs to pass during authorize device request - - parameter callback: The callback to call after the device authorization response has been received + - parameter params: Optional key/value pairs to pass during authorize device request + - returns: The device authorization response. */ - public func start(useNonTextualTransmission: Bool = false, params: OAuth2StringDict? = nil, completion: @escaping (DeviceAuthorization?, Error?) -> Void) { - authorizeDevice(params: params) { result, error in - guard let result else { - if let error { - self.logger?.warn("OAuth2", msg: "Unable to get device code: \(error)") - } - completion(nil, error) - return - } + public func start(useNonTextualTransmission: Bool = false, params: OAuth2StringDict? = nil) async throws -> DeviceAuthorization { + do { + let result = try await authorizeDevice(params: params) guard let deviceCode = result["device_code"] as? String, let userCode = result["user_code"] as? String, let verificationUri = result["verification_uri"] as? String, let verificationUrl = URL(string: verificationUri), - let expiresIn = result["expires_in"] as? Int else { - let error = OAuth2Error.generic("The response doesn't contain all required fields.") - self.logger?.warn("OAuth2", msg: String(describing: error)) - completion(nil, error) - return + let expiresIn = result["expires_in"] as? Int + else { + throw OAuth2Error.generic("The response doesn't contain all required fields.") } var verificationUrlComplete: URL? @@ -109,86 +101,83 @@ open class OAuth2DeviceGrant: OAuth2 { } if useNonTextualTransmission, let url = verificationUrlComplete { - do { - try self.authorizer.openAuthorizeURLInBrowser(url) - } catch let error { - completion(nil, error) - } + try self.authorizer.openAuthorizeURLInBrowser(url) } let pollingInterval = result["interval"] as? TimeInterval ?? 5 - self.getDeviceAccessToken(deviceCode: deviceCode, interval: pollingInterval) { params, error in - if let params { + + Task { + do { + let params = try await self.getDeviceAccessToken(deviceCode: deviceCode, interval: pollingInterval) self.didAuthorize(withParameters: params) - } - else if let error { + } catch { self.didFail(with: error.asOAuth2Error) } } - let deviceAuthorization = DeviceAuthorization(userCode: userCode, verificationUrl: verificationUrl, verificationUrlComplete: verificationUrlComplete, expiresIn: expiresIn) - completion(deviceAuthorization, nil) + return DeviceAuthorization( + userCode: userCode, + verificationUrl: verificationUrl, + verificationUrlComplete: verificationUrlComplete, + expiresIn: expiresIn + ) + + } catch { + self.logger?.warn("OAuth2", msg: "Unable to get device code: \(error)") // TODO improve message to cover different scenarios + throw error } } - private func authorizeDevice(params: OAuth2StringDict?, completion: @escaping (OAuth2JSON?, Error?) -> Void) { + private func authorizeDevice(params: OAuth2StringDict?) async throws -> OAuth2JSON { do { let post = try deviceAuthorizationRequest(params: params).asURLRequest(for: self) logger?.debug("OAuth2", msg: "Obtaining device code from \(post.url!)") - perform(request: post) { response in - do { - let data = try response.responseData() - let params = try self.parseDeviceAuthorizationResponse(data: data) - completion(params, nil) - } - catch let error { - completion(nil, error.asOAuth2Error) - } - } + let response = await self.perform(request: post) + let data = try response.responseData() + return try self.parseDeviceAuthorizationResponse(data: data) + } catch let error { - completion(nil, error.asOAuth2Error) + throw error.asOAuth2Error } } - private func getDeviceAccessToken(deviceCode: String, interval: TimeInterval, completion: @escaping (OAuth2JSON?, Error?) -> Void) { + private func getDeviceAccessToken(deviceCode: String, interval: TimeInterval) async throws -> OAuth2JSON { do { let post = try deviceAccessTokenRequest(with: deviceCode).asURLRequest(for: self) logger?.debug("OAuth2", msg: "Obtaining access token for device with code \(deviceCode) from \(post.url!)") - perform(request: post) { response in - do { - let data = try response.responseData() - let params = try self.parseAccessTokenResponse(data: data) - completion(params, nil) - } - catch let error { - let oaerror = error.asOAuth2Error - - if oaerror == .authorizationPending(nil) { - self.logger?.debug("OAuth2", msg: "AuthorizationPending, repeating in \(interval) seconds.") - DispatchQueue.main.asyncAfter(deadline: .now() + interval) { - self.getDeviceAccessToken(deviceCode: deviceCode, interval: interval, completion: completion) - } - } else if oaerror == .slowDown(nil) { - let updatedInterval = interval + 5 // The 5 seconds increase is required by the RFC8628 standard (https://www.rfc-editor.org/rfc/rfc8628#section-3.5) - self.logger?.debug("OAuth2", msg: "SlowDown, repeating in \(updatedInterval) seconds.") - DispatchQueue.main.asyncAfter(deadline: .now() + updatedInterval) { - self.getDeviceAccessToken(deviceCode: deviceCode, interval: updatedInterval, completion: completion) - } - } else { - completion(nil, oaerror) - } - } - } + let response = await self.perform(request: post) + let data = try response.responseData() + return try self.parseAccessTokenResponse(data: data) } - catch let error { - completion(nil, error.asOAuth2Error) + catch { + let oaerror = error.asOAuth2Error + + if oaerror == .authorizationPending(nil) { + self.logger?.debug("OAuth2", msg: "AuthorizationPending, repeating in \(interval) seconds.") + try await Task.sleep(seconds: interval) + return try await self.getDeviceAccessToken(deviceCode: deviceCode, interval: interval) + } else if oaerror == .slowDown(nil) { + let updatedInterval = interval + 5 // The 5 seconds increase is required by the RFC8628 standard (https://www.rfc-editor.org/rfc/rfc8628#section-3.5) + self.logger?.debug("OAuth2", msg: "SlowDown, repeating in \(updatedInterval) seconds.") + try await Task.sleep(seconds: updatedInterval) + return try await self.getDeviceAccessToken(deviceCode: deviceCode, interval: updatedInterval) + } + + throw error.asOAuth2Error } } } +fileprivate extension Task where Success == Never, Failure == Never { + static func sleep(seconds: Double) async throws { + let duration = UInt64(seconds * 1_000_000_000) + try await Task.sleep(nanoseconds: duration) + } +} + public struct DeviceAuthorization { public let userCode: String public let verificationUrl: URL diff --git a/Sources/Flows/OAuth2DynReg.swift b/Sources/Flows/OAuth2DynReg.swift index 9a1abf5..c93d428 100644 --- a/Sources/Flows/OAuth2DynReg.swift +++ b/Sources/Flows/OAuth2DynReg.swift @@ -51,32 +51,25 @@ open class OAuth2DynReg { Register the given client. - parameter client: The client to register and update with client credentials, when successful - - parameter callback: The callback to call when done with the registration response (JSON) and/or an error + - returns: JSON response */ - open func register(client: OAuth2, callback: @escaping ((_ json: OAuth2JSON?, _ error: OAuth2Error?) -> Void)) { + open func register(client: OAuth2) async throws -> OAuth2JSON { do { let req = try registrationRequest(for: client) client.logger?.debug("OAuth2", msg: "Registering client at \(req.url!) with scopes “\(client.scope ?? "(none)")”") - client.perform(request: req) { response in - do { - let data = try response.responseData() - let dict = try self.parseRegistrationResponse(data: data, client: client) - try client.assureNoErrorInResponse(dict) - if response.response.statusCode >= 400 { - client.logger?.warn("OAuth2", msg: "Registration failed with \(response.response.statusCode)") - } - else { - self.didRegisterWith(json: dict, client: client) - } - callback(dict, nil) - } - catch let error { - callback(nil, error.asOAuth2Error) - } + + let response = await client.perform(request: req) + let data = try response.responseData() + let dict = try self.parseRegistrationResponse(data: data, client: client) + try client.assureNoErrorInResponse(dict) + if response.response.statusCode >= 400 { + client.logger?.warn("OAuth2", msg: "Registration failed with \(response.response.statusCode)") + } else { + self.didRegisterWith(json: dict, client: client) } - } - catch let error { - callback(nil, error.asOAuth2Error) + return dict + } catch { + throw error.asOAuth2Error } } diff --git a/Sources/Flows/OAuth2PasswordGrant.swift b/Sources/Flows/OAuth2PasswordGrant.swift index 8f9a186..728ce75 100644 --- a/Sources/Flows/OAuth2PasswordGrant.swift +++ b/Sources/Flows/OAuth2PasswordGrant.swift @@ -100,18 +100,16 @@ open class OAuth2PasswordGrant: OAuth2 { - parameter params: Optional key/value pairs to pass during authorization */ - override open func doAuthorize(params: OAuth2StringDict? = nil) throws { + override open func doAuthorize(params: OAuth2StringDict? = nil) async throws { if username?.isEmpty ?? true || password?.isEmpty ?? true { try askForCredentials() } else { - obtainAccessToken(params: params) { params, error in - if let error = error { - self.didFail(with: error) - } - else { - self.didAuthorize(withParameters: params ?? OAuth2JSON()) - } + do { + let resultParams = try await obtainAccessToken(params: params) + self.didAuthorize(withParameters: resultParams) + } catch { + self.didFail(with: error.asOAuth2Error) } } } @@ -150,27 +148,20 @@ open class OAuth2PasswordGrant: OAuth2 { - parameter username: The username to try against the server - parameter password: The password to try against the server - - parameter completionHandler: The closure to call once the server responded. The response's JSON is send if the server accepted the - given credentials. If the JSON is empty, see the error field for more information about the failure. + - returns: The response JSON */ - public func tryCredentials(username: String, password: String, errorHandler: @escaping (OAuth2Error) -> Void) { + @discardableResult public func tryCredentials(username: String, password: String) async throws -> OAuth2JSON { self.username = username self.password = password - // perform the request - obtainAccessToken(params: customAuthParams) { params, error in - - // reset credentials on error - if let error = error { - self.username = nil - self.password = nil - errorHandler(error) - } - - // automatically end the authorization process with a success - else { - self.didAuthorize(withParameters: params ?? OAuth2JSON()) - } + do { + let params = try await self.obtainAccessToken(params: customAuthParams) + self.didAuthorize(withParameters: params) + return params + } catch { + self.username = nil + self.password = nil + throw error } } @@ -224,37 +215,31 @@ open class OAuth2PasswordGrant: OAuth2 { Uses `accessTokenRequest(params:)` to create the request, which you can subclass to change implementation specifics. - parameter params: Optional key/value pairs to pass during authorization - - parameter callback: The callback to call after the request has returned + - returns:: OAuth2 JSON dictionary */ - public func obtainAccessToken(params: OAuth2StringDict? = nil, callback: @escaping ((_ params: OAuth2JSON?, _ error: OAuth2Error?) -> Void)) { + public func obtainAccessToken(params: OAuth2StringDict? = nil) async throws -> OAuth2JSON { do { let post = try accessTokenRequest(params: params).asURLRequest(for: self) logger?.debug("OAuth2", msg: "Requesting new access token from \(post.url?.description ?? "nil")") - perform(request: post) { response in - do { - let data = try response.responseData() - let dict = try self.parseAccessTokenResponse(data: data) - if response.response.statusCode >= 400 { - throw OAuth2Error.generic("Failed with status \(response.response.statusCode)") - } - self.logger?.debug("OAuth2", msg: "Did get access token [\(nil != self.clientConfig.accessToken)]") - callback(dict, nil) - } - catch OAuth2Error.unauthorizedClient { // TODO: which one is it? - callback(nil, OAuth2Error.wrongUsernamePassword) - } - catch OAuth2Error.forbidden { // TODO: which one is it? - callback(nil, OAuth2Error.wrongUsernamePassword) - } - catch let error { - self.logger?.debug("OAuth2", msg: "Error obtaining access token: \(error)") - callback(nil, error.asOAuth2Error) - } + let response = await self.perform(request: post) + let data = try response.responseData() + let dict = try self.parseAccessTokenResponse(data: data) + if response.response.statusCode >= 400 { + throw OAuth2Error.generic("Failed with status \(response.response.statusCode)") } + self.logger?.debug("OAuth2", msg: "Did get access token [\(nil != self.clientConfig.accessToken)]") + return dict + } + catch OAuth2Error.unauthorizedClient { // TODO: which one is it? + throw OAuth2Error.wrongUsernamePassword + } + catch OAuth2Error.forbidden { // TODO: which one is it? + throw OAuth2Error.wrongUsernamePassword } catch { - callback(nil, error.asOAuth2Error) + self.logger?.debug("OAuth2", msg: "Error obtaining access token: \(error)") + throw error } } }