Skip to content

Commit

Permalink
chore(middleware-flexible-checksums): delay checksum validation until…
Browse files Browse the repository at this point in the history
… stream read (#6629)
  • Loading branch information
trivikr authored Nov 6, 2024
1 parent 0670605 commit 43ee3c1
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 41 deletions.
29 changes: 22 additions & 7 deletions clients/client-s3/test/e2e/S3.e2e.spec.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import "@aws-sdk/signature-v4-crt";

import { S3, SelectObjectContentEventStream } from "@aws-sdk/client-s3";
import { ChecksumAlgorithm, S3, SelectObjectContentEventStream } from "@aws-sdk/client-s3";
import { afterAll, afterEach, beforeAll, describe, expect, test as it } from "vitest";

import { getIntegTestResources } from "../../../../tests/e2e/get-integ-test-resources";
Expand All @@ -24,9 +24,7 @@ describe("@aws-sdk/client-s3", () => {

Key = ``;

client = new S3({
region,
});
client = new S3({ region });
});

describe("PutObject", () => {
Expand Down Expand Up @@ -74,26 +72,43 @@ describe("@aws-sdk/client-s3", () => {
await client.deleteObject({ Bucket, Key });
});

it("should succeed with valid body payload", async () => {
it("should succeed with valid body payload with checksums", async () => {
// prepare the object.
const body = createBuffer("1MB");
let bodyChecksum = "";

const bodyChecksumReader = (next) => async (args) => {
const checksumValue = args.request.headers["x-amz-checksum-crc32"];
if (checksumValue) {
bodyChecksum = checksumValue;
}
return next(args);
};
client.middlewareStack.addRelativeTo(bodyChecksumReader, {
name: "bodyChecksumReader",
relation: "before",
toMiddleware: "deserializerMiddleware",
});

try {
await client.putObject({ Bucket, Key, Body: body });
await client.putObject({ Bucket, Key, Body: body, ChecksumAlgorithm: ChecksumAlgorithm.CRC32 });
} catch (e) {
console.error("failed to put");
throw e;
}

expect(bodyChecksum).not.toEqual("");

try {
// eslint-disable-next-line no-var
var result = await client.getObject({ Bucket, Key });
var result = await client.getObject({ Bucket, Key, ChecksumMode: "ENABLED" });
} catch (e) {
console.error("failed to get");
throw e;
}

expect(result.$metadata.httpStatusCode).toEqual(200);
expect(result.ChecksumCRC32).toEqual(bodyChecksum);
const { Readable } = require("stream");
expect(result.Body).toBeInstanceOf(Readable);
});
Expand Down
1 change: 1 addition & 0 deletions packages/middleware-flexible-checksums/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"@smithy/types": "^3.6.0",
"@smithy/util-middleware": "^3.0.8",
"@smithy/util-utf8": "^3.0.0",
"@smithy/util-stream": "^3.2.1",
"tslib": "^2.6.2"
},
"devDependencies": {
Expand Down
16 changes: 1 addition & 15 deletions packages/middleware-flexible-checksums/src/getChecksum.spec.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest";

import { getChecksum } from "./getChecksum";
import { isStreaming } from "./isStreaming";
import { stringHasher } from "./stringHasher";

vi.mock("./isStreaming");
vi.mock("./stringHasher");

describe(getChecksum.name, () => {
const mockOptions = {
streamHasher: vi.fn(),
checksumAlgorithmFn: vi.fn(),
base64Encoder: vi.fn(),
};
Expand All @@ -26,21 +23,10 @@ describe(getChecksum.name, () => {
vi.clearAllMocks();
});

it("gets checksum from streamHasher if body is streaming", async () => {
vi.mocked(isStreaming).mockReturnValue(true);
mockOptions.streamHasher.mockResolvedValue(mockRawOutput);
const checksum = await getChecksum(mockBody, mockOptions);
expect(checksum).toEqual(mockOutput);
expect(stringHasher).not.toHaveBeenCalled();
expect(mockOptions.streamHasher).toHaveBeenCalledWith(mockOptions.checksumAlgorithmFn, mockBody);
});

it("gets checksum from stringHasher if body is not streaming", async () => {
vi.mocked(isStreaming).mockReturnValue(false);
it("gets checksum from stringHasher", async () => {
vi.mocked(stringHasher).mockResolvedValue(mockRawOutput);
const checksum = await getChecksum(mockBody, mockOptions);
expect(checksum).toEqual(mockOutput);
expect(mockOptions.streamHasher).not.toHaveBeenCalled();
expect(stringHasher).toHaveBeenCalledWith(mockOptions.checksumAlgorithmFn, mockBody);
});
});
13 changes: 3 additions & 10 deletions packages/middleware-flexible-checksums/src/getChecksum.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
import { ChecksumConstructor, Encoder, HashConstructor, StreamHasher } from "@smithy/types";
import { ChecksumConstructor, Encoder, HashConstructor } from "@smithy/types";

import { isStreaming } from "./isStreaming";
import { stringHasher } from "./stringHasher";

export interface GetChecksumDigestOptions {
streamHasher: StreamHasher<any>;
checksumAlgorithmFn: ChecksumConstructor | HashConstructor;
base64Encoder: Encoder;
}

export const getChecksum = async (
body: unknown,
{ streamHasher, checksumAlgorithmFn, base64Encoder }: GetChecksumDigestOptions
) => {
const digest = isStreaming(body) ? streamHasher(checksumAlgorithmFn, body) : stringHasher(checksumAlgorithmFn, body);
return base64Encoder(await digest);
};
export const getChecksum = async (body: unknown, { checksumAlgorithmFn, base64Encoder }: GetChecksumDigestOptions) =>
base64Encoder(await stringHasher(checksumAlgorithmFn, body));
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
import { HttpResponse } from "@smithy/protocol-http";
import { createChecksumStream } from "@smithy/util-stream";
import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { getChecksum } from "./getChecksum";
import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { isStreaming } from "./isStreaming";
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";
import { validateChecksumFromResponse } from "./validateChecksumFromResponse";

vi.mock("@smithy/util-stream");
vi.mock("./getChecksum");
vi.mock("./getChecksumLocationName");
vi.mock("./getChecksumAlgorithmListForResponse");
vi.mock("./isStreaming");
vi.mock("./selectChecksumAlgorithmFunction");

describe(validateChecksumFromResponse.name, () => {
const mockConfig = {
streamHasher: vi.fn(),
base64Encoder: vi.fn(),
} as unknown as PreviouslyResolved;

const mockBody = {};
const mockBodyStream = { isStream: true };
const mockHeaders = {};
const mockResponse = {
body: mockBody,
Expand Down Expand Up @@ -50,6 +54,7 @@ describe(validateChecksumFromResponse.name, () => {
vi.mocked(getChecksumAlgorithmListForResponse).mockImplementation((responseAlgorithms) => responseAlgorithms);
vi.mocked(selectChecksumAlgorithmFunction).mockReturnValue(mockChecksumAlgorithmFn);
vi.mocked(getChecksum).mockResolvedValue(mockChecksum);
vi.mocked(createChecksumStream).mockReturnValue(mockBodyStream);
});

afterEach(() => {
Expand Down Expand Up @@ -85,31 +90,56 @@ describe(validateChecksumFromResponse.name, () => {
});

describe("successful validation", () => {
afterEach(() => {
const validateCalls = (isStream: boolean, checksumAlgoFn: ChecksumAlgorithm) => {
expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(getChecksum).toHaveBeenCalledTimes(1);
});

it("when checksum is populated for first algorithm", async () => {
if (isStream) {
expect(getChecksum).not.toHaveBeenCalled();
expect(createChecksumStream).toHaveBeenCalledTimes(1);
expect(createChecksumStream).toHaveBeenCalledWith({
expectedChecksum: mockChecksum,
checksumSourceLocation: checksumAlgoFn,
checksum: new mockChecksumAlgorithmFn(),
source: mockBody,
base64Encoder: mockConfig.base64Encoder,
});
} else {
expect(getChecksum).toHaveBeenCalledTimes(1);
expect(getChecksum).toHaveBeenCalledWith(mockBody, {
checksumAlgorithmFn: mockChecksumAlgorithmFn,
base64Encoder: mockConfig.base64Encoder,
});
expect(createChecksumStream).not.toHaveBeenCalled();
}
};

it.each([false, true])("when checksum is populated for first algorithm when streaming: %s", async (isStream) => {
vi.mocked(isStreaming).mockReturnValue(isStream);
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], mockChecksum);
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledWith(mockResponseAlgorithms[0]);
validateCalls(isStream, mockResponseAlgorithms[0]);
});

it("when checksum is populated for second algorithm", async () => {
it.each([false, true])("when checksum is populated for second algorithm when streaming: %s", async (isStream) => {
vi.mocked(isStreaming).mockReturnValue(isStream);
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[1], mockChecksum);
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
expect(getChecksumLocationName).toHaveBeenCalledTimes(2);
expect(getChecksumLocationName).toHaveBeenNthCalledWith(1, mockResponseAlgorithms[0]);
expect(getChecksumLocationName).toHaveBeenNthCalledWith(2, mockResponseAlgorithms[1]);
validateCalls(isStream, mockResponseAlgorithms[1]);
});
});

it("throw error if checksum value is not accurate", async () => {
it("throw error if checksum value is not accurate when not streaming", async () => {
vi.mocked(isStreaming).mockReturnValue(false);

const incorrectChecksum = "incorrectChecksum";
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], incorrectChecksum);

try {
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
fail("should throw checksum mismatch error");
Expand All @@ -119,9 +149,28 @@ describe(validateChecksumFromResponse.name, () => {
` in response header "${mockResponseAlgorithms[0]}".`
);
}

expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(getChecksum).toHaveBeenCalledTimes(1);
expect(createChecksumStream).not.toHaveBeenCalled();
});

it("return if checksum value is not accurate when streaming, as error will be thrown when stream is consumed", async () => {
vi.mocked(isStreaming).mockReturnValue(true);

// This override does not matter for the purpose of unit test, but is kept for completeness.
const incorrectChecksum = "incorrectChecksum";
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], incorrectChecksum);

await validateChecksumFromResponse(responseWithChecksum, mockOptions);

expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(getChecksum).not.toHaveBeenCalled();
expect(createChecksumStream).toHaveBeenCalledTimes(1);
expect(responseWithChecksum.body).toBe(mockBodyStream);
});
});
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import { HttpResponse } from "@smithy/protocol-http";
import { Checksum } from "@smithy/types";
import { createChecksumStream } from "@smithy/util-stream";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { getChecksum } from "./getChecksum";
import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { isStreaming } from "./isStreaming";
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";

export interface ValidateChecksumFromResponseOptions {
Expand All @@ -29,9 +32,20 @@ export const validateChecksumFromResponse = async (
const checksumFromResponse = responseHeaders[responseHeader];
if (checksumFromResponse) {
const checksumAlgorithmFn = selectChecksumAlgorithmFunction(algorithm as ChecksumAlgorithm, config);
const { streamHasher, base64Encoder } = config;
const checksum = await getChecksum(responseBody, { streamHasher, checksumAlgorithmFn, base64Encoder });
const { base64Encoder } = config;

if (isStreaming(responseBody)) {
response.body = createChecksumStream({
expectedChecksum: checksumFromResponse,
checksumSourceLocation: responseHeader,
checksum: new checksumAlgorithmFn() as Checksum,
source: responseBody,
base64Encoder,
});
return;
}

const checksum = await getChecksum(responseBody, { checksumAlgorithmFn, base64Encoder });
if (checksum === checksumFromResponse) {
// The checksum for response payload is valid.
break;
Expand Down

0 comments on commit 43ee3c1

Please sign in to comment.