Skip to content

Commit

Permalink
Add Amazon Bedrock LLM endpoint integration (langchain-ai#2210)
Browse files Browse the repository at this point in the history
* Add Amazon Bedrock LLM endpoint integration (#1)

A contribution from Prompt Security

* Kept it as close to the Python implementation as possible
* Followed the guidelines from https://github.com/hwchase17/langchainjs/blob/main/CONTRIBUTING.md and https://github.com/hwchase17/langchainjs/blob/main/.github/contributing/INTEGRATIONS.md
* Supplied with unit test coverage
* Added documentation

* Fix bedrock documentation .mdx

* Fix Bedrock mdx documentation to show embedded example

* Fix Bedrock mdx documentation to show embedded example

* Rename LLMInputOutputAdapter to BedrockLLMInputOutputAdapter as requested in PR review

* Fixed @jacoblee93 suggestions from PR review

* Added ability to specify credentials override via Bedrock class constructor.
* Removed all dependency on `aws-sigv4-fetch` library and re-implemented the functionality using direct AWS api calls, directly within the `Bedrock._call()` function.
* Rewrote unit-tests to mock fetch() instead of the (now missing `aws-sigv4-fetch` library).
* Fixed documentation for Bedrock, hopefully better now.

* Commit changes to yarn.lock after 'yarn install' in an attempt to fix build error 'YN0028: │ The lockfile would have been modified by this install, which is explicitly forbidden.'

* Move hard deps to peer and optional deps, update docs

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
vitaly-ps and jacoblee93 authored Aug 15, 2023
1 parent df741af commit dc1fbe4
Show file tree
Hide file tree
Showing 12 changed files with 1,357 additions and 0 deletions.
19 changes: 19 additions & 0 deletions docs/extras/modules/model_io/models/llms/integrations/bedrock.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Bedrock

>[Amazon Bedrock](https://aws.amazon.com/bedrock/) is a fully managed service that makes Foundation Models (FMs)
from leading AI startups and Amazon available via an API. You can choose from a wide range of FMs to find the model that is best suited for your use case.

## Setup

You'll need to install a few official AWS packages as peer dependencies:

```bash npm2yarn
npm install @aws-crypto/sha256-js @aws-sdk/credential-provider-node @aws-sdk/protocol-http @aws-sdk/signature-v4
```

## Usage

import CodeBlock from "@theme/CodeBlock";
import BedrockExample from "@examples/models/llm/bedrock.ts";

<CodeBlock language="typescript">{BedrockExample}</CodeBlock>
8 changes: 8 additions & 0 deletions examples/src/llms/bedrock.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import { Bedrock } from "langchain/llms/bedrock";

async function test() {
const model = new Bedrock({model: "bedrock-model-name", region: "aws-region"});
const res = await model.call("Question: What would be a good company name a company that makes colorful socks?\nAnswer:");
console.log(res);
}
test();
18 changes: 18 additions & 0 deletions examples/src/models/llm/bedrock.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { Bedrock } from "langchain/llms/bedrock";

async function test() {
// If no credentials are provided, the default credentials from
// @aws-sdk/credential-provider-node will be used.
const model = new Bedrock({
model: "ai21",
region: "us-west-2",
// credentials: {
// accessKeyId: "YOUR_AWS_ACCESS_KEY",
// secretAccessKey: "YOUR_SECRET_ACCESS_KEY"
// }
});
const res = await model.call("Tell me a joke");
console.log(res);
}

test();
3 changes: 3 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ llms/googlepalm.d.ts
llms/sagemaker_endpoint.cjs
llms/sagemaker_endpoint.js
llms/sagemaker_endpoint.d.ts
llms/bedrock.cjs
llms/bedrock.js
llms/bedrock.d.ts
prompts.cjs
prompts.js
prompts.d.ts
Expand Down
29 changes: 29 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@
"llms/sagemaker_endpoint.cjs",
"llms/sagemaker_endpoint.js",
"llms/sagemaker_endpoint.d.ts",
"llms/bedrock.cjs",
"llms/bedrock.js",
"llms/bedrock.d.ts",
"prompts.cjs",
"prompts.js",
"prompts.d.ts",
Expand Down Expand Up @@ -564,12 +567,17 @@
"author": "LangChain",
"license": "MIT",
"devDependencies": {
"@aws-crypto/sha256-js": "^5.0.0",
"@aws-sdk/client-dynamodb": "^3.310.0",
"@aws-sdk/client-kendra": "^3.352.0",
"@aws-sdk/client-lambda": "^3.310.0",
"@aws-sdk/client-s3": "^3.310.0",
"@aws-sdk/client-sagemaker-runtime": "^3.310.0",
"@aws-sdk/client-sfn": "^3.362.0",
"@aws-sdk/credential-provider-node": "^3.388.0",
"@aws-sdk/protocol-http": "^3.374.0",
"@aws-sdk/signature-v4": "^3.374.0",
"@aws-sdk/types": "^3.357.0",
"@azure/storage-blob": "^12.15.0",
"@clickhouse/client": "^0.0.14",
"@elastic/elasticsearch": "^8.4.0",
Expand Down Expand Up @@ -668,12 +676,16 @@
"weaviate-ts-client": "^1.4.0"
},
"peerDependencies": {
"@aws-crypto/sha256-js": "^5.0.0",
"@aws-sdk/client-dynamodb": "^3.310.0",
"@aws-sdk/client-kendra": "^3.352.0",
"@aws-sdk/client-lambda": "^3.310.0",
"@aws-sdk/client-s3": "^3.310.0",
"@aws-sdk/client-sagemaker-runtime": "^3.310.0",
"@aws-sdk/client-sfn": "^3.310.0",
"@aws-sdk/credential-provider-node": "^3.388.0",
"@aws-sdk/protocol-http": "^3.374.0",
"@aws-sdk/signature-v4": "^3.374.0",
"@azure/storage-blob": "^12.15.0",
"@clickhouse/client": "^0.0.14",
"@elastic/elasticsearch": "^8.4.0",
Expand Down Expand Up @@ -736,6 +748,9 @@
"weaviate-ts-client": "^1.4.0"
},
"peerDependenciesMeta": {
"@aws-crypto/sha256-js": {
"optional": true
},
"@aws-sdk/client-dynamodb": {
"optional": true
},
Expand All @@ -754,6 +769,15 @@
"@aws-sdk/client-sfn": {
"optional": true
},
"@aws-sdk/credential-provider-node": {
"optional": true
},
"@aws-sdk/protocol-http": {
"optional": true
},
"@aws-sdk/signature-v4": {
"optional": true
},
"@azure/storage-blob": {
"optional": true
},
Expand Down Expand Up @@ -1194,6 +1218,11 @@
"import": "./llms/sagemaker_endpoint.js",
"require": "./llms/sagemaker_endpoint.cjs"
},
"./llms/bedrock": {
"types": "./llms/bedrock.d.ts",
"import": "./llms/bedrock.js",
"require": "./llms/bedrock.cjs"
},
"./prompts": {
"types": "./prompts.d.ts",
"import": "./prompts.js",
Expand Down
2 changes: 2 additions & 0 deletions langchain/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const entrypoints = {
"llms/googlevertexai": "llms/googlevertexai",
"llms/googlepalm": "llms/googlepalm",
"llms/sagemaker_endpoint": "llms/sagemaker_endpoint",
"llms/bedrock": "llms/bedrock",
// prompts
prompts: "prompts/index",
"prompts/load": "prompts/load",
Expand Down Expand Up @@ -255,6 +256,7 @@ const requiresOptionalDependency = [
"llms/raycast",
"llms/replicate",
"llms/sagemaker_endpoint",
"llms/bedrock",
"prompts/load",
"vectorstores/analyticdb",
"vectorstores/chroma",
Expand Down
195 changes: 195 additions & 0 deletions langchain/src/llms/bedrock.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import { SignatureV4 } from "@aws-sdk/signature-v4";
import { defaultProvider } from "@aws-sdk/credential-provider-node";
import { HttpRequest } from "@aws-sdk/protocol-http";
import { Sha256 } from "@aws-crypto/sha256-js";
import type { AwsCredentialIdentity, Provider } from "@aws-sdk/types";
import { getEnvironmentVariable } from "../util/env.js";
import { LLM, BaseLLMParams } from "./base.js";

type Dict = { [key: string]: unknown };
type CredentialType = AwsCredentialIdentity | Provider<AwsCredentialIdentity>;

class BedrockLLMInputOutputAdapter {
/** Adapter class to prepare the inputs from Langchain to a format
that LLM model expects. Also, provides a helper function to extract
the generated text from the model response. */

static prepareInput(provider: string, prompt: string): Dict {
const inputBody: Dict = {};

if (provider === "anthropic" || provider === "ai21") {
inputBody.prompt = prompt;
} else if (provider === "amazon") {
inputBody.inputText = prompt;
inputBody.textGenerationConfig = {};
} else {
inputBody.inputText = prompt;
}

if (provider === "anthropic" && !("max_tokens_to_sample" in inputBody)) {
inputBody.max_tokens_to_sample = 50;
}

return inputBody;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
static prepareOutput(provider: string, responseBody: any): string {
if (provider === "anthropic") {
return responseBody.completion;
} else if (provider === "ai21") {
return responseBody.completions[0].data.text;
}
return responseBody.results[0].outputText;
}
}

/** Bedrock models.
To authenticate, the AWS client uses the following methods to automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If a specific credential profile should be used, you must pass the name of the profile from the ~/.aws/credentials file that is to be used.
Make sure the credentials / roles used have the required policies to access the Bedrock service.
*/
export interface BedrockInput {
/** Model to use.
For example, "amazon.titan-tg1-large", this is equivalent to the modelId property in the list-foundation-models api.
*/
model: string;

/** The AWS region e.g. `us-west-2`.
Fallback to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config in case it is not provided here.
*/
region?: string;

/** AWS Credentials.
If no credentials are provided, the default credentials from `@aws-sdk/credential-provider-node` will be used.
*/
credentials?: CredentialType;

/** Temperature */
temperature?: number;

/** Max tokens */
maxTokens?: number;

/** A custom fetch function for low-level access to AWS API. Defaults to fetch() */
fetchFn?: typeof fetch;
}

export class Bedrock extends LLM implements BedrockInput {
model = "amazon.titan-tg1-large";

region: string;

credentials: CredentialType;

temperature?: number | undefined = undefined;

maxTokens?: number | undefined = undefined;

fetchFn: typeof fetch;

get lc_secrets(): { [key: string]: string } | undefined {
return {};
}

_llmType() {
return "bedrock";
}

constructor(fields?: Partial<BedrockInput> & BaseLLMParams) {
super(fields ?? {});

this.model = fields?.model ?? this.model;
const allowedModels = ["ai21", "anthropic", "amazon"];
if (!allowedModels.includes(this.model.split(".")[0])) {
throw new Error(
`Unknown model: '${this.model}', only these are supported: ${allowedModels}`
);
}
const region =
fields?.region ?? getEnvironmentVariable("AWS_DEFAULT_REGION");
if (!region) {
throw new Error(
"Please set the AWS_DEFAULT_REGION environment variable or pass it to the constructor as the region field."
);
}
this.region = region;
this.credentials = fields?.credentials ?? defaultProvider();
this.temperature = fields?.temperature ?? this.temperature;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.fetchFn = fields?.fetchFn ?? fetch;
}

/** Call out to Bedrock service model.
Arguments:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
response = model.call("Tell me a joke.")
*/
async _call(prompt: string): Promise<string> {
const provider = this.model.split(".")[0];
const service = "bedrock";

const inputBody = BedrockLLMInputOutputAdapter.prepareInput(
provider,
prompt
);

const url = new URL(
`https://${service}.${this.region}.amazonaws.com/model/${this.model}/invoke`
);

const request = new HttpRequest({
hostname: url.hostname,
path: url.pathname,
protocol: url.protocol,
method: "POST", // method must be uppercase
body: JSON.stringify(inputBody),
query: Object.fromEntries(url.searchParams.entries()),
headers: {
// host is required by AWS Signature V4: https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
host: url.host,
accept: "application/json",
"Content-Type": "application/json",
},
});

const signer = new SignatureV4({
credentials: this.credentials,
service,
region: this.region,
sha256: Sha256,
});

const signedRequest = await signer.sign(request);

// Send request to AWS using the low-level fetch API
const response = await this.fetchFn(url, {
headers: signedRequest.headers,
body: signedRequest.body,
method: signedRequest.method,
});

if (response.status < 200 || response.status >= 300) {
throw Error(
`Failed to access underlying url '${url}': got ${response.status} ${
response.statusText
}: ${await response.text()}`
);
}

const responseJson = await response.json();

const text = BedrockLLMInputOutputAdapter.prepareOutput(
provider,
responseJson
);

return text;
}
}
Loading

0 comments on commit dc1fbe4

Please sign in to comment.