diff --git a/IntegrationTests/Services/AWSKinesisIntegrationTests/KinesisTests.swift b/IntegrationTests/Services/AWSKinesisIntegrationTests/KinesisTests.swift index b099a58eb24..6f151dc75e1 100644 --- a/IntegrationTests/Services/AWSKinesisIntegrationTests/KinesisTests.swift +++ b/IntegrationTests/Services/AWSKinesisIntegrationTests/KinesisTests.swift @@ -66,6 +66,10 @@ class KinesisTests: XCTestCase { let input = SubscribeToShardInput(consumerARN: consumerARN, shardId: shard?.shardId, startingPosition: KinesisClientTypes.StartingPosition(sequenceNumber: shard?.sequenceNumberRange?.startingSequenceNumber, type: .atSequenceNumber)) let output = try await client.subscribeToShard(input: input) + if let initialResponse = output.dayaffe?.value { + assert(initialResponse.isEmpty) + } + // Monitor the shard event stream for try await event in output.eventStream! { switch event { diff --git a/Sources/Core/AWSClientRuntime/EventStream/AWSMessageDecoder.swift b/Sources/Core/AWSClientRuntime/EventStream/AWSMessageDecoder.swift index e484e571007..49165f51ff6 100644 --- a/Sources/Core/AWSClientRuntime/EventStream/AWSMessageDecoder.swift +++ b/Sources/Core/AWSClientRuntime/EventStream/AWSMessageDecoder.swift @@ -17,6 +17,9 @@ extension AWSEventStream { private var decoder: EventStreamMessageDecoder? private var messageBuffer: [EventStream.Message] = [] private var error: Error? + private var initialMessage: Data = Data() + private var onInitialResponseReceived: ((Data?) -> Void)? + private var didProcessInitialMessage = false private var decodedPayload = Data() private var decodededHeaders: [EventStreamHeader] = [] @@ -44,8 +47,18 @@ extension AWSEventStream { self.logger.debug("onComplete") let message = EventStream.Message(headers: self.decodededHeaders.toHeaders(), payload: self.decodedPayload) - self.messageBuffer.append(message) + if (message.headers.contains(EventStream.Header(name: ":event-type", value: .string("initial-response")))) { + self.initialMessage = message.payload + self.onInitialResponseReceived?(self.initialMessage) + self.didProcessInitialMessage = true + } else { + self.messageBuffer.append(message) + if !self.didProcessInitialMessage { + self.onInitialResponseReceived?(nil) // Signal that initial-response will never come. + self.didProcessInitialMessage = true + } + } // This could be end of the stream, hence reset the state self.decodedPayload = Data() self.decodededHeaders = [] @@ -88,6 +101,22 @@ extension AWSEventStream { return message } + public func awaitInitialResponse() async -> Data? { + return await withCheckedContinuation { continuation in + retrieveInitialResponse { data in + continuation.resume(returning: data) + } + } + } + + public func retrieveInitialResponse(completion: @escaping (Data?) -> Void) { + if self.didProcessInitialMessage { + completion(initialMessage) // Could be nil or populated. + } else { + self.onInitialResponseReceived = completion + } + } + /// Throws an error if one has occurred. /// This should be called before any other methods to make sure /// that the decoder is in a valid state.