-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/awslabs/smithy-kotlin into …
…jmes-path-functions
- Loading branch information
Showing
9 changed files
with
307 additions
and
41 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
runtime/runtime-core/common/test/aws/smithy/kotlin/runtime/content/ByteStreamFactory.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
/* | ||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package aws.smithy.kotlin.runtime.content | ||
|
||
import aws.smithy.kotlin.runtime.io.SdkByteReadChannel | ||
import aws.smithy.kotlin.runtime.io.SdkSource | ||
import aws.smithy.kotlin.runtime.io.source | ||
|
||
fun interface ByteStreamFactory { | ||
fun byteStream(input: ByteArray): ByteStream | ||
companion object { | ||
val BYTE_ARRAY: ByteStreamFactory = ByteStreamFactory { input -> ByteStream.fromBytes(input) } | ||
|
||
val SDK_SOURCE: ByteStreamFactory = ByteStreamFactory { input -> | ||
object : ByteStream.SourceStream() { | ||
override fun readFrom(): SdkSource = input.source() | ||
override val contentLength: Long = input.size.toLong() | ||
} | ||
} | ||
|
||
val SDK_CHANNEL: ByteStreamFactory = ByteStreamFactory { input -> | ||
object : ByteStream.ChannelStream() { | ||
override fun readFrom(): SdkByteReadChannel = SdkByteReadChannel(input) | ||
override val contentLength: Long = input.size.toLong() | ||
} | ||
} | ||
} | ||
} |
168 changes: 168 additions & 0 deletions
168
runtime/runtime-core/common/test/aws/smithy/kotlin/runtime/content/ByteStreamFlowTest.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
/* | ||
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package aws.smithy.kotlin.runtime.content | ||
|
||
import io.kotest.matchers.string.shouldContain | ||
import kotlinx.coroutines.* | ||
import kotlinx.coroutines.channels.Channel | ||
import kotlinx.coroutines.flow.* | ||
import kotlinx.coroutines.test.runTest | ||
import java.lang.RuntimeException | ||
import kotlin.test.* | ||
|
||
class ByteStreamBufferFlowTest : ByteStreamFlowTest(ByteStreamFactory.BYTE_ARRAY) | ||
class ByteStreamSourceStreamFlowTest : ByteStreamFlowTest(ByteStreamFactory.SDK_SOURCE) | ||
class ByteStreamChannelSourceFlowTest : ByteStreamFlowTest(ByteStreamFactory.SDK_CHANNEL) | ||
|
||
abstract class ByteStreamFlowTest( | ||
private val factory: ByteStreamFactory, | ||
) { | ||
@Test | ||
fun testToFlowWithSizeHint() = runTest { | ||
val data = "a korf is a tiger".repeat(1024).encodeToByteArray() | ||
val bufferSize = 8182 * 2 | ||
val byteStream = factory.byteStream(data) | ||
val flow = byteStream.toFlow(bufferSize.toLong()) | ||
val buffers = mutableListOf<ByteArray>() | ||
flow.toList(buffers) | ||
|
||
val totalCollected = buffers.sumOf { it.size } | ||
assertEquals(data.size, totalCollected) | ||
|
||
if (byteStream is ByteStream.Buffer) { | ||
assertEquals(1, buffers.size) | ||
assertContentEquals(data, buffers.first()) | ||
} else { | ||
val expectedFullBuffers = data.size / bufferSize | ||
for (i in 0 until expectedFullBuffers) { | ||
val b = buffers[i] | ||
val expected = data.sliceArray((i * bufferSize)until (i * bufferSize + bufferSize)) | ||
assertEquals(bufferSize, b.size) | ||
assertContentEquals(expected, b) | ||
} | ||
|
||
val last = buffers.last() | ||
val expected = data.sliceArray(((buffers.size - 1) * bufferSize) until data.size) | ||
assertContentEquals(expected, last) | ||
} | ||
} | ||
|
||
class FlowToByteStreamTest { | ||
private fun testByteArray(size: Int): ByteArray = ByteArray(size) { i -> i.toByte() } | ||
|
||
val data = listOf( | ||
testByteArray(576), | ||
testByteArray(9172), | ||
testByteArray(3278), | ||
) | ||
|
||
@Test | ||
fun testFlowToByteStreamReadAll() = runTest { | ||
val flow = data.asFlow() | ||
val scope = CoroutineScope(coroutineContext) | ||
val byteStream = flow.toByteStream(scope) | ||
|
||
assertNull(byteStream.contentLength) | ||
|
||
val actual = byteStream.toByteArray() | ||
val expected = data.reduce { acc, bytes -> acc + bytes } | ||
assertContentEquals(expected, actual) | ||
} | ||
|
||
@Test | ||
fun testContentLengthOverflow() = runTest { | ||
val advertisedContentLength = 1024L | ||
testInvalidContentLength(advertisedContentLength, "9748 bytes collected from flow exceeds reported content length of 1024") | ||
} | ||
|
||
@Test | ||
fun testContentLengthUnderflow() = runTest { | ||
val advertisedContentLength = data.sumOf { it.size } + 100L | ||
testInvalidContentLength(advertisedContentLength, "expected 13126 bytes collected from flow, got 13026") | ||
} | ||
|
||
private suspend fun testInvalidContentLength(advertisedContentLength: Long, expectedMessage: String) { | ||
val job = Job() | ||
val uncaughtExceptions = mutableListOf<Throwable>() | ||
val exHandler = CoroutineExceptionHandler { _, throwable -> uncaughtExceptions.add(throwable) } | ||
val scope = CoroutineScope(job + exHandler) | ||
val byteStream = data | ||
.asFlow() | ||
.toByteStream(scope, advertisedContentLength) | ||
|
||
assertEquals(advertisedContentLength, byteStream.contentLength) | ||
|
||
val ex = assertFailsWith<IllegalStateException> { | ||
byteStream.toByteArray() | ||
} | ||
|
||
ex.message?.shouldContain(expectedMessage) | ||
assertTrue(job.isCancelled) | ||
job.join() | ||
|
||
assertEquals(1, uncaughtExceptions.size) | ||
} | ||
|
||
@Test | ||
fun testScopeCancellation() = runTest { | ||
// cancelling the scope should close/cancel the channel | ||
val waiter = Channel<Unit>(1) | ||
val flow = flow { | ||
emit(testByteArray(128)) | ||
emit(testByteArray(277)) | ||
waiter.receive() | ||
emit(testByteArray(97)) | ||
} | ||
|
||
val job = Job() | ||
val scope = CoroutineScope(job) | ||
val byteStream = flow.toByteStream(scope) | ||
assertIs<ByteStream.ChannelStream>(byteStream) | ||
assertNull(byteStream.contentLength) | ||
yield() | ||
|
||
job.cancel("scope cancelled") | ||
waiter.send(Unit) | ||
job.join() | ||
|
||
val ch = byteStream.readFrom() | ||
assertTrue(ch.isClosedForRead) | ||
assertTrue(ch.isClosedForWrite) | ||
assertIs<CancellationException>(ch.closedCause) | ||
ch.closedCause?.message.shouldContain("scope cancelled") | ||
} | ||
|
||
@Test | ||
fun testChannelCancellation() = runTest { | ||
// cancelling the channel should cancel the scope (via write failing) | ||
val waiter = Channel<Unit>(1) | ||
val flow = flow { | ||
emit(testByteArray(128)) | ||
emit(testByteArray(277)) | ||
waiter.receive() | ||
emit(testByteArray(97)) | ||
} | ||
|
||
val uncaughtExceptions = mutableListOf<Throwable>() | ||
val exHandler = CoroutineExceptionHandler { _, throwable -> uncaughtExceptions.add(throwable) } | ||
val job = Job() | ||
val scope = CoroutineScope(job + exHandler) | ||
val byteStream = flow.toByteStream(scope) | ||
assertIs<ByteStream.ChannelStream>(byteStream) | ||
|
||
val ch = byteStream.readFrom() | ||
val cause = RuntimeException("chan cancelled") | ||
ch.cancel(cause) | ||
|
||
// unblock the flow | ||
waiter.send(Unit) | ||
|
||
job.join() | ||
assertTrue(job.isCancelled) | ||
assertEquals(1, uncaughtExceptions.size) | ||
uncaughtExceptions.first().message.shouldContain("chan cancelled") | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters