Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor URL types and supporting classes #989

Merged
merged 19 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/1cd7d354-501b-439a-a01a-3e884558383a.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"id": "1cd7d354-501b-439a-a01a-3e884558383a",
"type": "feature",
"description": "BREAKING: Overhaul URL APIs to clarify content encoding, when data is in which state, and to reduce the number of times data is encoded/decoded"
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ object RuntimeTypes {
val TimestampFormat = symbol("TimestampFormat", "time")
val ClientException = symbol("ClientException")

object Collections : RuntimeTypePackage(KotlinDependency.CORE, "collections") {
val Attributes = symbol("Attributes")
val attributesOf = symbol("attributesOf")
val AttributeKey = symbol("AttributeKey")
val get = symbol("get")
val mutableMultiMapOf = symbol("mutableMultiMapOf")
val putIfAbsent = symbol("putIfAbsent")
val putIfAbsentNotNull = symbol("putIfAbsentNotNull")
val ReadThroughCache = symbol("ReadThroughCache")
}

object Content : RuntimeTypePackage(KotlinDependency.CORE, "content") {
val BigDecimal = symbol("BigDecimal")
val BigInteger = symbol("BigInteger")
Expand Down Expand Up @@ -144,38 +155,34 @@ object RuntimeTypes {
val SdkManagedGroup = symbol("SdkManagedGroup")
val addIfManaged = symbol("addIfManaged", isExtension = true)
}

object Text : RuntimeTypePackage(KotlinDependency.CORE, "text") {
object Encoding : RuntimeTypePackage(KotlinDependency.CORE, "text.encoding") {
val decodeBase64 = symbol("decodeBase64")
val decodeBase64Bytes = symbol("decodeBase64Bytes")
val encodeBase64 = symbol("encodeBase64")
val encodeBase64String = symbol("encodeBase64String")
val PercentEncoding = symbol("PercentEncoding")
}
}

object Utils : RuntimeTypePackage(KotlinDependency.CORE, "util") {
val Attributes = symbol("Attributes")
val MutableAttributes = symbol("MutableAttributes")
val attributesOf = symbol("attributesOf")
val AttributeKey = symbol("AttributeKey")
val decodeBase64 = symbol("decodeBase64")
val decodeBase64Bytes = symbol("decodeBase64Bytes")
val encodeBase64 = symbol("encodeBase64")
val encodeBase64String = symbol("encodeBase64String")
val ExpiringValue = symbol("ExpiringValue")
val flattenIfPossible = symbol("flattenIfPossible")
val get = symbol("get")
val LazyAsyncValue = symbol("LazyAsyncValue")
val length = symbol("length")
val putIfAbsent = symbol("putIfAbsent")
val putIfAbsentNotNull = symbol("putIfAbsentNotNull")
val ReadThroughCache = symbol("ReadThroughCache")
val truthiness = symbol("truthiness")
val urlEncodeComponent = symbol("urlEncodeComponent", "text")
val toNumber = symbol("toNumber")
val type = symbol("type")
}

object Net : RuntimeTypePackage(KotlinDependency.CORE, "net") {
val Host = symbol("Host")
val parameters = symbol("parameters")
val QueryParameters = symbol("QueryParameters")
val QueryParametersBuilder = symbol("QueryParametersBuilder")
val splitAsQueryParameters = symbol("splitAsQueryParameters")
val toQueryParameters = symbol("toQueryParameters")
val Url = symbol("Url")
val UrlDecoding = symbol("UrlDecoding")

object Url : RuntimeTypePackage(KotlinDependency.CORE, "net.url") {
val QueryParameters = symbol("QueryParameters")
val Url = symbol("Url")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class AuthSchemeProviderAdapterGenerator {
RuntimeTypes.Auth.Identity.AuthOption,
) {
withBlock("val params = #T {", "}", AuthSchemeParametersGenerator.getSymbol(ctx.settings)) {
addImport(RuntimeTypes.Core.Utils.get)
addImport(RuntimeTypes.Core.Collections.get)
write("operationName = request.context[#T.OperationName]", RuntimeTypes.SmithyClient.SdkClientOption)

if (ctx.settings.api.enableEndpointAuthProvider) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ class DefaultEndpointProviderGenerator(
private fun renderEndpointRule(rule: EndpointRule) {
withConditions(rule.conditions) {
writer.withBlock("return #T(", ")", RuntimeTypes.SmithyClient.Endpoints.Endpoint) {
writeInline("#T.parse(", RuntimeTypes.Core.Net.Url)
writeInline("#T.parse(", RuntimeTypes.Core.Net.Url.Url)
renderExpression(rule.endpoint.url)
write(", #1T.DecodeAll - #1T.DecodePath),", RuntimeTypes.Core.Net.UrlDecoding)
write("),")

if (rule.endpoint.headers.isNotEmpty()) {
withBlock("headers = #T {", "},", RuntimeTypes.Http.Headers) {
Expand All @@ -168,7 +168,7 @@ class DefaultEndpointProviderGenerator(
}

if (rule.endpoint.properties.isNotEmpty()) {
withBlock("attributes = #T {", "},", RuntimeTypes.Core.Utils.attributesOf) {
withBlock("attributes = #T {", "},", RuntimeTypes.Core.Collections.attributesOf) {
rule.endpoint.properties.entries.forEach { (k, v) ->
val kStr = k.toString()

Expand All @@ -180,7 +180,7 @@ class DefaultEndpointProviderGenerator(

// otherwise, we just traverse the value like any other rules expression, object values will
// be rendered as Documents
writeInline("#T(#S) to ", RuntimeTypes.Core.Utils.AttributeKey, kStr)
writeInline("#T(#S) to ", RuntimeTypes.Core.Collections.AttributeKey, kStr)
renderExpression(v)
ensureNewline()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,7 @@ class DefaultEndpointProviderTestGenerator(
}

writer.withBlock("val expected = #T(", ")", RuntimeTypes.SmithyClient.Endpoints.Endpoint) {
write(
"uri = #1T.parse(#2S, #3T.DecodeAll - #3T.DecodePath),",
RuntimeTypes.Core.Net.Url,
endpoint.url,
RuntimeTypes.Core.Net.UrlDecoding,
)
write("uri = #T.parse(#S),", RuntimeTypes.Core.Net.Url.Url, endpoint.url)

if (endpoint.headers.isNotEmpty()) {
withBlock("headers = #T {", "},", RuntimeTypes.Http.Headers) {
Expand All @@ -124,14 +119,14 @@ class DefaultEndpointProviderTestGenerator(
}

if (endpoint.properties.isNotEmpty()) {
withBlock("attributes = #T {", "},", RuntimeTypes.Core.Utils.attributesOf) {
withBlock("attributes = #T {", "},", RuntimeTypes.Core.Collections.attributesOf) {
endpoint.properties.entries.forEach { (k, v) ->
if (k in expectedPropertyRenderers) {
expectedPropertyRenderers[k]!!(writer, Expression.fromNode(v), this@DefaultEndpointProviderTestGenerator)
return@forEach
}

writeInline("#T(#S) to ", RuntimeTypes.Core.Utils.AttributeKey, k)
writeInline("#T(#S) to ", RuntimeTypes.Core.Collections.AttributeKey, k)
renderExpression(Expression.fromNode(v))
ensureNewline()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class EndpointResolverAdapterGenerator(
RuntimeTypes.HttpClient.Operation.ResolveEndpointRequest,
EndpointParametersGenerator.getSymbol(ctx.settings),
) {
writer.addImport(RuntimeTypes.Core.Utils.get)
writer.addImport(RuntimeTypes.Core.Collections.get)
withBlock("return #T {", "}", EndpointParametersGenerator.getSymbol(ctx.settings)) {
// The SEP dictates a specific source order to use when binding parameters (from most specific to least):
// 1. staticContextParams (from operation shape)
Expand Down Expand Up @@ -164,7 +164,7 @@ class EndpointResolverAdapterGenerator(
val inputContextParams = epParameterIndex.inputContextParams(op)

if (inputContextParams.isNotEmpty()) {
writer.addImport(RuntimeTypes.Core.Utils.get)
writer.addImport(RuntimeTypes.Core.Collections.get)
writer.write("@Suppress(#S)", "UNCHECKED_CAST")
val opInputShape = ctx.model.expectShape(op.inputShape)
val inputSymbol = ctx.symbolProvider.toSymbol(opInputShape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class EndpointDiscovererGenerator(private val ctx: CodegenContext, private val d
) {
write(
"private val cache = #T<DiscoveryParams, #T>(10.#T, #T.System)",
RuntimeTypes.Core.Utils.ReadThroughCache,
RuntimeTypes.Core.Collections.ReadThroughCache,
RuntimeTypes.Core.Net.Host,
KotlinTypes.Time.minutes,
RuntimeTypes.Core.Clock,
Expand All @@ -66,7 +66,7 @@ class EndpointDiscovererGenerator(private val ctx: CodegenContext, private val d
write("")
write(
"""private val discoveryParamsKey = #T<DiscoveryParams>("DiscoveryParams")""",
RuntimeTypes.Core.Utils.AttributeKey,
RuntimeTypes.Core.Collections.AttributeKey,
)
write("private data class DiscoveryParams(private val region: String?, private val identity: String)")
}
Expand All @@ -92,7 +92,7 @@ class EndpointDiscovererGenerator(private val ctx: CodegenContext, private val d
write("")
write("val originalEndpoint = delegate.resolve(request)")
withBlock("#T(", ")", RuntimeTypes.SmithyClient.Endpoints.Endpoint) {
write("originalEndpoint.uri.copy(host = discoveredHost),")
write("originalEndpoint.uri.copy { host = discoveredHost },")
write("originalEndpoint.headers,")
write("originalEndpoint.attributes,")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,46 +314,57 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
renderNonBlankGuard(ctx, binding.member, writer)
}

writer.openBlock("val pathSegments = listOf<String>(", ")") {
httpTrait.uri.segments.forEach { segment ->
if (segment.isLabel || segment.isGreedyLabel) {
// spec dictates member name and label name MUST be the same
val binding = pathBindings.find { binding ->
binding.memberName == segment.content
} ?: throw CodegenException("failed to find corresponding member for httpLabel `${segment.content}")

// shape must be string, number, boolean, or timestamp
val targetShape = ctx.model.expectShape(binding.member.target)
val memberSymbol = ctx.symbolProvider.toSymbol(binding.member)
val identifier = if (targetShape.isTimestampShape) {
writer.addImport(RuntimeTypes.Core.TimestampFormat)
val tsFormat = resolver.determineTimestampFormat(
binding.member,
HttpBinding.Location.LABEL,
defaultTimestampFormat,
)
val nullCheck = if (memberSymbol.isNullable) "?" else ""
val tsLabel = formatInstant("input.${binding.member.defaultName()}$nullCheck", tsFormat, forceString = true)
tsLabel
} else {
"input.${binding.member.defaultName()}"
}
if (httpTrait.uri.segments.isNotEmpty()) {
writer.withBlock("path.encodedSegments {", "}") {
httpTrait.uri.segments.forEach { segment ->
if (segment.isLabel || segment.isGreedyLabel) {
// spec dictates member name and label name MUST be the same
val binding = pathBindings.find { binding ->
binding.memberName == segment.content
}
?: throw CodegenException("failed to find corresponding member for httpLabel `${segment.content}")

// shape must be string, number, boolean, or timestamp
val targetShape = ctx.model.expectShape(binding.member.target)
val memberSymbol = ctx.symbolProvider.toSymbol(binding.member)
val identifier = if (targetShape.isTimestampShape) {
addImport(RuntimeTypes.Core.TimestampFormat)
val tsFormat = resolver.determineTimestampFormat(
binding.member,
HttpBinding.Location.LABEL,
defaultTimestampFormat,
)
val nullCheck = if (memberSymbol.isNullable) "?" else ""
val tsLabel = formatInstant(
"input.${binding.member.defaultName()}$nullCheck",
tsFormat,
forceString = true,
)
tsLabel
} else {
"input.${binding.member.defaultName()}"
}

val encodeFn =
format("#T.SmithyLabel.encode", RuntimeTypes.Core.Text.Encoding.PercentEncoding)

val encodeSymbol = RuntimeTypes.Http.Util.encodeLabel
val encodeFn = if (segment.isGreedyLabel) {
writer.format("#T(greedy = true)", encodeSymbol)
if (segment.isGreedyLabel) {
write("#S.split(#S).mapTo(this) { #L(it) }", "\${$identifier}", '/', encodeFn)
} else {
write("add(#L(#S))", encodeFn, "\${$identifier}")
}
} else {
writer.format("#T()", encodeSymbol)
// literal
val encodeFn = format("#T.Path.encode", RuntimeTypes.Core.Text.Encoding.PercentEncoding)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question/correctness: Why wouldn't this use the same (smithy label) encode function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confusingly, path literals are encoded differently from path labels. In particular, literals have a wider variety of characters they accept without encoding while labels are more conservative. I couldn't find explicit documentation supporting this so I fall back to the RFC 3982 specification of how paths should be encoded.

Note that at least one protocol test enforces that path literal segments are not over-encoded. In this test, encoding literals with SmithyLabel would result in the path /ReDosLiteral/abc/%28a%2B%29%2B which fails to match the expected value of /ReDosLiteral/abc/(a+)+.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might add some of this context as a comment in the code

writer.write("add(#L(\"#L\"))", encodeFn, segment.content.toEscapedLiteral())
}
writer.write("#S.$encodeFn,", "\${$identifier}")
} else {
// literal
writer.write("\"#L\",", segment.content.toEscapedLiteral())
}
}
}

writer.write("""path = pathSegments.joinToString(separator = "/", prefix = "/")""")
if (httpTrait.uri.segments.isEmpty()) {
writer.write("path.trailingSlash = true")
}
} else {
// all literals, inline directly
val resolvedPath = httpTrait.uri.segments.joinToString(
Expand All @@ -363,7 +374,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
it.content.toEscapedLiteral()
},
)
writer.write("path = \"#L\"", resolvedPath)
writer.write("path.encoded = \"#L\"", resolvedPath)
}
}

Expand All @@ -384,12 +395,18 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {

if (queryBindings.isEmpty() && queryLiterals.isEmpty() && queryMapBindings.isEmpty()) return

writer
.withBlock("#T {", "}", RuntimeTypes.Core.Net.parameters) {
queryLiterals.forEach { (key, value) ->
writer.write("append(#S, #S)", key, value)
}
if (queryLiterals.isNotEmpty()) {
writer.withBlock("parameters.decodedParameters {", "}") {
queryLiterals.forEach { (key, value) -> writer.write("add(#S, #S)", key, value) }
}
}

if (queryBindings.isNotEmpty() || queryMapBindings.isNotEmpty()) {
writer.withBlock(
"parameters.decodedParameters(#T.SmithyLabel) {",
"}",
RuntimeTypes.Core.Text.Encoding.PercentEncoding,
) {
// render length check if applicable
queryBindings.forEach { binding -> renderNonBlankGuard(ctx, binding.member, writer) }

Expand All @@ -401,8 +418,8 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
val target = ctx.model.expectShape<MapShape>(it.member.target)
val valueTarget = ctx.model.expectShape(target.value.target)
val fn = when (valueTarget.type) {
ShapeType.STRING -> "append"
ShapeType.LIST, ShapeType.SET -> "appendAll"
ShapeType.STRING -> "add"
ShapeType.LIST, ShapeType.SET -> "addAll"
else -> throw CodegenException("unexpected value type for httpQueryParams map")
}

Expand All @@ -423,6 +440,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
.dedent()
}
}
}
}

// shared implementation for rendering members that belong to StringValuesMap (e.g. Header or Query parameters)
Expand Down Expand Up @@ -764,7 +782,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
)
}
is BlobShape -> {
writer.write("builder.#L = response.headers[#S]?.#T()", memberName, headerName, RuntimeTypes.Core.Utils.decodeBase64)
writer.write("builder.#L = response.headers[#S]?.#T()", memberName, headerName, RuntimeTypes.Core.Text.Encoding.decodeBase64)
}
is StringShape -> {
when {
Expand All @@ -779,7 +797,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
)
}
memberTarget.hasTrait<MediaTypeTrait>() -> {
writer.write("builder.#L = response.headers[#S]?.#T()", memberName, headerName, RuntimeTypes.Core.Utils.decodeBase64)
writer.write("builder.#L = response.headers[#S]?.#T()", memberName, headerName, RuntimeTypes.Core.Text.Encoding.decodeBase64)
}
else -> {
writer.write("builder.#L = response.headers[#S]", memberName, headerName)
Expand Down Expand Up @@ -839,7 +857,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
"${enumSymbol.name}.fromValue(it)"
}
collectionMemberTarget.hasTrait<MediaTypeTrait>() -> {
writer.addImport(RuntimeTypes.Core.Utils.decodeBase64)
writer.addImport(RuntimeTypes.Core.Text.Encoding.decodeBase64)
"it.decodeBase64()"
}
else -> ""
Expand Down
Loading
Loading