Skip to content

Commit

Permalink
[feature] UnsafeSkipCheck - skip the CSRF check for the given request.
Browse files Browse the repository at this point in the history
  • Loading branch information
elithrar committed May 28, 2016
2 parents 86f8ce2 + 27465b8 commit d03564e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
21 changes: 16 additions & 5 deletions csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ const tokenLength = 32

// Context/session keys & prefixes
const (
tokenKey string = "gorilla.csrf.Token"
formKey string = "gorilla.csrf.Form"
errorKey string = "gorilla.csrf.Error"
cookieName string = "_gorilla_csrf"
errorPrefix string = "gorilla/csrf: "
tokenKey string = "gorilla.csrf.Token"
formKey string = "gorilla.csrf.Form"
errorKey string = "gorilla.csrf.Error"
skipCheckKey string = "gorilla.csrf.Skip"
cookieName string = "_gorilla_csrf"
errorPrefix string = "gorilla/csrf: "
)

var (
Expand Down Expand Up @@ -172,6 +173,16 @@ func Protect(authKey []byte, opts ...Option) func(http.Handler) http.Handler {

// Implements http.Handler for the csrf type.
func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Skip the check if directed to. This should always be a bool.
if val, ok := context.GetOk(r, skipCheckKey); ok {
if skip, ok := val.(bool); ok {
if skip {
cs.h.ServeHTTP(w, r)
return
}
}
}

// Retrieve the token from the session.
// An error represents either a cookie that failed HMAC validation
// or that doesn't exist.
Expand Down
10 changes: 10 additions & 0 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ func FailureReason(r *http.Request) error {
return nil
}

// UnsafeSkipCheck will skip the CSRF check for any requests. This must be
// called before the CSRF middleware.
//
// Note: You should not set this without otherwise securing the request from
// CSRF attacks. The primary use-case for this function is to turn off CSRF
// checks for non-browser clients using authorization tokens against your API.
func UnsafeSkipCheck(r *http.Request) {
context.Set(r, skipCheckKey, true)
}

// TemplateField is a template helper for html/template that provides an <input> field
// populated with a CSRF token.
//
Expand Down
35 changes: 35 additions & 0 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,38 @@ func TestCompareTokens(t *testing.T) {
t.Fatalf("compareTokens failed on different tokens: got %v want %v", v, !v)
}
}

func TestUnsafeSkipCSRFCheck(t *testing.T) {
s := http.NewServeMux()
skipCheck := func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
UnsafeSkipCheck(r)
h.ServeHTTP(w, r)
}

return http.HandlerFunc(fn)
}

var teapot = 418

// Issue a POST request without a CSRF token in the request.
s.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set a non-200 header to make the test explicit.
w.WriteHeader(teapot)
}))

r, err := http.NewRequest("POST", "/", nil)
if err != nil {
t.Fatal(err)
}

// Must be used prior to the CSRF handler being invoked.
p := skipCheck(Protect(testKey)(s))
rr := httptest.NewRecorder()
p.ServeHTTP(rr, r)

if status := rr.Code; status != teapot {
t.Fatalf("middleware failed to skip this request: got %v want %v",
status, teapot)
}
}

0 comments on commit d03564e

Please sign in to comment.