From f1918866dcafd4dfe42d70f29d555ec11e75d636 Mon Sep 17 00:00:00 2001 From: Jonathon Date: Tue, 22 Aug 2023 12:09:05 +0100 Subject: [PATCH] Migrate to rely on PostgresQL database The underlying functionality is the same, but we are moving away from mongo to postgres to ensure we can continue to maintain the app. We no longer keep soft copies of the data read from the database, instead reading each row from memory. We no longer rely on comparing OpTime from Mongo changes, instead relying on PSQL's LISTEN/NOTIFY process to notify us of changes. However, in order to avoid overloading with constant update requests, we still poll the update every two seconds. --- lib/router.go | 174 ++++++++++++++++++++++---------------------------- main.go | 34 +++++----- 2 files changed, 93 insertions(+), 115 deletions(-) diff --git a/lib/router.go b/lib/router.go index 03e56545..eb7b7957 100644 --- a/lib/router.go +++ b/lib/router.go @@ -1,6 +1,7 @@ package router import ( + "database/sql" "fmt" "net/http" "net/url" @@ -8,13 +9,11 @@ import ( "sync" "time" - "github.com/prometheus/client_golang/prometheus" - "github.com/alphagov/router/handlers" "github.com/alphagov/router/logger" "github.com/alphagov/router/triemux" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "github.com/lib/pq" + "github.com/prometheus/client_golang/prometheus" ) const ( @@ -38,18 +37,18 @@ const ( // come from, Route and Backend should not contain bson fields. // MongoReplicaSet, MongoReplicaSetMember etc. should move out of this module. type Router struct { - mux *triemux.Mux - lock sync.RWMutex - mongoReadToOptime bson.MongoTimestamp - logger logger.Logger - opts Options - ReloadChan chan bool + mux *triemux.Mux + lock sync.RWMutex + logger logger.Logger + opts Options + ReloadChan chan bool } type Options struct { - MongoURL string - MongoDBName string - MongoPollInterval time.Duration + DatabaseURL string + DatabaseName string + Listener *pq.Listener + DatabasePollInterval time.Duration BackendConnTimeout time.Duration BackendHeaderTimeout time.Duration LogFileName string @@ -61,16 +60,6 @@ type Backend struct { SubdomainName string `bson:"subdomain_name"` } -type MongoReplicaSet struct { - Members []MongoReplicaSetMember `bson:"members"` -} - -type MongoReplicaSetMember struct { - Name string `bson:"name"` - Optime bson.MongoTimestamp `bson:"optime"` - Current bool `bson:"self"` -} - type Route struct { IncomingPath string `bson:"incoming_path"` RouteType string `bson:"route_type"` @@ -92,7 +81,7 @@ func RegisterMetrics(r prometheus.Registerer) { // NewRouter returns a new empty router instance. You will need to call // SelfUpdateRoutes() to initialise the self-update process for routes. func NewRouter(o Options) (rt *Router, err error) { - logInfo("router: using mongo poll interval:", o.MongoPollInterval) + logInfo("router: using database poll interval:", o.DatabasePollInterval) logInfo("router: using backend connect timeout:", o.BackendConnTimeout) logInfo("router: using backend header timeout:", o.BackendHeaderTimeout) @@ -102,18 +91,27 @@ func NewRouter(o Options) (rt *Router, err error) { } logInfo("router: logging errors as JSON to", o.LogFileName) - mongoReadToOptime, err := bson.NewMongoTimestamp(time.Date(1970, time.January, 1, 0, 0, 0, 0, time.UTC), 1) + listenerProblemReporter := func(event pq.ListenerEventType, err error) { + if err != nil { + logWarn(fmt.Sprintf("pq: error creating listener for PSQL notify channel: %v)", err)) + return + } + } + + listener := pq.NewListener(o.DatabaseURL, 10*time.Second, time.Minute, listenerProblemReporter) + o.Listener = listener + + err = listener.Listen("events") if err != nil { - return nil, err + panic(err) } reloadChan := make(chan bool, 1) rt = &Router{ - mux: triemux.NewMux(), - mongoReadToOptime: mongoReadToOptime, - logger: l, - opts: o, - ReloadChan: reloadChan, + mux: triemux.NewMux(), + logger: l, + opts: o, + ReloadChan: reloadChan, } go rt.pollAndReload() @@ -150,9 +148,9 @@ func (rt *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func (rt *Router) SelfUpdateRoutes() { - logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.opts.MongoPollInterval)) + logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.opts.DatabasePollInterval)) - tick := time.Tick(rt.opts.MongoPollInterval) + tick := time.Tick(rt.opts.DatabasePollInterval) for range tick { logDebug("router: polling MongoDB for changes") @@ -172,32 +170,19 @@ func (rt *Router) pollAndReload() { } }() - logDebug("mgo: connecting to", rt.opts.MongoURL) + logDebug("pq: connecting to", rt.opts.DatabaseURL) - sess, err := mgo.Dial(rt.opts.MongoURL) + sess, err := sql.Open("postgres", rt.opts.DatabaseURL) if err != nil { - logWarn(fmt.Sprintf("mgo: error connecting to MongoDB, skipping update (error: %v)", err)) + logWarn(fmt.Sprintf("pq: error connecting to PSQL database, skipping update (error: %v)", err)) return } defer sess.Close() - sess.SetMode(mgo.SecondaryPreferred, true) - - currentMongoInstance, err := rt.getCurrentMongoInstance(sess.DB("admin")) - if err != nil { - logWarn(err) - return - } - logDebug("mgo: communicating with replica set member", currentMongoInstance.Name) - - logDebug("router: polled mongo instance is ", currentMongoInstance.Name) - logDebug("router: polled mongo optime is ", currentMongoInstance.Optime) - logDebug("router: current read-to mongo optime is ", rt.mongoReadToOptime) - - if rt.shouldReload(currentMongoInstance) { + if rt.shouldReload(rt.opts.Listener) { logDebug("router: updates found") - rt.reloadRoutes(sess.DB(rt.opts.MongoDBName), currentMongoInstance.Optime) + rt.reloadRoutes(sess) } else { logDebug("router: no updates found") } @@ -212,7 +197,7 @@ type mongoDatabase interface { // reloadRoutes reloads the routes for this Router instance on the fly. It will // create a new proxy mux, load applications (backends) and routes into it, and // then flip the "mux" pointer in the Router. -func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimestamp) { +func (rt *Router) reloadRoutes(db *sql.DB) { defer func() { // increment this metric regardless of whether the route reload succeeded routeReloadCountMetric.Inc() @@ -225,16 +210,14 @@ func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimesta logger.NotifySentry(logger.ReportableError{Error: err}) routeReloadErrorCountMetric.Inc() - } else { - rt.mongoReadToOptime = currentOptime } }() logInfo("router: reloading routes") newmux := triemux.NewMux() - backends := rt.loadBackends(db.C("backends")) - loadRoutes(db.C("routes"), newmux, backends) + backends := rt.loadBackends(db) + loadRoutes(db, newmux, backends) routeCount := newmux.RouteCount() rt.lock.Lock() @@ -245,58 +228,44 @@ func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimesta routesCountMetric.Set(float64(routeCount)) } -func (rt *Router) getCurrentMongoInstance(db mongoDatabase) (MongoReplicaSetMember, error) { - replicaSetStatus := bson.M{} - - if err := db.Run("replSetGetStatus", &replicaSetStatus); err != nil { - return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't get replica set status from MongoDB, skipping update (error: %w)", err) - } - - replicaSetStatusBytes, err := bson.Marshal(replicaSetStatus) - if err != nil { - return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't marshal replica set status from MongoDB, skipping update (error: %w)", err) - } - - replicaSet := MongoReplicaSet{} - err = bson.Unmarshal(replicaSetStatusBytes, &replicaSet) - if err != nil { - return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't unmarshal replica set status from MongoDB, skipping update (error: %w)", err) - } - - currentInstance := make([]MongoReplicaSetMember, 0) - for _, instance := range replicaSet.Members { - if instance.Current { - currentInstance = append(currentInstance, instance) +func (rt *Router) shouldReload(listener *pq.Listener) bool { + select { + case n := <-listener.Notify: + // n.Extra contains the payload from the notification + logInfo("notification:", n.Channel) + return true + default: + if err := listener.Ping(); err != nil { + panic(err) } + return false } - - logDebug("router: MongoDB instances", currentInstance) - - if len(currentInstance) != 1 { - return MongoReplicaSetMember{}, fmt.Errorf("router: did not find exactly one current MongoDB instance, skipping update (current instances found: %d)", len(currentInstance)) - } - - return currentInstance[0], nil -} - -func (rt *Router) shouldReload(currentMongoInstance MongoReplicaSetMember) bool { - return currentMongoInstance.Optime > rt.mongoReadToOptime } // loadBackends is a helper function which loads backends from the // passed mongo collection, constructs a Handler for each one, and returns // them in map keyed on the backend_id -func (rt *Router) loadBackends(c *mgo.Collection) (backends map[string]http.Handler) { +func (rt *Router) loadBackends(db *sql.DB) (backends map[string]http.Handler) { backend := &Backend{} backends = make(map[string]http.Handler) - iter := c.Find(nil).Iter() + rows, err := db.Query("SELECT * FROM backends") + if err != nil { + logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err)) + return + } + + for rows.Next() { + err := rows.Scan(&backend.BackendID, &backend.BackendURL) + if err != nil { + logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err)) + return + } - for iter.Next(&backend) { backendURL, err := backend.ParseURL() if err != nil { - logWarn(fmt.Errorf("router: couldn't parse URL %s for backend %s "+ - "(error: %w), skipping", backend.BackendURL, backend.BackendID, err)) + logWarn(fmt.Sprintf("router: couldn't parse URL %s for backends %s "+ + "(error: %v), skipping!", backend.BackendURL, backend.BackendID, err)) continue } @@ -318,10 +287,14 @@ func (rt *Router) loadBackends(c *mgo.Collection) (backends map[string]http.Hand // loadRoutes is a helper function which loads routes from the passed mongo // collection and registers them with the passed proxy mux. -func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Handler) { +func loadRoutes(db *sql.DB, mux *triemux.Mux, backends map[string]http.Handler) { route := &Route{} - iter := c.Find(nil).Iter() + rows, err := db.Query("SELECT * FROM routes") + if err != nil { + logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err)) + return + } goneHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "410 Gone", http.StatusGone) @@ -330,7 +303,12 @@ func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Ha http.Error(w, "503 Service Unavailable", http.StatusServiceUnavailable) }) - for iter.Next(&route) { + for rows.Next() { + err := rows.Scan(&route.IncomingPath, &route.RouteType, &route.Handler, &route.Disabled, &route.BackendID, &route.RedirectTo, &route.RedirectType, &route.SegmentsMode) + if err != nil { + logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err)) + return + } prefix := (route.RouteType == RouteTypePrefix) // the database contains paths with % encoded routes. diff --git a/main.go b/main.go index f2d878e3..85c3b92f 100644 --- a/main.go +++ b/main.go @@ -85,17 +85,17 @@ func main() { router.EnableDebugOutput = os.Getenv("ROUTER_DEBUG") != "" var ( - pubAddr = getenv("ROUTER_PUBADDR", ":8080") - apiAddr = getenv("ROUTER_APIADDR", ":8081") - mongoURL = getenv("ROUTER_MONGO_URL", "127.0.0.1") - mongoDBName = getenv("ROUTER_MONGO_DB", "router") - mongoPollInterval = getenvDuration("ROUTER_MONGO_POLL_INTERVAL", "2s") - errorLogFile = getenv("ROUTER_ERROR_LOG", "STDERR") - tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" - beConnTimeout = getenvDuration("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s") - beHeaderTimeout = getenvDuration("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") - feReadTimeout = getenvDuration("ROUTER_FRONTEND_READ_TIMEOUT", "60s") - feWriteTimeout = getenvDuration("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s") + pubAddr = getenv("ROUTER_PUBADDR", ":8080") + apiAddr = getenv("ROUTER_APIADDR", ":8081") + databaseURL = getenv("DATABASE_URL", "postgresql://postgres@127.0.0.1:27017/router?sslmode=disable") + databaseName = getenv("DATABASE_NAME", "router") + dbPollInterval = getenv("ROUTER_POLL_INTERVAL", "2s") + errorLogFile = getenv("ROUTER_ERROR_LOG", "STDERR") + tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" + beConnTimeout = getenvDuration("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s") + beHeaderTimeout = getenvDuration("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") + feReadTimeout = getenvDuration("ROUTER_FRONTEND_READ_TIMEOUT", "60s") + feWriteTimeout = getenvDuration("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s") ) log.Printf("using frontend read timeout: %v", feReadTimeout) @@ -111,12 +111,12 @@ func main() { router.RegisterMetrics(prometheus.DefaultRegisterer) rout, err := router.NewRouter(router.Options{ - MongoURL: mongoURL, - MongoDBName: mongoDBName, - MongoPollInterval: mongoPollInterval, - BackendConnTimeout: beConnTimeout, - BackendHeaderTimeout: beHeaderTimeout, - LogFileName: errorLogFile, + DatabaseURL: databaseURL, + DatabaseName: databaseName, + DatabasePolleInterval: dbPollInterval, + BackendConnTimeout: beConnTimeout, + BackendHeaderTimeout: beHeaderTimeout, + LogFileName: errorLogFile, }) if err != nil { log.Fatal(err)