Skip to content

Commit

Permalink
issue-1774: Function to retrieve info about client connections was added
Browse files Browse the repository at this point in the history
  • Loading branch information
testisnullus committed Aug 19, 2024
1 parent 34fdeeb commit b2e98d5
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 10 deletions.
110 changes: 100 additions & 10 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,95 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
return applied, iter, iter.err
}

// connectionType is a custom type that represents the different stages
// of a client connection in a Cassandra cluster. It is used to filter and categorize
// connections based on their current state.
type connectionType string

const (
Ready connectionType = "ready"
Connecting connectionType = "connecting"
Idle connectionType = "idle"
Closed connectionType = "closed"
Failed connectionType = "failed"
)

// ClientConnection represents a client connection to a Cassandra node. It holds detailed
// information about the connection, including the client address, connection stage, driver details,
// and various configuration options.
type ClientConnection struct {
Address string
Port int
ConnectionStage string
DriverName string
DriverVersion string
Hostname string
KeyspaceName *string
ProtocolVersion int
RequestCount int
SSLCipherSuite *string
SSLEnabled bool
SSLProtocol *string
Username string
}

// RetrieveClientConnections retrieves a list of client connections from the
// `system_views.clients` table based on the specified connection type. The function
// queries the Cassandra database for connections with a given `connection_stage` and
// scans the results into a slice of `ClientConnection` structs. It handles nullable
// fields and returns the list of connections or an error if the operation fails.
func (s *Session) RetrieveClientConnections(connectionType connectionType) ([]*ClientConnection, error) {
const stmt = `
SELECT address, port, connection_stage, driver_name, driver_version,
hostname, keyspace_name, protocol_version, request_count,
ssl_cipher_suite, ssl_enabled, ssl_protocol, username
FROM system_views.clients
WHERE connection_stage = ?`

iter := s.control.query(stmt, connectionType)
if iter.NumRows() == 0 {
return nil, ErrConnectionsDoNotExist
}
defer iter.Close()

var connections []*ClientConnection
for {
conn := &ClientConnection{}

// Variables to hold nullable fields
var keyspaceName, sslCipherSuite, sslProtocol *string

if !iter.Scan(
&conn.Address,
&conn.Port,
&conn.ConnectionStage,
&conn.DriverName,
&conn.DriverVersion,
&conn.Hostname,
&keyspaceName,
&conn.ProtocolVersion,
&conn.RequestCount,
&sslCipherSuite,
&conn.SSLEnabled,
&sslProtocol,
&conn.Username,
) {
if err := iter.Close(); err != nil {
return nil, err
}
break
}

conn.KeyspaceName = keyspaceName
conn.SSLCipherSuite = sslCipherSuite
conn.SSLProtocol = sslProtocol

connections = append(connections, conn)
}

return connections, nil
}

type hostMetrics struct {
// Attempts is count of how many times this query has been attempted for this host.
// An attempt is either a retry or fetching next page of results.
Expand Down Expand Up @@ -2259,16 +2348,17 @@ func (e Error) Error() string {
}

var (
ErrNotFound = errors.New("not found")
ErrUnavailable = errors.New("unavailable")
ErrUnsupported = errors.New("feature not supported")
ErrTooManyStmts = errors.New("too many statements")
ErrUseStmt = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explanation.")
ErrSessionClosed = errors.New("session has been closed")
ErrNoConnections = errors.New("gocql: no hosts available in the pool")
ErrNoKeyspace = errors.New("no keyspace provided")
ErrKeyspaceDoesNotExist = errors.New("keyspace does not exist")
ErrNoMetadata = errors.New("no metadata available")
ErrNotFound = errors.New("not found")
ErrUnavailable = errors.New("unavailable")
ErrUnsupported = errors.New("feature not supported")
ErrTooManyStmts = errors.New("too many statements")
ErrUseStmt = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explanation.")
ErrSessionClosed = errors.New("session has been closed")
ErrNoConnections = errors.New("gocql: no hosts available in the pool")
ErrNoKeyspace = errors.New("no keyspace provided")
ErrKeyspaceDoesNotExist = errors.New("keyspace does not exist")
ErrConnectionsDoNotExist = errors.New("connections do not exist")
ErrNoMetadata = errors.New("no metadata available")
)

type ErrProtocol struct{ error }
Expand Down
74 changes: 74 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,77 @@ func TestIsUseStatement(t *testing.T) {
}
}
}

func TestRetrieveClientConnections(t *testing.T) {
testCases := []struct {
name string
connectionType connectionType
expectedResult []*ClientConnection
expectError bool
}{
{
name: "Valid ready connections",
connectionType: Ready,
expectedResult: []*ClientConnection{
{
Address: "127.0.0.1",
Port: 9042,
ConnectionStage: "ready",
DriverName: "gocql",
DriverVersion: "v1.0.0",
Hostname: "localhost",
KeyspaceName: nil,
ProtocolVersion: 4,
RequestCount: 10,
SSLCipherSuite: nil,
SSLEnabled: true,
SSLProtocol: nil,
Username: "user1",
},
},
expectError: false,
},
{
name: "No connections found",
connectionType: Closed,
expectedResult: nil,
expectError: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session := &Session{
control: &controlConn{},
}

results, err := session.RetrieveClientConnections(tc.connectionType)

if tc.expectError {
if err == nil {
t.Fatalf("expected an error but got none")
}
} else {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !compareClientConnections(results, tc.expectedResult) {
t.Fatalf("expected result %+v, got %+v", tc.expectedResult, results)
}
}
})
}
}

// Helper function to compare two slices of ClientConnection pointers
func compareClientConnections(a, b []*ClientConnection) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if *a[i] != *b[i] {
return false
}
}
return true
}

0 comments on commit b2e98d5

Please sign in to comment.