diff --git a/.cspell.yaml b/.cspell.yaml index 51bccdc..0961d49 100644 --- a/.cspell.yaml +++ b/.cspell.yaml @@ -28,6 +28,7 @@ words: - longcontext - openai - openpose + - permissioning - qwen - qwenai - sdxl diff --git a/.env.example b/.env.example index 738d059..a0eefcc 100644 --- a/.env.example +++ b/.env.example @@ -26,3 +26,9 @@ VYRO_API_KEY="vk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" # See https://api.minimax.chat/user-center/basic-information/interface-key MINIMAX_API_ORG="xxxxxxxx" MINIMAX_API_KEY="eyJhxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + +# Google Gemini AI +# Documentation: +# https://ai.google.dev/tutorials/ai-studio_quickstart +# https://makersuite.google.com/app/apikey +GEMINI_API_KEY="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" diff --git a/package.json b/package.json index ae6acd4..02994a4 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@zhengxs/ai", - "version": "0.4.0", + "version": "0.5.0", "description": "llm sdk", "keywords": [ "ai", diff --git a/samples/gemini/chat/chat.ts b/samples/gemini/chat/chat.ts new file mode 100644 index 0000000..8db49fd --- /dev/null +++ b/samples/gemini/chat/chat.ts @@ -0,0 +1,14 @@ +import { GeminiAI } from '../../../src'; + +const ai = new GeminiAI(); + +async function main() { + const chatCompletion = await ai.chat.completions.create({ + model: 'gemini-pro', + messages: [{ role: 'user', content: 'Say this is a test' }], + }); + + console.log(chatCompletion.choices[0].message.content); +} + +main(); diff --git a/samples/gemini/chat/stream.ts b/samples/gemini/chat/stream.ts new file mode 100644 index 0000000..7748f84 --- /dev/null +++ b/samples/gemini/chat/stream.ts @@ -0,0 +1,17 @@ +import { GeminiAI } from '../../../src'; + +const client = new GeminiAI(); + +async function main() { + const stream = await client.chat.completions.create({ + stream: true, + model: 'gemini-pro', + messages: [{ role: 'user', content: 'Say this is a test' }], + }); + + for await (const chunk of stream) { + console.log(chunk.choices[0]?.delta?.content || ''); + } +} + +main(); diff --git a/samples/gemini/models/list.ts b/samples/gemini/models/list.ts new file mode 100644 index 0000000..d92e462 --- /dev/null +++ b/samples/gemini/models/list.ts @@ -0,0 +1,13 @@ +import { GeminiAI } from '../../../src'; + +const ai = new GeminiAI(); + +async function main() { + const list = await ai.models.list(); + + for await (const model of list) { + console.log(model); + } +} + +main(); diff --git a/samples/gemini/models/retrieve.ts b/samples/gemini/models/retrieve.ts new file mode 100644 index 0000000..f099d3f --- /dev/null +++ b/samples/gemini/models/retrieve.ts @@ -0,0 +1,11 @@ +import { GeminiAI } from '../../../src'; + +const ai = new GeminiAI(); + +async function main() { + const model = await ai.models.retrieve('gemini-pro'); + + console.log(model); +} + +main(); diff --git a/src/gemini/index.ts b/src/gemini/index.ts new file mode 100644 index 0000000..7faa69b --- /dev/null +++ b/src/gemini/index.ts @@ -0,0 +1,73 @@ +import type { Agent } from 'node:http'; + +import { APIClient, type DefaultQuery, type Fetch, type FinalRequestOptions, type Headers } from 'openai/core'; + +import * as API from './resources'; + +const BASE_URL = 'https://generativelanguage.googleapis.com/v1'; + +export interface GeminiAIOptions { + baseURL?: string; + apiKey?: string; + timeout?: number | undefined; + httpAgent?: Agent; + fetch?: Fetch | undefined; + defaultHeaders?: Headers; + defaultQuery?: DefaultQuery; +} + +export class GeminiAI extends APIClient { + apiKey: string; + + private _options: GeminiAIOptions; + + constructor(options: GeminiAIOptions = {}) { + const { + apiKey = process.env.GEMINI_API_KEY || '', + baseURL = process.env.GEMINI_BASE_URL || BASE_URL, + timeout = 30000, + fetch = globalThis.fetch, + httpAgent = undefined, + ...rest + } = options; + + super({ + baseURL, + timeout, + fetch, + httpAgent, + ...rest, + }); + + this._options = options; + + this.apiKey = apiKey; + } + + chat = new API.Chat(this); + + models = new API.Models(this); + + protected override defaultHeaders(opts: FinalRequestOptions): Headers { + return { + ...super.defaultHeaders(opts), + ...this._options.defaultHeaders, + }; + } + + protected override defaultQuery(): DefaultQuery | undefined { + return { + ...this._options.defaultQuery, + key: this.apiKey, + }; + } +} + +export namespace GeminiAI { + export type ChatModel = API.ChatModel; + export type ChatCompletionCreateParams = API.ChatCompletionCreateParams; + export type ChatCompletionCreateParamsStreaming = API.ChatCompletionCreateParamsStreaming; + export type ChatCompletionCreateParamsNonStreaming = API.ChatCompletionCreateParamsNonStreaming; +} + +export default GeminiAI; diff --git a/src/gemini/resource.ts b/src/gemini/resource.ts new file mode 100644 index 0000000..2add8ac --- /dev/null +++ b/src/gemini/resource.ts @@ -0,0 +1,9 @@ +import { GeminiAI } from './index'; + +export class APIResource { + protected _client: GeminiAI; + + constructor(client: GeminiAI) { + this._client = client; + } +} diff --git a/src/gemini/resources/chat/chat.ts b/src/gemini/resources/chat/chat.ts new file mode 100644 index 0000000..214edf9 --- /dev/null +++ b/src/gemini/resources/chat/chat.ts @@ -0,0 +1,6 @@ +import { APIResource } from '../../resource'; +import { Completions } from './completions'; + +export class Chat extends APIResource { + completions = new Completions(this._client); +} diff --git a/src/gemini/resources/chat/completions.ts b/src/gemini/resources/chat/completions.ts new file mode 100644 index 0000000..0f3e1af --- /dev/null +++ b/src/gemini/resources/chat/completions.ts @@ -0,0 +1,303 @@ +import { randomUUID } from 'crypto'; +import OpenAI from 'openai'; +import { Stream } from 'openai/streaming'; + +import { ensureArray } from '../../../util'; +import { APIResource } from '../../resource'; + +export class Completions extends APIResource { + /** + * Creates a model response for the given chat conversation. + */ + create(body: ChatCompletionCreateParamsNonStreaming, options?: OpenAI.RequestOptions): Promise; + create( + body: ChatCompletionCreateParamsStreaming, + options?: OpenAI.RequestOptions, + ): Promise>; + + async create( + params: ChatCompletionCreateParams, + options?: OpenAI.RequestOptions, + ): Promise> { + const { stream, model } = params; + const body = this.buildCreateParams(params); + const path = `/models/${model}:generateContent`; + + const response: Response = await this._client.post(path, { + ...options, + query: stream ? { alt: 'sse' } : {}, + body: body as unknown as Record, + stream: false, + __binaryResponse: true, + }); + + if (stream) { + const controller = new AbortController(); + + options?.signal?.addEventListener('abort', () => { + controller.abort(); + }); + + return this.afterSSEResponse(model, response, controller); + } + + return this.afterResponse(model, response); + } + + protected buildCreateParams(params: ChatCompletionCreateParams) { + const { messages = [], max_tokens, top_p, top_k, stop, temperature } = params; + + function formatContentParts(content: string | OpenAI.ChatCompletionContentPart[]) { + const parts: GeminiChat.Part[] = []; + + if (typeof content === 'string') { + parts.push({ text: content }); + return parts; + } + + for (const part of content) { + if (part.type === 'text') { + parts.push({ text: part.text }); + } else { + // TODO: Handle images + // parts.push({ + // inline_data: { + // "mime_type": "image/jpeg", + // "data": "'$(base64 -w0 image.jpg)'" + // } + // }); + } + } + + return parts; + } + + function formatRole(role: string): 'user' | 'model' { + return role === 'user' ? 'user' : 'model'; + } + + const generationConfig: GeminiChat.GenerationConfig = {}; + + const data: GeminiChat.GenerateContentRequest = { + contents: messages.map(item => { + return { + role: formatRole(item.role), + parts: formatContentParts(item.content!), + }; + }), + generationConfig, + }; + + if (temperature != null) { + generationConfig.temperature = temperature; + } + + if (top_k != null) { + generationConfig.topK = top_k; + } + + if (top_p != null) { + generationConfig.topP = top_p; + } + + if (stop != null) { + generationConfig.stopSequences = ensureArray(stop); + } + + if (max_tokens != null) { + generationConfig.maxOutputTokens = max_tokens; + } + + return data; + } + + protected async afterResponse(model: string, response: Response): Promise { + const data: GeminiChat.GenerateContentResponse = await response.json(); + const choices: OpenAI.ChatCompletion.Choice[] = data.candidates!.map(item => { + const [part] = item.content.parts; + + const choice: OpenAI.ChatCompletion.Choice = { + index: item.index, + message: { + role: 'assistant', + content: part.text!, + }, + finish_reason: 'stop', + }; + + switch (item.finishReason) { + case 'MAX_TOKENS': + choice.finish_reason = 'length'; + break; + case 'SAFETY': + case 'RECITATION': + choice.finish_reason = 'content_filter'; + break; + default: + choice.finish_reason = 'stop'; + } + + return choice; + }); + + return { + id: randomUUID(), + model: model, + choices, + object: 'chat.completion', + created: Date.now() / 10, + // TODO 需要支持 usage + usage: { + completion_tokens: 0, + prompt_tokens: 0, + total_tokens: 0, + }, + }; + } + + protected afterSSEResponse( + model: string, + response: Response, + controller: AbortController, + ): Stream { + const stream = Stream.fromSSEResponse(response, controller); + + const toChoices = (data: GeminiChat.GenerateContentResponse) => { + return data.candidates!.map(item => { + const [part] = item.content.parts; + + const choice: OpenAI.ChatCompletionChunk.Choice = { + index: item.index, + delta: { + role: 'assistant', + content: part.text || '', + }, + finish_reason: null, + }; + + switch (item.finishReason) { + case 'MAX_TOKENS': + choice.finish_reason = 'length'; + break; + case 'SAFETY': + case 'RECITATION': + choice.finish_reason = 'content_filter'; + break; + default: + choice.finish_reason = 'stop'; + } + + return choice; + }); + }; + + async function* iterator(): AsyncIterator { + for await (const chunk of stream) { + yield { + id: randomUUID(), + model, + choices: toChoices(chunk), + object: 'chat.completion.chunk', + created: Date.now() / 10, + }; + } + } + + return new Stream(iterator, controller); + } +} + +export type ChatCompletionCreateParamsNonStreaming = Chat.ChatCompletionCreateParamsNonStreaming; + +export type ChatCompletionCreateParamsStreaming = Chat.ChatCompletionCreateParamsStreaming; + +export type ChatCompletionCreateParams = Chat.ChatCompletionCreateParams; + +export type ChatModel = Chat.ChatModel; + +export namespace Chat { + export type ChatModel = (string & NonNullable) | 'gemini-pro'; + // 支持的有点问题 + // | 'gemini-pro-vision'; + + export interface ChatCompletionCreateParamsNonStreaming extends OpenAI.ChatCompletionCreateParamsNonStreaming { + model: ChatModel; + top_k?: number; + } + + export interface ChatCompletionCreateParamsStreaming extends OpenAI.ChatCompletionCreateParamsStreaming { + model: ChatModel; + top_k?: number | null; + } + + export type ChatCompletionCreateParams = ChatCompletionCreateParamsNonStreaming | ChatCompletionCreateParamsStreaming; +} + +namespace GeminiChat { + export interface GenerationConfig { + candidateCount?: number; + stopSequences?: string[]; + maxOutputTokens?: number; + temperature?: number; + topP?: number; + topK?: number; + } + + export interface GenerateContentCandidate { + index: number; + content: Content; + finishReason?: 'FINISH_REASON_UNSPECIFIED' | 'STOP' | 'MAX_TOKENS' | 'SAFETY' | 'RECITATION' | 'OTHER'; + finishMessage?: string; + citationMetadata?: CitationMetadata; + } + + export interface GenerateContentResponse { + candidates?: GenerateContentCandidate[]; + // promptFeedback?: PromptFeedback; + } + + export interface CitationMetadata { + citationSources: CitationSource[]; + } + + export interface CitationSource { + startIndex?: number; + endIndex?: number; + uri?: string; + license?: string; + } + + export interface InputContent { + parts: string | Array; + role: string; + } + + export interface Content extends InputContent { + parts: Part[]; + } + + export type Part = TextPart | InlineDataPart; + + export interface TextPart { + text: string; + inlineData?: never; + } + + export interface InlineDataPart { + text?: never; + inlineData: GeminiContentBlob; + } + + export interface GeminiContentBlob { + mimeType: string; + data: string; + } + + export interface BaseParams { + generationConfig?: GenerationConfig; + } + + export interface GenerateContentRequest extends BaseParams { + contents: Content[]; + } +} diff --git a/src/gemini/resources/chat/index.ts b/src/gemini/resources/chat/index.ts new file mode 100644 index 0000000..4805791 --- /dev/null +++ b/src/gemini/resources/chat/index.ts @@ -0,0 +1,8 @@ +export { Chat } from './chat'; +export { + type ChatModel, + type ChatCompletionCreateParams, + type ChatCompletionCreateParamsNonStreaming, + type ChatCompletionCreateParamsStreaming, + Completions, +} from './completions'; diff --git a/src/gemini/resources/index.ts b/src/gemini/resources/index.ts new file mode 100644 index 0000000..143b817 --- /dev/null +++ b/src/gemini/resources/index.ts @@ -0,0 +1,3 @@ +export * from './chat/index'; + +export { Models, type Model, ModelsPage } from './models'; diff --git a/src/gemini/resources/models.ts b/src/gemini/resources/models.ts new file mode 100644 index 0000000..3348c54 --- /dev/null +++ b/src/gemini/resources/models.ts @@ -0,0 +1,71 @@ +import OpenAI from 'openai'; +import { type FinalRequestOptions, type PagePromise, type RequestOptions } from 'openai/core'; +import { Page } from 'openai/pagination'; + +import { type GeminiAI } from '../../index'; +import { APIResource } from '../resource'; + +// TODO 输出原始对象 +export class Models extends APIResource { + /** + * Retrieves a model instance, providing basic information about the model such as + * the owner and permissioning. + */ + async retrieve(model: string, options?: RequestOptions): Promise { + const item: GeminiModel = await this._client.get(`/models/${model}`, options); + + return { + id: item.name, + created: 0, + object: 'model', + owned_by: 'google', + }; + } + + /** + * Lists the currently available models, and provides basic information about each + * one such as the owner and availability. + */ + list(options?: RequestOptions): PagePromise { + return this._client.getAPIList('/models', ModelsPage, options); + } +} + +export class ModelsPage extends Page { + constructor(client: GeminiAI, response: Response, body: GeminiPageResponse, options: FinalRequestOptions) { + const data: Model[] = body.models.map(item => { + return { + id: item.name, + created: 0, + object: 'model', + owned_by: 'google', + }; + }); + + super(client, response, { data, object: 'list' }, options); + } +} + +interface GeminiModel { + name: string; + version: string; + displayName: string; + description: string; + inputTokenLimit: string; + outputTokenLimit: string; + supportedGenerationMethods: string[]; +} + +interface GeminiPageResponse { + models: GeminiModel[]; +} + +/** + * Describes an OpenAI model offering that can be used with the API. + */ +export type Model = OpenAI.Models.Model; + +export namespace Models { + export import Model = OpenAI.Models.Model; + export import ModelsPage = OpenAI.Models.ModelsPage; +} diff --git a/src/index.ts b/src/index.ts index 10542d5..4e366da 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,6 +1,7 @@ import OpenAI from 'openai'; import ErnieAI, { ErnieAIOptions } from './ernie'; +import GeminiAI, { GeminiAIOptions } from './gemini'; import HunYuanAI, { HunYuanAIOptions } from './hunyuan'; import MinimaxAI, { MinimaxAIOptions } from './minimax'; import QWenAI, { QWenAIOptions } from './qwen'; @@ -10,6 +11,8 @@ import VYroAI, { VYroAIOptions } from './vyro'; export { ErnieAI, type ErnieAIOptions, + GeminiAI, + type GeminiAIOptions, HunYuanAI, type HunYuanAIOptions, MinimaxAI, diff --git a/src/resource.ts b/src/resource.ts index 111df1c..476eb5a 100644 --- a/src/resource.ts +++ b/src/resource.ts @@ -1,9 +1,9 @@ import { APIClient } from 'openai/core'; -export class APIResource { - protected _client: APIClient; +export class APIResource { + protected _client: Client; - constructor(client: APIClient) { + constructor(client: Client) { this._client = client; } }