Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issue-1774: Method to retrieve info about client connections was added #1796

Open
wants to merge 1 commit into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 100 additions & 10 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,95 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
return applied, iter, iter.err
}

// connectionType is a custom type that represents the different stages
// of a client connection in a Cassandra cluster. It is used to filter and categorize
// connections based on their current state.
type connectionType string

const (
Ready connectionType = "ready"
Connecting connectionType = "connecting"
Idle connectionType = "idle"
Closed connectionType = "closed"
Failed connectionType = "failed"
)

// ClientConnection represents a client connection to a Cassandra node. It holds detailed
// information about the connection, including the client address, connection stage, driver details,
// and various configuration options.
type ClientConnection struct {
Address string
Port int
ConnectionStage string
DriverName string
DriverVersion string
Hostname string
KeyspaceName *string
ProtocolVersion int
RequestCount int
SSLCipherSuite *string
SSLEnabled bool
SSLProtocol *string
Username string
}

// RetrieveClientConnections retrieves a list of client connections from the
// `system_views.clients` table based on the specified connection type. The function
// queries the Cassandra database for connections with a given `connection_stage` and
// scans the results into a slice of `ClientConnection` structs. It handles nullable
// fields and returns the list of connections or an error if the operation fails.
func (s *Session) RetrieveClientConnections(connectionType connectionType) ([]*ClientConnection, error) {
const stmt = `
SELECT address, port, connection_stage, driver_name, driver_version,
hostname, keyspace_name, protocol_version, request_count,
ssl_cipher_suite, ssl_enabled, ssl_protocol, username
FROM system_views.clients
WHERE connection_stage = ?`

iter := s.control.query(stmt, connectionType)
if iter.NumRows() == 0 {
return nil, ErrConnectionsDoNotExist
}
defer iter.Close()

var connections []*ClientConnection
for {
conn := &ClientConnection{}

// Variables to hold nullable fields
var keyspaceName, sslCipherSuite, sslProtocol *string

if !iter.Scan(
&conn.Address,
&conn.Port,
&conn.ConnectionStage,
&conn.DriverName,
&conn.DriverVersion,
&conn.Hostname,
&keyspaceName,
&conn.ProtocolVersion,
&conn.RequestCount,
&sslCipherSuite,
&conn.SSLEnabled,
&sslProtocol,
&conn.Username,
) {
if err := iter.Close(); err != nil {
return nil, err
}
break
}

conn.KeyspaceName = keyspaceName
conn.SSLCipherSuite = sslCipherSuite
conn.SSLProtocol = sslProtocol

connections = append(connections, conn)
}

return connections, nil
}

type hostMetrics struct {
// Attempts is count of how many times this query has been attempted for this host.
// An attempt is either a retry or fetching next page of results.
Expand Down Expand Up @@ -2279,16 +2368,17 @@ func (e Error) Error() string {
}

var (
ErrNotFound = errors.New("not found")
ErrUnavailable = errors.New("unavailable")
ErrUnsupported = errors.New("feature not supported")
ErrTooManyStmts = errors.New("too many statements")
ErrUseStmt = errors.New("use statements aren't supported. Please see https://github.com/apache/cassandra-gocql-driver for explanation.")
ErrSessionClosed = errors.New("session has been closed")
ErrNoConnections = errors.New("gocql: no hosts available in the pool")
ErrNoKeyspace = errors.New("no keyspace provided")
ErrKeyspaceDoesNotExist = errors.New("keyspace does not exist")
ErrNoMetadata = errors.New("no metadata available")
ErrNotFound = errors.New("not found")
ErrUnavailable = errors.New("unavailable")
ErrUnsupported = errors.New("feature not supported")
ErrTooManyStmts = errors.New("too many statements")
ErrUseStmt = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explanation.")
ErrSessionClosed = errors.New("session has been closed")
ErrNoConnections = errors.New("gocql: no hosts available in the pool")
ErrNoKeyspace = errors.New("no keyspace provided")
ErrKeyspaceDoesNotExist = errors.New("keyspace does not exist")
ErrConnectionsDoNotExist = errors.New("connections do not exist")
ErrNoMetadata = errors.New("no metadata available")
)

type ErrProtocol struct{ error }
Expand Down
74 changes: 74 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,77 @@ func TestIsUseStatement(t *testing.T) {
}
}
}

func TestRetrieveClientConnections(t *testing.T) {
testCases := []struct {
name string
connectionType connectionType
expectedResult []*ClientConnection
expectError bool
}{
{
name: "Valid ready connections",
connectionType: Ready,
expectedResult: []*ClientConnection{
{
Address: "127.0.0.1",
Port: 9042,
ConnectionStage: "ready",
DriverName: "gocql",
DriverVersion: "v1.0.0",
Hostname: "localhost",
KeyspaceName: nil,
ProtocolVersion: 4,
RequestCount: 10,
SSLCipherSuite: nil,
SSLEnabled: true,
SSLProtocol: nil,
Username: "user1",
},
},
expectError: false,
},
{
name: "No connections found",
connectionType: Closed,
expectedResult: nil,
expectError: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session := &Session{
control: &controlConn{},
}

results, err := session.RetrieveClientConnections(tc.connectionType)

if tc.expectError {
if err == nil {
t.Fatalf("expected an error but got none")
}
} else {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !compareClientConnections(results, tc.expectedResult) {
t.Fatalf("expected result %+v, got %+v", tc.expectedResult, results)
}
}
})
}
}

// Helper function to compare two slices of ClientConnection pointers
func compareClientConnections(a, b []*ClientConnection) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if *a[i] != *b[i] {
return false
}
}
return true
}