Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Authentication Providers Support for Cody #6526

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions agent/scripts/reverse-proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from aiohttp import web, ClientSession
from urllib.parse import urlparse
import asyncio

target_url = ''
port = 5050

async def proxy_handler(request):
async with ClientSession(auto_decompress=False) as session:
print(f'Request to: {request.url}')

# Modify headers here
headers = dict(request.headers)

# Reset the Host header to use target server host instead of the proxy host
if 'Host' in headers:
headers['Host'] = urlparse(target_url).netloc.split(':')[0]

# 'chunked' encoding results in error 400 from Cloudflare, removing it still keeps response chunked anyway
if 'Transfer-Encoding' in headers:
del headers['Transfer-Encoding']

# Use value of 'Authorization: Bearer' to fill 'X-Forwarded-User' and remove 'Authorization' header
if 'Authorization' in headers:
values = headers['Authorization'].split()
if values and values[0] == 'Bearer':
headers['X-Forwarded-User'] = values[1]
del headers['Authorization']

# Forward the request to target
async with session.request(
method=request.method,
url=f'{target_url}{request.path_qs}',
headers=headers,
data=await request.read()
) as response:
proxy_response = web.StreamResponse(
status=response.status,
headers=response.headers
)

await proxy_response.prepare(request)

# Stream the response back
async for chunk in response.content.iter_chunks():
await proxy_response.write(chunk[0])

await proxy_response.write_eof()
return proxy_response

app = web.Application()
app.router.add_route('*', '/{path_info:.*}', proxy_handler)


if __name__ == '__main__':
print('Usage: python reverse_proxy.py [target_url] [proxy_port]')

import sys
if (len(sys.argv) < 2):
print('Please specify target_url')
sys.exit(1)
if len(sys.argv) > 1:
target_url = sys.argv[1]
if len(sys.argv) > 2:
port = int(sys.argv[2])

print(f'Starting proxy server on port {port} targeting {target_url}...')
web.run_app(app, port=port)
2 changes: 1 addition & 1 deletion agent/src/AgentWorkspaceConfiguration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ export class AgentWorkspaceConfiguration implements vscode.WorkspaceConfiguratio

