From f93468688fa7b357fb5d3d9a9c9c9c4c90b85a39 Mon Sep 17 00:00:00 2001 From: Didip Kerabat Date: Wed, 9 Oct 2024 12:45:44 -0700 Subject: [PATCH] We received a vulnerability disclosure due to how we pick a remote IP address. (#99) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * We received a vulnerability disclosure due to how we pick a remote IP address. Disclosure URL: https://gist.github.com/adam-p/4b777de4bda0027f4c3daa45618adcdc This is an attempt to address the situation. 1. We no longer configure SetIPLookups on default. 2. We address the two different SetIPLookups confusion in two different place by removing both of them. 3. We add a new, explicit way, for user to define how IP address should be picked up. Tests are all updated to use the new method of picking IP address. This will be a backward incompatible change so version number has to be bumped to 7. * Make golint happy. * Update documentation. * We don’t need the ability to pick which header to use. * Fix tests. --------- Co-authored-by: Didip Kerabat --- README.md | 45 ++++++++-- libstring/libstring.go | 53 ++++++------ libstring/libstring_test.go | 103 ++++++++++++----------- limiter/limiter.go | 38 ++++++--- limiter/limiter_setter_getter_test.go | 13 --- tollbooth.go | 7 +- tollbooth_benchmark_test.go | 9 +- tollbooth_bug_report_test.go | 14 ++-- tollbooth_test.go | 116 ++++++++++++++++++-------- 9 files changed, 242 insertions(+), 156 deletions(-) diff --git a/README.md b/README.md index f96ac2f..e13a0d5 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,9 @@ This is a generic middleware to rate-limit HTTP requests. **v7.x.x:** Replaced `time/rate` with `embedded time/rate` so that we can support more rate limit headers. +**v8.x.x:** Address `RemoteIP` vulnerability concern by replacing it with `RemoteIPFromIPLookup`, an explicit way to pick the IP address. + + ## Five Minute Tutorial ```go @@ -34,6 +37,7 @@ import ( "net/http" "github.com/didip/tollbooth/v7" + "github.com/didip/tollbooth/v7/limiter" ) func HelloHandler(w http.ResponseWriter, req *http.Request) { @@ -42,7 +46,15 @@ func HelloHandler(w http.ResponseWriter, req *http.Request) { func main() { // Create a request limiter per handler. - http.Handle("/", tollbooth.LimitFuncHandler(tollbooth.NewLimiter(1, nil), HelloHandler)) + lmt := tollbooth.NewLimiter(1, nil) + + // New in version >= 8, you must explicitly define how to pick the IP address. + lmt.SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }) + + http.Handle("/", tollbooth.LimitFuncHandler(lmt, HelloHandler)) http.ListenAndServe(":12345", nil) } ``` @@ -66,10 +78,24 @@ func main() { // every token bucket in it will expire 1 hour after it was initially set. lmt = tollbooth.NewLimiter(1, &limiter.ExpirableOptions{DefaultExpirationTTL: time.Hour}) - // Configure list of places to look for IP address. - // By default it's: "RemoteAddr", "X-Forwarded-For", "X-Real-IP" - // If your application is behind a proxy, set "X-Forwarded-For" first. - lmt.SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}) + // New in version >= 8, you must explicitly define how to pick the IP address. + // If IP address cannot be found, rate limiter will not be activated. + lmt.SetIPLookup(limiter.IPLookup{ + // The name of lookup method. + // Possible options are: RemoteAddr, X-Forwarded-For, X-Real-IP, CF-Connecting-IP + // All other headers are considered unknown and will be ignored. + Name: "X-Real-IP", + + // The index position to pick the ip address from a comma separated list. + // The index goes from right to left. + // + // When there are multiple of the same headers, + // we will concat them together in the order of first to last seen. + // And then we pick the IP using this index position. + IndexFromRight: 0, + }) + + // In version >= 8, lmt.SetIPLookups and lmt.GetIPLookups are removed. // Limit only GET and POST requests. lmt.SetMethods([]string{"GET", "POST"}) @@ -89,8 +115,7 @@ func main() { lmt.RemoveHeaderEntries("X-Access-Token", []string{"limitless-token"}) // By the way, the setters are chainable. Example: - lmt.SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}). - SetMethods([]string{"GET", "POST"}). + lmt.SetMethods([]string{"GET", "POST"}). SetBasicAuthUsers([]string{"sansa"}). SetBasicAuthUsers([]string{"tyrion"}) ``` @@ -137,6 +162,12 @@ func main() { ```go lmt := tollbooth.NewLimiter(1, nil) + // New in version >= 8, you must explicitly define how to pick the IP address. + lmt.SetIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + IndexFromRight: 0, + }) + // Set a custom message. lmt.SetMessage("You have reached maximum request limit.") diff --git a/libstring/libstring.go b/libstring/libstring.go index 730b654..0e6b334 100644 --- a/libstring/libstring.go +++ b/libstring/libstring.go @@ -5,6 +5,8 @@ import ( "net" "net/http" "strings" + + "github.com/didip/tollbooth/v7/limiter" ) // StringInSlice finds needle in a slice of strings. @@ -17,38 +19,35 @@ func StringInSlice(sliceString []string, needle string) bool { return false } -// RemoteIP finds IP Address given http.Request struct. -func RemoteIP(ipLookups []string, forwardedForIndexFromBehind int, r *http.Request) string { - realIP := r.Header.Get("X-Real-IP") - forwardedFor := r.Header.Get("X-Forwarded-For") - - for _, lookup := range ipLookups { - if lookup == "RemoteAddr" { - // 1. Cover the basic use cases for both ipv4 and ipv6 - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - // 2. Upon error, just return the remote addr. - return r.RemoteAddr - } - return ip +// RemoteIPFromIPLookup picks an ip address explicitly from limiter.IPLookup criteria. +// This function is intended to replace RemoteIP function. +func RemoteIPFromIPLookup(ipLookup limiter.IPLookup, r *http.Request) string { + switch ipLookup.Name { + case "RemoteAddr": + // 1. Cover the basic use cases for both ipv4 and ipv6 + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + // 2. Upon error, just return the remote addr. + return r.RemoteAddr } - if lookup == "X-Forwarded-For" && forwardedFor != "" { - // X-Forwarded-For is potentially a list of addresses separated with "," - parts := strings.Split(forwardedFor, ",") - for i, p := range parts { - parts[i] = strings.TrimSpace(p) - } + return ip - partIndex := len(parts) - 1 - forwardedForIndexFromBehind - if partIndex < 0 { - partIndex = 0 - } + case "X-Forwarded-For", "X-Real-IP", "CF-Connecting-IP": + ipAddrListCommaSeparated := r.Header.Values(ipLookup.Name) - return parts[partIndex] + ipAddrCommaSeparated := strings.Join(ipAddrListCommaSeparated, ",") + + ips := strings.Split(ipAddrCommaSeparated, ",") + for i, p := range ips { + ips[i] = strings.TrimSpace(p) } - if lookup == "X-Real-IP" && realIP != "" { - return realIP + + ipIndex := len(ips) - 1 - ipLookup.IndexFromRight + if ipIndex < 0 { + ipIndex = 0 } + + return ips[ipIndex] } return "" diff --git a/libstring/libstring_test.go b/libstring/libstring_test.go index 08abe3c..aed70ce 100644 --- a/libstring/libstring_test.go +++ b/libstring/libstring_test.go @@ -4,6 +4,8 @@ import ( "net/http" "strings" "testing" + + "github.com/didip/tollbooth/v7/limiter" ) func TestStringInSlice(t *testing.T) { @@ -12,28 +14,7 @@ func TestStringInSlice(t *testing.T) { } } -func TestRemoteIPDefault(t *testing.T) { - ipLookups := []string{"RemoteAddr", "X-Real-IP"} - ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8" - - request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) - if err != nil { - t.Errorf("Unable to create new HTTP request. Error: %v", err) - } - - request.Header.Set("X-Real-IP", ipv6) - - ip := RemoteIP(ipLookups, 0, request) - if ip != request.RemoteAddr { - t.Errorf("Did not get the right IP. IP: %v", ip) - } - if ip == ipv6 { - t.Errorf("X-Real-IP should have been skipped. IP: %v", ip) - } -} - func TestRemoteIPForwardedFor(t *testing.T) { - ipLookups := []string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"} ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8" request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) @@ -44,7 +25,11 @@ func TestRemoteIPForwardedFor(t *testing.T) { request.Header.Set("X-Forwarded-For", "10.10.10.10") request.Header.Set("X-Real-IP", ipv6) - ip := RemoteIP(ipLookups, 0, request) + ip := RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + IndexFromRight: 0, + }, request) + if ip != "10.10.10.10" { t.Errorf("Did not get the right IP. IP: %v", ip) } @@ -54,7 +39,6 @@ func TestRemoteIPForwardedFor(t *testing.T) { } func TestRemoteIPRealIP(t *testing.T) { - ipLookups := []string{"X-Real-IP", "X-Forwarded-For", "RemoteAddr"} ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8" request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) @@ -65,7 +49,11 @@ func TestRemoteIPRealIP(t *testing.T) { request.Header.Set("X-Forwarded-For", "10.10.10.10") request.Header.Set("X-Real-IP", ipv6) - ip := RemoteIP(ipLookups, 0, request) + ip := RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }, request) + if ip != ipv6 { t.Errorf("Did not get the right IP. IP: %v", ip) } @@ -74,53 +62,64 @@ func TestRemoteIPRealIP(t *testing.T) { } } -func TestRemoteIPMultipleForwardedFor(t *testing.T) { - ipLookups := []string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"} - ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8" - +func TestRemoteIPMultipleForwardedForIPAddresses(t *testing.T) { request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { t.Errorf("Unable to create new HTTP request. Error: %v", err) } - request.Header.Set("X-Real-IP", ipv6) - - // Missing X-Forwarded-For should not break things - ip := RemoteIP(ipLookups, 0, request) - if ip != ipv6 { - t.Errorf("X-Real-IP should have been chosen because X-Forwarded-For is missing. IP: %v", ip) - } - request.Header.Set("X-Forwarded-For", "10.10.10.10,10.10.10.11") + ip := RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + IndexFromRight: 0, + }, request) + // Should get the last one - ip = RemoteIP(ipLookups, 0, request) if ip != "10.10.10.11" { t.Errorf("Did not get the right IP. IP: %v", ip) } - if ip == ipv6 { - t.Errorf("X-Real-IP should have been skipped. IP: %v", ip) - } + + ip = RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + IndexFromRight: 1, + }, request) // Should get the 2nd from last - ip = RemoteIP(ipLookups, 1, request) if ip != "10.10.10.10" { t.Errorf("Did not get the right IP. IP: %v", ip) } - if ip == ipv6 { - t.Errorf("X-Real-IP should have been skipped. IP: %v", ip) - } // What about index out of bound? RemoteIP should simply choose index 0. - ip = RemoteIP(ipLookups, 2, request) + ip = RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + IndexFromRight: 2, + }, request) + if ip != "10.10.10.10" { t.Errorf("Did not get the right IP. IP: %v", ip) } - if ip == ipv6 { - t.Errorf("X-Real-IP should have been skipped. IP: %v", ip) - } } +func TestRemoteIPMultipleForwardedForHeaders(t *testing.T) { + request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) + if err != nil { + t.Errorf("Unable to create new HTTP request. Error: %v", err) + } + + request.Header.Add("X-Forwarded-For", "8.8.8.8,8.8.4.4") + request.Header.Add("X-Forwarded-For", "10.10.10.10,10.10.10.11") + + ip := RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + IndexFromRight: 0, + }, request) + + // Should get the last header and the last IP + if ip != "10.10.10.11" { + t.Errorf("Did not get the right IP. IP: %v", ip) + } +} func TestCanonicalizeIP(t *testing.T) { tests := []struct { name string @@ -169,10 +168,12 @@ func TestCanonicalizeIP(t *testing.T) { }, } for _, tt := range tests { - tt := tt + ip := tt.ip + want := tt.want + t.Run(tt.name, func(t *testing.T) { - if got := CanonicalizeIP(tt.ip); got != tt.want { - t.Errorf("CanonicalizeIP() = %v, want %v", got, tt.want) + if got := CanonicalizeIP(ip); got != want { + t.Errorf("CanonicalizeIP() = %v, want %v", got, want) } }) } diff --git a/limiter/limiter.go b/limiter/limiter.go index 0153dc8..38636e6 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -19,7 +19,6 @@ func New(generalExpirableOptions *ExpirableOptions) *Limiter { SetMessage("You have reached maximum request limit."). SetStatusCode(429). SetOnLimitReached(nil). - SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}). SetForwardedForIndexFromBehind(0). SetHeaders(make(map[string][]string)). SetContextValues(make(map[string][]string)). @@ -43,6 +42,18 @@ func New(generalExpirableOptions *ExpirableOptions) *Limiter { return lmt } +// IPLookup is a config struct to define how users want to pick the remote IP address. +type IPLookup struct { + // The name of lookup method. + // Possible options are: RemoteAddr, X-Forwarded-For, X-Real-IP, CF-Connecting-IP + // All other headers are considered unknown and will be ignored. + Name string + + // The index position to pick the ip address from a comma separated list. + // The index goes from right to left. + IndexFromRight int +} + // Limiter is a config struct to limit a particular request handler. type Limiter struct { // Maximum number of requests to limit per second. @@ -66,10 +77,9 @@ type Limiter struct { // An option to write back what you want upon reaching a limit. overrideDefaultResponseWriter bool - // List of places to look up IP address. - // Default is "RemoteAddr", "X-Forwarded-For", "X-Real-IP". - // You can rearrange the order as you like. - ipLookups []string + // Explicitly define how to look up IP address. + // This is intended to replace ipLookups + explicitIPLookup IPLookup forwardedForIndex int @@ -270,10 +280,12 @@ func (l *Limiter) ExecOnLimitReached(w http.ResponseWriter, r *http.Request) { } // SetOverrideDefaultResponseWriter is a thread-safe way of setting the response writer override variable. -func (l *Limiter) SetOverrideDefaultResponseWriter(override bool) { +func (l *Limiter) SetOverrideDefaultResponseWriter(override bool) *Limiter { l.Lock() l.overrideDefaultResponseWriter = override l.Unlock() + + return l } // GetOverrideDefaultResponseWriter is a thread-safe way of getting the response writer override variable. @@ -283,20 +295,22 @@ func (l *Limiter) GetOverrideDefaultResponseWriter() bool { return l.overrideDefaultResponseWriter } -// SetIPLookups is thread-safe way of setting list of places to look up IP address. -func (l *Limiter) SetIPLookups(ipLookups []string) *Limiter { +// SetIPLookup is thread-safe way of setting an explicit way to look up IP address. +// This method is intended to replace SetIPLookups (version 6 or older). +func (l *Limiter) SetIPLookup(lookup IPLookup) *Limiter { l.Lock() - l.ipLookups = ipLookups + l.explicitIPLookup = lookup l.Unlock() return l } -// GetIPLookups is thread-safe way of getting list of places to look up IP address. -func (l *Limiter) GetIPLookups() []string { +// GetIPLookup is thread-safe way of getting an explicit way to look up IP address. +// This method is intended to replace the old GetIPLookups (version 6 or older). +func (l *Limiter) GetIPLookup() IPLookup { l.RLock() defer l.RUnlock() - return l.ipLookups + return l.explicitIPLookup } // SetIgnoreURL is thread-safe way of setting whenever ignore the URL on rate limit keys diff --git a/limiter/limiter_setter_getter_test.go b/limiter/limiter_setter_getter_test.go index a4bd41f..55276c1 100644 --- a/limiter/limiter_setter_getter_test.go +++ b/limiter/limiter_setter_getter_test.go @@ -43,19 +43,6 @@ func TestSetGetStatusCode(t *testing.T) { } } -func TestSetGetIPLookups(t *testing.T) { - lmt := New(nil).SetMax(1) - - // Check default - if len(lmt.GetIPLookups()) != 3 { - t.Errorf("IPLookups field is incorrect. Value: %v", lmt.GetIPLookups()) - } - - if lmt.SetIPLookups([]string{"X-Real-IP"}).GetIPLookups()[0] != "X-Real-IP" { - t.Errorf("IPLookups field is incorrect. Value: %v", lmt.GetIPLookups()) - } -} - func TestSetGetMethods(t *testing.T) { lmt := New(nil).SetMax(1) diff --git a/tollbooth.go b/tollbooth.go index 0dcf82e..684a24c 100644 --- a/tollbooth.go +++ b/tollbooth.go @@ -37,8 +37,7 @@ func setRateLimitResponseHeaders(lmt *limiter.Limiter, w http.ResponseWriter, to func NewLimiter(max float64, tbOptions *limiter.ExpirableOptions) *limiter.Limiter { return limiter.New(tbOptions). SetMax(max). - SetBurst(int(math.Max(1, max))). - SetIPLookups([]string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"}) + SetBurst(int(math.Max(1, max))) } // LimitByKeys keeps track number of request made by keys separated by pipe. @@ -63,7 +62,7 @@ func ShouldSkipLimiter(lmt *limiter.Limiter, r *http.Request) bool { // --------------------------------- // Filter by remote ip // If we are unable to find remoteIP, skip limiter - remoteIP := libstring.RemoteIP(lmt.GetIPLookups(), lmt.GetForwardedForIndexFromBehind(), r) + remoteIP := libstring.RemoteIPFromIPLookup(lmt.GetIPLookup(), r) remoteIP = libstring.CanonicalizeIP(remoteIP) if remoteIP == "" { return true @@ -195,7 +194,7 @@ func ShouldSkipLimiter(lmt *limiter.Limiter, r *http.Request) bool { // BuildKeys generates a slice of keys to rate-limit by given limiter and request structs. func BuildKeys(lmt *limiter.Limiter, r *http.Request) [][]string { - remoteIP := libstring.RemoteIP(lmt.GetIPLookups(), lmt.GetForwardedForIndexFromBehind(), r) + remoteIP := libstring.RemoteIPFromIPLookup(lmt.GetIPLookup(), r) remoteIP = libstring.CanonicalizeIP(remoteIP) path := r.URL.Path sliceKeys := make([][]string, 0) diff --git a/tollbooth_benchmark_test.go b/tollbooth_benchmark_test.go index a035d83..338156f 100644 --- a/tollbooth_benchmark_test.go +++ b/tollbooth_benchmark_test.go @@ -30,9 +30,12 @@ func BenchmarkLimitByKeysWithExpiringBuckets(b *testing.B) { func BenchmarkBuildKeys(b *testing.B) { lmt := limiter.New(nil).SetMax(1) // Only 1 request per second is allowed. - lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"}) - lmt.SetHeaders(make(map[string][]string)) - lmt.SetHeader("X-Real-IP", []string{"2601:7:1c82:4097:59a0:a80b:2841:b8c8"}) + lmt.SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetHeaders(make(map[string][]string)). + SetHeader("X-Real-IP", []string{"2601:7:1c82:4097:59a0:a80b:2841:b8c8"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { diff --git a/tollbooth_bug_report_test.go b/tollbooth_bug_report_test.go index 9c323d8..b8b9fbb 100644 --- a/tollbooth_bug_report_test.go +++ b/tollbooth_bug_report_test.go @@ -59,7 +59,10 @@ Top: var issue66HeaderKey = "X-Customer-ID" func issue66RateLimiter(h http.HandlerFunc, customerIDs []string) (http.HandlerFunc, *limiter.Limiter) { - allocationLimiter := NewLimiter(1, nil).SetMethods([]string{"POST"}) + allocationLimiter := NewLimiter(1, nil).SetMethods([]string{"POST"}). + SetIPLookup(limiter.IPLookup{ + Name: "RemoteAddr", + }) handler := func(w http.ResponseWriter, r *http.Request) { allocationLimiter.SetHeader(issue66HeaderKey, customerIDs) @@ -135,8 +138,7 @@ Expected to receive: %v status code. Got: %v`, func Test_Issue91_BrokenSetMethod_DontBlockGet(t *testing.T) { requestsPerSecond := float64(1) - lmt := NewLimiter(requestsPerSecond, nil) - lmt.SetMethods([]string{"POST"}) + lmt := NewLimiter(requestsPerSecond, nil).SetMethods([]string{"POST"}) methods := lmt.GetMethods() if methods[0] != "POST" { @@ -170,8 +172,10 @@ func Test_Issue91_BrokenSetMethod_DontBlockGet(t *testing.T) { func Test_Issue91_BrokenSetMethod_BlockPost(t *testing.T) { requestsPerSecond := float64(1) - lmt := NewLimiter(requestsPerSecond, nil) - lmt.SetMethods([]string{"POST"}) + lmt := NewLimiter(requestsPerSecond, nil).SetMethods([]string{"POST"}). + SetIPLookup(limiter.IPLookup{ + Name: "RemoteAddr", + }) limitReachedCounter := 0 lmt.SetOnLimitReached(func(http.ResponseWriter, *http.Request) { diff --git a/tollbooth_test.go b/tollbooth_test.go index 9d3f69c..e5865df 100644 --- a/tollbooth_test.go +++ b/tollbooth_test.go @@ -33,7 +33,10 @@ func TestLimitByKeys(t *testing.T) { } func TestDefaultBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) + lmt := NewLimiter(1, nil).SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -59,8 +62,12 @@ func TestDefaultBuildKeys(t *testing.T) { } func TestIgnoreURLBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetIgnoreURL(true) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetIgnoreURL(true) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -79,8 +86,12 @@ func TestIgnoreURLBuildKeys(t *testing.T) { } func TestBasicAuthBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetBasicAuthUsers([]string{"bro"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetBasicAuthUsers([]string{"bro"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -109,8 +120,12 @@ func TestBasicAuthBuildKeys(t *testing.T) { } func TestCustomHeadersBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -138,8 +153,12 @@ func TestCustomHeadersBuildKeys(t *testing.T) { } func TestRequestMethodBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -166,8 +185,12 @@ func TestRequestMethodBuildKeys(t *testing.T) { } func TestContextValueBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetContextValue("API-access-level", []string{"basic"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetContextValue("API-access-level", []string{"basic"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -196,9 +219,13 @@ func TestContextValueBuildKeys(t *testing.T) { } func TestRequestMethodAndCustomHeadersBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) - lmt.SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}). + SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -228,9 +255,13 @@ func TestRequestMethodAndCustomHeadersBuildKeys(t *testing.T) { } func TestRequestMethodAndBasicAuthUsersBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) - lmt.SetBasicAuthUsers([]string{"bro"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}). + SetBasicAuthUsers([]string{"bro"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -258,10 +289,14 @@ func TestRequestMethodAndBasicAuthUsersBuildKeys(t *testing.T) { } func TestRequestMethodCustomHeadersAndBasicAuthUsersBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) - lmt.SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) - lmt.SetBasicAuthUsers([]string{"bro"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}). + SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}). + SetBasicAuthUsers([]string{"bro"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -293,11 +328,15 @@ func TestRequestMethodCustomHeadersAndBasicAuthUsersBuildKeys(t *testing.T) { } func TestRequestMethodCustomHeadersAndBasicAuthUsersAndContextValuesBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) - lmt.SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) - lmt.SetContextValue("API-access-level", []string{"basic"}) - lmt.SetBasicAuthUsers([]string{"bro"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}). + SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}). + SetContextValue("API-access-level", []string{"basic"}). + SetBasicAuthUsers([]string{"bro"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -332,9 +371,12 @@ func TestRequestMethodCustomHeadersAndBasicAuthUsersAndContextValuesBuildKeys(t } func TestLimitHandler(t *testing.T) { - lmt := limiter.New(nil).SetMax(1).SetBurst(1) - lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"}) - lmt.SetMethods([]string{"POST"}) + lmt := limiter.New(nil).SetMax(1).SetBurst(1). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"POST"}) counter := 0 lmt.SetOnLimitReached(func(http.ResponseWriter, *http.Request) { counter++ }) @@ -405,10 +447,13 @@ func TestLimitHandler(t *testing.T) { } func TestOverrideForResponseWriter(t *testing.T) { - lmt := limiter.New(nil).SetMax(1).SetBurst(1) - lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"}) - lmt.SetMethods([]string{"POST"}) - lmt.SetOverrideDefaultResponseWriter(true) + lmt := limiter.New(nil).SetMax(1).SetBurst(1). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"POST"}). + SetOverrideDefaultResponseWriter(true) counter := 0 lmt.SetOnLimitReached(func(w http.ResponseWriter, _ *http.Request) { @@ -519,7 +564,10 @@ func (lm *LockMap) Add(key string, incr int64) { func TestLimitHandlerEmptyHeader(t *testing.T) { lmt := limiter.New(nil).SetMax(1).SetBurst(1) - lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"}) + lmt.SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }) lmt.SetMethods([]string{"POST"}) lmt.SetHeader("user_id", []string{})