Skip to content

Commit

Permalink
Add support for custom auth providers
Browse files Browse the repository at this point in the history
  • Loading branch information
pkukielka committed Jan 5, 2025
1 parent f3fe8b6 commit 5bf3cd4
Show file tree
Hide file tree
Showing 39 changed files with 429 additions and 137 deletions.
68 changes: 68 additions & 0 deletions agent/scripts/reverse-proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from aiohttp import web, ClientSession
from urllib.parse import urlparse
import asyncio

target_url = ''
port = 5050

async def proxy_handler(request):
async with ClientSession(auto_decompress=False) as session:
print(f'Request to: {request.url}')

# Modify headers here
headers = dict(request.headers)

# Reset the Host header to use target server host instead of the proxy host
if 'Host' in headers:
headers['Host'] = urlparse(target_url).netloc.split(':')[0]

# 'chunked' encoding results in error 400 from Cloudflare, removing it still keeps response chunked anyway
if 'Transfer-Encoding' in headers:
del headers['Transfer-Encoding']

# Use value of 'Authorization: Bearer' to fill 'X-Forwarded-User' and remove 'Authorization' header
if 'Authorization' in headers:
values = headers['Authorization'].split()
if values and values[0] == 'Bearer':
headers['X-Forwarded-User'] = values[1]
del headers['Authorization']

# Forward the request to target
async with session.request(
method=request.method,
url=f'{target_url}{request.path_qs}',
headers=headers,
data=await request.read()
) as response:
proxy_response = web.StreamResponse(
status=response.status,
headers=response.headers
)

await proxy_response.prepare(request)

# Stream the response back
async for chunk in response.content.iter_chunks():
await proxy_response.write(chunk[0])

await proxy_response.write_eof()
return proxy_response

app = web.Application()
app.router.add_route('*', '/{path_info:.*}', proxy_handler)


if __name__ == '__main__':
print('Usage: python reverse_proxy.py [target_url] [proxy_port]')

import sys
if (len(sys.argv) < 2):
print('Please specify target_url')
sys.exit(1)
if len(sys.argv) > 1:
target_url = sys.argv[1]
if len(sys.argv) > 2:
port = int(sys.argv[2])

print(f'Starting proxy server on port {port} targeting {target_url}...')
web.run_app(app, port=port)
12 changes: 11 additions & 1 deletion agent/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import path from 'node:path'
import type { Polly, Request } from '@pollyjs/core'
import {
type AccountKeyedChatHistory,
AuthCredentials,
type ChatHistoryKey,
type ClientCapabilities,
type CodyCommand,
Expand All @@ -14,6 +15,7 @@ import {
currentAuthStatusAuthed,
firstNonPendingAuthStatus,
firstResultFromOperation,
resolveAuth,
resolvedConfig,
telemetryRecorder,
waitUntilComplete,
Expand Down Expand Up @@ -70,6 +72,7 @@ import type { FixupActor, FixupFileCollection } from '../../vscode/src/non-stop/
import type { FixupControlApplicator } from '../../vscode/src/non-stop/strategies'
import { authProvider } from '../../vscode/src/services/AuthProvider'
import { localStorage } from '../../vscode/src/services/LocalStorageProvider'
import { secretStorage } from '../../vscode/src/services/SecretStorageProvider'
import { AgentWorkspaceEdit } from '../../vscode/src/testutils/AgentWorkspaceEdit'
import { AgentAuthHandler } from './AgentAuthHandler'
import { AgentFixupControls } from './AgentFixupControls'
Expand Down Expand Up @@ -1482,6 +1485,11 @@ export class Agent extends MessageHandler implements ExtensionClient {
return this.clientInfo?.capabilities ?? undefined
}

private async resolveAuthCredentials(extCfg: ExtensionConfiguration): Promise<AuthCredentials> {
const config = JSON.parse(extCfg.customConfigurationJson ?? '{}')
return resolveAuth(extCfg.serverEndpoint, config, secretStorage)
}

private async handleConfigChanges(
config: ExtensionConfiguration,
params?: { forceAuthentication: boolean }
Expand All @@ -1500,7 +1508,9 @@ export class Agent extends MessageHandler implements ExtensionClient {
},
auth: {
serverEndpoint: config.serverEndpoint,
accessToken: config.accessToken ?? null,
accessTokenOrHeaders:
config.accessToken ??
(await this.resolveAuthCredentials(config)).accessTokenOrHeaders,
},
clientState: {
anonymousUserID: config.anonymousUserID ?? null,
Expand Down
2 changes: 1 addition & 1 deletion agent/src/cli/command-auth/AuthenticatedAccount.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export class AuthenticatedAccount {
): Promise<AuthenticatedAccount | Error> {
const graphqlClient = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessToken: options.accessToken, serverEndpoint: options.endpoint },
auth: { accessTokenOrHeaders: options.accessToken, serverEndpoint: options.endpoint },
clientState: { anonymousUserID: null },
})
const userInfo = await graphqlClient.getCurrentUserInfo()
Expand Down
4 changes: 2 additions & 2 deletions agent/src/cli/command-auth/command-login.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async function loginAction(
: await captureAccessTokenViaBrowserRedirect(serverEndpoint, spinner)
const client = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessToken: token, serverEndpoint: serverEndpoint },
auth: { accessTokenOrHeaders: token, serverEndpoint: serverEndpoint },
clientState: { anonymousUserID: null },
})
const userInfo = await client.getCurrentUserInfo()
Expand Down Expand Up @@ -256,7 +256,7 @@ async function promptUserAboutLoginMethod(spinner: Ora, options: LoginOptions):
try {
const client = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessToken: options.accessToken, serverEndpoint: options.endpoint },
auth: { accessTokenOrHeaders: options.accessToken, serverEndpoint: options.endpoint },
clientState: { anonymousUserID: null },
})
const userInfo = await client.getCurrentUserInfo()
Expand Down
2 changes: 1 addition & 1 deletion agent/src/cli/command-bench/command-bench.ts
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ export const benchCommand = new commander.Command('bench')
setStaticResolvedConfigurationWithAuthCredentials({
configuration: { customHeaders: {} },
auth: {
accessToken: options.srcAccessToken,
accessTokenOrHeaders: options.srcAccessToken,
serverEndpoint: options.srcEndpoint,
},
})
Expand Down
2 changes: 1 addition & 1 deletion agent/src/cli/command-bench/llm-judge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export class LlmJudge {
localStorage.setStorage('noop')
setStaticResolvedConfigurationWithAuthCredentials({
configuration: { customHeaders: undefined },
auth: { accessToken: options.srcAccessToken, serverEndpoint: options.srcEndpoint },
auth: { accessTokenOrHeaders: options.srcAccessToken, serverEndpoint: options.srcEndpoint },
})
setClientCapabilities({ configuration: {}, agentCapabilities: undefined })
this.client = new SourcegraphNodeCompletionsClient()
Expand Down
5 changes: 4 additions & 1 deletion agent/src/local-e2e/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ export class LocalSGInstance {
// for checking the LLM configuration section.
this.gqlclient = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { customHeaders: headers, telemetryLevel: 'agent' },
auth: { accessToken: this.params.accessToken, serverEndpoint: this.params.serverEndpoint },
auth: {
accessTokenOrHeaders: this.params.accessToken,
serverEndpoint: this.params.serverEndpoint,
},
clientState: { anonymousUserID: null },
})
}
Expand Down
8 changes: 7 additions & 1 deletion agent/src/vscode-shim.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,15 @@ export function isAuthenticationChange(newConfig: ExtensionConfiguration): boole
return true
}

