diff --git a/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py b/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py index b9a6cee87cd93..c980bfe83e534 100644 --- a/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py +++ b/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py @@ -99,6 +99,10 @@ class Ollama(FunctionCallingLLM): default=True, description="Whether the model is a function calling model.", ) + keep_alive: Optional[Union[float, str]] = Field( + default="5m", + description="controls how long the model will stay loaded into memory following the request(default: 5m)", + ) _client: Optional[Client] = PrivateAttr() _async_client: Optional[AsyncClient] = PrivateAttr() @@ -116,6 +120,7 @@ def __init__( client: Optional[Client] = None, async_client: Optional[AsyncClient] = None, is_function_calling_model: bool = True, + keep_alive: Optional[Union[float, str]] = None, **kwargs: Any, ) -> None: super().__init__( @@ -128,6 +133,7 @@ def __init__( json_mode=json_mode, additional_kwargs=additional_kwargs, is_function_calling_model=is_function_calling_model, + keep_alive=keep_alive, **kwargs, ) @@ -279,6 +285,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: format="json" if self.json_mode else "", tools=tools, options=self._model_kwargs, + keep_alive=self.keep_alive, ) tool_calls = response["message"].get("tool_calls", []) @@ -311,6 +318,7 @@ def gen() -> ChatResponseGen: format="json" if self.json_mode else "", tools=tools, options=self._model_kwargs, + keep_alive=self.keep_alive, ) response_txt = "" @@ -354,6 +362,7 @@ async def gen() -> ChatResponseAsyncGen: format="json" if self.json_mode else "", tools=tools, options=self._model_kwargs, + keep_alive=self.keep_alive, ) response_txt = "" @@ -396,6 +405,7 @@ async def achat( format="json" if self.json_mode else "", tools=tools, options=self._model_kwargs, + keep_alive=self.keep_alive, ) tool_calls = response["message"].get("tool_calls", []) diff --git a/llama-index-integrations/llms/llama-index-llms-ollama/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-ollama/pyproject.toml index 43973bfa61d9d..17182735835a5 100644 --- a/llama-index-integrations/llms/llama-index-llms-ollama/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-ollama/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-ollama" readme = "README.md" -version = "0.3.2" +version = "0.3.3" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"