From ff1a60bafc3633e04fa3ef433c87a98d40eb6541 Mon Sep 17 00:00:00 2001 From: Yuki Watanabe <31463517+B-Step62@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:06:15 +0900 Subject: [PATCH] Support structured output in ChatDatabricks (#28) Signed-off-by: B-Step62 --- .../langchain_databricks/chat_models.py | 229 +++++++++++++++++- .../integration_tests/test_chat_models.py | 63 ++++- .../tests/unit_tests/test_chat_models.py | 46 ++++ 3 files changed, 335 insertions(+), 3 deletions(-) diff --git a/libs/databricks/langchain_databricks/chat_models.py b/libs/databricks/langchain_databricks/chat_models.py index 1c6f609..d0c54a0 100644 --- a/libs/databricks/langchain_databricks/chat_models.py +++ b/libs/databricks/langchain_databricks/chat_models.py @@ -2,6 +2,7 @@ import json import logging +from operator import itemgetter from typing import ( Any, Callable, @@ -35,14 +36,19 @@ ToolMessageChunk, ) from langchain_core.messages.tool import tool_call_chunk +from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser +from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, make_invalid_tool_call, parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool +from langchain_core.utils.pydantic import is_basemodel_subclass from mlflow.deployments import BaseDeploymentClient # type: ignore from pydantic import BaseModel, Field @@ -398,6 +404,227 @@ def bind_tools( kwargs["tool_choice"] = tool_choice return super().bind(tools=formatted_tools, **kwargs) + def with_structured_output( + self, + schema: Optional[Union[Dict, Type]] = None, + *, + method: Literal["function_calling", "json_mode"] = "function_calling", + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + """Model wrapper that returns outputs formatted to match the given schema. + + Assumes model is compatible with OpenAI tool-calling API. + + Args: + schema: The output schema as a dict or a Pydantic class. If a Pydantic class + then the model output will be an object of that class. If a dict then + the model output will be a dict. With a Pydantic class the returned + attributes will be validated, whereas with a dict they will not be. If + `method` is "function_calling" and `schema` is a dict, then the dict + must match the OpenAI function-calling spec or be a valid JSON schema + with top level 'title' and 'description' keys specified. + method: The method for steering model generation, either "function_calling" + or "json_mode". If "function_calling" then the schema will be converted + to an OpenAI function and the returned model will make use of the + function-calling API. If "json_mode" then OpenAI's JSON mode will be + used. Note that if using "json_mode" then you must include instructions + for formatting the output into the desired schema into the model call. + include_raw: If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". + + Returns: + A Runnable that takes any ChatModel input and returns as output: + + If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs + an instance of ``schema`` (i.e., a Pydantic object). + + Otherwise, if ``include_raw`` is False then Runnable outputs a dict. + + If ``include_raw`` is True, then Runnable outputs a dict with keys: + - ``"raw"``: BaseMessage + - ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. + - ``"parsing_error"``: Optional[BaseException] + + Example: Function-calling, Pydantic schema (method="function_calling", include_raw=False): + .. code-block:: python + + from langchain_databricks import ChatDatabricks + from pydantic import BaseModel + + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + + answer: str + justification: str + + + llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct") + structured_llm = llm.with_structured_output(AnswerWithJustification) + + structured_llm.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) + + # -> AnswerWithJustification( + # answer='They weigh the same', + # justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.' + # ) + + Example: Function-calling, Pydantic schema (method="function_calling", include_raw=True): + .. code-block:: python + + from langchain_databricks import ChatDatabricks + from pydantic import BaseModel + + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + + answer: str + justification: str + + + llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct") + structured_llm = llm.with_structured_output( + AnswerWithJustification, include_raw=True + ) + + structured_llm.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) + # -> { + # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}), + # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'), + # 'parsing_error': None + # } + + Example: Function-calling, dict schema (method="function_calling", include_raw=False): + .. code-block:: python + + from langchain_databricks import ChatDatabricks + from langchain_core.utils.function_calling import convert_to_openai_tool + from pydantic import BaseModel + + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + + answer: str + justification: str + + + dict_schema = convert_to_openai_tool(AnswerWithJustification) + llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct") + structured_llm = llm.with_structured_output(dict_schema) + + structured_llm.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) + # -> { + # 'answer': 'They weigh the same', + # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' + # } + + Example: JSON mode, Pydantic schema (method="json_mode", include_raw=True): + .. code-block:: + + from langchain_databricks import ChatDatabricks + from pydantic import BaseModel + + class AnswerWithJustification(BaseModel): + answer: str + justification: str + + llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct") + structured_llm = llm.with_structured_output( + AnswerWithJustification, + method="json_mode", + include_raw=True + ) + + structured_llm.invoke( + "Answer the following question. " + "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" + "What's heavier a pound of bricks or a pound of feathers?" + ) + # -> { + # 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), + # 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'), + # 'parsing_error': None + # } + + Example: JSON mode, no schema (schema=None, method="json_mode", include_raw=True): + .. code-block:: + + structured_llm = llm.with_structured_output(method="json_mode", include_raw=True) + + structured_llm.invoke( + "Answer the following question. " + "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" + "What's heavier a pound of bricks or a pound of feathers?" + ) + # -> { + # 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), + # 'parsed': { + # 'answer': 'They are both the same weight.', + # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.' + # }, + # 'parsing_error': None + # } + + + """ # noqa: E501 + if kwargs: + raise ValueError(f"Received unsupported arguments {kwargs}") + is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema) + if method == "function_calling": + if schema is None: + raise ValueError( + "schema must be specified when method is 'function_calling'. " + "Received None." + ) + tool_name = convert_to_openai_tool(schema)["function"]["name"] + llm = self.bind_tools([schema], tool_choice=tool_name) + if is_pydantic_schema: + output_parser: OutputParserLike = PydanticToolsParser( + tools=[schema], # type: ignore[list-item] + first_tool_only=True, # type: ignore[list-item] + ) + else: + output_parser = JsonOutputKeyToolsParser( + key_name=tool_name, first_tool_only=True + ) + elif method == "json_mode": + llm = self.bind(response_format={"type": "json_object"}) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + else: + raise ValueError( + f"Unrecognized method argument. Expected one of 'function_calling' or " + f"'json_mode'. Received: '{method}'" + ) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | output_parser + @property def _identifying_params(self) -> Dict[str, Any]: return self._default_params diff --git a/libs/databricks/tests/integration_tests/test_chat_models.py b/libs/databricks/tests/integration_tests/test_chat_models.py index f70ab6e..0a7f491 100644 --- a/libs/databricks/tests/integration_tests/test_chat_models.py +++ b/libs/databricks/tests/integration_tests/test_chat_models.py @@ -29,6 +29,7 @@ from langgraph.graph import START, StateGraph from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode, create_react_agent, tools_condition +from pydantic import BaseModel, Field from typing_extensions import TypedDict from langchain_databricks.chat_models import ChatDatabricks @@ -164,8 +165,6 @@ async def test_chat_databricks_abatch(): @pytest.mark.parametrize("tool_choice", [None, "auto", "required", "any", "none"]) def test_chat_databricks_tool_calls(tool_choice): - from pydantic import BaseModel, Field - chat = ChatDatabricks( endpoint=_TEST_ENDPOINT, temperature=0, @@ -219,6 +218,66 @@ class GetWeather(BaseModel): ] +# Pydantic-based schema +class AnswerWithJustification(BaseModel): + """An answer to the user question along with justification for the answer.""" + + answer: str = Field(description="The answer to the user question.") + justification: str = Field(description="The justification for the answer.") + + +# Raw JSON schema +JSON_SCHEMA = { + "title": "AnswerWithJustification", + "description": "An answer to the user question along with justification.", + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "The answer to the user question.", + }, + "justification": { + "type": "string", + "description": "The justification for the answer.", + }, + }, + "required": ["answer", "justification"], +} + + +@pytest.mark.parametrize("schema", [AnswerWithJustification, JSON_SCHEMA, None]) +@pytest.mark.parametrize("method", ["function_calling", "json_mode"]) +def test_chat_databricks_with_structured_output(schema, method): + llm = ChatDatabricks(endpoint=_TEST_ENDPOINT) + + if schema is None and method == "function_calling": + pytest.skip("Cannot use function_calling without schema") + + structured_llm = llm.with_structured_output(schema, method=method) + + if method == "function_calling": + prompt = "What day comes two days after Monday?" + else: + prompt = ( + "What day comes two days after Monday? Return in JSON format with key " + "'answer' for the answer and 'justification' for the justification." + ) + + response = structured_llm.invoke(prompt) + + if schema == AnswerWithJustification: + assert response.answer == "Wednesday" + assert response.justification is not None + else: + assert response["answer"] == "Wednesday" + assert response["justification"] is not None + + # Invoke with raw output + structured_llm = llm.with_structured_output(schema, method=method, include_raw=True) + response_with_raw = structured_llm.invoke(prompt) + assert isinstance(response_with_raw["raw"], AIMessage) + + def test_chat_databricks_runnable_sequence(): chat = ChatDatabricks( endpoint=_TEST_ENDPOINT, diff --git a/libs/databricks/tests/unit_tests/test_chat_models.py b/libs/databricks/tests/unit_tests/test_chat_models.py index 4848f11..37579da 100644 --- a/libs/databricks/tests/unit_tests/test_chat_models.py +++ b/libs/databricks/tests/unit_tests/test_chat_models.py @@ -20,6 +20,7 @@ ToolMessageChunk, ) from langchain_core.messages.tool import ToolCallChunk +from langchain_core.runnables import RunnableMap from pydantic import BaseModel, Field from langchain_databricks.chat_models import ( @@ -241,6 +242,51 @@ def test_chat_model_bind_tolls_with_invalid_choices(llm: ChatDatabricks) -> None ) +# Pydantic-based schema +class AnswerWithJustification(BaseModel): + """An answer to the user question along with justification for the answer.""" + + answer: str = Field(description="The answer to the user question.") + justification: str = Field(description="The justification for the answer.") + + +# Raw JSON schema +JSON_SCHEMA = { + "title": "AnswerWithJustification", + "description": "An answer to the user question along with justification.", + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "The answer to the user question.", + }, + "justification": { + "type": "string", + "description": "The justification for the answer.", + }, + }, + "required": ["answer", "justification"], +} + + +@pytest.mark.parametrize("schema", [AnswerWithJustification, JSON_SCHEMA, None]) +@pytest.mark.parametrize("method", ["function_calling", "json_mode"]) +def test_chat_model_with_structured_output(llm, schema, method: str): + if schema is None and method == "function_calling": + pytest.skip("Cannot use function_calling without schema") + + structured_llm = llm.with_structured_output(schema, method=method) + + bind = structured_llm.first.kwargs + if method == "function_calling": + assert bind["tool_choice"]["function"]["name"] == "AnswerWithJustification" + else: + assert bind["response_format"] == {"type": "json_object"} + + structured_llm = llm.with_structured_output(schema, include_raw=True, method=method) + assert isinstance(structured_llm.first, RunnableMap) + + ### Test data conversion functions ###