Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support LangChain 0.3.0 #14

Merged
merged 2 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions libs/databricks/langchain_databricks/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,11 @@
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
PrivateAttr,
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from mlflow.deployments import BaseDeploymentClient # type: ignore
from pydantic import BaseModel, Field

from langchain_databricks.utils import get_deployment_client

Expand Down Expand Up @@ -180,7 +177,7 @@ class ChatDatabricks(BaseChatModel):
Tool calling:
.. code-block:: python

from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field

class GetWeather(BaseModel):
'''Get the current weather in a given location'''
Expand Down Expand Up @@ -225,13 +222,16 @@ class GetPopulation(BaseModel):
"""List of strings to stop generation at."""
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
extra_params: dict = Field(default_factory=dict)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: After updating to pydantic v2 import, this field starts to return FieldInfo object as default value instead of factory. Might be pydantic issue, but I couldn't figure it out the root cause so just using None as a workaround.

extra_params: Optional[Dict[str, Any]] = None
"""Any extra parameters to pass to the endpoint."""
_client: Any = PrivateAttr()
client: Optional[BaseDeploymentClient] = Field(
default=None, exclude=True
) #: :meta private:

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._client = get_deployment_client(self.target_uri)
self.client = get_deployment_client(self.target_uri)
self.extra_params = self.extra_params or {}

@property
def _default_params(self) -> Dict[str, Any]:
Expand All @@ -254,7 +254,7 @@ def _generate(
**kwargs: Any,
) -> ChatResult:
data = self._prepare_inputs(messages, stop, **kwargs)
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
resp = self.client.predict(endpoint=self.endpoint, inputs=data) # type: ignore
return self._convert_response_to_chat_result(resp)

def _prepare_inputs(
Expand All @@ -267,7 +267,7 @@ def _prepare_inputs(
"messages": [_convert_message_to_dict(msg) for msg in messages],
"temperature": self.temperature,
"n": self.n,
**self.extra_params,
**self.extra_params, # type: ignore
**kwargs,
}
if stop := self.stop or stop:
Expand Down Expand Up @@ -299,7 +299,7 @@ def _stream(
) -> Iterator[ChatGenerationChunk]:
data = self._prepare_inputs(messages, stop, **kwargs)
first_chunk_role = None
for chunk in self._client.predict_stream(endpoint=self.endpoint, inputs=data):
for chunk in self.client.predict_stream(endpoint=self.endpoint, inputs=data): # type: ignore
if chunk["choices"]:
choice = chunk["choices"][0]

Expand Down
2 changes: 1 addition & 1 deletion libs/databricks/langchain_databricks/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, Iterator, List

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
from pydantic import BaseModel, PrivateAttr

from langchain_databricks.utils import get_deployment_client

Expand Down
Loading
Loading