diff --git a/mux/router.go b/mux/router.go index fc19b875..d72e8301 100644 --- a/mux/router.go +++ b/mux/router.go @@ -93,9 +93,19 @@ func pathMatch(pattern Route, path string) bool { return pattern.regexMatcher.regexp.MatchString(path) } +// FilterPath checks the unfiltered input path or pattern against a blacklist and transforms them into valid paths +// of the same semantic meaning +func FilterPath(unfiltered string) string { + if unfiltered == "" { + return "/" + } + return unfiltered +} + // Find a handler on a handler map given a path string // Most-specific (longest) pattern wins func (r *Router) Match(path string, routeParams *RouteParams) (matchedRoute *Route, matchedPattern string) { + path = FilterPath(path) r.m.RLock() n := 0 for pattern, route := range r.z { @@ -127,10 +137,7 @@ func (r *Router) Match(path string, routeParams *RouteParams) (matchedRoute *Rou // Handle adds a handler to the Router for pattern. func (r *Router) Handle(pattern string, handler Handler) error { - switch pattern { - case "", "/": - pattern = "/" - } + pattern = FilterPath(pattern) if handler == nil { return errors.New("nil handler") @@ -170,10 +177,7 @@ func (r *Router) DefaultHandleFunc(handler func(w ResponseWriter, r *Message)) { // HandleRemove deregistrars the handler specific for pattern from the Router. func (r *Router) HandleRemove(pattern string) error { - switch pattern { - case "", "/": - pattern = "/" - } + pattern = FilterPath(pattern) r.m.Lock() defer r.m.Unlock() if _, ok := r.z[pattern]; ok { @@ -185,6 +189,7 @@ func (r *Router) HandleRemove(pattern string) error { // GetRoute obtains route from the pattern it has been assigned func (r *Router) GetRoute(pattern string) *Route { + pattern = FilterPath(pattern) r.m.RLock() defer r.m.RUnlock() if route, ok := r.z[pattern]; ok { diff --git a/mux/router_test.go b/mux/router_test.go index cb5b5eca..4a56fbf3 100644 --- a/mux/router_test.go +++ b/mux/router_test.go @@ -5,6 +5,7 @@ import ( "github.com/plgd-dev/go-coap/v3/mux" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" // TODO: replace with standard maps package as soon as Go dependency hits 1.20 ) type routeTest struct { @@ -78,6 +79,34 @@ func TestMux(t *testing.T) { pathTemplate: `/{category:a|b/c}`, shouldMatch: true, }, + { + title: "root path without slash", + vars: map[string]string{}, + path: "", + pathTemplate: "", + shouldMatch: true, + }, + { + title: "root path with slash", + vars: map[string]string{}, + path: "/", + pathTemplate: "/", + shouldMatch: true, + }, + { + title: "root path without slash", + vars: map[string]string{}, + path: "", + pathTemplate: "/", + shouldMatch: true, + }, + { + title: "root path with slash", + vars: map[string]string{}, + path: "/", + pathTemplate: "", + shouldMatch: true, + }, } for _, test := range tests { @@ -116,29 +145,14 @@ func testRoute(t *testing.T, router *mux.Router, test routeTest) { t.Errorf("(%v) %v:\nPath: %#v\nPathTemplate: %#v\nVars: %v\n", test.title, msg, test.path, test.pathTemplate, test.vars) } if test.shouldMatch { - if vars != nil && !stringMapEqual(vars, routeParams.Vars) { + if vars != nil && !maps.Equal(vars, routeParams.Vars) { t.Errorf("(%v) Vars not equal: expected %v, got %v", test.title, vars, routeParams.Vars) return } - if routeParams.PathTemplate != test.pathTemplate { - t.Errorf("(%v) PathTemplate not equal: expected %v, got %v", test.title, test.pathTemplate, test.pathTemplate) + if routeParams.PathTemplate != mux.FilterPath(test.pathTemplate) { + t.Errorf("(%v) PathTemplate not equal: expected %v, got %v", test.title, test.pathTemplate, routeParams.PathTemplate) return } } } - -// stringMapEqual checks the equality of two string maps -func stringMapEqual(m1, m2 map[string]string) bool { - nil1 := m1 == nil - nil2 := m2 == nil - if nil1 != nil2 || len(m1) != len(m2) { - return false - } - for k, v := range m1 { - if v != m2[k] { - return false - } - } - return true -}