From 49d85fdb114e9a91af5809675c5e206c8ca358e8 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 6 Aug 2024 12:40:51 -0700 Subject: [PATCH 1/2] openai[minor],core[minor]: Add support for passing strict in openai tools --- langchain-core/src/language_models/base.ts | 9 + langchain-core/src/utils/function_calling.ts | 19 +- libs/langchain-openai/package.json | 2 +- libs/langchain-openai/src/chat_models.ts | 37 ++- .../src/tests/chat_models.test.ts | 212 ++++++++++++++++++ .../chat_models_structured_output.int.test.ts | 1 + libs/langchain-openai/src/types.ts | 7 + yarn.lock | 24 +- 8 files changed, 299 insertions(+), 12 deletions(-) create mode 100644 libs/langchain-openai/src/tests/chat_models.test.ts diff --git a/langchain-core/src/language_models/base.ts b/langchain-core/src/language_models/base.ts index 0e8af1bc32..cea8ca2f9a 100644 --- a/langchain-core/src/language_models/base.ts +++ b/langchain-core/src/language_models/base.ts @@ -233,6 +233,15 @@ export interface FunctionDefinition { * how to call the function. */ description?: string; + + /** + * Whether to enable strict schema adherence when generating the function call. If + * set to true, the model will follow the exact schema defined in the `parameters` + * field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn + * more about Structured Outputs in the + * [function calling guide](https://platform.openai.com/docs/guides/function-calling). + */ + strict?: boolean; } export interface ToolDefinition { diff --git a/langchain-core/src/utils/function_calling.ts b/langchain-core/src/utils/function_calling.ts index 3871ffc445..6155f3cd8f 100644 --- a/langchain-core/src/utils/function_calling.ts +++ b/langchain-core/src/utils/function_calling.ts @@ -34,14 +34,29 @@ export function convertToOpenAIFunction( */ export function convertToOpenAITool( // eslint-disable-next-line @typescript-eslint/no-explicit-any - tool: StructuredToolInterface | Record | RunnableToolLike + tool: StructuredToolInterface | Record | RunnableToolLike, + fields?: { + /** + * If `true`, model output is guaranteed to exactly match the JSON Schema + * provided in the function definition. + */ + strict?: boolean; + } ): ToolDefinition { + let toolDef: ToolDefinition | undefined; if (isStructuredTool(tool) || isRunnableToolLike(tool)) { - return { + toolDef = { type: "function", function: convertToOpenAIFunction(tool), }; + } else { + toolDef = tool as ToolDefinition; + } + + if (fields?.strict !== undefined) { + toolDef.function.strict = fields.strict; } + return tool as ToolDefinition; } diff --git a/libs/langchain-openai/package.json b/libs/langchain-openai/package.json index 3115ef248c..7a56530841 100644 --- a/libs/langchain-openai/package.json +++ b/libs/langchain-openai/package.json @@ -37,7 +37,7 @@ "dependencies": { "@langchain/core": ">=0.2.16 <0.3.0", "js-tiktoken": "^1.0.12", - "openai": "^4.49.1", + "openai": "^4.55.0", "zod": "^3.22.4", "zod-to-json-schema": "^3.22.3" }, diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index db86e6e919..af06da3471 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -1,4 +1,4 @@ -import { type ClientOptions, OpenAI as OpenAIClient } from "openai"; +import { type ClientOptions, OpenAI as OpenAIClient, } from "openai"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { @@ -299,6 +299,16 @@ export interface ChatOpenAICallOptions * call multiple tools in one response. */ parallel_tool_calls?: boolean; + /** + * If `true`, model output is guaranteed to exactly match the JSON Schema + * provided in the tool definition. + * Enabled by default for `"gpt-"` models. + */ + strict?: boolean; +} + +export interface ChatOpenAIFields extends Partial, Partial, BaseChatModelParams { + configuration?: ClientOptions & LegacyOpenAIInput; } /** @@ -441,12 +451,15 @@ export class ChatOpenAI< protected clientConfig: ClientOptions; + /** + * Whether the model supports the 'strict' argument when passing in tools. + * Defaults to `true` if `modelName`/`model` starts with 'gpt-' otherwise + * defaults to `false`. + */ + supportsStrictToolCalling?: boolean; + constructor( - fields?: Partial & - Partial & - BaseChatModelParams & { - configuration?: ClientOptions & LegacyOpenAIInput; - }, + fields?: ChatOpenAIFields, /** @deprecated */ configuration?: ClientOptions & LegacyOpenAIInput ) { @@ -541,6 +554,12 @@ export class ChatOpenAI< ...configuration, ...fields?.configuration, }; + + // Assume only "gpt-..." models support strict tool calling as of 08/06/24. + this.supportsStrictToolCalling = + fields?.supportsStrictToolCalling !== undefined + ? fields.supportsStrictToolCalling + : this.modelName.startsWith("gpt-"); } getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { @@ -563,8 +582,9 @@ export class ChatOpenAI< )[], kwargs?: Partial ): Runnable { + const strict = kwargs?.strict !== undefined ? kwargs.strict : this.supportsStrictToolCalling; return this.bind({ - tools: tools.map(convertToOpenAITool), + tools: tools.map((tool) => convertToOpenAITool(tool, { strict })), ...kwargs, } as Partial); } @@ -578,6 +598,7 @@ export class ChatOpenAI< streaming?: boolean; } ): Omit { + const strict = options?.strict !== undefined ? options.strict : this.supportsStrictToolCalling; function isStructuredToolArray( tools?: unknown[] ): tools is StructuredToolInterface[] { @@ -615,7 +636,7 @@ export class ChatOpenAI< functions: options?.functions, function_call: options?.function_call, tools: isStructuredToolArray(options?.tools) - ? options?.tools.map(convertToOpenAITool) + ? options?.tools.map((tool) => convertToOpenAITool(tool, { strict })) : options?.tools, tool_choice: formatToOpenAIToolChoice(options?.tool_choice), response_format: options?.response_format, diff --git a/libs/langchain-openai/src/tests/chat_models.test.ts b/libs/langchain-openai/src/tests/chat_models.test.ts new file mode 100644 index 0000000000..3624ea148c --- /dev/null +++ b/libs/langchain-openai/src/tests/chat_models.test.ts @@ -0,0 +1,212 @@ +import { z } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import { it, expect, describe, beforeAll, afterAll, jest } from "@jest/globals"; +import { ChatOpenAI } from "../chat_models.js"; + + +describe("strict tool calling", () => { + const weatherTool = { + type: "function" as const, + function: { + name: "get_current_weather", + description: "Get the current weather in a location", + parameters: zodToJsonSchema(z.object({ + location: z.string().describe("The location to get the weather for"), + })) + } + } + + // Store the original value of LANGCHAIN_TRACING_V2 + let oldLangChainTracingValue: string | undefined; + // Before all tests, save the current LANGCHAIN_TRACING_V2 value + beforeAll(() => { + oldLangChainTracingValue = process.env.LANGCHAIN_TRACING_V2; + }) + // After all tests, restore the original LANGCHAIN_TRACING_V2 value + afterAll(() => { + if (oldLangChainTracingValue !== undefined) { + process.env.LANGCHAIN_TRACING_V2 = oldLangChainTracingValue; + } else { + // If it was undefined, remove the environment variable + delete process.env.LANGCHAIN_TRACING_V2; + } + }) + + it("Can accept strict as a call arg via .bindTools", async () => { + const mockFetch = jest.fn<(url: any, init?: any) => Promise>(); + mockFetch.mockImplementation((url, options): Promise => { + // Store the request details for later inspection + mockFetch.mock.calls.push({ url, options } as any); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }) as Promise; + }); + + const model = new ChatOpenAI({ + model: "gpt-4", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + }); + + const modelWithTools = model.bindTools([weatherTool], { strict: true }); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ + type: "function", + function: { + ...weatherTool.function, + // This should be added to the function call because `strict` was passed to `bindTools` + strict: true, + } + })]); + } else { + throw new Error("Body not found in request.") + } + }); + + it("Can accept strict as a call arg via .bind", async () => { + const mockFetch = jest.fn<(url: any, init?: any) => Promise>(); + mockFetch.mockImplementation((url, options): Promise => { + // Store the request details for later inspection + mockFetch.mock.calls.push({ url, options } as any); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }) as Promise; + }); + + const model = new ChatOpenAI({ + model: "gpt-4", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + }); + + const modelWithTools = model.bind({ + tools: [weatherTool], + strict: true + }); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ + type: "function", + function: { + ...weatherTool.function, + // This should be added to the function call because `strict` was passed to `bind` + strict: true, + } + })]); + } else { + throw new Error("Body not found in request.") + } + }); + + it("Sets strict to true if the model name starts with 'gpt-'", async () => { + const mockFetch = jest.fn<(url: any, init?: any) => Promise>(); + mockFetch.mockImplementation((url, options): Promise => { + // Store the request details for later inspection + mockFetch.mock.calls.push({ url, options } as any); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }) as Promise; + }); + + const model = new ChatOpenAI({ + model: "gpt-4", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + }); + + // Do NOT pass `strict` here since we're checking that it's set to true by default + const modelWithTools = model.bindTools([weatherTool]); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ + type: "function", + function: { + ...weatherTool.function, + // This should be added to the function call because `strict` was passed to `bind` + strict: true, + } + })]); + } else { + throw new Error("Body not found in request.") + } + }); + + it("Strict is false if supportsStrictToolCalling is false", async () => { + const mockFetch = jest.fn<(url: any, init?: any) => Promise>(); + mockFetch.mockImplementation((url, options): Promise => { + // Store the request details for later inspection + mockFetch.mock.calls.push({ url, options } as any); + + // Return a mock response + return Promise.resolve({ + ok: true, + json: () => Promise.resolve({}), + }) as Promise; + }); + + const model = new ChatOpenAI({ + model: "gpt-4", + configuration: { + fetch: mockFetch, + }, + maxRetries: 0, + supportsStrictToolCalling: false, + }); + + // Do NOT pass `strict` here since we're checking that it's set to true by default + const modelWithTools = model.bindTools([weatherTool]); + + // This will fail since we're not returning a valid response in our mocked fetch function. + await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); + + expect(mockFetch).toHaveBeenCalled(); + const [_url, options] = mockFetch.mock.calls[0]; + + if (options && options.body) { + expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ + type: "function", + function: { + ...weatherTool.function, + // This should be added to the function call because `strict` was passed to `bind` + strict: false, + } + })]); + } else { + throw new Error("Body not found in request.") + } + }); +}) diff --git a/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts b/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts index 86bf0247bd..95f379c869 100644 --- a/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts +++ b/libs/langchain-openai/src/tests/chat_models_structured_output.int.test.ts @@ -3,6 +3,7 @@ import { zodToJsonSchema } from "zod-to-json-schema"; import { ChatPromptTemplate } from "@langchain/core/prompts"; import { AIMessage } from "@langchain/core/messages"; import { ChatOpenAI } from "../chat_models.js"; +import { test, expect } from "@jest/globals"; test("withStructuredOutput zod schema function calling", async () => { const model = new ChatOpenAI({ diff --git a/libs/langchain-openai/src/types.ts b/libs/langchain-openai/src/types.ts index 19e6af483d..afd9fea262 100644 --- a/libs/langchain-openai/src/types.ts +++ b/libs/langchain-openai/src/types.ts @@ -155,6 +155,13 @@ export interface OpenAIChatInput extends OpenAIBaseInput { * Currently in experimental beta. */ __includeRawResponse?: boolean; + + /** + * Whether the model supports the 'strict' argument when passing in tools. + * Defaults to `true` if `modelName`/`model` starts with 'gpt-' otherwise + * defaults to `false`. + */ + supportsStrictToolCalling?: boolean; } export declare interface AzureOpenAIInput { diff --git a/yarn.lock b/yarn.lock index c96e976056..075329f22f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -12199,7 +12199,7 @@ __metadata: jest: ^29.5.0 jest-environment-node: ^29.6.4 js-tiktoken: ^1.0.12 - openai: ^4.49.1 + openai: ^4.55.0 prettier: ^2.8.3 release-it: ^17.6.0 rimraf: ^5.0.1 @@ -34040,6 +34040,28 @@ __metadata: languageName: node linkType: hard +"openai@npm:^4.55.0": + version: 4.55.0 + resolution: "openai@npm:4.55.0" + dependencies: + "@types/node": ^18.11.18 + "@types/node-fetch": ^2.6.4 + abort-controller: ^3.0.0 + agentkeepalive: ^4.2.1 + form-data-encoder: 1.7.2 + formdata-node: ^4.3.2 + node-fetch: ^2.6.7 + peerDependencies: + zod: ^3.23.8 + peerDependenciesMeta: + zod: + optional: true + bin: + openai: bin/cli + checksum: b2b1daa976516262e08e182ee982976a1dc615eebd250bbd71f4122740ebeeb207a20af6d35c718b67f1c3457196b524667a0c7fa417ab4e119020b5c1f5cd74 + languageName: node + linkType: hard + "openapi-types@npm:^12.1.3": version: 12.1.3 resolution: "openapi-types@npm:12.1.3" From 717a996f56d5c20c8cb4d08be61837cb8bbdc6bb Mon Sep 17 00:00:00 2001 From: "local-dev-korbit-ai-mentor[bot]" <130798245+local-dev-korbit-ai-mentor[bot]@users.noreply.github.com> Date: Thu, 15 Aug 2024 17:39:01 +0000 Subject: [PATCH 2/2] [skip ci]