Skip to content

Commit

Permalink
PMM-12640 Fix support of SSL certificates for PostgreSQL. (#2586)
Browse files Browse the repository at this point in the history
* PMM-12640 Fix support of SSL certificates for PostgreSQL.

* PMM-12640 Fix tests.

* PMM-12640 don't send sslsni for old PMM clients.

* PMM-12640 Fix tests.

* PMM-12640 Fix linter.

* PMM-12640 Fix linter.

* PMM-12654 Fix test and address review comments.

* PMM-12640 Fix connection checker.
  • Loading branch information
BupycHuk authored Nov 13, 2023
1 parent ee96d27 commit 80e4c8b
Show file tree
Hide file tree
Showing 23 changed files with 252 additions and 94 deletions.
7 changes: 3 additions & 4 deletions agent/agents/postgres/pgstatstatements/pgstatstatements.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,9 @@ func (m *PGStatStatementsQAN) getStatStatementsExtended(
databases := queryDatabases(q)
usernames := queryUsernames(q)

rows, e := rowsByVersion(q, "WHERE queryid IS NOT NULL AND query IS NOT NULL")
if e != nil {
err = e
return nil, nil, err
rows, err := rowsByVersion(q, "WHERE queryid IS NOT NULL AND query IS NOT NULL")
if err != nil {
return nil, nil, errors.Wrap(err, "couldn't get rows from pg_stat_statements")
}
defer rows.Close() //nolint:errcheck

Expand Down
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ services:
profiles:
- pmm
image: ${PMM_CONTAINER:-perconalab/pmm-server:dev-container}
platform: linux/amd64
# build:
# context: .
# args:
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ require (
github.com/jhunters/bigqueue v1.2.7
github.com/jmoiron/sqlx v1.3.5
github.com/jotaen/kong-completion v0.0.5
github.com/lib/pq v1.10.6
github.com/lib/pq v1.10.9
github.com/minio/minio-go/v7 v7.0.55
github.com/operator-framework/api v0.17.6
github.com/operator-framework/operator-lifecycle-manager v0.24.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,8 @@ github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs=
github.com/lib/pq v1.10.6/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
Expand Down
39 changes: 25 additions & 14 deletions managed/models/agent_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,15 @@ func (s *Agent) DBConfig(service *Service) *DBConfig {
}
}

type DSNParams struct {
DialTimeout time.Duration
Database string

PostgreSQLSupportsSSLSNI bool
}

// DSN returns DSN string for accessing given Service with this Agent (and implicit driver).
func (s *Agent) DSN(service *Service, dialTimeout time.Duration, database string, tdp *DelimiterPair) string { //nolint:cyclop,maintidx
func (s *Agent) DSN(service *Service, dsnParams DSNParams, tdp *DelimiterPair) string { //nolint:cyclop,maintidx
host := pointer.GetString(service.Address)
port := pointer.GetUint16(service.Port)
socket := pointer.GetString(service.Socket)
Expand All @@ -320,8 +327,8 @@ func (s *Agent) DSN(service *Service, dialTimeout time.Duration, database string
cfg.Net = tcp
cfg.Addr = net.JoinHostPort(host, strconv.Itoa(int(port)))
}
cfg.Timeout = dialTimeout
cfg.DBName = database
cfg.Timeout = dsnParams.DialTimeout
cfg.DBName = dsnParams.Database
cfg.Params = make(map[string]string)
if s.TLS {
switch {
Expand Down Expand Up @@ -349,8 +356,8 @@ func (s *Agent) DSN(service *Service, dialTimeout time.Duration, database string
cfg.Net = tcp
cfg.Addr = net.JoinHostPort(host, strconv.Itoa(int(port)))
}
cfg.Timeout = dialTimeout
cfg.DBName = database
cfg.Timeout = dsnParams.DialTimeout
cfg.DBName = dsnParams.Database
cfg.Params = make(map[string]string)
if s.TLS {
switch {
Expand Down Expand Up @@ -382,8 +389,8 @@ func (s *Agent) DSN(service *Service, dialTimeout time.Duration, database string
cfg.Net = tcp
cfg.Addr = net.JoinHostPort(host, strconv.Itoa(int(port)))
}
cfg.Timeout = dialTimeout
cfg.DBName = database
cfg.Timeout = dsnParams.DialTimeout
cfg.DBName = dsnParams.Database
cfg.Params = make(map[string]string)
if s.TLS {
if s.TLSSkipVerify {
Expand All @@ -400,16 +407,16 @@ func (s *Agent) DSN(service *Service, dialTimeout time.Duration, database string

case QANMongoDBProfilerAgentType, MongoDBExporterType:
q := make(url.Values)
if dialTimeout != 0 {
q.Set("connectTimeoutMS", strconv.Itoa(int(dialTimeout/time.Millisecond)))
q.Set("serverSelectionTimeoutMS", strconv.Itoa(int(dialTimeout/time.Millisecond)))
if dsnParams.DialTimeout != 0 {
q.Set("connectTimeoutMS", strconv.Itoa(int(dsnParams.DialTimeout/time.Millisecond)))
q.Set("serverSelectionTimeoutMS", strconv.Itoa(int(dsnParams.DialTimeout/time.Millisecond)))
}

// https://docs.mongodb.com/manual/reference/connection-string/
// > If the connection string does not specify a database/ you must specify a slash (/)
// between the last host and the question mark (?) that begins the string of options.
path := database
if database == "" {
path := dsnParams.Database
if path == "" {
path = "/"
}

Expand Down Expand Up @@ -475,6 +482,9 @@ func (s *Agent) DSN(service *Service, dialTimeout time.Duration, database string
} else {
sslmode = VerifyCaSSLMode
}
if dsnParams.PostgreSQLSupportsSSLSNI {
q.Set("sslsni", "0")
}
}
q.Set("sslmode", sslmode)

Expand All @@ -493,11 +503,12 @@ func (s *Agent) DSN(service *Service, dialTimeout time.Duration, database string
}
}

if dialTimeout != 0 {
q.Set("connect_timeout", strconv.Itoa(int(dialTimeout.Seconds())))
if dsnParams.DialTimeout != 0 {
q.Set("connect_timeout", strconv.Itoa(int(dsnParams.DialTimeout.Seconds())))
}

address := ""
database := dsnParams.Database
if socket == "" {
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
} else {
Expand Down
40 changes: 20 additions & 20 deletions managed/models/agent_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ func TestAgent(t *testing.T) {
} {
t.Run(string(typ), func(t *testing.T) {
agent.AgentType = typ
assert.Equal(t, expected, agent.DSN(service, time.Second, "database", nil))
assert.Equal(t, expected, agent.DSN(service, models.DSNParams{DialTimeout: time.Second, Database: "database"}, nil))
})
}

t.Run("MongoDBNoDatabase", func(t *testing.T) {
agent.AgentType = models.MongoDBExporterType

assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?connectTimeoutMS=1000&directConnection=true&serverSelectionTimeoutMS=1000", agent.DSN(service, time.Second, "", nil))
assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?directConnection=true", agent.DSN(service, 0, "", nil))
assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?connectTimeoutMS=1000&directConnection=true&serverSelectionTimeoutMS=1000", agent.DSN(service, models.DSNParams{DialTimeout: time.Second}, nil))
assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?directConnection=true", agent.DSN(service, models.DSNParams{}, nil))
})
})

Expand All @@ -94,7 +94,7 @@ func TestAgent(t *testing.T) {
} {
t.Run(string(typ), func(t *testing.T) {
agent.AgentType = typ
assert.Equal(t, expected, agent.DSN(service, time.Second, "database", nil))
assert.Equal(t, expected, agent.DSN(service, models.DSNParams{DialTimeout: time.Second, Database: "database"}, nil))
})
}
})
Expand All @@ -113,7 +113,7 @@ func TestAgent(t *testing.T) {
} {
t.Run(string(typ), func(t *testing.T) {
agent.AgentType = typ
assert.Equal(t, expected, agent.DSN(service, time.Second, "database", nil))
assert.Equal(t, expected, agent.DSN(service, models.DSNParams{DialTimeout: time.Second, Database: "database"}, nil))
})
}
})
Expand Down Expand Up @@ -159,7 +159,7 @@ func TestAgent(t *testing.T) {
} {
t.Run(string(typ), func(t *testing.T) {
agent.AgentType = typ
assert.Equal(t, expected, agent.DSN(service, time.Second, "database", nil))
assert.Equal(t, expected, agent.DSN(service, models.DSNParams{DialTimeout: time.Second, Database: "database"}, nil))
})
}

Expand All @@ -169,8 +169,8 @@ func TestAgent(t *testing.T) {
agent.MongoDBOptions.TLSCertificateKeyFilePassword = ""
agent.MongoDBOptions.AuthenticationMechanism = ""

assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?connectTimeoutMS=1000&directConnection=true&serverSelectionTimeoutMS=1000&ssl=true&tlsCaFile={{.TextFiles.caFilePlaceholder}}&tlsCertificateKeyFile={{.TextFiles.certificateKeyFilePlaceholder}}", agent.DSN(service, time.Second, "", nil))
assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?directConnection=true&ssl=true&tlsCaFile={{.TextFiles.caFilePlaceholder}}&tlsCertificateKeyFile={{.TextFiles.certificateKeyFilePlaceholder}}", agent.DSN(service, 0, "", nil))
assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?connectTimeoutMS=1000&directConnection=true&serverSelectionTimeoutMS=1000&ssl=true&tlsCaFile={{.TextFiles.caFilePlaceholder}}&tlsCertificateKeyFile={{.TextFiles.certificateKeyFilePlaceholder}}", agent.DSN(service, models.DSNParams{DialTimeout: time.Second}, nil))
assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?directConnection=true&ssl=true&tlsCaFile={{.TextFiles.caFilePlaceholder}}&tlsCertificateKeyFile={{.TextFiles.certificateKeyFilePlaceholder}}", agent.DSN(service, models.DSNParams{}, nil))
expectedFiles := map[string]string{
"caFilePlaceholder": "cert",
"certificateKeyFilePlaceholder": "key",
Expand All @@ -185,8 +185,8 @@ func TestAgent(t *testing.T) {
agent.MongoDBOptions.AuthenticationMechanism = "MONGO-X509"
agent.MongoDBOptions.AuthenticationDatabase = "$external"

assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?authMechanism=MONGO-X509&authSource=%24external&connectTimeoutMS=1000&directConnection=true&serverSelectionTimeoutMS=1000&ssl=true&tlsCaFile={{.TextFiles.caFilePlaceholder}}&tlsCertificateKeyFile={{.TextFiles.certificateKeyFilePlaceholder}}", agent.DSN(service, time.Second, "", nil))
assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?authMechanism=MONGO-X509&authSource=%24external&directConnection=true&ssl=true&tlsCaFile={{.TextFiles.caFilePlaceholder}}&tlsCertificateKeyFile={{.TextFiles.certificateKeyFilePlaceholder}}", agent.DSN(service, 0, "", nil))
assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?authMechanism=MONGO-X509&authSource=%24external&connectTimeoutMS=1000&directConnection=true&serverSelectionTimeoutMS=1000&ssl=true&tlsCaFile={{.TextFiles.caFilePlaceholder}}&tlsCertificateKeyFile={{.TextFiles.certificateKeyFilePlaceholder}}", agent.DSN(service, models.DSNParams{DialTimeout: time.Second}, nil))
assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?authMechanism=MONGO-X509&authSource=%24external&directConnection=true&ssl=true&tlsCaFile={{.TextFiles.caFilePlaceholder}}&tlsCertificateKeyFile={{.TextFiles.certificateKeyFilePlaceholder}}", agent.DSN(service, models.DSNParams{}, nil))
expectedFiles := map[string]string{
"caFilePlaceholder": "cert",
"certificateKeyFilePlaceholder": "key",
Expand Down Expand Up @@ -217,15 +217,15 @@ func TestAgent(t *testing.T) {
} {
t.Run(string(typ), func(t *testing.T) {
agent.AgentType = typ
assert.Equal(t, expected, agent.DSN(service, time.Second, "database", nil))
assert.Equal(t, expected, agent.DSN(service, models.DSNParams{DialTimeout: time.Second, Database: "database"}, nil))
})
}

t.Run("MongoDBNoDatabase", func(t *testing.T) {
agent.AgentType = models.MongoDBExporterType

assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?connectTimeoutMS=1000&directConnection=true&serverSelectionTimeoutMS=1000&ssl=true&tlsInsecure=true", agent.DSN(service, time.Second, "", nil))
assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?directConnection=true&ssl=true&tlsInsecure=true", agent.DSN(service, 0, "", nil))
assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?connectTimeoutMS=1000&directConnection=true&serverSelectionTimeoutMS=1000&ssl=true&tlsInsecure=true", agent.DSN(service, models.DSNParams{DialTimeout: time.Second}, nil))
assert.Equal(t, "mongodb://username:s3cur3%20p%[email protected]:12345/?directConnection=true&ssl=true&tlsInsecure=true", agent.DSN(service, models.DSNParams{}, nil))
})
})
}
Expand Down Expand Up @@ -255,7 +255,7 @@ func TestPostgresAgentTLS(t *testing.T) {
t.Run(name, func(t *testing.T) {
agent.TLS = testCase.tls
agent.TLSSkipVerify = testCase.tlsSkipVerify
assert.Equal(t, testCase.expected, agent.DSN(service, time.Second, "database", nil))
assert.Equal(t, testCase.expected, agent.DSN(service, models.DSNParams{DialTimeout: time.Second, Database: "database"}, nil))
})
}
}
Expand All @@ -272,7 +272,7 @@ func TestPostgresWithSocket(t *testing.T) {
Socket: pointer.ToString("/var/run/postgres"),
}
expect := "postgres://username@/database?connect_timeout=1&host=%2Fvar%2Frun%2Fpostgres&sslmode=verify-ca"
assert.Equal(t, expect, agent.DSN(service, time.Second, "database", nil))
assert.Equal(t, expect, agent.DSN(service, models.DSNParams{DialTimeout: time.Second, Database: "database"}, nil))
})

t.Run("empty-user-password", func(t *testing.T) {
Expand All @@ -283,7 +283,7 @@ func TestPostgresWithSocket(t *testing.T) {
Socket: pointer.ToString("/var/run/postgres"),
}
expect := "postgres:///database?connect_timeout=1&host=%2Fvar%2Frun%2Fpostgres&sslmode=disable"
assert.Equal(t, expect, agent.DSN(service, time.Second, "database", nil))
assert.Equal(t, expect, agent.DSN(service, models.DSNParams{DialTimeout: time.Second, Database: "database"}, nil))
})

t.Run("dir-with-symbols", func(t *testing.T) {
Expand All @@ -294,7 +294,7 @@ func TestPostgresWithSocket(t *testing.T) {
Socket: pointer.ToString(`/tmp/123\ A0m\%\$\@\8\,\+\-`),
}
expect := "postgres:///database?connect_timeout=1&host=%2Ftmp%2F123%5C+A0m%5C%25%5C%24%5C%40%5C8%5C%2C%5C%2B%5C-&sslmode=disable"
assert.Equal(t, expect, agent.DSN(service, time.Second, "database", nil))
assert.Equal(t, expect, agent.DSN(service, models.DSNParams{DialTimeout: time.Second, Database: "database"}, nil))
})
}

Expand All @@ -310,7 +310,7 @@ func TestMongoWithSocket(t *testing.T) {
Socket: pointer.ToString("/tmp/mongodb-27017.sock"),
}
expect := "mongodb://username@%2Ftmp%2Fmongodb-27017.sock/database?connectTimeoutMS=1000&directConnection=true&serverSelectionTimeoutMS=1000&ssl=true"
assert.Equal(t, expect, agent.DSN(service, time.Second, "database", nil))
assert.Equal(t, expect, agent.DSN(service, models.DSNParams{DialTimeout: time.Second, Database: "database"}, nil))
})

t.Run("empty-user-password", func(t *testing.T) {
Expand All @@ -321,7 +321,7 @@ func TestMongoWithSocket(t *testing.T) {
Socket: pointer.ToString("/tmp/mongodb-27017.sock"),
}
expect := "mongodb://%2Ftmp%2Fmongodb-27017.sock/database?connectTimeoutMS=1000&directConnection=true&serverSelectionTimeoutMS=1000"
assert.Equal(t, expect, agent.DSN(service, time.Second, "database", nil))
assert.Equal(t, expect, agent.DSN(service, models.DSNParams{DialTimeout: time.Second, Database: "database"}, nil))
})

t.Run("dir-with-symbols", func(t *testing.T) {
Expand All @@ -332,7 +332,7 @@ func TestMongoWithSocket(t *testing.T) {
Socket: pointer.ToString(`/tmp/123\ A0m\%\$\@\8\,\+\-/mongodb-27017.sock`),
}
expect := "mongodb://%2Ftmp%2F123%5C%20A0m%5C%25%5C$%5C%40%5C8%5C,%5C+%5C-%2Fmongodb-27017.sock/database?connectTimeoutMS=1000&directConnection=true&serverSelectionTimeoutMS=1000"
assert.Equal(t, expect, agent.DSN(service, time.Second, "database", nil))
assert.Equal(t, expect, agent.DSN(service, models.DSNParams{DialTimeout: time.Second, Database: "database"}, nil))
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

package agents
package models

import (
"fmt"

"github.com/hashicorp/go-version"
"github.com/pkg/errors"
"gopkg.in/reform.v1"

"github.com/percona/pmm/managed/models"
)

var PMMAgentMinVersionForPostgreSQLSSLSni = version.Must(version.NewVersion("2.41.0-0"))

// AgentNotSupportedError is used when the target PMM agent doesn't support the requested functionality.
type AgentNotSupportedError struct {
Functionality string
Expand All @@ -40,15 +40,15 @@ func (e *AgentNotSupportedError) Error() string {

// PMMAgentSupported checks if pmm agent version satisfies required min version.
func PMMAgentSupported(q *reform.Querier, pmmAgentID, functionalityPrefix string, pmmMinVersion *version.Version) error {
pmmAgent, err := models.FindAgentByID(q, pmmAgentID)
pmmAgent, err := FindAgentByID(q, pmmAgentID)
if err != nil {
return errors.Errorf("failed to get PMM Agent: %s", err)
}
return isAgentSupported(pmmAgent, functionalityPrefix, pmmMinVersion)
}

// isAgentSupported contains logic for PMMAgentSupported.
func isAgentSupported(agentModel *models.Agent, functionalityPrefix string, pmmMinVersion *version.Version) error {
func isAgentSupported(agentModel *Agent, functionalityPrefix string, pmmMinVersion *version.Version) error {
if agentModel == nil {
return errors.New("nil agent")
}
Expand All @@ -70,3 +70,15 @@ func isAgentSupported(agentModel *models.Agent, functionalityPrefix string, pmmM
}
return nil
}

func IsPostgreSQLSSLSniSupported(q *reform.Querier, pmmAgentID string) (bool, error) {
err := PMMAgentSupported(q, pmmAgentID, "postgresql SSL sni check", PMMAgentMinVersionForPostgreSQLSSLSni)
switch {
case errors.Is(err, &AgentNotSupportedError{}):
return false, nil
case err == nil:
return true, nil
default:
return false, errors.Wrap(err, "couldn't compare PMM Agent version")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

package agents
package models

import (
"testing"

"github.com/AlekSi/pointer"
"github.com/hashicorp/go-version"
"github.com/stretchr/testify/assert"

"github.com/percona/pmm/managed/models"
)

func TestPMMAgentSupported(t *testing.T) {
Expand Down Expand Up @@ -66,7 +64,7 @@ func TestPMMAgentSupported(t *testing.T) {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
agentModel := models.Agent{
agentModel := Agent{
AgentID: "Test agent ID",
Version: pointer.ToString(test.agentVersion),
}
Expand All @@ -80,7 +78,7 @@ func TestPMMAgentSupported(t *testing.T) {
}

t.Run("No version info", func(t *testing.T) {
err := isAgentSupported(&models.Agent{AgentID: "Test agent ID"}, prefix, version.Must(version.NewVersion("2.30.0")))
err := isAgentSupported(&Agent{AgentID: "Test agent ID"}, prefix, version.Must(version.NewVersion("2.30.0")))
assert.Contains(t, err.Error(), "has no version info")
})

Expand Down
Loading

0 comments on commit 80e4c8b

Please sign in to comment.