diff --git a/Package.resolved b/Package.resolved index 8690cab..d80cc5a 100644 --- a/Package.resolved +++ b/Package.resolved @@ -5,8 +5,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/swift-server/async-http-client.git", "state" : { - "revision" : "fc510a39cff61b849bf5cdff17eb2bd6d0777b49", - "version" : "1.11.5" + "revision" : "16f7e62c08c6969899ce6cc277041e868364e5cf", + "version" : "1.19.0" } }, { @@ -14,8 +14,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/vapor/async-kit.git", "state" : { - "revision" : "c3329e444bafbb12d1d312af9191be95348a8175", - "version" : "1.13.0" + "revision" : "eab9edff78e8ace20bd7cb6e792ab46d54f59ab9", + "version" : "1.18.0" } }, { @@ -23,8 +23,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/vapor/console-kit.git", "state" : { - "revision" : "a7e67a1719933318b5ab7eaaed355cde020465b1", - "version" : "4.5.0" + "revision" : "9a12000f4064a2bdc49068d7258292ec1bdc88fc", + "version" : "4.7.0" } }, { @@ -32,8 +32,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/vapor/multipart-kit.git", "state" : { - "revision" : "0d55c35e788451ee27222783c7d363cb88092fab", - "version" : "4.5.2" + "revision" : "1adfd69df2da08f7931d4281b257475e32c96734", + "version" : "4.5.4" } }, { @@ -41,8 +41,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/vapor/routing-kit.git", "state" : { - "revision" : "ffac7b3a127ce1e85fb232f1a6271164628809ad", - "version" : "4.6.0" + "revision" : "e0539da5b60a60d7381f44cdcf04036f456cee2f", + "version" : "4.8.0" } }, { @@ -59,8 +59,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-atomics.git", "state" : { - "revision" : "919eb1d83e02121cdb434c7bfc1f0c66ef17febe", - "version" : "1.0.2" + "revision" : "6c89474e62719ddcc1e9614989fff2f68208fe10", + "version" : "1.1.0" } }, { @@ -77,8 +77,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-collections.git", "state" : { - "revision" : "f504716c27d2e5d4144fa4794b12129301d17729", - "version" : "1.0.3" + "revision" : "937e904258d22af6e447a0b72c0bc67583ef64a2", + "version" : "1.0.4" } }, { @@ -86,8 +86,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-crypto.git", "state" : { - "revision" : "d9825fa541df64b1a7b182178d61b9a82730d01f", - "version" : "2.1.0" + "revision" : "60f13f60c4d093691934dc6cfdf5f508ada1f894", + "version" : "2.6.0" } }, { @@ -95,8 +95,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-log.git", "state" : { - "revision" : "6fe203dc33195667ce1759bf0182975e4653ba1c", - "version" : "1.4.4" + "revision" : "532d8b529501fb73a2455b179e0bbb6d49b652ed", + "version" : "1.5.3" } }, { @@ -104,8 +104,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-metrics.git", "state" : { - "revision" : "53be78637ecd165d1ddedc4e20de69b8f43ec3b7", - "version" : "2.3.2" + "revision" : "971ba26378ab69c43737ee7ba967a896cb74c0d1", + "version" : "2.4.1" } }, { @@ -113,8 +113,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio.git", "state" : { - "revision" : "b4e0a274f7f34210e97e2f2c50ab02a10b549250", - "version" : "2.41.1" + "revision" : "cf281631ff10ec6111f2761052aa81896a83a007", + "version" : "2.58.0" } }, { @@ -122,8 +122,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio-extras.git", "state" : { - "revision" : "6c84d247754ad77487a6f0694273b89b83efd056", - "version" : "1.14.0" + "revision" : "0e0d0aab665ff1a0659ce75ac003081f2b1c8997", + "version" : "1.19.0" } }, { @@ -131,8 +131,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio-http2.git", "state" : { - "revision" : "f9ab1c94c80d568efd762d2a638f25162691d766", - "version" : "1.22.1" + "revision" : "a8ccf13fa62775277a5d56844878c828bbb3be1a", + "version" : "1.27.0" } }, { @@ -140,8 +140,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio-ssl.git", "state" : { - "revision" : "ba7c0d7f82affc518147ea61d240330bf7f7ea9b", - "version" : "2.22.1" + "revision" : "320bd978cceb8e88c125dcbb774943a92f6286e9", + "version" : "2.25.0" } }, { @@ -149,8 +149,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio-transport-services.git", "state" : { - "revision" : "4e02d9cf35cabfb538c96613272fb027dd0c8692", - "version" : "1.13.1" + "revision" : "e7403c35ca6bb539a7ca353b91cc2d8ec0362d58", + "version" : "1.19.0" } }, { @@ -167,8 +167,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/vapor/vapor.git", "state" : { - "revision" : "dda0de537e7906414dccd551e77095be1e34e3da", - "version" : "4.65.2" + "revision" : "1bb4a2ed94bec7a92f92e82896408c785d068f5c", + "version" : "4.79.0" } }, { @@ -176,8 +176,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/vapor/websocket-kit.git", "state" : { - "revision" : "2d9d2188a08eef4a869d368daab21b3c08510991", - "version" : "2.6.1" + "revision" : "53fe0639a98903858d0196b699720decb42aee7b", + "version" : "2.14.0" } } ], diff --git a/Sources/VaporOAuth/DefaultImplementations/EmptyCodeManager.swift b/Sources/VaporOAuth/DefaultImplementations/EmptyCodeManager.swift index b3d676d..8aa71a9 100644 --- a/Sources/VaporOAuth/DefaultImplementations/EmptyCodeManager.swift +++ b/Sources/VaporOAuth/DefaultImplementations/EmptyCodeManager.swift @@ -15,4 +15,14 @@ public struct EmptyCodeManager: CodeManager { } public func codeUsed(_ code: OAuthCode) {} + + public func getDeviceCode(_ deviceCode: String) -> OAuthDeviceCode? { + return nil + } + + public func generateDeviceCode(userID: String, clientID: String, scopes: [String]?) throws -> String { + return "" + } + + public func deviceCodeUsed(_ deviceCode: OAuthDeviceCode) {} } diff --git a/Sources/VaporOAuth/Models/OAuthDeviceCode.swift b/Sources/VaporOAuth/Models/OAuthDeviceCode.swift new file mode 100644 index 0000000..32ea682 --- /dev/null +++ b/Sources/VaporOAuth/Models/OAuthDeviceCode.swift @@ -0,0 +1,32 @@ +import Foundation + +public final class OAuthDeviceCode { + public let deviceCodeID: String + public let userCode: String + public let clientID: String + public let userID: String? + public let expiryDate: Date + public let scopes: [String]? + + public var extend: [String: Any] = [:] + + public init( + deviceCodeID: String, + userCode: String, + clientID: String, + userID: String?, + expiryDate: Date, + scopes: [String]? + ) { + self.deviceCodeID = deviceCodeID + self.userCode = userCode + self.clientID = clientID + self.userID = userID + self.expiryDate = expiryDate + self.scopes = scopes + } + + public var isExpired: Bool { + return Date() > expiryDate + } +} diff --git a/Sources/VaporOAuth/Protocols/CodeManager.swift b/Sources/VaporOAuth/Protocols/CodeManager.swift index f77695e..8a1a868 100644 --- a/Sources/VaporOAuth/Protocols/CodeManager.swift +++ b/Sources/VaporOAuth/Protocols/CodeManager.swift @@ -1,9 +1,44 @@ /// Responsible for generating and managing OAuth Codes public protocol CodeManager { + + /// Generates an OAuth code for the specified user, client, redirect URI, and scopes. + /// - Parameters: + /// - userID: The ID of the user. + /// - clientID: The ID of the client. + /// - redirectURI: The redirect URI for the client. + /// - scopes: Optional array of scopes. + /// - Returns: The generated OAuth code. + /// - Throws: An error if the code generation fails. func generateCode(userID: String, clientID: String, redirectURI: String, scopes: [String]?) async throws -> String + + /// Retrieves the OAuth code associated with the specified code. + /// - Parameter code: The OAuth code. + /// - Returns: The associated OAuth code, or `nil` if not found. + /// - Throws: An error if the retrieval fails. func getCode(_ code: String) async throws -> OAuthCode? - - // This is explicit to ensure that the code is marked as used or deleted (it could be implied that this is done when you call - // `getCode` but it is called explicitly to remind developers to ensure that codes can't be reused) + + /// Marks the specified OAuth code as used or deleted. + /// - Parameter code: The OAuth code to mark as used or deleted. + /// - Throws: An error if the operation fails. func codeUsed(_ code: OAuthCode) async throws + + /// Generates a device code for the specified user, client, and scopes. + /// - Parameters: + /// - userID: The ID of the user. + /// - clientID: The ID of the client. + /// - scopes: Optional array of scopes. + /// - Returns: The generated device code. + /// - Throws: An error if the code generation fails. + func generateDeviceCode(userID: String, clientID: String, scopes: [String]?) async throws -> String + + /// Retrieves the device code associated with the specified code. + /// - Parameter deviceCode: The device code. + /// - Returns: The associated device code, or `nil` if not found. + /// - Throws: An error if the retrieval fails. + func getDeviceCode(_ deviceCode: String) async throws -> OAuthDeviceCode? + + /// Marks the specified device code as used or deleted. + /// - Parameter deviceCode: The device code to mark as used or deleted. + /// - Throws: An error if the operation fails. + func deviceCodeUsed(_ deviceCode: OAuthDeviceCode) async throws } diff --git a/Sources/VaporOAuth/RouteHandlers/TokenHandler.swift b/Sources/VaporOAuth/RouteHandlers/TokenHandler.swift index 04ad049..91d2117 100644 --- a/Sources/VaporOAuth/RouteHandlers/TokenHandler.swift +++ b/Sources/VaporOAuth/RouteHandlers/TokenHandler.swift @@ -8,6 +8,7 @@ struct TokenHandler { let tokenResponseGenerator: TokenResponseGenerator let authCodeTokenHandler: AuthCodeTokenHandler let passwordTokenHandler: PasswordTokenHandler + let deviceCodeTokenHandler: DeviceCodeTokenHandler init(clientValidator: ClientValidator, tokenManager: TokenManager, scopeValidator: ScopeValidator, codeManager: CodeManager, userManager: UserManager, logger: Logger) { @@ -25,6 +26,9 @@ struct TokenHandler { passwordTokenHandler = PasswordTokenHandler(clientValidator: clientValidator, scopeValidator: scopeValidator, userManager: userManager, logger: logger, tokenManager: tokenManager, tokenResponseGenerator: tokenResponseGenerator) + deviceCodeTokenHandler = DeviceCodeTokenHandler(clientValidator: clientValidator, scopeValidator: scopeValidator, codeManager: codeManager, + tokenManager: tokenManager, + tokenResponseGenerator: tokenResponseGenerator) } func handleRequest(request: Request) async throws -> Response { @@ -42,6 +46,8 @@ struct TokenHandler { return try await clientCredentialsTokenHandler.handleClientCredentialsTokenRequest(request) case OAuthFlowType.refresh.rawValue: return try await refreshTokenHandler.handleRefreshTokenRequest(request) + case OAuthFlowType.deviceCode.rawValue: + return try await deviceCodeTokenHandler.handleDeviceCodeTokenRequest(request) default: return try tokenResponseGenerator.createResponse(error: OAuthResponseParameters.ErrorType.unsupportedGrant, description: "This server does not support the '\(grantType)' grant type") diff --git a/Sources/VaporOAuth/RouteHandlers/TokenHandlers/DeviceCodeTokenHandler.swift b/Sources/VaporOAuth/RouteHandlers/TokenHandlers/DeviceCodeTokenHandler.swift new file mode 100644 index 0000000..dc22d8a --- /dev/null +++ b/Sources/VaporOAuth/RouteHandlers/TokenHandlers/DeviceCodeTokenHandler.swift @@ -0,0 +1,64 @@ +import Vapor + +struct DeviceCodeTokenHandler { + + let clientValidator: ClientValidator + let scopeValidator: ScopeValidator + let codeManager: CodeManager + let tokenManager: TokenManager + let tokenResponseGenerator: TokenResponseGenerator + + func handleDeviceCodeTokenRequest(_ request: Request) async throws -> Response { + guard let deviceCodeString: String = request.content[OAuthRequestParameters.deviceCode] else { + return try tokenResponseGenerator.createResponse(error: OAuthResponseParameters.ErrorType.invalidRequest, + description: "Request was missing the 'device_code' parameter") + } + + guard let clientID: String = request.content[OAuthRequestParameters.clientID] else { + return try tokenResponseGenerator.createResponse(error: OAuthResponseParameters.ErrorType.invalidRequest, + description: "Request was missing the 'client_id' parameter") + } + + do { + try await clientValidator.authenticateClient(clientID: clientID, clientSecret: nil, + grantType: .deviceCode) + } catch { + return try tokenResponseGenerator.createResponse(error: OAuthResponseParameters.ErrorType.invalidClient, + description: "Request had invalid client credentials", status: .unauthorized) + } + + guard let deviceCode = try await codeManager.getDeviceCode(deviceCodeString) else { + let errorDescription = "The device code provided was invalid or expired" + return try tokenResponseGenerator.createResponse(error: OAuthResponseParameters.ErrorType.invalidGrant, + description: errorDescription) + } + + if deviceCode.expiryDate < Date() { + let errorDescription = "The device code provided was invalid or expired" + return try tokenResponseGenerator.createResponse(error: "expired_token", + description: errorDescription) + } + + if let scopes = deviceCode.scopes { + do { + try await scopeValidator.validateScope(clientID: clientID, scopes: scopes) + } catch ScopeError.invalid, ScopeError.unknown { + return try tokenResponseGenerator.createResponse(error: OAuthResponseParameters.ErrorType.invalidScope, + description: "Request contained an invalid or unknown scope") + } + } + + try await codeManager.deviceCodeUsed(deviceCode) + + let expiryTime = 3600 + + let (access, refresh) = try await tokenManager.generateAccessRefreshTokens( + clientID: clientID, userID: deviceCode.userID, + scopes: deviceCode.scopes, + accessTokenExpiryTime: expiryTime + ) + + return try tokenResponseGenerator.createResponse(accessToken: access, refreshToken: refresh, expires: Int(expiryTime), + scope: deviceCode.scopes?.joined(separator: " ")) + } +} diff --git a/Sources/VaporOAuth/Utilities/OAuthFlowType.swift b/Sources/VaporOAuth/Utilities/OAuthFlowType.swift index 6155f84..4f7c45a 100644 --- a/Sources/VaporOAuth/Utilities/OAuthFlowType.swift +++ b/Sources/VaporOAuth/Utilities/OAuthFlowType.swift @@ -5,4 +5,5 @@ public enum OAuthFlowType: String { case clientCredentials = "client_credentials" case refresh = "refresh_token" case tokenIntrospection = "token_introspection" + case deviceCode = "urn:ietf:params:oauth:grant-type:device_code" } diff --git a/Sources/VaporOAuth/Utilities/StringDefines.swift b/Sources/VaporOAuth/Utilities/StringDefines.swift index d331c66..c62b71b 100644 --- a/Sources/VaporOAuth/Utilities/StringDefines.swift +++ b/Sources/VaporOAuth/Utilities/StringDefines.swift @@ -13,6 +13,7 @@ struct OAuthRequestParameters { static let usernname = "username" static let csrfToken = "csrfToken" static let token = "token" + static let deviceCode = "device_code" } struct OAuthResponseParameters { @@ -39,6 +40,7 @@ struct OAuthResponseParameters { static let unsupportedGrant = "unsupported_grant_type" static let invalidGrant = "invalid_grant" static let missingToken = "missing_token" + static let expiredToken = "expired_token" } } diff --git a/Sources/VaporOAuth/Validators/ClientValidator.swift b/Sources/VaporOAuth/Validators/ClientValidator.swift index 0b66db4..e2b1af9 100644 --- a/Sources/VaporOAuth/Validators/ClientValidator.swift +++ b/Sources/VaporOAuth/Validators/ClientValidator.swift @@ -62,6 +62,10 @@ struct ClientValidator { throw ClientError.notFirstParty } } + + if grantType == .deviceCode { + + } } if checkConfidentialClient { diff --git a/Tests/VaporOAuthTests/Fakes/FakeCodeManager.swift b/Tests/VaporOAuthTests/Fakes/FakeCodeManager.swift index 2538d10..0883973 100644 --- a/Tests/VaporOAuthTests/Fakes/FakeCodeManager.swift +++ b/Tests/VaporOAuthTests/Fakes/FakeCodeManager.swift @@ -5,20 +5,36 @@ class FakeCodeManager: CodeManager { private(set) var usedCodes: [String] = [] var codes: [String: OAuthCode] = [:] + var deviceCodes: [String: OAuthDeviceCode] = [:] var generatedCode = UUID().uuidString func getCode(_ code: String) -> OAuthCode? { return codes[code] } + + func getDeviceCode(_ deviceCode: String) -> OAuthDeviceCode? { + return deviceCodes[deviceCode] + } func generateCode(userID: String, clientID: String, redirectURI: String, scopes: [String]?) throws -> String { let code = OAuthCode(codeID: generatedCode, clientID: clientID, redirectURI: redirectURI, userID: userID, expiryDate: Date().addingTimeInterval(60), scopes: scopes) codes[generatedCode] = code return generatedCode } + + func generateDeviceCode(userID: String, clientID: String, scopes: [String]?) throws -> String { // Added to generate a device code + let deviceCode = OAuthDeviceCode(deviceCodeID: generatedCode, userCode: "USER_CODE", clientID: clientID, userID: userID, expiryDate: Date().addingTimeInterval(60), scopes: scopes) + deviceCodes[generatedCode] = deviceCode + return generatedCode + } func codeUsed(_ code: OAuthCode) { usedCodes.append(code.codeID) codes.removeValue(forKey: code.codeID) } + + func deviceCodeUsed(_ deviceCode: OAuthDeviceCode) { + usedCodes.append(deviceCode.deviceCodeID) + deviceCodes.removeValue(forKey: deviceCode.deviceCodeID) + } } diff --git a/Tests/VaporOAuthTests/Fakes/FakeTokenManager.swift b/Tests/VaporOAuthTests/Fakes/FakeTokenManager.swift index 03b8766..1a7138e 100644 --- a/Tests/VaporOAuthTests/Fakes/FakeTokenManager.swift +++ b/Tests/VaporOAuthTests/Fakes/FakeTokenManager.swift @@ -2,39 +2,41 @@ import VaporOAuth import Foundation class FakeTokenManager: TokenManager { - + var accessTokenToReturn = "ACCESS-TOKEN-STRING" var refreshTokenToReturn = "REFRESH-TOKEN-STRING" var refreshTokens: [String: RefreshToken] = [:] var accessTokens: [String: AccessToken] = [:] + var deviceCodes: [String: OAuthDeviceCode] = [:] var currentTime = Date() - + func getRefreshToken(_ refreshToken: String) -> RefreshToken? { return refreshTokens[refreshToken] } - + func getAccessToken(_ accessToken: String) -> AccessToken? { return accessTokens[accessToken] } - + func generateAccessRefreshTokens(clientID: String, userID: String?, scopes: [String]?, accessTokenExpiryTime: Int) throws -> (AccessToken, RefreshToken) { let accessToken = FakeAccessToken(tokenString: accessTokenToReturn, clientID: clientID, userID: userID, scopes: scopes, expiryTime: currentTime.addingTimeInterval(TimeInterval(accessTokenExpiryTime))) let refreshToken = FakeRefreshToken(tokenString: refreshTokenToReturn, clientID: clientID, userID: userID, scopes: scopes) - + accessTokens[accessTokenToReturn] = accessToken refreshTokens[refreshTokenToReturn] = refreshToken return (accessToken, refreshToken) } - + func generateAccessToken(clientID: String, userID: String?, scopes: [String]?, expiryTime: Int) throws -> AccessToken { let accessToken = FakeAccessToken(tokenString: accessTokenToReturn, clientID: clientID, userID: userID, scopes: scopes, expiryTime: currentTime.addingTimeInterval(TimeInterval(expiryTime))) accessTokens[accessTokenToReturn] = accessToken return accessToken } - + func updateRefreshToken(_ refreshToken: RefreshToken, scopes: [String]) { var tempRefreshToken = refreshToken tempRefreshToken.scopes = scopes refreshTokens[refreshToken.tokenString] = tempRefreshToken } + } diff --git a/Tests/VaporOAuthTests/Fakes/StubCodeManager.swift b/Tests/VaporOAuthTests/Fakes/StubCodeManager.swift index 41c2bcf..db4d892 100644 --- a/Tests/VaporOAuthTests/Fakes/StubCodeManager.swift +++ b/Tests/VaporOAuthTests/Fakes/StubCodeManager.swift @@ -15,4 +15,20 @@ class StubCodeManager: CodeManager { func codeUsed(_ code: OAuthCode) { } + + func getDeviceCode(_ deviceCode: String) -> OAuthDeviceCode? { + + return nil + } + + func generateDeviceCode(userID: String, clientID: String, scopes: [String]?) throws -> String { + + return "DEVICE_CODE" + } + + func deviceCodeUsed(_ deviceCode: OAuthDeviceCode) { + + } + + } diff --git a/Tests/VaporOAuthTests/Fakes/StubTokenManager.swift b/Tests/VaporOAuthTests/Fakes/StubTokenManager.swift index ec5c71e..f1ac264 100644 --- a/Tests/VaporOAuthTests/Fakes/StubTokenManager.swift +++ b/Tests/VaporOAuthTests/Fakes/StubTokenManager.swift @@ -2,10 +2,11 @@ import VaporOAuth import Foundation class StubTokenManager: TokenManager { - + var accessToken = "ABCDEF" var refreshToken = "GHIJKL" - + var deviceCodes: [String: OAuthDeviceCode] = [:] + func generateAccessRefreshTokens(clientID: String, userID: String?, scopes: [String]?, accessTokenExpiryTime: Int) throws -> (AccessToken, RefreshToken) { let access = FakeAccessToken(tokenString: accessToken, clientID: clientID, userID: userID, scopes: scopes, expiryTime: Date()) let refresh = FakeRefreshToken(tokenString: refreshToken, clientID: clientID, userID: nil, scopes: scopes) diff --git a/Tests/VaporOAuthTests/GrantTests/DeviceCodeGrantTests.swift b/Tests/VaporOAuthTests/GrantTests/DeviceCodeGrantTests.swift new file mode 100644 index 0000000..85a7793 --- /dev/null +++ b/Tests/VaporOAuthTests/GrantTests/DeviceCodeGrantTests.swift @@ -0,0 +1,158 @@ +import XCTVapor +@testable import VaporOAuth + +class DeviceCodeTokenTests: XCTestCase { + struct ErrorResponse: Decodable { + var error: String + var errorDescription: String + + enum CodingKeys: String, CodingKey { + case error + case errorDescription = "error_description" + } + } + + struct SuccessResponse: Decodable { + var tokenType: String? + var expiresIn: Int? + var accessToken: String? + var refreshToken: String? + var scope: String? + + enum CodingKeys: String, CodingKey { + case tokenType = "token_type" + case expiresIn = "expires_in" + case accessToken = "access_token" + case refreshToken = "refresh_token" + case scope + } + } + + // MARK: - Properties + + var app: Application! + var fakeClientGetter: FakeClientGetter! + var fakeDeviceCodeManager: FakeCodeManager! + var fakeTokenManager: FakeTokenManager! + + let testClientID = "1234567890" + let testDeviceCodeID = "DEVICE_CODE_ID" + let userID = "the-user-id" + let scopes = ["email", "create"] + + // MARK: - Overrides + + override func setUp() async throws { + fakeClientGetter = FakeClientGetter() + fakeDeviceCodeManager = FakeCodeManager() + fakeTokenManager = FakeTokenManager() + + let oauthClient = OAuthClient( + clientID: testClientID, + redirectURIs: ["https://api.brokenhands.io/callback"], + clientSecret: nil, + allowedGrantType: .deviceCode + ) + fakeClientGetter.validClients[testClientID] = oauthClient + + let testDeviceCode = OAuthDeviceCode( + deviceCodeID: testDeviceCodeID, + userCode: "USER_CODE", + clientID: testClientID, + userID: userID, + expiryDate: Date().addingTimeInterval(60), + scopes: scopes + ) + + fakeDeviceCodeManager.deviceCodes[testDeviceCodeID] = testDeviceCode + + app = try TestDataBuilder.getOAuth2Application( + codeManager: fakeDeviceCodeManager, + tokenManager: fakeTokenManager, + clientRetriever: fakeClientGetter + ) + } + + override func tearDown() async throws { + app.shutdown() + try await super.tearDown() + } + + // MARK: - Tests + + func testCorrectErrorAndHeadersReceivedWhenNoGrantTypeSent() async throws { + let response = try await getDeviceCodeResponse(grantType: nil) + + XCTAssertEqual(response.status, .badRequest) + let errorResponse = try response.content.decode(ErrorResponse.self) + XCTAssertEqual(errorResponse.error, "invalid_request") + XCTAssertEqual(errorResponse.errorDescription, "Request was missing the 'grant_type' parameter") + } + + func testCorrectErrorAndHeadersReceivedWhenIncorrectGrantTypeSet() async throws { + let grantType = "some_unknown_type" + let response = try await getDeviceCodeResponse(grantType: grantType) + + XCTAssertEqual(response.status, .badRequest) + let errorResponse = try response.content.decode(ErrorResponse.self) + XCTAssertEqual(errorResponse.error, "unsupported_grant_type") + XCTAssertEqual(errorResponse.errorDescription, "This server does not support the 'some_unknown_type' grant type") + } + + func testCorrectErrorAndHeadersReceivedWhenNoDeviceCodeSent() async throws { + let response = try await getDeviceCodeResponse(deviceCode: nil) + + XCTAssertEqual(response.status, .badRequest) + let errorResponse = try response.content.decode(ErrorResponse.self) + XCTAssertEqual(errorResponse.error, "invalid_request") + XCTAssertEqual(errorResponse.errorDescription, "Request was missing the 'device_code' parameter") + } + + func testCorrectErrorCodeWhenDeviceCodeIsExpired() async throws { + let expiredDeviceCodeID = "expiredDeviceCodeID" + let expiredDeviceCode = OAuthDeviceCode( + deviceCodeID: expiredDeviceCodeID, + userCode: "USER_CODE", + clientID: testClientID, + userID: userID, + expiryDate: Date().addingTimeInterval(-60), // Expired 60 seconds ago + scopes: scopes + ) + fakeDeviceCodeManager.deviceCodes[expiredDeviceCodeID] = expiredDeviceCode + + let response = try await getDeviceCodeResponse(deviceCode: expiredDeviceCodeID) + + XCTAssertEqual(response.status, .badRequest) + let errorResponse = try response.content.decode(ErrorResponse.self) + XCTAssertEqual(errorResponse.error, "expired_token") + XCTAssertEqual(errorResponse.errorDescription, "The device code provided was invalid or expired") + } + + func testThatCorrectResponseReceivedWhenCorrectRequestSent() async throws { + let response = try await getDeviceCodeResponse() + + XCTAssertEqual(response.status, .ok) + let successResponse = try response.content.decode(SuccessResponse.self) + XCTAssertEqual(successResponse.tokenType, "bearer") + XCTAssertNotNil(successResponse.accessToken) + XCTAssertNotNil(successResponse.expiresIn) + XCTAssertEqual(successResponse.scope, scopes.joined(separator: " ")) + } + + // MARK: - Private + + private func getDeviceCodeResponse( + grantType: String? = "urn:ietf:params:oauth:grant-type:device_code", + deviceCode: String? = "DEVICE_CODE_ID", + clientID: String? = "1234567890" + ) async throws -> XCTHTTPResponse { + return try await TestDataBuilder.getTokenRequestResponse( + with: app, + grantType: grantType, + clientID: clientID, + clientSecret: nil, + deviceCode: deviceCode + ) + } + +} diff --git a/Tests/VaporOAuthTests/Helpers/TestDataBuilder.swift b/Tests/VaporOAuthTests/Helpers/TestDataBuilder.swift index f3dfb1a..0e2c7bd 100644 --- a/Tests/VaporOAuthTests/Helpers/TestDataBuilder.swift +++ b/Tests/VaporOAuthTests/Helpers/TestDataBuilder.swift @@ -62,7 +62,8 @@ class TestDataBuilder { scope: String? = nil, username: String? = nil, password: String? = nil, - refreshToken: String? = nil + refreshToken: String? = nil, + deviceCode: String? = nil ) async throws -> XCTHTTPResponse { struct RequestData: Content { var grantType: String? @@ -74,7 +75,8 @@ class TestDataBuilder { var username: String? var password: String? var refreshToken: String? - + var deviceCode: String? + enum CodingKeys: String, CodingKey { case username, password, scope, code case grantType = "grant_type" @@ -82,6 +84,7 @@ class TestDataBuilder { case clientSecret = "client_secret" case redirectURI = "redirect_uri" case refreshToken = "refresh_token" + case deviceCode = "device_code" } } @@ -94,7 +97,8 @@ class TestDataBuilder { scope: scope, username: username, password: password, - refreshToken: refreshToken + refreshToken: refreshToken, + deviceCode: deviceCode ) return try await withCheckedThrowingContinuation { continuation in