Skip to content

Commit

Permalink
Migrate to rely on PostgresQL database
Browse files Browse the repository at this point in the history
The underlying functionality is the same, but we are moving away from
mongo to postgres to ensure we can continue to maintain the app.

We no longer keep soft copies of the data read from the database,
instead reading each row from memory.

We no longer rely on comparing OpTime from Mongo changes, instead relying
on PSQL's LISTEN/NOTIFY process to notify us of changes. However, in order
to avoid overloading with constant update requests, we still poll
the update every two seconds.
  • Loading branch information
Tetrino committed Aug 22, 2023
1 parent 3b4a1b7 commit 9dec2a5
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 124 deletions.
186 changes: 76 additions & 110 deletions lib/router.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
package router

import (
"database/sql"
"fmt"
"net/http"
"net/url"
"os"
"sync"
"time"

"github.com/prometheus/client_golang/prometheus"

"github.com/alphagov/router/handlers"
"github.com/alphagov/router/logger"
"github.com/alphagov/router/triemux"
"github.com/globalsign/mgo"
"github.com/globalsign/mgo/bson"
"github.com/lib/pq"

Check failure on line 15 in lib/router.go

View workflow job for this annotation

GitHub Actions / Test Go

cannot query module due to -mod=vendor
"github.com/prometheus/client_golang/prometheus"
)