function mergeWithBaseConfig(config: any) {
for (const [key, value] of Object.entries(config)) {
if (typeof value === 'object') {
if (typeof value === 'object' && !Array.isArray(value)) {
const existing = _.get(baseConfig, key) ?? {}
const merged = _.merge(existing, value)
_.set(baseConfig, key, merged)
Expand Down
7 changes: 3 additions & 4 deletions agent/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1486,11 +1486,10 @@ export class Agent extends MessageHandler implements ExtensionClient {
config: ExtensionConfiguration,
params?: { forceAuthentication: boolean }
): Promise<AuthStatus> {
const isAuthChange = vscode_shim.isAuthenticationChange(config)
const isAuthChange = vscode_shim.isTokenOrEndpointChange(config)
vscode_shim.setExtensionConfiguration(config)

// If this is an authentication change we need to reauthenticate prior to firing events
// that update the clients
// If this is an token or endpoint change we need to save them prior to firing events that update the clients
try {
if (isAuthChange || params?.forceAuthentication) {
await authProvider.validateAndStoreCredentials(
Expand All @@ -1500,7 +1499,7 @@ export class Agent extends MessageHandler implements ExtensionClient {
},
auth: {
serverEndpoint: config.serverEndpoint,
accessToken: config.accessToken ?? null,
accessTokenOrHeaders: config.accessToken || null,
},
clientState: {
anonymousUserID: config.anonymousUserID ?? null,
Expand Down
2 changes: 1 addition & 1 deletion agent/src/cli/command-auth/AuthenticatedAccount.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export class AuthenticatedAccount {
): Promise<AuthenticatedAccount | Error> {
const graphqlClient = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessToken: options.accessToken, serverEndpoint: options.endpoint },
auth: { accessTokenOrHeaders: options.accessToken, serverEndpoint: options.endpoint },
clientState: { anonymousUserID: null },
})
const userInfo = await graphqlClient.getCurrentUserInfo()
Expand Down
4 changes: 2 additions & 2 deletions agent/src/cli/command-auth/command-login.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async function loginAction(
: await captureAccessTokenViaBrowserRedirect(serverEndpoint, spinner)
const client = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessToken: token, serverEndpoint: serverEndpoint },
auth: { accessTokenOrHeaders: token, serverEndpoint: serverEndpoint },
clientState: { anonymousUserID: null },
})
const userInfo = await client.getCurrentUserInfo()
Expand Down Expand Up @@ -256,7 +256,7 @@ async function promptUserAboutLoginMethod(spinner: Ora, options: LoginOptions):
try {
const client = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessToken: options.accessToken, serverEndpoint: options.endpoint },
auth: { accessTokenOrHeaders: options.accessToken, serverEndpoint: options.endpoint },
clientState: { anonymousUserID: null },
})
const userInfo = await client.getCurrentUserInfo()
Expand Down
2 changes: 1 addition & 1 deletion agent/src/cli/command-bench/command-bench.ts
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ export const benchCommand = new commander.Command('bench')
setStaticResolvedConfigurationWithAuthCredentials({
configuration: { customHeaders: {} },
auth: {
accessToken: options.srcAccessToken,
accessTokenOrHeaders: options.srcAccessToken,
serverEndpoint: options.srcEndpoint,
},
})
Expand Down
2 changes: 1 addition & 1 deletion agent/src/cli/command-bench/llm-judge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export class LlmJudge {
localStorage.setStorage('noop')
setStaticResolvedConfigurationWithAuthCredentials({
configuration: { customHeaders: undefined },
auth: { accessToken: options.srcAccessToken, serverEndpoint: options.srcEndpoint },
auth: { accessTokenOrHeaders: options.srcAccessToken, serverEndpoint: options.srcEndpoint },
})
setClientCapabilities({ configuration: {}, agentCapabilities: undefined })
this.client = new SourcegraphNodeCompletionsClient()
Expand Down
5 changes: 4 additions & 1 deletion agent/src/local-e2e/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ export class LocalSGInstance {
// for checking the LLM configuration section.
this.gqlclient = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { customHeaders: headers, telemetryLevel: 'agent' },
auth: { accessToken: this.params.accessToken, serverEndpoint: this.params.serverEndpoint },
auth: {
accessTokenOrHeaders: this.params.accessToken,
serverEndpoint: this.params.serverEndpoint,
},
clientState: { anonymousUserID: null },
})
}
Expand Down
3 changes: 2 additions & 1 deletion agent/src/vscode-shim.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ export let extensionConfiguration: ExtensionConfiguration | undefined
export function setExtensionConfiguration(newConfig: ExtensionConfiguration): void {
extensionConfiguration = newConfig
}
export function isAuthenticationChange(newConfig: ExtensionConfiguration): boolean {

export function isTokenOrEndpointChange(newConfig: ExtensionConfiguration): boolean {
if (!extensionConfiguration) {
return true
}
Expand Down
23 changes: 21 additions & 2 deletions lib/shared/src/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ import type { ReadonlyDeep } from './utils'
* A redirect flow is initiated by the user clicking a link in the browser, while a paste flow is initiated by the user
* manually entering the access from into the VsCode App.
*/
export type TokenSource = 'redirect' | 'paste'
export type TokenSource = 'redirect' | 'paste' | 'custom-auth-provider'

export type AuthHeaders = Record<string, string>

/**
* The user's authentication credentials, which are stored separately from the rest of the
* configuration.
*/
export interface AuthCredentials {
serverEndpoint: string
accessToken: string | null
tokenSource?: TokenSource | undefined
accessTokenOrHeaders: string | AuthHeaders | null
}

export interface AutoEditsTokenLimit {
Expand Down Expand Up @@ -71,6 +73,20 @@ export interface AgenticContextConfiguration {
}
}

export interface ExternalAuthCommand {
commandLine: string[]
environment?: Record<string, string>
workingDir?: string
shell?: string
timeout?: number
windowsHide?: boolean
}

export interface ExternalAuthProvider {
endpoint: string
executable: ExternalAuthCommand
}

interface RawClientConfiguration {
net: NetConfiguration
codebase?: string
Expand Down Expand Up @@ -166,6 +182,9 @@ interface RawClientConfiguration {
*/
overrideServerEndpoint?: string | undefined
overrideAuthToken?: string | undefined

// External auth providers
authExternalProviders?: ExternalAuthProvider[]
}

/**
Expand Down
123 changes: 100 additions & 23 deletions lib/shared/src/configuration/resolver.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { Observable, map } from 'observable-fns'
import type { AuthCredentials, ClientConfiguration } from '../configuration'
import type {
AuthCredentials,
ClientConfiguration,
ExternalAuthCommand,
TokenSource,
} from '../configuration'
import { logError } from '../logger'
import {
distinctUntilChanged,
Expand Down Expand Up @@ -27,6 +32,7 @@ export interface ConfigurationInput {

export interface ClientSecrets {
getToken(endpoint: string): Promise<string | undefined>
getTokenSource(endpoint: string): Promise<TokenSource | undefined>
}

export interface ClientState {
Expand Down Expand Up @@ -72,39 +78,110 @@ export type PickResolvedConfiguration<Keys extends KeysSpec> = {
: undefined
}

async function executeCommand(cmd: ExternalAuthCommand): Promise<string> {
if (typeof process === 'undefined' || !process.version) {
throw new Error('Command execution is only supported in Node.js environments')
}

const { exec } = await import('node:child_process')
const { promisify } = await import('node:util')
const execAsync = promisify(exec)

const command = cmd.commandLine.join(' ')
const options = {
...cmd,
env: cmd.environment ? { ...process.env, ...cmd.environment } : process.env,
}

const { stdout, stderr } = await execAsync(command, options)
if (stderr) {
throw new Error(`External auth command error: ${stderr}`)
}
return stdout.trim()
}

async function getExternalProviderAuthHeaders(
serverEndpoint: string,
clientConfiguration: ClientConfiguration
): Promise<Record<string, string> | undefined> {
// Check for external auth provider for this endpoint
const externalProvider = clientConfiguration.authExternalProviders?.find(
provider => normalizeServerEndpointURL(provider.endpoint) === serverEndpoint
)

if (externalProvider) {
const result = await executeCommand(externalProvider.executable)
return JSON.parse(result)
}

return undefined
}

export async function resolveAuth(
endpoint: string,
clientConfiguration: ClientConfiguration,
clientSecrets: ClientSecrets
): Promise<AuthCredentials> {
let accessTokenOrHeaders = null
let tokenSource: TokenSource | undefined = undefined

const serverEndpoint = normalizeServerEndpointURL(
clientConfiguration.overrideServerEndpoint || endpoint
)

// We must not throw here, because that would result in the `resolvedConfig` observable
// terminating and all callers receiving no further config updates.
const loadTokenFn = () =>
clientSecrets.getToken(serverEndpoint).catch(error => {
throw new Error(
`Failed to get access token for endpoint ${serverEndpoint}: ${error.message || error}`
)
})

if (clientConfiguration.overrideAuthToken) {
accessTokenOrHeaders = clientConfiguration.overrideAuthToken
} else
try {
const authHeaders = await getExternalProviderAuthHeaders(serverEndpoint, clientConfiguration)
if (authHeaders) {
accessTokenOrHeaders = authHeaders
tokenSource = 'custom-auth-provider'
} else {
accessTokenOrHeaders = (await loadTokenFn()) || null
tokenSource = await clientSecrets.getTokenSource(serverEndpoint).catch(_ => undefined)
}
} catch (error) {
throw new Error(`Failed to execute external auth command: ${error}`)
}

return { accessTokenOrHeaders, serverEndpoint, tokenSource }
}

async function resolveConfiguration({
clientConfiguration,
clientSecrets,
clientState,
reinstall: { isReinstalling, onReinstall },
}: ConfigurationInput): Promise<ResolvedConfiguration> {
}: ConfigurationInput): Promise<ResolvedConfiguration | Error> {
const isReinstall = await isReinstalling()
if (isReinstall) {
await onReinstall()
}
// we allow for overriding the server endpoint from config if we haven't
// manually signed in somewhere else
const serverEndpoint = normalizeServerEndpointURL(

const serverEndpoint =
clientConfiguration.overrideServerEndpoint ||
(clientState.lastUsedEndpoint ?? DOTCOM_URL.toString())
)
clientState.lastUsedEndpoint ||
DOTCOM_URL.toString()

// We must not throw here, because that would result in the `resolvedConfig` observable
// terminating and all callers receiving no further config updates.
const loadTokenFn = () =>
clientSecrets.getToken(serverEndpoint).catch(error => {
logError(
'resolveConfiguration',
`Failed to get access token for endpoint ${serverEndpoint}: ${error}`
)
return null
})
const accessToken = clientConfiguration.overrideAuthToken || ((await loadTokenFn()) ?? null)
return {
configuration: clientConfiguration,
clientState,
auth: { accessToken, serverEndpoint },
isReinstall,
try {
const auth = await resolveAuth(serverEndpoint, clientConfiguration, clientSecrets)
return { configuration: clientConfiguration, clientState, auth, isReinstall }
} catch (error) {
// We don't want to throw here, because that would cause the observable to terminate and
// all callers receiving no further config updates.
logError('resolveConfiguration', `Error resolving configuration: ${error}`)
const auth = { accessTokenOrHeaders: null, serverEndpoint, tokenSource: undefined }
return { configuration: clientConfiguration, clientState, auth, isReinstall }
}
}

Expand Down
2 changes: 1 addition & 1 deletion lib/shared/src/experimentation/FeatureFlagProvider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ describe('FeatureFlagProvider', () => {
beforeAll(() => {
vi.useFakeTimers()
mockResolvedConfig({
auth: { accessToken: null, serverEndpoint: 'https://example.com' },
auth: { accessTokenOrHeaders: null, serverEndpoint: 'https://example.com' },
})
mockAuthStatus(AUTH_STATUS_FIXTURE_AUTHED)
})
Expand Down
6 changes: 1 addition & 5 deletions lib/shared/src/models/sync.ts
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,7 @@ async function fetchServerSideModels(
): Promise<ServerModelConfiguration | undefined> {
// Fetch the data via REST API.
// NOTE: We may end up exposing this data via GraphQL, it's still TBD.
const client = new RestClient(
config.auth.serverEndpoint,
config.auth.accessToken ?? undefined,
config.configuration.customHeaders
)
const client = new RestClient(config.auth, config.configuration.customHeaders)
return await client.getAvailableModels(signal)
}

Expand Down
Loading
Loading