diff --git a/pkg/amqp10_client/common.go b/pkg/amqp10_client/common.go index d94a7c0..1153d29 100644 --- a/pkg/amqp10_client/common.go +++ b/pkg/amqp10_client/common.go @@ -8,7 +8,7 @@ import ( "github.com/rabbitmq/omq/pkg/log" ) -func amqpVHost(connectionString string) string { +func hostAndVHost(connectionString string) (string, string) { uri, err := url.Parse(connectionString) if err != nil { log.Error("failed to parse connection string", "error", err.Error()) @@ -20,5 +20,5 @@ func amqpVHost(connectionString string) string { vhost = strings.TrimPrefix(uri.Path, "/") } - return "vhost:" + vhost + return uri.Hostname(), "vhost:" + vhost } diff --git a/pkg/amqp10_client/consumer.go b/pkg/amqp10_client/consumer.go index 73e72a2..b2c4558 100644 --- a/pkg/amqp10_client/consumer.go +++ b/pkg/amqp10_client/consumer.go @@ -2,6 +2,7 @@ package amqp10_client import ( "context" + "crypto/tls" "fmt" "os" "strconv" @@ -29,9 +30,11 @@ type Amqp10Consumer struct { func NewConsumer(cfg config.Config, id int) *Amqp10Consumer { // open connection + hostname, vhost := hostAndVHost(cfg.ConsumerUri) conn, err := amqp.Dial(context.TODO(), cfg.ConsumerUri, &amqp.ConnOptions{ - HostName: amqpVHost(cfg.ConsumerUri), - }) + HostName: vhost, + TLSConfig: &tls.Config{ + ServerName: hostname}}) if err != nil { log.Error("consumer failed to connect", "protocol", "amqp-1.0", "consumerId", id, "error", err.Error()) return nil diff --git a/pkg/amqp10_client/publisher.go b/pkg/amqp10_client/publisher.go index b726db3..dd5f79c 100644 --- a/pkg/amqp10_client/publisher.go +++ b/pkg/amqp10_client/publisher.go @@ -2,6 +2,7 @@ package amqp10_client import ( "context" + "crypto/tls" "math/rand" "strconv" "time" @@ -28,8 +29,11 @@ type Amqp10Publisher struct { func NewPublisher(cfg config.Config, n int) *Amqp10Publisher { // open connection + hostname, vhost := hostAndVHost(cfg.PublisherUri) conn, err := amqp.Dial(context.TODO(), cfg.PublisherUri, &amqp.ConnOptions{ - HostName: amqpVHost(cfg.PublisherUri)}) + HostName: vhost, + TLSConfig: &tls.Config{ + ServerName: hostname}}) if err != nil { log.Error("publisher connection failed", "protocol", "amqp-1.0", "publisherId", n, "error", err.Error()) return nil