Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
pcfreak30 committed Jan 13, 2025
1 parent 88697e9 commit 9b85d47
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 3 deletions.
157 changes: 157 additions & 0 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3137,6 +3137,163 @@ func BenchmarkPopulateContext(b *testing.B) {
}
}

// testOptionsMiddleWare returns 200 on an OPTIONS request
func testOptionsMiddleWare(inner http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
inner.ServeHTTP(w, r)
})
}

// TestRouterOrder Should Pass whichever order route is defined
func TestRouterOrder(t *testing.T) {
type requestCase struct {
request *http.Request
expCode int
}

tests := []struct {
name string
routes []*Route
customMiddleware MiddlewareFunc
requests []requestCase
}{
{
name: "Routes added with same method and intersecting path regex",
routes: []*Route{
new(Route).Path("/a/b").Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})).Methods(http.MethodGet),
new(Route).Path("/a/{a}").Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).Methods(http.MethodGet),
},
requests: []requestCase{
{
request: newRequest(http.MethodGet, "/a/b"),
expCode: http.StatusNotFound,
},
{
request: newRequest(http.MethodGet, "/a/a"),
expCode: http.StatusOK,
},
},
},
{
name: "Routes added with same method and intersecting path regex, path with pathVariable first",
routes: []*Route{
new(Route).Path("/a/{a}").Handler(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).Methods(http.MethodGet),
new(Route).Path("/a/b").Handler(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})).Methods(http.MethodGet),
},
requests: []requestCase{
{
request: newRequest(http.MethodGet, "/a/b"),
expCode: http.StatusOK,
},
{
request: newRequest(http.MethodGet, "/a/a"),
expCode: http.StatusOK,
},
},
},
{
name: "Routes added same path - different methods, no path variables",
routes: []*Route{
new(Route).Path("/a/b").Handler(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).Methods(http.MethodGet),
new(Route).Path("/a/b").Handler(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})).Methods(http.MethodOptions),
},
requests: []requestCase{
{
request: newRequest(http.MethodGet, "/a/b"),
expCode: http.StatusOK,
},
{
request: newRequest(http.MethodOptions, "/a/b"),
expCode: http.StatusNotFound,
},
},
},
{
name: "Routes added same path - different methods, with path variables and middleware",
routes: []*Route{
new(Route).Path("/a/{a}").Handler(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})).Methods(http.MethodGet),
new(Route).Path("/a/b").Handler(nil).Methods(http.MethodOptions),
},
customMiddleware: testOptionsMiddleWare,
requests: []requestCase{
{
request: newRequest(http.MethodGet, "/a/b"),
expCode: http.StatusNotFound,
},
{
request: newRequest(http.MethodOptions, "/a/b"),
expCode: http.StatusOK,
},
},
},
{
name: "Routes added same path - different methods, with path variables and middleware order reversed",
routes: []*Route{
new(Route).Path("/a/b").Handler(nil).Methods(http.MethodOptions),
new(Route).Path("/a/{a}").Handler(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})).Methods(http.MethodGet),
},
customMiddleware: testOptionsMiddleWare,
requests: []requestCase{
{
request: newRequest(http.MethodGet, "/a/b"),
expCode: http.StatusNotFound,
},
{
request: newRequest(http.MethodOptions, "/a/b"),
expCode: http.StatusOK,
},
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
router := NewRouter()

if test.customMiddleware != nil {
router.Use(test.customMiddleware)
}

router.routes = test.routes
w := NewRecorder()

for _, requestCase := range test.requests {
router.ServeHTTP(w, requestCase.request)

if w.Code != requestCase.expCode {
t.Fatalf("Expected status code %d (got %d)", requestCase.expCode, w.Code)
}
}
})
}
}

// mapToPairs converts a string map to a slice of string pairs
func mapToPairs(m map[string]string) []string {
var i int
Expand Down
9 changes: 6 additions & 3 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,14 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
return false
}

if match.MatchErr != nil && r.handler != nil {
// If a route matches, but the HTTP method does not, we do one of two (2) things:
// 1. Reset the match error if we find a matching method later.
// 2. Else, we override the matched handler in the event we have a possible fallback handler for that route.
//
// This prevents propagation of ErrMethodMismatch once a suitable match is found for a Method-Path combination
if match.MatchErr == ErrMethodMismatch {
// We found a route which matches request method, clear MatchErr
match.MatchErr = nil
// Then override the mis-matched handler
match.Handler = r.handler
}

// Yay, we have a match. Let's collect some info about it.
Expand Down

0 comments on commit 9b85d47

Please sign in to comment.