Skip to content

Commit

Permalink
fix(credential-provider-ini): pass clientConfig to sso and sso-oidc i…
Browse files Browse the repository at this point in the history
…nner clients (#6688)

* fix(credential-provider-ini): pass clientConfig to sso and sso-oidc inner clients

* fix: undefined check
  • Loading branch information
kuhe authored Nov 22, 2024
1 parent 0982bc4 commit 0ca3da3
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 50 deletions.
3 changes: 2 additions & 1 deletion packages/credential-provider-ini/src/fromIni.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ export interface FromIniInit extends SourceProfileInit, CredentialProviderOption
roleAssumerWithWebIdentity?: (params: AssumeRoleWithWebIdentityParams) => Promise<AwsCredentialIdentity>;

/**
* STSClientConfig to be used for creating STS Client for assuming role.
* STSClientConfig or SSOClientConfig to be used for creating inner client
* for auth operations.
* @internal
*/
clientConfig?: any;
Expand Down
35 changes: 35 additions & 0 deletions packages/credential-provider-ini/src/resolveSsoCredentials.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,39 @@ describe(resolveSsoCredentials.name, () => {
profile: mockProfileName,
});
});

it("passes through clientConfig and parentClientConfig to the fromSSO provider", async () => {
const mockProfileName = "mockProfileName";
const mockCreds: AwsCredentialIdentity = {
accessKeyId: "mockAccessKeyId",
secretAccessKey: "mockSecretAccessKey",
};
const requestHandler = vi.fn();
const logger = vi.fn();

vi.mocked(fromSSO).mockReturnValue(() => Promise.resolve(mockCreds));

const receivedCreds = await resolveSsoCredentials(
mockProfileName,
{},
{
clientConfig: {
requestHandler,
},
parentClientConfig: {
logger,
},
}
);
expect(receivedCreds).toStrictEqual(mockCreds);
expect(fromSSO).toHaveBeenCalledWith({
profile: mockProfileName,
clientConfig: {
requestHandler,
},
parentClientConfig: {
logger,
},
});
});
});
11 changes: 5 additions & 6 deletions packages/credential-provider-ini/src/resolveSsoCredentials.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import { setCredentialFeature } from "@aws-sdk/core/client";
import type { SsoProfile } from "@aws-sdk/credential-provider-sso";
import type { CredentialProviderOptions } from "@aws-sdk/types";
import type { IniSection, Profile } from "@smithy/types";

import type { FromIniInit } from "./fromIni";

