Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor URL types and supporting classes #989

Merged
merged 19 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,21 @@ object RuntimeTypes {
val SdkManagedGroup = symbol("SdkManagedGroup")
val addIfManaged = symbol("addIfManaged", isExtension = true)
}

object Text : RuntimeTypePackage(KotlinDependency.CORE, "text") {
object Encoding : RuntimeTypePackage(KotlinDependency.CORE, "text.encoding") {
val decodeBase64 = symbol("decodeBase64")
val decodeBase64Bytes = symbol("decodeBase64Bytes")
val encodeBase64 = symbol("encodeBase64")
val encodeBase64String = symbol("encodeBase64String")
}
}

object Utils : RuntimeTypePackage(KotlinDependency.CORE, "util") {
val Attributes = symbol("Attributes")
val MutableAttributes = symbol("MutableAttributes")
val attributesOf = symbol("attributesOf")
val AttributeKey = symbol("AttributeKey")
val decodeBase64 = symbol("decodeBase64")
val decodeBase64Bytes = symbol("decodeBase64Bytes")
val encodeBase64 = symbol("encodeBase64")
val encodeBase64String = symbol("encodeBase64String")
val ExpiringValue = symbol("ExpiringValue")
val flattenIfPossible = symbol("flattenIfPossible")
val get = symbol("get")
Expand All @@ -162,20 +168,17 @@ object RuntimeTypes {
val putIfAbsentNotNull = symbol("putIfAbsentNotNull")
val ReadThroughCache = symbol("ReadThroughCache")
val truthiness = symbol("truthiness")
val urlEncodeComponent = symbol("urlEncodeComponent", "text")
val toNumber = symbol("toNumber")
val type = symbol("type")
}

object Net : RuntimeTypePackage(KotlinDependency.CORE, "net") {
val Host = symbol("Host")
val parameters = symbol("parameters")
val QueryParameters = symbol("QueryParameters")
val QueryParametersBuilder = symbol("QueryParametersBuilder")
val splitAsQueryParameters = symbol("splitAsQueryParameters")
val toQueryParameters = symbol("toQueryParameters")
val Url = symbol("Url")
val UrlDecoding = symbol("UrlDecoding")

object Url : RuntimeTypePackage(KotlinDependency.CORE, "net.url") {
val QueryParameters = symbol("QueryParameters")
val Url = symbol("Url")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ class DefaultEndpointProviderGenerator(
private fun renderEndpointRule(rule: EndpointRule) {
withConditions(rule.conditions) {
writer.withBlock("return #T(", ")", RuntimeTypes.SmithyClient.Endpoints.Endpoint) {
writeInline("#T.parse(", RuntimeTypes.Core.Net.Url)
writeInline("#T.parse(", RuntimeTypes.Core.Net.Url.Url)
renderExpression(rule.endpoint.url)
write(", #1T.DecodeAll - #1T.DecodePath),", RuntimeTypes.Core.Net.UrlDecoding)
write("),")

if (rule.endpoint.headers.isNotEmpty()) {
withBlock("headers = #T {", "},", RuntimeTypes.Http.Headers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,7 @@ class DefaultEndpointProviderTestGenerator(
}

writer.withBlock("val expected = #T(", ")", RuntimeTypes.SmithyClient.Endpoints.Endpoint) {
write(
"uri = #1T.parse(#2S, #3T.DecodeAll - #3T.DecodePath),",
RuntimeTypes.Core.Net.Url,
endpoint.url,
RuntimeTypes.Core.Net.UrlDecoding,
)
write("uri = #T.parse(#S),", RuntimeTypes.Core.Net.Url.Url, endpoint.url)

if (endpoint.headers.isNotEmpty()) {
withBlock("headers = #T {", "},", RuntimeTypes.Http.Headers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,46 +314,47 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
renderNonBlankGuard(ctx, binding.member, writer)
}

writer.openBlock("val pathSegments = listOf<String>(", ")") {
writer.withBlock("path.decodedSegments {", "}") {
httpTrait.uri.segments.forEach { segment ->
if (segment.isLabel || segment.isGreedyLabel) {
// spec dictates member name and label name MUST be the same
val binding = pathBindings.find { binding ->
binding.memberName == segment.content
} ?: throw CodegenException("failed to find corresponding member for httpLabel `${segment.content}")
}
?: throw CodegenException("failed to find corresponding member for httpLabel `${segment.content}")

// shape must be string, number, boolean, or timestamp
val targetShape = ctx.model.expectShape(binding.member.target)
val memberSymbol = ctx.symbolProvider.toSymbol(binding.member)
val identifier = if (targetShape.isTimestampShape) {
writer.addImport(RuntimeTypes.Core.TimestampFormat)
addImport(RuntimeTypes.Core.TimestampFormat)
val tsFormat = resolver.determineTimestampFormat(
binding.member,
HttpBinding.Location.LABEL,
defaultTimestampFormat,
)
val nullCheck = if (memberSymbol.isNullable) "?" else ""
val tsLabel = formatInstant("input.${binding.member.defaultName()}$nullCheck", tsFormat, forceString = true)
val tsLabel = formatInstant(
"input.${binding.member.defaultName()}$nullCheck",
tsFormat,
forceString = true,
)
tsLabel
} else {
"input.${binding.member.defaultName()}"
}

val encodeSymbol = RuntimeTypes.Http.Util.encodeLabel
val encodeFn = if (segment.isGreedyLabel) {
writer.format("#T(greedy = true)", encodeSymbol)
if (segment.isGreedyLabel) {
write("addAll(#S.split(#S))", "\${$identifier}", '"')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question/correctness: Why is this splitting on "?

Also previously we used a dedicated function encodeLabel which adheres to slightly different encoding semantics than HTTP RFC does. I'm not sure this is correct to replace appending (or splitting greedy labels and appending) as it does now.

Copy link
Contributor Author

@ianbotsf ianbotsf Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right on both accounts. I've fixed the " splitting to be / as intended. I've fixed part of the label encoding but I still have more changes to make to support query parameters.

} else {
writer.format("#T()", encodeSymbol)
write("add(#S)", "\${$identifier}")
}
writer.write("#S.$encodeFn,", "\${$identifier}")
} else {
// literal
writer.write("\"#L\",", segment.content.toEscapedLiteral())
writer.write("add(\"#L\")", segment.content.toEscapedLiteral())
}
}
}

writer.write("""path = pathSegments.joinToString(separator = "/", prefix = "/")""")
} else {
// all literals, inline directly
val resolvedPath = httpTrait.uri.segments.joinToString(
Expand All @@ -363,7 +364,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
it.content.toEscapedLiteral()
},
)
writer.write("path = \"#L\"", resolvedPath)
writer.write("path.encoded = \"#L\"", resolvedPath)
}
}

Expand All @@ -384,45 +385,44 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {

if (queryBindings.isEmpty() && queryLiterals.isEmpty() && queryMapBindings.isEmpty()) return

writer
.withBlock("#T {", "}", RuntimeTypes.Core.Net.parameters) {
queryLiterals.forEach { (key, value) ->
writer.write("append(#S, #S)", key, value)
}

// render length check if applicable
queryBindings.forEach { binding -> renderNonBlankGuard(ctx, binding.member, writer) }
writer.withBlock("parameters.encodedParameters {", "}") {
queryLiterals.forEach { (key, value) ->
writer.write("add(#S, #S)", key, value)
}

renderStringValuesMapParameters(ctx, queryBindings, writer)
// render length check if applicable
queryBindings.forEach { binding -> renderNonBlankGuard(ctx, binding.member, writer) }

queryMapBindings.forEach {
// either Map<String, String> or Map<String, Collection<String>>
// https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html#httpqueryparams-trait
val target = ctx.model.expectShape<MapShape>(it.member.target)
val valueTarget = ctx.model.expectShape(target.value.target)
val fn = when (valueTarget.type) {
ShapeType.STRING -> "append"
ShapeType.LIST, ShapeType.SET -> "appendAll"
else -> throw CodegenException("unexpected value type for httpQueryParams map")
}
renderStringValuesMapParameters(ctx, queryBindings, writer)

val nullCheck = if (target.hasTrait<SparseTrait>()) {
"if (value != null) "
} else {
""
}
queryMapBindings.forEach {
// either Map<String, String> or Map<String, Collection<String>>
// https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html#httpqueryparams-trait
val target = ctx.model.expectShape<MapShape>(it.member.target)
val valueTarget = ctx.model.expectShape(target.value.target)
val fn = when (valueTarget.type) {
ShapeType.STRING -> "add"
ShapeType.LIST, ShapeType.SET -> "addAll"
else -> throw CodegenException("unexpected value type for httpQueryParams map")
}

writer.write("input.${it.member.defaultName()}")
.indent()
// ensure query precedence rules are enforced by filtering keys already set
// (httpQuery bound members take precedence over a query map with same key)
.write("?.filterNot{ contains(it.key) }")
.withBlock("?.forEach { (key, value) ->", "}") {
write("${nullCheck}$fn(key, value)")
}
.dedent()
val nullCheck = if (target.hasTrait<SparseTrait>()) {
"if (value != null) "
} else {
""
}

writer.write("input.${it.member.defaultName()}")
.indent()
// ensure query precedence rules are enforced by filtering keys already set
// (httpQuery bound members take precedence over a query map with same key)
.write("?.filterNot{ contains(it.key) }")
.withBlock("?.forEach { (key, value) ->", "}") {
write("${nullCheck}$fn(key, value)")
}
.dedent()
}
}
}

// shared implementation for rendering members that belong to StringValuesMap (e.g. Header or Query parameters)
Expand Down Expand Up @@ -764,7 +764,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
)
}
is BlobShape -> {
writer.write("builder.#L = response.headers[#S]?.#T()", memberName, headerName, RuntimeTypes.Core.Utils.decodeBase64)
writer.write("builder.#L = response.headers[#S]?.#T()", memberName, headerName, RuntimeTypes.Core.Text.Encoding.decodeBase64)
}
is StringShape -> {
when {
Expand All @@ -779,7 +779,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
)
}
memberTarget.hasTrait<MediaTypeTrait>() -> {
writer.write("builder.#L = response.headers[#S]?.#T()", memberName, headerName, RuntimeTypes.Core.Utils.decodeBase64)
writer.write("builder.#L = response.headers[#S]?.#T()", memberName, headerName, RuntimeTypes.Core.Text.Encoding.decodeBase64)
}
else -> {
writer.write("builder.#L = response.headers[#S]", memberName, headerName)
Expand Down Expand Up @@ -839,7 +839,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
"${enumSymbol.name}.fromValue(it)"
}
collectionMemberTarget.hasTrait<MediaTypeTrait>() -> {
writer.addImport(RuntimeTypes.Core.Utils.decodeBase64)
writer.addImport(RuntimeTypes.Core.Text.Encoding.decodeBase64)
"it.decodeBase64()"
}
else -> ""
Expand Down
Loading
Loading