function getExternalAuthProvidersConfig(cfg: ExtensionConfiguration) {
return JSON.parse(cfg.customConfigurationJson ?? '{}')?.cody?.auth?.externalProviders
}

return (
extensionConfiguration.accessToken !== newConfig.accessToken ||
extensionConfiguration.serverEndpoint !== newConfig.serverEndpoint
extensionConfiguration.serverEndpoint !== newConfig.serverEndpoint ||
getExternalAuthProvidersConfig(extensionConfiguration) !==
getExternalAuthProvidersConfig(newConfig)
)
}

Expand Down
23 changes: 21 additions & 2 deletions lib/shared/src/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ import type { ReadonlyDeep } from './utils'
* A redirect flow is initiated by the user clicking a link in the browser, while a paste flow is initiated by the user
* manually entering the access from into the VsCode App.
*/
export type TokenSource = 'redirect' | 'paste'
export type TokenSource = 'redirect' | 'paste' | 'custom-auth-provider'

export type AuthHeaders = Record<string, string>

/**
* The user's authentication credentials, which are stored separately from the rest of the
* configuration.
*/
export interface AuthCredentials {
serverEndpoint: string
accessToken: string | null
tokenSource?: TokenSource | undefined
accessTokenOrHeaders: string | AuthHeaders | null
}

export interface AutoEditsTokenLimit {
Expand Down Expand Up @@ -71,6 +73,20 @@ export interface AgenticContextConfiguration {
}
}

export interface ExternalAuthCommand {
commandLine: string[]
environment?: Record<string, string>
workingDir?: string
shell?: string
timeout?: number
windowsHide?: boolean
}

export interface ExternalAuthProvider {
endpoint: string
executable: ExternalAuthCommand
}

