diff --git a/lib/router.go b/lib/router.go index 03e56545..eb7b7957 100644 --- a/lib/router.go +++ b/lib/router.go @@ -1,6 +1,7 @@ package router import ( + "database/sql" "fmt" "net/http" "net/url" @@ -8,13 +9,11 @@ import ( "sync" "time" - "github.com/prometheus/client_golang/prometheus" - "github.com/alphagov/router/handlers" "github.com/alphagov/router/logger" "github.com/alphagov/router/triemux" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "github.com/lib/pq" + "github.com/prometheus/client_golang/prometheus" ) const ( @@ -38,18 +37,18 @@ const ( // come from, Route and Backend should not contain bson fields. // MongoReplicaSet, MongoReplicaSetMember etc. should move out of this module. type Router struct { - mux *triemux.Mux - lock sync.RWMutex - mongoReadToOptime bson.MongoTimestamp - logger logger.Logger - opts Options - ReloadChan chan bool + mux *triemux.Mux + lock sync.RWMutex + logger logger.Logger + opts Options + ReloadChan chan bool } type Options struct { - MongoURL string - MongoDBName string - MongoPollInterval time.Duration + DatabaseURL string + DatabaseName string + Listener *pq.Listener + DatabasePollInterval time.Duration BackendConnTimeout time.Duration BackendHeaderTimeout time.Duration LogFileName string @@ -61,16 +60,6 @@ type Backend struct { SubdomainName string `bson:"subdomain_name"` } -type MongoReplicaSet struct { - Members []MongoReplicaSetMember `bson:"members"` -} - -type MongoReplicaSetMember struct { - Name string `bson:"name"` - Optime bson.MongoTimestamp `bson:"optime"` - Current bool `bson:"self"` -} - type Route struct { IncomingPath string `bson:"incoming_path"` RouteType string `bson:"route_type"` @@ -92,7 +81,7 @@ func RegisterMetrics(r prometheus.Registerer) { // NewRouter returns a new empty router instance. You will need to call // SelfUpdateRoutes() to initialise the self-update process for routes. func NewRouter(o Options) (rt *Router, err error) { - logInfo("router: using mongo poll interval:", o.MongoPollInterval) + logInfo("router: using database poll interval:", o.DatabasePollInterval) logInfo("router: using backend connect timeout:", o.BackendConnTimeout) logInfo("router: using backend header timeout:", o.BackendHeaderTimeout) @@ -102,18 +91,27 @@ func NewRouter(o Options) (rt *Router, err error) { } logInfo("router: logging errors as JSON to", o.LogFileName) - mongoReadToOptime, err := bson.NewMongoTimestamp(time.Date(1970, time.January, 1, 0, 0, 0, 0, time.UTC), 1) + listenerProblemReporter := func(event pq.ListenerEventType, err error) { + if err != nil { + logWarn(fmt.Sprintf("pq: error creating listener for PSQL notify channel: %v)", err)) + return + } + } + + listener := pq.NewListener(o.DatabaseURL, 10*time.Second, time.Minute, listenerProblemReporter) + o.Listener = listener + + err = listener.Listen("events") if err != nil { - return nil, err + panic(err) } reloadChan := make(chan bool, 1) rt = &Router{ - mux: triemux.NewMux(), - mongoReadToOptime: mongoReadToOptime, - logger: l, - opts: o, - ReloadChan: reloadChan, + mux: triemux.NewMux(), + logger: l, + opts: o, + ReloadChan: reloadChan, } go rt.pollAndReload() @@ -150,9 +148,9 @@ func (rt *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func (rt *Router) SelfUpdateRoutes() { - logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.opts.MongoPollInterval)) + logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.opts.DatabasePollInterval)) - tick := time.Tick(rt.opts.MongoPollInterval) + tick := time.Tick(rt.opts.DatabasePollInterval) for range tick { logDebug("router: polling MongoDB for changes") @@ -172,32 +170,19 @@ func (rt *Router) pollAndReload() { } }() - logDebug("mgo: connecting to", rt.opts.MongoURL) + logDebug("pq: connecting to", rt.opts.DatabaseURL) - sess, err := mgo.Dial(rt.opts.MongoURL) + sess, err := sql.Open("postgres", rt.opts.DatabaseURL) if err != nil { - logWarn(fmt.Sprintf("mgo: error connecting to MongoDB, skipping update (error: %v)", err)) + logWarn(fmt.Sprintf("pq: error connecting to PSQL database, skipping update (error: %v)", err)) return } defer sess.Close() - sess.SetMode(mgo.SecondaryPreferred, true) - - currentMongoInstance, err := rt.getCurrentMongoInstance(sess.DB("admin")) - if err != nil { - logWarn(err) - return - } - logDebug("mgo: communicating with replica set member", currentMongoInstance.Name) - - logDebug("router: polled mongo instance is ", currentMongoInstance.Name) - logDebug("router: polled mongo optime is ", currentMongoInstance.Optime) - logDebug("router: current read-to mongo optime is ", rt.mongoReadToOptime) - - if rt.shouldReload(currentMongoInstance) { + if rt.shouldReload(rt.opts.Listener) { logDebug("router: updates found") - rt.reloadRoutes(sess.DB(rt.opts.MongoDBName), currentMongoInstance.Optime) + rt.reloadRoutes(sess) } else { logDebug("router: no updates found") } @@ -212,7 +197,7 @@ type mongoDatabase interface { // reloadRoutes reloads the routes for this Router instance on the fly. It will // create a new proxy mux, load applications (backends) and routes into it, and // then flip the "mux" pointer in the Router. -func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimestamp) { +func (rt *Router) reloadRoutes(db *sql.DB) { defer func() { // increment this metric regardless of whether the route reload succeeded routeReloadCountMetric.Inc() @@ -225,16 +210,14 @@ func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimesta logger.NotifySentry(logger.ReportableError{Error: err}) routeReloadErrorCountMetric.Inc() - } else { - rt.mongoReadToOptime = currentOptime } }() logInfo("router: reloading routes") newmux := triemux.NewMux() - backends := rt.loadBackends(db.C("backends")) - loadRoutes(db.C("routes"), newmux, backends) + backends := rt.loadBackends(db) + loadRoutes(db, newmux, backends) routeCount := newmux.RouteCount() rt.lock.Lock() @@ -245,58 +228,44 @@ func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimesta routesCountMetric.Set(float64(routeCount)) } -func (rt *Router) getCurrentMongoInstance(db mongoDatabase) (MongoReplicaSetMember, error) { - replicaSetStatus := bson.M{} - - if err := db.Run("replSetGetStatus", &replicaSetStatus); err != nil { - return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't get replica set status from MongoDB, skipping update (error: %w)", err) - } - - replicaSetStatusBytes, err := bson.Marshal(replicaSetStatus) - if err != nil { - return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't marshal replica set status from MongoDB, skipping update (error: %w)", err) - } - - replicaSet := MongoReplicaSet{} - err = bson.Unmarshal(replicaSetStatusBytes, &replicaSet) - if err != nil { - return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't unmarshal replica set status from MongoDB, skipping update (error: %w)", err) - } - - currentInstance := make([]MongoReplicaSetMember, 0) - for _, instance := range replicaSet.Members { - if instance.Current { - currentInstance = append(currentInstance, instance) +func (rt *Router) shouldReload(listener *pq.Listener) bool { + select { + case n := <-listener.Notify: + // n.Extra contains the payload from the notification + logInfo("notification:", n.Channel) + return true + default: + if err := listener.Ping(); err != nil { + panic(err) } + return false } - - logDebug("router: MongoDB instances", currentInstance) - - if len(currentInstance) != 1 { - return MongoReplicaSetMember{}, fmt.Errorf("router: did not find exactly one current MongoDB instance, skipping update (current instances found: %d)", len(currentInstance)) - } - - return currentInstance[0], nil -} - -func (rt *Router) shouldReload(currentMongoInstance MongoReplicaSetMember) bool { - return currentMongoInstance.Optime > rt.mongoReadToOptime } // loadBackends is a helper function which loads backends from the // passed mongo collection, constructs a Handler for each one, and returns // them in map keyed on the backend_id -func (rt *Router) loadBackends(c *mgo.Collection) (backends map[string]http.Handler) { +func (rt *Router) loadBackends(db *sql.DB) (backends map[string]http.Handler) { backend := &Backend{} backends = make(map[string]http.Handler) - iter := c.Find(nil).Iter() + rows, err := db.Query("SELECT * FROM backends") + if err != nil { + logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err)) + return + } + + for rows.Next() { + err := rows.Scan(&backend.BackendID, &backend.BackendURL) + if err != nil { + logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err)) + return + } - for iter.Next(&backend) { backendURL, err := backend.ParseURL() if err != nil { - logWarn(fmt.Errorf("router: couldn't parse URL %s for backend %s "+ - "(error: %w), skipping", backend.BackendURL, backend.BackendID, err)) + logWarn(fmt.Sprintf("router: couldn't parse URL %s for backends %s "+ + "(error: %v), skipping!", backend.BackendURL, backend.BackendID, err)) continue } @@ -318,10 +287,14 @@ func (rt *Router) loadBackends(c *mgo.Collection) (backends map[string]http.Hand // loadRoutes is a helper function which loads routes from the passed mongo // collection and registers them with the passed proxy mux. -func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Handler) { +func loadRoutes(db *sql.DB, mux *triemux.Mux, backends map[string]http.Handler) { route := &Route{} - iter := c.Find(nil).Iter() + rows, err := db.Query("SELECT * FROM routes") + if err != nil { + logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err)) + return + } goneHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "410 Gone", http.StatusGone) @@ -330,7 +303,12 @@ func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Ha http.Error(w, "503 Service Unavailable", http.StatusServiceUnavailable) }) - for iter.Next(&route) { + for rows.Next() { + err := rows.Scan(&route.IncomingPath, &route.RouteType, &route.Handler, &route.Disabled, &route.BackendID, &route.RedirectTo, &route.RedirectType, &route.SegmentsMode) + if err != nil { + logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err)) + return + } prefix := (route.RouteType == RouteTypePrefix) // the database contains paths with % encoded routes. diff --git a/main.go b/main.go index f2d878e3..85c3b92f 100644 --- a/main.go +++ b/main.go @@ -85,17 +85,17 @@ func main() { router.EnableDebugOutput = os.Getenv("ROUTER_DEBUG") != "" var ( - pubAddr = getenv("ROUTER_PUBADDR", ":8080") - apiAddr = getenv("ROUTER_APIADDR", ":8081") - mongoURL = getenv("ROUTER_MONGO_URL", "127.0.0.1") - mongoDBName = getenv("ROUTER_MONGO_DB", "router") - mongoPollInterval = getenvDuration("ROUTER_MONGO_POLL_INTERVAL", "2s") - errorLogFile = getenv("ROUTER_ERROR_LOG", "STDERR") - tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" - beConnTimeout = getenvDuration("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s") - beHeaderTimeout = getenvDuration("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") - feReadTimeout = getenvDuration("ROUTER_FRONTEND_READ_TIMEOUT", "60s") - feWriteTimeout = getenvDuration("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s") + pubAddr = getenv("ROUTER_PUBADDR", ":8080") + apiAddr = getenv("ROUTER_APIADDR", ":8081") + databaseURL = getenv("DATABASE_URL", "postgresql://postgres@127.0.0.1:27017/router?sslmode=disable") + databaseName = getenv("DATABASE_NAME", "router") + dbPollInterval = getenv("ROUTER_POLL_INTERVAL", "2s") + errorLogFile = getenv("ROUTER_ERROR_LOG", "STDERR") + tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" + beConnTimeout = getenvDuration("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s") + beHeaderTimeout = getenvDuration("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") + feReadTimeout = getenvDuration("ROUTER_FRONTEND_READ_TIMEOUT", "60s") + feWriteTimeout = getenvDuration("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s") ) log.Printf("using frontend read timeout: %v", feReadTimeout) @@ -111,12 +111,12 @@ func main() { router.RegisterMetrics(prometheus.DefaultRegisterer) rout, err := router.NewRouter(router.Options{ - MongoURL: mongoURL, - MongoDBName: mongoDBName, - MongoPollInterval: mongoPollInterval, - BackendConnTimeout: beConnTimeout, - BackendHeaderTimeout: beHeaderTimeout, - LogFileName: errorLogFile, + DatabaseURL: databaseURL, + DatabaseName: databaseName, + DatabasePolleInterval: dbPollInterval, + BackendConnTimeout: beConnTimeout, + BackendHeaderTimeout: beHeaderTimeout, + LogFileName: errorLogFile, }) if err != nil { log.Fatal(err)