Skip to content

Commit

Permalink
Add support for custom auth providers
Browse files Browse the repository at this point in the history
  • Loading branch information
pkukielka committed Jan 7, 2025
1 parent 4d04c9a commit 1b034fd
Show file tree
Hide file tree
Showing 40 changed files with 413 additions and 141 deletions.
68 changes: 68 additions & 0 deletions agent/scripts/reverse-proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from aiohttp import web, ClientSession
from urllib.parse import urlparse
import asyncio

target_url = ''
port = 5050

async def proxy_handler(request):
async with ClientSession(auto_decompress=False) as session:
print(f'Request to: {request.url}')

# Modify headers here
headers = dict(request.headers)

# Reset the Host header to use target server host instead of the proxy host
if 'Host' in headers:
headers['Host'] = urlparse(target_url).netloc.split(':')[0]

# 'chunked' encoding results in error 400 from Cloudflare, removing it still keeps response chunked anyway
if 'Transfer-Encoding' in headers:
del headers['Transfer-Encoding']

# Use value of 'Authorization: Bearer' to fill 'X-Forwarded-User' and remove 'Authorization' header
if 'Authorization' in headers:
values = headers['Authorization'].split()
if values and values[0] == 'Bearer':
headers['X-Forwarded-User'] = values[1]
del headers['Authorization']

# Forward the request to target
async with session.request(
method=request.method,
url=f'{target_url}{request.path_qs}',
headers=headers,
data=await request.read()
) as response:
proxy_response = web.StreamResponse(
status=response.status,
headers=response.headers
)

await proxy_response.prepare(request)

# Stream the response back
async for chunk in response.content.iter_chunks():
await proxy_response.write(chunk[0])

await proxy_response.write_eof()
return proxy_response

app = web.Application()
app.router.add_route('*', '/{path_info:.*}', proxy_handler)


if __name__ == '__main__':
print('Usage: python reverse_proxy.py [target_url] [proxy_port]')

import sys
if (len(sys.argv) < 2):
print('Please specify target_url')
sys.exit(1)
if len(sys.argv) > 1:
target_url = sys.argv[1]
if len(sys.argv) > 2:
port = int(sys.argv[2])

