Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/awslabs/smithy-kotlin into …
Browse files Browse the repository at this point in the history
…jmes-path-functions
  • Loading branch information
0marperez committed Sep 13, 2023
2 parents 3b996e7 + c236d00 commit efc220e
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 41 deletions.
5 changes: 0 additions & 5 deletions .changes/d477286c-f799-426b-947c-7cc6982fbcfe.json

This file was deleted.

8 changes: 0 additions & 8 deletions .changes/d47756d3-1127-4ed0-a71f-44ca2daebf9a.json

This file was deleted.

9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Changelog

## [0.27.3] - 09/08/2023

### Features
* [#612](https://github.com/awslabs/aws-sdk-kotlin/issues/612) Add conversions to and from `Flow<ByteArray>` and `ByteStream`
* [#617](https://github.com/awslabs/aws-sdk-kotlin/issues/617) Add conversion to InputStream from ByteStream

### Miscellaneous
* Expose SDK ID in service companion object section writer.

## [0.27.1] - 08/31/2023

### Fixes
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ kotlin.mpp.stability.nowarn=true
kotlin.native.ignoreDisabledTargets=true

# SDK
sdkVersion=0.27.3-SNAPSHOT
sdkVersion=0.27.4-SNAPSHOT

# kotlin
kotlinVersion=1.8.22
4 changes: 4 additions & 0 deletions runtime/runtime-core/api/runtime-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ public final class aws/smithy/kotlin/runtime/content/ByteStreamKt {
public static final fun cancel (Laws/smithy/kotlin/runtime/content/ByteStream;)V
public static final fun decodeToString (Laws/smithy/kotlin/runtime/content/ByteStream;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun toByteArray (Laws/smithy/kotlin/runtime/content/ByteStream;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun toByteStream (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/CoroutineScope;Ljava/lang/Long;)Laws/smithy/kotlin/runtime/content/ByteStream;
public static synthetic fun toByteStream$default (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/CoroutineScope;Ljava/lang/Long;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/content/ByteStream;
public static final fun toFlow (Laws/smithy/kotlin/runtime/content/ByteStream;J)Lkotlinx/coroutines/flow/Flow;
public static synthetic fun toFlow$default (Laws/smithy/kotlin/runtime/content/ByteStream;JILjava/lang/Object;)Lkotlinx/coroutines/flow/Flow;
}

public abstract class aws/smithy/kotlin/runtime/content/Document {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
*/
package aws.smithy.kotlin.runtime.content

import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
import aws.smithy.kotlin.runtime.io.SdkSource
import aws.smithy.kotlin.runtime.io.readToBuffer
import aws.smithy.kotlin.runtime.io.readToByteArray
import aws.smithy.kotlin.runtime.io.*
import aws.smithy.kotlin.runtime.io.internal.SdkDispatchers
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.launch

/**
* Represents an abstract read-only stream of bytes
Expand Down Expand Up @@ -106,3 +107,92 @@ public fun ByteStream.cancel() {
is ByteStream.SourceStream -> stream.readFrom().close()
}
}

/**
* Return a [Flow] that consumes the underlying [ByteStream] when collected.
*
* @param bufferSize the size of the buffers to emit from the flow. All buffers emitted
* will be of this size except for the last one which may be less than the requested buffer size.
* This parameter has no effect for the [ByteStream.Buffer] variant. The emitted [ByteArray]
* will be whatever size the in-memory buffer already is in that case.
*/
public fun ByteStream.toFlow(bufferSize: Long = 8192): Flow<ByteArray> = when (this) {
is ByteStream.Buffer -> flowOf(bytes())
is ByteStream.ChannelStream -> readFrom().toFlow(bufferSize)
is ByteStream.SourceStream -> readFrom().toFlow(bufferSize).flowOn(SdkDispatchers.IO)
}

/**
* Create a [ByteStream] from a [Flow] of byte arrays.
*
* @param scope the [CoroutineScope] to use for launching a coroutine to do the collection in.
* @param contentLength the overall content length of the [Flow] (if known). If set this will be
* used as [ByteStream.contentLength]. Some APIs require a known `Content-Length` header and
* since the total size of the flow can't be calculated without collecting it callers should set this
* parameter appropriately in those cases.
*/
public fun Flow<ByteArray>.toByteStream(
scope: CoroutineScope,
contentLength: Long? = null,
): ByteStream {
val ch = SdkByteChannel(true)
var totalWritten = 0L
val job = scope.launch {
collect { bytes ->
ch.write(bytes)
totalWritten += bytes.size

check(contentLength == null || totalWritten <= contentLength) {
"$totalWritten bytes collected from flow exceeds reported content length of $contentLength"
}
}

check(contentLength == null || totalWritten == contentLength) {
"expected $contentLength bytes collected from flow, got $totalWritten"
}

ch.close()
}

job.invokeOnCompletion { cause ->
ch.close(cause)
}

return object : ByteStream.ChannelStream() {
override val contentLength: Long? = contentLength
override val isOneShot: Boolean = true
override fun readFrom(): SdkByteReadChannel = ch
}
}

private fun SdkByteReadChannel.toFlow(bufferSize: Long): Flow<ByteArray> = flow {
val chan = this@toFlow
val sink = SdkBuffer()
while (!chan.isClosedForRead) {
val rc = chan.read(sink, bufferSize)
if (rc == -1L) break
if (sink.size >= bufferSize) {
val bytes = sink.readByteArray(bufferSize)
emit(bytes)
}
}
if (sink.size > 0L) {
emit(sink.readByteArray())
}
}

private fun SdkSource.toFlow(bufferSize: Long): Flow<ByteArray> = flow {
val source = this@toFlow
val sink = SdkBuffer()
while (true) {
val rc = source.read(sink, bufferSize)
if (rc == -1L) break
if (sink.size >= bufferSize) {
val bytes = sink.readByteArray(bufferSize)
emit(bytes)
}
}
if (sink.size > 0L) {
emit(sink.readByteArray())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package aws.smithy.kotlin.runtime.content

import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
import aws.smithy.kotlin.runtime.io.SdkSource
import aws.smithy.kotlin.runtime.io.source

fun interface ByteStreamFactory {
fun byteStream(input: ByteArray): ByteStream
companion object {
val BYTE_ARRAY: ByteStreamFactory = ByteStreamFactory { input -> ByteStream.fromBytes(input) }

val SDK_SOURCE: ByteStreamFactory = ByteStreamFactory { input ->
object : ByteStream.SourceStream() {
override fun readFrom(): SdkSource = input.source()
override val contentLength: Long = input.size.toLong()
}
}

val SDK_CHANNEL: ByteStreamFactory = ByteStreamFactory { input ->
object : ByteStream.ChannelStream() {
override fun readFrom(): SdkByteReadChannel = SdkByteReadChannel(input)
override val contentLength: Long = input.size.toLong()
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package aws.smithy.kotlin.runtime.content

import io.kotest.matchers.string.shouldContain
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.test.runTest
import java.lang.RuntimeException
import kotlin.test.*

class ByteStreamBufferFlowTest : ByteStreamFlowTest(ByteStreamFactory.BYTE_ARRAY)
class ByteStreamSourceStreamFlowTest : ByteStreamFlowTest(ByteStreamFactory.SDK_SOURCE)
class ByteStreamChannelSourceFlowTest : ByteStreamFlowTest(ByteStreamFactory.SDK_CHANNEL)

abstract class ByteStreamFlowTest(
private val factory: ByteStreamFactory,
) {
@Test
fun testToFlowWithSizeHint() = runTest {
val data = "a korf is a tiger".repeat(1024).encodeToByteArray()
val bufferSize = 8182 * 2
val byteStream = factory.byteStream(data)
val flow = byteStream.toFlow(bufferSize.toLong())
val buffers = mutableListOf<ByteArray>()
flow.toList(buffers)

val totalCollected = buffers.sumOf { it.size }
assertEquals(data.size, totalCollected)

if (byteStream is ByteStream.Buffer) {
assertEquals(1, buffers.size)
assertContentEquals(data, buffers.first())
} else {
val expectedFullBuffers = data.size / bufferSize
for (i in 0 until expectedFullBuffers) {
val b = buffers[i]
val expected = data.sliceArray((i * bufferSize)until (i * bufferSize + bufferSize))
assertEquals(bufferSize, b.size)
assertContentEquals(expected, b)
}

val last = buffers.last()
val expected = data.sliceArray(((buffers.size - 1) * bufferSize) until data.size)
assertContentEquals(expected, last)
}
}

class FlowToByteStreamTest {
private fun testByteArray(size: Int): ByteArray = ByteArray(size) { i -> i.toByte() }

val data = listOf(
testByteArray(576),
testByteArray(9172),
testByteArray(3278),
)

@Test
fun testFlowToByteStreamReadAll() = runTest {
val flow = data.asFlow()
val scope = CoroutineScope(coroutineContext)
val byteStream = flow.toByteStream(scope)

assertNull(byteStream.contentLength)

val actual = byteStream.toByteArray()
val expected = data.reduce { acc, bytes -> acc + bytes }
assertContentEquals(expected, actual)
}

@Test
fun testContentLengthOverflow() = runTest {
val advertisedContentLength = 1024L
testInvalidContentLength(advertisedContentLength, "9748 bytes collected from flow exceeds reported content length of 1024")
}

@Test
fun testContentLengthUnderflow() = runTest {
val advertisedContentLength = data.sumOf { it.size } + 100L
testInvalidContentLength(advertisedContentLength, "expected 13126 bytes collected from flow, got 13026")
}

private suspend fun testInvalidContentLength(advertisedContentLength: Long, expectedMessage: String) {
val job = Job()
val uncaughtExceptions = mutableListOf<Throwable>()
val exHandler = CoroutineExceptionHandler { _, throwable -> uncaughtExceptions.add(throwable) }
val scope = CoroutineScope(job + exHandler)
val byteStream = data
.asFlow()
.toByteStream(scope, advertisedContentLength)

assertEquals(advertisedContentLength, byteStream.contentLength)

val ex = assertFailsWith<IllegalStateException> {
byteStream.toByteArray()
}

ex.message?.shouldContain(expectedMessage)
assertTrue(job.isCancelled)
job.join()

assertEquals(1, uncaughtExceptions.size)
}

@Test
fun testScopeCancellation() = runTest {
// cancelling the scope should close/cancel the channel
val waiter = Channel<Unit>(1)
val flow = flow {
emit(testByteArray(128))
emit(testByteArray(277))
waiter.receive()
emit(testByteArray(97))
}

val job = Job()
val scope = CoroutineScope(job)
val byteStream = flow.toByteStream(scope)
assertIs<ByteStream.ChannelStream>(byteStream)
assertNull(byteStream.contentLength)
yield()

job.cancel("scope cancelled")
waiter.send(Unit)
job.join()

val ch = byteStream.readFrom()
assertTrue(ch.isClosedForRead)
assertTrue(ch.isClosedForWrite)
assertIs<CancellationException>(ch.closedCause)
ch.closedCause?.message.shouldContain("scope cancelled")
}

@Test
fun testChannelCancellation() = runTest {
// cancelling the channel should cancel the scope (via write failing)
val waiter = Channel<Unit>(1)
val flow = flow {
emit(testByteArray(128))
emit(testByteArray(277))
waiter.receive()
emit(testByteArray(97))
}

val uncaughtExceptions = mutableListOf<Throwable>()
val exHandler = CoroutineExceptionHandler { _, throwable -> uncaughtExceptions.add(throwable) }
val job = Job()
val scope = CoroutineScope(job + exHandler)
val byteStream = flow.toByteStream(scope)
assertIs<ByteStream.ChannelStream>(byteStream)

val ch = byteStream.readFrom()
val cause = RuntimeException("chan cancelled")
ch.cancel(cause)

// unblock the flow
waiter.send(Unit)

job.join()
assertTrue(job.isCancelled)
assertEquals(1, uncaughtExceptions.size)
uncaughtExceptions.first().message.shouldContain("chan cancelled")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,12 @@
*/
package aws.smithy.kotlin.runtime.content

import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
import aws.smithy.kotlin.runtime.io.SdkSource
import aws.smithy.kotlin.runtime.io.source
import java.io.InputStream
import kotlin.test.Test
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals

fun interface ByteStreamFactory {
fun inputStream(input: ByteArray): InputStream
companion object {
val BYTE_ARRAY: ByteStreamFactory = ByteStreamFactory { input -> ByteStream.fromBytes(input).toInputStream() }

val SDK_SOURCE: ByteStreamFactory = ByteStreamFactory { input ->
object : ByteStream.SourceStream() {
override fun readFrom(): SdkSource = input.source()
override val contentLength: Long = input.size.toLong()
}.toInputStream()
}

val SDK_CHANNEL: ByteStreamFactory = ByteStreamFactory { input ->
object : ByteStream.ChannelStream() {
override fun readFrom(): SdkByteReadChannel = SdkByteReadChannel(input)
override val contentLength: Long = input.size.toLong()
}.toInputStream()
}
}
}
fun ByteStreamFactory.inputStream(input: ByteArray): InputStream = byteStream(input).toInputStream()

class ByteStreamBufferInputStreamTest : ByteStreamInputStreamTest(ByteStreamFactory.BYTE_ARRAY)
class ByteStreamSourceStreamInputStreamTest : ByteStreamInputStreamTest(ByteStreamFactory.SDK_SOURCE)
Expand Down

0 comments on commit efc220e

Please sign in to comment.