Skip to content

Commit

Permalink
Handle EmptyQueryResponse (#500)
Browse files Browse the repository at this point in the history
  • Loading branch information
MahdiBM authored Aug 21, 2024
1 parent cd5318a commit 3de37e6
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ struct ExtendedQueryStateMachine {
case parameterDescriptionReceived(ExtendedQueryContext)
case rowDescriptionReceived(ExtendedQueryContext, [RowDescription.Column])
case noDataMessageReceived(ExtendedQueryContext)

case emptyQueryResponseReceived

/// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is
/// used after receiving a `bindComplete` message
case bindCompleteReceived(ExtendedQueryContext)
Expand Down Expand Up @@ -122,7 +123,7 @@ struct ExtendedQueryStateMachine {
return .forwardStreamError(.queryCancelled, read: true)
}

case .commandComplete, .error, .drain:
case .commandComplete, .emptyQueryResponseReceived, .error, .drain:
// the stream has already finished.
return .wait

Expand Down Expand Up @@ -229,6 +230,7 @@ struct ExtendedQueryStateMachine {
.messagesSent,
.parseCompleteReceived,
.parameterDescriptionReceived,
.emptyQueryResponseReceived,
.bindCompleteReceived,
.streaming,
.drain,
Expand Down Expand Up @@ -268,6 +270,7 @@ struct ExtendedQueryStateMachine {
.parseCompleteReceived,
.parameterDescriptionReceived,
.noDataMessageReceived,
.emptyQueryResponseReceived,
.rowDescriptionReceived,
.bindCompleteReceived,
.commandComplete,
Expand All @@ -285,7 +288,7 @@ struct ExtendedQueryStateMachine {
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
return self.avoidingStateMachineCoW { state -> Action in
state = .commandComplete(commandTag: commandTag)
let result = QueryResult(value: .noRows(commandTag), logger: context.logger)
let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger)
return .succeedQuery(eventLoopPromise, with: result)
}

Expand All @@ -309,6 +312,7 @@ struct ExtendedQueryStateMachine {
.parseCompleteReceived,
.parameterDescriptionReceived,
.noDataMessageReceived,
.emptyQueryResponseReceived,
.rowDescriptionReceived,
.commandComplete,
.error:
Expand All @@ -319,7 +323,22 @@ struct ExtendedQueryStateMachine {
}

mutating func emptyQueryResponseReceived() -> Action {
preconditionFailure("Unimplemented")
guard case .bindCompleteReceived(let queryContext) = self.state else {
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
}

switch queryContext.query {
case .unnamed(_, let eventLoopPromise),
.executeStatement(_, let eventLoopPromise):
return self.avoidingStateMachineCoW { state -> Action in
state = .emptyQueryResponseReceived
let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger)
return .succeedQuery(eventLoopPromise, with: result)
}

case .prepareStatement(_, _, _, _):
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
}
}

mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action {
Expand All @@ -336,7 +355,7 @@ struct ExtendedQueryStateMachine {
return self.setAndFireError(error)
case .streaming, .drain:
return self.setAndFireError(error)
case .commandComplete:
case .commandComplete, .emptyQueryResponseReceived:
return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage)))
case .error:
preconditionFailure("""
Expand Down Expand Up @@ -382,6 +401,7 @@ struct ExtendedQueryStateMachine {
.parseCompleteReceived,
.parameterDescriptionReceived,
.noDataMessageReceived,
.emptyQueryResponseReceived,
.rowDescriptionReceived,
.bindCompleteReceived:
preconditionFailure("Requested to consume next row without anything going on.")
Expand All @@ -405,6 +425,7 @@ struct ExtendedQueryStateMachine {
.parseCompleteReceived,
.parameterDescriptionReceived,
.noDataMessageReceived,
.emptyQueryResponseReceived,
.rowDescriptionReceived,
.bindCompleteReceived:
return .wait
Expand Down Expand Up @@ -449,6 +470,7 @@ struct ExtendedQueryStateMachine {
}
case .initialized,
.commandComplete,
.emptyQueryResponseReceived,
.drain,
.error:
// we already have the complete stream received, now we are waiting for a
Expand Down Expand Up @@ -495,7 +517,7 @@ struct ExtendedQueryStateMachine {
return .forwardStreamError(error, read: true)
}

case .commandComplete, .error:
case .commandComplete, .emptyQueryResponseReceived, .error:
preconditionFailure("""
This state must not be reached. If the query `.isComplete`, the
ConnectionStateMachine must not send any further events to the substate machine.
Expand All @@ -507,7 +529,7 @@ struct ExtendedQueryStateMachine {

var isComplete: Bool {
switch self.state {
case .commandComplete, .error:
case .commandComplete, .emptyQueryResponseReceived, .error:
return true

case .noDataMessageReceived(let context), .rowDescriptionReceived(let context, _):
Expand Down
66 changes: 38 additions & 28 deletions Sources/PostgresNIO/New/PSQLRowStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import Logging

struct QueryResult {
enum Value: Equatable {
case noRows(String)
case noRows(PSQLRowStream.StatementSummary)
case rowDescription([RowDescription.Column])
}

Expand All @@ -16,25 +16,30 @@ struct QueryResult {
final class PSQLRowStream: @unchecked Sendable {
private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer<DataRow, Error, AdaptiveRowBuffer, PSQLRowStream>.Source

enum StatementSummary: Equatable {
case tag(String)
case emptyResponse
}

enum Source {
case stream([RowDescription.Column], PSQLRowsDataSource)
case noRows(Result<String, Error>)
case noRows(Result<StatementSummary, Error>)
}

let eventLoop: EventLoop
let logger: Logger

private enum BufferState {
case streaming(buffer: CircularBuffer<DataRow>, dataSource: PSQLRowsDataSource)
case finished(buffer: CircularBuffer<DataRow>, commandTag: String)
case finished(buffer: CircularBuffer<DataRow>, summary: StatementSummary)
case failure(Error)
}

private enum DownstreamState {
case waitingForConsumer(BufferState)
case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise<Void>, PSQLRowsDataSource)
case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource)
case consumed(Result<String, Error>)
case consumed(Result<StatementSummary, Error>)
case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ())
}

Expand All @@ -52,9 +57,9 @@ final class PSQLRowStream: @unchecked Sendable {
case .stream(let rowDescription, let dataSource):
self.rowDescription = rowDescription
bufferState = .streaming(buffer: .init(), dataSource: dataSource)
case .noRows(.success(let commandTag)):
case .noRows(.success(let summary)):
self.rowDescription = []
bufferState = .finished(buffer: .init(), commandTag: commandTag)
bufferState = .finished(buffer: .init(), summary: summary)
case .noRows(.failure(let error)):
self.rowDescription = []
bufferState = .failure(error)
Expand Down Expand Up @@ -98,12 +103,12 @@ final class PSQLRowStream: @unchecked Sendable {
self.downstreamState = .asyncSequence(source, dataSource, onFinish: onFinish)
self.executeActionBasedOnYieldResult(yieldResult, source: dataSource)

case .finished(let buffer, let commandTag):
case .finished(let buffer, let summary):
_ = source.yield(contentsOf: buffer)
source.finish()
onFinish()
self.downstreamState = .consumed(.success(commandTag))
self.downstreamState = .consumed(.success(summary))

case .failure(let error):
source.finish(error)
self.downstreamState = .consumed(.failure(error))
Expand Down Expand Up @@ -190,12 +195,12 @@ final class PSQLRowStream: @unchecked Sendable {
dataSource.request(for: self)
return promise.futureResult

case .finished(let buffer, let commandTag):
case .finished(let buffer, let summary):
let rows = buffer.map {
PostgresRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription)
}

self.downstreamState = .consumed(.success(commandTag))
self.downstreamState = .consumed(.success(summary))
return self.eventLoop.makeSucceededFuture(rows)

case .failure(let error):
Expand Down Expand Up @@ -247,8 +252,8 @@ final class PSQLRowStream: @unchecked Sendable {
}

return promise.futureResult
case .finished(let buffer, let commandTag):

case .finished(let buffer, let summary):
do {
for data in buffer {
let row = PostgresRow(
Expand All @@ -259,7 +264,7 @@ final class PSQLRowStream: @unchecked Sendable {
try onRow(row)
}

self.downstreamState = .consumed(.success(commandTag))
self.downstreamState = .consumed(.success(summary))
return self.eventLoop.makeSucceededVoidFuture()
} catch {
self.downstreamState = .consumed(.failure(error))
Expand Down Expand Up @@ -292,7 +297,7 @@ final class PSQLRowStream: @unchecked Sendable {

case .waitingForConsumer(.finished), .waitingForConsumer(.failure):
preconditionFailure("How can new rows be received, if an end was already signalled?")

case .iteratingRows(let onRow, let promise, let dataSource):
do {
for data in newRows {
Expand Down Expand Up @@ -347,25 +352,25 @@ final class PSQLRowStream: @unchecked Sendable {
private func receiveEnd(_ commandTag: String) {
switch self.downstreamState {
case .waitingForConsumer(.streaming(buffer: let buffer, _)):
self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, commandTag: commandTag))
case .waitingForConsumer(.finished), .waitingForConsumer(.failure):
self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, summary: .tag(commandTag)))

case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)):
preconditionFailure("How can we get another end, if an end was already signalled?")

case .iteratingRows(_, let promise, _):
self.downstreamState = .consumed(.success(commandTag))
self.downstreamState = .consumed(.success(.tag(commandTag)))
promise.succeed(())

case .waitingForAll(let rows, let promise, _):
self.downstreamState = .consumed(.success(commandTag))
self.downstreamState = .consumed(.success(.tag(commandTag)))
promise.succeed(rows)

case .asyncSequence(let source, _, let onFinish):
self.downstreamState = .consumed(.success(commandTag))
self.downstreamState = .consumed(.success(.tag(commandTag)))
source.finish()
onFinish()

case .consumed:
case .consumed(.success(.tag)), .consumed(.failure):
break
}
}
Expand All @@ -375,7 +380,7 @@ final class PSQLRowStream: @unchecked Sendable {
case .waitingForConsumer(.streaming):
self.downstreamState = .waitingForConsumer(.failure(error))

case .waitingForConsumer(.finished), .waitingForConsumer(.failure):
case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)):
preconditionFailure("How can we get another end, if an end was already signalled?")

case .iteratingRows(_, let promise, _):
Expand All @@ -391,7 +396,7 @@ final class PSQLRowStream: @unchecked Sendable {
consumer.finish(error)
onFinish()

case .consumed:
case .consumed(.success(.tag)), .consumed(.failure):
break
}
}
Expand All @@ -413,10 +418,15 @@ final class PSQLRowStream: @unchecked Sendable {
}

var commandTag: String {
guard case .consumed(.success(let commandTag)) = self.downstreamState else {
guard case .consumed(.success(let consumed)) = self.downstreamState else {
preconditionFailure("commandTag may only be called if all rows have been consumed")
}
return commandTag
switch consumed {
case .tag(let tag):
return tag
case .emptyResponse:
return ""
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,9 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
)
self.rowStream = rows

case .noRows(let commandTag):
case .noRows(let summary):
rows = PSQLRowStream(
source: .noRows(.success(commandTag)),
source: .noRows(.success(summary)),
eventLoop: context.channel.eventLoop,
logger: result.logger
)
Expand Down
5 changes: 1 addition & 4 deletions Sources/PostgresNIO/PostgresDatabase+Query.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,7 @@ public struct PostgresQueryMetadata: Sendable {

init?(string: String) {
let parts = string.split(separator: " ")
guard parts.count >= 1 else {
return nil
}
switch parts[0] {
switch parts.first {
case "INSERT":
// INSERT oid rows
guard parts.count == 3 else {
Expand Down
19 changes: 19 additions & 0 deletions Tests/IntegrationTests/PSQLIntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,25 @@ final class IntegrationTests: XCTestCase {
XCTAssertEqual(foo, "hello")
}

func testQueryNothing() throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

var conn: PostgresConnection?
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
defer { XCTAssertNoThrow(try conn?.close().wait()) }

var _result: PostgresQueryResult?
XCTAssertNoThrow(_result = try conn?.query("""
-- Some comments
""", logger: .psqlTest).wait())

let result = try XCTUnwrap(_result)
XCTAssertEqual(result.rows, [])
XCTAssertEqual(result.metadata.command, "")
}

func testDecodeIntegers() {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ExtendedQueryStateMachineTests: XCTestCase {
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)
XCTAssertEqual(state.noDataReceived(), .wait)
XCTAssertEqual(state.bindCompleteReceived(), .wait)
XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows("DELETE 1"), logger: logger)))
XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows(.tag("DELETE 1")), logger: logger)))
XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery)
}

Expand Down Expand Up @@ -77,7 +77,25 @@ class ExtendedQueryStateMachineTests: XCTestCase {
XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .forwardStreamComplete([row5, row6], commandTag: "SELECT 2"))
XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery)
}


func testExtendedQueryWithNoQuery() {
var state = ConnectionStateMachine.readyForQuery()

let logger = Logger.psqlTest
let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self)
promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all.
let query: PostgresQuery = "-- some comments"
let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise)

XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query))
XCTAssertEqual(state.parseCompleteReceived(), .wait)
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)
XCTAssertEqual(state.noDataReceived(), .wait)
XCTAssertEqual(state.bindCompleteReceived(), .wait)
XCTAssertEqual(state.emptyQueryResponseReceived(), .succeedQuery(promise, with: .init(value: .noRows(.emptyResponse), logger: logger)))
XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery)
}

func testReceiveTotallyUnexpectedMessageInQuery() {
var state = ConnectionStateMachine.readyForQuery()

Expand Down
Loading

0 comments on commit 3de37e6

Please sign in to comment.