print(f'Starting proxy server on port {port} targeting {target_url}...')
web.run_app(app, port=port)
7 changes: 3 additions & 4 deletions agent/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1486,11 +1486,10 @@ export class Agent extends MessageHandler implements ExtensionClient {
config: ExtensionConfiguration,
params?: { forceAuthentication: boolean }
): Promise<AuthStatus> {
const isAuthChange = vscode_shim.isAuthenticationChange(config)
const isAuthChange = vscode_shim.isTokenOrEndpointChange(config)
vscode_shim.setExtensionConfiguration(config)

// If this is an authentication change we need to reauthenticate prior to firing events
// that update the clients
// If this is an token or endpoint change we need to save them prior to firing events that update the clients
try {
if (isAuthChange || params?.forceAuthentication) {
await authProvider.validateAndStoreCredentials(
Expand All @@ -1500,7 +1499,7 @@ export class Agent extends MessageHandler implements ExtensionClient {
},
auth: {
serverEndpoint: config.serverEndpoint,
accessToken: config.accessToken ?? null,
accessTokenOrHeaders: config.accessToken || null,
},
clientState: {
anonymousUserID: config.anonymousUserID ?? null,
Expand Down
2 changes: 1 addition & 1 deletion agent/src/cli/command-auth/AuthenticatedAccount.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export class AuthenticatedAccount {
): Promise<AuthenticatedAccount | Error> {
const graphqlClient = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessToken: options.accessToken, serverEndpoint: options.endpoint },
auth: { accessTokenOrHeaders: options.accessToken, serverEndpoint: options.endpoint },
clientState: { anonymousUserID: null },
})
const userInfo = await graphqlClient.getCurrentUserInfo()
Expand Down
4 changes: 2 additions & 2 deletions agent/src/cli/command-auth/command-login.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async function loginAction(
: await captureAccessTokenViaBrowserRedirect(serverEndpoint, spinner)
const client = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessToken: token, serverEndpoint: serverEndpoint },
auth: { accessTokenOrHeaders: token, serverEndpoint: serverEndpoint },
clientState: { anonymousUserID: null },
})
const userInfo = await client.getCurrentUserInfo()
Expand Down Expand Up @@ -256,7 +256,7 @@ async function promptUserAboutLoginMethod(spinner: Ora, options: LoginOptions):
try {
const client = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { telemetryLevel: 'agent' },
auth: { accessToken: options.accessToken, serverEndpoint: options.endpoint },
auth: { accessTokenOrHeaders: options.accessToken, serverEndpoint: options.endpoint },
clientState: { anonymousUserID: null },
})
const userInfo = await client.getCurrentUserInfo()
Expand Down
2 changes: 1 addition & 1 deletion agent/src/cli/command-bench/command-bench.ts
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ export const benchCommand = new commander.Command('bench')
setStaticResolvedConfigurationWithAuthCredentials({
configuration: { customHeaders: {} },
auth: {
accessToken: options.srcAccessToken,
accessTokenOrHeaders: options.srcAccessToken,
serverEndpoint: options.srcEndpoint,
},
})
Expand Down
2 changes: 1 addition & 1 deletion agent/src/cli/command-bench/llm-judge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export class LlmJudge {
localStorage.setStorage('noop')
setStaticResolvedConfigurationWithAuthCredentials({
configuration: { customHeaders: undefined },
auth: { accessToken: options.srcAccessToken, serverEndpoint: options.srcEndpoint },
auth: { accessTokenOrHeaders: options.srcAccessToken, serverEndpoint: options.srcEndpoint },
})
setClientCapabilities({ configuration: {}, agentCapabilities: undefined })
this.client = new SourcegraphNodeCompletionsClient()
Expand Down
5 changes: 4 additions & 1 deletion agent/src/local-e2e/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ export class LocalSGInstance {
// for checking the LLM configuration section.
this.gqlclient = SourcegraphGraphQLAPIClient.withStaticConfig({
configuration: { customHeaders: headers, telemetryLevel: 'agent' },
auth: { accessToken: this.params.accessToken, serverEndpoint: this.params.serverEndpoint },
auth: {
accessTokenOrHeaders: this.params.accessToken,
serverEndpoint: this.params.serverEndpoint,
},
clientState: { anonymousUserID: null },
})
}
Expand Down
3 changes: 2 additions & 1 deletion agent/src/vscode-shim.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ export let extensionConfiguration: ExtensionConfiguration | undefined
export function setExtensionConfiguration(newConfig: ExtensionConfiguration): void {
extensionConfiguration = newConfig
}
export function isAuthenticationChange(newConfig: ExtensionConfiguration): boolean {

export function isTokenOrEndpointChange(newConfig: ExtensionConfiguration): boolean {
if (!extensionConfiguration) {
return true
}
Expand Down
23 changes: 21 additions & 2 deletions lib/shared/src/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ import type { ReadonlyDeep } from './utils'
* A redirect flow is initiated by the user clicking a link in the browser, while a paste flow is initiated by the user
* manually entering the access from into the VsCode App.
*/
export type TokenSource = 'redirect' | 'paste'
export type TokenSource = 'redirect' | 'paste' | 'custom-auth-provider'

export type AuthHeaders = Record<string, string>

/**
* The user's authentication credentials, which are stored separately from the rest of the
* configuration.
*/
export interface AuthCredentials {
serverEndpoint: string
accessToken: string | null
tokenSource?: TokenSource | undefined
accessTokenOrHeaders: string | AuthHeaders | null
}

export interface AutoEditsTokenLimit {
Expand Down Expand Up @@ -71,6 +73,20 @@ export interface AgenticContextConfiguration {
}
}

export interface ExternalAuthCommand {
commandLine: string[]
environment?: Record<string, string>
workingDir?: string
shell?: string
timeout?: number
windowsHide?: boolean
}

export interface ExternalAuthProvider {
endpoint: string
executable: ExternalAuthCommand
}

interface RawClientConfiguration {
net: NetConfiguration
codebase?: string
Expand Down Expand Up @@ -166,6 +182,9 @@ interface RawClientConfiguration {
*/
overrideServerEndpoint?: string | undefined
overrideAuthToken?: string | undefined

// External auth providers
authExternalProviders?: ExternalAuthProvider[]
}

/**
Expand Down
123 changes: 100 additions & 23 deletions lib/shared/src/configuration/resolver.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { Observable, map } from 'observable-fns'
import type { AuthCredentials, ClientConfiguration } from '../configuration'
import type {
AuthCredentials,
ClientConfiguration,
ExternalAuthCommand,
TokenSource,
} from '../configuration'
import { logError } from '../logger'
import {
distinctUntilChanged,
Expand Down Expand Up @@ -27,6 +32,7 @@ export interface ConfigurationInput {

export interface ClientSecrets {
getToken(endpoint: string): Promise<string | undefined>
getTokenSource(endpoint: string): Promise<TokenSource | undefined>
}

export interface ClientState {
Expand Down Expand Up @@ -72,39 +78,110 @@ export type PickResolvedConfiguration<Keys extends KeysSpec> = {
: undefined
}

async function executeCommand(cmd: ExternalAuthCommand): Promise<string> {
if (typeof process === 'undefined' || !process.version) {
throw new Error('Command execution is only supported in Node.js environments')
}

const { exec } = await import('node:child_process')
const { promisify } = await import('node:util')
const execAsync = promisify(exec)

const command = cmd.commandLine.join(' ')
const options = {
...cmd,
env: cmd.environment ? { ...process.env, ...cmd.environment } : process.env,
}

const { stdout, stderr } = await execAsync(command, options)
if (stderr) {
throw new Error(`External auth command error: ${stderr}`)
}
return stdout.trim()
}

async function getExternalProviderAuthHeaders(
serverEndpoint: string,
clientConfiguration: ClientConfiguration
): Promise<Record<string, string> | undefined> {
// Check for external auth provider for this endpoint
const externalProvider = clientConfiguration.authExternalProviders?.find(
provider => normalizeServerEndpointURL(provider.endpoint) === serverEndpoint
)

if (externalProvider) {
const result = await executeCommand(externalProvider.executable)
return JSON.parse(result)
}

return undefined
}

export async function resolveAuth(
endpoint: string,
clientConfiguration: ClientConfiguration,
clientSecrets: ClientSecrets
): Promise<AuthCredentials> {
let accessTokenOrHeaders = null
let tokenSource: TokenSource | undefined = undefined

const serverEndpoint = normalizeServerEndpointURL(
clientConfiguration.overrideServerEndpoint || endpoint
)

// We must not throw here, because that would result in the `resolvedConfig` observable
// terminating and all callers receiving no further config updates.
const loadTokenFn = () =>
clientSecrets.getToken(serverEndpoint).catch(error => {
throw new Error(
`Failed to get access token for endpoint ${serverEndpoint}: ${error.message || error}`
)
})

if (clientConfiguration.overrideAuthToken) {
accessTokenOrHeaders = clientConfiguration.overrideAuthToken
} else
try {
const authHeaders = await getExternalProviderAuthHeaders(serverEndpoint, clientConfiguration)
if (authHeaders) {
accessTokenOrHeaders = authHeaders
tokenSource = 'custom-auth-provider'
} else {
accessTokenOrHeaders = (await loadTokenFn()) || null
tokenSource = await clientSecrets.getTokenSource(serverEndpoint).catch(_ => undefined)
}
} catch (error) {
throw new Error(`Failed to execute external auth command: ${error}`)
}

return { accessTokenOrHeaders, serverEndpoint, tokenSource }
}

async function resolveConfiguration({
clientConfiguration,
clientSecrets,
clientState,
reinstall: { isReinstalling, onReinstall },
}: ConfigurationInput): Promise<ResolvedConfiguration> {
}: ConfigurationInput): Promise<ResolvedConfiguration | Error> {
const isReinstall = await isReinstalling()
if (isReinstall) {
await onReinstall()
}
// we allow for overriding the server endpoint from config if we haven't
// manually signed in somewhere else
const serverEndpoint = normalizeServerEndpointURL(

const serverEndpoint =
clientConfiguration.overrideServerEndpoint ||
(clientState.lastUsedEndpoint ?? DOTCOM_URL.toString())
)
clientState.lastUsedEndpoint ||
DOTCOM_URL.toString()

// We must not throw here, because that would result in the `resolvedConfig` observable
// terminating and all callers receiving no further config updates.
const loadTokenFn = () =>
clientSecrets.getToken(serverEndpoint).catch(error => {
logError(
'resolveConfiguration',
`Failed to get access token for endpoint ${serverEndpoint}: ${error}`
)
return null
})
const accessToken = clientConfiguration.overrideAuthToken || ((await loadTokenFn()) ?? null)
return {
configuration: clientConfiguration,
clientState,
auth: { accessToken, serverEndpoint },
isReinstall,
try {
const auth = await resolveAuth(serverEndpoint, clientConfiguration, clientSecrets)
return { configuration: clientConfiguration, clientState, auth, isReinstall }
} catch (error) {
// We don't want to throw here, because that would cause the observable to terminate and
// all callers receiving no further config updates.
logError('resolveConfiguration', `Error resolving configuration: ${error}`)
const auth = { accessTokenOrHeaders: null, serverEndpoint, tokenSource: undefined }
return { configuration: clientConfiguration, clientState, auth, isReinstall }
}
}

Expand Down
2 changes: 1 addition & 1 deletion lib/shared/src/experimentation/FeatureFlagProvider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ describe('FeatureFlagProvider', () => {
beforeAll(() => {
vi.useFakeTimers()
mockResolvedConfig({
auth: { accessToken: null, serverEndpoint: 'https://example.com' },
auth: { accessTokenOrHeaders: null, serverEndpoint: 'https://example.com' },
})
mockAuthStatus(AUTH_STATUS_FIXTURE_AUTHED)
})
Expand Down
6 changes: 1 addition & 5 deletions lib/shared/src/models/sync.ts
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,7 @@ async function fetchServerSideModels(
): Promise<ServerModelConfiguration | undefined> {
// Fetch the data via REST API.
// NOTE: We may end up exposing this data via GraphQL, it's still TBD.
const client = new RestClient(
config.auth.serverEndpoint,
config.auth.accessToken ?? undefined,
config.configuration.customHeaders
)
const client = new RestClient(config.auth, config.configuration.customHeaders)
return await client.getAvailableModels(signal)
}

Expand Down
Loading

0 comments on commit 1b034fd

Please sign in to comment.