Skip to content

Commit

Permalink
Add option to use base64 URL-safe encoding format for CSRF token
Browse files Browse the repository at this point in the history
  • Loading branch information
Sang-Hyuk committed Dec 6, 2024
1 parent a009743 commit aaa75a9
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 6 deletions.
8 changes: 7 additions & 1 deletion csrf.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package csrf

import (
"encoding/base64"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -87,6 +88,7 @@ type options struct {
// http.Cookie field instead of the "correct" HTTPOnly name that golint suggests.
HttpOnly bool
Secure bool
URLSafe bool
SameSite SameSiteMode
RequestHeader string
FieldName string
Expand Down Expand Up @@ -235,7 +237,11 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

// Save the masked token to the request context
r = contextSave(r, tokenKey, mask(realToken, r))
encoding := base64.StdEncoding
if cs.opts.URLSafe {
encoding = base64.URLEncoding
}
r = contextSave(r, tokenKey, mask(realToken, r, encoding))
// Save the field name to the request context
r = contextSave(r, formKey, cs.opts.FieldName)

Expand Down
12 changes: 9 additions & 3 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TemplateField(r *http.Request) template.HTML {
// token and returning them together as a 64-byte slice. This effectively
// randomises the token on a per-request basis without breaking multiple browser
// tabs/windows.
func mask(realToken []byte, _ *http.Request) string {
func mask(realToken []byte, _ *http.Request, encoding *base64.Encoding) string {
otp, err := generateRandomBytes(tokenLength)
if err != nil {
return ""
Expand All @@ -83,7 +83,7 @@ func mask(realToken []byte, _ *http.Request) string {
// XOR the OTP with the real token to generate a masked token. Append the
// OTP to the front of the masked token to allow unmasking in the subsequent
// request.
return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...))
return encoding.EncodeToString(append(otp, xorToken(otp, realToken)...))
}

// unmask splits the issued token (one-time-pad + masked token) and returns the
Expand Down Expand Up @@ -129,7 +129,13 @@ func (cs *csrf) requestToken(r *http.Request) ([]byte, error) {

// Decode the "issued" (pad + masked) token sent in the request. Return a
// nil byte slice on a decoding error (this will fail upstream).
decoded, err := base64.StdEncoding.DecodeString(issued)
encoding := base64.StdEncoding

if cs.opts.URLSafe {
encoding = base64.URLEncoding
}

decoded, err := encoding.DecodeString(issued)
if err != nil {
return nil, err
}
Expand Down
6 changes: 4 additions & 2 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ func TestMaskUnmaskTokens(t *testing.T) {
t.Fatal(err)
}

issued := mask(realToken, nil)
decoded, err := base64.StdEncoding.DecodeString(issued)
encoding := base64.StdEncoding

issued := mask(realToken, nil, encoding)
decoded, err := encoding.DecodeString(issued)
if err != nil {
t.Fatal(err)
}
Expand Down
7 changes: 7 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ func TrustedOrigins(origins []string) Option {
}
}

// URLSafe changes the base64 encoding format ( URL safe ) of the CSRF token.
func URLSafe(s bool) Option {
return func(cs *csrf) {
cs.opts.URLSafe = s
}
}

// setStore sets the store used by the CSRF middleware.
// Note: this is private (for now) to allow for internal API changes.
func setStore(s store) Option {
Expand Down
20 changes: 20 additions & 0 deletions options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,23 @@ func TestMaxAge(t *testing.T) {
})

}

func TestURLSafe(t *testing.T) {
t.Run("Ensure the default URLSafe is applied", func(t *testing.T) {
handler := Protect(testKey)(nil)
cs := handler.(*csrf)

if cs.opts.URLSafe != false {
t.Fatalf("default URLSafe not applied: got %v (want %v)", cs.opts.URLSafe, false)
}
})

t.Run("Support an explicit URLSafe of true", func(t *testing.T) {
handler := Protect(testKey, URLSafe(true))(nil)
cs := handler.(*csrf)

if cs.opts.URLSafe != true {
t.Fatalf("URLSafe not applied: got %v (want %v)", cs.opts.URLSafe, true)
}
})
}

0 comments on commit aaa75a9

Please sign in to comment.