Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai[minor],core[minor]: Add support for passing strict in openai tools #13

Open
wants to merge 2 commits into
base: cloned_main_c5fb8
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions langchain-core/src/language_models/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,15 @@ export interface FunctionDefinition {
* how to call the function.
*/
description?: string;

/**
* Whether to enable strict schema adherence when generating the function call. If
* set to true, the model will follow the exact schema defined in the `parameters`
* field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn
* more about Structured Outputs in the
* [function calling guide](https://platform.openai.com/docs/guides/function-calling).
*/
strict?: boolean;
}

export interface ToolDefinition {
Expand Down
19 changes: 17 additions & 2 deletions langchain-core/src/utils/function_calling.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,29 @@ export function convertToOpenAIFunction(
*/
export function convertToOpenAITool(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tool: StructuredToolInterface | Record<string, any> | RunnableToolLike
tool: StructuredToolInterface | Record<string, any> | RunnableToolLike,
fields?: {
/**
* If `true`, model output is guaranteed to exactly match the JSON Schema
* provided in the function definition.
*/
strict?: boolean;
}
): ToolDefinition {
let toolDef: ToolDefinition | undefined;
if (isStructuredTool(tool) || isRunnableToolLike(tool)) {
return {
toolDef = {
type: "function",
function: convertToOpenAIFunction(tool),
};
} else {
toolDef = tool as ToolDefinition;
}

if (fields?.strict !== undefined) {
toolDef.function.strict = fields.strict;
}

Comment on lines +56 to +59
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the return value to include modifications.

The function is currently returning the original tool parameter instead of the modified toolDef. Ensure that the return statement reflects the intended changes, including the strict parameter.

-  return tool as ToolDefinition;
+  return toolDef;
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (fields?.strict !== undefined) {
toolDef.function.strict = fields.strict;
}
if (fields?.strict !== undefined) {
toolDef.function.strict = fields.strict;
}
return toolDef;

return tool as ToolDefinition;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

category Functionality severity potentially major

The convertToOpenAITool function is not returning the modified toolDef object. Instead, it's returning the original tool parameter, which means any modifications made to toolDef (including setting the strict field) are lost. To fix this, replace the last line return tool as ToolDefinition; with return toolDef;.

Chat with Korbit by mentioning @korbit-ai, and give a 👍 or 👎 to help Korbit improve your reviews.

}

Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"dependencies": {
"@langchain/core": ">=0.2.16 <0.3.0",
"js-tiktoken": "^1.0.12",
"openai": "^4.49.1",
"openai": "^4.55.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.3"
},
Expand Down
37 changes: 29 additions & 8 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type ClientOptions, OpenAI as OpenAIClient } from "openai";
import { type ClientOptions, OpenAI as OpenAIClient, } from "openai";

import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
Expand Down Expand Up @@ -299,6 +299,16 @@ export interface ChatOpenAICallOptions
* call multiple tools in one response.
*/
parallel_tool_calls?: boolean;
/**
* If `true`, model output is guaranteed to exactly match the JSON Schema
* provided in the tool definition.
* Enabled by default for `"gpt-"` models.
*/
strict?: boolean;
}

export interface ChatOpenAIFields extends Partial<OpenAIChatInput>, Partial<AzureOpenAIInput>, BaseChatModelParams {
configuration?: ClientOptions & LegacyOpenAIInput;
}

/**
Expand Down Expand Up @@ -441,12 +451,15 @@ export class ChatOpenAI<

protected clientConfig: ClientOptions;

/**
* Whether the model supports the 'strict' argument when passing in tools.
* Defaults to `true` if `modelName`/`model` starts with 'gpt-' otherwise
* defaults to `false`.
*/
supportsStrictToolCalling?: boolean;

constructor(
fields?: Partial<OpenAIChatInput> &
Partial<AzureOpenAIInput> &
BaseChatModelParams & {
configuration?: ClientOptions & LegacyOpenAIInput;
},
fields?: ChatOpenAIFields,
/** @deprecated */
configuration?: ClientOptions & LegacyOpenAIInput
) {
Expand Down Expand Up @@ -541,6 +554,12 @@ export class ChatOpenAI<
...configuration,
...fields?.configuration,
};

// Assume only "gpt-..." models support strict tool calling as of 08/06/24.
this.supportsStrictToolCalling =
fields?.supportsStrictToolCalling !== undefined
? fields.supportsStrictToolCalling
: this.modelName.startsWith("gpt-");
}

getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
Expand All @@ -563,8 +582,9 @@ export class ChatOpenAI<
)[],
kwargs?: Partial<CallOptions>
): Runnable<BaseLanguageModelInput, AIMessageChunk, CallOptions> {
const strict = kwargs?.strict !== undefined ? kwargs.strict : this.supportsStrictToolCalling;
return this.bind({
tools: tools.map(convertToOpenAITool),
tools: tools.map((tool) => convertToOpenAITool(tool, { strict })),
...kwargs,
} as Partial<CallOptions>);
}
Expand All @@ -578,6 +598,7 @@ export class ChatOpenAI<
streaming?: boolean;
}
): Omit<OpenAIClient.Chat.ChatCompletionCreateParams, "messages"> {
const strict = options?.strict !== undefined ? options.strict : this.supportsStrictToolCalling;
function isStructuredToolArray(
tools?: unknown[]
): tools is StructuredToolInterface[] {
Expand Down Expand Up @@ -615,7 +636,7 @@ export class ChatOpenAI<
functions: options?.functions,
function_call: options?.function_call,
tools: isStructuredToolArray(options?.tools)
? options?.tools.map(convertToOpenAITool)
? options?.tools.map((tool) => convertToOpenAITool(tool, { strict }))
: options?.tools,
tool_choice: formatToOpenAIToolChoice(options?.tool_choice),
response_format: options?.response_format,
Expand Down
212 changes: 212 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import { it, expect, describe, beforeAll, afterAll, jest } from "@jest/globals";
import { ChatOpenAI } from "../chat_models.js";


describe("strict tool calling", () => {
const weatherTool = {
type: "function" as const,
function: {
name: "get_current_weather",
description: "Get the current weather in a location",
parameters: zodToJsonSchema(z.object({
location: z.string().describe("The location to get the weather for"),
}))
}
}

// Store the original value of LANGCHAIN_TRACING_V2
let oldLangChainTracingValue: string | undefined;
// Before all tests, save the current LANGCHAIN_TRACING_V2 value
beforeAll(() => {
oldLangChainTracingValue = process.env.LANGCHAIN_TRACING_V2;
})
// After all tests, restore the original LANGCHAIN_TRACING_V2 value
afterAll(() => {
if (oldLangChainTracingValue !== undefined) {
process.env.LANGCHAIN_TRACING_V2 = oldLangChainTracingValue;
} else {
// If it was undefined, remove the environment variable
delete process.env.LANGCHAIN_TRACING_V2;
}
})

it("Can accept strict as a call arg via .bindTools", async () => {
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>();
mockFetch.mockImplementation((url, options): Promise<any> => {
// Store the request details for later inspection
mockFetch.mock.calls.push({ url, options } as any);

// Return a mock response
return Promise.resolve({
ok: true,
json: () => Promise.resolve({}),
}) as Promise<any>;
});

const model = new ChatOpenAI({
model: "gpt-4",
configuration: {
fetch: mockFetch,
},
maxRetries: 0,
});

const modelWithTools = model.bindTools([weatherTool], { strict: true });

// This will fail since we're not returning a valid response in our mocked fetch function.
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow();
Comment on lines +58 to +59
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

category Functionality severity potentially major

The test cases are currently expecting the modelWithTools.invoke() call to throw an error, but the mock implementation is returning a successful response. This mismatch could lead to false positives in your tests. Consider either:

  1. Modifying the mock to return an error response:
mockFetch.mockRejectedValue(new Error('Mock API error'));
  1. Or, if the intention is to test successful responses, update the test expectation:
await expect(modelWithTools.invoke("What's the weather like?")).resolves.not.toThrow();

Additionally, add more specific error assertions to ensure you're catching the expected errors. This will make your tests more robust and informative.

Chat with Korbit by mentioning @korbit-ai, and give a 👍 or 👎 to help Korbit improve your reviews.


expect(mockFetch).toHaveBeenCalled();
const [_url, options] = mockFetch.mock.calls[0];

if (options && options.body) {
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({
type: "function",
function: {
...weatherTool.function,
// This should be added to the function call because `strict` was passed to `bindTools`
strict: true,
}
})]);
} else {
throw new Error("Body not found in request.")
}
});

it("Can accept strict as a call arg via .bind", async () => {
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>();
mockFetch.mockImplementation((url, options): Promise<any> => {
// Store the request details for later inspection
mockFetch.mock.calls.push({ url, options } as any);

// Return a mock response
return Promise.resolve({
ok: true,
json: () => Promise.resolve({}),
}) as Promise<any>;
});

const model = new ChatOpenAI({
model: "gpt-4",
configuration: {
fetch: mockFetch,
},
maxRetries: 0,
});

const modelWithTools = model.bind({
tools: [weatherTool],
strict: true
});

// This will fail since we're not returning a valid response in our mocked fetch function.
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow();

expect(mockFetch).toHaveBeenCalled();
const [_url, options] = mockFetch.mock.calls[0];

if (options && options.body) {
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({
type: "function",
function: {
...weatherTool.function,
// This should be added to the function call because `strict` was passed to `bind`
strict: true,
}
})]);
} else {
throw new Error("Body not found in request.")
}
});

