From 068277d2dbc602c05eb97d033ef90642d029c542 Mon Sep 17 00:00:00 2001 From: Piotr Kukielka Date: Thu, 2 Jan 2025 17:17:21 +0100 Subject: [PATCH] Initial implementation --- .vscode/settings.json | 10 +- agent/scripts/reverse-proxy.py | 68 ++++++++++ agent/src/agent.ts | 12 +- .../cli/command-auth/AuthenticatedAccount.ts | 2 +- agent/src/cli/command-auth/command-login.ts | 4 +- agent/src/cli/command-bench/command-bench.ts | 2 +- agent/src/cli/command-bench/llm-judge.ts | 2 +- agent/src/local-e2e/helpers.ts | 5 +- agent/src/vscode-shim.ts | 8 +- lib/shared/src/configuration.ts | 23 +++- lib/shared/src/configuration/resolver.ts | 127 ++++++++++++++---- .../FeatureFlagProvider.test.ts | 2 +- lib/shared/src/models/sync.ts | 6 +- .../completions/browserClient.ts | 11 +- .../src/sourcegraph-api/graphql/client.ts | 29 ++-- lib/shared/src/sourcegraph-api/rest/client.ts | 23 ++-- lib/shared/src/sourcegraph-api/utils.ts | 17 +++ vscode/package.json | 51 +++++++ vscode/src/auth/auth.ts | 25 ++-- vscode/src/auth/token-receiver.ts | 7 +- .../autoedits/adapters/cody-gateway.test.ts | 2 +- vscode/src/autoedits/adapters/cody-gateway.ts | 7 +- vscode/src/chat/agentic/DeepCody.test.ts | 2 +- vscode/src/chat/chat-view/ChatController.ts | 33 +++-- vscode/src/completions/default-client.ts | 11 +- vscode/src/completions/nodeClient.ts | 5 +- vscode/src/completions/providers/fireworks.ts | 7 +- vscode/src/configuration.test.ts | 3 + vscode/src/configuration.ts | 2 + .../rewrite-keyword-query.test.ts | 2 +- vscode/src/main.ts | 9 +- .../src/notifications/setup-notification.ts | 2 +- vscode/src/services/AuthProvider.test.ts | 10 +- vscode/src/services/LocalStorageProvider.ts | 6 +- vscode/src/services/UpstreamHealthProvider.ts | 8 +- .../open-telemetry/CodyTraceExport.ts | 12 +- .../OpenTelemetryService.node.ts | 2 +- .../services/open-telemetry/trace-sender.ts | 16 +-- vscode/src/testutils/mocks.ts | 1 + vscode/webviews/AppWrapperForTest.tsx | 2 +- 40 files changed, 438 insertions(+), 138 deletions(-) create mode 100644 agent/scripts/reverse-proxy.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 7dd5a69e09b1..4e19fcfea7a5 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -42,5 +42,13 @@ "[typescriptreact]": { "editor.defaultFormatter": "biomejs.biome" }, - "rust-analyzer.procMacro.ignored": { "napi-derive": ["napi"] } + "rust-analyzer.procMacro.ignored": { + "napi-derive": [ + "napi" + ] + }, + "debug.javascript.defaultRuntimeExecutable": { + "pwa-node": "/Users/pkukielka/.local/share/mise/shims/node" + }, + "python.defaultInterpreterPath": "/Users/pkukielka/.local/share/mise/shims/python" } diff --git a/agent/scripts/reverse-proxy.py b/agent/scripts/reverse-proxy.py new file mode 100644 index 000000000000..53b41871954c --- /dev/null +++ b/agent/scripts/reverse-proxy.py @@ -0,0 +1,68 @@ +from aiohttp import web, ClientSession +from urllib.parse import urlparse +import asyncio + +target_url = '' +port = 5050 + +async def proxy_handler(request): + async with ClientSession(auto_decompress=False) as session: + print(f'Request to: {request.url}') + + # Modify headers here + headers = dict(request.headers) + + # Reset the Host header to use target server host instead of the proxy host + if 'Host' in headers: + headers['Host'] = urlparse(target_url).netloc.split(':')[0] + + # 'chunked' encoding results in error 400 from Cloudflare, removing it still keeps response chunked anyway + if 'Transfer-Encoding' in headers: + del headers['Transfer-Encoding'] + + # Use value of 'Authorization: Bearer' to fill 'X-Forwarded-User' and remove 'Authorization' header + if 'Authorization' in headers: + values = headers['Authorization'].split() + if values and values[0] == 'Bearer': + headers['X-Forwarded-User'] = values[1] + del headers['Authorization'] + + # Forward the request to target + async with session.request( + method=request.method, + url=f'{target_url}{request.path_qs}', + headers=headers, + data=await request.read() + ) as response: + proxy_response = web.StreamResponse( + status=response.status, + headers=response.headers + ) + + await proxy_response.prepare(request) + + # Stream the response back + async for chunk in response.content.iter_chunks(): + await proxy_response.write(chunk[0]) + + await proxy_response.write_eof() + return proxy_response + +app = web.Application() +app.router.add_route('*', '/{path_info:.*}', proxy_handler) + + +if __name__ == '__main__': + print('Usage: python reverse_proxy.py [target_url] [proxy_port]') + + import sys + if (len(sys.argv) < 2): + print('Please specify target_url') + sys.exit(1) + if len(sys.argv) > 1: + target_url = sys.argv[1] + if len(sys.argv) > 2: + port = int(sys.argv[2]) + + print(f'Starting proxy server on port {port} targeting {target_url}...') + web.run_app(app, port=port) diff --git a/agent/src/agent.ts b/agent/src/agent.ts index becb7fee688a..96cf75a70270 100644 --- a/agent/src/agent.ts +++ b/agent/src/agent.ts @@ -4,6 +4,7 @@ import path from 'node:path' import type { Polly, Request } from '@pollyjs/core' import { type AccountKeyedChatHistory, + AuthCredentials, type ChatHistoryKey, type ClientCapabilities, type CodyCommand, @@ -14,6 +15,7 @@ import { currentAuthStatusAuthed, firstNonPendingAuthStatus, firstResultFromOperation, + resolveAuth, resolvedConfig, telemetryRecorder, waitUntilComplete, @@ -70,6 +72,7 @@ import type { FixupActor, FixupFileCollection } from '../../vscode/src/non-stop/ import type { FixupControlApplicator } from '../../vscode/src/non-stop/strategies' import { authProvider } from '../../vscode/src/services/AuthProvider' import { localStorage } from '../../vscode/src/services/LocalStorageProvider' +import { secretStorage } from '../../vscode/src/services/SecretStorageProvider' import { AgentWorkspaceEdit } from '../../vscode/src/testutils/AgentWorkspaceEdit' import { AgentAuthHandler } from './AgentAuthHandler' import { AgentFixupControls } from './AgentFixupControls' @@ -1482,6 +1485,11 @@ export class Agent extends MessageHandler implements ExtensionClient { return this.clientInfo?.capabilities ?? undefined } + private async resolveAuthCredentials(extCfg: ExtensionConfiguration): Promise { + const config = JSON.parse(extCfg.customConfigurationJson ?? '{}') + return resolveAuth(extCfg.serverEndpoint, config, secretStorage) + } + private async handleConfigChanges( config: ExtensionConfiguration, params?: { forceAuthentication: boolean } @@ -1500,7 +1508,9 @@ export class Agent extends MessageHandler implements ExtensionClient { }, auth: { serverEndpoint: config.serverEndpoint, - accessToken: config.accessToken ?? null, + accessTokenOrHeaders: + config.accessToken ?? + (await this.resolveAuthCredentials(config)).accessTokenOrHeaders, }, clientState: { anonymousUserID: config.anonymousUserID ?? null, diff --git a/agent/src/cli/command-auth/AuthenticatedAccount.ts b/agent/src/cli/command-auth/AuthenticatedAccount.ts index 9a228e3fb224..91f0d8c4c694 100644 --- a/agent/src/cli/command-auth/AuthenticatedAccount.ts +++ b/agent/src/cli/command-auth/AuthenticatedAccount.ts @@ -42,7 +42,7 @@ export class AuthenticatedAccount { ): Promise { const graphqlClient = SourcegraphGraphQLAPIClient.withStaticConfig({ configuration: { telemetryLevel: 'agent' }, - auth: { accessToken: options.accessToken, serverEndpoint: options.endpoint }, + auth: { accessTokenOrHeaders: options.accessToken, serverEndpoint: options.endpoint }, clientState: { anonymousUserID: null }, }) const userInfo = await graphqlClient.getCurrentUserInfo() diff --git a/agent/src/cli/command-auth/command-login.ts b/agent/src/cli/command-auth/command-login.ts index 2cc9aedb745d..d02dcac380c8 100644 --- a/agent/src/cli/command-auth/command-login.ts +++ b/agent/src/cli/command-auth/command-login.ts @@ -123,7 +123,7 @@ async function loginAction( : await captureAccessTokenViaBrowserRedirect(serverEndpoint, spinner) const client = SourcegraphGraphQLAPIClient.withStaticConfig({ configuration: { telemetryLevel: 'agent' }, - auth: { accessToken: token, serverEndpoint: serverEndpoint }, + auth: { accessTokenOrHeaders: token, serverEndpoint: serverEndpoint }, clientState: { anonymousUserID: null }, }) const userInfo = await client.getCurrentUserInfo() @@ -256,7 +256,7 @@ async function promptUserAboutLoginMethod(spinner: Ora, options: LoginOptions): try { const client = SourcegraphGraphQLAPIClient.withStaticConfig({ configuration: { telemetryLevel: 'agent' }, - auth: { accessToken: options.accessToken, serverEndpoint: options.endpoint }, + auth: { accessTokenOrHeaders: options.accessToken, serverEndpoint: options.endpoint }, clientState: { anonymousUserID: null }, }) const userInfo = await client.getCurrentUserInfo() diff --git a/agent/src/cli/command-bench/command-bench.ts b/agent/src/cli/command-bench/command-bench.ts index 4cb09ce25e51..ac16e60df4aa 100644 --- a/agent/src/cli/command-bench/command-bench.ts +++ b/agent/src/cli/command-bench/command-bench.ts @@ -329,7 +329,7 @@ export const benchCommand = new commander.Command('bench') setStaticResolvedConfigurationWithAuthCredentials({ configuration: { customHeaders: {} }, auth: { - accessToken: options.srcAccessToken, + accessTokenOrHeaders: options.srcAccessToken, serverEndpoint: options.srcEndpoint, }, }) diff --git a/agent/src/cli/command-bench/llm-judge.ts b/agent/src/cli/command-bench/llm-judge.ts index 8df9d4510c2f..10b48f77b3e7 100644 --- a/agent/src/cli/command-bench/llm-judge.ts +++ b/agent/src/cli/command-bench/llm-judge.ts @@ -16,7 +16,7 @@ export class LlmJudge { localStorage.setStorage('noop') setStaticResolvedConfigurationWithAuthCredentials({ configuration: { customHeaders: undefined }, - auth: { accessToken: options.srcAccessToken, serverEndpoint: options.srcEndpoint }, + auth: { accessTokenOrHeaders: options.srcAccessToken, serverEndpoint: options.srcEndpoint }, }) setClientCapabilities({ configuration: {}, agentCapabilities: undefined }) this.client = new SourcegraphNodeCompletionsClient() diff --git a/agent/src/local-e2e/helpers.ts b/agent/src/local-e2e/helpers.ts index 703a4817c061..da6de8615a87 100644 --- a/agent/src/local-e2e/helpers.ts +++ b/agent/src/local-e2e/helpers.ts @@ -96,7 +96,10 @@ export class LocalSGInstance { // for checking the LLM configuration section. this.gqlclient = SourcegraphGraphQLAPIClient.withStaticConfig({ configuration: { customHeaders: headers, telemetryLevel: 'agent' }, - auth: { accessToken: this.params.accessToken, serverEndpoint: this.params.serverEndpoint }, + auth: { + accessTokenOrHeaders: this.params.accessToken, + serverEndpoint: this.params.serverEndpoint, + }, clientState: { anonymousUserID: null }, }) } diff --git a/agent/src/vscode-shim.ts b/agent/src/vscode-shim.ts index 9b856be75ae2..71135cde381d 100644 --- a/agent/src/vscode-shim.ts +++ b/agent/src/vscode-shim.ts @@ -141,9 +141,15 @@ export function isAuthenticationChange(newConfig: ExtensionConfiguration): boole return true } + function getExternalAuthProvidersConfig(cfg: ExtensionConfiguration) { + return JSON.parse(cfg.customConfigurationJson ?? '{}')?.cody?.auth?.externalProviders + } + return ( extensionConfiguration.accessToken !== newConfig.accessToken || - extensionConfiguration.serverEndpoint !== newConfig.serverEndpoint + extensionConfiguration.serverEndpoint !== newConfig.serverEndpoint || + getExternalAuthProvidersConfig(extensionConfiguration) !== + getExternalAuthProvidersConfig(newConfig) ) } diff --git a/lib/shared/src/configuration.ts b/lib/shared/src/configuration.ts index 7659244cec19..b69e49837b17 100644 --- a/lib/shared/src/configuration.ts +++ b/lib/shared/src/configuration.ts @@ -9,7 +9,9 @@ import type { ReadonlyDeep } from './utils' * A redirect flow is initiated by the user clicking a link in the browser, while a paste flow is initiated by the user * manually entering the access from into the VsCode App. */ -export type TokenSource = 'redirect' | 'paste' +export type TokenSource = 'redirect' | 'paste' | 'custom-auth-provider' + +export type AuthHeaders = Record /** * The user's authentication credentials, which are stored separately from the rest of the @@ -17,8 +19,8 @@ export type TokenSource = 'redirect' | 'paste' */ export interface AuthCredentials { serverEndpoint: string - accessToken: string | null tokenSource?: TokenSource | undefined + accessTokenOrHeaders: string | AuthHeaders | null } export interface AutoEditsTokenLimit { @@ -71,6 +73,20 @@ export interface AgenticContextConfiguration { } } +export interface ExternalAuthCommand { + commandLine: string[] + environment?: Record + workingDir?: string + shell?: string + timeout?: number + windowsHide?: boolean +} + +export interface ExternalAuthProvider { + endpoint: string + executable: ExternalAuthCommand +} + interface RawClientConfiguration { net: NetConfiguration codebase?: string @@ -166,6 +182,9 @@ interface RawClientConfiguration { */ overrideServerEndpoint?: string | undefined overrideAuthToken?: string | undefined + + // External auth providers + authExternalProviders?: ExternalAuthProvider[] } /** diff --git a/lib/shared/src/configuration/resolver.ts b/lib/shared/src/configuration/resolver.ts index 331309fbfc7f..70f2df425f4d 100644 --- a/lib/shared/src/configuration/resolver.ts +++ b/lib/shared/src/configuration/resolver.ts @@ -1,5 +1,11 @@ import { Observable, map } from 'observable-fns' -import type { AuthCredentials, ClientConfiguration } from '../configuration' +import type { + AuthCredentials, + ClientConfiguration, + ExternalAuthCommand, + ExternalAuthProvider, + TokenSource, +} from '../configuration' import { logError } from '../logger' import { distinctUntilChanged, @@ -27,6 +33,7 @@ export interface ConfigurationInput { export interface ClientSecrets { getToken(endpoint: string): Promise + getTokenSource(endpoint: string): Promise } export interface ClientState { @@ -72,39 +79,113 @@ export type PickResolvedConfiguration = { : undefined } +export type ExternalProviderAuthConfiguration = { + authExternalProviders?: ExternalAuthProvider[] + overrideAuthToken?: string | undefined +} + +async function executeCommand(cmd: ExternalAuthCommand): Promise { + if (typeof process === 'undefined' || !process.version) { + throw new Error('Command execution is only supported in Node.js environments') + } + + const { exec } = await import('node:child_process') + const { promisify } = await import('node:util') + const execAsync = promisify(exec) + + const command = cmd.commandLine.join(' ') + const options = { + ...cmd, + env: cmd.environment ? { ...process.env, ...cmd.environment } : process.env, + } + + const { stdout, stderr } = await execAsync(command, options) + if (stderr) { + throw new Error(`External auth command error: ${stderr}`) + } + return stdout.trim() +} + +async function getExternalProviderAuthHeaders( + serverEndpoint: string, + clientConfiguration: ExternalProviderAuthConfiguration +): Promise | undefined> { + // Check for external auth provider for this endpoint + const externalProvider = clientConfiguration.authExternalProviders?.find( + provider => normalizeServerEndpointURL(provider.endpoint) === serverEndpoint + ) + + if (externalProvider) { + const result = await executeCommand(externalProvider.executable) + return JSON.parse(result) + } + + return undefined +} + +export async function resolveAuth( + endpoint: string, + clientConfiguration: ExternalProviderAuthConfiguration, + clientSecrets: ClientSecrets +): Promise { + let accessTokenOrHeaders = null + let tokenSource: TokenSource | undefined = undefined + + const serverEndpoint = normalizeServerEndpointURL(endpoint) + + // We must not throw here, because that would result in the `resolvedConfig` observable + // terminating and all callers receiving no further config updates. + const loadTokenFn = () => + clientSecrets.getToken(serverEndpoint).catch(error => { + throw new Error( + `Failed to get access token for endpoint ${serverEndpoint}: ${error.message || error}` + ) + }) + + if (clientConfiguration.overrideAuthToken) { + accessTokenOrHeaders = clientConfiguration.overrideAuthToken + } else + try { + const authHeaders = await getExternalProviderAuthHeaders(serverEndpoint, clientConfiguration) + if (authHeaders) { + accessTokenOrHeaders = authHeaders + tokenSource = 'custom-auth-provider' + } else { + accessTokenOrHeaders = (await loadTokenFn()) || null + tokenSource = await clientSecrets.getTokenSource(serverEndpoint).catch(_ => undefined) + } + } catch (error) { + throw new Error(`Failed to execute external auth command: ${error}`) + } + + return { accessTokenOrHeaders, serverEndpoint, tokenSource } +} + async function resolveConfiguration({ clientConfiguration, clientSecrets, clientState, reinstall: { isReinstalling, onReinstall }, -}: ConfigurationInput): Promise { +}: ConfigurationInput): Promise { const isReinstall = await isReinstalling() if (isReinstall) { await onReinstall() } - // we allow for overriding the server endpoint from config if we haven't - // manually signed in somewhere else - const serverEndpoint = normalizeServerEndpointURL( + + const serverEndpoint = clientConfiguration.overrideServerEndpoint || - (clientState.lastUsedEndpoint ?? DOTCOM_URL.toString()) - ) + clientState.lastUsedEndpoint || + DOTCOM_URL.toString() - // We must not throw here, because that would result in the `resolvedConfig` observable - // terminating and all callers receiving no further config updates. - const loadTokenFn = () => - clientSecrets.getToken(serverEndpoint).catch(error => { - logError( - 'resolveConfiguration', - `Failed to get access token for endpoint ${serverEndpoint}: ${error}` - ) - return null - }) - const accessToken = clientConfiguration.overrideAuthToken || ((await loadTokenFn()) ?? null) - return { - configuration: clientConfiguration, - clientState, - auth: { accessToken, serverEndpoint }, - isReinstall, + try { + const auth = await resolveAuth(serverEndpoint, clientConfiguration, clientSecrets) + return { configuration: clientConfiguration, clientState, auth, isReinstall } + } catch (error) { + // We don't want to throw here, because that would cause the observable to terminate and + // all callers receiving no further config updates. + logError('resolveConfiguration', `Error resolving configuration: ${error}`) + const auth = { accessTokenOrHeaders: null, serverEndpoint, tokenSource: undefined } + return { configuration: clientConfiguration, clientState, auth, isReinstall } } } diff --git a/lib/shared/src/experimentation/FeatureFlagProvider.test.ts b/lib/shared/src/experimentation/FeatureFlagProvider.test.ts index 43bea01c31d3..31a24411d939 100644 --- a/lib/shared/src/experimentation/FeatureFlagProvider.test.ts +++ b/lib/shared/src/experimentation/FeatureFlagProvider.test.ts @@ -24,7 +24,7 @@ describe('FeatureFlagProvider', () => { beforeAll(() => { vi.useFakeTimers() mockResolvedConfig({ - auth: { accessToken: null, serverEndpoint: 'https://example.com' }, + auth: { accessTokenOrHeaders: null, serverEndpoint: 'https://example.com' }, }) mockAuthStatus(AUTH_STATUS_FIXTURE_AUTHED) }) diff --git a/lib/shared/src/models/sync.ts b/lib/shared/src/models/sync.ts index 58fa9107e69b..6c25eecad31c 100644 --- a/lib/shared/src/models/sync.ts +++ b/lib/shared/src/models/sync.ts @@ -448,11 +448,7 @@ async function fetchServerSideModels( ): Promise { // Fetch the data via REST API. // NOTE: We may end up exposing this data via GraphQL, it's still TBD. - const client = new RestClient( - config.auth.serverEndpoint, - config.auth.accessToken ?? undefined, - config.configuration.customHeaders - ) + const client = new RestClient(config.auth, config.configuration.customHeaders) return await client.getAvailableModels(signal) } diff --git a/lib/shared/src/sourcegraph-api/completions/browserClient.ts b/lib/shared/src/sourcegraph-api/completions/browserClient.ts index 65b8d7ec8996..8475b4f3aa74 100644 --- a/lib/shared/src/sourcegraph-api/completions/browserClient.ts +++ b/lib/shared/src/sourcegraph-api/completions/browserClient.ts @@ -4,6 +4,7 @@ import { dependentAbortController } from '../../common/abortController' import { currentResolvedConfig } from '../../configuration/resolver' import { isError } from '../../utils' import { addClientInfoParams, addCodyClientIdentificationHeaders } from '../client-name-version' +import { addAuthHeaders } from '../utils' import { CompletionsResponseBuilder } from './CompletionsResponseBuilder' import { type CompletionRequestParameters, SourcegraphCompletionsClient } from './client' @@ -39,10 +40,9 @@ export class SourcegraphBrowserCompletionsClient extends SourcegraphCompletionsC ...requestParams.customHeaders, } as HeadersInit) addCodyClientIdentificationHeaders(headersInstance) + addAuthHeaders(config.auth, headersInstance, url) headersInstance.set('Content-Type', 'application/json; charset=utf-8') - if (config.auth.accessToken) { - headersInstance.set('Authorization', `token ${config.auth.accessToken}`) - } + const parameters = new URLSearchParams(globalThis.location.search) const trace = parameters.get('trace') if (trace) { @@ -132,9 +132,8 @@ export class SourcegraphBrowserCompletionsClient extends SourcegraphCompletionsC ...requestParams.customHeaders, }) addCodyClientIdentificationHeaders(headersInstance) - if (auth.accessToken) { - headersInstance.set('Authorization', `token ${auth.accessToken}`) - } + addAuthHeaders(auth, headersInstance, url) + if (new URLSearchParams(globalThis.location.search).get('trace')) { headersInstance.set('X-Sourcegraph-Should-Trace', 'true') } diff --git a/lib/shared/src/sourcegraph-api/graphql/client.ts b/lib/shared/src/sourcegraph-api/graphql/client.ts index 0bca06fbc360..ca6ed79997c2 100644 --- a/lib/shared/src/sourcegraph-api/graphql/client.ts +++ b/lib/shared/src/sourcegraph-api/graphql/client.ts @@ -17,6 +17,7 @@ import { addTraceparent, wrapInActiveSpan } from '../../tracing' import { isError } from '../../utils' import { addCodyClientIdentificationHeaders } from '../client-name-version' import { isAbortError } from '../errors' +import { addAuthHeaders } from '../utils' import { type GraphQLResultCache, ObservableInvalidatedGraphQLResultCacheFactory } from './cache' import { BUILTIN_PROMPTS_QUERY, @@ -1552,23 +1553,23 @@ export class SourcegraphGraphQLAPIClient { const headers = new Headers(config.configuration?.customHeaders as HeadersInit | undefined) headers.set('Content-Type', 'application/json; charset=utf-8') - if (config.auth.accessToken) { - headers.set('Authorization', `token ${config.auth.accessToken}`) - } if (config.clientState.anonymousUserID && !process.env.CODY_WEB_DONT_SET_SOME_HEADERS) { headers.set('X-Sourcegraph-Actor-Anonymous-UID', config.clientState.anonymousUserID) } + const url = new URL( + buildGraphQLUrl({ + request: query, + baseUrl: config.auth.serverEndpoint, + }) + ) + addTraceparent(headers) addCodyClientIdentificationHeaders(headers) + addAuthHeaders(config.auth, headers, url) const queryName = query.match(QUERY_TO_NAME_REGEXP)?.[1] - const url = buildGraphQLUrl({ - request: query, - baseUrl: config.auth.serverEndpoint, - }) - const { abortController, timeoutSignal } = dependentAbortControllerWithTimeout(signal) return wrapInActiveSpan(`graphql.fetch${queryName ? `.${queryName}` : ''}`, () => fetch(url, { @@ -1579,7 +1580,7 @@ export class SourcegraphGraphQLAPIClient { }) .then(verifyResponseCode) .then(response => response.json() as T) - .catch(catchHTTPError(url, timeoutSignal)) + .catch(catchHTTPError(url.href, timeoutSignal)) ) } @@ -1605,17 +1606,15 @@ export class SourcegraphGraphQLAPIClient { const headers = new Headers(config.configuration?.customHeaders as HeadersInit | undefined) headers.set('Content-Type', 'application/json; charset=utf-8') - if (config.auth.accessToken) { - headers.set('Authorization', `token ${config.auth.accessToken}`) - } if (config.clientState.anonymousUserID && !process.env.CODY_WEB_DONT_SET_SOME_HEADERS) { headers.set('X-Sourcegraph-Actor-Anonymous-UID', config.clientState.anonymousUserID) } + const url = new URL(urlPath, config.auth.serverEndpoint) + addTraceparent(headers) addCodyClientIdentificationHeaders(headers) - - const url = new URL(urlPath, config.auth.serverEndpoint).href + addAuthHeaders(config.auth, headers, url) const { abortController, timeoutSignal } = dependentAbortControllerWithTimeout(signal) return wrapInActiveSpan(`httpapi.fetch${queryName ? `.${queryName}` : ''}`, () => @@ -1627,7 +1626,7 @@ export class SourcegraphGraphQLAPIClient { }) .then(verifyResponseCode) .then(response => response.json() as T) - .catch(catchHTTPError(url, timeoutSignal)) + .catch(catchHTTPError(url.href, timeoutSignal)) ) } } diff --git a/lib/shared/src/sourcegraph-api/rest/client.ts b/lib/shared/src/sourcegraph-api/rest/client.ts index 551ada9c70c0..e433a2ba1f7c 100644 --- a/lib/shared/src/sourcegraph-api/rest/client.ts +++ b/lib/shared/src/sourcegraph-api/rest/client.ts @@ -1,5 +1,6 @@ import type { ServerModelConfiguration } from '../../models/modelsService' +import { type AuthCredentials, addAuthHeaders } from '../..' import { fetch } from '../../fetch' import { logError } from '../../logger' import { addTraceparent, wrapInActiveSpan } from '../../tracing' @@ -20,14 +21,12 @@ import { verifyResponseCode } from '../graphql/client' */ export class RestClient { /** - * @param endpointUrl URL to the sourcegraph instance, e.g. "https://sourcegraph.acme.com". - * @param accessToken User access token to contact the sourcegraph instance. - * @param customHeaders Custom headers (primary is used by Cody Web case when Sourcegraph client - * providers set of custom headers to make sure that auth flow will work properly + * Creates a new REST client to interact with a Sourcegraph instance. + * @param auth Authentication credentials containing endpoint URL and access token + * @param customHeaders Additional headers for requests (used by Cody Web to ensure proper auth flow) */ constructor( - private endpointUrl: string, - private accessToken: string | undefined, + private auth: AuthCredentials, private customHeaders: Record | undefined ) {} @@ -35,15 +34,15 @@ export class RestClient { // "name" is a developer-friendly term to label the request's trace span. private getRequest(name: string, urlSuffix: string, signal?: AbortSignal): Promise { const headers = new Headers(this.customHeaders) - if (this.accessToken) { - headers.set('Authorization', `token ${this.accessToken}`) - } - addCodyClientIdentificationHeaders(headers) - addTraceparent(headers) - const endpoint = new URL(this.endpointUrl) + const endpoint = new URL(this.auth.serverEndpoint) endpoint.pathname = urlSuffix const url = endpoint.href + + addCodyClientIdentificationHeaders(headers) + addAuthHeaders(this.auth, headers, endpoint) + addTraceparent(headers) + return wrapInActiveSpan(`rest-api.${name}`, () => fetch(url, { method: 'GET', diff --git a/lib/shared/src/sourcegraph-api/utils.ts b/lib/shared/src/sourcegraph-api/utils.ts index a6d161b27903..1d58f2d80466 100644 --- a/lib/shared/src/sourcegraph-api/utils.ts +++ b/lib/shared/src/sourcegraph-api/utils.ts @@ -2,6 +2,9 @@ // of a character, returns the remaining bytes of the partial character in a // new buffer. Note! This assumes that the prefix of buf *is* valid UTF8--it // only examines the bytes of the last character in the buffer and assumes it + +import type { AuthCredentials } from '..' + // will find an initial byte before the start of the buffer. export function toPartialUtf8String(buf: Buffer): { str: string; buf: Buffer } { if (buf.length === 0) { @@ -32,3 +35,17 @@ export function toPartialUtf8String(buf: Buffer): { str: string; buf: Buffer } { buf: Buffer.from(buf.slice(lastValidByteOffsetExclusive)), } } + +export function addAuthHeaders(auth: AuthCredentials, headers: Headers, url: URL): void { + // We want to be sure we sent authorization headers only to the valid endpoint + if (auth.accessTokenOrHeaders && url.host === new URL(auth.serverEndpoint).host) { + if (typeof auth.accessTokenOrHeaders === 'string') { + headers.set('Authorization', `token ${auth.accessTokenOrHeaders}`) + } else { + // Add headers as-is when accessTokenOrHeaders is a record of headers + for (const [key, value] of Object.entries(auth.accessTokenOrHeaders)) { + headers.set(key, value) + } + } + } +} diff --git a/vscode/package.json b/vscode/package.json index 9ce993f5c499..617ac48d6282 100644 --- a/vscode/package.json +++ b/vscode/package.json @@ -1316,6 +1316,57 @@ "~/.mitmproxy/mitmproxy-ca-cert.pem", "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----" ] + }, + "cody.auth.externalProviders": { + "type": "array", + "markdownDescription": "Configure external authentication providers for Cody requests. Each provider consists of a command that generates HTTP headers used for authentication for a given endpoint.\n\n**How it works:**\n1. The specified command outputs a JSON object with header-value pairs\n2. These headers are included in authenticated Cody requests to the specified endpoint\n3. HTTP authentication proxy need to be used to enable custom authentication flows (e.g. JWT tokens, Oath2, etc)\n\nSee [HTTP Authentication Proxies](https://sourcegraph.com/docs/admin/auth#http-authentication-proxies) for proxy configuration.", + "items": { + "type": "object", + "required": ["endpoint", "executable"], + "properties": { + "endpoint": { + "type": "string", + "description": "The endpoint URL of the Sourcegraph instance" + }, + "executable": { + "type": "object", + "required": ["commandLine"], + "properties": { + "commandLine": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Command line arguments to execute the command." + }, + "environment": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "description": "Environment variables to set when executing the command." + }, + "workingDir": { + "type": "string", + "description": "Working directory for executing the command." + }, + "shell": { + "type": "string", + "description": "If true, runs command inside of a shell. Uses '/bin/sh' on Unix, and process.env.ComSpec on Windows. A different shell can be specified as a string." + }, + "timeout": { + "type": "number", + "description": "Timeout for executing the command in miliseconds." + }, + "windowsHide": { + "type": "boolean", + "description": "Whether to hide the console window that would normally be created for the child process on Windows." + } + } + } + } + }, + "default": [] } } }, diff --git a/vscode/src/auth/auth.ts b/vscode/src/auth/auth.ts index 3f490c9b16d5..fa7ba40626f7 100644 --- a/vscode/src/auth/auth.ts +++ b/vscode/src/auth/auth.ts @@ -12,12 +12,14 @@ import { cenv, clientCapabilities, currentAuthStatus, + currentResolvedConfig, getAuthErrorMessage, getCodyAuthReferralCode, graphqlClient, isDotCom, isError, isNetworkLikeError, + resolveAuth, telemetryRecorder, } from '@sourcegraph/cody-shared' import { isSourcegraphToken } from '../chat/protocol' @@ -83,16 +85,15 @@ export async function showSignInMenu( break } default: { - // Auto log user if token for the selected instance was found in secret + // Auto log user if token for the selected instance was found in secret or custom provider is configured const selectedEndpoint = item.uri - const token = await secretStorage.getToken(selectedEndpoint) - const tokenSource = await secretStorage.getTokenSource(selectedEndpoint) - let authStatus = token - ? await authProvider.validateAndStoreCredentials( - { serverEndpoint: selectedEndpoint, accessToken: token, tokenSource }, - 'store-if-valid' - ) + const { configuration } = await currentResolvedConfig() + const auth = await resolveAuth(selectedEndpoint, configuration, secretStorage) + + let authStatus = auth.accessTokenOrHeaders + ? await authProvider.validateAndStoreCredentials(auth, 'store-if-valid') : undefined + if (!authStatus?.authenticated) { const newToken = await showAccessTokenInputBox(selectedEndpoint) if (!newToken) { @@ -101,7 +102,7 @@ export async function showSignInMenu( authStatus = await authProvider.validateAndStoreCredentials( { serverEndpoint: selectedEndpoint, - accessToken: newToken, + accessTokenOrHeaders: newToken, tokenSource: 'paste', }, 'store-if-valid' @@ -233,7 +234,7 @@ async function signinMenuForInstanceUrl(instanceUrl: string): Promise { return } const authStatus = await authProvider.validateAndStoreCredentials( - { serverEndpoint: instanceUrl, accessToken: accessToken, tokenSource: 'paste' }, + { serverEndpoint: instanceUrl, accessTokenOrHeaders: accessToken, tokenSource: 'paste' }, 'store-if-valid' ) telemetryRecorder.recordEvent('cody.auth.signin.token', 'clicked', { @@ -312,7 +313,7 @@ export async function tokenCallbackHandler(uri: vscode.Uri): Promise { } const authStatus = await authProvider.validateAndStoreCredentials( - { serverEndpoint: endpoint, accessToken: token, tokenSource: 'redirect' }, + { serverEndpoint: endpoint, accessTokenOrHeaders: token, tokenSource: 'redirect' }, 'store-if-valid' ) telemetryRecorder.recordEvent('cody.auth.fromCallback.web', 'succeeded', { @@ -410,7 +411,7 @@ export async function validateCredentials( clientConfig?: CodyClientConfig ): Promise { // An access token is needed except for Cody Web, which uses cookies. - if (!config.auth.accessToken && !clientCapabilities().isCodyWeb) { + if (!config.auth.accessTokenOrHeaders && !clientCapabilities().isCodyWeb) { return { authenticated: false, endpoint: config.auth.serverEndpoint, pendingValidation: false } } diff --git a/vscode/src/auth/token-receiver.ts b/vscode/src/auth/token-receiver.ts index 0a93195690fe..e7104ba1947d 100644 --- a/vscode/src/auth/token-receiver.ts +++ b/vscode/src/auth/token-receiver.ts @@ -14,7 +14,7 @@ const FIVE_MINUTES = 5 * 60 * 1000 // the user follow a redirect. export function startTokenReceiver( endpoint: string, - onNewToken: (credentials: Pick) => void, + onNewToken: (credentials: Pick) => void, timeout = FIVE_MINUTES ): Promise { const endpointUrl = new URL(endpoint) @@ -46,7 +46,10 @@ export function startTokenReceiver( 'accessToken' in json && typeof json.accessToken === 'string' ) { - onNewToken({ serverEndpoint: endpoint, accessToken: json.accessToken }) + onNewToken({ + serverEndpoint: endpoint, + accessTokenOrHeaders: json.accessToken, + }) res.writeHead(200, headers) res.write('ok') diff --git a/vscode/src/autoedits/adapters/cody-gateway.test.ts b/vscode/src/autoedits/adapters/cody-gateway.test.ts index f3fcbdb4dd09..dd527ea7847f 100644 --- a/vscode/src/autoedits/adapters/cody-gateway.test.ts +++ b/vscode/src/autoedits/adapters/cody-gateway.test.ts @@ -33,7 +33,7 @@ describe('CodyGatewayAdapter', () => { mockResolvedConfig({ configuration: {}, auth: { - accessToken: 'sgp_local_f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0', + accessTokenOrHeaders: 'sgp_local_f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0', serverEndpoint: DOTCOM_URL.toString(), }, }) diff --git a/vscode/src/autoedits/adapters/cody-gateway.ts b/vscode/src/autoedits/adapters/cody-gateway.ts index 65dd5c3f652d..d83a2c39bc85 100644 --- a/vscode/src/autoedits/adapters/cody-gateway.ts +++ b/vscode/src/autoedits/adapters/cody-gateway.ts @@ -35,7 +35,12 @@ export class CodyGatewayAdapter implements AutoeditsModelAdapter { private async getApiKey(): Promise { const resolvedConfig = await currentResolvedConfig() - const fastPathAccessToken = dotcomTokenToGatewayToken(resolvedConfig.auth.accessToken) + // TODO (pkukielka): Check if fastpath should support custom auth providers and how + const accessToken = + typeof resolvedConfig.auth.accessTokenOrHeaders === 'string' + ? resolvedConfig.auth.accessTokenOrHeaders + : null + const fastPathAccessToken = dotcomTokenToGatewayToken(accessToken) if (!fastPathAccessToken) { autoeditsOutputChannelLogger.logError('getApiKey', 'FastPath access token is not available') throw new Error('FastPath access token is not available') diff --git a/vscode/src/chat/agentic/DeepCody.test.ts b/vscode/src/chat/agentic/DeepCody.test.ts index d93e45f403a8..e70eda93ef93 100644 --- a/vscode/src/chat/agentic/DeepCody.test.ts +++ b/vscode/src/chat/agentic/DeepCody.test.ts @@ -45,7 +45,7 @@ describe('DeepCody', () => { } as any) beforeEach(async () => { - mockResolvedConfig({ configuration: {} }) + mockResolvedConfig({ configuration: {}, auth: { serverEndpoint: DOTCOM_URL.toString() } }) mockClientCapabilities(CLIENT_CAPABILITIES_FIXTURE) mockAuthStatus(codyProAuthStatus) localStorageData = {} diff --git a/vscode/src/chat/chat-view/ChatController.ts b/vscode/src/chat/chat-view/ChatController.ts index b42d1bcf3e95..cbcb3ceef1c4 100644 --- a/vscode/src/chat/chat-view/ChatController.ts +++ b/vscode/src/chat/chat-view/ChatController.ts @@ -23,6 +23,7 @@ import { type SerializedChatInteraction, type SerializedChatTranscript, type SerializedPromptEditorState, + type TokenSource, addMessageListenersForExtensionAPI, authStatus, cenv, @@ -53,6 +54,7 @@ import { ps, recordErrorToSpan, reformatBotMessageForChat, + resolveAuth, resolvedConfig, serializeChatMessage, shareReplay, @@ -409,10 +411,6 @@ export class ChatController implements vscode.Disposable, vscode.WebviewViewProv ) break case 'auth': { - if (message.authKind === 'callback' && message.endpoint) { - redirectToEndpointLogin(message.endpoint) - break - } if (message.authKind === 'simplified-onboarding') { const endpoint = DOTCOM_URL.href @@ -452,19 +450,36 @@ export class ChatController implements vscode.Disposable, vscode.WebviewViewProv } break } - if (message.authKind === 'signin' && message.endpoint) { + if ( + (message.authKind === 'signin' || message.authKind === 'callback') && + message.endpoint + ) { try { const { endpoint, value: token } = message + let accessTokenOrHeaders = null + let tokenSource: TokenSource | undefined = undefined + + if (token) { + accessTokenOrHeaders = token + tokenSource = 'paste' + } else { + const { configuration } = await currentResolvedConfig() + const auth = await resolveAuth(endpoint, configuration, secretStorage) + accessTokenOrHeaders = auth.accessTokenOrHeaders + tokenSource = auth.tokenSource + } + const credentials = { serverEndpoint: endpoint, - accessToken: token || (await secretStorage.getToken(endpoint)) || null, - tokenSource: token ? 'paste' : await secretStorage.getTokenSource(endpoint), + accessTokenOrHeaders, + tokenSource: tokenSource, } - if (!credentials.accessToken) { + if (!credentials.accessTokenOrHeaders) { return redirectToEndpointLogin(credentials.serverEndpoint) } await authProvider.validateAndStoreCredentials(credentials, 'always-store') } catch (error) { + void vscode.window.showErrorMessage(`Authentication failed: ${error}`) this.postError(new Error(`Authentication failed: ${error}`)) } break @@ -500,7 +515,7 @@ export class ChatController implements vscode.Disposable, vscode.WebviewViewProv const authStatus = await authProvider.validateAndStoreCredentials( { serverEndpoint: DOTCOM_URL.href, - accessToken: token, + accessTokenOrHeaders: token, }, 'store-if-valid' ) diff --git a/vscode/src/completions/default-client.ts b/vscode/src/completions/default-client.ts index 7b3912745d6c..992b174d5642 100644 --- a/vscode/src/completions/default-client.ts +++ b/vscode/src/completions/default-client.ts @@ -14,6 +14,8 @@ import { RateLimitError, type SerializedCodeCompletionsParams, TracedError, + addAuthHeaders, + addCodyClientIdentificationHeaders, addTraceparent, contextFiltersProvider, createSSEIterator, @@ -49,8 +51,8 @@ class DefaultCodeCompletionsClient implements CodeCompletionsClient { const { auth, configuration } = await currentResolvedConfig() const query = new URLSearchParams(getClientInfoParams()) - const url = new URL(`/.api/completions/code?${query.toString()}`, auth.serverEndpoint).href - const log = autocompleteLifecycleOutputChannelLogger?.startCompletion(params, url) + const url = new URL(`/.api/completions/code?${query.toString()}`, auth.serverEndpoint) + const log = autocompleteLifecycleOutputChannelLogger?.startCompletion(params, url.href) const { signal } = abortController return tracer.startActiveSpan( @@ -69,9 +71,8 @@ class DefaultCodeCompletionsClient implements CodeCompletionsClient { // Force HTTP connection reuse to reduce latency. // c.f. https://github.com/microsoft/vscode/issues/173861 headers.set('Content-Type', 'application/json; charset=utf-8') - if (auth.accessToken) { - headers.set('Authorization', `token ${auth.accessToken}`) - } + addCodyClientIdentificationHeaders(headers) + addAuthHeaders(auth, headers, url) if (tracingFlagEnabled) { headers.set('X-Sourcegraph-Should-Trace', '1') diff --git a/vscode/src/completions/nodeClient.ts b/vscode/src/completions/nodeClient.ts index 4f2fc6500492..cd561d66491f 100644 --- a/vscode/src/completions/nodeClient.ts +++ b/vscode/src/completions/nodeClient.ts @@ -13,6 +13,7 @@ import { NetworkError, RateLimitError, SourcegraphCompletionsClient, + addAuthHeaders, addClientInfoParams, addCodyClientIdentificationHeaders, currentResolvedConfig, @@ -95,13 +96,13 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie // responses afterwards. 'Accept-Encoding': 'gzip;q=0', 'X-Sourcegraph-Interaction-ID': interactionId || '', - ...(auth.accessToken ? { Authorization: `token ${auth.accessToken}` } : null), ...configuration?.customHeaders, ...requestParams.customHeaders, ...getTraceparentHeaders(), Connection: 'keep-alive', }) addCodyClientIdentificationHeaders(headers) + addAuthHeaders(auth, headers, url) const request = requestFn( url, @@ -299,7 +300,6 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie const headers = new Headers({ 'Content-Type': 'application/json', 'Accept-Encoding': 'gzip;q=0', - ...(auth.accessToken ? { Authorization: `token ${auth.accessToken}` } : null), ...configuration.customHeaders, ...requestParams.customHeaders, ...getTraceparentHeaders(), @@ -307,6 +307,7 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie }) addCodyClientIdentificationHeaders(headers) + addAuthHeaders(auth, headers, url) const response = await fetch(url.toString(), { method: 'POST', diff --git a/vscode/src/completions/providers/fireworks.ts b/vscode/src/completions/providers/fireworks.ts index 8f8548dc5b25..ee0a66b3822c 100644 --- a/vscode/src/completions/providers/fireworks.ts +++ b/vscode/src/completions/providers/fireworks.ts @@ -133,7 +133,12 @@ class FireworksProvider extends Provider { typeof process !== 'undefined' if (canFastPathBeUsed) { - const fastPathAccessToken = dotcomTokenToGatewayToken(config.auth.accessToken) + // TODO (pkukielka): Check if fastpath should support custom auth providers and how + const accessToken = + typeof config.auth.accessTokenOrHeaders === 'string' + ? config.auth.accessTokenOrHeaders + : null + const fastPathAccessToken = dotcomTokenToGatewayToken(accessToken) const localFastPathAccessToken = process.env.NODE_ENV === 'development' diff --git a/vscode/src/configuration.test.ts b/vscode/src/configuration.test.ts index 635a00b57e7b..a744497e53e8 100644 --- a/vscode/src/configuration.test.ts +++ b/vscode/src/configuration.test.ts @@ -135,6 +135,8 @@ describe('getConfiguration', () => { return false case 'cody.agentic.context.experimentalOptions': return { shell: { allow: ['git'] } } + case 'cody.auth.externalProviders': + return undefined default: assert(false, `unexpected key: ${key}`) } @@ -204,6 +206,7 @@ describe('getConfiguration', () => { overrideAuthToken: undefined, overrideServerEndpoint: undefined, + authExternalProviders: undefined, } satisfies ClientConfiguration) }) }) diff --git a/vscode/src/configuration.ts b/vscode/src/configuration.ts index a619ff054cba..5950f774fdf3 100644 --- a/vscode/src/configuration.ts +++ b/vscode/src/configuration.ts @@ -98,6 +98,8 @@ export function getConfiguration( agenticContextExperimentalShell: config.get(CONFIG_KEY.agenticContextExperimentalShell, false), agenticContextExperimentalOptions: config.get(CONFIG_KEY.agenticContextExperimentalOptions, {}), + authExternalProviders: config.get(CONFIG_KEY.authExternalProviders, undefined), + /** * Hidden settings for internal use only. */ diff --git a/vscode/src/local-context/rewrite-keyword-query.test.ts b/vscode/src/local-context/rewrite-keyword-query.test.ts index 099ab22236e5..71b294725e98 100644 --- a/vscode/src/local-context/rewrite-keyword-query.test.ts +++ b/vscode/src/local-context/rewrite-keyword-query.test.ts @@ -30,7 +30,7 @@ describe('rewrite-query', () => { mockResolvedConfig({ configuration: { customHeaders: {} }, auth: { - accessToken: + accessTokenOrHeaders: TESTING_CREDENTIALS.dotcom.token ?? TESTING_CREDENTIALS.dotcom.redactedToken, serverEndpoint: TESTING_CREDENTIALS.dotcom.serverEndpoint, }, diff --git a/vscode/src/main.ts b/vscode/src/main.ts index 5ba73e55b4e5..ffdbc7d799c2 100644 --- a/vscode/src/main.ts +++ b/vscode/src/main.ts @@ -666,8 +666,13 @@ async function registerTestCommands( } }), // Access token - this is only used in configuration tests - vscode.commands.registerCommand('cody.test.token', async (serverEndpoint, accessToken) => - authProvider.validateAndStoreCredentials({ serverEndpoint, accessToken }, 'always-store') + vscode.commands.registerCommand( + 'cody.test.token', + async (serverEndpoint, accessTokenOrHeaders) => + authProvider.validateAndStoreCredentials( + { serverEndpoint, accessTokenOrHeaders }, + 'always-store' + ) ) ) } diff --git a/vscode/src/notifications/setup-notification.ts b/vscode/src/notifications/setup-notification.ts index 0c0812aea15a..0b2deb2046c6 100644 --- a/vscode/src/notifications/setup-notification.ts +++ b/vscode/src/notifications/setup-notification.ts @@ -8,7 +8,7 @@ import { telemetryRecorder } from '@sourcegraph/cody-shared' import { showActionNotification } from '.' export const showSetupNotification = async (auth: AuthCredentials): Promise => { - if (auth.serverEndpoint && auth.accessToken) { + if (auth.serverEndpoint && auth.accessTokenOrHeaders) { // User has already attempted to configure Cody. // Regardless of if they are authenticated or not, we don't want to prompt them. return diff --git a/vscode/src/services/AuthProvider.test.ts b/vscode/src/services/AuthProvider.test.ts index 2ff7fd5a7ab0..fe24fb475a52 100644 --- a/vscode/src/services/AuthProvider.test.ts +++ b/vscode/src/services/AuthProvider.test.ts @@ -81,7 +81,7 @@ describe('AuthProvider', () => { const { values, clearValues } = readValuesFrom(authStatus) resolvedConfig.next({ configuration: {}, - auth: { serverEndpoint: 'https://example.com/', accessToken: 't' }, + auth: { serverEndpoint: 'https://example.com/', accessTokenOrHeaders: 't' }, clientState: { anonymousUserID: '123' }, } satisfies PartialDeep as ResolvedConfiguration) @@ -108,7 +108,7 @@ describe('AuthProvider', () => { validateCredentialsMock.mockReturnValue(asyncValue(authedAuthStatusBob, 10)) resolvedConfig.next({ configuration: {}, - auth: { serverEndpoint: 'https://other.example.com/', accessToken: 't2' }, + auth: { serverEndpoint: 'https://other.example.com/', accessTokenOrHeaders: 't2' }, clientState: { anonymousUserID: '123' }, } satisfies PartialDeep as ResolvedConfiguration) await vi.advanceTimersByTimeAsync(1) @@ -156,7 +156,7 @@ describe('AuthProvider', () => { const { values, clearValues } = readValuesFrom(authStatus) resolvedConfig.next({ configuration: {}, - auth: { serverEndpoint: 'https://example.com/', accessToken: 't' }, + auth: { serverEndpoint: 'https://example.com/', accessTokenOrHeaders: 't' }, clientState: { anonymousUserID: '123' }, } satisfies PartialDeep as ResolvedConfiguration) @@ -176,7 +176,7 @@ describe('AuthProvider', () => { const promise = authProvider.validateAndStoreCredentials( { configuration: {}, - auth: { serverEndpoint: 'https://other.example.com/', accessToken: 't2' }, + auth: { serverEndpoint: 'https://other.example.com/', accessTokenOrHeaders: 't2' }, clientState: { anonymousUserID: '123' }, }, 'always-store' @@ -212,7 +212,7 @@ describe('AuthProvider', () => { const { values, clearValues } = readValuesFrom(authStatus) resolvedConfig.next({ configuration: {}, - auth: { serverEndpoint: 'https://example.com/', accessToken: 't' }, + auth: { serverEndpoint: 'https://example.com/', accessTokenOrHeaders: 't' }, clientState: { anonymousUserID: '123' }, } satisfies PartialDeep as ResolvedConfiguration) diff --git a/vscode/src/services/LocalStorageProvider.ts b/vscode/src/services/LocalStorageProvider.ts index 70728275c3bb..1d027b1268fb 100644 --- a/vscode/src/services/LocalStorageProvider.ts +++ b/vscode/src/services/LocalStorageProvider.ts @@ -114,7 +114,7 @@ class LocalStorage implements LocalStorageForModelPreferences { * would give an inconsistent view of the state. */ public async saveEndpointAndToken( - credentials: Pick + credentials: Pick ): Promise { if (!credentials.serverEndpoint) { return @@ -129,10 +129,10 @@ class LocalStorage implements LocalStorageForModelPreferences { // Pass `false` to avoid firing the change event until we've stored all of the values. await this.set(this.LAST_USED_ENDPOINT, serverEndpoint, false) await this.addEndpointHistory(serverEndpoint, false) - if (credentials.accessToken) { + if (credentials.accessTokenOrHeaders && typeof credentials.accessTokenOrHeaders === 'string') { await secretStorage.storeToken( serverEndpoint, - credentials.accessToken, + credentials.accessTokenOrHeaders, credentials.tokenSource ) } diff --git a/vscode/src/services/UpstreamHealthProvider.ts b/vscode/src/services/UpstreamHealthProvider.ts index 8667dd2bc75c..a5fb5be15675 100644 --- a/vscode/src/services/UpstreamHealthProvider.ts +++ b/vscode/src/services/UpstreamHealthProvider.ts @@ -1,5 +1,6 @@ import { type BrowserOrNodeResponse, + addAuthHeaders, addCodyClientIdentificationHeaders, addTraceparent, currentResolvedConfig, @@ -94,11 +95,10 @@ class UpstreamHealthProvider implements vscode.Disposable { addTraceparent(sharedHeaders) addCodyClientIdentificationHeaders(sharedHeaders) - const upstreamHeaders = new Headers(sharedHeaders) - if (auth.accessToken) { - upstreamHeaders.set('Authorization', `token ${auth.accessToken}`) - } const url = new URL('/healthz', auth.serverEndpoint) + const upstreamHeaders = new Headers(sharedHeaders) + addAuthHeaders(auth, upstreamHeaders, url) + const upstreamResult = await wrapInActiveSpan('upstream-latency.upstream', span => { span.setAttribute('sampled', true) return measureLatencyToUri(upstreamHeaders, url.toString()) diff --git a/vscode/src/services/open-telemetry/CodyTraceExport.ts b/vscode/src/services/open-telemetry/CodyTraceExport.ts index b0e0d482daab..57380f09f040 100644 --- a/vscode/src/services/open-telemetry/CodyTraceExport.ts +++ b/vscode/src/services/open-telemetry/CodyTraceExport.ts @@ -1,6 +1,7 @@ import type { ExportResult } from '@opentelemetry/core' import { OTLPTraceExporter } from '@opentelemetry/exporter-trace-otlp-http' import type { ReadableSpan } from '@opentelemetry/sdk-trace-base' +import { type AuthCredentials, addAuthHeaders } from '@sourcegraph/cody-shared' const MAX_TRACE_RETAIN_MS = 60 * 1000 @@ -10,15 +11,16 @@ export class CodyTraceExporter extends OTLPTraceExporter { constructor({ traceUrl, - accessToken, + auth, isTracingEnabled, - }: { traceUrl: string; accessToken: string | null; isTracingEnabled: boolean }) { + }: { traceUrl: string; auth: AuthCredentials | null; isTracingEnabled: boolean }) { + const headers = new Headers() + if (auth) addAuthHeaders(auth, headers, new URL(traceUrl)) + super({ url: traceUrl, httpAgentOptions: { rejectUnauthorized: false }, - headers: { - ...(accessToken ? { Authorization: `token ${accessToken}` } : {}), - }, + headers: Object.fromEntries(headers.entries()), }) this.isTracingEnabled = isTracingEnabled } diff --git a/vscode/src/services/open-telemetry/OpenTelemetryService.node.ts b/vscode/src/services/open-telemetry/OpenTelemetryService.node.ts index 49ba2d4ac170..53c5920a5756 100644 --- a/vscode/src/services/open-telemetry/OpenTelemetryService.node.ts +++ b/vscode/src/services/open-telemetry/OpenTelemetryService.node.ts @@ -79,7 +79,7 @@ export class OpenTelemetryService { new CodyTraceExporter({ traceUrl, isTracingEnabled: this.isTracingEnabled, - accessToken: auth.accessToken, + auth, }) ) ) diff --git a/vscode/src/services/open-telemetry/trace-sender.ts b/vscode/src/services/open-telemetry/trace-sender.ts index d7e7bb410d82..fe89821f24ff 100644 --- a/vscode/src/services/open-telemetry/trace-sender.ts +++ b/vscode/src/services/open-telemetry/trace-sender.ts @@ -1,5 +1,4 @@ -import { currentResolvedConfig } from '@sourcegraph/cody-shared' -import fetch from 'node-fetch' +import { addAuthHeaders, currentResolvedConfig, fetch } from '@sourcegraph/cody-shared' import { logDebug, logError } from '../../output-channel-logger' /** @@ -22,18 +21,19 @@ export const TraceSender = { */ async function doSendTraceData(spanData: any): Promise { const { auth } = await currentResolvedConfig() - if (!auth.accessToken) { + if (!auth.accessTokenOrHeaders) { logError('TraceSender', 'Cannot send trace data: not authenticated') throw new Error('Not authenticated') } - const traceUrl = new URL('/-/debug/otlp/v1/traces', auth.serverEndpoint).toString() + const traceUrl = new URL('/-/debug/otlp/v1/traces', auth.serverEndpoint) + + const headers = new Headers({ 'Content-Type': 'application/json' }) + addAuthHeaders(auth, headers, traceUrl) + const response = await fetch(traceUrl, { method: 'POST', - headers: { - 'Content-Type': 'application/json', - ...(auth.accessToken ? { Authorization: `token ${auth.accessToken}` } : {}), - }, + headers: headers, body: spanData, }) diff --git a/vscode/src/testutils/mocks.ts b/vscode/src/testutils/mocks.ts index d22b053a60c4..a39b87bc55d2 100644 --- a/vscode/src/testutils/mocks.ts +++ b/vscode/src/testutils/mocks.ts @@ -934,4 +934,5 @@ export const DEFAULT_VSCODE_SETTINGS = { experimentalGuardrailsTimeoutSeconds: undefined, overrideAuthToken: undefined, overrideServerEndpoint: undefined, + authExternalProviders: undefined, } satisfies ClientConfiguration diff --git a/vscode/webviews/AppWrapperForTest.tsx b/vscode/webviews/AppWrapperForTest.tsx index 8cc967a8051b..2e404c5d79e2 100644 --- a/vscode/webviews/AppWrapperForTest.tsx +++ b/vscode/webviews/AppWrapperForTest.tsx @@ -113,7 +113,7 @@ export const AppWrapperForTest: FunctionComponent<{ children: ReactNode }> = ({ detectIntent: () => Observable.of(), resolvedConfig: () => Observable.of({ - auth: { accessToken: 'abc', serverEndpoint: 'https://example.com' }, + auth: { accessTokenOrHeaders: 'abc', serverEndpoint: 'https://example.com' }, configuration: { autocomplete: true, devModels: [{ model: 'my-model', provider: 'my-provider' }],