Skip to content

Commit

Permalink
feat: add support for requiresLength trait and Transfer-Encoding: Chu…
Browse files Browse the repository at this point in the history
…nked (#604)
  • Loading branch information
dayaffe authored Oct 23, 2023
1 parent 62ad3a2 commit 818a5c7
Show file tree
Hide file tree
Showing 8 changed files with 294 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@ public struct ContentLengthMiddleware<OperationStackOutput: HttpResponseBinding>

private let contentLengthHeaderName = "Content-Length"

public init() {}
private var requiresLength: Bool = false

private var unsignedPayload: Bool = false

public init(requiresLength: Bool = false, unsignedPayload: Bool = false) {
self.requiresLength = requiresLength
self.unsignedPayload = unsignedPayload
}

public func handle<H>(context: Context,
input: MInput,
Expand All @@ -22,8 +29,16 @@ public struct ContentLengthMiddleware<OperationStackOutput: HttpResponseBinding>
case .stream(let stream):
if let length = stream.length {
input.headers.update(name: "Content-Length", value: String(length))
} else if !requiresLength && unsignedPayload {
// only for HTTP/1.1 requests, will be removed in all HTTP/2 requests
input.headers.update(name: "Transfer-Encoding", value: "Chunked")
} else {
input.headers.update(name: "Transfer-Encoded", value: "Chunked")
let operation = context.attributes.get(key: AttributeKey<String>(name: "Operation"))
?? "Error getting operation name"
let errorMessage = unsignedPayload ?
"Missing content-length for operation: \(operation)" :
"Missing content-length for SigV4 signing on operation: \(operation)"
throw StreamError.notSupported(errorMessage)
}
default:
input.headers.update(name: "Content-Length", value: "0")
Expand Down
3 changes: 3 additions & 0 deletions Sources/ClientRuntime/Networking/Http/SdkHttpRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ extension SdkHttpRequest {
httpRequest.path = [endpoint.path, endpoint.queryItemString].compactMap { $0 }.joined(separator: "?")
httpRequest.addHeaders(headers: headers.toHttpHeaders())

// Remove the "Transfer-Encoding" header if it exists since h2 does not support it
httpRequest.removeHeader(name: "Transfer-Encoding")

// HTTP2Request used with manual writes hence we need to set the body to nil
// so that CRT does not write the body for us (we will write it manually)
httpRequest.body = nil
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0.

import XCTest
import SmithyTestUtil
@testable import ClientRuntime

class ContentLengthMiddlewareTests: XCTestCase {
private var builtContext: HttpContext!
private var stack: OperationStack<MockInput, MockOutput, MockMiddlewareError>!

override func setUpWithError() throws {
try super.setUpWithError()
builtContext = HttpContextBuilder()
.withMethod(value: .get)
.withPath(value: "/")
.withEncoder(value: JSONEncoder())
.withDecoder(value: JSONDecoder())
.withOperation(value: "Test Operation")
.build()
stack = OperationStack<MockInput, MockOutput, MockMiddlewareError>(id: "Test Operation")
}

func testTransferEncodingChunkedSetWhenStreamLengthIsNil() async throws {
addContentLengthMiddlewareWith(requiresLength: false, unsignedPayload: true)
forceEmptyStream()
try await AssertHeadersArePresent(expectedHeaders: ["Transfer-Encoding": "Chunked"])
}

func testContentLengthSetWhenStreamLengthAvailableAndRequiresLengthSet() async throws {
addContentLengthMiddlewareWith(requiresLength: true, unsignedPayload: false)
try await AssertHeadersArePresent(expectedHeaders: ["Content-Length": "0"])
}

func testContentLengthSetWhenRequiresLengthAndUnsignedPayload() async throws {
addContentLengthMiddlewareWith(requiresLength: true, unsignedPayload: true)
try await AssertHeadersArePresent(expectedHeaders: ["Content-Length": "0"])
}

func testRequiresLengthSetWithNilStreamShouldThrowError() async throws {
addContentLengthMiddlewareWith(requiresLength: true, unsignedPayload: false)
forceEmptyStream()
do {
try await AssertHeadersArePresent(expectedHeaders: ["Content-Length": "0"])
XCTFail("Should throw error")
} catch let error as StreamError {
switch error {
case .notSupported("Missing content-length for SigV4 signing on operation: Test Operation"), .notSupported("Missing content-length for operation: Test Operation"):
// The error matches one of the expected cases, test passes
break
default:
XCTFail("Error is not StreamError.notSupported with expected message")
}
}
}

private func addContentLengthMiddlewareWith(requiresLength: Bool, unsignedPayload: Bool) {
stack.finalizeStep.intercept(
position: .before,
middleware: ContentLengthMiddleware(requiresLength: requiresLength, unsignedPayload: unsignedPayload)
)
}

private func forceEmptyStream() {
// Force stream length to be nil
stack.finalizeStep.intercept(position: .before, id: "set nil stream length") { (context, input, next) -> OperationOutput<MockOutput> in
input.body = .stream(BufferedStream()) // Set the stream length to nil
return try await next.handle(context: context, input: input)
}
}

private func AssertHeadersArePresent(expectedHeaders: [String: String], file: StaticString = #file, line: UInt = #line) async throws -> Void {
let mockHandler = MockHandler { (_, input) in
for (key, value) in expectedHeaders {
XCTAssert(input.headers.value(for: key) == value, file: file, line: line)
}
let httpResponse = HttpResponse(body: HttpBody.none, statusCode: HttpStatusCode.ok)
let mockOutput = try! MockOutput(httpResponse: httpResponse, decoder: nil)
let output = OperationOutput<MockOutput>(httpResponse: httpResponse, output: mockOutput)
return output
}

_ = try await stack.handleMiddleware(context: builtContext, input: MockInput(), next: mockHandler)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package software.amazon.smithy.swift.codegen.integration

import software.amazon.smithy.aws.traits.auth.UnsignedPayloadTrait
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBindingIndex
Expand All @@ -30,6 +31,7 @@ import software.amazon.smithy.model.traits.HttpPrefixHeadersTrait
import software.amazon.smithy.model.traits.HttpQueryParamsTrait
import software.amazon.smithy.model.traits.HttpQueryTrait
import software.amazon.smithy.model.traits.MediaTypeTrait
import software.amazon.smithy.model.traits.RequiresLengthTrait
import software.amazon.smithy.model.traits.StreamingTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
Expand Down Expand Up @@ -60,6 +62,7 @@ import software.amazon.smithy.swift.codegen.integration.serde.UnionEncodeGenerat
import software.amazon.smithy.swift.codegen.middleware.OperationMiddlewareGenerator
import software.amazon.smithy.swift.codegen.model.ShapeMetadata
import software.amazon.smithy.swift.codegen.model.bodySymbol
import software.amazon.smithy.swift.codegen.model.findStreamingMember
import software.amazon.smithy.swift.codegen.model.hasEventStreamMember
import software.amazon.smithy.swift.codegen.model.hasTrait
import software.amazon.smithy.utils.OptionalUtils
Expand Down Expand Up @@ -91,9 +94,8 @@ fun formatHeaderOrQueryValue(
memberShape: MemberShape,
location: HttpBinding.Location,
bindingIndex: HttpBindingIndex,
defaultTimestampFormat: TimestampFormatTrait.Format
defaultTimestampFormat: TimestampFormatTrait.Format,
): Pair<String, Boolean> {

return when (val shape = ctx.model.expectShape(memberShape.target)) {
is TimestampShape -> {
val timestampFormat = bindingIndex.determineTimestampFormat(memberShape, location, defaultTimestampFormat)
Expand Down Expand Up @@ -165,7 +167,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
writer.openBlock(
"extension $symbolName: \$N {",
"}",
SwiftTypes.Protocols.Encodable
SwiftTypes.Protocols.Encodable,
) {
writer.addImport(SwiftDependency.CLIENT_RUNTIME.target)

Expand Down Expand Up @@ -286,7 +288,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
private fun generateCodingKeysForMembers(
ctx: ProtocolGenerator.GenerationContext,
writer: SwiftWriter,
members: List<MemberShape>
members: List<MemberShape>,
) {
codingKeysGenerator.generateCodingKeysForMembers(ctx, writer, members)
}
Expand All @@ -298,7 +300,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
val inputType = ctx.model.expectShape(operation.input.get())
var metadata = mapOf<ShapeMetadata, Any>(
Pair(ShapeMetadata.OPERATION_SHAPE, operation),
Pair(ShapeMetadata.SERVICE_VERSION, ctx.service.version)
Pair(ShapeMetadata.SERVICE_VERSION, ctx.service.version),
)
shapesInfo.put(inputType, metadata)
}
Expand Down Expand Up @@ -336,7 +338,6 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
}

private fun resolveShapesNeedingCodableConformance(ctx: ProtocolGenerator.GenerationContext): Set<Shape> {

val topLevelOutputMembers = getHttpBindingOperations(ctx).flatMap {
val outputShape = ctx.model.expectShape(it.output.get())
outputShape.members()
Expand Down Expand Up @@ -390,7 +391,8 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
RelationshipType.LIST_MEMBER,
RelationshipType.SET_MEMBER,
RelationshipType.MAP_VALUE,
RelationshipType.UNION_MEMBER -> true
RelationshipType.UNION_MEMBER,
-> true
else -> false
}
}.forEach {
Expand All @@ -403,6 +405,29 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
return resolved
}

// Checks for @requiresLength trait
// Returns true if the operation:
// - has a streaming member with @httpPayload trait
// - target is a blob shape with @requiresLength trait
private fun hasRequiresLengthTrait(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean {
if (op.input.isPresent) {
val inputShape = ctx.model.expectShape(op.input.get())
val streamingMember = inputShape.findStreamingMember(ctx.model)
if (streamingMember != null) {
val targetShape = ctx.model.expectShape(streamingMember.target)
if (targetShape != null) {
return streamingMember.hasTrait<HttpPayloadTrait>() &&
targetShape.isBlobShape &&
targetShape.hasTrait<RequiresLengthTrait>()
}
}
}
return false
}

// Checks for @unsignedPayload trait on an operation
private fun hasUnsignedPayloadTrait(op: OperationShape): Boolean = op.hasTrait<UnsignedPayloadTrait>()

override fun generateProtocolClient(ctx: ProtocolGenerator.GenerationContext) {
val symbol = ctx.symbolProvider.toSymbol(ctx.service)
ctx.delegator.useFileWriter("./${ctx.settings.moduleName}/${symbol.name}.swift") { writer ->
Expand All @@ -414,7 +439,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
serviceSymbol.name,
defaultContentType,
httpProtocolCustomizable,
operationMiddleware
operationMiddleware,
)
clientGenerator.render()
}
Expand All @@ -433,7 +458,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
operationMiddleware.appendMiddleware(operation, ContentTypeMiddleware(ctx.model, ctx.symbolProvider, resolver.determineRequestContentType(operation)))
operationMiddleware.appendMiddleware(operation, OperationInputBodyMiddleware(ctx.model, ctx.symbolProvider))

operationMiddleware.appendMiddleware(operation, ContentLengthMiddleware(ctx.model, shouldRenderEncodableConformance))
operationMiddleware.appendMiddleware(operation, ContentLengthMiddleware(ctx.model, shouldRenderEncodableConformance, hasRequiresLengthTrait(ctx, operation), hasUnsignedPayloadTrait(operation)))

operationMiddleware.appendMiddleware(operation, DeserializeMiddleware(ctx.model, ctx.symbolProvider))
operationMiddleware.appendMiddleware(operation, LoggingMiddleware(ctx.model, ctx.symbolProvider))
Expand Down Expand Up @@ -463,15 +488,15 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
members: List<MemberShape>,
writer: SwiftWriter,
defaultTimestampFormat: TimestampFormatTrait.Format,
path: String? = null
path: String? = null,
)
protected abstract fun renderStructDecode(
ctx: ProtocolGenerator.GenerationContext,
shapeMetaData: Map<ShapeMetadata, Any>,
members: List<MemberShape>,
writer: SwiftWriter,
defaultTimestampFormat: TimestampFormatTrait.Format,
path: String
path: String,
)
protected abstract fun addProtocolSpecificMiddleware(ctx: ProtocolGenerator.GenerationContext, operation: OperationShape)

Expand All @@ -487,11 +512,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
for (operation in topDownIndex.getContainedOperations(ctx.service)) {
OptionalUtils.ifPresentOrElse(
Optional.of(getProtocolHttpBindingResolver(ctx, defaultContentType).httpTrait(operation)::class.java),
{ containedOperations.add(operation) }
{ containedOperations.add(operation) },
) {
LOGGER.warning(
"Unable to fetch $protocolName protocol request bindings for ${operation.id} because " +
"it does not have an http binding trait"
"it does not have an http binding trait",
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import software.amazon.smithy.swift.codegen.middleware.MiddlewarePosition
import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable
import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep

class ContentLengthMiddleware(val model: Model, private val alwaysIntercept: Boolean) : MiddlewareRenderable {
class ContentLengthMiddleware(val model: Model, private val alwaysIntercept: Boolean, private val requiresLength: Boolean, private val unsignedPayload: Boolean) : MiddlewareRenderable {

override val name = "ContentLengthMiddleware"

Expand All @@ -20,17 +20,17 @@ class ContentLengthMiddleware(val model: Model, private val alwaysIntercept: Boo
override fun render(
writer: SwiftWriter,
op: OperationShape,
operationStackName: String
operationStackName: String,
) {
val hasHttpBody = MiddlewareShapeUtils.hasHttpBody(model, op)
if (hasHttpBody || alwaysIntercept) {
writer.write(
"\$L.\$L.intercept(position: \$L, middleware: \$N())",
operationStackName,
middlewareStep.stringValue(),
position.stringValue(),
ClientRuntimeTypes.Middleware.ContentLengthMiddleware
)
val str = "requiresLength: $requiresLength, unsignedPayload: $unsignedPayload"
val middlewareArgs = str.takeIf { requiresLength || unsignedPayload } ?: ""

val interceptStatement = "$operationStackName.${middlewareStep.stringValue()}.intercept(" +
"position: ${position.stringValue()}, middleware: ${ClientRuntimeTypes.Middleware.ContentLengthMiddleware}($middlewareArgs))"

writer.write(interceptStatement)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class TestHttpProtocolClientGeneratorFactory : HttpProtocolClientGeneratorFactor
private fun getClientProperties(ctx: ProtocolGenerator.GenerationContext): List<ClientProperty> {
return mutableListOf(
DefaultRequestEncoder(),
DefaultResponseDecoder()
DefaultResponseDecoder(),
)
}

Expand Down Expand Up @@ -125,6 +125,7 @@ extension InlineDocumentAsPayloadOutput: ClientRuntime.HttpResponseBinding {
""".trimIndent()
contents.shouldContainOnlyOnce(expectedContents)
}

@Test
fun `default fooMap to an empty map if keysForFooMap is empty`() {
val contents = getModelFileContents("example", "HttpPrefixHeadersOutput+HttpResponseBinding.swift", newTestContext.manifest)
Expand Down
Loading

0 comments on commit 818a5c7

Please sign in to comment.