Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CASSGO-4 Support of sending queries to the specific node #1793

Open
wants to merge 1 commit into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3253,7 +3253,6 @@ func TestUnsetColBatch(t *testing.T) {
}
var id, mInt, count int
var mText string

if err := session.Query("SELECT count(*) FROM gocql_test.batchUnsetInsert;").Scan(&count); err != nil {
t.Fatalf("Failed to select with err: %v", err)
} else if count != 2 {
Expand Down Expand Up @@ -3288,3 +3287,33 @@ func TestQuery_NamedValues(t *testing.T) {
t.Fatal(err)
}
}

func TestQuery_SetHost(t *testing.T) {
session := createSession(t)
defer session.Close()

hosts, err := session.GetHosts()
if err != nil {
t.Fatal(err)
}

for _, expectedHost := range hosts {
const iterations = 5
for i := 0; i < iterations; i++ {
var actualHostID string
err := session.Query("SELECT host_id FROM system.local").
SetHost(expectedHost).
Scan(&actualHostID)
if err != nil {
t.Fatal(err)
}

if expectedHost.HostID() != actualHostID {
t.Fatalf("Expected query to be executed on host %s, but it was executed on %s",
expectedHost.HostID(),
actualHostID,
)
}
}
}
}
23 changes: 22 additions & 1 deletion query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,28 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S
}

func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
hostIter := q.policy.Pick(qry)
type hostGetter interface {
getHost() *HostInfo
}

var hostIter NextHost
// checking if the qry implements hostGetter interface
if hostGetter, ok := qry.(hostGetter); ok {
// checking if the host is specified for the query,
// if it is, the query should be executed at the specified host
if host := hostGetter.getHost(); host != nil {
hostIter = func() SelectedHost {
return (*selectedHost)(host)
}
}
}

// if host is not specified for the query,
// or it doesn't implement hostGetter interface,
// then a host will be picked by HostSelectionPolicy
if hostIter == nil {
hostIter = q.policy.Pick(qry)
}

// check if the query is not marked as idempotent, if
// it is, we force the policy to NonSpeculative
Expand Down
24 changes: 24 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,10 @@ type Query struct {

// routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex.
routingInfo *queryRoutingInfo

// host specifies the host on which the query should be executed.
// If it is nil, then the host is picked by HostSelectionPolicy
host *HostInfo
}

type queryRoutingInfo struct {
Expand Down Expand Up @@ -1423,6 +1427,17 @@ func (q *Query) releaseAfterExecution() {
q.decRefCount()
}

// SetHosts allows to define on which host the query should be executed.
// If host == nil, then the HostSelectionPolicy will be used to pick a host
func (q *Query) SetHost(host *HostInfo) *Query {
q.host = host
return q
}

func (q *Query) getHost() *HostInfo {
return q.host
}

// Iter represents an iterator that can be used to iterate over all rows that
// were returned by a query. The iterator might send additional queries to the
// database during the iteration if paging was enabled.
Expand Down Expand Up @@ -2174,6 +2189,15 @@ func (t *traceWriter) Trace(traceId []byte) {
}
}

// GetHosts returns a list of hosts found via queries to system.local and system.peers
func (s *Session) GetHosts() ([]*HostInfo, error) {
hosts, _, err := s.hostSource.GetHosts()
if err != nil {
return nil, err
}
return hosts, nil
}

type ObservedQuery struct {
Keyspace string
Statement string
Expand Down