Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tools/cogview): Updated cogview tool to support cogview-3 and the latest cogview-3-plus #8382

Merged
merged 11 commits into from
Sep 22, 2024
Merged
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .__version__ import __version__
from ._client import ZhipuAI
from .core._errors import (
from .core import (
APIAuthenticationError,
APIConnectionError,
APIInternalError,
APIReachLimitError,
APIRequestFailedError,
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "v2.0.1"
__version__ = "v2.1.0"
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
from typing_extensions import override

from . import api_resource
from .core import _jwt_token
from .core._base_type import NOT_GIVEN, NotGiven
from .core._errors import ZhipuAIError
from .core._http_client import ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient
from .core import NOT_GIVEN, ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient, NotGiven, ZhipuAIError, _jwt_token


class ZhipuAI(HttpClient):
chat: api_resource.chat
chat: api_resource.chat.Chat
api_key: str
_disable_token_cache: bool = True

def __init__(
self,
Expand All @@ -28,10 +26,15 @@ def __init__(
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
http_client: httpx.Client | None = None,
custom_headers: Mapping[str, str] | None = None,
disable_token_cache: bool = True,
_strict_response_validation: bool = False,
) -> None:
if api_key is None:
raise ZhipuAIError("No api_key provided, please provide it through parameters or environment variables")
api_key = os.environ.get("ZHIPUAI_API_KEY")
if api_key is None:
raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供")
self.api_key = api_key
self._disable_token_cache = disable_token_cache

if base_url is None:
base_url = os.environ.get("ZHIPUAI_BASE_URL")
Expand All @@ -42,21 +45,31 @@ def __init__(
super().__init__(
version=__version__,
base_url=base_url,
max_retries=max_retries,
timeout=timeout,
custom_httpx_client=http_client,
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
)
self.chat = api_resource.chat.Chat(self)
self.images = api_resource.images.Images(self)
self.embeddings = api_resource.embeddings.Embeddings(self)
self.files = api_resource.files.Files(self)
self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
self.batches = api_resource.Batches(self)
self.knowledge = api_resource.Knowledge(self)
self.tools = api_resource.Tools(self)
self.videos = api_resource.Videos(self)
self.assistant = api_resource.Assistant(self)

@property
@override
def _auth_headers(self) -> dict[str, str]:
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
return {"Authorization": f"{_jwt_token.generate_token(api_key)}"}
if self._disable_token_cache:
return {"Authorization": f"Bearer {api_key}"}
else:
return {"Authorization": f"Bearer {_jwt_token.generate_token(api_key)}"}

def __del__(self) -> None:
if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
from .chat import chat
from .assistant import (
Assistant,
)
from .batches import Batches
from .chat import (
AsyncCompletions,
Chat,
Completions,
)
from .embeddings import Embeddings
from .files import Files
from .fine_tuning import fine_tuning
from .files import Files, FilesWithRawResponse
from .fine_tuning import FineTuning
from .images import Images
from .knowledge import Knowledge
from .tools import Tools
from .videos import (
Videos,
)

__all__ = [
"Videos",
"AsyncCompletions",
"Chat",
"Completions",
"Images",
"Embeddings",
"Files",
"FilesWithRawResponse",
"FineTuning",
"Batches",
"Knowledge",
"Tools",
"Assistant",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .assistant import Assistant

__all__ = ["Assistant"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

import httpx

from ...core import (
NOT_GIVEN,
BaseAPI,
Body,
Headers,
NotGiven,
StreamResponse,
deepcopy_minimal,
make_request_options,
maybe_transform,
)
from ...types.assistant import AssistantCompletion
from ...types.assistant.assistant_conversation_resp import ConversationUsageListResp
from ...types.assistant.assistant_support_resp import AssistantSupportResp

if TYPE_CHECKING:
from ..._client import ZhipuAI

from ...types.assistant import assistant_conversation_params, assistant_create_params

__all__ = ["Assistant"]


class Assistant(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)

def conversation(
self,
assistant_id: str,
model: str,
messages: list[assistant_create_params.ConversationMessage],
*,
stream: bool = True,
conversation_id: Optional[str] = None,
attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None,
metadata: dict | None = None,
request_id: str = None,
user_id: str = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> StreamResponse[AssistantCompletion]:
body = deepcopy_minimal(
{
"assistant_id": assistant_id,
"model": model,
"messages": messages,
"stream": stream,
"conversation_id": conversation_id,
"attachments": attachments,
"metadata": metadata,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/assistant",
body=maybe_transform(body, assistant_create_params.AssistantParameters),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=AssistantCompletion,
stream=stream or True,
stream_cls=StreamResponse[AssistantCompletion],
)

def query_support(
self,
*,
assistant_id_list: list[str] = None,
request_id: str = None,
user_id: str = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AssistantSupportResp:
body = deepcopy_minimal(
{
"assistant_id_list": assistant_id_list,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/assistant/list",
body=body,
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=AssistantSupportResp,
)

def query_conversation_usage(
self,
assistant_id: str,
page: int = 1,
page_size: int = 10,
*,
request_id: str = None,
user_id: str = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ConversationUsageListResp:
body = deepcopy_minimal(
{
"assistant_id": assistant_id,
"page": page,
"page_size": page_size,
"request_id": request_id,
"user_id": user_id,
}
)
return self._post(
"/assistant/conversation/list",
body=maybe_transform(body, assistant_conversation_params.ConversationParameters),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=ConversationUsageListResp,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal, Optional

import httpx

from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options, maybe_transform
from ..core.pagination import SyncCursorPage
from ..types import batch_create_params, batch_list_params
from ..types.batch import Batch

if TYPE_CHECKING:
from .._client import ZhipuAI


class Batches(BaseAPI):
def __init__(self, client: ZhipuAI) -> None:
super().__init__(client)

def create(
self,
*,
completion_window: str | None = None,
endpoint: Literal["/v1/chat/completions", "/v1/embeddings"],
input_file_id: str,
metadata: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
auto_delete_input_file: bool = True,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Batch:
return self._post(
"/batches",
body=maybe_transform(
{
"completion_window": completion_window,
"endpoint": endpoint,
"input_file_id": input_file_id,
"metadata": metadata,
"auto_delete_input_file": auto_delete_input_file,
},
batch_create_params.BatchCreateParams,
),
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Batch,
)

def retrieve(
self,
batch_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Batch:
"""
Retrieves a batch.

Args:
extra_headers: Send extra headers

extra_body: Add additional JSON properties to the request

timeout: Override the client-level default timeout for this request, in seconds
"""
if not batch_id:
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
return self._get(
f"/batches/{batch_id}",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Batch,
)

def list(
self,
*,
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> SyncCursorPage[Batch]:
"""List your organization's batches.

Args:
after: A cursor for use in pagination.

`after` is an object ID that defines your place
in the list. For instance, if you make a list request and receive 100 objects,
ending with obj_foo, your subsequent call can include after=obj_foo in order to
fetch the next page of the list.

limit: A limit on the number of objects to be returned. Limit can range between 1 and
100, and the default is 20.

extra_headers: Send extra headers

extra_body: Add additional JSON properties to the request

timeout: Override the client-level default timeout for this request, in seconds
"""
return self._get_api_list(
"/batches",
page=SyncCursorPage[Batch],
options=make_request_options(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout,
query=maybe_transform(
{
"after": after,
"limit": limit,
},
batch_list_params.BatchListParams,
),
),
model=Batch,
)

def cancel(
self,
batch_id: str,
*,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Batch:
"""
Cancels an in-progress batch.

Args:
batch_id: The ID of the batch to cancel.
extra_headers: Send extra headers

extra_body: Add additional JSON properties to the request

timeout: Override the client-level default timeout for this request, in seconds

"""
if not batch_id:
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
return self._post(
f"/batches/{batch_id}/cancel",
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
cast_type=Batch,
)
Loading