Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly place the SSL channel handler in front of the PostgresChannelHandler #527

Merged
merged 3 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,18 @@ public final class PostgresConnection: @unchecked Sendable {
func start(configuration: InternalConfiguration) -> EventLoopFuture<Void> {
// 1. configure handlers

let configureSSLCallback: ((Channel) throws -> ())?
let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> ())?

switch configuration.tls.base {
case .prefer(let context), .require(let context):
configureSSLCallback = { channel in
configureSSLCallback = { channel, postgresChannelHandler in
channel.eventLoop.assertInEventLoop()

let sslHandler = try NIOSSLClientHandler(
context: context,
serverHostname: configuration.serverNameForTLS
)
try channel.pipeline.syncOperations.addHandler(sslHandler, position: .first)
try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(postgresChannelHandler))
}
case .disable:
configureSSLCallback = nil
Expand Down
8 changes: 4 additions & 4 deletions Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
private var decoder: NIOSingleStepByteToMessageProcessor<PostgresBackendMessageDecoder>
private var encoder: PostgresFrontendMessageEncoder!
private let configuration: PostgresConnection.InternalConfiguration
private let configureSSLCallback: ((Channel) throws -> Void)?
private let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)?

private var listenState = ListenStateMachine()
private var preparedStatementState = PreparedStatementStateMachine()
Expand All @@ -29,7 +29,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
configuration: PostgresConnection.InternalConfiguration,
eventLoop: EventLoop,
logger: Logger,
configureSSLCallback: ((Channel) throws -> Void)?
configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)?
) {
self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData)
self.eventLoop = eventLoop
Expand All @@ -46,7 +46,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
eventLoop: EventLoop,
state: ConnectionStateMachine = .init(.initialized),
logger: Logger = .psqlNoOpLogger,
configureSSLCallback: ((Channel) throws -> Void)?
configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)?
) {
self.state = state
self.eventLoop = eventLoop
Expand Down Expand Up @@ -439,7 +439,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
// This method must only be called, if we signalized the StateMachine before that we are
// able to setup a SSL connection.
do {
try self.configureSSLCallback!(context.channel)
try self.configureSSLCallback!(context.channel, self)
let action = self.state.sslHandlerAdded()
self.run(action, with: context)
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class PostgresChannelHandlerTests: XCTestCase {
var config = self.testConnectionConfiguration()
XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration())))
var addSSLCallbackIsHit = false
let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in
let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in
addSSLCallbackIsHit = true
}
let embedded = EmbeddedChannel(handlers: [
Expand Down Expand Up @@ -84,7 +84,7 @@ class PostgresChannelHandlerTests: XCTestCase {
var config = self.testConnectionConfiguration()
XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration())))
var addSSLCallbackIsHit = false
let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in
let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in
addSSLCallbackIsHit = true
}
let eventHandler = TestEventHandler()
Expand Down Expand Up @@ -114,7 +114,7 @@ class PostgresChannelHandlerTests: XCTestCase {
func testSSLUnsupportedClosesConnection() throws {
let config = self.testConnectionConfiguration(tls: .require(try NIOSSLContext(configuration: .makeClientConfiguration())))

let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in
let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in
XCTFail("This callback should never be exectuded")
throw PSQLError.sslUnsupported
}
Expand Down
Loading