Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(credential-provider-ini): pass clientConfig to sso and sso-oidc inner clients #6688

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/credential-provider-ini/src/fromIni.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ export interface FromIniInit extends SourceProfileInit, CredentialProviderOption
roleAssumerWithWebIdentity?: (params: AssumeRoleWithWebIdentityParams) => Promise<AwsCredentialIdentity>;

/**
* STSClientConfig to be used for creating STS Client for assuming role.
* STSClientConfig or SSOClientConfig to be used for creating inner client
* for auth operations.
* @internal
*/
clientConfig?: any;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,39 @@ describe(resolveSsoCredentials.name, () => {
profile: mockProfileName,
});
});

it("passes through clientConfig and parentClientConfig to the fromSSO provider", async () => {
const mockProfileName = "mockProfileName";
const mockCreds: AwsCredentialIdentity = {
accessKeyId: "mockAccessKeyId",
secretAccessKey: "mockSecretAccessKey",
};
const requestHandler = vi.fn();
const logger = vi.fn();

vi.mocked(fromSSO).mockReturnValue(() => Promise.resolve(mockCreds));

const receivedCreds = await resolveSsoCredentials(
mockProfileName,
{},
{
clientConfig: {
requestHandler,
},
parentClientConfig: {
logger,
},
}
);
expect(receivedCreds).toStrictEqual(mockCreds);
expect(fromSSO).toHaveBeenCalledWith({
profile: mockProfileName,
clientConfig: {
requestHandler,
},
parentClientConfig: {
logger,
},
});
});
});
11 changes: 5 additions & 6 deletions packages/credential-provider-ini/src/resolveSsoCredentials.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import { setCredentialFeature } from "@aws-sdk/core/client";
import type { SsoProfile } from "@aws-sdk/credential-provider-sso";
import type { CredentialProviderOptions } from "@aws-sdk/types";
import type { IniSection, Profile } from "@smithy/types";

import type { FromIniInit } from "./fromIni";

