diff --git a/session.go b/session.go index a600b95f3..314b3006e 100644 --- a/session.go +++ b/session.go @@ -800,6 +800,95 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{}) return applied, iter, iter.err } +// connectionType is a custom type that represents the different stages +// of a client connection in a Cassandra cluster. It is used to filter and categorize +// connections based on their current state. +type connectionType string + +const ( + Ready connectionType = "ready" + Connecting connectionType = "connecting" + Idle connectionType = "idle" + Closed connectionType = "closed" + Failed connectionType = "failed" +) + +// ClientConnection represents a client connection to a Cassandra node. It holds detailed +// information about the connection, including the client address, connection stage, driver details, +// and various configuration options. +type ClientConnection struct { + Address string + Port int + ConnectionStage string + DriverName string + DriverVersion string + Hostname string + KeyspaceName *string + ProtocolVersion int + RequestCount int + SSLCipherSuite *string + SSLEnabled bool + SSLProtocol *string + Username string +} + +// RetrieveClientConnections retrieves a list of client connections from the +// `system_views.clients` table based on the specified connection type. The function +// queries the Cassandra database for connections with a given `connection_stage` and +// scans the results into a slice of `ClientConnection` structs. It handles nullable +// fields and returns the list of connections or an error if the operation fails. +func (s *Session) RetrieveClientConnections(connectionType connectionType) ([]*ClientConnection, error) { + const stmt = ` + SELECT address, port, connection_stage, driver_name, driver_version, + hostname, keyspace_name, protocol_version, request_count, + ssl_cipher_suite, ssl_enabled, ssl_protocol, username + FROM system_views.clients + WHERE connection_stage = ?` + + iter := s.control.query(stmt, connectionType) + if iter.NumRows() == 0 { + return nil, ErrConnectionsDoNotExist + } + defer iter.Close() + + var connections []*ClientConnection + for { + conn := &ClientConnection{} + + // Variables to hold nullable fields + var keyspaceName, sslCipherSuite, sslProtocol *string + + if !iter.Scan( + &conn.Address, + &conn.Port, + &conn.ConnectionStage, + &conn.DriverName, + &conn.DriverVersion, + &conn.Hostname, + &keyspaceName, + &conn.ProtocolVersion, + &conn.RequestCount, + &sslCipherSuite, + &conn.SSLEnabled, + &sslProtocol, + &conn.Username, + ) { + if err := iter.Close(); err != nil { + return nil, err + } + break + } + + conn.KeyspaceName = keyspaceName + conn.SSLCipherSuite = sslCipherSuite + conn.SSLProtocol = sslProtocol + + connections = append(connections, conn) + } + + return connections, nil +} + type hostMetrics struct { // Attempts is count of how many times this query has been attempted for this host. // An attempt is either a retry or fetching next page of results. @@ -2279,16 +2368,17 @@ func (e Error) Error() string { } var ( - ErrNotFound = errors.New("not found") - ErrUnavailable = errors.New("unavailable") - ErrUnsupported = errors.New("feature not supported") - ErrTooManyStmts = errors.New("too many statements") - ErrUseStmt = errors.New("use statements aren't supported. Please see https://github.com/apache/cassandra-gocql-driver for explanation.") - ErrSessionClosed = errors.New("session has been closed") - ErrNoConnections = errors.New("gocql: no hosts available in the pool") - ErrNoKeyspace = errors.New("no keyspace provided") - ErrKeyspaceDoesNotExist = errors.New("keyspace does not exist") - ErrNoMetadata = errors.New("no metadata available") + ErrNotFound = errors.New("not found") + ErrUnavailable = errors.New("unavailable") + ErrUnsupported = errors.New("feature not supported") + ErrTooManyStmts = errors.New("too many statements") + ErrUseStmt = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explanation.") + ErrSessionClosed = errors.New("session has been closed") + ErrNoConnections = errors.New("gocql: no hosts available in the pool") + ErrNoKeyspace = errors.New("no keyspace provided") + ErrKeyspaceDoesNotExist = errors.New("keyspace does not exist") + ErrConnectionsDoNotExist = errors.New("connections do not exist") + ErrNoMetadata = errors.New("no metadata available") ) type ErrProtocol struct{ error } diff --git a/session_test.go b/session_test.go index 0319a8a4c..95a3cf44e 100644 --- a/session_test.go +++ b/session_test.go @@ -347,3 +347,77 @@ func TestIsUseStatement(t *testing.T) { } } } + +func TestRetrieveClientConnections(t *testing.T) { + testCases := []struct { + name string + connectionType connectionType + expectedResult []*ClientConnection + expectError bool + }{ + { + name: "Valid ready connections", + connectionType: Ready, + expectedResult: []*ClientConnection{ + { + Address: "127.0.0.1", + Port: 9042, + ConnectionStage: "ready", + DriverName: "gocql", + DriverVersion: "v1.0.0", + Hostname: "localhost", + KeyspaceName: nil, + ProtocolVersion: 4, + RequestCount: 10, + SSLCipherSuite: nil, + SSLEnabled: true, + SSLProtocol: nil, + Username: "user1", + }, + }, + expectError: false, + }, + { + name: "No connections found", + connectionType: Closed, + expectedResult: nil, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + session := &Session{ + control: &controlConn{}, + } + + results, err := session.RetrieveClientConnections(tc.connectionType) + + if tc.expectError { + if err == nil { + t.Fatalf("expected an error but got none") + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !compareClientConnections(results, tc.expectedResult) { + t.Fatalf("expected result %+v, got %+v", tc.expectedResult, results) + } + } + }) + } +} + +// Helper function to compare two slices of ClientConnection pointers +func compareClientConnections(a, b []*ClientConnection) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if *a[i] != *b[i] { + return false + } + } + return true +}