it("Sets strict to true if the model name starts with 'gpt-'", async () => {
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>();
mockFetch.mockImplementation((url, options): Promise<any> => {
// Store the request details for later inspection
mockFetch.mock.calls.push({ url, options } as any);

// Return a mock response
return Promise.resolve({
ok: true,
json: () => Promise.resolve({}),
}) as Promise<any>;
});

const model = new ChatOpenAI({
model: "gpt-4",
configuration: {
fetch: mockFetch,
},
maxRetries: 0,
});

// Do NOT pass `strict` here since we're checking that it's set to true by default
const modelWithTools = model.bindTools([weatherTool]);

// This will fail since we're not returning a valid response in our mocked fetch function.
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow();

expect(mockFetch).toHaveBeenCalled();
const [_url, options] = mockFetch.mock.calls[0];

if (options && options.body) {
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({
type: "function",
function: {
...weatherTool.function,
// This should be added to the function call because `strict` was passed to `bind`
strict: true,
}
})]);
} else {
throw new Error("Body not found in request.")
}
});

it("Strict is false if supportsStrictToolCalling is false", async () => {
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>();
mockFetch.mockImplementation((url, options): Promise<any> => {
// Store the request details for later inspection
mockFetch.mock.calls.push({ url, options } as any);

// Return a mock response
return Promise.resolve({
ok: true,
json: () => Promise.resolve({}),
}) as Promise<any>;
});

const model = new ChatOpenAI({
model: "gpt-4",
configuration: {
fetch: mockFetch,
},
maxRetries: 0,
supportsStrictToolCalling: false,
});

// Do NOT pass `strict` here since we're checking that it's set to true by default
const modelWithTools = model.bindTools([weatherTool]);

// This will fail since we're not returning a valid response in our mocked fetch function.
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow();

expect(mockFetch).toHaveBeenCalled();
const [_url, options] = mockFetch.mock.calls[0];

if (options && options.body) {
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({
type: "function",
function: {
...weatherTool.function,
// This should be added to the function call because `strict` was passed to `bind`
strict: false,
}
})]);
} else {
throw new Error("Body not found in request.")
}
});
})
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { zodToJsonSchema } from "zod-to-json-schema";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { AIMessage } from "@langchain/core/messages";
import { ChatOpenAI } from "../chat_models.js";
import { test, expect } from "@jest/globals";

test("withStructuredOutput zod schema function calling", async () => {
const model = new ChatOpenAI({
Expand Down
7 changes: 7 additions & 0 deletions libs/langchain-openai/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ export interface OpenAIChatInput extends OpenAIBaseInput {
* Currently in experimental beta.
*/
__includeRawResponse?: boolean;

/**
* Whether the model supports the 'strict' argument when passing in tools.
* Defaults to `true` if `modelName`/`model` starts with 'gpt-' otherwise
* defaults to `false`.
*/
supportsStrictToolCalling?: boolean;
Comment on lines +159 to +164
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

category Functionality

The new supportsStrictToolCalling property has been added, which is great for extending functionality. However, it would be beneficial to add a comment explaining when a user might want to override the default behavior. For example, you could add a note like: 'Set this to true for models that support strict tool calling, or false for those that don't. Only override if you're certain about the model's capabilities.'

Chat with Korbit by mentioning @korbit-ai, and give a 👍 or 👎 to help Korbit improve your reviews.

}

export declare interface AzureOpenAIInput {
Expand Down
Loading