/**
* @internal
*/
export const resolveSsoCredentials = async (
profile: string,
profileData: IniSection,
options: CredentialProviderOptions = {}
) => {
export const resolveSsoCredentials = async (profile: string, profileData: IniSection, options: FromIniInit = {}) => {
const { fromSSO } = await import("@aws-sdk/credential-provider-sso");
return fromSSO({
profile,
logger: options.logger,
parentClientConfig: options.parentClientConfig,
clientConfig: options.clientConfig,
})().then((creds) => {
if (profileData.sso_session) {
return setCredentialFeature(creds, "CREDENTIALS_PROFILE_SSO", "r");
Expand Down
2 changes: 2 additions & 0 deletions packages/credential-provider-sso/src/fromSSO.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ export const fromSSO =
ssoRoleName: sso_role_name,
ssoClient: ssoClient,
clientConfig: init.clientConfig,
parentClientConfig: init.parentClientConfig,
profile: profileName,
});
} else if (!ssoStartUrl || !ssoAccountId || !ssoRegion || !ssoRoleName) {
Expand All @@ -150,6 +151,7 @@ export const fromSSO =
ssoRoleName,
ssoClient,
clientConfig: init.clientConfig,
parentClientConfig: init.parentClientConfig,
profile: profileName,
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export const resolveSSOCredentials = async ({
ssoRoleName,
ssoClient,
clientConfig,
parentClientConfig,
profile,
logger,
}: FromSSOInit & SsoCredentialsParameters): Promise<AwsCredentialIdentity> => {
Expand Down Expand Up @@ -65,6 +66,7 @@ export const resolveSSOCredentials = async ({
ssoClient ||
new SSOClient(
Object.assign({}, clientConfig ?? {}, {
logger: clientConfig?.logger ?? parentClientConfig?.logger,
region: clientConfig?.region ?? ssoRegion,
})
);
Expand Down
15 changes: 10 additions & 5 deletions packages/token-providers/src/fromSso.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ describe(fromSso.name, () => {
accessToken: "mockNewAccessToken",
expiresIn: 3600,
refreshToken: "mockNewRefreshToken",
$metadata: {},
};
const mockNewToken = {
token: mockNewTokenFromService.accessToken,
Expand Down Expand Up @@ -166,7 +167,7 @@ describe(fromSso.name, () => {
const { fromSso } = await import("./fromSso");
await expect(fromSso(mockInit)()).resolves.toStrictEqual(mockNewToken);
expect(getNewSsoOidcToken).toHaveBeenCalledTimes(1);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region, mockInit);

// Simulate token expiration.
const ssoTokenExpiryError = new TokenProviderError(`SSO Token is expired. ${REFRESH_MESSAGE}`, false);
Expand All @@ -182,7 +183,7 @@ describe(fromSso.name, () => {
const { fromSso } = await import("./fromSso");
await expect(fromSso(mockInit)()).resolves.toStrictEqual(mockNewToken);
expect(getNewSsoOidcToken).toHaveBeenCalledTimes(1);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region, mockInit);

// Return a valid token for second call.
const mockValidSsoToken = {
Expand Down Expand Up @@ -230,7 +231,11 @@ describe(fromSso.name, () => {
token: mockValidSsoTokenInExpiryWindow.accessToken,
expiration: new Date(mockValidSsoTokenInExpiryWindow.expiresAt),
});
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockValidSsoTokenInExpiryWindow, mockSsoSession.sso_region);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(
mockValidSsoTokenInExpiryWindow,
mockSsoSession.sso_region,
mockInit
);
};

const throwErrorExpiredTokenTest = async (fromSsoImpl: typeof fromSso) => {
Expand All @@ -239,7 +244,7 @@ describe(fromSso.name, () => {
throw ssoTokenExpiryError;
});
await expect(fromSsoImpl(mockInit)()).rejects.toStrictEqual(ssoTokenExpiryError);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region, mockInit);
};

afterEach(() => {
Expand Down Expand Up @@ -285,7 +290,7 @@ describe(fromSso.name, () => {
const { fromSso } = await import("./fromSso");
await expect(fromSso(mockInit)()).resolves.toStrictEqual(mockNewToken);
expect(getNewSsoOidcToken).toHaveBeenCalledTimes(1);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region, mockInit);

expect(writeSSOTokenToFile).toHaveBeenCalledWith(mockSsoSessionName, {
...mockSsoToken,
Expand Down
9 changes: 7 additions & 2 deletions packages/token-providers/src/fromSso.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ import { writeSSOTokenToFile } from "./writeSSOTokenToFile";
*/
const lastRefreshAttemptTime = new Date(0);

export interface FromSsoInit extends SourceProfileInit, CredentialProviderOptions {}
export interface FromSsoInit extends SourceProfileInit, CredentialProviderOptions {
/**
* @see SSOOIDCClientConfig in \@aws-sdk/client-sso-oidc.
*/
clientConfig?: any;
}

/**
* Creates a token provider that will read from SSO token cache or ssoOidc.createToken() call.
Expand Down Expand Up @@ -101,7 +106,7 @@ export const fromSso =

try {
lastRefreshAttemptTime.setTime(Date.now());
const newSsoOidcToken = await getNewSsoOidcToken(ssoToken, ssoRegion);
const newSsoOidcToken = await getNewSsoOidcToken(ssoToken, ssoRegion, init);
validateTokenKey("accessToken", newSsoOidcToken.accessToken);
validateTokenKey("expiresIn", newSsoOidcToken.expiresIn);
const newTokenExpiration = new Date(Date.now() + newSsoOidcToken.expiresIn! * 1000);
Expand Down
8 changes: 4 additions & 4 deletions packages/token-providers/src/getNewSsoOidcToken.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ describe(getNewSsoOidcToken.name, () => {
} catch (error) {
expect(error).toStrictEqual(mockError);
}
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion);
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion, {});
expect(mockSend).not.toHaveBeenCalled();
expect(CreateTokenCommand).not.toHaveBeenCalled();
});
Expand All @@ -63,7 +63,7 @@ describe(getNewSsoOidcToken.name, () => {
} catch (error) {
expect(error).toStrictEqual(mockError);
}
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion);
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion, {});
expect(mockSendWithError).toHaveBeenCalledWith(mockCreateTokenArgs);
expect(CreateTokenCommand).toHaveBeenCalledWith(mockCreateTokenArgs);
});
Expand All @@ -78,7 +78,7 @@ describe(getNewSsoOidcToken.name, () => {
} catch (error) {
expect(error).toStrictEqual(mockError);
}
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion);
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion, {});
expect(mockSend).not.toHaveBeenCalled();
expect(CreateTokenCommand).toHaveBeenCalledWith(mockCreateTokenArgs);
});
Expand All @@ -90,6 +90,6 @@ describe(getNewSsoOidcToken.name, () => {
expect(newSsoOidcToken).toEqual(mockNewToken as any);
expect(CreateTokenCommand).toHaveBeenCalledWith(mockCreateTokenArgs);
expect(mockSend).toHaveBeenCalledWith(mockCreateTokenArgs);
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion);
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion, {});
});
});
5 changes: 3 additions & 2 deletions packages/token-providers/src/getNewSsoOidcToken.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import { SSOToken } from "@smithy/shared-ini-file-loader";

import { FromSsoInit } from "./fromSso";
import { getSsoOidcClient } from "./getSsoOidcClient";

/**
* Returns a new SSO OIDC token from ssoOids.createToken() API call.
* @internal
*/
export const getNewSsoOidcToken = async (ssoToken: SSOToken, ssoRegion: string) => {
export const getNewSsoOidcToken = async (ssoToken: SSOToken, ssoRegion: string, init: FromSsoInit = {}) => {
// @ts-ignore Cannot find module '@aws-sdk/client-sso-oidc'
const { CreateTokenCommand } = await import("@aws-sdk/client-sso-oidc");

const ssoOidcClient = await getSsoOidcClient(ssoRegion);
const ssoOidcClient = await getSsoOidcClient(ssoRegion, init);
return ssoOidcClient.send(
new CreateTokenCommand({
clientId: ssoToken.clientId,
Expand Down
33 changes: 17 additions & 16 deletions packages/token-providers/src/getSsoOidcClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ vi.mock("@aws-sdk/client-sso-oidc");

describe("getSsoOidcClient", () => {
const mockSsoRegion = "mockSsoRegion";
const mockRequestHandler = {
protocol: "http",
};
const getMockClient = (region: string) => ({ region });

beforeEach(() => {
Expand All @@ -22,24 +25,22 @@ describe("getSsoOidcClient", () => {
expect(SSOOIDCClient).toHaveBeenCalledTimes(1);
});

it("returns SSOOIDC client from hash if already created", async () => {
const { getSsoOidcClient } = await import("./getSsoOidcClient");
expect(await getSsoOidcClient(mockSsoRegion)).toEqual(getMockClient(mockSsoRegion) as any);
expect(SSOOIDCClient).toHaveBeenCalledTimes(1);
expect(await getSsoOidcClient(mockSsoRegion)).toEqual(getMockClient(mockSsoRegion) as any);
expect(SSOOIDCClient).toHaveBeenCalledTimes(1);
});

it("creates new SSOOIDC client per region", async () => {
it("passes through clientConfig and parentClientConfig.logger", async () => {
const { getSsoOidcClient } = await import("./getSsoOidcClient");
const mockSsoRegion1 = `${mockSsoRegion}1`;
expect(await getSsoOidcClient(mockSsoRegion1)).toEqual(getMockClient(mockSsoRegion1) as any);
expect(
await getSsoOidcClient(mockSsoRegion1, {
clientConfig: { requestHandler: mockRequestHandler },
parentClientConfig: { logger: console },
})
).toEqual({
region: mockSsoRegion1,
} as any);
expect(SSOOIDCClient).toHaveBeenCalledTimes(1);
expect(SSOOIDCClient).toHaveBeenCalledWith({ region: mockSsoRegion1 });

const mockSsoRegion2 = `${mockSsoRegion}2`;
expect(await getSsoOidcClient(mockSsoRegion2)).toEqual(getMockClient(mockSsoRegion2) as any);
expect(SSOOIDCClient).toHaveBeenCalledTimes(2);
expect(SSOOIDCClient).toHaveBeenNthCalledWith(2, { region: mockSsoRegion2 });
expect(SSOOIDCClient).toHaveBeenCalledWith({
region: mockSsoRegion1,
requestHandler: mockRequestHandler,
logger: console,
});
});
});
23 changes: 9 additions & 14 deletions packages/token-providers/src/getSsoOidcClient.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
const ssoOidcClientsHash: Record<string, any> = {};
import { FromSsoInit } from "./fromSso";

/**
* Returns a SSOOIDC client for the given region. If the client has already been created,
* it will be returned from the hash.
* Returns a SSOOIDC client for the given region.
* @internal
*/
export const getSsoOidcClient = async (ssoRegion: string) => {
export const getSsoOidcClient = async (ssoRegion: string, init: FromSsoInit = {}) => {
// @ts-ignore Cannot find module '@aws-sdk/client-sso-oidc'
const { SSOOIDCClient } = await import("@aws-sdk/client-sso-oidc");

// return ssoOidsClient if already created.
if (ssoOidcClientsHash[ssoRegion]) {
return ssoOidcClientsHash[ssoRegion];
}

// Create new SSOOIDC client, and store is in hash.
// If we need to support configuration of SsoOidc client in future through code,
// the provision to pass region from client configuration needs to be added.
const ssoOidcClient = new SSOOIDCClient({ region: ssoRegion });
ssoOidcClientsHash[ssoRegion] = ssoOidcClient;
const ssoOidcClient = new SSOOIDCClient(
Object.assign({}, init.clientConfig ?? {}, {
region: ssoRegion ?? init.clientConfig?.region,
logger: init.clientConfig?.logger ?? init.parentClientConfig?.logger,
})
);
return ssoOidcClient;
};
Loading