const (
Expand All @@ -38,18 +37,18 @@ const (
// come from, Route and Backend should not contain bson fields.
// MongoReplicaSet, MongoReplicaSetMember etc. should move out of this module.
type Router struct {
mux *triemux.Mux
lock sync.RWMutex
mongoReadToOptime bson.MongoTimestamp
logger logger.Logger
opts Options
ReloadChan chan bool
mux *triemux.Mux
lock sync.RWMutex
logger logger.Logger
opts Options
ReloadChan chan bool
}

type Options struct {
MongoURL string
MongoDBName string
MongoPollInterval time.Duration
DatabaseURL string
DatabaseName string
Listener *pq.Listener
DatabasePollInterval time.Duration
BackendConnTimeout time.Duration
BackendHeaderTimeout time.Duration
LogFileName string
Expand All @@ -61,16 +60,6 @@ type Backend struct {
SubdomainName string `bson:"subdomain_name"`
}

type MongoReplicaSet struct {
Members []MongoReplicaSetMember `bson:"members"`
}

type MongoReplicaSetMember struct {
Name string `bson:"name"`
Optime bson.MongoTimestamp `bson:"optime"`
Current bool `bson:"self"`
}

type Route struct {
IncomingPath string `bson:"incoming_path"`
RouteType string `bson:"route_type"`
Expand All @@ -92,7 +81,7 @@ func RegisterMetrics(r prometheus.Registerer) {
// NewRouter returns a new empty router instance. You will need to call
// SelfUpdateRoutes() to initialise the self-update process for routes.
func NewRouter(o Options) (rt *Router, err error) {
logInfo("router: using mongo poll interval:", o.MongoPollInterval)
logInfo("router: using database poll interval:", o.DatabasePollInterval)
logInfo("router: using backend connect timeout:", o.BackendConnTimeout)
logInfo("router: using backend header timeout:", o.BackendHeaderTimeout)

Expand All @@ -102,18 +91,27 @@ func NewRouter(o Options) (rt *Router, err error) {
}
logInfo("router: logging errors as JSON to", o.LogFileName)

mongoReadToOptime, err := bson.NewMongoTimestamp(time.Date(1970, time.January, 1, 0, 0, 0, 0, time.UTC), 1)
listenerProblemReporter := func(event pq.ListenerEventType, err error) {
if err != nil {
logWarn(fmt.Sprintf("pq: error creating listener for PSQL notify channel: %v)", err))
return
}
}

listener := pq.NewListener(o.DatabaseURL, 10*time.Second, time.Minute, listenerProblemReporter)
o.Listener = listener

err = listener.Listen("events")
if err != nil {
return nil, err
panic(err)
}

reloadChan := make(chan bool, 1)
rt = &Router{
mux: triemux.NewMux(),
mongoReadToOptime: mongoReadToOptime,
logger: l,
opts: o,
ReloadChan: reloadChan,
mux: triemux.NewMux(),
logger: l,
opts: o,
ReloadChan: reloadChan,
}

go rt.pollAndReload()
Expand Down Expand Up @@ -150,9 +148,9 @@ func (rt *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

func (rt *Router) SelfUpdateRoutes() {
logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.opts.MongoPollInterval))
logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.opts.DatabasePollInterval))

tick := time.Tick(rt.opts.MongoPollInterval)
tick := time.Tick(rt.opts.DatabasePollInterval)
for range tick {
logDebug("router: polling MongoDB for changes")

Expand All @@ -172,47 +170,30 @@ func (rt *Router) pollAndReload() {
}
}()

logDebug("mgo: connecting to", rt.opts.MongoURL)
logDebug("pq: connecting to", rt.opts.DatabaseURL)

sess, err := mgo.Dial(rt.opts.MongoURL)
sess, err := sql.Open("postgres", rt.opts.DatabaseURL)
if err != nil {
logWarn(fmt.Sprintf("mgo: error connecting to MongoDB, skipping update (error: %v)", err))
logWarn(fmt.Sprintf("pq: error connecting to PSQL database, skipping update (error: %v)", err))
return
}

defer sess.Close()
sess.SetMode(mgo.SecondaryPreferred, true)

currentMongoInstance, err := rt.getCurrentMongoInstance(sess.DB("admin"))
if err != nil {
logWarn(err)
return
}

logDebug("mgo: communicating with replica set member", currentMongoInstance.Name)

logDebug("router: polled mongo instance is ", currentMongoInstance.Name)
logDebug("router: polled mongo optime is ", currentMongoInstance.Optime)
logDebug("router: current read-to mongo optime is ", rt.mongoReadToOptime)

if rt.shouldReload(currentMongoInstance) {
if rt.shouldReload(rt.opts.Listener) {
logDebug("router: updates found")
rt.reloadRoutes(sess.DB(rt.opts.MongoDBName), currentMongoInstance.Optime)
rt.reloadRoutes(sess)
} else {
logDebug("router: no updates found")
}
}()
}
}

type mongoDatabase interface {
Run(command interface{}, result interface{}) error
}

// reloadRoutes reloads the routes for this Router instance on the fly. It will
// create a new proxy mux, load applications (backends) and routes into it, and
// then flip the "mux" pointer in the Router.
func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimestamp) {
func (rt *Router) reloadRoutes(db *sql.DB) {
defer func() {
// increment this metric regardless of whether the route reload succeeded
routeReloadCountMetric.Inc()
Expand All @@ -225,16 +206,14 @@ func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimesta
logger.NotifySentry(logger.ReportableError{Error: err})

routeReloadErrorCountMetric.Inc()
} else {
rt.mongoReadToOptime = currentOptime
}
}()

logInfo("router: reloading routes")
newmux := triemux.NewMux()

backends := rt.loadBackends(db.C("backends"))
loadRoutes(db.C("routes"), newmux, backends)
backends := rt.loadBackends(db)
loadRoutes(db, newmux, backends)
routeCount := newmux.RouteCount()

rt.lock.Lock()
Expand All @@ -245,58 +224,44 @@ func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimesta
routesCountMetric.Set(float64(routeCount))
}

func (rt *Router) getCurrentMongoInstance(db mongoDatabase) (MongoReplicaSetMember, error) {
replicaSetStatus := bson.M{}

if err := db.Run("replSetGetStatus", &replicaSetStatus); err != nil {
return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't get replica set status from MongoDB, skipping update (error: %w)", err)
}

replicaSetStatusBytes, err := bson.Marshal(replicaSetStatus)
if err != nil {
return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't marshal replica set status from MongoDB, skipping update (error: %w)", err)
}

replicaSet := MongoReplicaSet{}
err = bson.Unmarshal(replicaSetStatusBytes, &replicaSet)
if err != nil {
return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't unmarshal replica set status from MongoDB, skipping update (error: %w)", err)
}

currentInstance := make([]MongoReplicaSetMember, 0)
for _, instance := range replicaSet.Members {
if instance.Current {
currentInstance = append(currentInstance, instance)
func (rt *Router) shouldReload(listener *pq.Listener) bool {
select {
case n := <-listener.Notify:
// n.Extra contains the payload from the notification
logInfo("notification:", n.Channel)
return true
default:
if err := listener.Ping(); err != nil {
panic(err)
}
return false
}

logDebug("router: MongoDB instances", currentInstance)

if len(currentInstance) != 1 {
return MongoReplicaSetMember{}, fmt.Errorf("router: did not find exactly one current MongoDB instance, skipping update (current instances found: %d)", len(currentInstance))
}

return currentInstance[0], nil
}

func (rt *Router) shouldReload(currentMongoInstance MongoReplicaSetMember) bool {
return currentMongoInstance.Optime > rt.mongoReadToOptime
}

// loadBackends is a helper function which loads backends from the
// passed mongo collection, constructs a Handler for each one, and returns
// them in map keyed on the backend_id
func (rt *Router) loadBackends(c *mgo.Collection) (backends map[string]http.Handler) {
func (rt *Router) loadBackends(db *sql.DB) (backends map[string]http.Handler) {
backend := &Backend{}
backends = make(map[string]http.Handler)

iter := c.Find(nil).Iter()
rows, err := db.Query("SELECT * FROM backends")
if err != nil {
logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err))
return
}

for rows.Next() {
err := rows.Scan(&backend.BackendID, &backend.BackendURL)
if err != nil {
logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err))
return
}

for iter.Next(&backend) {
backendURL, err := backend.ParseURL()
if err != nil {
logWarn(fmt.Errorf("router: couldn't parse URL %s for backend %s "+
"(error: %w), skipping", backend.BackendURL, backend.BackendID, err))
logWarn(fmt.Sprintf("router: couldn't parse URL %s for backends %s "+
"(error: %v), skipping!", backend.BackendURL, backend.BackendID, err))
continue
}

Expand All @@ -309,19 +274,19 @@ func (rt *Router) loadBackends(c *mgo.Collection) (backends map[string]http.Hand
)
}

if err := iter.Err(); err != nil {
panic(err)
}

return
}

// loadRoutes is a helper function which loads routes from the passed mongo
// collection and registers them with the passed proxy mux.
func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Handler) {
func loadRoutes(db *sql.DB, mux *triemux.Mux, backends map[string]http.Handler) {
route := &Route{}

iter := c.Find(nil).Iter()
rows, err := db.Query("SELECT * FROM routes")
if err != nil {
logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err))
return
}

goneHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "410 Gone", http.StatusGone)
Expand All @@ -330,7 +295,12 @@ func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Ha
http.Error(w, "503 Service Unavailable", http.StatusServiceUnavailable)
})

