Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(codegen): generate paginators and waiters with default parameters for all optional inputs #959

Merged
merged 5 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/4602d073-9393-4496-b63d-85535a2f631a.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"id": "4602d073-9393-4496-b63d-85535a2f631a",
"type": "feature",
"description": "Generate paginators and waiters with a default parameter when input shape has all optional members"
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fun String.shouldContain(expectedStart: String, expectedEnd: String) {
}

fun <T> List<T>.indexOfSublistOrNull(sublist: List<T>, startFrom: Int = 0): Int? =
drop(startFrom).windowed(sublist.size).indexOf(sublist)
drop(startFrom).windowed(sublist.size).indexOf(sublist).takeIf { it >= 0 }

/** Format a multi-line string suitable for comparison with codegen, defaults to one level of indention. */
fun String.formatForTest(indent: String = " ") =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ fun OperationIndex.operationSignature(

val hasOutputStream = outputShape.map { it.hasStreamingMember(model) }.orElse(false)
val inputParam = input.map {
if (includeOptionalDefault && inputShape.get().isOptional()) "input: $it = $it {}" else "input: $it"
if (includeOptionalDefault && inputShape.get().hasAllOptionalMembers) "input: $it = $it { }" else "input: $it"
}.orElse("")
val outputParam = output.map { ": $it" }.orElse("")

Expand Down Expand Up @@ -245,9 +245,10 @@ fun UnionShape.filterEventStreamErrors(model: Model): Collection<MemberShape> {
}

/**
* Test if a shape is optional.
* Test if a shape has all optional members (no member marked `@required`)
*/
fun Shape.isOptional(): Boolean = members().none { it.isRequired }
val Shape.hasAllOptionalMembers: Boolean
get() = members().none { it.isRequired }

/**
* Derive the input and output symbols for an operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.model.SymbolProperty
import software.amazon.smithy.kotlin.codegen.model.expectShape
import software.amazon.smithy.kotlin.codegen.model.hasAllOptionalMembers
import software.amazon.smithy.kotlin.codegen.model.hasTrait
import software.amazon.smithy.kotlin.codegen.model.traits.PaginationTruncationMember
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
Expand All @@ -29,7 +30,6 @@ import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.PaginatedTrait

/**
Expand All @@ -54,60 +54,52 @@ class PaginatorGenerator : KotlinIntegration {
?: throw CodegenException("Unexpectedly unable to get PaginationInfo from $service $paginatedOperation")
val paginationItemInfo = getItemDescriptorOrNull(paginationInfo, ctx)

renderPaginatorForOperation(writer, ctx, service, paginatedOperation, paginationInfo, paginationItemInfo)
renderPaginatorForOperation(ctx, writer, paginatedOperation, paginationInfo, paginationItemInfo)
}
}
}

// Render paginator(s) for operation
private fun renderPaginatorForOperation(
writer: KotlinWriter,
ctx: CodegenContext,
service: ServiceShape,
writer: KotlinWriter,
paginatedOperation: OperationShape,
paginationInfo: PaginationInfo,
itemDesc: ItemDescriptor?,
) {
val serviceSymbol = ctx.symbolProvider.toSymbol(service)
val outputSymbol = ctx.symbolProvider.toSymbol(paginationInfo.output)
val inputSymbol = ctx.symbolProvider.toSymbol(paginationInfo.input)
val cursorMember = ctx.model.getShape(paginationInfo.inputTokenMember.target).get()
val cursorSymbol = ctx.symbolProvider.toSymbol(cursorMember)

renderResponsePaginator(
ctx,
writer,
serviceSymbol,
paginatedOperation,
inputSymbol,
paginationInfo.output,
outputSymbol,
paginationInfo,
cursorSymbol,
)

// Optionally generate paginator when nested item is specified on the trait.
if (itemDesc != null) {
renderItemPaginator(
ctx,
writer,
service,
paginatedOperation,
itemDesc,
outputSymbol,
)
}
}

// Generate the paginator that iterates over responses
private fun renderResponsePaginator(
ctx: CodegenContext,
writer: KotlinWriter,
serviceSymbol: Symbol,
operationShape: OperationShape,
inputSymbol: Symbol,
outputShape: StructureShape,
outputSymbol: Symbol,
paginationInfo: PaginationInfo,
cursorSymbol: Symbol,
) {
val service = ctx.model.expectShape<ServiceShape>(ctx.settings.service)
val serviceSymbol = ctx.symbolProvider.toSymbol(service)
val outputShape = paginationInfo.output
val outputSymbol = ctx.symbolProvider.toSymbol(outputShape)
val inputSymbol = ctx.symbolProvider.toSymbol(paginationInfo.input)
val cursorMember = ctx.model.getShape(paginationInfo.inputTokenMember.target).get()
val cursorSymbol = ctx.symbolProvider.toSymbol(cursorMember)

val nextMarkerLiteral = paginationInfo.outputTokenMemberPath.joinToString(separator = "?.") {
it.defaultName()
}
Expand All @@ -124,6 +116,12 @@ class PaginatorGenerator : KotlinIntegration {
""".trimIndent()
val docReturn = "@return A [kotlinx.coroutines.flow.Flow] that can collect [${outputSymbol.name}]"

val inputParameter = if (paginationInfo.input.hasAllOptionalMembers) {
writer.format("initialRequest: #1T = #1T { }", inputSymbol)
} else {
writer.format("initialRequest: #T", inputSymbol)
}

writer.write("")
writer
.dokka(
Expand All @@ -135,11 +133,12 @@ class PaginatorGenerator : KotlinIntegration {
)
.addImportReferences(cursorSymbol, SymbolReference.ContextOption.DECLARE)
.withBlock(
"public fun #T.#LPaginated(initialRequest: #T): #T<#T> =",
"#L fun #T.#LPaginated(#L): #T<#T> =",
"",
ctx.settings.api.visibility,
serviceSymbol,
operationShape.defaultName(),
inputSymbol,
inputParameter,
ExternalTypes.KotlinxCoroutines.Flow,
outputSymbol,
) {
Expand Down Expand Up @@ -180,8 +179,9 @@ class PaginatorGenerator : KotlinIntegration {
""".trimMargin(),
)
.withBlock(
"public fun #T.#LPaginated(block: #T.Builder.() -> #T): #T<#T> =",
"#L fun #T.#LPaginated(block: #T.Builder.() -> #T): #T<#T> =",
"",
ctx.settings.api.visibility,
serviceSymbol,
operationShape.defaultName(),
inputSymbol,
Expand All @@ -195,12 +195,15 @@ class PaginatorGenerator : KotlinIntegration {

// Generate a paginator that iterates over the model-specified item
private fun renderItemPaginator(
ctx: CodegenContext,
writer: KotlinWriter,
serviceShape: ServiceShape,
operationShape: OperationShape,
itemDesc: ItemDescriptor,
outputSymbol: Symbol,
) {
val serviceShape = ctx.model.expectShape<ServiceShape>(ctx.settings.service)
val outputShape = ctx.model.expectShape(operationShape.outputShape)
val outputSymbol = ctx.symbolProvider.toSymbol(outputShape)

writer.write("")
writer.dokka(
"""
Expand All @@ -223,8 +226,9 @@ class PaginatorGenerator : KotlinIntegration {
itemDesc.targetMember.defaultName(serviceShape),
)
.withBlock(
"public fun #T<#T>.#L(): #T<#L> =",
"#L fun #T<#T>.#L(): #T<#L> =",
"",
ctx.settings.api.visibility,
ExternalTypes.KotlinxCoroutines.Flow,
outputSymbol,
itemDesc.itemLiteral,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
import software.amazon.smithy.kotlin.codegen.core.withBlock
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.model.hasAllOptionalMembers
import java.text.DecimalFormat
import java.text.DecimalFormatSymbols

Expand All @@ -34,12 +35,18 @@ private fun KotlinWriter.renderRetryStrategy(wi: WaiterInfo, asValName: String)
internal fun KotlinWriter.renderWaiter(wi: WaiterInfo) {
write("")
wi.waiter.documentation.ifPresent(::dokka)
val inputParameter = if (wi.input.hasAllOptionalMembers) {
format("request: #1T = #1T { }", wi.inputSymbol)
} else {
format("request: #T", wi.inputSymbol)
}
withBlock(
"public suspend fun #T.#L(request: #T): #T<#T> {",
"#L suspend fun #T.#L(#L): #T<#T> {",
"}",
wi.ctx.settings.api.visibility,
wi.serviceSymbol,
wi.methodName,
wi.inputSymbol,
inputParameter,
RuntimeTypes.Core.Retries.Outcome,
wi.outputSymbol,
) {
Expand All @@ -54,7 +61,8 @@ internal fun KotlinWriter.renderWaiter(wi: WaiterInfo) {
write("")
wi.waiter.documentation.ifPresent(this::dokka)
write(
"public suspend fun #T.#L(block: #T.Builder.() -> Unit): #T<#T> =",
"#L suspend fun #T.#L(block: #T.Builder.() -> Unit): #T<#T> =",
wi.ctx.settings.api.visibility,
wi.serviceSymbol,
wi.methodName,
wi.inputSymbol,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class PaginatorGeneratorTest {
* @param initialRequest A [ListFunctionsRequest] to start pagination
* @return A [kotlinx.coroutines.flow.Flow] that can collect [ListFunctionsResponse]
*/
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest): Flow<ListFunctionsResponse> =
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest = ListFunctionsRequest { }): Flow<ListFunctionsResponse> =
flow {
var cursor: kotlin.String? = null
var hasNextPage: Boolean = true
Expand Down Expand Up @@ -204,7 +204,7 @@ class PaginatorGeneratorTest {
* @param initialRequest A [ListFunctionsRequest] to start pagination
* @return A [kotlinx.coroutines.flow.Flow] that can collect [ListFunctionsResponse]
*/
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest): Flow<ListFunctionsResponse> =
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest = ListFunctionsRequest { }): Flow<ListFunctionsResponse> =
flow {
var cursor: kotlin.String? = null
var hasNextPage: Boolean = true
Expand Down Expand Up @@ -333,7 +333,7 @@ class PaginatorGeneratorTest {
val actual = testManifest.expectFileString("src/main/kotlin/smithy/kotlin/traits/paginators/Paginators.kt")

val expectedCode = """
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest): Flow<ListFunctionsResponse> =
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest = ListFunctionsRequest { }): Flow<ListFunctionsResponse> =
flow {
var cursor: kotlin.String? = null
var hasNextPage: Boolean = true
Expand All @@ -352,4 +352,76 @@ class PaginatorGeneratorTest {

actual.shouldContainOnlyOnceWithDiff(expectedCode)
}

@Test
fun testRenderPaginatorWithRequiredInputMembers() {
val testModelNoItem = """
namespace com.test

use aws.protocols#restJson1

service Lambda {
operations: [ListFunctions]
}

@paginated(
inputToken: "Marker",
outputToken: "NextMarker",
pageSize: "MaxItems"
)
@readonly
@http(method: "GET", uri: "/functions", code: 200)
operation ListFunctions {
input: ListFunctionsRequest,
output: ListFunctionsResponse
}

structure ListFunctionsRequest {
@required
@httpQuery("FunctionVersion")
FunctionVersion: String,
@httpQuery("Marker")
Marker: String,
@httpQuery("MasterRegion")
MasterRegion: String,
@httpQuery("MaxItems")
MaxItems: Integer
}

structure ListFunctionsResponse {
Functions: FunctionConfigurationList,
NextMarker: String
}

list FunctionConfigurationList {
member: FunctionConfiguration
}

structure FunctionConfiguration {
FunctionName: String
}
""".toSmithyModel()
val testContextNoItem = testModelNoItem.newTestContext("Lambda", "com.test")

val codegenContextNoItem = object : CodegenContext {
override val model: Model = testContextNoItem.generationCtx.model
override val symbolProvider: SymbolProvider = testContextNoItem.generationCtx.symbolProvider
override val settings: KotlinSettings = testContextNoItem.generationCtx.settings
override val protocolGenerator: ProtocolGenerator = testContextNoItem.generator
override val integrations: List<KotlinIntegration> = testContextNoItem.generationCtx.integrations
}

val unit = PaginatorGenerator()
unit.writeAdditionalFiles(codegenContextNoItem, testContextNoItem.generationCtx.delegator)

testContextNoItem.generationCtx.delegator.flushWriters()
val testManifest = testContextNoItem.generationCtx.delegator.fileManifest as MockManifest
val actual = testManifest.expectFileString("src/main/kotlin/com/test/paginators/Paginators.kt")

val expected = """
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest): Flow<ListFunctionsResponse> =
""".trimIndent()

actual.shouldContainOnlyOnceWithDiff(expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ class ServiceClientGeneratorTest {
fun `it renders signatures correctly`() {
val expectedSignatures = listOf(
"public suspend fun getFoo(input: GetFooRequest): GetFooResponse",
"public suspend fun getFooNoRequired(input: GetFooNoRequiredRequest = GetFooNoRequiredRequest {}): GetFooNoRequiredResponse",
"public suspend fun getFooNoRequired(input: GetFooNoRequiredRequest = GetFooNoRequiredRequest { }): GetFooNoRequiredResponse",
"public suspend fun getFooSomeRequired(input: GetFooSomeRequiredRequest): GetFooSomeRequiredResponse",
"public suspend fun getFooNoInput(input: GetFooNoInputRequest = GetFooNoInputRequest {}): GetFooNoInputResponse",
"public suspend fun getFooNoInput(input: GetFooNoInputRequest = GetFooNoInputRequest { }): GetFooNoInputResponse",
"public suspend fun getFooNoOutput(input: GetFooNoOutputRequest): GetFooNoOutputResponse",
"public suspend fun getFooStreamingInput(input: GetFooStreamingInputRequest): GetFooStreamingInputResponse",
"public suspend fun <T> getFooStreamingOutput(input: GetFooStreamingOutputRequest, block: suspend (GetFooStreamingOutputResponse) -> T): T",
"public suspend fun <T> getFooStreamingOutputNoInput(input: GetFooStreamingOutputNoInputRequest = GetFooStreamingOutputNoInputRequest {}, block: suspend (GetFooStreamingOutputNoInputResponse) -> T): T",
"public suspend fun <T> getFooStreamingOutputNoInput(input: GetFooStreamingOutputNoInputRequest = GetFooStreamingOutputNoInputRequest { }, block: suspend (GetFooStreamingOutputNoInputResponse) -> T): T",
"public suspend fun getFooStreamingInputNoOutput(input: GetFooStreamingInputNoOutputRequest): GetFooStreamingInputNoOutputResponse",
)
expectedSignatures.forEach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package software.amazon.smithy.kotlin.codegen.rendering.waiters

import io.kotest.matchers.string.shouldContain
import io.kotest.matchers.string.shouldContainOnlyOnce
import software.amazon.smithy.build.MockManifest
import software.amazon.smithy.codegen.core.SymbolProvider
Expand Down Expand Up @@ -42,7 +43,7 @@ class ServiceWaitersGeneratorTest {
/**
* Wait until a foo exists
*/
public suspend fun TestClient.waitUntilFooExists(request: DescribeFooRequest): Outcome<DescribeFooResponse> {
public suspend fun TestClient.waitUntilFooExists(request: DescribeFooRequest = DescribeFooRequest { }): Outcome<DescribeFooResponse> {
""".trimIndent()
val methodFooter = """
val policy = AcceptorRetryPolicy(request, acceptors)
Expand All @@ -52,6 +53,17 @@ class ServiceWaitersGeneratorTest {
generated.shouldContain(methodHeader, methodFooter)
}

@Test
fun testWaiterSignatureWithRequiredInput() {
val methodHeader = """
/**
* Wait until a foo exists with required input
*/
public suspend fun TestClient.waitUntilFooRequiredExists(request: DescribeFooRequiredRequest): Outcome<DescribeFooRequiredResponse> {
""".trimIndent()
generated.shouldContainOnlyOnceWithDiff(methodHeader)
}

@Test
fun testConvenienceWaiterMethod() {
val expected = """
Expand Down Expand Up @@ -89,7 +101,7 @@ class ServiceWaitersGeneratorTest {
}
}
""".formatForTest()
generated.shouldContainOnlyOnce(expected)
generated.shouldContain(expected)
}

private fun generateService(modelResourceName: String): String {
Expand Down
Loading