diff --git a/cmd/optimizely/main.go b/cmd/optimizely/main.go index a6edaec3..d58f95b4 100644 --- a/cmd/optimizely/main.go +++ b/cmd/optimizely/main.go @@ -267,7 +267,7 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) // Create default service context sg := server.NewGroup(ctx, conf.Server) // Create a new server group to manage the individual http listeners - optlyCache := optimizely.NewCache(ctx, conf.Client, sdkMetricsRegistry) + optlyCache := optimizely.NewCache(ctx, *conf, sdkMetricsRegistry) optlyCache.Init(conf.SDKKeys) // goroutine to check for signals to gracefully shutdown listeners @@ -281,7 +281,7 @@ func main() { cancel() }() - apiRouter := routers.NewDefaultAPIRouter(optlyCache, conf.API, agentMetricsRegistry) + apiRouter := routers.NewDefaultAPIRouter(optlyCache, *conf, agentMetricsRegistry) adminRouter := routers.NewAdminRouter(*conf) log.Info().Str("version", conf.Version).Msg("Starting services.") diff --git a/config.yaml b/config.yaml index f143f7ba..87b872a0 100644 --- a/config.yaml +++ b/config.yaml @@ -243,3 +243,17 @@ runtime: ## To just read the current rate, pass rate < 0. ## (For n>1 the details of sampling may change.) mutexProfileFraction: 0 + +## synchronization should be enabled when multiple replicas of agent is deployed +## if notification synchronization is enabled, then the active notification event-stream API +## will get the notifications from multiple replicas +synchronization: + pubsub: + redis: + host: "redis.demo.svc:6379" + password: "" + database: 0 + channel: "optimizely-sync" + notification: + enable: false + default: "redis" diff --git a/config/config.go b/config/config.go index 19017a57..428c3992 100644 --- a/config/config.go +++ b/config/config.go @@ -126,6 +126,20 @@ func NewDefaultConfig() *AgentConfig { Webhook: WebhookConfig{ Port: "8085", }, + Synchronization: SyncConfig{ + Pubsub: map[string]interface{}{ + "redis": map[string]interface{}{ + "host": "localhost:6379", + "password": "", + "database": 0, + "channel": "optimizely-notifications", + }, + }, + Notification: NotificationConfig{ + Enable: false, + Default: "redis", + }, + }, } return &config @@ -139,14 +153,27 @@ type AgentConfig struct { SDKKeys []string `yaml:"sdkKeys" json:"sdkKeys"` - Admin AdminConfig `json:"admin"` - API APIConfig `json:"api"` - Log LogConfig `json:"log"` - Tracing TracingConfig `json:"tracing"` - Client ClientConfig `json:"client"` - Runtime RuntimeConfig `json:"runtime"` - Server ServerConfig `json:"server"` - Webhook WebhookConfig `json:"webhook"` + Admin AdminConfig `json:"admin"` + API APIConfig `json:"api"` + Log LogConfig `json:"log"` + Tracing TracingConfig `json:"tracing"` + Client ClientConfig `json:"client"` + Runtime RuntimeConfig `json:"runtime"` + Server ServerConfig `json:"server"` + Webhook WebhookConfig `json:"webhook"` + Synchronization SyncConfig `json:"synchronization"` +} + +// SyncConfig contains Synchronization configuration for the multiple Agent nodes +type SyncConfig struct { + Pubsub map[string]interface{} `json:"pubsub"` + Notification NotificationConfig `json:"notification"` +} + +// NotificationConfig contains Notification Synchronization configuration for the multiple Agent nodes +type NotificationConfig struct { + Enable bool `json:"enable"` + Default string `json:"default"` } // HTTPSDisabledWarning is logged when keyfile and certfile are not provided in server configuration diff --git a/go.mod b/go.mod index fc587b1a..fcb03a04 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/golang-jwt/jwt/v4 v4.5.0 github.com/google/uuid v1.3.0 github.com/lestrrat-go/jwx v0.9.0 - github.com/optimizely/go-sdk v1.8.4-0.20230515121609-7ffed835c991 + github.com/optimizely/go-sdk v1.8.4-0.20230911163718-b10e161e39b8 github.com/orcaman/concurrent-map v1.0.0 github.com/prometheus/client_golang v1.11.0 github.com/rakyll/statik v0.1.7 diff --git a/go.sum b/go.sum index ea29a985..ecd97249 100644 --- a/go.sum +++ b/go.sum @@ -244,8 +244,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= -github.com/optimizely/go-sdk v1.8.4-0.20230515121609-7ffed835c991 h1:bRoRDKRa7EgSTCEb54qaDuU6IDegQKQumun8buDV/cY= -github.com/optimizely/go-sdk v1.8.4-0.20230515121609-7ffed835c991/go.mod h1:06VK8mwwQTEh7QzP+qivf16tXtXEpoeblqtlhfvWEgk= +github.com/optimizely/go-sdk v1.8.4-0.20230911163718-b10e161e39b8 h1:1LhZsu7IB7LR3PzwIzfP56cdOkUAKRXxW1wljd352sg= +github.com/optimizely/go-sdk v1.8.4-0.20230911163718-b10e161e39b8/go.mod h1:zITWqffjOXsae/Z0PlCN5kgJRgJF/0g/k8RBEsxNrxg= github.com/orcaman/concurrent-map v1.0.0 h1:I/2A2XPCb4IuQWcQhBhSwGfiuybl/J0ev9HDbW65HOY= github.com/orcaman/concurrent-map v1.0.0/go.mod h1:Lu3tH6HLW3feq74c2GC+jIMS/K2CFcDWnWD9XkenwhI= github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU= diff --git a/pkg/handlers/notification.go b/pkg/handlers/notification.go index 55b9b7f0..5c391cf8 100644 --- a/pkg/handlers/notification.go +++ b/pkg/handlers/notification.go @@ -1,5 +1,5 @@ /**************************************************************************** - * Copyright 2020, Optimizely, Inc. and contributors * + * Copyright 2020,2023 Optimizely, Inc. and contributors * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * @@ -18,29 +18,42 @@ package handlers import ( + "context" "encoding/json" + "errors" "fmt" "net/http" "strings" + "github.com/go-redis/redis/v8" + "github.com/optimizely/agent/config" "github.com/optimizely/agent/pkg/middleware" + "github.com/optimizely/agent/pkg/syncer" "github.com/optimizely/go-sdk/pkg/notification" "github.com/optimizely/go-sdk/pkg/registry" + "github.com/rs/zerolog" +) + +const ( + LoggerKey = "notification-logger" + SDKKey = "context-sdk-key" ) // A MessageChan is a channel of bytes // Each http handler call creates a new channel and pumps decision service messages onto it. type MessageChan chan []byte +type NotificationReceiverFunc func(context.Context) (<-chan syncer.Event, error) + // types of notifications supported. -var types = map[string]notification.Type{ - string(notification.Decision): notification.Decision, - string(notification.Track): notification.Track, - string(notification.ProjectConfigUpdate): notification.ProjectConfigUpdate, +var types = map[notification.Type]string{ + notification.Decision: string(notification.Decision), + notification.Track: string(notification.Track), + notification.ProjectConfigUpdate: string(notification.ProjectConfigUpdate), } -func getFilter(filters []string) map[string]notification.Type { - notificationsToAdd := map[string]notification.Type{} +func getFilter(filters []string) map[notification.Type]string { + notificationsToAdd := make(map[notification.Type]string) // Parse out the any filters that were added if len(filters) == 0 { notificationsToAdd = types @@ -51,8 +64,8 @@ func getFilter(filters []string) map[string]notification.Type { splits := strings.Split(filter, ",") for _, split := range splits { // if the string is a valid type - if _, ok := types[split]; ok { - notificationsToAdd[split] = notification.Type(split) + if _, ok := types[notification.Type(split)]; ok { + notificationsToAdd[notification.Type(split)] = split } } } @@ -60,103 +73,198 @@ func getFilter(filters []string) map[string]notification.Type { return notificationsToAdd } -// NotificationEventSteamHandler implements the http.Handler interface. -func NotificationEventSteamHandler(w http.ResponseWriter, r *http.Request) { - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) +func NotificationEventStreamHandler(notificationReceiverFn NotificationReceiverFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Make sure that the writer supports flushing. + flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) - return - } + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + _, err := middleware.GetOptlyClient(r) + + if err != nil { + RenderError(err, http.StatusUnprocessableEntity, w, r) + return + } + + // Set the headers related to event streaming. + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + // "raw" query string option + // If provided, send raw JSON lines instead of SSE-compliant strings. + raw := len(r.URL.Query()["raw"]) > 0 - _, err := middleware.GetOptlyClient(r) + // Parse the form. + _ = r.ParseForm() - if err != nil { - RenderError(err, http.StatusUnprocessableEntity, w, r) - return + filters := r.Form["filter"] + + // Parse out the any filters that were added + notificationsToAdd := getFilter(filters) + + // Listen to connection close and un-register messageChan + notify := r.Context().Done() + + sdkKey := r.Header.Get(middleware.OptlySDKHeader) + ctx := context.WithValue(r.Context(), SDKKey, sdkKey) + + dataChan, err := notificationReceiverFn(context.WithValue(ctx, LoggerKey, middleware.GetLogger(r))) + if err != nil { + middleware.GetLogger(r).Err(err).Msg("error from receiver") + http.Error(w, "Error from data receiver!", http.StatusInternalServerError) + return + } + + for { + select { + case <-notify: + middleware.GetLogger(r).Debug().Msg("received close on the request. So, we are shutting down this handler") + return + case event := <-dataChan: + _, found := notificationsToAdd[event.Type] + if !found { + continue + } + + jsonEvent, err := json.Marshal(event.Message) + if err != nil { + middleware.GetLogger(r).Err(err).Msg("failed to marshal notification into json") + continue + } + + if raw { + // Raw JSON events, one per line + _, _ = fmt.Fprintf(w, "%s\n", string(jsonEvent)) + } else { + // Server Sent Events compatible + _, _ = fmt.Fprintf(w, "data: %s\n\n", string(jsonEvent)) + } + // Flush the data immediately instead of buffering it for later. + // The flush will fail if the connection is closed. That will cause the handler to exit. + flusher.Flush() + } + } } +} - // Set the headers related to event streaming. - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") +func DefaultNotificationReceiver(ctx context.Context) (<-chan syncer.Event, error) { + logger, ok := ctx.Value(LoggerKey).(*zerolog.Logger) + if !ok { + logger = &zerolog.Logger{} + } + + sdkKey, ok := ctx.Value(SDKKey).(string) + if !ok || sdkKey == "" { + return nil, errors.New("sdk key not found") + } // Each connection registers its own message channel with the NotificationHandler's connections registry - messageChan := make(MessageChan) - // Each connection also adds listeners - sdkKey := r.Header.Get(middleware.OptlySDKHeader) + messageChan := make(chan syncer.Event) nc := registry.GetNotificationCenter(sdkKey) - // Parse the form. - _ = r.ParseForm() - - filters := r.Form["filter"] - // Parse out the any filters that were added - notificationsToAdd := getFilter(filters) + notificationsToAdd := types ids := []struct { int notification.Type }{} - for _, value := range notificationsToAdd { - id, e := nc.AddHandler(value, func(n interface{}) { - jsonEvent, err := json.Marshal(n) - if err != nil { - middleware.GetLogger(r).Error().Msg("encoding notification to json") - } else { - messageChan <- jsonEvent + for notificationType := range notificationsToAdd { + id, e := nc.AddHandler(notificationType, func(n interface{}) { + msg := syncer.Event{ + Type: notificationType, + Message: n, } + messageChan <- msg }) if e != nil { - RenderError(e, http.StatusUnprocessableEntity, w, r) - return + return nil, e } // do defer outside the loop. ids = append(ids, struct { int notification.Type - }{id, value}) + }{id, notificationType}) } - // Remove the decision listener if we exited. - defer func() { - for _, id := range ids { - err := nc.RemoveHandler(id.int, id.Type) - if err != nil { - middleware.GetLogger(r).Error().AnErr("removing notification", err) + go func() { + for { + select { + case <-ctx.Done(): + for _, id := range ids { + err := nc.RemoveHandler(id.int, id.Type) + if err != nil { + logger.Err(err).AnErr("error in removing notification handler", err) + } + } + return } } }() - // "raw" query string option - // If provided, send raw JSON lines instead of SSE-compliant strings. - raw := len(r.Form["raw"]) > 0 - - // Listen to connection close and un-register messageChan - notify := r.Context().Done() - // block waiting or messages broadcast on this connection's messageChan - for { - select { - // Write to the ResponseWriter - case msg := <-messageChan: - if raw { - // Raw JSON events, one per line - _, _ = fmt.Fprintf(w, "%s\n", msg) - } else { - // Server Sent Events compatible - _, _ = fmt.Fprintf(w, "data: %s\n\n", msg) - } - // Flush the data immediately instead of buffering it for later. - // The flush will fail if the connection is closed. That will cause the handler to exit. - flusher.Flush() - case <-notify: - middleware.GetLogger(r).Debug().Msg("received close on the request. So, we are shutting down this handler") - return + return messageChan, nil +} + +func RedisNotificationReceiver(conf config.SyncConfig) NotificationReceiverFunc { + return func(ctx context.Context) (<-chan syncer.Event, error) { + sdkKey, ok := ctx.Value(SDKKey).(string) + if !ok || sdkKey == "" { + return nil, errors.New("sdk key not found") } - } + redisSyncer, err := syncer.NewRedisSyncer(&zerolog.Logger{}, conf, sdkKey) + if err != nil { + return nil, err + } + + client := redis.NewClient(&redis.Options{ + Addr: redisSyncer.Host, + Password: redisSyncer.Password, + DB: redisSyncer.Database, + }) + + // Subscribe to a Redis channel + pubsub := client.Subscribe(ctx, syncer.GetChannelForSDKKey(redisSyncer.Channel, sdkKey)) + + dataChan := make(chan syncer.Event) + + logger, ok := ctx.Value(LoggerKey).(*zerolog.Logger) + if !ok { + logger = &zerolog.Logger{} + } + + go func() { + for { + select { + case <-ctx.Done(): + client.Close() + pubsub.Close() + logger.Debug().Msg("context canceled, redis notification receiver is closed") + return + default: + msg, err := pubsub.ReceiveMessage(ctx) + if err != nil { + logger.Err(err).Msg("failed to receive message from redis") + continue + } + + var event syncer.Event + if err := json.Unmarshal([]byte(msg.Payload), &event); err != nil { + logger.Err(err).Msg("failed to unmarshal redis message") + continue + } + dataChan <- event + } + } + }() + + return dataChan, nil + } } diff --git a/pkg/handlers/notification_test.go b/pkg/handlers/notification_test.go index 595c87b9..82d0b95e 100644 --- a/pkg/handlers/notification_test.go +++ b/pkg/handlers/notification_test.go @@ -19,20 +19,23 @@ package handlers import ( "context" + "errors" + "fmt" "net/http" "net/http/httptest" + "reflect" "testing" "time" - "github.com/optimizely/go-sdk/pkg/notification" - "github.com/optimizely/go-sdk/pkg/registry" - + "github.com/go-chi/chi/v5" + "github.com/optimizely/agent/config" "github.com/optimizely/agent/pkg/middleware" "github.com/optimizely/agent/pkg/optimizely" "github.com/optimizely/agent/pkg/optimizely/optimizelytest" - - "github.com/go-chi/chi/v5" + "github.com/optimizely/agent/pkg/syncer" "github.com/optimizely/go-sdk/pkg/entities" + "github.com/optimizely/go-sdk/pkg/notification" + "github.com/optimizely/go-sdk/pkg/registry" "github.com/stretchr/testify/suite" ) @@ -66,13 +69,15 @@ func (suite *NotificationTestSuite) SetupTest() { EventStreamMW := &NotificationMW{optlyClient} mux.Use(EventStreamMW.ClientCtx) - mux.Get("/notifications/event-stream", NotificationEventSteamHandler) suite.mux = mux suite.tc = testClient } func (suite *NotificationTestSuite) TestFeatureTestFilter() { + conf := config.NewDefaultConfig() + suite.mux.Get("/notifications/event-stream", NotificationEventStreamHandler(getMockNotificationReceiver(conf.Synchronization, false))) + feature := entities.Feature{Key: "one"} suite.tc.AddFeatureTest(feature) @@ -137,16 +142,87 @@ func (suite *NotificationTestSuite) TestTrackAndProjectConfig() { nc := registry.GetNotificationCenter("") + notifications := make([]syncer.Event, 0) + + trackEvent := map[string]string{"test": "value"} + projectConfigUpdateNotification := notification.ProjectConfigUpdateNotification{ + Type: notification.ProjectConfigUpdate, + Revision: suite.tc.ProjectConfig.GetRevision(), + } + + notifications = append(notifications, syncer.Event{Type: notification.Track, Message: trackEvent}) + notifications = append(notifications, syncer.Event{Type: notification.ProjectConfigUpdate, Message: projectConfigUpdateNotification}) + go func() { time.Sleep(1 * time.Second) - _ = nc.Send(notification.Track, map[string]string{"test": "value"}) - projectConfigUpdateNotification := notification.ProjectConfigUpdateNotification{ - Type: notification.ProjectConfigUpdate, - Revision: suite.tc.ProjectConfig.GetRevision(), - } + + _ = nc.Send(notification.Track, trackEvent) + + _ = nc.Send(notification.ProjectConfigUpdate, projectConfigUpdateNotification) + }() + + conf := config.NewDefaultConfig() + suite.mux.Get("/notifications/event-stream", NotificationEventStreamHandler(getMockNotificationReceiver(conf.Synchronization, false, notifications...))) + + suite.mux.ServeHTTP(rec, req.WithContext(ctx1)) + + suite.Equal(http.StatusOK, rec.Code) + + // Unmarshal response + response := string(rec.Body.Bytes()) + suite.Equal(expected, response) +} + +func (suite *NotificationTestSuite) TestTrackAndProjectConfigWithSynchronization() { + event := entities.Event{Key: "one"} + suite.tc.AddEvent(event) + + req := httptest.NewRequest("GET", "/notifications/event-stream", nil) + rec := httptest.NewRecorder() + + expected := `data: {"test":"value"}` + "\n\n" + `data: {"Type":"project_config_update","Revision":"revision"}` + "\n\n" + + // create a cancelable request context + ctx := req.Context() + ctx1, _ := context.WithTimeout(ctx, 3*time.Second) + + nc := registry.GetNotificationCenter("") + + notifications := make([]syncer.Event, 0) + + trackEvent := map[string]string{"test": "value"} + projectConfigUpdateNotification := notification.ProjectConfigUpdateNotification{ + Type: notification.ProjectConfigUpdate, + Revision: suite.tc.ProjectConfig.GetRevision(), + } + + notifications = append(notifications, syncer.Event{Type: notification.Track, Message: trackEvent}) + notifications = append(notifications, syncer.Event{Type: notification.ProjectConfigUpdate, Message: projectConfigUpdateNotification}) + + go func() { + time.Sleep(1 * time.Second) + + _ = nc.Send(notification.Track, trackEvent) + _ = nc.Send(notification.ProjectConfigUpdate, projectConfigUpdateNotification) }() + conf := config.NewDefaultConfig() + conf.Synchronization = config.SyncConfig{ + Pubsub: map[string]interface{}{ + "redis": map[string]interface{}{ + "host": "localhost:6379", + "password": "", + "database": 0, + }, + }, + Notification: config.NotificationConfig{ + Enable: true, + Default: "redis", + }, + } + suite.mux.Get("/notifications/event-stream", NotificationEventStreamHandler(getMockNotificationReceiver(conf.Synchronization, false, notifications...))) + suite.mux.ServeHTTP(rec, req.WithContext(ctx1)) suite.Equal(http.StatusOK, rec.Code) @@ -170,11 +246,19 @@ func (suite *NotificationTestSuite) TestActivateExperimentRaw() { ctx1, _ := context.WithTimeout(ctx, 2*time.Second) nc := registry.GetNotificationCenter("") + decisionEvent := map[string]string{"key": "value"} + + notifications := make([]syncer.Event, 0) + notifications = append(notifications, syncer.Event{Type: notification.Decision, Message: decisionEvent}) + go func() { time.Sleep(1 * time.Second) - nc.Send(notification.Decision, map[string]string{"key": "value"}) + nc.Send(notification.Decision, decisionEvent) }() + conf := config.NewDefaultConfig() + suite.mux.Get("/notifications/event-stream", NotificationEventStreamHandler(getMockNotificationReceiver(conf.Synchronization, false, notifications...))) + suite.mux.ServeHTTP(rec, req.WithContext(ctx1)) suite.Equal(http.StatusOK, rec.Code) @@ -184,6 +268,33 @@ func (suite *NotificationTestSuite) TestActivateExperimentRaw() { suite.Equal(expected, response) } +func (suite *NotificationTestSuite) TestWithFailedNotificationReceiver() { + req := httptest.NewRequest("GET", "/notifications/event-stream", nil) + rec := httptest.NewRecorder() + + // create a cancelable request context + ctx := req.Context() + ctx1, _ := context.WithTimeout(ctx, 2*time.Second) + + nc := registry.GetNotificationCenter("") + decisionEvent := map[string]string{"key": "value"} + + notifications := make([]syncer.Event, 0) + notifications = append(notifications, syncer.Event{Type: notification.Decision, Message: decisionEvent}) + + go func() { + time.Sleep(1 * time.Second) + nc.Send(notification.Decision, decisionEvent) + }() + + conf := config.NewDefaultConfig() + suite.mux.Get("/notifications/event-stream", NotificationEventStreamHandler(getMockNotificationReceiver(conf.Synchronization, true, notifications...))) + + suite.mux.ServeHTTP(rec, req.WithContext(ctx1)) + + suite.Equal(http.StatusInternalServerError, rec.Code) +} + func (suite *NotificationTestSuite) assertError(rec *httptest.ResponseRecorder, msg string, code int) { assertError(suite.T(), rec, msg, code) } @@ -201,8 +312,9 @@ func TestEventStreamMissingOptlyCtx(t *testing.T) { mw := new(NotificationMW) mw.optlyClient = nil + conf := config.NewDefaultConfig() handlers := []func(w http.ResponseWriter, r *http.Request){ - NotificationEventSteamHandler, + NotificationEventStreamHandler(getMockNotificationReceiver(conf.Synchronization, false)), } for _, handler := range handlers { @@ -211,3 +323,122 @@ func TestEventStreamMissingOptlyCtx(t *testing.T) { assertError(t, rec, "optlyClient not available", http.StatusUnprocessableEntity) } } + +func TestDefaultNotificationReceiver(t *testing.T) { + type args struct { + ctx context.Context + } + tests := []struct { + name string + args args + want <-chan syncer.Event + wantErr bool + }{ + { + name: "Test happy path", + args: args{ctx: context.WithValue(context.TODO(), SDKKey, "1221")}, + want: make(chan syncer.Event), + wantErr: false, + }, + { + name: "Test without sdk key", + args: args{ctx: context.TODO()}, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := DefaultNotificationReceiver(tt.args.ctx) + if (err != nil) != tt.wantErr { + t.Errorf("DefaultNotificationReceiver() error = %v, wantErr %v", err, tt.wantErr) + return + } + if reflect.TypeOf(tt.want) != reflect.TypeOf(got) { + t.Errorf("DefaultNotificationReceiver() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRedisNotificationReceiver(t *testing.T) { + conf := config.SyncConfig{ + Pubsub: map[string]interface{}{ + "redis": map[string]interface{}{ + "host": "localhost:6379", + "password": "", + "database": 0, + }, + }, + Notification: config.NotificationConfig{ + Enable: true, + Default: "redis", + }, + } + type args struct { + conf config.SyncConfig + ctx context.Context + } + tests := []struct { + name string + args args + want NotificationReceiverFunc + }{ + { + name: "Test happy path", + args: args{ + conf: conf, + ctx: context.WithValue(context.Background(), SDKKey, "random-sdk-key-1"), + }, + want: func(ctx context.Context) (<-chan syncer.Event, error) { + return make(<-chan syncer.Event), nil + }, + }, + { + name: "Test empty config", + args: args{ + conf: config.SyncConfig{}, + ctx: context.WithValue(context.Background(), SDKKey, "random-sdk-key-2"), + }, + want: func(ctx context.Context) (<-chan syncer.Event, error) { + return nil, errors.New("error") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := RedisNotificationReceiver(tt.args.conf) + if reflect.TypeOf(got) != reflect.TypeOf(tt.want) { + t.Errorf("RedisNotificationReceiver() = %v, want %v", got, tt.want) + } + + ch1, err1 := got(tt.args.ctx) + ch2, err2 := tt.want(tt.args.ctx) + + if reflect.TypeOf(err1) != reflect.TypeOf(err2) { + fmt.Println(err1, err2) + t.Errorf("error type not matched") + } + + if reflect.TypeOf(ch1) != reflect.TypeOf(ch2) { + t.Errorf("channel type not matched") + } + }) + } +} + +func getMockNotificationReceiver(conf config.SyncConfig, returnError bool, msg ...syncer.Event) NotificationReceiverFunc { + return func(ctx context.Context) (<-chan syncer.Event, error) { + if returnError { + return nil, errors.New("mock error") + } + dataChan := make(chan syncer.Event) + go func() { + time.Sleep(1) + for _, val := range msg { + dataChan <- val + } + }() + return dataChan, nil + } +} diff --git a/pkg/optimizely/cache.go b/pkg/optimizely/cache.go index d9af8388..f1401311 100644 --- a/pkg/optimizely/cache.go +++ b/pkg/optimizely/cache.go @@ -26,6 +26,7 @@ import ( "sync" "github.com/optimizely/agent/config" + "github.com/optimizely/agent/pkg/syncer" "github.com/optimizely/agent/plugins/odpcache" "github.com/optimizely/agent/plugins/userprofileservice" "github.com/optimizely/go-sdk/pkg/client" @@ -40,6 +41,7 @@ import ( odpCachePkg "github.com/optimizely/go-sdk/pkg/odp/cache" cmap "github.com/orcaman/concurrent-map" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) @@ -61,7 +63,7 @@ type OptlyCache struct { } // NewCache returns a new implementation of OptlyCache interface backed by a concurrent map. -func NewCache(ctx context.Context, conf config.ClientConfig, metricsRegistry *MetricsRegistry) *OptlyCache { +func NewCache(ctx context.Context, conf config.AgentConfig, metricsRegistry *MetricsRegistry) *OptlyCache { // TODO is there a cleaner way to handle this translation??? cmLoader := func(sdkkey string, options ...sdkconfig.OptionFunc) SyncedConfigManager { @@ -168,13 +170,14 @@ func regexValidator(sdkKeyRegex string) func(string) bool { } func defaultLoader( - conf config.ClientConfig, + agentConf config.AgentConfig, metricsRegistry *MetricsRegistry, userProfileServiceMap cmap.ConcurrentMap, odpCacheMap cmap.ConcurrentMap, pcFactory func(sdkKey string, options ...sdkconfig.OptionFunc) SyncedConfigManager, bpFactory func(options ...event.BPOptionConfig) *event.BatchEventProcessor) func(clientKey string) (*OptlyClient, error) { - validator := regexValidator(conf.SdkKeyRegex) + clientConf := agentConf.Client + validator := regexValidator(clientConf.SdkKeyRegex) return func(clientKey string) (*OptlyClient, error) { var sdkKey string @@ -211,15 +214,15 @@ func defaultLoader( if datafileAccessToken != "" { configManager = pcFactory( sdkKey, - sdkconfig.WithPollingInterval(conf.PollingInterval), - sdkconfig.WithDatafileURLTemplate(conf.DatafileURLTemplate), + sdkconfig.WithPollingInterval(clientConf.PollingInterval), + sdkconfig.WithDatafileURLTemplate(clientConf.DatafileURLTemplate), sdkconfig.WithDatafileAccessToken(datafileAccessToken), ) } else { configManager = pcFactory( sdkKey, - sdkconfig.WithPollingInterval(conf.PollingInterval), - sdkconfig.WithDatafileURLTemplate(conf.DatafileURLTemplate), + sdkconfig.WithPollingInterval(clientConf.PollingInterval), + sdkconfig.WithDatafileURLTemplate(clientConf.DatafileURLTemplate), ) } @@ -227,13 +230,13 @@ func defaultLoader( return &OptlyClient{}, err } - q := event.NewInMemoryQueue(conf.QueueSize) + q := event.NewInMemoryQueue(clientConf.QueueSize) ep := bpFactory( event.WithSDKKey(sdkKey), - event.WithQueueSize(conf.QueueSize), - event.WithBatchSize(conf.BatchSize), - event.WithEventEndPoint(conf.EventURL), - event.WithFlushInterval(conf.FlushInterval), + event.WithQueueSize(clientConf.QueueSize), + event.WithBatchSize(clientConf.BatchSize), + event.WithEventEndPoint(clientConf.EventURL), + event.WithFlushInterval(clientConf.FlushInterval), event.WithQueue(q), event.WithEventDispatcherMetrics(metricsRegistry), ) @@ -245,11 +248,19 @@ func defaultLoader( client.WithConfigManager(configManager), client.WithExperimentOverrides(forcedVariations), client.WithEventProcessor(ep), - client.WithOdpDisabled(conf.ODP.Disable), + client.WithOdpDisabled(clientConf.ODP.Disable), + } + + if agentConf.Synchronization.Notification.Enable { + redisSyncer, err := syncer.NewRedisSyncer(&zerolog.Logger{}, agentConf.Synchronization, sdkKey) + if err != nil { + return nil, err + } + clientOptions = append(clientOptions, client.WithNotificationCenter(redisSyncer)) } var clientUserProfileService decision.UserProfileService - var rawUPS = getServiceWithType(userProfileServicePlugin, sdkKey, userProfileServiceMap, conf.UserProfileService) + var rawUPS = getServiceWithType(userProfileServicePlugin, sdkKey, userProfileServiceMap, clientConf.UserProfileService) // Check if ups was provided by user if rawUPS != nil { // convert ups to UserProfileService interface @@ -260,7 +271,7 @@ func defaultLoader( } var clientODPCache odpCachePkg.Cache - var rawODPCache = getServiceWithType(odpCachePlugin, sdkKey, odpCacheMap, conf.ODP.SegmentsCache) + var rawODPCache = getServiceWithType(odpCachePlugin, sdkKey, odpCacheMap, clientConf.ODP.SegmentsCache) // Check if odp cache was provided by user if rawODPCache != nil { // convert odpCache to Cache interface @@ -273,7 +284,7 @@ func defaultLoader( segmentManager := odpSegmentPkg.NewSegmentManager( sdkKey, odpSegmentPkg.WithAPIManager( - odpSegmentPkg.NewSegmentAPIManager(sdkKey, utils.NewHTTPRequester(logging.GetLogger(sdkKey, "SegmentAPIManager"), utils.Timeout(conf.ODP.SegmentsRequestTimeout))), + odpSegmentPkg.NewSegmentAPIManager(sdkKey, utils.NewHTTPRequester(logging.GetLogger(sdkKey, "SegmentAPIManager"), utils.Timeout(clientConf.ODP.SegmentsRequestTimeout))), ), odpSegmentPkg.WithSegmentsCache(clientODPCache), ) @@ -282,16 +293,16 @@ func defaultLoader( eventManager := odpEventPkg.NewBatchEventManager( odpEventPkg.WithAPIManager( odpEventPkg.NewEventAPIManager( - sdkKey, utils.NewHTTPRequester(logging.GetLogger(sdkKey, "EventAPIManager"), utils.Timeout(conf.ODP.EventsRequestTimeout)), + sdkKey, utils.NewHTTPRequester(logging.GetLogger(sdkKey, "EventAPIManager"), utils.Timeout(clientConf.ODP.EventsRequestTimeout)), ), ), - odpEventPkg.WithFlushInterval(conf.ODP.EventsFlushInterval), + odpEventPkg.WithFlushInterval(clientConf.ODP.EventsFlushInterval), ) // Create odp manager with custom segment and event manager odpManager := odp.NewOdpManager( sdkKey, - conf.ODP.Disable, + clientConf.ODP.Disable, odp.WithSegmentManager(segmentManager), odp.WithEventManager(eventManager), ) diff --git a/pkg/optimizely/cache_test.go b/pkg/optimizely/cache_test.go index 2beabb4b..bd8678b8 100644 --- a/pkg/optimizely/cache_test.go +++ b/pkg/optimizely/cache_test.go @@ -112,7 +112,7 @@ func (suite *CacheTestSuite) TestNewCache() { sdkMetricsRegistry := NewRegistry(agentMetricsRegistry) // To improve coverage - optlyCache := NewCache(context.Background(), config.ClientConfig{}, sdkMetricsRegistry) + optlyCache := NewCache(context.Background(), config.AgentConfig{}, sdkMetricsRegistry) suite.NotNil(optlyCache) } @@ -420,7 +420,7 @@ func (s *DefaultLoaderTestSuite) TestDefaultLoader() { }, } - loader := defaultLoader(conf, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) + loader := defaultLoader(config.AgentConfig{Client: conf}, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) client, err := loader("sdkkey") s.NoError(err) @@ -474,7 +474,7 @@ func (s *DefaultLoaderTestSuite) TestUPSAndODPCacheHeaderOverridesDefaultKey() { tmpOdpCacheMap := cmap.New() tmpOdpCacheMap.Set("sdkkey", "in-memory") - loader := defaultLoader(conf, s.registry, tmpUPSMap, tmpOdpCacheMap, s.pcFactory, s.bpFactory) + loader := defaultLoader(config.AgentConfig{Client: conf}, s.registry, tmpUPSMap, tmpOdpCacheMap, s.pcFactory, s.bpFactory) client, err := loader("sdkkey") s.NoError(err) @@ -538,7 +538,7 @@ func (s *DefaultLoaderTestSuite) TestFirstSaveConfiguresClientForRedisUPSAndODPC }}, }, } - loader := defaultLoader(conf, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) + loader := defaultLoader(config.AgentConfig{Client: conf}, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) client, err := loader("sdkkey") s.NoError(err) s.NotNil(client.UserProfileService) @@ -596,7 +596,7 @@ func (s *DefaultLoaderTestSuite) TestFirstSaveConfiguresLRUCacheForInMemoryCache }}, }, } - loader := defaultLoader(conf, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) + loader := defaultLoader(config.AgentConfig{Client: conf}, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) client, err := loader("sdkkey") s.NoError(err) s.NotNil(client.odpCache) @@ -627,7 +627,7 @@ func (s *DefaultLoaderTestSuite) TestHttpClientInitializesByDefaultRestUPS() { "rest": map[string]interface{}{}, }}, } - loader := defaultLoader(conf, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) + loader := defaultLoader(config.AgentConfig{Client: conf}, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) client, err := loader("sdkkey") s.NoError(err) s.NotNil(client.UserProfileService) @@ -655,7 +655,7 @@ func (s *DefaultLoaderTestSuite) TestLoaderWithValidUserProfileServices() { }, }}, } - loader := defaultLoader(conf, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) + loader := defaultLoader(config.AgentConfig{Client: conf}, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) client, err := loader("sdkkey") s.NoError(err) @@ -686,7 +686,7 @@ func (s *DefaultLoaderTestSuite) TestLoaderWithValidODPCache() { }}, }, } - loader := defaultLoader(conf, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) + loader := defaultLoader(config.AgentConfig{Client: conf}, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) client, err := loader("sdkkey") s.NoError(err) @@ -709,7 +709,7 @@ func (s *DefaultLoaderTestSuite) TestLoaderWithEmptyUserProfileServices() { conf := config.ClientConfig{ UserProfileService: map[string]interface{}{}, } - loader := defaultLoader(conf, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) + loader := defaultLoader(config.AgentConfig{Client: conf}, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) client, err := loader("sdkkey") s.NoError(err) s.Nil(client.UserProfileService) @@ -726,7 +726,7 @@ func (s *DefaultLoaderTestSuite) TestLoaderWithEmptyODPCache() { SegmentsCache: map[string]interface{}{}, }, } - loader := defaultLoader(conf, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) + loader := defaultLoader(config.AgentConfig{Client: conf}, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) client, err := loader("sdkkey") s.NoError(err) s.Nil(client.odpCache) @@ -743,7 +743,7 @@ func (s *DefaultLoaderTestSuite) TestLoaderWithNoDefaultUserProfileServices() { "mock3": map[string]interface{}{}, }}, } - loader := defaultLoader(conf, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) + loader := defaultLoader(config.AgentConfig{Client: conf}, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) client, err := loader("sdkkey") s.NoError(err) s.Nil(client.UserProfileService) @@ -762,7 +762,7 @@ func (s *DefaultLoaderTestSuite) TestLoaderWithNoDefaultODPCache() { }}, }, } - loader := defaultLoader(conf, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) + loader := defaultLoader(config.AgentConfig{Client: conf}, s.registry, s.upsMap, s.odpCacheMap, s.pcFactory, s.bpFactory) client, err := loader("sdkkey") s.NoError(err) s.Nil(client.odpCache) diff --git a/pkg/routers/api.go b/pkg/routers/api.go index b7d2aac5..7eb15d6e 100644 --- a/pkg/routers/api.go +++ b/pkg/routers/api.go @@ -63,34 +63,37 @@ func forbiddenHandler(message string) http.HandlerFunc { } // NewDefaultAPIRouter creates a new router with the default backing optimizely.Cache -func NewDefaultAPIRouter(optlyCache optimizely.Cache, conf config.APIConfig, metricsRegistry *metrics.Registry) http.Handler { - authProvider := middleware.NewAuth(&conf.Auth) +func NewDefaultAPIRouter(optlyCache optimizely.Cache, conf config.AgentConfig, metricsRegistry *metrics.Registry) http.Handler { + authProvider := middleware.NewAuth(&conf.API.Auth) if authProvider == nil { log.Error().Msg("unable to initialize api auth middleware.") return nil } - authHandler := handlers.NewOAuthHandler(&conf.Auth) + authHandler := handlers.NewOAuthHandler(&conf.API.Auth) if authHandler == nil { log.Error().Msg("unable to initialize api auth handler.") return nil } overrideHandler := handlers.Override - if !conf.EnableOverrides { + if !conf.API.EnableOverrides { overrideHandler = forbiddenHandler("Overrides not enabled") } - nStreamHandler := handlers.NotificationEventSteamHandler - if !conf.EnableNotifications { - nStreamHandler = forbiddenHandler("Notification stream not enabled") + nStreamHandler := forbiddenHandler("Notification stream not enabled") + if conf.API.EnableNotifications { + nStreamHandler = handlers.NotificationEventStreamHandler(handlers.DefaultNotificationReceiver) + if conf.Synchronization.Notification.Enable { + nStreamHandler = handlers.NotificationEventStreamHandler(handlers.RedisNotificationReceiver(conf.Synchronization)) + } } mw := middleware.CachedOptlyMiddleware{Cache: optlyCache} - corsHandler := createCorsHandler(conf.CORS) + corsHandler := createCorsHandler(conf.API.CORS) spec := &APIOptions{ - maxConns: conf.MaxConns, + maxConns: conf.API.MaxConns, metricsRegistry: metricsRegistry, configHandler: handlers.OptimizelyConfig, datafileHandler: handlers.GetDatafile, diff --git a/pkg/routers/api_test.go b/pkg/routers/api_test.go index e9728a14..79bfa428 100644 --- a/pkg/routers/api_test.go +++ b/pkg/routers/api_test.go @@ -331,7 +331,7 @@ func TestAPIV1TestSuite(t *testing.T) { } func TestNewDefaultAPIV1Router(t *testing.T) { - client := NewDefaultAPIRouter(MockCache{}, config.APIConfig{}, metricsRegistry) + client := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{}, metricsRegistry) assert.NotNil(t, client) } @@ -356,7 +356,7 @@ func TestNewDefaultAPIV1RouterInvalidHandlerConfig(t *testing.T) { EnableNotifications: false, EnableOverrides: false, } - client := NewDefaultAPIRouter(MockCache{}, invalidAPIConfig, metricsRegistry) + client := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{API: invalidAPIConfig}, metricsRegistry) assert.Nil(t, client) } @@ -371,12 +371,12 @@ func TestNewDefaultClientRouterInvalidMiddlewareConfig(t *testing.T) { EnableNotifications: false, EnableOverrides: false, } - client := NewDefaultAPIRouter(MockCache{}, invalidAPIConfig, metricsRegistry) + client := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{API: invalidAPIConfig}, metricsRegistry) assert.Nil(t, client) } func TestForbiddenRoutes(t *testing.T) { - mux := NewDefaultAPIRouter(MockCache{}, config.APIConfig{}, metricsRegistry) + mux := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{}, metricsRegistry) routes := []struct { method string diff --git a/pkg/syncer/syncer.go b/pkg/syncer/syncer.go new file mode 100644 index 00000000..aa102188 --- /dev/null +++ b/pkg/syncer/syncer.go @@ -0,0 +1,164 @@ +/**************************************************************************** + * Copyright 2023 Optimizely, Inc. and contributors * + * * + * Licensed under the Apache License, Version 2.0 (the "License"); * + * you may not use this file except in compliance with the License. * + * You may obtain a copy of the License at * + * * + * http://www.apache.org/licenses/LICENSE-2.0 * + * * + * Unless required by applicable law or agreed to in writing, software * + * distributed under the License is distributed on an "AS IS" BASIS, * + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * + * See the License for the specific language governing permissions and * + * limitations under the License. * + ***************************************************************************/ + +// Package syncer provides synchronization across Agent nodes +package syncer + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + + "github.com/go-redis/redis/v8" + "github.com/optimizely/agent/config" + "github.com/optimizely/go-sdk/pkg/notification" + "github.com/rs/zerolog" +) + +const ( + // PubSubDefaultChan will be used as default pubsub channel name + PubSubDefaultChan = "optimizely-sync" + // PubSubRedis is the name of pubsub type of Redis + PubSubRedis = "redis" +) + +var ( + ncCache = make(map[string]*RedisSyncer) + mutexLock = &sync.Mutex{} +) + +// Event holds the notification event with it's type +type Event struct { + Type notification.Type `json:"type"` + Message interface{} `json:"message"` +} + +// RedisSyncer defines Redis pubsub configuration +type RedisSyncer struct { + ctx context.Context + Host string + Password string + Database int + Channel string + logger *zerolog.Logger + sdkKey string +} + +// NewRedisSyncer returns an instance of RedisNotificationSyncer +func NewRedisSyncer(logger *zerolog.Logger, conf config.SyncConfig, sdkKey string) (*RedisSyncer, error) { + mutexLock.Lock() + defer mutexLock.Unlock() + + if nc, found := ncCache[sdkKey]; found { + return nc, nil + } + + if !conf.Notification.Enable { + return nil, errors.New("notification syncer is not enabled") + } + if conf.Notification.Default != PubSubRedis { + return nil, errors.New("redis syncer is not set as default") + } + if conf.Pubsub == nil { + return nil, errors.New("redis config is not given") + } + + redisConfig, found := conf.Pubsub[PubSubRedis].(map[string]interface{}) + if !found { + return nil, errors.New("redis pubsub config not found") + } + + host, ok := redisConfig["host"].(string) + if !ok { + return nil, errors.New("redis host not provided in correct format") + } + password, ok := redisConfig["password"].(string) + if !ok { + return nil, errors.New("redis password not provider in correct format") + } + database, ok := redisConfig["database"].(int) + if !ok { + return nil, errors.New("redis database not provided in correct format") + } + channel, ok := redisConfig["channel"].(string) + if !ok { + channel = PubSubDefaultChan + } + + if logger == nil { + logger = &zerolog.Logger{} + } + + nc := &RedisSyncer{ + ctx: context.Background(), + Host: host, + Password: password, + Database: database, + Channel: channel, + logger: logger, + sdkKey: sdkKey, + } + ncCache[sdkKey] = nc + return nc, nil +} + +func (r *RedisSyncer) WithContext(ctx context.Context) *RedisSyncer { + r.ctx = ctx + return r +} + +// AddHandler is empty but needed to implement notification.Center interface +func (r *RedisSyncer) AddHandler(_ notification.Type, _ func(interface{})) (int, error) { + return 0, nil +} + +// RemoveHandler is empty but needed to implement notification.Center interface +func (r *RedisSyncer) RemoveHandler(_ int, t notification.Type) error { + return nil +} + +// Send will send the notification to the specified channel in the Redis pubsub +func (r *RedisSyncer) Send(t notification.Type, n interface{}) error { + event := Event{ + Type: t, + Message: n, + } + + jsonEvent, err := json.Marshal(event) + if err != nil { + return err + } + + client := redis.NewClient(&redis.Options{ + Addr: r.Host, + Password: r.Password, + DB: r.Database, + }) + defer client.Close() + channel := GetChannelForSDKKey(r.Channel, r.sdkKey) + + if err := client.Publish(r.ctx, channel, jsonEvent).Err(); err != nil { + r.logger.Err(err).Msg("failed to publish json event to pub/sub") + return err + } + return nil +} + +func GetChannelForSDKKey(channel, key string) string { + return fmt.Sprintf("%s-%s", channel, key) +}