Skip to content

Commit

Permalink
bugfix: Not providing any token in requests results in wrong error me…
Browse files Browse the repository at this point in the history
…ssage (#149)

* Fix wrong error being reported when token missing in request
* Remove a condition that never becomes true
* Add myself to the list of AUTHORS assuming this is good style
* Fix minor style issue
  • Loading branch information
FlorianLoch authored Jul 29, 2021
1 parent c61da38 commit b69cbb3
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 8 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Please keep the list sorted.

adiabatic <[email protected]>
Florian D. Loch <[email protected]>
Google LLC (https://opensource.google.com)
jamesgroat <[email protected]>
Joshua Carp <[email protected]>
Expand Down
16 changes: 11 additions & 5 deletions csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,22 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

// If the token returned from the session store is nil for non-idempotent
// ("unsafe") methods, call the error handler.
if realToken == nil {
// Retrieve the combined token (pad + masked) token...
maskedToken, err := cs.requestToken(r)
if err != nil {
r = envError(r, ErrBadToken)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}

if maskedToken == nil {
r = envError(r, ErrNoToken)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}

// Retrieve the combined token (pad + masked) token and unmask it.
requestToken := unmask(cs.requestToken(r))
// ... and unmask it.
requestToken := unmask(maskedToken)

// Compare the request token against the real token
if !compareTokens(requestToken, realToken) {
Expand Down
42 changes: 42 additions & 0 deletions csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,48 @@ func TestWithReferer(t *testing.T) {
}
}

// Requests without a token should fail with ErrNoToken.
func TestNoTokenProvided(t *testing.T) {
var finalErr error

s := http.NewServeMux()
p := Protect(testKey, ErrorHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
finalErr = FailureReason(r)
})))(s)

var token string
s.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token = Token(r)
}))

// Obtain a CSRF cookie via a GET request.
r, err := http.NewRequest("GET", "http://www.gorillatoolkit.org/", nil)
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
p.ServeHTTP(rr, r)

// POST the token back in the header.
r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil)
if err != nil {
t.Fatal(err)
}

setCookie(rr, r)
// By accident we use the wrong header name for the token...
r.Header.Set("X-CSRF-nekot", token)
r.Header.Set("Referer", "http://www.gorillatoolkit.org/")

rr = httptest.NewRecorder()
p.ServeHTTP(rr, r)

if finalErr != nil && finalErr != ErrNoToken {
t.Fatalf("middleware failed to return correct error: got '%v' want '%v'", finalErr, ErrNoToken)
}
}

func setCookie(rr *httptest.ResponseRecorder, r *http.Request) {
r.Header.Set("Cookie", rr.Header().Get("Set-Cookie"))
}
11 changes: 8 additions & 3 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func unmask(issued []byte) []byte {

// requestToken returns the issued token (pad + masked token) from the HTTP POST
// body or HTTP header. It will return nil if the token fails to decode.
func (cs *csrf) requestToken(r *http.Request) []byte {
func (cs *csrf) requestToken(r *http.Request) ([]byte, error) {
// 1. Check the HTTP header first.
issued := r.Header.Get(cs.opts.RequestHeader)

Expand All @@ -123,14 +123,19 @@ func (cs *csrf) requestToken(r *http.Request) []byte {
}
}

// Return nil (equivalent to empty byte slice) if no token was found
if issued == "" {
return nil, nil
}

// Decode the "issued" (pad + masked) token sent in the request. Return a
// nil byte slice on a decoding error (this will fail upstream).
decoded, err := base64.StdEncoding.DecodeString(issued)
if err != nil {
return nil
return nil, err
}

return decoded
return decoded, nil
}

// generateRandomBytes returns securely generated random bytes.
Expand Down

0 comments on commit b69cbb3

Please sign in to comment.