-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
openai[minor],core[minor]: Add support for passing strict in openai tools #13
base: cloned_main_c5fb8
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,14 +34,29 @@ export function convertToOpenAIFunction( | |
*/ | ||
export function convertToOpenAITool( | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
tool: StructuredToolInterface | Record<string, any> | RunnableToolLike | ||
tool: StructuredToolInterface | Record<string, any> | RunnableToolLike, | ||
fields?: { | ||
/** | ||
* If `true`, model output is guaranteed to exactly match the JSON Schema | ||
* provided in the function definition. | ||
*/ | ||
strict?: boolean; | ||
} | ||
): ToolDefinition { | ||
let toolDef: ToolDefinition | undefined; | ||
if (isStructuredTool(tool) || isRunnableToolLike(tool)) { | ||
return { | ||
toolDef = { | ||
type: "function", | ||
function: convertToOpenAIFunction(tool), | ||
}; | ||
} else { | ||
toolDef = tool as ToolDefinition; | ||
} | ||
|
||
if (fields?.strict !== undefined) { | ||
toolDef.function.strict = fields.strict; | ||
} | ||
|
||
return tool as ToolDefinition; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
|
||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
import { z } from "zod"; | ||
import { zodToJsonSchema } from "zod-to-json-schema"; | ||
import { it, expect, describe, beforeAll, afterAll, jest } from "@jest/globals"; | ||
import { ChatOpenAI } from "../chat_models.js"; | ||
|
||
|
||
describe("strict tool calling", () => { | ||
const weatherTool = { | ||
type: "function" as const, | ||
function: { | ||
name: "get_current_weather", | ||
description: "Get the current weather in a location", | ||
parameters: zodToJsonSchema(z.object({ | ||
location: z.string().describe("The location to get the weather for"), | ||
})) | ||
} | ||
} | ||
|
||
// Store the original value of LANGCHAIN_TRACING_V2 | ||
let oldLangChainTracingValue: string | undefined; | ||
// Before all tests, save the current LANGCHAIN_TRACING_V2 value | ||
beforeAll(() => { | ||
oldLangChainTracingValue = process.env.LANGCHAIN_TRACING_V2; | ||
}) | ||
// After all tests, restore the original LANGCHAIN_TRACING_V2 value | ||
afterAll(() => { | ||
if (oldLangChainTracingValue !== undefined) { | ||
process.env.LANGCHAIN_TRACING_V2 = oldLangChainTracingValue; | ||
} else { | ||
// If it was undefined, remove the environment variable | ||
delete process.env.LANGCHAIN_TRACING_V2; | ||
} | ||
}) | ||
|
||
it("Can accept strict as a call arg via .bindTools", async () => { | ||
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>(); | ||
mockFetch.mockImplementation((url, options): Promise<any> => { | ||
// Store the request details for later inspection | ||
mockFetch.mock.calls.push({ url, options } as any); | ||
|
||
// Return a mock response | ||
return Promise.resolve({ | ||
ok: true, | ||
json: () => Promise.resolve({}), | ||
}) as Promise<any>; | ||
}); | ||
|
||
const model = new ChatOpenAI({ | ||
model: "gpt-4", | ||
configuration: { | ||
fetch: mockFetch, | ||
}, | ||
maxRetries: 0, | ||
}); | ||
|
||
const modelWithTools = model.bindTools([weatherTool], { strict: true }); | ||
|
||
// This will fail since we're not returning a valid response in our mocked fetch function. | ||
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); | ||
Comment on lines
+58
to
+59
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test cases are currently expecting the
mockFetch.mockRejectedValue(new Error('Mock API error'));
await expect(modelWithTools.invoke("What's the weather like?")).resolves.not.toThrow(); Additionally, add more specific error assertions to ensure you're catching the expected errors. This will make your tests more robust and informative.
|
||
|
||
expect(mockFetch).toHaveBeenCalled(); | ||
const [_url, options] = mockFetch.mock.calls[0]; | ||
|
||
if (options && options.body) { | ||
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ | ||
type: "function", | ||
function: { | ||
...weatherTool.function, | ||
// This should be added to the function call because `strict` was passed to `bindTools` | ||
strict: true, | ||
} | ||
})]); | ||
} else { | ||
throw new Error("Body not found in request.") | ||
} | ||
}); | ||
|
||
it("Can accept strict as a call arg via .bind", async () => { | ||
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>(); | ||
mockFetch.mockImplementation((url, options): Promise<any> => { | ||
// Store the request details for later inspection | ||
mockFetch.mock.calls.push({ url, options } as any); | ||
|
||
// Return a mock response | ||
return Promise.resolve({ | ||
ok: true, | ||
json: () => Promise.resolve({}), | ||
}) as Promise<any>; | ||
}); | ||
|
||
const model = new ChatOpenAI({ | ||
model: "gpt-4", | ||
configuration: { | ||
fetch: mockFetch, | ||
}, | ||
maxRetries: 0, | ||
}); | ||
|
||
const modelWithTools = model.bind({ | ||
tools: [weatherTool], | ||
strict: true | ||
}); | ||
|
||
// This will fail since we're not returning a valid response in our mocked fetch function. | ||
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); | ||
|
||
expect(mockFetch).toHaveBeenCalled(); | ||
const [_url, options] = mockFetch.mock.calls[0]; | ||
|
||
if (options && options.body) { | ||
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ | ||
type: "function", | ||
function: { | ||
...weatherTool.function, | ||
// This should be added to the function call because `strict` was passed to `bind` | ||
strict: true, | ||
} | ||
})]); | ||
} else { | ||
throw new Error("Body not found in request.") | ||
} | ||
}); | ||
|
||
it("Sets strict to true if the model name starts with 'gpt-'", async () => { | ||
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>(); | ||
mockFetch.mockImplementation((url, options): Promise<any> => { | ||
// Store the request details for later inspection | ||
mockFetch.mock.calls.push({ url, options } as any); | ||
|
||
// Return a mock response | ||
return Promise.resolve({ | ||
ok: true, | ||
json: () => Promise.resolve({}), | ||
}) as Promise<any>; | ||
}); | ||
|
||
const model = new ChatOpenAI({ | ||
model: "gpt-4", | ||
configuration: { | ||
fetch: mockFetch, | ||
}, | ||
maxRetries: 0, | ||
}); | ||
|
||
// Do NOT pass `strict` here since we're checking that it's set to true by default | ||
const modelWithTools = model.bindTools([weatherTool]); | ||
|
||
// This will fail since we're not returning a valid response in our mocked fetch function. | ||
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); | ||
|
||
expect(mockFetch).toHaveBeenCalled(); | ||
const [_url, options] = mockFetch.mock.calls[0]; | ||
|
||
if (options && options.body) { | ||
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ | ||
type: "function", | ||
function: { | ||
...weatherTool.function, | ||
// This should be added to the function call because `strict` was passed to `bind` | ||
strict: true, | ||
} | ||
})]); | ||
} else { | ||
throw new Error("Body not found in request.") | ||
} | ||
}); | ||
|
||
it("Strict is false if supportsStrictToolCalling is false", async () => { | ||
const mockFetch = jest.fn<(url: any, init?: any) => Promise<any>>(); | ||
mockFetch.mockImplementation((url, options): Promise<any> => { | ||
// Store the request details for later inspection | ||
mockFetch.mock.calls.push({ url, options } as any); | ||
|
||
// Return a mock response | ||
return Promise.resolve({ | ||
ok: true, | ||
json: () => Promise.resolve({}), | ||
}) as Promise<any>; | ||
}); | ||
|
||
const model = new ChatOpenAI({ | ||
model: "gpt-4", | ||
configuration: { | ||
fetch: mockFetch, | ||
}, | ||
maxRetries: 0, | ||
supportsStrictToolCalling: false, | ||
}); | ||
|
||
// Do NOT pass `strict` here since we're checking that it's set to true by default | ||
const modelWithTools = model.bindTools([weatherTool]); | ||
|
||
// This will fail since we're not returning a valid response in our mocked fetch function. | ||
await expect(modelWithTools.invoke("What's the weather like?")).rejects.toThrow(); | ||
|
||
expect(mockFetch).toHaveBeenCalled(); | ||
const [_url, options] = mockFetch.mock.calls[0]; | ||
|
||
if (options && options.body) { | ||
expect(JSON.parse(options.body).tools).toEqual([expect.objectContaining({ | ||
type: "function", | ||
function: { | ||
...weatherTool.function, | ||
// This should be added to the function call because `strict` was passed to `bind` | ||
strict: false, | ||
} | ||
})]); | ||
} else { | ||
throw new Error("Body not found in request.") | ||
} | ||
}); | ||
}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,6 +155,13 @@ export interface OpenAIChatInput extends OpenAIBaseInput { | |
* Currently in experimental beta. | ||
*/ | ||
__includeRawResponse?: boolean; | ||
|
||
/** | ||
* Whether the model supports the 'strict' argument when passing in tools. | ||
* Defaults to `true` if `modelName`/`model` starts with 'gpt-' otherwise | ||
* defaults to `false`. | ||
*/ | ||
supportsStrictToolCalling?: boolean; | ||
Comment on lines
+159
to
+164
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new
|
||
} | ||
|
||
export declare interface AzureOpenAIInput { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the return value to include modifications.
The function is currently returning the original
tool
parameter instead of the modifiedtoolDef
. Ensure that the return statement reflects the intended changes, including thestrict
parameter.Committable suggestion