diff --git a/csrf.go b/csrf.go index 6d12cc0..fe5e933 100644 --- a/csrf.go +++ b/csrf.go @@ -15,11 +15,12 @@ const tokenLength = 32 // Context/session keys & prefixes const ( - tokenKey string = "gorilla.csrf.Token" - formKey string = "gorilla.csrf.Form" - errorKey string = "gorilla.csrf.Error" - cookieName string = "_gorilla_csrf" - errorPrefix string = "gorilla/csrf: " + tokenKey string = "gorilla.csrf.Token" + formKey string = "gorilla.csrf.Form" + errorKey string = "gorilla.csrf.Error" + skipCheckKey string = "gorilla.csrf.Skip" + cookieName string = "_gorilla_csrf" + errorPrefix string = "gorilla/csrf: " ) var ( @@ -172,6 +173,16 @@ func Protect(authKey []byte, opts ...Option) func(http.Handler) http.Handler { // Implements http.Handler for the csrf type. func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Skip the check if directed to. This should always be a bool. + if val, ok := context.GetOk(r, skipCheckKey); ok { + if skip, ok := val.(bool); ok { + if skip { + cs.h.ServeHTTP(w, r) + return + } + } + } + // Retrieve the token from the session. // An error represents either a cookie that failed HMAC validation // or that doesn't exist. diff --git a/helpers.go b/helpers.go index e01d117..e0f23e0 100644 --- a/helpers.go +++ b/helpers.go @@ -38,6 +38,16 @@ func FailureReason(r *http.Request) error { return nil } +// UnsafeSkipCheck will skip the CSRF check for any requests. This must be +// called before the CSRF middleware. +// +// Note: You should not set this without otherwise securing the request from +// CSRF attacks. The primary use-case for this function is to turn off CSRF +// checks for non-browser clients using authorization tokens against your API. +func UnsafeSkipCheck(r *http.Request) { + context.Set(r, skipCheckKey, true) +} + // TemplateField is a template helper for html/template that provides an field // populated with a CSRF token. // diff --git a/helpers_test.go b/helpers_test.go index 09fbe3e..f1340ea 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -265,3 +265,38 @@ func TestCompareTokens(t *testing.T) { t.Fatalf("compareTokens failed on different tokens: got %v want %v", v, !v) } } + +func TestUnsafeSkipCSRFCheck(t *testing.T) { + s := http.NewServeMux() + skipCheck := func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + UnsafeSkipCheck(r) + h.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) + } + + var teapot = 418 + + // Issue a POST request without a CSRF token in the request. + s.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Set a non-200 header to make the test explicit. + w.WriteHeader(teapot) + })) + + r, err := http.NewRequest("POST", "/", nil) + if err != nil { + t.Fatal(err) + } + + // Must be used prior to the CSRF handler being invoked. + p := skipCheck(Protect(testKey)(s)) + rr := httptest.NewRecorder() + p.ServeHTTP(rr, r) + + if status := rr.Code; status != teapot { + t.Fatalf("middleware failed to skip this request: got %v want %v", + status, teapot) + } +}