Skip to content

Commit

Permalink
refactor(generator): generators use context.Context instead of stopflag
Browse files Browse the repository at this point in the history
In process of removing the `stopFlag` from gemini's codebase,
first step is to migrate the `Value Generators` for patitions.
Using context with generators make a lot more sense then the custom
built, `stopFlag`. `context` is built-in package in Go, and
this is it's usecase - cancelation propagation to background task.

Signed-off-by: Dusan Malusev <[email protected]>
  • Loading branch information
CodeLieutenant committed Nov 27, 2024
1 parent 7c5dda0 commit f162158
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 159 deletions.
57 changes: 0 additions & 57 deletions cmd/gemini/generators.go

This file was deleted.

8 changes: 5 additions & 3 deletions cmd/gemini/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,14 @@ func run(_ *cobra.Command, _ []string) error {
stop.StartOsSignalsTransmitter(logger, stopFlag, warmupStopFlag)
pump := jobs.NewPump(stopFlag, logger)

gens, err := createGenerators(schema, schemaConfig, intSeed, partitionCount, logger)
distFunc, err := createDistributionFunc(partitionKeyDistribution, partitionCount, intSeed, stdDistMean, oneStdDev)
if err != nil {
return err
return errors.Wrapf(err, "Faile to create distribution function: %s", partitionKeyDistribution)
}
gens.StartAll(stopFlag)

gens := generators.New(ctx, schema, distFunc, schemaConfig.GetPartitionRangeConfig(), intSeed, partitionCount, pkBufferReuseSize, logger)
defer utils.IgnoreError(gens.Close)

if !nonInteractive {
sp := createSpinner(interactive())
ticker := time.NewTicker(time.Second)
Expand Down
107 changes: 63 additions & 44 deletions pkg/generators/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
package generators

import (
"context"
"github.com/pkg/errors"
"go.uber.org/zap"
"golang.org/x/exp/rand"

"github.com/scylladb/gemini/pkg/routingkey"
"github.com/scylladb/gemini/pkg/stop"
"github.com/scylladb/gemini/pkg/typedef"
)

Expand All @@ -37,7 +37,7 @@ type TokenIndex uint64

type DistributionFunc func() TokenIndex

type GeneratorInterface interface {
type Interface interface {
Get() *typedef.ValueWithToken
GetOld() *typedef.ValueWithToken
GiveOld(_ *typedef.ValueWithToken)
Expand All @@ -64,14 +64,6 @@ func (g *Generator) PartitionCount() uint64 {
return g.partitionCount
}

type Generators []*Generator

func (g Generators) StartAll(stopFlag *stop.Flag) {
for _, gen := range g {
gen.Start(stopFlag)
}
}

type Config struct {
PartitionsDistributionFunc DistributionFunc
PartitionsRangeConfig typedef.PartitionRangeConfig
Expand All @@ -80,9 +72,9 @@ type Config struct {
PkUsedBufferSize uint64
}

func NewGenerator(table *typedef.Table, config *Config, logger *zap.Logger) *Generator {
func NewGenerator(table *typedef.Table, config Config, logger *zap.Logger) Generator {
wakeUpSignal := make(chan struct{})
return &Generator{
return Generator{
partitions: NewPartitions(int(config.PartitionsCount), int(config.PkUsedBufferSize), wakeUpSignal),
partitionCount: config.PartitionsCount,
table: table,
Expand Down Expand Up @@ -135,47 +127,41 @@ func (g *Generator) ReleaseToken(token uint64) {
g.GetPartitionForToken(TokenIndex(token)).releaseToken(token)
}

func (g *Generator) Start(stopFlag *stop.Flag) {
go func() {
g.logger.Info("starting partition key generation loop")
defer g.partitions.CloseAll()
for {
g.fillAllPartitions(stopFlag)
select {
case <-stopFlag.SignalChannel():
g.logger.Debug("stopping partition key generation loop",
zap.Uint64("keys_created", g.cntCreated),
zap.Uint64("keys_emitted", g.cntEmitted))
return
case <-g.wakeUpSignal:
}
func (g *Generator) Start(ctx context.Context) {
defer g.partitions.Close()
g.logger.Info("starting partition key generation loop")
for {
g.fillAllPartitions(ctx)
select {
case <-ctx.Done():
g.logger.Debug("stopping partition key generation loop",
zap.Uint64("keys_created", g.cntCreated),
zap.Uint64("keys_emitted", g.cntEmitted))
return
case <-g.wakeUpSignal:
}
}()
}
}

func (g *Generator) FindAndMarkStalePartitions() {
r := rand.New(rand.NewSource(10))
nonStale := make([]bool, g.partitionCount)
for n := uint64(0); n < g.partitionCount*100; n++ {
values := CreatePartitionKeyValues(g.table, r, &g.partitionsConfig)
token, err := g.routingKeyCreator.GetHash(g.table, values)

for range g.partitionCount * 100 {
token, _, err := g.createPartitionKeyValues(r)
if err != nil {
g.logger.Panic(errors.Wrap(err, "failed to get primary key hash").Error())
g.logger.Panic("failed to get primary key hash", zap.Error(err))
}
nonStale[g.shardOf(token)] = true
}

for idx, val := range nonStale {
if !val {
g.partitions[idx].MarkStale()
if err = g.partition(token).MarkStale(); err != nil {
g.logger.Panic("failed to mark partition as stale", zap.Error(err))
}
}
}

// fillAllPartitions guarantees that each partition was tested to be full
// at least once since the function started and before it ended.
// In other words no partition will be starved.
func (g *Generator) fillAllPartitions(stopFlag *stop.Flag) {
func (g *Generator) fillAllPartitions(ctx context.Context) {
pFilled := make([]bool, len(g.partitions))
allFilled := func() bool {
for idx, filled := range pFilled {
Expand All @@ -188,22 +174,30 @@ func (g *Generator) fillAllPartitions(stopFlag *stop.Flag) {
}
return true
}
for !stopFlag.IsHardOrSoft() {
values := CreatePartitionKeyValues(g.table, g.r, &g.partitionsConfig)
token, err := g.routingKeyCreator.GetHash(g.table, values)

for {
select {
case <-ctx.Done():
return
default:
}

token, values, err := g.createPartitionKeyValues()
if err != nil {
g.logger.Panic(errors.Wrap(err, "failed to get primary key hash").Error())
g.logger.Panic("failed to get primary key hash", zap.Error(err))
}
g.cntCreated++
idx := token % g.partitionCount
partition := g.partitions[idx]

partition := g.partition(token)
if partition.Stale() || partition.inFlight.Has(token) {
continue
}

select {
case partition.values <- &typedef.ValueWithToken{Token: token, Value: values}:
g.cntEmitted++
default:
idx := g.shardOf(token)
if !pFilled[idx] {
pFilled[idx] = true
if allFilled() {
Expand All @@ -217,3 +211,28 @@ func (g *Generator) fillAllPartitions(stopFlag *stop.Flag) {
func (g *Generator) shardOf(token uint64) int {
return int(token % g.partitionCount)
}

func (g *Generator) partition(token uint64) *Partition {
return g.partitions[g.shardOf(token)]
}

func (g *Generator) createPartitionKeyValues(r ...*rand.Rand) (uint64, []any, error) {
rnd := g.r

if len(r) > 0 && r[0] != nil {
rnd = r[0]
}

values := make([]any, 0, g.table.PartitionKeysLenValues())

for _, pk := range g.table.PartitionKeys {
values = append(values, pk.Type.GenValue(rnd, &g.partitionsConfig)...)
}

token, err := g.routingKeyCreator.GetHash(g.table, values)
if err != nil {
return 0, nil, errors.Wrap(err, "failed to get primary key hash")
}

return token, values, nil
}
84 changes: 84 additions & 0 deletions pkg/generators/generators.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2019 ScyllaDB
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package generators

import (
"context"
"math"
"sync"

"go.uber.org/zap"

"github.com/scylladb/gemini/pkg/typedef"
)

type Generators struct {
Generators []Generator
wg *sync.WaitGroup
cancel context.CancelFunc
}

func New(
ctx context.Context,
schema *typedef.Schema,
distFunc DistributionFunc,
partitionRangeConfig typedef.PartitionRangeConfig,
seed, distributionSize, pkBufferReuseSize uint64,
logger *zap.Logger,
) *Generators {
gs := make([]Generator, 0, len(schema.Tables))

cfg := Config{
PartitionsRangeConfig: partitionRangeConfig,
PartitionsCount: distributionSize,
PartitionsDistributionFunc: distFunc,
Seed: seed,
PkUsedBufferSize: pkBufferReuseSize,
}

wg := new(sync.WaitGroup)
wg.Add(len(schema.Tables))

ctx, cancel := context.WithCancel(ctx)

for _, table := range schema.Tables {
g := NewGenerator(table, cfg, logger.Named("generators-"+table.Name))
go func() {
defer wg.Done()
g.Start(ctx)
}()

if table.PartitionKeys.ValueVariationsNumber(&partitionRangeConfig) < math.MaxUint32 {
// Low partition key variation can lead to having staled partitions
// Let's detect and mark them before running test
g.FindAndMarkStalePartitions()
}

gs = append(gs, g)
}

return &Generators{
Generators: gs,
wg: wg,
cancel: cancel,
}
}

func (g *Generators) Close() error {
g.cancel()
g.wg.Wait()

return nil
}
Loading

0 comments on commit f162158

Please sign in to comment.