Skip to content

Commit

Permalink
feat: codegen changes (#609)
Browse files Browse the repository at this point in the history
* Add codegen for service specific auth scheme resolver protocol, service specific default auth scheme resolver struct, and service specific auth scheme resolver parameters struct.

* Make ASR throw if passed in ASR params doesn't have region field for SigV4 auth scheme & fix ktlint issues.

* Clean up middlewares.

* Remove auth scheme and signing middlewares from operation stack of protocol tests.

* Update test cases to include new middlewares.

* Codegen more descriptive comment for empty service specific auth scheme resolver protocol.

* Add codegen test for auth scheme resolver generation.

* Move region in middleware context from sdk side to smithy side.

* Remove AWSClientRuntime dependency - signingProperties will be set in auth scheme customization hooks instead.

* Move auth schemes from service specific config to general AWS config.

---------

Co-authored-by: Sichan Yoo <[email protected]>
  • Loading branch information
sichanyoo and Sichan Yoo authored Nov 9, 2023
1 parent 07bb6b0 commit 1b62ed9
Show file tree
Hide file tree
Showing 20 changed files with 631 additions and 10 deletions.
14 changes: 12 additions & 2 deletions Sources/ClientRuntime/Auth/HTTPAuthAPI/AuthOption.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@

public struct AuthOption {
let schemeID: String
var identityProperties: Attributes
var signingProperties: Attributes
public var identityProperties: Attributes
public var signingProperties: Attributes

public init (
schemeID: String,
identityProperties: Attributes = Attributes(),
signingProperties: Attributes = Attributes()
) {
self.schemeID = schemeID
self.identityProperties = identityProperties
self.signingProperties = signingProperties
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
//

public protocol AuthSchemeResolver {
func resolveAuthScheme(params: AuthSchemeResolverParameters) -> [AuthOption]
func resolveAuthScheme(params: AuthSchemeResolverParameters) throws -> [AuthOption]
func constructParameters(context: HttpContext) throws -> AuthSchemeResolverParameters
}
4 changes: 4 additions & 0 deletions Sources/ClientRuntime/Middleware/Attribute.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,8 @@ public struct Attributes {
public mutating func remove<T>(key: AttributeKey<T>) {
attributes.removeValue(forKey: key.name)
}

public func getSize() -> Int {
return attributes.count
}
}
47 changes: 46 additions & 1 deletion Sources/ClientRuntime/Networking/Http/HttpContext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ public class HttpContext: MiddlewareContext {
return attributes.get(key: AttributeKeys.method)!
}

public func getOperation() -> String? {
return attributes.get(key: AttributeKeys.operation)
}

/// The partition ID to be used for this context.
///
/// Requests made with the same partition ID will be grouped together for retry throttling purposes.
Expand All @@ -79,6 +83,10 @@ public class HttpContext: MiddlewareContext {
return attributes.get(key: AttributeKeys.path)!
}

public func getRegion() -> String? {
return attributes.get(key: AttributeKeys.region)
}

public func getSelectedAuthScheme() -> SelectedAuthScheme? {
return attributes.get(key: AttributeKeys.selectedAuthScheme)
}
Expand All @@ -87,6 +95,14 @@ public class HttpContext: MiddlewareContext {
return attributes.get(key: AttributeKeys.serviceName)!
}

public func getSigningName() -> String? {
return attributes.get(key: AttributeKeys.signingName)
}

public func getSigningRegion() -> String? {
return attributes.get(key: AttributeKeys.signingRegion)
}

public func isBidirectionalStreamingEnabled() -> Bool {
return attributes.get(key: AttributeKeys.bidirectionalStreaming) ?? false
}
Expand Down Expand Up @@ -128,6 +144,14 @@ public class HttpContextBuilder {
return self
}

@discardableResult
public func withAuthSchemes(value: [AuthScheme]) -> HttpContextBuilder {
for scheme in value {
self.withAuthScheme(value: scheme)
}
return self
}

@discardableResult
public func withDecoder(value: ResponseDecoder) -> HttpContextBuilder {
self.attributes.set(key: AttributeKeys.decoder, value: value)
Expand Down Expand Up @@ -202,6 +226,12 @@ public class HttpContextBuilder {
return self
}

@discardableResult
public func withRegion(value: String?) -> HttpContextBuilder {
self.attributes.set(key: AttributeKeys.region, value: value)
return self
}

@discardableResult
public func withResponse(value: HttpResponse) -> HttpContextBuilder {
self.response = value
Expand All @@ -220,6 +250,18 @@ public class HttpContextBuilder {
return self
}

@discardableResult
public func withSigningName(value: String) -> HttpContextBuilder {
self.attributes.set(key: AttributeKeys.signingName, value: value)
return self
}

@discardableResult
public func withSigningRegion(value: String?) -> HttpContextBuilder {
self.attributes.set(key: AttributeKeys.signingRegion, value: value)
return self
}

public func build() -> HttpContext {
return HttpContext(attributes: attributes)
}
Expand All @@ -244,9 +286,12 @@ public enum AttributeKeys {
public static let operation = AttributeKey<String>(name: "Operation")
public static let partitionId = AttributeKey<String>(name: "PartitionID")
public static let path = AttributeKey<String>(name: "Path")
public static let region = AttributeKey<String>(name: "Region")
public static let selectedAuthScheme = AttributeKey<SelectedAuthScheme>(name: "SelectedAuthScheme")
public static let serviceName = AttributeKey<String>(name: "ServiceName")
public static let signingName = AttributeKey<String>(name: "SigningName")
public static let signingRegion = AttributeKey<String>(name: "SigningRegion")

// The attribute key used to store a credentials provider configured on service client config onto middleware context.
public static let awsIdResolver = AttributeKey<any IdentityResolver>(name: "AWSIDResolver")
public static let awsIdResolver = AttributeKey<any IdentityResolver>(name: "\(IdentityKind.aws)")
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ public struct AuthSchemeMiddleware<OperationStackOutput: HttpResponseBinding,
// Construct auth scheme resolver parameters
let resolverParams = try resolver.constructParameters(context: context)
// Retrieve valid auth options for the operation at hand
let validAuthOptions = resolver.resolveAuthScheme(params: resolverParams)
let validAuthOptions = try resolver.resolveAuthScheme(params: resolverParams)

// Create IdentityResolverConfiguration
guard let identityResolvers = context.getIdentityResolvers() else {
let identityResolvers = context.getIdentityResolvers()
guard let identityResolvers, identityResolvers.getSize() > 0 else {
throw ClientError.authError("No identity resolver has been configured on the service.")
}
let identityResolverConfig = DefaultIdentityResolverConfiguration(configuredIdResolvers: identityResolvers)
Expand Down Expand Up @@ -94,7 +95,7 @@ public struct AuthSchemeMiddleware<OperationStackOutput: HttpResponseBinding,
// If no auth scheme could be resolved, throw an error
guard let selectedAuthScheme else {
throw ClientError.authError(
"Could not resolve auth scheme for the operation call.\nLog:\n\(log.joined(separator: "\n"))"
"Could not resolve auth scheme for the operation call. Log: \(log.joined(separator: ","))"
)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
package software.amazon.smithy.swift.codegen

import software.amazon.smithy.aws.traits.ServiceTrait
import software.amazon.smithy.aws.traits.auth.SigV4Trait
import software.amazon.smithy.aws.traits.auth.UnsignedPayloadTrait
import software.amazon.smithy.model.knowledge.ServiceIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.AuthTrait
import software.amazon.smithy.model.traits.OptionalAuthTrait
import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.ServiceTypes
import software.amazon.smithy.swift.codegen.utils.clientName
import software.amazon.smithy.swift.codegen.utils.toLowerCamelCase

class AuthSchemeResolverGenerator() {
fun render(ctx: ProtocolGenerator.GenerationContext) {
val rootNamespace = ctx.settings.moduleName
val serviceIndex = ServiceIndex(ctx.model)

ctx.delegator.useFileWriter("./$rootNamespace/${ClientRuntimeTypes.Core.AuthSchemeResolver.name}.swift") {
renderResolverParams(serviceIndex, ctx, it)
it.write("")
renderResolverProtocol(ctx, it)
it.write("")
renderDefaultResolver(serviceIndex, ctx, it)
it.write("")
it.addImport(SwiftDependency.CLIENT_RUNTIME.target)
}
}

private fun renderResolverParams(
serviceIndex: ServiceIndex,
ctx: ProtocolGenerator.GenerationContext,
writer: SwiftWriter
) {
writer.apply {
openBlock(
"public struct ${getSdkId(ctx)}${ClientRuntimeTypes.Core.AuthSchemeResolverParameters.name}: \$L {",
"}",
ServiceTypes.AuthSchemeResolverParams
) {
write("public let operation: String")
// If service supports SigV4 auth scheme, it's a special-case
// Region has to be in params in addition to operation string from AuthSchemeResolver protocol
if (serviceIndex.getEffectiveAuthSchemes(ctx.service).contains(SigV4Trait.ID)) {
write("// Region is used for SigV4 auth scheme")
write("public let region: String?")
}
}
}
}

private fun renderResolverProtocol(ctx: ProtocolGenerator.GenerationContext, writer: SwiftWriter) {
writer.apply {
openBlock(
"public protocol ${getSdkId(ctx)}${ClientRuntimeTypes.Core.AuthSchemeResolver.name}: \$L {",
"}",
ServiceTypes.AuthSchemeResolver
) {
// This is just a parent protocol that all auth scheme resolvers of a given service must conform to.
write("// Intentionally empty.")
write("// This is the parent protocol that all auth scheme resolver implementations of")
write("// the service ${getSdkId(ctx)} must conform to.")
}
}
}

private fun renderDefaultResolver(
serviceIndex: ServiceIndex,
ctx: ProtocolGenerator.GenerationContext,
writer: SwiftWriter
) {
val sdkId = getSdkId(ctx)
val defaultResolverName = "Default$sdkId${ClientRuntimeTypes.Core.AuthSchemeResolver.name}"
val serviceProtocolName = sdkId + ClientRuntimeTypes.Core.AuthSchemeResolver.name

writer.apply {
openBlock(
"public struct \$L: \$L {",
"}",
defaultResolverName,
serviceProtocolName
) {
renderResolveAuthSchemeMethod(serviceIndex, ctx, writer)
write("")
renderConstructParametersMethod(
serviceIndex.getEffectiveAuthSchemes(ctx.service).contains(SigV4Trait.ID),
sdkId + ClientRuntimeTypes.Core.AuthSchemeResolverParameters.name,
writer
)
}
}
}

private fun renderResolveAuthSchemeMethod(
serviceIndex: ServiceIndex,
ctx: ProtocolGenerator.GenerationContext,
writer: SwiftWriter
) {
val sdkId = getSdkId(ctx)
val serviceParamsName = sdkId + ClientRuntimeTypes.Core.AuthSchemeResolverParameters.name

writer.apply {
openBlock(
"public func resolveAuthScheme(params: \$L) throws -> [AuthOption] {",
"}",
ServiceTypes.AuthSchemeResolverParams
) {
// Return value of array of auth options
write("var validAuthOptions = Array<AuthOption>()")

// Cast params to service specific params object
openBlock(
"guard let serviceParams = params as? \$L else {",
"}",
serviceParamsName
) {
write("throw ClientError.authError(\"Service specific auth scheme parameters type must be passed to auth scheme resolver.\")")
}

renderSwitchBlock(serviceIndex, ctx, this)
}
}
}

private fun renderSwitchBlock(
serviceIndex: ServiceIndex,
ctx: ProtocolGenerator.GenerationContext,
writer: SwiftWriter
) {
writer.apply {
// Switch block for iterating over operation name cases
openBlock("switch serviceParams.operation {", "}") {
// Handle each operation name case
val operations = ctx.service.operations
operations.filter { op ->
val opShape = ctx.model.getShape(op).get() as OperationShape
opShape.hasTrait(AuthTrait::class.java) ||
opShape.hasTrait(OptionalAuthTrait::class.java) ||
opShape.hasTrait(UnsignedPayloadTrait::class.java)
}.forEach { op ->
val opName = op.name.toLowerCamelCase()
val sdkId = getSdkId(ctx)
val validSchemesForOp = serviceIndex.getEffectiveAuthSchemes(
ctx.service, op, ServiceIndex.AuthSchemeMode.NO_AUTH_AWARE
)
renderOperationSwitchCase(
sdkId,
ctx.model.getShape(op).get() as OperationShape,
opName,
validSchemesForOp,
writer
)
}
// Handle default case, where operations default to auth schemes defined on service shape
val validSchemesForService = serviceIndex.getEffectiveAuthSchemes(ctx.service, ServiceIndex.AuthSchemeMode.NO_AUTH_AWARE)
renderDefaultSwitchCase(getSdkId(ctx), validSchemesForService, writer)
}

// Return result
write("return validAuthOptions")
}
}

private fun renderOperationSwitchCase(sdkId: String, opShape: OperationShape, opName: String, schemes: Map<ShapeId, Trait>, writer: SwiftWriter) {
writer.apply {
write("case \"$opName\":")
indent()
schemes.forEach {
if (it.key == SigV4Trait.ID) {
write("var sigV4Option = AuthOption(schemeID: \"${it.key}\")")
write("sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: \"${(it.value as SigV4Trait).name}\")")
openBlock("guard let region = serviceParams.region else {", "}") {
val errorMessage = "\"Missing region in auth scheme parameters for SigV4 auth scheme.\""
write("throw ClientError.authError($errorMessage)")
}
write("sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: region)")
write("validAuthOptions.append(sigV4Option)")
} else {
write("validAuthOptions.append(AuthOption(schemeID: \"${it.key}\"))")
}
}
dedent()
}
}

private fun renderDefaultSwitchCase(sdkId: String, schemes: Map<ShapeId, Trait>, writer: SwiftWriter) {
writer.apply {
write("default:")
indent()
schemes.forEach {
if (it.key == SigV4Trait.ID) {
write("var sigV4Option = AuthOption(schemeID: \"${it.key}\")")
write("sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: \"${(it.value as SigV4Trait).name}\")")
openBlock("guard let region = serviceParams.region else {", "}") {
val errorMessage = "\"Missing region in auth scheme parameters for SigV4 auth scheme.\""
write("throw ClientError.authError($errorMessage)")
}
write("sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: region)")
write("validAuthOptions.append(sigV4Option)")
} else {
write("validAuthOptions.append(AuthOption(schemeID: \"${it.key}\"))")
}
}
dedent()
}
}

private fun renderConstructParametersMethod(
hasSigV4: Boolean,
returnTypeName: String,
writer: SwiftWriter
) {
writer.apply {
openBlock(
"public func constructParameters(context: HttpContext) throws -> \$L {",
"}",
ServiceTypes.AuthSchemeResolverParams
) {
openBlock("guard let opName = context.getOperation() else {", "}") {
write("throw ClientError.dataNotFound(\"Operation name not configured in middleware context for auth scheme resolver params construction.\")")
}
if (hasSigV4) {
write("let opRegion = context.getRegion()")
write("return $returnTypeName(operation: opName, region: opRegion)")
} else {
write("return $returnTypeName(operation: opName)")
}
}
}
}

// Utility function for returning sdkId from generation context
fun getSdkId(ctx: ProtocolGenerator.GenerationContext): String {
return if (ctx.service.hasTrait(ServiceTrait::class.java))
ctx.service.getTrait(ServiceTrait::class.java).get().sdkId.clientName()
else ctx.settings.sdkId.clientName()
}
}
Loading

0 comments on commit 1b62ed9

Please sign in to comment.