From 0444ee3479b4ca9d469148b51f1e7261051a558f Mon Sep 17 00:00:00 2001 From: Bohdan Siryk Date: Fri, 2 Aug 2024 15:48:17 +0300 Subject: [PATCH] Support of sending queries to the specific node --- cassandra_test.go | 31 ++++++++++++++++++++++++++++++- query_executor.go | 23 ++++++++++++++++++++++- session.go | 24 ++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 2 deletions(-) diff --git a/cassandra_test.go b/cassandra_test.go index 797a7cf7f..8c1439a81 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -3253,7 +3253,6 @@ func TestUnsetColBatch(t *testing.T) { } var id, mInt, count int var mText string - if err := session.Query("SELECT count(*) FROM gocql_test.batchUnsetInsert;").Scan(&count); err != nil { t.Fatalf("Failed to select with err: %v", err) } else if count != 2 { @@ -3288,3 +3287,33 @@ func TestQuery_NamedValues(t *testing.T) { t.Fatal(err) } } + +func TestQuery_SetHost(t *testing.T) { + session := createSession(t) + defer session.Close() + + hosts, err := session.GetHosts() + if err != nil { + t.Fatal(err) + } + + for _, expectedHost := range hosts { + const iterations = 5 + for i := 0; i < iterations; i++ { + var actualHostID string + err := session.Query("SELECT host_id FROM system.local"). + SetHost(expectedHost). + Scan(&actualHostID) + if err != nil { + t.Fatal(err) + } + + if expectedHost.HostID() != actualHostID { + t.Fatalf("Expected query to be executed on host %s, but it was executed on %s", + expectedHost.HostID(), + actualHostID, + ) + } + } + } +} diff --git a/query_executor.go b/query_executor.go index fb68b07f2..869c6259e 100644 --- a/query_executor.go +++ b/query_executor.go @@ -83,7 +83,28 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S } func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { - hostIter := q.policy.Pick(qry) + type hostGetter interface { + getHost() *HostInfo + } + + var hostIter NextHost + // checking if the qry implements hostGetter interface + if hostGetter, ok := qry.(hostGetter); ok { + // checking if the host is specified for the query, + // if it is, the query should be executed at the specified host + if host := hostGetter.getHost(); host != nil { + hostIter = func() SelectedHost { + return (*selectedHost)(host) + } + } + } + + // if host is not specified for the query, + // or it doesn't implement hostGetter interface, + // then a host will be picked by HostSelectionPolicy + if hostIter == nil { + hostIter = q.policy.Pick(qry) + } // check if the query is not marked as idempotent, if // it is, we force the policy to NonSpeculative diff --git a/session.go b/session.go index a600b95f3..8f5fe2713 100644 --- a/session.go +++ b/session.go @@ -936,6 +936,10 @@ type Query struct { // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. routingInfo *queryRoutingInfo + + // host specifies the host on which the query should be executed. + // If it is nil, then the host is picked by HostSelectionPolicy + host *HostInfo } type queryRoutingInfo struct { @@ -1423,6 +1427,17 @@ func (q *Query) releaseAfterExecution() { q.decRefCount() } +// SetHosts allows to define on which host the query should be executed. +// If host == nil, then the HostSelectionPolicy will be used to pick a host +func (q *Query) SetHost(host *HostInfo) *Query { + q.host = host + return q +} + +func (q *Query) getHost() *HostInfo { + return q.host +} + // Iter represents an iterator that can be used to iterate over all rows that // were returned by a query. The iterator might send additional queries to the // database during the iteration if paging was enabled. @@ -2174,6 +2189,15 @@ func (t *traceWriter) Trace(traceId []byte) { } } +// GetHosts returns a list of hosts found via queries to system.local and system.peers +func (s *Session) GetHosts() ([]*HostInfo, error) { + hosts, _, err := s.hostSource.GetHosts() + if err != nil { + return nil, err + } + return hosts, nil +} + type ObservedQuery struct { Keyspace string Statement string