From 43ee3c1e938cf6f7978e5bc65c01a1f5c5e70a63 Mon Sep 17 00:00:00 2001 From: Trivikram Kamat <16024985+trivikr@users.noreply.github.com> Date: Wed, 6 Nov 2024 15:04:23 -0800 Subject: [PATCH] chore(middleware-flexible-checksums): delay checksum validation until stream read (#6629) --- clients/client-s3/test/e2e/S3.e2e.spec.ts | 29 ++++++--- .../package.json | 1 + .../src/getChecksum.spec.ts | 16 +---- .../src/getChecksum.ts | 13 +--- .../src/validateChecksumFromResponse.spec.ts | 63 ++++++++++++++++--- .../src/validateChecksumFromResponse.ts | 18 +++++- 6 files changed, 99 insertions(+), 41 deletions(-) diff --git a/clients/client-s3/test/e2e/S3.e2e.spec.ts b/clients/client-s3/test/e2e/S3.e2e.spec.ts index e29cbaf0539c..e1307c8170b6 100644 --- a/clients/client-s3/test/e2e/S3.e2e.spec.ts +++ b/clients/client-s3/test/e2e/S3.e2e.spec.ts @@ -1,6 +1,6 @@ import "@aws-sdk/signature-v4-crt"; -import { S3, SelectObjectContentEventStream } from "@aws-sdk/client-s3"; +import { ChecksumAlgorithm, S3, SelectObjectContentEventStream } from "@aws-sdk/client-s3"; import { afterAll, afterEach, beforeAll, describe, expect, test as it } from "vitest"; import { getIntegTestResources } from "../../../../tests/e2e/get-integ-test-resources"; @@ -24,9 +24,7 @@ describe("@aws-sdk/client-s3", () => { Key = ``; - client = new S3({ - region, - }); + client = new S3({ region }); }); describe("PutObject", () => { @@ -74,26 +72,43 @@ describe("@aws-sdk/client-s3", () => { await client.deleteObject({ Bucket, Key }); }); - it("should succeed with valid body payload", async () => { + it("should succeed with valid body payload with checksums", async () => { // prepare the object. const body = createBuffer("1MB"); + let bodyChecksum = ""; + + const bodyChecksumReader = (next) => async (args) => { + const checksumValue = args.request.headers["x-amz-checksum-crc32"]; + if (checksumValue) { + bodyChecksum = checksumValue; + } + return next(args); + }; + client.middlewareStack.addRelativeTo(bodyChecksumReader, { + name: "bodyChecksumReader", + relation: "before", + toMiddleware: "deserializerMiddleware", + }); try { - await client.putObject({ Bucket, Key, Body: body }); + await client.putObject({ Bucket, Key, Body: body, ChecksumAlgorithm: ChecksumAlgorithm.CRC32 }); } catch (e) { console.error("failed to put"); throw e; } + expect(bodyChecksum).not.toEqual(""); + try { // eslint-disable-next-line no-var - var result = await client.getObject({ Bucket, Key }); + var result = await client.getObject({ Bucket, Key, ChecksumMode: "ENABLED" }); } catch (e) { console.error("failed to get"); throw e; } expect(result.$metadata.httpStatusCode).toEqual(200); + expect(result.ChecksumCRC32).toEqual(bodyChecksum); const { Readable } = require("stream"); expect(result.Body).toBeInstanceOf(Readable); }); diff --git a/packages/middleware-flexible-checksums/package.json b/packages/middleware-flexible-checksums/package.json index 740cb1a4dab4..a66eeca17ef1 100644 --- a/packages/middleware-flexible-checksums/package.json +++ b/packages/middleware-flexible-checksums/package.json @@ -40,6 +40,7 @@ "@smithy/types": "^3.6.0", "@smithy/util-middleware": "^3.0.8", "@smithy/util-utf8": "^3.0.0", + "@smithy/util-stream": "^3.2.1", "tslib": "^2.6.2" }, "devDependencies": { diff --git a/packages/middleware-flexible-checksums/src/getChecksum.spec.ts b/packages/middleware-flexible-checksums/src/getChecksum.spec.ts index 2a8b37001468..383d30ec8026 100644 --- a/packages/middleware-flexible-checksums/src/getChecksum.spec.ts +++ b/packages/middleware-flexible-checksums/src/getChecksum.spec.ts @@ -1,15 +1,12 @@ import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest"; import { getChecksum } from "./getChecksum"; -import { isStreaming } from "./isStreaming"; import { stringHasher } from "./stringHasher"; -vi.mock("./isStreaming"); vi.mock("./stringHasher"); describe(getChecksum.name, () => { const mockOptions = { - streamHasher: vi.fn(), checksumAlgorithmFn: vi.fn(), base64Encoder: vi.fn(), }; @@ -26,21 +23,10 @@ describe(getChecksum.name, () => { vi.clearAllMocks(); }); - it("gets checksum from streamHasher if body is streaming", async () => { - vi.mocked(isStreaming).mockReturnValue(true); - mockOptions.streamHasher.mockResolvedValue(mockRawOutput); - const checksum = await getChecksum(mockBody, mockOptions); - expect(checksum).toEqual(mockOutput); - expect(stringHasher).not.toHaveBeenCalled(); - expect(mockOptions.streamHasher).toHaveBeenCalledWith(mockOptions.checksumAlgorithmFn, mockBody); - }); - - it("gets checksum from stringHasher if body is not streaming", async () => { - vi.mocked(isStreaming).mockReturnValue(false); + it("gets checksum from stringHasher", async () => { vi.mocked(stringHasher).mockResolvedValue(mockRawOutput); const checksum = await getChecksum(mockBody, mockOptions); expect(checksum).toEqual(mockOutput); - expect(mockOptions.streamHasher).not.toHaveBeenCalled(); expect(stringHasher).toHaveBeenCalledWith(mockOptions.checksumAlgorithmFn, mockBody); }); }); diff --git a/packages/middleware-flexible-checksums/src/getChecksum.ts b/packages/middleware-flexible-checksums/src/getChecksum.ts index 4452fe466397..fd1b3091741e 100644 --- a/packages/middleware-flexible-checksums/src/getChecksum.ts +++ b/packages/middleware-flexible-checksums/src/getChecksum.ts @@ -1,18 +1,11 @@ -import { ChecksumConstructor, Encoder, HashConstructor, StreamHasher } from "@smithy/types"; +import { ChecksumConstructor, Encoder, HashConstructor } from "@smithy/types"; -import { isStreaming } from "./isStreaming"; import { stringHasher } from "./stringHasher"; export interface GetChecksumDigestOptions { - streamHasher: StreamHasher; checksumAlgorithmFn: ChecksumConstructor | HashConstructor; base64Encoder: Encoder; } -export const getChecksum = async ( - body: unknown, - { streamHasher, checksumAlgorithmFn, base64Encoder }: GetChecksumDigestOptions -) => { - const digest = isStreaming(body) ? streamHasher(checksumAlgorithmFn, body) : stringHasher(checksumAlgorithmFn, body); - return base64Encoder(await digest); -}; +export const getChecksum = async (body: unknown, { checksumAlgorithmFn, base64Encoder }: GetChecksumDigestOptions) => + base64Encoder(await stringHasher(checksumAlgorithmFn, body)); diff --git a/packages/middleware-flexible-checksums/src/validateChecksumFromResponse.spec.ts b/packages/middleware-flexible-checksums/src/validateChecksumFromResponse.spec.ts index 8a8944bbe9a2..7829cdd13d2f 100644 --- a/packages/middleware-flexible-checksums/src/validateChecksumFromResponse.spec.ts +++ b/packages/middleware-flexible-checksums/src/validateChecksumFromResponse.spec.ts @@ -1,4 +1,5 @@ import { HttpResponse } from "@smithy/protocol-http"; +import { createChecksumStream } from "@smithy/util-stream"; import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest"; import { PreviouslyResolved } from "./configuration"; @@ -6,21 +7,24 @@ import { ChecksumAlgorithm } from "./constants"; import { getChecksum } from "./getChecksum"; import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse"; import { getChecksumLocationName } from "./getChecksumLocationName"; +import { isStreaming } from "./isStreaming"; import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction"; import { validateChecksumFromResponse } from "./validateChecksumFromResponse"; +vi.mock("@smithy/util-stream"); vi.mock("./getChecksum"); vi.mock("./getChecksumLocationName"); vi.mock("./getChecksumAlgorithmListForResponse"); +vi.mock("./isStreaming"); vi.mock("./selectChecksumAlgorithmFunction"); describe(validateChecksumFromResponse.name, () => { const mockConfig = { - streamHasher: vi.fn(), base64Encoder: vi.fn(), } as unknown as PreviouslyResolved; const mockBody = {}; + const mockBodyStream = { isStream: true }; const mockHeaders = {}; const mockResponse = { body: mockBody, @@ -50,6 +54,7 @@ describe(validateChecksumFromResponse.name, () => { vi.mocked(getChecksumAlgorithmListForResponse).mockImplementation((responseAlgorithms) => responseAlgorithms); vi.mocked(selectChecksumAlgorithmFunction).mockReturnValue(mockChecksumAlgorithmFn); vi.mocked(getChecksum).mockResolvedValue(mockChecksum); + vi.mocked(createChecksumStream).mockReturnValue(mockBodyStream); }); afterEach(() => { @@ -85,31 +90,56 @@ describe(validateChecksumFromResponse.name, () => { }); describe("successful validation", () => { - afterEach(() => { + const validateCalls = (isStream: boolean, checksumAlgoFn: ChecksumAlgorithm) => { expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms); expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1); - expect(getChecksum).toHaveBeenCalledTimes(1); - }); - it("when checksum is populated for first algorithm", async () => { + if (isStream) { + expect(getChecksum).not.toHaveBeenCalled(); + expect(createChecksumStream).toHaveBeenCalledTimes(1); + expect(createChecksumStream).toHaveBeenCalledWith({ + expectedChecksum: mockChecksum, + checksumSourceLocation: checksumAlgoFn, + checksum: new mockChecksumAlgorithmFn(), + source: mockBody, + base64Encoder: mockConfig.base64Encoder, + }); + } else { + expect(getChecksum).toHaveBeenCalledTimes(1); + expect(getChecksum).toHaveBeenCalledWith(mockBody, { + checksumAlgorithmFn: mockChecksumAlgorithmFn, + base64Encoder: mockConfig.base64Encoder, + }); + expect(createChecksumStream).not.toHaveBeenCalled(); + } + }; + + it.each([false, true])("when checksum is populated for first algorithm when streaming: %s", async (isStream) => { + vi.mocked(isStreaming).mockReturnValue(isStream); const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], mockChecksum); await validateChecksumFromResponse(responseWithChecksum, mockOptions); expect(getChecksumLocationName).toHaveBeenCalledTimes(1); expect(getChecksumLocationName).toHaveBeenCalledWith(mockResponseAlgorithms[0]); + validateCalls(isStream, mockResponseAlgorithms[0]); }); - it("when checksum is populated for second algorithm", async () => { + it.each([false, true])("when checksum is populated for second algorithm when streaming: %s", async (isStream) => { + vi.mocked(isStreaming).mockReturnValue(isStream); const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[1], mockChecksum); await validateChecksumFromResponse(responseWithChecksum, mockOptions); expect(getChecksumLocationName).toHaveBeenCalledTimes(2); expect(getChecksumLocationName).toHaveBeenNthCalledWith(1, mockResponseAlgorithms[0]); expect(getChecksumLocationName).toHaveBeenNthCalledWith(2, mockResponseAlgorithms[1]); + validateCalls(isStream, mockResponseAlgorithms[1]); }); }); - it("throw error if checksum value is not accurate", async () => { + it("throw error if checksum value is not accurate when not streaming", async () => { + vi.mocked(isStreaming).mockReturnValue(false); + const incorrectChecksum = "incorrectChecksum"; const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], incorrectChecksum); + try { await validateChecksumFromResponse(responseWithChecksum, mockOptions); fail("should throw checksum mismatch error"); @@ -119,9 +149,28 @@ describe(validateChecksumFromResponse.name, () => { ` in response header "${mockResponseAlgorithms[0]}".` ); } + expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms); expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1); expect(getChecksumLocationName).toHaveBeenCalledTimes(1); expect(getChecksum).toHaveBeenCalledTimes(1); + expect(createChecksumStream).not.toHaveBeenCalled(); + }); + + it("return if checksum value is not accurate when streaming, as error will be thrown when stream is consumed", async () => { + vi.mocked(isStreaming).mockReturnValue(true); + + // This override does not matter for the purpose of unit test, but is kept for completeness. + const incorrectChecksum = "incorrectChecksum"; + const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], incorrectChecksum); + + await validateChecksumFromResponse(responseWithChecksum, mockOptions); + + expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms); + expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1); + expect(getChecksumLocationName).toHaveBeenCalledTimes(1); + expect(getChecksum).not.toHaveBeenCalled(); + expect(createChecksumStream).toHaveBeenCalledTimes(1); + expect(responseWithChecksum.body).toBe(mockBodyStream); }); }); diff --git a/packages/middleware-flexible-checksums/src/validateChecksumFromResponse.ts b/packages/middleware-flexible-checksums/src/validateChecksumFromResponse.ts index a25028fb2ec7..862f97bc2f0c 100644 --- a/packages/middleware-flexible-checksums/src/validateChecksumFromResponse.ts +++ b/packages/middleware-flexible-checksums/src/validateChecksumFromResponse.ts @@ -1,10 +1,13 @@ import { HttpResponse } from "@smithy/protocol-http"; +import { Checksum } from "@smithy/types"; +import { createChecksumStream } from "@smithy/util-stream"; import { PreviouslyResolved } from "./configuration"; import { ChecksumAlgorithm } from "./constants"; import { getChecksum } from "./getChecksum"; import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse"; import { getChecksumLocationName } from "./getChecksumLocationName"; +import { isStreaming } from "./isStreaming"; import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction"; export interface ValidateChecksumFromResponseOptions { @@ -29,9 +32,20 @@ export const validateChecksumFromResponse = async ( const checksumFromResponse = responseHeaders[responseHeader]; if (checksumFromResponse) { const checksumAlgorithmFn = selectChecksumAlgorithmFunction(algorithm as ChecksumAlgorithm, config); - const { streamHasher, base64Encoder } = config; - const checksum = await getChecksum(responseBody, { streamHasher, checksumAlgorithmFn, base64Encoder }); + const { base64Encoder } = config; + if (isStreaming(responseBody)) { + response.body = createChecksumStream({ + expectedChecksum: checksumFromResponse, + checksumSourceLocation: responseHeader, + checksum: new checksumAlgorithmFn() as Checksum, + source: responseBody, + base64Encoder, + }); + return; + } + + const checksum = await getChecksum(responseBody, { checksumAlgorithmFn, base64Encoder }); if (checksum === checksumFromResponse) { // The checksum for response payload is valid. break;