Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup encoding Startup message #395

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 0 additions & 52 deletions Sources/PostgresNIO/New/Messages/Startup.swift

This file was deleted.

13 changes: 1 addition & 12 deletions Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
case .wait:
break
case .sendStartupMessage(let authContext):
self.encoder.startup(authContext.toStartupParameters())
self.encoder.startup(user: authContext.username, database: authContext.database)
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
case .sendSSLRequest:
self.encoder.ssl()
Expand Down Expand Up @@ -684,17 +684,6 @@ extension PostgresChannelHandler: PSQLRowsDataSource {
}
}

extension AuthContext {
func toStartupParameters() -> PostgresFrontendMessage.Startup.Parameters {
PostgresFrontendMessage.Startup.Parameters(
user: self.username,
database: self.database,
options: nil,
replication: .false
)
}
}

private extension Insecure.MD5.Digest {

private static let lowercaseLookup: [UInt8] = [
Expand Down
48 changes: 44 additions & 4 deletions Sources/PostgresNIO/New/PostgresFrontendMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,50 @@ enum PostgresFrontendMessage: Equatable {
static let requestCode: Int32 = 80877103
}

struct Startup: Hashable {
static let versionThree: Int32 = 0x00_03_00_00

/// Creates a `Startup` with "3.0" as the protocol version.
static func versionThree(parameters: Parameters) -> Startup {
return .init(protocolVersion: Self.versionThree, parameters: parameters)
}

/// The protocol version number. The most significant 16 bits are the major
/// version number (3 for the protocol described here). The least significant
/// 16 bits are the minor version number (0 for the protocol described here).
var protocolVersion: Int32

/// The protocol version number is followed by one or more pairs of parameter
/// name and value strings. A zero byte is required as a terminator after
/// the last name/value pair. `user` is required, others are optional.
struct Parameters: Hashable {
enum Replication {
case `true`
case `false`
case database
}

/// The database user name to connect as. Required; there is no default.
var user: String

/// The database to connect to. Defaults to the user name.
var database: String?

/// Command-line arguments for the backend. (This is deprecated in favor
/// of setting individual run-time parameters.) Spaces within this string are
/// considered to separate arguments, unless escaped with a
/// backslash (\); write \\ to represent a literal backslash.
var options: String?

/// Used to connect in streaming replication mode, where a small set of
/// replication commands can be issued instead of SQL statements. Value
/// can be true, false, or database, and the default is false.
var replication: Replication
}

var parameters: Parameters
}

case bind(Bind)
case cancel(Cancel)
case close(Close)
Expand Down Expand Up @@ -225,7 +269,3 @@ extension PostgresFrontendMessage {
}
}
}

protocol PSQLMessagePayloadEncodable {
func encode(into buffer: inout ByteBuffer)
}
22 changes: 3 additions & 19 deletions Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,18 @@ struct PostgresFrontendMessageEncoder {
self.buffer = buffer
}

mutating func startup(_ parameters: PostgresFrontendMessage.Startup.Parameters) {
mutating func startup(user: String, database: String?) {
self.clearIfNeeded()
self.encodeLengthPrefixed { buffer in
buffer.writeInteger(PostgresFrontendMessage.Startup.versionThree)
buffer.writeNullTerminatedString("user")
buffer.writeNullTerminatedString(parameters.user)
buffer.writeNullTerminatedString(user)

if let database = parameters.database {
if let database = database {
buffer.writeNullTerminatedString("database")
buffer.writeNullTerminatedString(database)
}

if let options = parameters.options {
buffer.writeNullTerminatedString("options")
buffer.writeNullTerminatedString(options)
}

switch parameters.replication {
case .database:
buffer.writeNullTerminatedString("replication")
buffer.writeNullTerminatedString("replication")
case .true:
buffer.writeNullTerminatedString("replication")
buffer.writeNullTerminatedString("true")
case .false:
break
}

buffer.writeInteger(UInt8(0))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,7 @@ extension RowDescription: PSQLMessagePayloadEncodable {
}
}
}

protocol PSQLMessagePayloadEncodable {
func encode(into buffer: inout ByteBuffer)
}
82 changes: 35 additions & 47 deletions Tests/PostgresNIOTests/New/Messages/StartupTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,44 @@ import NIOCore

class StartupTests: XCTestCase {

func testStartupMessage() {
func testStartupMessageWithDatabase() {
var encoder = PostgresFrontendMessageEncoder(buffer: .init())
var byteBuffer = ByteBuffer()

let replicationValues: [PostgresFrontendMessage.Startup.Parameters.Replication] = [
.`true`,
.`false`,
.database
]

for replication in replicationValues {
let parameters = PostgresFrontendMessage.Startup.Parameters(
user: "test",
database: "abc123",
options: "some options",
replication: replication
)

encoder.startup(parameters)
byteBuffer = encoder.flushBuffer()

let byteBufferLength = Int32(byteBuffer.readableBytes)
XCTAssertEqual(byteBufferLength, byteBuffer.readInteger())
XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger())
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some options")
if replication != .false {
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "replication")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), replication.stringValue)
}
XCTAssertEqual(byteBuffer.readInteger(), UInt8(0))

XCTAssertEqual(byteBuffer.readableBytes, 0)
}

let user = "test"
let database = "abc123"

encoder.startup(user: user, database: database)
byteBuffer = encoder.flushBuffer()

let byteBufferLength = Int32(byteBuffer.readableBytes)
XCTAssertEqual(byteBufferLength, byteBuffer.readInteger())
XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger())
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123")
XCTAssertEqual(byteBuffer.readInteger(), UInt8(0))

XCTAssertEqual(byteBuffer.readableBytes, 0)
}
}

extension PostgresFrontendMessage.Startup.Parameters.Replication {
var stringValue: String {
switch self {
case .true:
return "true"
case .false:
return "false"
case .database:
return "replication"
}
func testStartupMessageWithoutDatabase() {
var encoder = PostgresFrontendMessageEncoder(buffer: .init())
var byteBuffer = ByteBuffer()

let user = "test"

encoder.startup(user: user, database: nil)
byteBuffer = encoder.flushBuffer()

let byteBufferLength = Int32(byteBuffer.readableBytes)
XCTAssertEqual(byteBufferLength, byteBuffer.readInteger())
XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger())
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user")
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test")
XCTAssertEqual(byteBuffer.readInteger(), UInt8(0))

XCTAssertEqual(byteBuffer.readableBytes, 0)
}
}