/**
* @internal
*/
export const resolveSsoCredentials = async (
profile: string,
profileData: IniSection,
options: CredentialProviderOptions = {}
) => {
export const resolveSsoCredentials = async (profile: string, profileData: IniSection, options: FromIniInit = {}) => {
const { fromSSO } = await import("@aws-sdk/credential-provider-sso");
return fromSSO({
profile,
logger: options.logger,
parentClientConfig: options.parentClientConfig,
clientConfig: options.clientConfig,
})().then((creds) => {
if (profileData.sso_session) {
return setCredentialFeature(creds, "CREDENTIALS_PROFILE_SSO", "r");
Expand Down
2 changes: 2 additions & 0 deletions packages/credential-provider-sso/src/fromSSO.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ export const fromSSO =
ssoRoleName: sso_role_name,
ssoClient: ssoClient,
clientConfig: init.clientConfig,
parentClientConfig: init.parentClientConfig,
profile: profileName,
});
} else if (!ssoStartUrl || !ssoAccountId || !ssoRegion || !ssoRoleName) {
Expand All @@ -150,6 +151,7 @@ export const fromSSO =
ssoRoleName,
ssoClient,
clientConfig: init.clientConfig,
parentClientConfig: init.parentClientConfig,
profile: profileName,
});
}
Expand Down
2 changes: 2 additions & 0 deletions packages/credential-provider-sso/src/resolveSSOCredentials.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export const resolveSSOCredentials = async ({
ssoRoleName,
ssoClient,
clientConfig,
parentClientConfig,
profile,
logger,
}: FromSSOInit & SsoCredentialsParameters): Promise<AwsCredentialIdentity> => {
Expand Down Expand Up @@ -65,6 +66,7 @@ export const resolveSSOCredentials = async ({
ssoClient ||
new SSOClient(
Object.assign({}, clientConfig ?? {}, {
logger: clientConfig?.logger ?? parentClientConfig?.logger,
region: clientConfig?.region ?? ssoRegion,
})
);
Expand Down
15 changes: 10 additions & 5 deletions packages/token-providers/src/fromSso.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ describe(fromSso.name, () => {
accessToken: "mockNewAccessToken",
expiresIn: 3600,
refreshToken: "mockNewRefreshToken",
$metadata: {},
};
const mockNewToken = {
token: mockNewTokenFromService.accessToken,
Expand Down Expand Up @@ -166,7 +167,7 @@ describe(fromSso.name, () => {
const { fromSso } = await import("./fromSso");
await expect(fromSso(mockInit)()).resolves.toStrictEqual(mockNewToken);
expect(getNewSsoOidcToken).toHaveBeenCalledTimes(1);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region, mockInit);

// Simulate token expiration.
const ssoTokenExpiryError = new TokenProviderError(`SSO Token is expired. ${REFRESH_MESSAGE}`, false);
Expand All @@ -182,7 +183,7 @@ describe(fromSso.name, () => {
const { fromSso } = await import("./fromSso");
await expect(fromSso(mockInit)()).resolves.toStrictEqual(mockNewToken);
expect(getNewSsoOidcToken).toHaveBeenCalledTimes(1);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region, mockInit);

// Return a valid token for second call.
const mockValidSsoToken = {
Expand Down Expand Up @@ -230,7 +231,11 @@ describe(fromSso.name, () => {
token: mockValidSsoTokenInExpiryWindow.accessToken,
expiration: new Date(mockValidSsoTokenInExpiryWindow.expiresAt),
});
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockValidSsoTokenInExpiryWindow, mockSsoSession.sso_region);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(
mockValidSsoTokenInExpiryWindow,
mockSsoSession.sso_region,
mockInit
);
};

const throwErrorExpiredTokenTest = async (fromSsoImpl: typeof fromSso) => {
Expand All @@ -239,7 +244,7 @@ describe(fromSso.name, () => {
throw ssoTokenExpiryError;
});
await expect(fromSsoImpl(mockInit)()).rejects.toStrictEqual(ssoTokenExpiryError);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region, mockInit);
};

afterEach(() => {
Expand Down Expand Up @@ -285,7 +290,7 @@ describe(fromSso.name, () => {
const { fromSso } = await import("./fromSso");
await expect(fromSso(mockInit)()).resolves.toStrictEqual(mockNewToken);
expect(getNewSsoOidcToken).toHaveBeenCalledTimes(1);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region);
expect(getNewSsoOidcToken).toHaveBeenCalledWith(mockSsoToken, mockSsoSession.sso_region, mockInit);

expect(writeSSOTokenToFile).toHaveBeenCalledWith(mockSsoSessionName, {
...mockSsoToken,
Expand Down
9 changes: 7 additions & 2 deletions packages/token-providers/src/fromSso.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ import { writeSSOTokenToFile } from "./writeSSOTokenToFile";
*/
const lastRefreshAttemptTime = new Date(0);

export interface FromSsoInit extends SourceProfileInit, CredentialProviderOptions {}
export interface FromSsoInit extends SourceProfileInit, CredentialProviderOptions {
/**
* @see SSOOIDCClientConfig in \@aws-sdk/client-sso-oidc.
*/
clientConfig?: any;
}

/**
* Creates a token provider that will read from SSO token cache or ssoOidc.createToken() call.
Expand Down Expand Up @@ -101,7 +106,7 @@ export const fromSso =

try {
lastRefreshAttemptTime.setTime(Date.now());
const newSsoOidcToken = await getNewSsoOidcToken(ssoToken, ssoRegion);
const newSsoOidcToken = await getNewSsoOidcToken(ssoToken, ssoRegion, init);
validateTokenKey("accessToken", newSsoOidcToken.accessToken);
validateTokenKey("expiresIn", newSsoOidcToken.expiresIn);
const newTokenExpiration = new Date(Date.now() + newSsoOidcToken.expiresIn! * 1000);
Expand Down
8 changes: 4 additions & 4 deletions packages/token-providers/src/getNewSsoOidcToken.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ describe(getNewSsoOidcToken.name, () => {
} catch (error) {
expect(error).toStrictEqual(mockError);
}
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion);
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion, {});
expect(mockSend).not.toHaveBeenCalled();
expect(CreateTokenCommand).not.toHaveBeenCalled();
});
Expand All @@ -63,7 +63,7 @@ describe(getNewSsoOidcToken.name, () => {
} catch (error) {
expect(error).toStrictEqual(mockError);
}
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion);
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion, {});
expect(mockSendWithError).toHaveBeenCalledWith(mockCreateTokenArgs);
expect(CreateTokenCommand).toHaveBeenCalledWith(mockCreateTokenArgs);
});
Expand All @@ -78,7 +78,7 @@ describe(getNewSsoOidcToken.name, () => {
} catch (error) {
expect(error).toStrictEqual(mockError);
}
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion);
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion, {});
expect(mockSend).not.toHaveBeenCalled();
expect(CreateTokenCommand).toHaveBeenCalledWith(mockCreateTokenArgs);
});
Expand All @@ -90,6 +90,6 @@ describe(getNewSsoOidcToken.name, () => {
expect(newSsoOidcToken).toEqual(mockNewToken as any);
expect(CreateTokenCommand).toHaveBeenCalledWith(mockCreateTokenArgs);
expect(mockSend).toHaveBeenCalledWith(mockCreateTokenArgs);
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion);
expect(getSsoOidcClient).toHaveBeenCalledWith(mockSsoRegion, {});
});
});
5 changes: 3 additions & 2 deletions packages/token-providers/src/getNewSsoOidcToken.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import { SSOToken } from "@smithy/shared-ini-file-loader";

import { FromSsoInit } from "./fromSso";
import { getSsoOidcClient } from "./getSsoOidcClient";

/**
* Returns a new SSO OIDC token from ssoOids.createToken() API call.
* @internal
*/
export const getNewSsoOidcToken = async (ssoToken: SSOToken, ssoRegion: string) => {
export const getNewSsoOidcToken = async (ssoToken: SSOToken, ssoRegion: string, init: FromSsoInit = {}) => {
// @ts-ignore Cannot find module '@aws-sdk/client-sso-oidc'
const { CreateTokenCommand } = await import("@aws-sdk/client-sso-oidc");

const ssoOidcClient = await getSsoOidcClient(ssoRegion);
const ssoOidcClient = await getSsoOidcClient(ssoRegion, init);
return ssoOidcClient.send(
new CreateTokenCommand({
clientId: ssoToken.clientId,
Expand Down
33 changes: 17 additions & 16 deletions packages/token-providers/src/getSsoOidcClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ vi.mock("@aws-sdk/client-sso-oidc");

describe("getSsoOidcClient", () => {
const mockSsoRegion = "mockSsoRegion";
const mockRequestHandler = {
protocol: "http",
};
const getMockClient = (region: string) => ({ region });

beforeEach(() => {
Expand All @@ -22,24 +25,22 @@ describe("getSsoOidcClient", () => {
expect(SSOOIDCClient).toHaveBeenCalledTimes(1);
});

it("returns SSOOIDC client from hash if already created", async () => {
const { getSsoOidcClient } = await import("./getSsoOidcClient");
expect(await getSsoOidcClient(mockSsoRegion)).toEqual(getMockClient(mockSsoRegion) as any);
expect(SSOOIDCClient).toHaveBeenCalledTimes(1);
expect(await getSsoOidcClient(mockSsoRegion)).toEqual(getMockClient(mockSsoRegion) as any);
expect(SSOOIDCClient).toHaveBeenCalledTimes(1);
});

it("creates new SSOOIDC client per region", async () => {
it("passes through clientConfig and parentClientConfig.logger", async () => {
const { getSsoOidcClient } = await import("./getSsoOidcClient");
const mockSsoRegion1 = `${mockSsoRegion}1`;
expect(await getSsoOidcClient(mockSsoRegion1)).toEqual(getMockClient(mockSsoRegion1) as any);
expect(
await getSsoOidcClient(mockSsoRegion1, {
clientConfig: { requestHandler: mockRequestHandler },
parentClientConfig: { logger: console },
})
).toEqual({
region: mockSsoRegion1,
} as any);
expect(SSOOIDCClient).toHaveBeenCalledTimes(1);
expect(SSOOIDCClient).toHaveBeenCalledWith({ region: mockSsoRegion1 });

const mockSsoRegion2 = `${mockSsoRegion}2`;
expect(await getSsoOidcClient(mockSsoRegion2)).toEqual(getMockClient(mockSsoRegion2) as any);
expect(SSOOIDCClient).toHaveBeenCalledTimes(2);
expect(SSOOIDCClient).toHaveBeenNthCalledWith(2, { region: mockSsoRegion2 });
expect(SSOOIDCClient).toHaveBeenCalledWith({
region: mockSsoRegion1,
requestHandler: mockRequestHandler,
logger: console,
});
});
});
23 changes: 9 additions & 14 deletions packages/token-providers/src/getSsoOidcClient.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
const ssoOidcClientsHash: Record<string, any> = {};
import { FromSsoInit } from "./fromSso";

/**
* Returns a SSOOIDC client for the given region. If the client has already been created,
* it will be returned from the hash.
* Returns a SSOOIDC client for the given region.
* @internal
*/
export const getSsoOidcClient = async (ssoRegion: string) => {
export const getSsoOidcClient = async (ssoRegion: string, init: FromSsoInit = {}) => {
// @ts-ignore Cannot find module '@aws-sdk/client-sso-oidc'
const { SSOOIDCClient } = await import("@aws-sdk/client-sso-oidc");

// return ssoOidsClient if already created.
if (ssoOidcClientsHash[ssoRegion]) {
return ssoOidcClientsHash[ssoRegion];
}

// Create new SSOOIDC client, and store is in hash.
// If we need to support configuration of SsoOidc client in future through code,
// the provision to pass region from client configuration needs to be added.
const ssoOidcClient = new SSOOIDCClient({ region: ssoRegion });
ssoOidcClientsHash[ssoRegion] = ssoOidcClient;
const ssoOidcClient = new SSOOIDCClient(
Object.assign({}, init.clientConfig ?? {}, {
region: ssoRegion ?? init.clientConfig?.region,
logger: init.clientConfig?.logger ?? init.parentClientConfig?.logger,
})
);
return ssoOidcClient;
};

0 comments on commit 0ca3da3

Please sign in to comment.