for iter.Next(&route) {
for rows.Next() {
err := rows.Scan(&route.IncomingPath, &route.RouteType, &route.Handler, &route.Disabled, &route.BackendID, &route.RedirectTo, &route.RedirectType, &route.SegmentsMode)
if err != nil {
logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err))
return
}
prefix := (route.RouteType == RouteTypePrefix)

// the database contains paths with % encoded routes.
Expand Down Expand Up @@ -379,10 +349,6 @@ func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Ha
continue
}
}

if err := iter.Err(); err != nil {
panic(err)
}
}

func (be *Backend) ParseURL() (*url.URL, error) {
Expand Down
33 changes: 19 additions & 14 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,17 @@ func main() {

router.EnableDebugOutput = os.Getenv("ROUTER_DEBUG") != ""
var (
pubAddr = getenv("ROUTER_PUBADDR", ":8080")
apiAddr = getenv("ROUTER_APIADDR", ":8081")
mongoURL = getenv("ROUTER_MONGO_URL", "127.0.0.1")
mongoDBName = getenv("ROUTER_MONGO_DB", "router")
mongoPollInterval = getenvDuration("ROUTER_MONGO_POLL_INTERVAL", "2s")
errorLogFile = getenv("ROUTER_ERROR_LOG", "STDERR")
tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != ""
beConnTimeout = getenvDuration("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s")
beHeaderTimeout = getenvDuration("ROUTER_BACKEND_HEADER_TIMEOUT", "20s")
feReadTimeout = getenvDuration("ROUTER_FRONTEND_READ_TIMEOUT", "60s")
feWriteTimeout = getenvDuration("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s")
pubAddr = getenv("ROUTER_PUBADDR", ":8080")
apiAddr = getenv("ROUTER_APIADDR", ":8081")
databaseURL = getenv("DATABASE_URL", "postgresql://postgres@127.0.0.1:27017/router?sslmode=disable")
databaseName = getenv("DATABASE_NAME", "router")
dbPollInterval = getenv("ROUTER_POLL_INTERVAL", "2s")
errorLogFile = getenv("ROUTER_ERROR_LOG", "STDERR")
tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != ""
beConnTimeout = getenvDuration("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s")
beHeaderTimeout = getenvDuration("ROUTER_BACKEND_HEADER_TIMEOUT", "20s")
feReadTimeout = getenvDuration("ROUTER_FRONTEND_READ_TIMEOUT", "60s")
feWriteTimeout = getenvDuration("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s")
)

log.Printf("using frontend read timeout: %v", feReadTimeout)
Expand All @@ -110,10 +110,15 @@ func main() {

router.RegisterMetrics(prometheus.DefaultRegisterer)

parsedPollInterval, err := time.ParseDuration(dbPollInterval)
if err != nil {
log.Fatal(err)
}

rout, err := router.NewRouter(router.Options{
MongoURL: mongoURL,
MongoDBName: mongoDBName,
MongoPollInterval: mongoPollInterval,
DatabaseURL: databaseURL,
DatabaseName: databaseName,
DatabasePollInterval: parsedPollInterval,
BackendConnTimeout: beConnTimeout,
BackendHeaderTimeout: beHeaderTimeout,
LogFileName: errorLogFile,
Expand Down

0 comments on commit 9dec2a5

Please sign in to comment.