forked from e2b-dev/fragments
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.ts
59 lines (49 loc) · 1.96 KB
/
models.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import { createAnthropic } from '@ai-sdk/anthropic'
import { createOpenAI } from '@ai-sdk/openai'
import { createGoogleGenerativeAI } from '@ai-sdk/google'
import { createMistral } from '@ai-sdk/mistral'
import { createOllama } from 'ollama-ai-provider'
export type LLMModel = {
id: string
name: string
provider: string
providerId: string
}
export type LLMModelConfig = {
model?: string
apiKey?: string
baseURL?: string
temperature?: number
topP?: number
topK?: number
frequencyPenalty?: number
presencePenalty?: number
maxTokens?: number
}
export function getModelClient(model: LLMModel, config: LLMModelConfig) {
const { id: modelNameString, providerId } = model
const { apiKey, baseURL } = config
const providerConfigs = {
anthropic: () => createAnthropic({ apiKey, baseURL })(modelNameString),
openai: () => createOpenAI({ apiKey, baseURL })(modelNameString),
google: () => createGoogleGenerativeAI({ apiKey, baseURL })(modelNameString),
mistral: () => createMistral({ apiKey, baseURL })(modelNameString),
groq: () => createOpenAI({ apiKey: apiKey || process.env.GROQ_API_KEY, baseURL: baseURL || 'https://api.groq.com/openai/v1' })(modelNameString),
togetherai: () => createOpenAI({ apiKey: apiKey || process.env.TOGETHER_AI_API_KEY, baseURL: baseURL || 'https://api.together.xyz/v1' })(modelNameString),
ollama: () => createOllama({ baseURL })(modelNameString),
fireworks: () => createOpenAI({ apiKey: apiKey || process.env.FIREWORKS_API_KEY, baseURL: baseURL || 'https://api.fireworks.ai/inference/v1' })(modelNameString),
}
const createClient = providerConfigs[providerId as keyof typeof providerConfigs]
if (!createClient) {
throw new Error(`Unsupported provider: ${providerId}`)
}
return createClient()
}
export function getDefaultMode (model: LLMModel) {
const { id: modelNameString, providerId } = model
// monkey patch fireworks
if (providerId === 'fireworks') {
return 'json'
}
return 'auto'
}