diff --git a/tollbooth.go b/tollbooth.go index 6298ba5..27e008c 100644 --- a/tollbooth.go +++ b/tollbooth.go @@ -346,3 +346,30 @@ func LimitHandler(lmt *limiter.Limiter, next http.Handler) http.Handler { func LimitFuncHandler(lmt *limiter.Limiter, nextFunc func(http.ResponseWriter, *http.Request)) http.Handler { return LimitHandler(lmt, http.HandlerFunc(nextFunc)) } + +// HTTPMiddleware wraps http.Handler with tollbooth limiter +func HTTPMiddleware(lmt *limiter.Limiter) func(http.Handler) http.Handler { + // // set IP lookup only if not set + if lmt.GetIPLookup().Name == "" { + lmt.SetIPLookup(limiter.IPLookup{Name: "RemoteAddr"}) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + http.Error(w, "Context was canceled", http.StatusServiceUnavailable) + return + default: + if httpError := LimitByRequest(lmt, w, r); httpError != nil { + lmt.ExecOnLimitReached(w, r) + w.Header().Add("Content-Type", lmt.GetMessageContentType()) + w.WriteHeader(httpError.StatusCode) + w.Write([]byte(httpError.Message)) //nolint:gosec // not much we can do here with failed write + return + } + next.ServeHTTP(w, r) + } + }) + } +} diff --git a/tollbooth_test.go b/tollbooth_test.go index 9e9eb13..9119f52 100644 --- a/tollbooth_test.go +++ b/tollbooth_test.go @@ -653,3 +653,137 @@ func TestLimitHandlerEmptyHeader(t *testing.T) { wg.Wait() // Block until go func is done. } + +func TestHTTPMiddleware(t *testing.T) { + t.Run("basic request", func(t *testing.T) { + lmt := NewLimiter(1, nil) + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + wrapped := HTTPMiddleware(lmt)(handler) + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/test", nil) + r.RemoteAddr = "127.0.0.1:12345" + wrapped.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, w.Code) + } + }) + + t.Run("rate limit exceeded", func(t *testing.T) { + lmt := NewLimiter(0.1, nil) // only allow one request per 10 seconds + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + wrapped := HTTPMiddleware(lmt)(handler) + + // first request + w1 := httptest.NewRecorder() + r1 := httptest.NewRequest(http.MethodGet, "/test", nil) + r1.RemoteAddr = "127.0.0.1:12345" + wrapped.ServeHTTP(w1, r1) + if w1.Code != http.StatusOK { + t.Errorf("first request: expected status %d, got %d", http.StatusOK, w1.Code) + } + + // immediate second request should fail + w2 := httptest.NewRecorder() + r2 := httptest.NewRequest(http.MethodGet, "/test", nil) + r2.RemoteAddr = "127.0.0.1:12345" + wrapped.ServeHTTP(w2, r2) + if w2.Code != http.StatusTooManyRequests { + t.Errorf("second request: expected status %d, got %d", http.StatusTooManyRequests, w2.Code) + } + if !strings.Contains(w2.Body.String(), "maximum request limit") { + t.Errorf("expected error message containing 'maximum request limit', got %q", w2.Body.String()) + } + }) + + t.Run("context cancelled", func(t *testing.T) { + lmt := NewLimiter(1, nil) + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + wrapped := HTTPMiddleware(lmt)(handler) + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx, cancel := context.WithCancel(r.Context()) + cancel() + r = r.WithContext(ctx) + wrapped.ServeHTTP(w, r) + if w.Code != http.StatusServiceUnavailable { + t.Errorf("expected status %d, got %d", http.StatusServiceUnavailable, w.Code) + } + if !strings.Contains(w.Body.String(), "Context was canceled") { + t.Errorf("expected error message containing 'Context was canceled', got %q", w.Body.String()) + } + }) + + t.Run("custom error handler", func(t *testing.T) { + lmt := NewLimiter(0.1, nil) // only allow one request per 10 seconds + customMsg := "custom limit reached" + lmt.SetMessage(customMsg) + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + wrapped := HTTPMiddleware(lmt)(handler) + + // first request + w1 := httptest.NewRecorder() + r1 := httptest.NewRequest(http.MethodGet, "/test", nil) + r1.RemoteAddr = "127.0.0.1:12345" + wrapped.ServeHTTP(w1, r1) + if w1.Code != http.StatusOK { + t.Errorf("first request: expected status %d, got %d", http.StatusOK, w1.Code) + } + + // immediate second request should fail + w2 := httptest.NewRecorder() + r2 := httptest.NewRequest(http.MethodGet, "/test", nil) + r2.RemoteAddr = "127.0.0.1:12345" + wrapped.ServeHTTP(w2, r2) + if w2.Code != http.StatusTooManyRequests { + t.Errorf("second request: expected status %d, got %d", http.StatusTooManyRequests, w2.Code) + } + if !strings.Contains(w2.Body.String(), customMsg) { + t.Errorf("expected error message containing %q, got %q", customMsg, w2.Body.String()) + } + }) + + t.Run("custom IP lookup", func(t *testing.T) { + lmt := NewLimiter(0.1, nil) + lmt.SetIPLookup(limiter.IPLookup{Name: "X-Real-IP"}) + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + wrapped := HTTPMiddleware(lmt)(handler) + + // first request with IP1 + w1 := httptest.NewRecorder() + r1 := httptest.NewRequest(http.MethodGet, "/test", nil) + r1.Header.Set("X-Real-IP", "5.5.5.5") + wrapped.ServeHTTP(w1, r1) + if w1.Code != http.StatusOK { + t.Errorf("first request: expected status %d, got %d", http.StatusOK, w1.Code) + } + + // second request with IP1 should fail + w2 := httptest.NewRecorder() + r2 := httptest.NewRequest(http.MethodGet, "/test", nil) + r2.Header.Set("X-Real-IP", "5.5.5.5") + wrapped.ServeHTTP(w2, r2) + if w2.Code != http.StatusTooManyRequests { + t.Errorf("second request: expected status %d, got %d", http.StatusTooManyRequests, w2.Code) + } + + // request with IP2 should pass + w3 := httptest.NewRecorder() + r3 := httptest.NewRequest(http.MethodGet, "/test", nil) + r3.Header.Set("X-Real-IP", "6.6.6.6") + wrapped.ServeHTTP(w3, r3) + if w3.Code != http.StatusOK { + t.Errorf("third request: expected status %d, got %d", http.StatusOK, w3.Code) + } + }) +}