interface RawClientConfiguration {
net: NetConfiguration
codebase?: string
Expand Down Expand Up @@ -166,6 +182,9 @@ interface RawClientConfiguration {
*/
overrideServerEndpoint?: string | undefined
overrideAuthToken?: string | undefined

// External auth providers
authExternalProviders?: ExternalAuthProvider[]
}

/**
Expand Down
127 changes: 104 additions & 23 deletions lib/shared/src/configuration/resolver.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import { Observable, map } from 'observable-fns'
import type { AuthCredentials, ClientConfiguration } from '../configuration'
import type {
AuthCredentials,
ClientConfiguration,
ExternalAuthCommand,
ExternalAuthProvider,
TokenSource,
} from '../configuration'
import { logError } from '../logger'
import {
distinctUntilChanged,
Expand Down Expand Up @@ -27,6 +33,7 @@ export interface ConfigurationInput {

export interface ClientSecrets {
getToken(endpoint: string): Promise<string | undefined>
getTokenSource(endpoint: string): Promise<TokenSource | undefined>
}

export interface ClientState {
Expand Down Expand Up @@ -72,39 +79,113 @@ export type PickResolvedConfiguration<Keys extends KeysSpec> = {
: undefined
}

export type ExternalProviderAuthConfiguration = {
authExternalProviders?: ExternalAuthProvider[]
overrideAuthToken?: string | undefined
}

async function executeCommand(cmd: ExternalAuthCommand): Promise<string> {
if (typeof process === 'undefined' || !process.version) {
throw new Error('Command execution is only supported in Node.js environments')
}

const { exec } = await import('node:child_process')
const { promisify } = await import('node:util')
const execAsync = promisify(exec)

const command = cmd.commandLine.join(' ')
const options = {
...cmd,
env: cmd.environment ? { ...process.env, ...cmd.environment } : process.env,
}

const { stdout, stderr } = await execAsync(command, options)
if (stderr) {
throw new Error(`External auth command error: ${stderr}`)
}
return stdout.trim()
}

async function getExternalProviderAuthHeaders(
serverEndpoint: string,
clientConfiguration: ExternalProviderAuthConfiguration
): Promise<Record<string, string> | undefined> {
// Check for external auth provider for this endpoint
const externalProvider = clientConfiguration.authExternalProviders?.find(
provider => normalizeServerEndpointURL(provider.endpoint) === serverEndpoint
)

if (externalProvider) {
const result = await executeCommand(externalProvider.executable)
return JSON.parse(result)
}

return undefined
}

export async function resolveAuth(
endpoint: string,
clientConfiguration: ExternalProviderAuthConfiguration,
clientSecrets: ClientSecrets
): Promise<AuthCredentials> {
let accessTokenOrHeaders = null
let tokenSource: TokenSource | undefined = undefined

const serverEndpoint = normalizeServerEndpointURL(endpoint)

// We must not throw here, because that would result in the `resolvedConfig` observable
// terminating and all callers receiving no further config updates.
const loadTokenFn = () =>
clientSecrets.getToken(serverEndpoint).catch(error => {
throw new Error(
`Failed to get access token for endpoint ${serverEndpoint}: ${error.message || error}`
)
})

if (clientConfiguration.overrideAuthToken) {
accessTokenOrHeaders = clientConfiguration.overrideAuthToken
} else
try {
const authHeaders = await getExternalProviderAuthHeaders(serverEndpoint, clientConfiguration)
if (authHeaders) {
accessTokenOrHeaders = authHeaders
tokenSource = 'custom-auth-provider'
} else {
accessTokenOrHeaders = (await loadTokenFn()) || null
tokenSource = await clientSecrets.getTokenSource(serverEndpoint).catch(_ => undefined)
}
} catch (error) {
throw new Error(`Failed to execute external auth command: ${error}`)
}

return { accessTokenOrHeaders, serverEndpoint, tokenSource }
}

async function resolveConfiguration({
clientConfiguration,
clientSecrets,
clientState,
reinstall: { isReinstalling, onReinstall },
}: ConfigurationInput): Promise<ResolvedConfiguration> {
}: ConfigurationInput): Promise<ResolvedConfiguration | Error> {
const isReinstall = await isReinstalling()
if (isReinstall) {
await onReinstall()
}
// we allow for overriding the server endpoint from config if we haven't
// manually signed in somewhere else
const serverEndpoint = normalizeServerEndpointURL(

const serverEndpoint =
clientConfiguration.overrideServerEndpoint ||
(clientState.lastUsedEndpoint ?? DOTCOM_URL.toString())
)
clientState.lastUsedEndpoint ||
DOTCOM_URL.toString()

// We must not throw here, because that would result in the `resolvedConfig` observable
// terminating and all callers receiving no further config updates.
const loadTokenFn = () =>
clientSecrets.getToken(serverEndpoint).catch(error => {
logError(
'resolveConfiguration',
`Failed to get access token for endpoint ${serverEndpoint}: ${error}`
)
return null
})
const accessToken = clientConfiguration.overrideAuthToken || ((await loadTokenFn()) ?? null)
return {
configuration: clientConfiguration,
clientState,
auth: { accessToken, serverEndpoint },
isReinstall,
try {
const auth = await resolveAuth(serverEndpoint, clientConfiguration, clientSecrets)
return { configuration: clientConfiguration, clientState, auth, isReinstall }
} catch (error) {
// We don't want to throw here, because that would cause the observable to terminate and
// all callers receiving no further config updates.
logError('resolveConfiguration', `Error resolving configuration: ${error}`)
const auth = { accessTokenOrHeaders: null, serverEndpoint, tokenSource: undefined }
return { configuration: clientConfiguration, clientState, auth, isReinstall }
}
}

Expand Down
Loading

0 comments on commit 5bf3cd4

Please sign in to comment.