From 9e112dcd276ead8f870fb5c35e894b80d9fe6e6d Mon Sep 17 00:00:00 2001 From: Dmitry Verkhoturov Date: Thu, 12 May 2022 21:39:53 +0200 Subject: [PATCH 01/11] bump golangci-lint, fix discovered problems (#100) --- .github/workflows/ci.yml | 2 +- .golangci.yml | 4 ++-- libstring/libstring.go | 1 - libstring/libstring_test.go | 1 + tollbooth_test.go | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e56894b..1bd413e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,7 @@ jobs: go build -race - name: install golangci-lint - run: curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh| sh -s -- -b $GITHUB_WORKSPACE v1.26.0 + run: curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh| sh -s -- -b $GITHUB_WORKSPACE v1.45.2 - name: run golangci-lint run: $GITHUB_WORKSPACE/golangci-lint run --out-format=github-actions diff --git a/.golangci.yml b/.golangci.yml index 98b2ab1..880786a 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,7 +1,7 @@ linters: enable: - megacheck - - golint + - revive - govet - unconvert - megacheck @@ -18,7 +18,7 @@ linters: - varcheck - stylecheck - gochecknoinits - - scopelint + - exportloopref - gocritic - nakedret - gosimple diff --git a/libstring/libstring.go b/libstring/libstring.go index c9355e2..730b654 100644 --- a/libstring/libstring.go +++ b/libstring/libstring.go @@ -69,7 +69,6 @@ func CanonicalizeIP(ip string) string { case ':': // IPv6 isIPv6 = true - break } } if !isIPv6 { diff --git a/libstring/libstring_test.go b/libstring/libstring_test.go index 0bc8713..08abe3c 100644 --- a/libstring/libstring_test.go +++ b/libstring/libstring_test.go @@ -169,6 +169,7 @@ func TestCanonicalizeIP(t *testing.T) { }, } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { if got := CanonicalizeIP(tt.ip); got != tt.want { t.Errorf("CanonicalizeIP() = %v, want %v", got, tt.want) diff --git a/tollbooth_test.go b/tollbooth_test.go index 2bbf020..eafd3e2 100644 --- a/tollbooth_test.go +++ b/tollbooth_test.go @@ -174,7 +174,7 @@ func TestContextValueBuildKeys(t *testing.T) { } request.Header.Set("X-Real-IP", "172.217.0.46") - //nolint:golint,staticcheck // limiter.SetContextValue requires string as a key, so we have to live with that + //nolint:revive,staticcheck // limiter.SetContextValue requires string as a key, so we have to live with that request = request.WithContext(context.WithValue(request.Context(), "API-access-level", "basic")) sliceKeys := BuildKeys(lmt, request) @@ -306,7 +306,7 @@ func TestRequestMethodCustomHeadersAndBasicAuthUsersAndContextValuesBuildKeys(t request.Header.Set("X-Real-IP", "172.217.0.46") request.Header.Set("X-Auth-Token", "totally-top-secret") request.SetBasicAuth("bro", "tato") - //nolint:golint,staticcheck // limiter.SetContextValue requires string as a key, so we have to live with that + //nolint:revive,staticcheck // limiter.SetContextValue requires string as a key, so we have to live with that request = request.WithContext(context.WithValue(request.Context(), "API-access-level", "basic")) sliceKeys := BuildKeys(lmt, request) From a7634c70944aeafb7a043be87bb5e43e151689e6 Mon Sep 17 00:00:00 2001 From: Dmitry Verkhoturov Date: Mon, 6 Jun 2022 17:29:38 +0200 Subject: [PATCH 02/11] update go modules (#101) --- go.mod | 6 ++---- go.sum | 25 ++++++++----------------- 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/go.mod b/go.mod index 500727a..cc721be 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,6 @@ module github.com/didip/tollbooth/v6 go 1.12 require ( - github.com/go-pkgz/expirable-cache v0.0.3 - github.com/kr/pretty v0.1.0 // indirect - golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 - gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + github.com/go-pkgz/expirable-cache v0.1.0 + golang.org/x/time v0.0.0-20220411224347-583f2d630306 ) diff --git a/go.sum b/go.sum index b48c5c0..2a51c2f 100644 --- a/go.sum +++ b/go.sum @@ -1,24 +1,15 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-pkgz/expirable-cache v0.0.3 h1:rTh6qNPp78z0bQE6HDhXBHUwqnV9i09Vm6dksJLXQDc= -github.com/go-pkgz/expirable-cache v0.0.3/go.mod h1:+IauqN00R2FqNRLCLA+X5YljQJrwB179PfiAoMPlTlQ= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/go-pkgz/expirable-cache v0.1.0 h1:3bw0m8vlTK8qlwz5KXuygNBTkiKRTPrAGXU0Ej2AC1g= +github.com/go-pkgz/expirable-cache v0.1.0/go.mod h1:GTrEl0X+q0mPNqN6dtcQXksACnzCBQ5k/k1SwXJsZKs= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 h1:NusfzzA6yGQ+ua51ck7E3omNUX/JuqbFSaRGqU8CcLI= -golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/time v0.0.0-20220411224347-583f2d630306 h1:+gHMid33q6pen7kv9xvT+JRinntgeXO2AeZVd0AWD3w= +golang.org/x/time v0.0.0-20220411224347-583f2d630306/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 49cb666ddb18bdc0c60dfb49a88dc9d935ca5b83 Mon Sep 17 00:00:00 2001 From: Dmitry Verkhoturov Date: Mon, 4 Jul 2022 03:41:46 +0200 Subject: [PATCH 03/11] Embed time/rate, set RateLimit headers (#96) * embed time/rate package * set RateLimit headers --- README.md | 7 + go.mod | 5 +- go.sum | 2 - internal/time/AUTHORS | 3 + internal/time/CONTRIBUTING.md | 26 ++ internal/time/CONTRIBUTORS | 3 + internal/time/LICENSE | 27 ++ internal/time/PATENTS | 22 ++ internal/time/README.md | 19 ++ internal/time/rate/rate.go | 396 ++++++++++++++++++++++++++ internal/time/rate/rate_test.go | 482 ++++++++++++++++++++++++++++++++ limiter/limiter.go | 13 +- tollbooth.go | 31 +- tollbooth_test.go | 27 ++ 14 files changed, 1053 insertions(+), 10 deletions(-) create mode 100644 internal/time/AUTHORS create mode 100644 internal/time/CONTRIBUTING.md create mode 100644 internal/time/CONTRIBUTORS create mode 100644 internal/time/LICENSE create mode 100644 internal/time/PATENTS create mode 100644 internal/time/README.md create mode 100644 internal/time/rate/rate.go create mode 100644 internal/time/rate/rate_test.go diff --git a/README.md b/README.md index 31f37b5..aeaced3 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,13 @@ func main() { * `X-Rate-Limit-Request-Remote-Addr` The rejected request `RemoteAddr`. + Upon both success and rejection [RateLimit](https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers) headers are sent: + + * `RateLimit-Limit` The maximum request limit within the time window (1s). + + * `RateLimit-Reset` The rate-limiter time window duration in seconds (always 1s). + + * `RateLimit-Remaining` The remaining tokens. 5. Customize your own message or function when limit is reached. diff --git a/go.mod b/go.mod index cc721be..ab3ff2b 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,4 @@ module github.com/didip/tollbooth/v6 go 1.12 -require ( - github.com/go-pkgz/expirable-cache v0.1.0 - golang.org/x/time v0.0.0-20220411224347-583f2d630306 -) +require github.com/go-pkgz/expirable-cache v0.1.0 diff --git a/go.sum b/go.sum index 2a51c2f..fdcdb33 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -golang.org/x/time v0.0.0-20220411224347-583f2d630306 h1:+gHMid33q6pen7kv9xvT+JRinntgeXO2AeZVd0AWD3w= -golang.org/x/time v0.0.0-20220411224347-583f2d630306/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= diff --git a/internal/time/AUTHORS b/internal/time/AUTHORS new file mode 100644 index 0000000..15167cd --- /dev/null +++ b/internal/time/AUTHORS @@ -0,0 +1,3 @@ +# This source code refers to The Go Authors for copyright purposes. +# The master list of authors is in the main Go distribution, +# visible at http://tip.golang.org/AUTHORS. diff --git a/internal/time/CONTRIBUTING.md b/internal/time/CONTRIBUTING.md new file mode 100644 index 0000000..d0485e8 --- /dev/null +++ b/internal/time/CONTRIBUTING.md @@ -0,0 +1,26 @@ +# Contributing to Go + +Go is an open source project. + +It is the work of hundreds of contributors. We appreciate your help! + +## Filing issues + +When [filing an issue](https://golang.org/issue/new), make sure to answer these five questions: + +1. What version of Go are you using (`go version`)? +2. What operating system and processor architecture are you using? +3. What did you do? +4. What did you expect to see? +5. What did you see instead? + +General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker. +The gophers there will answer or ask you to file an issue if you've tripped over a bug. + +## Contributing code + +Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html) +before sending patches. + +Unless otherwise noted, the Go source files are distributed under +the BSD-style license found in the LICENSE file. diff --git a/internal/time/CONTRIBUTORS b/internal/time/CONTRIBUTORS new file mode 100644 index 0000000..1c4577e --- /dev/null +++ b/internal/time/CONTRIBUTORS @@ -0,0 +1,3 @@ +# This source code was written by the Go contributors. +# The master list of contributors is in the main Go distribution, +# visible at http://tip.golang.org/CONTRIBUTORS. diff --git a/internal/time/LICENSE b/internal/time/LICENSE new file mode 100644 index 0000000..6a66aea --- /dev/null +++ b/internal/time/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/internal/time/PATENTS b/internal/time/PATENTS new file mode 100644 index 0000000..7330990 --- /dev/null +++ b/internal/time/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/internal/time/README.md b/internal/time/README.md new file mode 100644 index 0000000..705a497 --- /dev/null +++ b/internal/time/README.md @@ -0,0 +1,19 @@ +# Go Time + +[![Go Reference](https://pkg.go.dev/badge/golang.org/x/time.svg)](https://pkg.go.dev/golang.org/x/time) + +This repository provides supplementary Go time packages. + +## Download/Install + +The easiest way to install is to run `go get -u golang.org/x/time`. You can +also manually git clone the repository to `$GOPATH/src/golang.org/x/time`. + +## Report Issues / Send Patches + +This repository uses Gerrit for code changes. To learn how to submit changes to +this repository, see https://golang.org/doc/contribute.html. + +The main issue tracker for the time repository is located at +https://github.com/golang/go/issues. Prefix your issue with "x/time:" in the +subject line, so it is easy to find. diff --git a/internal/time/rate/rate.go b/internal/time/rate/rate.go new file mode 100644 index 0000000..6c3b442 --- /dev/null +++ b/internal/time/rate/rate.go @@ -0,0 +1,396 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package rate provides a rate limiter. +package rate + +import ( + "context" + "fmt" + "math" + "sync" + "time" +) + +// Limit defines the maximum frequency of some events. +// Limit is represented as number of events per second. +// A zero Limit allows no events. +type Limit float64 + +// Inf is the infinite rate limit; it allows all events (even if burst is zero). +const Inf = Limit(math.MaxFloat64) + +// Every converts a minimum time interval between events to a Limit. +func Every(interval time.Duration) Limit { + if interval <= 0 { + return Inf + } + return 1 / Limit(interval.Seconds()) +} + +// A Limiter controls how frequently events are allowed to happen. +// It implements a "token bucket" of size b, initially full and refilled +// at rate r tokens per second. +// Informally, in any large enough time interval, the Limiter limits the +// rate to r tokens per second, with a maximum burst size of b events. +// As a special case, if r == Inf (the infinite rate), b is ignored. +// See https://en.wikipedia.org/wiki/Token_bucket for more about token buckets. +// +// The zero value is a valid Limiter, but it will reject all events. +// Use NewLimiter to create non-zero Limiters. +// +// Limiter has three main methods, Allow, Reserve, and Wait. +// Most callers should use Wait. +// +// Each of the three methods consumes a single token. +// They differ in their behavior when no token is available. +// If no token is available, Allow returns false. +// If no token is available, Reserve returns a reservation for a future token +// and the amount of time the caller must wait before using it. +// If no token is available, Wait blocks until one can be obtained +// or its associated context.Context is canceled. +// +// The methods AllowN, ReserveN, and WaitN consume n tokens. +type Limiter struct { + mu sync.Mutex + limit Limit + burst int + tokens float64 + // last is the last time the limiter's tokens field was updated + last time.Time + // lastEvent is the latest time of a rate-limited event (past or future) + lastEvent time.Time +} + +// Limit returns the maximum overall event rate. +func (lim *Limiter) Limit() Limit { + lim.mu.Lock() + defer lim.mu.Unlock() + return lim.limit +} + +// Burst returns the maximum burst size. Burst is the maximum number of tokens +// that can be consumed in a single call to Allow, Reserve, or Wait, so higher +// Burst values allow more events to happen at once. +// A zero Burst allows no events, unless limit == Inf. +func (lim *Limiter) Burst() int { + lim.mu.Lock() + defer lim.mu.Unlock() + return lim.burst +} + +// NewLimiter returns a new Limiter that allows events up to rate r and permits +// bursts of at most b tokens. +func NewLimiter(r Limit, b int) *Limiter { + return &Limiter{ + limit: r, + burst: b, + } +} + +// Allow is shorthand for AllowN(time.Now(), 1). +func (lim *Limiter) Allow() bool { + return lim.AllowN(time.Now(), 1) +} + +// TokensAt returns the number of tokens available for the given time. +func (lim *Limiter) TokensAt(t time.Time) float64 { + lim.mu.Lock() + _, _, tokens := lim.advance(t) // does not mutate lim + lim.mu.Unlock() + return tokens +} + +// AllowN reports whether n events may happen at time now. +// Use this method if you intend to drop / skip events that exceed the rate limit. +// Otherwise use Reserve or Wait. +func (lim *Limiter) AllowN(now time.Time, n int) bool { + return lim.reserveN(now, n, 0).ok +} + +// A Reservation holds information about events that are permitted by a Limiter to happen after a delay. +// A Reservation may be canceled, which may enable the Limiter to permit additional events. +type Reservation struct { + ok bool + lim *Limiter + tokens int + timeToAct time.Time + // This is the Limit at reservation time, it can change later. + limit Limit +} + +// OK returns whether the limiter can provide the requested number of tokens +// within the maximum wait time. If OK is false, Delay returns InfDuration, and +// Cancel does nothing. +func (r *Reservation) OK() bool { + return r.ok +} + +// Delay is shorthand for DelayFrom(time.Now()). +func (r *Reservation) Delay() time.Duration { + return r.DelayFrom(time.Now()) +} + +// InfDuration is the duration returned by Delay when a Reservation is not OK. +const InfDuration = time.Duration(1<<63 - 1) + +// DelayFrom returns the duration for which the reservation holder must wait +// before taking the reserved action. Zero duration means act immediately. +// InfDuration means the limiter cannot grant the tokens requested in this +// Reservation within the maximum wait time. +func (r *Reservation) DelayFrom(now time.Time) time.Duration { + if !r.ok { + return InfDuration + } + delay := r.timeToAct.Sub(now) + if delay < 0 { + return 0 + } + return delay +} + +// Cancel is shorthand for CancelAt(time.Now()). +func (r *Reservation) Cancel() { + r.CancelAt(time.Now()) +} + +// CancelAt indicates that the reservation holder will not perform the reserved action +// and reverses the effects of this Reservation on the rate limit as much as possible, +// considering that other reservations may have already been made. +func (r *Reservation) CancelAt(now time.Time) { + if !r.ok { + return + } + + r.lim.mu.Lock() + defer r.lim.mu.Unlock() + + if r.lim.limit == Inf || r.tokens == 0 || r.timeToAct.Before(now) { + return + } + + // calculate tokens to restore + // The duration between lim.lastEvent and r.timeToAct tells us how many tokens were reserved + // after r was obtained. These tokens should not be restored. + restoreTokens := float64(r.tokens) - r.limit.tokensFromDuration(r.lim.lastEvent.Sub(r.timeToAct)) + if restoreTokens <= 0 { + return + } + // advance time to now + now, _, tokens := r.lim.advance(now) + // calculate new number of tokens + tokens += restoreTokens + if burst := float64(r.lim.burst); tokens > burst { + tokens = burst + } + // update state + r.lim.last = now + r.lim.tokens = tokens + if r.timeToAct == r.lim.lastEvent { + prevEvent := r.timeToAct.Add(r.limit.durationFromTokens(float64(-r.tokens))) + if !prevEvent.Before(now) { + r.lim.lastEvent = prevEvent + } + } +} + +// Reserve is shorthand for ReserveN(time.Now(), 1). +func (lim *Limiter) Reserve() *Reservation { + return lim.ReserveN(time.Now(), 1) +} + +// ReserveN returns a Reservation that indicates how long the caller must wait before n events happen. +// The Limiter takes this Reservation into account when allowing future events. +// The returned Reservation’s OK() method returns false if n exceeds the Limiter's burst size. +// Usage example: +// r := lim.ReserveN(time.Now(), 1) +// if !r.OK() { +// // Not allowed to act! Did you remember to set lim.burst to be > 0 ? +// return +// } +// time.Sleep(r.Delay()) +// Act() +// Use this method if you wish to wait and slow down in accordance with the rate limit without dropping events. +// If you need to respect a deadline or cancel the delay, use Wait instead. +// To drop or skip events exceeding rate limit, use Allow instead. +func (lim *Limiter) ReserveN(now time.Time, n int) *Reservation { + r := lim.reserveN(now, n, InfDuration) + return &r +} + +// Wait is shorthand for WaitN(ctx, 1). +func (lim *Limiter) Wait(ctx context.Context) (err error) { + return lim.WaitN(ctx, 1) +} + +// WaitN blocks until lim permits n events to happen. +// It returns an error if n exceeds the Limiter's burst size, the Context is +// canceled, or the expected wait time exceeds the Context's Deadline. +// The burst limit is ignored if the rate limit is Inf. +func (lim *Limiter) WaitN(ctx context.Context, n int) (err error) { + lim.mu.Lock() + burst := lim.burst + limit := lim.limit + lim.mu.Unlock() + + if n > burst && limit != Inf { + return fmt.Errorf("rate: Wait(n=%d) exceeds limiter's burst %d", n, burst) + } + // Check if ctx is already cancelled + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + // Determine wait limit + now := time.Now() + waitLimit := InfDuration + if deadline, ok := ctx.Deadline(); ok { + waitLimit = deadline.Sub(now) + } + // Reserve + r := lim.reserveN(now, n, waitLimit) + if !r.ok { + return fmt.Errorf("rate: Wait(n=%d) would exceed context deadline", n) + } + // Wait if necessary + delay := r.DelayFrom(now) + if delay == 0 { + return nil + } + t := time.NewTimer(delay) + defer t.Stop() + select { + case <-t.C: + // We can proceed. + return nil + case <-ctx.Done(): + // Context was canceled before we could proceed. Cancel the + // reservation, which may permit other events to proceed sooner. + r.Cancel() + return ctx.Err() + } +} + +// SetLimit is shorthand for SetLimitAt(time.Now(), newLimit). +func (lim *Limiter) SetLimit(newLimit Limit) { + lim.SetLimitAt(time.Now(), newLimit) +} + +// SetLimitAt sets a new Limit for the limiter. The new Limit, and Burst, may be violated +// or underutilized by those which reserved (using Reserve or Wait) but did not yet act +// before SetLimitAt was called. +func (lim *Limiter) SetLimitAt(now time.Time, newLimit Limit) { + lim.mu.Lock() + defer lim.mu.Unlock() + + now, _, tokens := lim.advance(now) + + lim.last = now + lim.tokens = tokens + lim.limit = newLimit +} + +// SetBurst is shorthand for SetBurstAt(time.Now(), newBurst). +func (lim *Limiter) SetBurst(newBurst int) { + lim.SetBurstAt(time.Now(), newBurst) +} + +// SetBurstAt sets a new burst size for the limiter. +func (lim *Limiter) SetBurstAt(now time.Time, newBurst int) { + lim.mu.Lock() + defer lim.mu.Unlock() + + now, _, tokens := lim.advance(now) + + lim.last = now + lim.tokens = tokens + lim.burst = newBurst +} + +// reserveN is a helper method for AllowN, ReserveN, and WaitN. +// maxFutureReserve specifies the maximum reservation wait duration allowed. +// reserveN returns Reservation, not *Reservation, to avoid allocation in AllowN and WaitN. +func (lim *Limiter) reserveN(now time.Time, n int, maxFutureReserve time.Duration) Reservation { + lim.mu.Lock() + + if lim.limit == Inf { + lim.mu.Unlock() + return Reservation{ + ok: true, + lim: lim, + tokens: n, + timeToAct: now, + } + } + + now, last, tokens := lim.advance(now) + + // Calculate the remaining number of tokens resulting from the request. + tokens -= float64(n) + + // Calculate the wait duration + var waitDuration time.Duration + if tokens < 0 { + waitDuration = lim.limit.durationFromTokens(-tokens) + } + + // Decide result + ok := n <= lim.burst && waitDuration <= maxFutureReserve + + // Prepare reservation + r := Reservation{ + ok: ok, + lim: lim, + limit: lim.limit, + } + if ok { + r.tokens = n + r.timeToAct = now.Add(waitDuration) + } + + // Update state + if ok { + lim.last = now + lim.tokens = tokens + lim.lastEvent = r.timeToAct + } else { + lim.last = last + } + + lim.mu.Unlock() + return r +} + +// advance calculates and returns an updated state for lim resulting from the passage of time. +// lim is not changed. +// advance requires that lim.mu is held. +func (lim *Limiter) advance(now time.Time) (newNow time.Time, newLast time.Time, newTokens float64) { + last := lim.last + if now.Before(last) { + last = now + } + + // Calculate the new number of tokens, due to time that passed. + elapsed := now.Sub(last) + delta := lim.limit.tokensFromDuration(elapsed) + tokens := lim.tokens + delta + if burst := float64(lim.burst); tokens > burst { + tokens = burst + } + return now, last, tokens +} + +// durationFromTokens is a unit conversion function from the number of tokens to the duration +// of time it takes to accumulate them at a rate of limit tokens per second. +func (limit Limit) durationFromTokens(tokens float64) time.Duration { + seconds := tokens / float64(limit) + return time.Duration(float64(time.Second) * seconds) +} + +// tokensFromDuration is a unit conversion function from a time duration to the number of tokens +// which could be accumulated during that duration at a rate of limit tokens per second. +func (limit Limit) tokensFromDuration(d time.Duration) float64 { + return d.Seconds() * float64(limit) +} diff --git a/internal/time/rate/rate_test.go b/internal/time/rate/rate_test.go new file mode 100644 index 0000000..1c5e9e7 --- /dev/null +++ b/internal/time/rate/rate_test.go @@ -0,0 +1,482 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.7 +// +build go1.7 + +package rate + +import ( + "context" + "math" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestLimit(t *testing.T) { + if Limit(10) == Inf { + t.Errorf("Limit(10) == Inf should be false") + } +} + +func closeEnough(a, b Limit) bool { + return (math.Abs(float64(a)/float64(b)) - 1.0) < 1e-9 +} + +func TestEvery(t *testing.T) { + cases := []struct { + interval time.Duration + lim Limit + }{ + {0, Inf}, + {-1, Inf}, + {1 * time.Nanosecond, Limit(1e9)}, + {1 * time.Microsecond, Limit(1e6)}, + {1 * time.Millisecond, Limit(1e3)}, + {10 * time.Millisecond, Limit(100)}, + {100 * time.Millisecond, Limit(10)}, + {1 * time.Second, Limit(1)}, + {2 * time.Second, Limit(0.5)}, + {time.Duration(2.5 * float64(time.Second)), Limit(0.4)}, + {4 * time.Second, Limit(0.25)}, + {10 * time.Second, Limit(0.1)}, + {time.Duration(math.MaxInt64), Limit(1e9 / float64(math.MaxInt64))}, + } + for _, tc := range cases { + lim := Every(tc.interval) + if !closeEnough(lim, tc.lim) { + t.Errorf("Every(%v) = %v want %v", tc.interval, lim, tc.lim) + } + } +} + +const ( + d = 100 * time.Millisecond +) + +var ( + t0 = time.Now() + t1 = t0.Add(time.Duration(1) * d) + t2 = t0.Add(time.Duration(2) * d) + t3 = t0.Add(time.Duration(3) * d) + t4 = t0.Add(time.Duration(4) * d) + t5 = t0.Add(time.Duration(5) * d) + t9 = t0.Add(time.Duration(9) * d) +) + +type allow struct { + t time.Time + n int + ok bool +} + +func run(t *testing.T, lim *Limiter, allows []allow) { + t.Helper() + for i, allow := range allows { + ok := lim.AllowN(allow.t, allow.n) + if ok != allow.ok { + t.Errorf("step %d: lim.AllowN(%v, %v) = %v want %v", + i, allow.t, allow.n, ok, allow.ok) + } + } +} + +func TestLimiterBurst1(t *testing.T) { + run(t, NewLimiter(10, 1), []allow{ + {t0, 1, true}, + {t0, 1, false}, + {t0, 1, false}, + {t1, 1, true}, + {t1, 1, false}, + {t1, 1, false}, + {t2, 2, false}, // burst size is 1, so n=2 always fails + {t2, 1, true}, + {t2, 1, false}, + }) +} + +func TestLimiterBurst3(t *testing.T) { + run(t, NewLimiter(10, 3), []allow{ + {t0, 2, true}, + {t0, 2, false}, + {t0, 1, true}, + {t0, 1, false}, + {t1, 4, false}, + {t2, 1, true}, + {t3, 1, true}, + {t4, 1, true}, + {t4, 1, true}, + {t4, 1, false}, + {t4, 1, false}, + {t9, 3, true}, + {t9, 0, true}, + }) +} + +func TestLimiterJumpBackwards(t *testing.T) { + run(t, NewLimiter(10, 3), []allow{ + {t1, 1, true}, // start at t1 + {t0, 1, true}, // jump back to t0, two tokens remain + {t0, 1, true}, + {t0, 1, false}, + {t0, 1, false}, + {t1, 1, true}, // got a token + {t1, 1, false}, + {t1, 1, false}, + {t2, 1, true}, // got another token + {t2, 1, false}, + {t2, 1, false}, + }) +} + +// Ensure that tokensFromDuration doesn't produce +// rounding errors by truncating nanoseconds. +// See golang.org/issues/34861. +func TestLimiter_noTruncationErrors(t *testing.T) { + if !NewLimiter(0.7692307692307693, 1).Allow() { + t.Fatal("expected true") + } +} + +func TestSimultaneousRequests(t *testing.T) { + const ( + limit = 1 + burst = 5 + numRequests = 15 + ) + var ( + wg sync.WaitGroup + numOK = uint32(0) + ) + + // Very slow replenishing bucket. + lim := NewLimiter(limit, burst) + + // Tries to take a token, atomically updates the counter and decreases the wait + // group counter. + f := func() { + defer wg.Done() + if ok := lim.Allow(); ok { + atomic.AddUint32(&numOK, 1) + } + } + + wg.Add(numRequests) + for i := 0; i < numRequests; i++ { + go f() + } + wg.Wait() + if numOK != burst { + t.Errorf("numOK = %d, want %d", numOK, burst) + } +} + +func TestLongRunningQPS(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + if runtime.GOOS == "openbsd" { + t.Skip("low resolution time.Sleep invalidates test (golang.org/issue/14183)") + return + } + + // The test runs for a few seconds executing many requests and then checks + // that overall number of requests is reasonable. + const ( + limit = 100 + burst = 100 + ) + var numOK = int32(0) + + lim := NewLimiter(limit, burst) + + var wg sync.WaitGroup + f := func() { + if ok := lim.Allow(); ok { + atomic.AddInt32(&numOK, 1) + } + wg.Done() + } + + start := time.Now() + end := start.Add(5 * time.Second) + for time.Now().Before(end) { + wg.Add(1) + go f() + + // This will still offer ~500 requests per second, but won't consume + // outrageous amount of CPU. + time.Sleep(2 * time.Millisecond) + } + wg.Wait() + elapsed := time.Since(start) + ideal := burst + (limit * float64(elapsed) / float64(time.Second)) + + // We should never get more requests than allowed. + if want := int32(ideal + 1); numOK > want { + t.Errorf("numOK = %d, want %d (ideal %f)", numOK, want, ideal) + } + // We should get very close to the number of requests allowed. + if want := int32(0.999 * ideal); numOK < want { + t.Errorf("numOK = %d, want %d (ideal %f)", numOK, want, ideal) + } +} + +type request struct { + t time.Time + n int + act time.Time + ok bool +} + +// dFromDuration converts a duration to a multiple of the global constant d +func dFromDuration(dur time.Duration) int { + // Adding a millisecond to be swallowed by the integer division + // because we don't care about small inaccuracies + return int((dur + time.Millisecond) / d) +} + +// dSince returns multiples of d since t0 +func dSince(t time.Time) int { + return dFromDuration(t.Sub(t0)) +} + +func runReserve(t *testing.T, lim *Limiter, req request) *Reservation { + t.Helper() + return runReserveMax(t, lim, req, InfDuration) +} + +func runReserveMax(t *testing.T, lim *Limiter, req request, maxReserve time.Duration) *Reservation { + t.Helper() + r := lim.reserveN(req.t, req.n, maxReserve) + if r.ok && (dSince(r.timeToAct) != dSince(req.act)) || r.ok != req.ok { + t.Errorf("lim.reserveN(t%d, %v, %v) = (t%d, %v) want (t%d, %v)", + dSince(req.t), req.n, maxReserve, dSince(r.timeToAct), r.ok, dSince(req.act), req.ok) + } + return &r +} + +func TestSimpleReserve(t *testing.T) { + lim := NewLimiter(10, 2) + + runReserve(t, lim, request{t0, 2, t0, true}) + runReserve(t, lim, request{t0, 2, t2, true}) + runReserve(t, lim, request{t3, 2, t4, true}) +} + +func TestMix(t *testing.T) { + lim := NewLimiter(10, 2) + + runReserve(t, lim, request{t0, 3, t1, false}) // should return false because n > Burst + runReserve(t, lim, request{t0, 2, t0, true}) + run(t, lim, []allow{{t1, 2, false}}) // not enough tokens - don't allow + runReserve(t, lim, request{t1, 2, t2, true}) + run(t, lim, []allow{{t1, 1, false}}) // negative tokens - don't allow + run(t, lim, []allow{{t3, 1, true}}) +} + +func TestCancelInvalid(t *testing.T) { + lim := NewLimiter(10, 2) + + runReserve(t, lim, request{t0, 2, t0, true}) + r := runReserve(t, lim, request{t0, 3, t3, false}) + r.CancelAt(t0) // should have no effect + runReserve(t, lim, request{t0, 2, t2, true}) // did not get extra tokens +} + +func TestCancelLast(t *testing.T) { + lim := NewLimiter(10, 2) + + runReserve(t, lim, request{t0, 2, t0, true}) + r := runReserve(t, lim, request{t0, 2, t2, true}) + r.CancelAt(t1) // got 2 tokens back + runReserve(t, lim, request{t1, 2, t2, true}) +} + +func TestCancelTooLate(t *testing.T) { + lim := NewLimiter(10, 2) + + runReserve(t, lim, request{t0, 2, t0, true}) + r := runReserve(t, lim, request{t0, 2, t2, true}) + r.CancelAt(t3) // too late to cancel - should have no effect + runReserve(t, lim, request{t3, 2, t4, true}) +} + +func TestCancel0Tokens(t *testing.T) { + lim := NewLimiter(10, 2) + + runReserve(t, lim, request{t0, 2, t0, true}) + r := runReserve(t, lim, request{t0, 1, t1, true}) + runReserve(t, lim, request{t0, 1, t2, true}) + r.CancelAt(t0) // got 0 tokens back + runReserve(t, lim, request{t0, 1, t3, true}) +} + +func TestCancel1Token(t *testing.T) { + lim := NewLimiter(10, 2) + + runReserve(t, lim, request{t0, 2, t0, true}) + r := runReserve(t, lim, request{t0, 2, t2, true}) + runReserve(t, lim, request{t0, 1, t3, true}) + r.CancelAt(t2) // got 1 token back + runReserve(t, lim, request{t2, 2, t4, true}) +} + +func TestCancelMulti(t *testing.T) { + lim := NewLimiter(10, 4) + + runReserve(t, lim, request{t0, 4, t0, true}) + rA := runReserve(t, lim, request{t0, 3, t3, true}) + runReserve(t, lim, request{t0, 1, t4, true}) + rC := runReserve(t, lim, request{t0, 1, t5, true}) + rC.CancelAt(t1) // get 1 token back + rA.CancelAt(t1) // get 2 tokens back, as if C was never reserved + runReserve(t, lim, request{t1, 3, t5, true}) +} + +func TestReserveJumpBack(t *testing.T) { + lim := NewLimiter(10, 2) + + runReserve(t, lim, request{t1, 2, t1, true}) // start at t1 + runReserve(t, lim, request{t0, 1, t1, true}) // should violate Limit,Burst + runReserve(t, lim, request{t2, 2, t3, true}) +} + +func TestReserveJumpBackCancel(t *testing.T) { + lim := NewLimiter(10, 2) + + runReserve(t, lim, request{t1, 2, t1, true}) // start at t1 + r := runReserve(t, lim, request{t1, 2, t3, true}) + runReserve(t, lim, request{t1, 1, t4, true}) + r.CancelAt(t0) // cancel at t0, get 1 token back + runReserve(t, lim, request{t1, 2, t4, true}) // should violate Limit,Burst +} + +func TestReserveSetLimit(t *testing.T) { + lim := NewLimiter(5, 2) + + runReserve(t, lim, request{t0, 2, t0, true}) + runReserve(t, lim, request{t0, 2, t4, true}) + lim.SetLimitAt(t2, 10) + runReserve(t, lim, request{t2, 1, t4, true}) // violates Limit and Burst +} + +func TestReserveSetBurst(t *testing.T) { + lim := NewLimiter(5, 2) + + runReserve(t, lim, request{t0, 2, t0, true}) + runReserve(t, lim, request{t0, 2, t4, true}) + lim.SetBurstAt(t3, 4) + runReserve(t, lim, request{t0, 4, t9, true}) // violates Limit and Burst +} + +func TestReserveSetLimitCancel(t *testing.T) { + lim := NewLimiter(5, 2) + + runReserve(t, lim, request{t0, 2, t0, true}) + r := runReserve(t, lim, request{t0, 2, t4, true}) + lim.SetLimitAt(t2, 10) + r.CancelAt(t2) // 2 tokens back + runReserve(t, lim, request{t2, 2, t3, true}) +} + +func TestReserveMax(t *testing.T) { + lim := NewLimiter(10, 2) + maxT := d + + runReserveMax(t, lim, request{t0, 2, t0, true}, maxT) + runReserveMax(t, lim, request{t0, 1, t1, true}, maxT) // reserve for close future + runReserveMax(t, lim, request{t0, 1, t2, false}, maxT) // time to act too far in the future +} + +type wait struct { + name string + ctx context.Context + n int + delay int // in multiples of d + nilErr bool +} + +func runWait(t *testing.T, lim *Limiter, w wait) { + t.Helper() + start := time.Now() + err := lim.WaitN(w.ctx, w.n) + delay := time.Since(start) + if (w.nilErr && err != nil) || (!w.nilErr && err == nil) || w.delay != dFromDuration(delay) { + errString := "" + if !w.nilErr { + errString = "" + } + t.Errorf("lim.WaitN(%v, lim, %v) = %v with delay %v ; want %v with delay %v", + w.name, w.n, err, delay, errString, d*time.Duration(w.delay)) + } +} + +func TestWaitSimple(t *testing.T) { + lim := NewLimiter(10, 3) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + runWait(t, lim, wait{"already-cancelled", ctx, 1, 0, false}) + + runWait(t, lim, wait{"exceed-burst-error", context.Background(), 4, 0, false}) + + runWait(t, lim, wait{"act-now", context.Background(), 2, 0, true}) + runWait(t, lim, wait{"act-later", context.Background(), 3, 2, true}) +} + +func TestWaitCancel(t *testing.T) { + lim := NewLimiter(10, 3) + + ctx, cancel := context.WithCancel(context.Background()) + runWait(t, lim, wait{"act-now", ctx, 2, 0, true}) // after this lim.tokens = 1 + go func() { + time.Sleep(d) + cancel() + }() + runWait(t, lim, wait{"will-cancel", ctx, 3, 1, false}) + // should get 3 tokens back, and have lim.tokens = 2 + t.Logf("tokens:%v last:%v lastEvent:%v", lim.tokens, lim.last, lim.lastEvent) + runWait(t, lim, wait{"act-now-after-cancel", context.Background(), 2, 0, true}) +} + +func TestWaitTimeout(t *testing.T) { + lim := NewLimiter(10, 3) + + ctx, cancel := context.WithTimeout(context.Background(), d) + defer cancel() + runWait(t, lim, wait{"act-now", ctx, 2, 0, true}) + runWait(t, lim, wait{"w-timeout-err", ctx, 3, 0, false}) +} + +func TestWaitInf(t *testing.T) { + lim := NewLimiter(Inf, 0) + + runWait(t, lim, wait{"exceed-burst-no-error", context.Background(), 3, 0, true}) +} + +func BenchmarkAllowN(b *testing.B) { + lim := NewLimiter(Every(1*time.Second), 1) + now := time.Now() + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + lim.AllowN(now, 1) + } + }) +} + +func BenchmarkWaitNNoDelay(b *testing.B) { + lim := NewLimiter(Limit(b.N), b.N) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + lim.WaitN(ctx, 1) + } +} diff --git a/limiter/limiter.go b/limiter/limiter.go index 4a7a08a..33b3e36 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -7,7 +7,8 @@ import ( "time" cache "github.com/go-pkgz/expirable-cache" - "golang.org/x/time/rate" + + "github.com/didip/tollbooth/v6/internal/time/rate" ) // New is a constructor for Limiter. @@ -597,3 +598,13 @@ func (l *Limiter) LimitReached(key string) bool { return l.limitReachedWithTokenBucketTTL(key, ttl) } + +// Tokens returns current amount of tokens left in the Bucket identified by key. +func (l *Limiter) Tokens(key string) int { + expiringMap, found := l.tokenBuckets.Get(key) + if !found { + return 0 + } + + return int(expiringMap.(*rate.Limiter).TokensAt(time.Now())) +} diff --git a/tollbooth.go b/tollbooth.go index d3fa240..a17271e 100644 --- a/tollbooth.go +++ b/tollbooth.go @@ -25,6 +25,14 @@ func setResponseHeaders(lmt *limiter.Limiter, w http.ResponseWriter, r *http.Req w.Header().Add("X-Rate-Limit-Request-Remote-Addr", r.RemoteAddr) } +// setRateLimitResponseHeaders configures RateLimit-Limit, RateLimit-Remaining and RateLimit-Reset +// as seen at https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers +func setRateLimitResponseHeaders(lmt *limiter.Limiter, w http.ResponseWriter, tokensLeft int) { + w.Header().Add("RateLimit-Limit", fmt.Sprintf("%d", int(math.Round(lmt.GetMax())))) + w.Header().Add("RateLimit-Reset", "1") + w.Header().Add("RateLimit-Remaining", fmt.Sprintf("%d", tokensLeft)) +} + // NewLimiter is a convenience function to limiter.New. func NewLimiter(max float64, tbOptions *limiter.ExpirableOptions) *limiter.Limiter { return limiter.New(tbOptions). @@ -36,11 +44,18 @@ func NewLimiter(max float64, tbOptions *limiter.ExpirableOptions) *limiter.Limit // LimitByKeys keeps track number of request made by keys separated by pipe. // It returns HTTPError when limit is exceeded. func LimitByKeys(lmt *limiter.Limiter, keys []string) *errors.HTTPError { + err, _ := LimitByKeysAndReturn(lmt, keys) + return err +} + +// LimitByKeysAndReturn keeps track number of request made by keys separated by pipe. +// It returns HTTPError when limit is exceeded, and also returns the current limit value. +func LimitByKeysAndReturn(lmt *limiter.Limiter, keys []string) (*errors.HTTPError, int) { if lmt.LimitReached(strings.Join(keys, "|")) { - return &errors.HTTPError{Message: lmt.GetMessage(), StatusCode: lmt.GetStatusCode()} + return &errors.HTTPError{Message: lmt.GetMessage(), StatusCode: lmt.GetStatusCode()}, 0 } - return nil + return nil, lmt.Tokens(strings.Join(keys, "|")) } // ShouldSkipLimiter is a series of filter that decides if request should be limited or not. @@ -281,14 +296,24 @@ func LimitByRequest(lmt *limiter.Limiter, w http.ResponseWriter, r *http.Request sliceKeys := BuildKeys(lmt, r) + // Get the lowest value over all keys to return in headers. + // Start with high arbitrary number so that any limit returned would be lower and would + // overwrite the value we start with. + var tokensLeft = math.MaxInt32 + // Loop sliceKeys and check if one of them has error. for _, keys := range sliceKeys { - httpError := LimitByKeys(lmt, keys) + httpError, keysLimit := LimitByKeysAndReturn(lmt, keys) + if tokensLeft > keysLimit { + tokensLeft = keysLimit + } if httpError != nil { + setRateLimitResponseHeaders(lmt, w, tokensLeft) return httpError } } + setRateLimitResponseHeaders(lmt, w, tokensLeft) return nil } diff --git a/tollbooth_test.go b/tollbooth_test.go index eafd3e2..4df67f0 100644 --- a/tollbooth_test.go +++ b/tollbooth_test.go @@ -355,6 +355,16 @@ func TestLimitHandler(t *testing.T) { if status := rr.Code; status != http.StatusOK { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) } + // check RateLimit headers + if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Limit")]; len(value) < 1 || value[0] != "1" { + t.Errorf("handler returned wrong value: got %s want %s", value, "1") + } + if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Reset")]; len(value) < 1 || value[0] != "1" { + t.Errorf("handler returned wrong value: got %s want %s", value, "1") + } + if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Remaining")]; len(value) < 1 || value[0] != "0" { + t.Errorf("handler returned wrong value: got %s want %s", value, "0") + } ch := make(chan int) go func() { @@ -367,6 +377,23 @@ func TestLimitHandler(t *testing.T) { if status := rr.Code; status != http.StatusTooManyRequests { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusTooManyRequests) } + // check X-Rate-Limit headers + if value := rr.Result().Header[http.CanonicalHeaderKey("X-Rate-Limit-Limit")]; len(value) < 1 || value[0] != "1.00" { + t.Errorf("X-Rate-Limit-Limit has wrong value: got %s want %v", value, "1.00") + } + if value := rr.Result().Header[http.CanonicalHeaderKey("X-Rate-Limit-Duration")]; len(value) < 1 || value[0] != "1" { + t.Errorf("X-Rate-Limit-Duration has wrong value: got %s want %v", value, "1") + } + // check RateLimit headers + if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Limit")]; len(value) < 1 || value[0] != "1" { + t.Errorf("RateLimit-Limit has wrong value: got %s want %v", value, "1") + } + if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Reset")]; len(value) < 1 || value[0] != "1" { + t.Errorf("RateLimit-Reset has wrong value: got %s want %v", value, "1") + } + if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Remaining")]; len(value) < 1 || value[0] != "0" { + t.Errorf("RateLimit-Remaining has wrong value: got %s want %v", value, "0") + } // OnLimitReached should be called if counter != 1 { t.Errorf("onLimitReached was not called") From b96dae2fb8761b3f2e3c956f87231e3243ec43b1 Mon Sep 17 00:00:00 2001 From: Didip Kerabat Date: Sun, 3 Jul 2022 18:44:26 -0700 Subject: [PATCH 04/11] Update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index aeaced3..fdb6143 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,8 @@ This is a generic middleware to rate-limit HTTP requests. **v6.x.x:** Replaced `go-cache` with `github.com/go-pkgz/expirable-cache` because `go-cache` leaks goroutines. +**v7.x.x:** Replaced `time/rate` with `embedded time/rate` so that we can support more rate limit headers. + ## Five Minute Tutorial ```go From 376acbc70f61a82c7f07a0643bd89e426a14426c Mon Sep 17 00:00:00 2001 From: Didip Kerabat Date: Wed, 6 Jul 2022 08:03:17 -0700 Subject: [PATCH 05/11] move to v7 --- README.md | 6 +++--- go.mod | 2 +- limiter/limiter.go | 2 +- tollbooth.go | 6 +++--- tollbooth_benchmark_test.go | 2 +- tollbooth_bug_report_test.go | 2 +- tollbooth_test.go | 2 +- 7 files changed, 11 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index fdb6143..8a45bd8 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ package main import ( "net/http" - "github.com/didip/tollbooth/v6" + "github.com/didip/tollbooth/v7" ) func HelloHandler(w http.ResponseWriter, req *http.Request) { @@ -54,8 +54,8 @@ func main() { import ( "time" - "github.com/didip/tollbooth/v6" - "github.com/didip/tollbooth/v6/limiter" + "github.com/didip/tollbooth/v7" + "github.com/didip/tollbooth/v7/limiter" ) lmt := tollbooth.NewLimiter(1, nil) diff --git a/go.mod b/go.mod index ab3ff2b..b93f0bc 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/didip/tollbooth/v6 +module github.com/didip/tollbooth/v7 go 1.12 diff --git a/limiter/limiter.go b/limiter/limiter.go index 33b3e36..c64a7f2 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -8,7 +8,7 @@ import ( cache "github.com/go-pkgz/expirable-cache" - "github.com/didip/tollbooth/v6/internal/time/rate" + "github.com/didip/tollbooth/v7/internal/time/rate" ) // New is a constructor for Limiter. diff --git a/tollbooth.go b/tollbooth.go index a17271e..4545c0a 100644 --- a/tollbooth.go +++ b/tollbooth.go @@ -7,9 +7,9 @@ import ( "net/http" "strings" - "github.com/didip/tollbooth/v6/errors" - "github.com/didip/tollbooth/v6/libstring" - "github.com/didip/tollbooth/v6/limiter" + "github.com/didip/tollbooth/v7/errors" + "github.com/didip/tollbooth/v7/libstring" + "github.com/didip/tollbooth/v7/limiter" ) // setResponseHeaders configures X-Rate-Limit-Limit and X-Rate-Limit-Duration diff --git a/tollbooth_benchmark_test.go b/tollbooth_benchmark_test.go index 488e113..a035d83 100644 --- a/tollbooth_benchmark_test.go +++ b/tollbooth_benchmark_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/didip/tollbooth/v6/limiter" + "github.com/didip/tollbooth/v7/limiter" ) func BenchmarkLimitByKeys(b *testing.B) { diff --git a/tollbooth_bug_report_test.go b/tollbooth_bug_report_test.go index 5b03bc1..4ab584e 100644 --- a/tollbooth_bug_report_test.go +++ b/tollbooth_bug_report_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/didip/tollbooth/v6/limiter" + "github.com/didip/tollbooth/v7/limiter" ) // See: https://github.com/didip/tollbooth/issues/48 diff --git a/tollbooth_test.go b/tollbooth_test.go index 4df67f0..ccf1d8f 100644 --- a/tollbooth_test.go +++ b/tollbooth_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/didip/tollbooth/v6/limiter" + "github.com/didip/tollbooth/v7/limiter" ) func TestLimitByKeys(t *testing.T) { From 0b61d553623024c838a74b6e1b7523e30c8c640f Mon Sep 17 00:00:00 2001 From: Xinyu Xie <55250846+Xinyu-bot@users.noreply.github.com> Date: Thu, 22 Sep 2022 00:19:00 +0800 Subject: [PATCH 06/11] Skip header value matching if entry is empty slice (#104) * Skip header value matching if entry is empty slice * Test function name tweak * Fix race cond for test by atomic int & sync map * Fix test build error Repleace atomic.Int64 and sync.Map with a naive map with mutex lock. Doing so because the build tool is at go version 1.14, while atomic.Int64 is introduced at go 1.19 I think. * Shorten test function name to pass ci lint * Simplify test case to pass cyclomatic complexity constraint --- tollbooth.go | 4 ++ tollbooth_test.go | 118 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+) diff --git a/tollbooth.go b/tollbooth.go index 4545c0a..0dcf82e 100644 --- a/tollbooth.go +++ b/tollbooth.go @@ -112,6 +112,10 @@ func ShouldSkipLimiter(lmt *limiter.Limiter, r *http.Request) bool { requestHeadersDefinedInLimiter = false for headerKey, headerValues := range lmtHeaders { + if len(headerValues) == 0 { + requestHeadersDefinedInLimiter = true + continue + } for _, headerValue := range headerValues { if r.Header.Get(headerKey) == headerValue { requestHeadersDefinedInLimiter = true diff --git a/tollbooth_test.go b/tollbooth_test.go index ccf1d8f..25d197d 100644 --- a/tollbooth_test.go +++ b/tollbooth_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "time" @@ -487,3 +488,120 @@ func isInSlice(key string, keys []string) bool { } return false } + +type LockMap struct { + m map[string]int64 + sync.Mutex +} + +func (lm *LockMap) Set(key string, value int64) { + lm.Lock() + lm.m[key] = value + lm.Unlock() +} + +func (lm *LockMap) Get(key string) (int64, bool) { + lm.Lock() + value, ok := lm.m[key] + lm.Unlock() + return value, ok +} + +func (lm *LockMap) Add(key string, incr int64) { + lm.Lock() + if val, ok := lm.m[key]; ok { + lm.m[key] = val + incr + } else { + lm.m[key] = incr + } + lm.Unlock() +} + +func TestLimitHandlerEmptyHeader(t *testing.T) { + lmt := limiter.New(nil).SetMax(1).SetBurst(1) + lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"}) + lmt.SetMethods([]string{"POST"}) + lmt.SetHeader("user_id", []string{}) + + counterMap := &LockMap{m: map[string]int64{}} + lmt.SetOnLimitReached(func(w http.ResponseWriter, r *http.Request) { + _, _ = w, r + counterMap.Add(r.Header.Get("user_id"), 1) + }) + + handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r + w.Write([]byte(`hello world`)) + })) + + req, err := http.NewRequest("POST", "/doesntmatter", nil) + if err != nil { + t.Fatal(err) + } + + req.Header.Set("X-Real-IP", "2601:7:1c82:4097:59a0:a80b:2841:b8c8") + req.Header.Set("user_id", "0") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + { // Should not be limited + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) + } + // check RateLimit headers + if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Limit")]; len(value) < 1 || value[0] != "1" { + t.Errorf("handler returned wrong value: got %s want %s", value, "1") + } + if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Reset")]; len(value) < 1 || value[0] != "1" { + t.Errorf("handler returned wrong value: got %s want %s", value, "1") + } + if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Remaining")]; len(value) < 1 || value[0] != "0" { + t.Errorf("handler returned wrong value: got %s want %s", value, "0") + } + } + + wg := sync.WaitGroup{} + wg.Add(1) + + // same user_id, should be limited + go func() { + defer wg.Done() + + req1, _ := http.NewRequest("POST", "/doesntmatter", nil) + req1.Header.Set("X-Real-IP", "2601:7:1c82:4097:59a0:a80b:2841:b8c8") + req1.Header.Set("user_id", "0") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req1) + // Should be limited + { + if status := rr.Code; status != http.StatusTooManyRequests { + t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusTooManyRequests) + } + // check X-Rate-Limit headers + if value := rr.Result().Header[http.CanonicalHeaderKey("X-Rate-Limit-Limit")]; len(value) < 1 || value[0] != "1.00" { + t.Errorf("X-Rate-Limit-Limit has wrong value: got %s want %v", value, "1.00") + } + if value := rr.Result().Header[http.CanonicalHeaderKey("X-Rate-Limit-Duration")]; len(value) < 1 || value[0] != "1" { + t.Errorf("X-Rate-Limit-Duration has wrong value: got %s want %v", value, "1") + } + // check RateLimit headers + if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Limit")]; len(value) < 1 || value[0] != "1" { + t.Errorf("RateLimit-Limit has wrong value: got %s want %v", value, "1") + } + if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Reset")]; len(value) < 1 || value[0] != "1" { + t.Errorf("RateLimit-Reset has wrong value: got %s want %v", value, "1") + } + if value := rr.Result().Header[http.CanonicalHeaderKey("RateLimit-Remaining")]; len(value) < 1 || value[0] != "0" { + t.Errorf("RateLimit-Remaining has wrong value: got %s want %v", value, "0") + } + // OnLimitReached should be called + if aint, ok := counterMap.Get(req1.Header.Get("user_id")); ok { + if aint == 0 { + t.Errorf("onLimitReached was not called") + } + } + } + }() + + wg.Wait() // Block until go func is done. +} From 604e3765373786d021e144bf0b8da2b561386702 Mon Sep 17 00:00:00 2001 From: Didip Kerabat Date: Wed, 21 Sep 2022 20:30:30 -0700 Subject: [PATCH 07/11] Upgrade everything to Go 1.19 --- .github/workflows/ci.yml | 4 ++-- go.mod | 2 +- go.sum | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1bd413e..dc32917 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,10 +11,10 @@ jobs: runs-on: ubuntu-latest steps: - - name: set up go 1.14 + - name: set up go 1.19 uses: actions/setup-go@v1 with: - go-version: 1.14 + go-version: 1.19 id: go - name: checkout diff --git a/go.mod b/go.mod index b93f0bc..c889a20 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ module github.com/didip/tollbooth/v7 -go 1.12 +go 1.19 require github.com/go-pkgz/expirable-cache v0.1.0 diff --git a/go.sum b/go.sum index fdcdb33..e6d03c3 100644 --- a/go.sum +++ b/go.sum @@ -7,7 +7,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 89842701ee69008f8f43c7f57704a11194d355f3 Mon Sep 17 00:00:00 2001 From: Didip Kerabat Date: Wed, 30 Nov 2022 07:46:14 -0800 Subject: [PATCH 08/11] update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 8a45bd8..f96ac2f 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,8 @@ Sometimes, other frameworks require a little bit of shim to use Tollbooth. These ## My other Go libraries +* [ErrStack](https://github.com/didip/errstack): A small library to combine errors and also display filename and line number. + * [Stopwatch](https://github.com/didip/stopwatch): A small library to measure latency of things. Useful if you want to report latency data to Graphite. * [LaborUnion](https://github.com/didip/laborunion): A dynamic worker pool library. From da073739d0f0adb52c8bceb4242ebc60e0f8444b Mon Sep 17 00:00:00 2001 From: Dmitry Verkhoturov Date: Wed, 8 May 2024 05:40:43 +0200 Subject: [PATCH 09/11] update GitHub actions, fix code according to golancilint latest version (#108) --- .github/workflows/ci.yml | 15 +++++++-------- .golangci.yml | 6 +----- tollbooth_bug_report_test.go | 12 ++++++------ tollbooth_test.go | 8 ++++---- 4 files changed, 18 insertions(+), 23 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dc32917..f26c51b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,21 +12,20 @@ jobs: steps: - name: set up go 1.19 - uses: actions/setup-go@v1 + uses: actions/setup-go@v5 with: - go-version: 1.19 + go-version: "1.19" id: go - name: checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: build and test run: | go test -timeout=60s -race go build -race - - name: install golangci-lint - run: curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh| sh -s -- -b $GITHUB_WORKSPACE v1.45.2 - - - name: run golangci-lint - run: $GITHUB_WORKSPACE/golangci-lint run --out-format=github-actions + - name: golangci-lint + uses: golangci/golangci-lint-action@v4 + with: + version: latest diff --git a/.golangci.yml b/.golangci.yml index 880786a..5d0a4b6 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,21 +1,17 @@ linters: enable: - - megacheck - revive - govet - unconvert - megacheck - - structcheck - gas - gocyclo - dupl - misspell - unparam - - varcheck - - deadcode + - unused - typecheck - ineffassign - - varcheck - stylecheck - gochecknoinits - exportloopref diff --git a/tollbooth_bug_report_test.go b/tollbooth_bug_report_test.go index 4ab584e..9c323d8 100644 --- a/tollbooth_bug_report_test.go +++ b/tollbooth_bug_report_test.go @@ -23,9 +23,9 @@ func Test_Issue48_RequestTerminatedEvenOnLowVolumeOnSameIP(t *testing.T) { lmt.SetMethods([]string{"GET"}) limitReachedCounter := 0 - lmt.SetOnLimitReached(func(w http.ResponseWriter, r *http.Request) { limitReachedCounter++ }) + lmt.SetOnLimitReached(func(http.ResponseWriter, *http.Request) { limitReachedCounter++ }) - handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Write([]byte(`hello world`)) })) @@ -74,7 +74,7 @@ func Test_Issue66_CustomRateLimitByHeaderValues(t *testing.T) { customerID1 := "1234" customerID2 := "5678" - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + h := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) h, allocationLimiter := issue66RateLimiter(h, []string{customerID1, customerID2}) testServer := httptest.NewServer(h) @@ -145,7 +145,7 @@ func Test_Issue91_BrokenSetMethod_DontBlockGet(t *testing.T) { // ------------------------------------------------------------------- - handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Write([]byte(`hello world`)) })) @@ -174,7 +174,7 @@ func Test_Issue91_BrokenSetMethod_BlockPost(t *testing.T) { lmt.SetMethods([]string{"POST"}) limitReachedCounter := 0 - lmt.SetOnLimitReached(func(w http.ResponseWriter, r *http.Request) { + lmt.SetOnLimitReached(func(http.ResponseWriter, *http.Request) { limitReachedCounter++ }) @@ -185,7 +185,7 @@ func Test_Issue91_BrokenSetMethod_BlockPost(t *testing.T) { // ------------------------------------------------------------------- - handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Write([]byte(`hello world`)) })) diff --git a/tollbooth_test.go b/tollbooth_test.go index 25d197d..9d3f69c 100644 --- a/tollbooth_test.go +++ b/tollbooth_test.go @@ -337,9 +337,9 @@ func TestLimitHandler(t *testing.T) { lmt.SetMethods([]string{"POST"}) counter := 0 - lmt.SetOnLimitReached(func(w http.ResponseWriter, r *http.Request) { counter++ }) + lmt.SetOnLimitReached(func(http.ResponseWriter, *http.Request) { counter++ }) - handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Write([]byte(`hello world`)) })) @@ -411,14 +411,14 @@ func TestOverrideForResponseWriter(t *testing.T) { lmt.SetOverrideDefaultResponseWriter(true) counter := 0 - lmt.SetOnLimitReached(func(w http.ResponseWriter, r *http.Request) { + lmt.SetOnLimitReached(func(w http.ResponseWriter, _ *http.Request) { w.Header().Add("Content-Type", "application/json") w.WriteHeader(http.StatusNotAcceptable) w.Write([]byte("rejecting the large amount of requests")) counter++ }) - handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := LimitHandler(lmt, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Write([]byte(`hello world`)) })) From 6b2b2c4e31a46dff3ed1bab7e5e906c7f6188b14 Mon Sep 17 00:00:00 2001 From: John Jarvis Date: Sat, 22 Jun 2024 03:03:14 +0200 Subject: [PATCH 10/11] Fixes deadlock in ExecOnLimitReached (#107) This moves the the mutex unlock in `ExecOnLimitReached` so that it isn't around the function that gets executed. Including the function in the lock may result in a deadlock if there are any method calls in the function that call `RLock` again. --- limiter/limiter.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/limiter/limiter.go b/limiter/limiter.go index c64a7f2..5561e6c 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -261,9 +261,9 @@ func (l *Limiter) SetOnLimitReached(fn func(w http.ResponseWriter, r *http.Reque // ExecOnLimitReached is thread-safe way of executing after-rejection function when limit is reached. func (l *Limiter) ExecOnLimitReached(w http.ResponseWriter, r *http.Request) { l.RLock() - defer l.RUnlock() - fn := l.onLimitReached + l.RUnlock() + if fn != nil { fn(w, r) } From 95418adbf23e1f2415354688a4fb1c3592145913 Mon Sep 17 00:00:00 2001 From: Dmitry Verkhoturov Date: Sat, 22 Jun 2024 16:35:06 +0200 Subject: [PATCH 11/11] update go-pkgz/expirable-cache to version with generics (#109) --- go.mod | 2 +- go.sum | 17 ++++++----------- limiter/limiter.go | 30 +++++++++++++++--------------- 3 files changed, 22 insertions(+), 27 deletions(-) diff --git a/go.mod b/go.mod index c889a20..5998196 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/didip/tollbooth/v7 go 1.19 -require github.com/go-pkgz/expirable-cache v0.1.0 +require github.com/go-pkgz/expirable-cache/v3 v3.0.0 diff --git a/go.sum b/go.sum index e6d03c3..98c711e 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,7 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-pkgz/expirable-cache v0.1.0 h1:3bw0m8vlTK8qlwz5KXuygNBTkiKRTPrAGXU0Ej2AC1g= -github.com/go-pkgz/expirable-cache v0.1.0/go.mod h1:GTrEl0X+q0mPNqN6dtcQXksACnzCBQ5k/k1SwXJsZKs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/go-pkgz/expirable-cache/v3 v3.0.0 h1:u3/gcu3sabLYiTCevoRKv+WzjIn5oo7P8XtiXBeRDLw= +github.com/go-pkgz/expirable-cache/v3 v3.0.0/go.mod h1:2OQiDyEGQalYecLWmXprm3maPXeVb5/6/X7yRPYTzec= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/limiter/limiter.go b/limiter/limiter.go index 5561e6c..0153dc8 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -6,7 +6,7 @@ import ( "sync" "time" - cache "github.com/go-pkgz/expirable-cache" + cache "github.com/go-pkgz/expirable-cache/v3" "github.com/didip/tollbooth/v7/internal/time/rate" ) @@ -36,9 +36,9 @@ func New(generalExpirableOptions *ExpirableOptions) *Limiter { lmt.generalExpirableOptions.DefaultExpirationTTL = 87600 * time.Hour } - lmt.tokenBuckets, _ = cache.NewCache(cache.TTL(lmt.generalExpirableOptions.DefaultExpirationTTL)) + lmt.tokenBuckets = cache.NewCache[string, *rate.Limiter]().WithTTL(lmt.generalExpirableOptions.DefaultExpirationTTL) - lmt.basicAuthUsers, _ = cache.NewCache(cache.TTL(lmt.generalExpirableOptions.DefaultExpirationTTL)) + lmt.basicAuthUsers = cache.NewCache[string, bool]().WithTTL(lmt.generalExpirableOptions.DefaultExpirationTTL) return lmt } @@ -81,17 +81,17 @@ type Limiter struct { generalExpirableOptions *ExpirableOptions // List of basic auth usernames to limit. - basicAuthUsers cache.Cache + basicAuthUsers cache.Cache[string, bool] // Map of HTTP headers to limit. // Empty means skip headers checking. - headers map[string]cache.Cache + headers map[string]cache.Cache[string, bool] // Map of Context values to limit. - contextValues map[string]cache.Cache + contextValues map[string]cache.Cache[string, bool] // Map of limiters with TTL - tokenBuckets cache.Cache + tokenBuckets cache.Cache[string, *rate.Limiter] // Ignore URL on the rate limiter keys ignoreURL bool @@ -383,7 +383,7 @@ func (l *Limiter) DeleteExpiredTokenBuckets() { // SetHeaders is thread-safe way of setting map of HTTP headers to limit. func (l *Limiter) SetHeaders(headers map[string][]string) *Limiter { if l.headers == nil { - l.headers = make(map[string]cache.Cache) + l.headers = make(map[string]cache.Cache[string, bool]) } for header, entries := range headers { @@ -419,7 +419,7 @@ func (l *Limiter) SetHeader(header string, entries []string) *Limiter { } if !found { - existing, _ = cache.NewCache(cache.TTL(ttl)) + existing = cache.NewCache[string, bool]().WithTTL(ttl) } for _, entry := range entries { @@ -450,7 +450,7 @@ func (l *Limiter) RemoveHeader(header string) *Limiter { } l.Lock() - l.headers[header], _ = cache.NewCache(cache.TTL(ttl)) + l.headers[header] = cache.NewCache[string, bool]().WithTTL(ttl) l.Unlock() return l @@ -476,7 +476,7 @@ func (l *Limiter) RemoveHeaderEntries(header string, entriesForRemoval []string) // SetContextValues is thread-safe way of setting map of HTTP headers to limit. func (l *Limiter) SetContextValues(contextValues map[string][]string) *Limiter { if l.contextValues == nil { - l.contextValues = make(map[string]cache.Cache) + l.contextValues = make(map[string]cache.Cache[string, bool]) } for contextValue, entries := range contextValues { @@ -512,7 +512,7 @@ func (l *Limiter) SetContextValue(contextValue string, entries []string) *Limite } if !found { - existing, _ = cache.NewCache(cache.TTL(ttl)) + existing = cache.NewCache[string, bool]().WithTTL(ttl) } for _, entry := range entries { @@ -543,7 +543,7 @@ func (l *Limiter) RemoveContextValue(contextValue string) *Limiter { } l.Lock() - l.contextValues[contextValue], _ = cache.NewCache(cache.TTL(ttl)) + l.contextValues[contextValue] = cache.NewCache[string, bool]().WithTTL(ttl) l.Unlock() return l @@ -585,7 +585,7 @@ func (l *Limiter) limitReachedWithTokenBucketTTL(key string, tokenBucketTTL time return false } - return !expiringMap.(*rate.Limiter).Allow() + return !expiringMap.Allow() } // LimitReached returns a bool indicating if the Bucket identified by key ran out of tokens. @@ -606,5 +606,5 @@ func (l *Limiter) Tokens(key string) int { return 0 } - return int(expiringMap.(*rate.Limiter).TokensAt(time.Now())) + return int(expiringMap.TokensAt(time.Now())) }