Skip to content

Commit

Permalink
Merge pull request #42 from ploxiln/xheaders_opt
Browse files Browse the repository at this point in the history
add new option "xheaders" for whether trust X-Real-IP request header
  • Loading branch information
ploxiln authored Feb 19, 2020
2 parents 12effa1 + 8c87815 commit afb2d43
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 15 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ Usage of oauth2_proxy:
-validate-url string: Access token validation endpoint
-version: print version string
-whitelist-domain value: allowed domain for redirection after authentication, leading '.' allows subdomains (may be given multiple times)
-xheaders: Trust X-Real-IP request header (appropriate when behind a reverse proxy) (default true)
```


Expand Down
4 changes: 4 additions & 0 deletions contrib/oauth2_proxy.cfg.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
# tls_cert_file = ""
# tls_key_file = ""

## whether to trust the X-Real-IP request header for logging
## disable if not running oauth2_proxy behind another reverse-proxy or load-balancer
# xheaders = true

## the OAuth Redirect URL
## defaults to "https://" + requested host header + "/oauth2/callback"
# redirect_url = "https://internalapp.yourcompany.com/oauth2/callback"
Expand Down
16 changes: 8 additions & 8 deletions logging_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,17 @@ type loggingHandler struct {
writer io.Writer
handler http.Handler
enabled bool
xheaders bool
logTemplate *template.Template
}

func LoggingHandler(out io.Writer, h http.Handler, v bool, requestLoggingTpl string) http.Handler {
func LoggingHandler(out io.Writer, h http.Handler, enabled bool, xheaders bool, requestLoggingTpl string) http.Handler {
return loggingHandler{
writer: out,
handler: h,
enabled: v,
logTemplate: template.Must(template.New("request-log").Parse(requestLoggingTpl)),
enabled: enabled,
xheaders: xheaders,
logTemplate: template.Must(template.New("request-log").Parse(requestLoggingTpl + "\n")),
}
}

Expand Down Expand Up @@ -146,9 +148,9 @@ func (h loggingHandler) writeLogLine(username, upstream string, req *http.Reques
}
}

client := req.Header.Get("X-Real-IP")
if client == "" {
client = req.RemoteAddr
client := req.RemoteAddr
if h.xheaders && req.Header.Get("X-Real-IP") != "" {
client = req.Header.Get("X-Real-IP")
}

if c, _, err := net.SplitHostPort(client); err == nil {
Expand All @@ -171,6 +173,4 @@ func (h loggingHandler) writeLogLine(username, upstream string, req *http.Reques
UserAgent: fmt.Sprintf("%q", req.UserAgent()),
Username: username,
})

h.writer.Write([]byte("\n"))
}
2 changes: 1 addition & 1 deletion logging_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) {
w.Write([]byte("test"))
}

h := LoggingHandler(buf, http.HandlerFunc(handler), true, test.Format)
h := LoggingHandler(buf, http.HandlerFunc(handler), true, true, test.Format)

r, _ := http.NewRequest("GET", "/foo/bar", nil)
r.RemoteAddr = "127.0.0.1"
Expand Down
5 changes: 3 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ func mainFlagSet() *flag.FlagSet {
flagSet.String("tls-cert", "", "path to certificate file")
flagSet.String("tls-key", "", "path to private key file")
flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"")
flagSet.Bool("set-xauthrequest", false, "set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode)")
flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint or file:// paths for static files. Routing is based on the path")
flagSet.Bool("set-xauthrequest", false, "set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode)")
flagSet.Bool("pass-user-headers", true, "pass X-Forwarded-User and X-Forwarded-Email information to upstream")
flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth header to upstream")
flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header")
Expand Down Expand Up @@ -70,6 +70,7 @@ func mainFlagSet() *flag.FlagSet {
flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag")
flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag")

flagSet.Bool("xheaders", true, "Trust X-Real-IP request header (appropriate when behind a reverse proxy)")
flagSet.Bool("request-logging", true, "Log requests to stdout")
flagSet.String("request-logging-format", defaultRequestLoggingFormat, "Template for request log lines")

Expand Down Expand Up @@ -142,7 +143,7 @@ func main() {
}

s := &Server{
Handler: LoggingHandler(os.Stdout, oauthproxy, opts.RequestLogging, opts.RequestLoggingFormat),
Handler: LoggingHandler(os.Stdout, oauthproxy, opts.RequestLogging, opts.XHeaders, opts.RequestLoggingFormat),
Opts: opts,
}
s.ListenAndServe()
Expand Down
10 changes: 6 additions & 4 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type OAuthProxy struct {
PassUserHeaders bool
BasicAuthPassword string
PassAccessToken bool
XHeaders bool
CookieCipher *cookie.Cipher
skipAuthRegex []string
skipAuthPreflight bool
Expand Down Expand Up @@ -228,6 +229,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
BasicAuthPassword: opts.BasicAuthPassword,
PassAccessToken: opts.PassAccessToken,
SkipProviderButton: opts.SkipProviderButton,
XHeaders: opts.XHeaders,
CookieCipher: cipher,
templates: loadTemplates(opts.CustomTemplatesDir),
Footer: opts.Footer,
Expand Down Expand Up @@ -500,9 +502,9 @@ func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) {
return
}

func getRemoteAddr(req *http.Request) (s string) {
func (p *OAuthProxy) getRemoteAddr(req *http.Request) (s string) {
s = req.RemoteAddr
if req.Header.Get("X-Real-IP") != "" {
if p.XHeaders && req.Header.Get("X-Real-IP") != "" {
s += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP"))
}
return
Expand Down Expand Up @@ -574,7 +576,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
}

func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
remoteAddr := getRemoteAddr(req)
remoteAddr := p.getRemoteAddr(req)

// finish the oauth cycle
err := req.ParseForm()
Expand Down Expand Up @@ -661,7 +663,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {

func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int {
var saveSession, clearSession, revalidated bool
remoteAddr := getRemoteAddr(req)
remoteAddr := p.getRemoteAddr(req)

session, sessionAge, err := p.LoadCookiedSession(req)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ type Options struct {
Scope string `flag:"scope" cfg:"scope"`
ApprovalPrompt string `flag:"approval-prompt" cfg:"approval_prompt"`

XHeaders bool `flag:"xheaders" cfg:"xheaders"`
RequestLogging bool `flag:"request-logging" cfg:"request_logging"`
RequestLoggingFormat string `flag:"request-logging-format" cfg:"request_logging_format"`

Expand Down Expand Up @@ -118,6 +119,7 @@ func NewOptions() *Options {
PassAccessToken: false,
PassHostHeader: true,
ApprovalPrompt: "force",
XHeaders: true,
RequestLogging: true,
RequestLoggingFormat: defaultRequestLoggingFormat,
}
Expand Down

0 comments on commit afb2d43

Please sign in to comment.