From d3126cd7ff6b6b8cc65f9024bf1f2e330c28fc0f Mon Sep 17 00:00:00 2001 From: Harald Schilly Date: Fri, 23 Feb 2024 18:43:38 +0100 Subject: [PATCH] server/ollama: some progress ... --- src/packages/frontend/customize.tsx | 7 +- .../frame-editors/llm/model-switch.tsx | 2 - src/packages/hub/webapp-configuration.ts | 1 + src/packages/pnpm-lock.yaml | 290 +++++++++++++++++- src/packages/server/llm/abuse.ts | 4 +- src/packages/server/llm/client.ts | 16 +- src/packages/server/llm/index.ts | 11 +- src/packages/server/llm/ollama.ts | 64 +++- src/packages/server/package.json | 7 + src/packages/util/db-schema/llm.ts | 10 +- 10 files changed, 388 insertions(+), 24 deletions(-) diff --git a/src/packages/frontend/customize.tsx b/src/packages/frontend/customize.tsx index edfe62c02dc..07f0c5a9a76 100644 --- a/src/packages/frontend/customize.tsx +++ b/src/packages/frontend/customize.tsx @@ -259,10 +259,9 @@ async function init_customize() { init_customize(); -function process_ollama(ollama) { - if (ollama) { - actions.setState({ ollama: fromJS(ollama) }); - } +function process_ollama(ollama?) { + if (!ollama) return; + actions.setState({ ollama: fromJS(ollama) }); } function process_kucalc(obj) { diff --git a/src/packages/frontend/frame-editors/llm/model-switch.tsx b/src/packages/frontend/frame-editors/llm/model-switch.tsx index 5743ad309c6..046bf90aff2 100644 --- a/src/packages/frontend/frame-editors/llm/model-switch.tsx +++ b/src/packages/frontend/frame-editors/llm/model-switch.tsx @@ -120,8 +120,6 @@ export default function ModelSwitch({ }); } - console.log("model", model); - // all models selectable here must be in util/db-schema/openai::USER_SELECTABLE_LANGUAGE_MODELS return ( =6.9.0'} @@ -5671,7 +5693,6 @@ packages: resolution: {integrity: sha512-AMZ2UWx+woHNfM11PyAEQmfSxi05jm9OlkxczuHeEqmvwPkYj6MWv44gbzDPefYOLysTOFyI3ziiy2ONmUZfpA==} dependencies: undici-types: 5.26.5 - dev: true /@types/node@18.19.4: resolution: {integrity: sha512-xNzlUhzoHotIsnFoXmJB+yWmBvFZgKCI9TtPIEdYIMM1KWfwuY8zh7wvc1u1OAXlC7dlf6mZVx/s+Y5KfFz19A==} @@ -9467,6 +9488,10 @@ packages: jest-util: 29.7.0 dev: true + /expr-eval@2.0.2: + resolution: {integrity: sha512-4EMSHGOPSwAfBiibw3ndnP0AvjDWLsMvGOvWEZ2F96IGk0bIVdjQisOHxReSkE13mHcfbuCiXw+G4y0zv6N8Eg==} + dev: false + /express-rate-limit@6.7.0(express@4.18.2): resolution: {integrity: sha512-vhwIdRoqcYB/72TK3tRZI+0ttS8Ytrk24GfmsxDXK9o9IhHNO5bXRiXQSExPQ4GbaE5tvIS7j1SGrxsuWs+sGA==} engines: {node: '>= 12.9.0'} @@ -12582,7 +12607,6 @@ packages: hasBin: true dependencies: argparse: 2.0.1 - dev: true /jsesc@2.5.2: resolution: {integrity: sha512-OYu7XEzjkCQ3C5Ps3QIZsQfNpqoJyZZA99wd9aWd05NCtC5pWOkShK2mkL6HXQR6/Cy2lbNdPlZBpuQHXE63gA==} @@ -12658,6 +12682,11 @@ packages: resolution: {integrity: sha512-trvBk1ki43VZptdBI5rIlG4YOzyeH/WefQt5rj1grasPn4iiZWKet8nkgc4GlsAylaztn0qZfUYOiTsASJFdNA==} dev: false + /jsonpointer@5.0.1: + resolution: {integrity: sha512-p/nXbhSEcu3pZRdkW1OfJhpsVtW1gd4Wa1fnQc9YLiTfAjn0312eMKimbdIQzuZl9aa9xUGaRlP9T/CJE/ditQ==} + engines: {node: '>=0.10.0'} + dev: false + /jsonwebtoken@9.0.2: resolution: {integrity: sha512-PRp66vJ865SSqOlgqS8hujT5U4AOgMfhrwYIuIhfKaoSCZcirrmASQr8CX7cUg+RMih+hgznrjp99o+W4pJLHQ==} engines: {node: '>=12', npm: '>=6'} @@ -12826,6 +12855,254 @@ packages: engines: {node: '>=18.0.0'} dev: false + /langchain@0.1.21(@google-ai/generativelanguage@1.1.0)(axios@1.6.7)(encoding@0.1.13)(google-auth-library@9.4.1)(lodash@4.17.21): + resolution: {integrity: sha512-OOcCFIgx23WyyNS1VJBLbC3QL5plQBVfp2drXw1OJAarZ8yEY3cgJq8NbTY37sMnLoJ2olFEzMuAOdlTur4cwQ==} + engines: {node: '>=18'} + peerDependencies: + '@aws-sdk/client-s3': ^3.310.0 + '@aws-sdk/client-sagemaker-runtime': ^3.310.0 + '@aws-sdk/client-sfn': ^3.310.0 + '@aws-sdk/credential-provider-node': ^3.388.0 + '@azure/storage-blob': ^12.15.0 + '@gomomento/sdk': ^1.51.1 + '@gomomento/sdk-core': ^1.51.1 + '@gomomento/sdk-web': ^1.51.1 + '@google-ai/generativelanguage': ^0.2.1 + '@google-cloud/storage': ^6.10.1 || ^7.7.0 + '@notionhq/client': ^2.2.10 + '@pinecone-database/pinecone': '*' + '@supabase/supabase-js': ^2.10.0 + '@vercel/kv': ^0.2.3 + '@xata.io/client': ^0.28.0 + apify-client: ^2.7.1 + assemblyai: ^4.0.0 + axios: '*' + cheerio: ^1.0.0-rc.12 + chromadb: '*' + convex: ^1.3.1 + couchbase: ^4.2.10 + d3-dsv: ^2.0.0 + epub2: ^3.0.1 + faiss-node: '*' + fast-xml-parser: ^4.2.7 + google-auth-library: ^8.9.0 + handlebars: ^4.7.8 + html-to-text: ^9.0.5 + ignore: ^5.2.0 + ioredis: ^5.3.2 + jsdom: '*' + mammoth: ^1.6.0 + mongodb: '>=5.2.0' + node-llama-cpp: '*' + notion-to-md: ^3.1.0 + officeparser: ^4.0.4 + pdf-parse: 1.1.1 + peggy: ^3.0.2 + playwright: ^1.32.1 + puppeteer: ^19.7.2 + pyodide: ^0.24.1 + redis: ^4.6.4 + sonix-speech-recognition: ^2.1.1 + srt-parser-2: ^1.2.3 + typeorm: ^0.3.12 + weaviate-ts-client: '*' + web-auth-library: ^1.0.3 + ws: ^8.14.2 + youtube-transcript: ^1.0.6 + youtubei.js: ^5.8.0 + peerDependenciesMeta: + '@aws-sdk/client-s3': + optional: true + '@aws-sdk/client-sagemaker-runtime': + optional: true + '@aws-sdk/client-sfn': + optional: true + '@aws-sdk/credential-provider-node': + optional: true + '@azure/storage-blob': + optional: true + '@gomomento/sdk': + optional: true + '@gomomento/sdk-core': + optional: true + '@gomomento/sdk-web': + optional: true + '@google-ai/generativelanguage': + optional: true + '@google-cloud/storage': + optional: true + '@notionhq/client': + optional: true + '@pinecone-database/pinecone': + optional: true + '@supabase/supabase-js': + optional: true + '@vercel/kv': + optional: true + '@xata.io/client': + optional: true + apify-client: + optional: true + assemblyai: + optional: true + axios: + optional: true + cheerio: + optional: true + chromadb: + optional: true + convex: + optional: true + couchbase: + optional: true + d3-dsv: + optional: true + epub2: + optional: true + faiss-node: + optional: true + fast-xml-parser: + optional: true + google-auth-library: + optional: true + handlebars: + optional: true + html-to-text: + optional: true + ignore: + optional: true + ioredis: + optional: true + jsdom: + optional: true + mammoth: + optional: true + mongodb: + optional: true + node-llama-cpp: + optional: true + notion-to-md: + optional: true + officeparser: + optional: true + pdf-parse: + optional: true + peggy: + optional: true + playwright: + optional: true + puppeteer: + optional: true + pyodide: + optional: true + redis: + optional: true + sonix-speech-recognition: + optional: true + srt-parser-2: + optional: true + typeorm: + optional: true + weaviate-ts-client: + optional: true + web-auth-library: + optional: true + ws: + optional: true + youtube-transcript: + optional: true + youtubei.js: + optional: true + dependencies: + '@anthropic-ai/sdk': 0.9.1(encoding@0.1.13) + '@google-ai/generativelanguage': 1.1.0(encoding@0.1.13) + '@langchain/community': 0.0.32(@google-ai/generativelanguage@1.1.0)(encoding@0.1.13)(google-auth-library@9.4.1)(lodash@4.17.21) + '@langchain/core': 0.1.32 + '@langchain/openai': 0.0.14(encoding@0.1.13) + axios: 1.6.7 + binary-extensions: 2.2.0 + expr-eval: 2.0.2 + google-auth-library: 9.4.1(encoding@0.1.13) + js-tiktoken: 1.0.10 + js-yaml: 4.1.0 + jsonpointer: 5.0.1 + langchainhub: 0.0.8 + langsmith: 0.1.3 + ml-distance: 4.0.1 + openapi-types: 12.1.3 + p-retry: 4.6.2 + uuid: 9.0.1 + yaml: 2.3.4 + zod: 3.22.4 + zod-to-json-schema: 3.22.4(zod@3.22.4) + transitivePeerDependencies: + - '@aws-crypto/sha256-js' + - '@aws-sdk/client-bedrock-agent-runtime' + - '@aws-sdk/client-bedrock-runtime' + - '@aws-sdk/client-dynamodb' + - '@aws-sdk/client-kendra' + - '@aws-sdk/client-lambda' + - '@azure/search-documents' + - '@clickhouse/client' + - '@cloudflare/ai' + - '@datastax/astra-db-ts' + - '@elastic/elasticsearch' + - '@getmetal/metal-sdk' + - '@getzep/zep-js' + - '@gradientai/nodejs-sdk' + - '@huggingface/inference' + - '@mozilla/readability' + - '@opensearch-project/opensearch' + - '@planetscale/database' + - '@qdrant/js-client-rest' + - '@raycast/api' + - '@rockset/client' + - '@smithy/eventstream-codec' + - '@smithy/protocol-http' + - '@smithy/signature-v4' + - '@smithy/util-utf8' + - '@supabase/postgrest-js' + - '@tensorflow-models/universal-sentence-encoder' + - '@tensorflow/tfjs-converter' + - '@tensorflow/tfjs-core' + - '@upstash/redis' + - '@upstash/vector' + - '@vercel/postgres' + - '@writerai/writer-sdk' + - '@xenova/transformers' + - '@zilliz/milvus2-sdk-node' + - better-sqlite3 + - cassandra-driver + - closevector-common + - closevector-node + - closevector-web + - cohere-ai + - discord.js + - dria + - encoding + - firebase-admin + - googleapis + - hnswlib-node + - llmonitor + - lodash + - lunary + - mysql2 + - neo4j-driver + - pg + - pg-copy-streams + - pickleparser + - portkey-ai + - replicate + - typesense + - usearch + - vectordb + - voy-search + dev: false + + /langchainhub@0.0.8: + resolution: {integrity: sha512-Woyb8YDHgqqTOZvWIbm2CaFDGfZ4NTSyXV687AG4vXEfoNo7cGQp7nhl7wL3ehenKWmNEmcxCLgOZzW8jE6lOQ==} + dev: false + /langs@2.0.0: resolution: {integrity: sha512-v4pxOBEQVN1WBTfB1crhTtxzNLZU9HPWgadlwzWKISJtt6Ku/CnpBrwVy+jFv8StjxsPfwPFzO0CMwdZLJ0/BA==} dev: false @@ -14277,6 +14554,10 @@ packages: - encoding dev: false + /openapi-types@12.1.3: + resolution: {integrity: sha512-N4YtSYJqghVu4iek2ZUvcN/0aqH1kRDuNqzcycDxhOUpg7GdvLa2F3DgS6yBNhInhv2r/6I0Flkn7CqL8+nIcw==} + dev: false + /opener@1.5.2: resolution: {integrity: sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A==} hasBin: true @@ -19395,6 +19676,11 @@ packages: resolution: {integrity: sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==} requiresBuild: true + /yaml@2.3.4: + resolution: {integrity: sha512-8aAvwVUSHpfEqTQ4w/KMlf3HcRdt50E5ODIQJBw1fQ5RL34xabzxtUlzTXVqc4rkZsPbvrXKWnABCD7kWSmocA==} + engines: {node: '>= 14'} + dev: false + /yargs-parser@18.1.3: resolution: {integrity: sha512-o50j0JeToy/4K6OZcaQmW6lyXXKhq7csREXcDwk2omFPJEwUNOVtJKvmDr9EI1fAJZUyZcRF7kxGBWmRXudrCQ==} engines: {node: '>=6'} diff --git a/src/packages/server/llm/abuse.ts b/src/packages/server/llm/abuse.ts index ae89cac4dae..fabe29c8c4f 100644 --- a/src/packages/server/llm/abuse.ts +++ b/src/packages/server/llm/abuse.ts @@ -26,6 +26,7 @@ import getPool from "@cocalc/database/pool"; import { assertPurchaseAllowed } from "@cocalc/server/purchases/is-purchase-allowed"; import { isFreeModel, + isOllamaLLM, LanguageModel, LanguageService, model2service, @@ -66,7 +67,8 @@ export async function checkForAbuse({ // at least some amount of tracking. throw Error("at least one of account_id or analytics_cookie must be set"); } - if (!MODELS.includes(model)) { + + if (!MODELS.includes(model) && !isOllamaLLM(model)) { throw Error(`invalid model "${model}"`); } diff --git a/src/packages/server/llm/client.ts b/src/packages/server/llm/client.ts index 4f0fa454ea6..74dddddaafb 100644 --- a/src/packages/server/llm/client.ts +++ b/src/packages/server/llm/client.ts @@ -11,7 +11,11 @@ import OpenAI from "openai"; import getLogger from "@cocalc/backend/logger"; import { getServerSettings } from "@cocalc/database/settings/server-settings"; -import { LanguageModel, model2vendor } from "@cocalc/util/db-schema/llm"; +import { + LanguageModel, + isOllamaLLM, + model2vendor, +} from "@cocalc/util/db-schema/llm"; import { unreachable } from "@cocalc/util/misc"; import { VertexAIClient } from "./vertex-ai-client"; @@ -78,9 +82,9 @@ const ollamaCache: { [key: string]: Ollama } = {}; * All other config parameters are passed to the Ollama constructor (e.g. topK, temperature, etc.). */ export async function getOllama(model: string) { - if (model.startsWith("ollama-")) { + if (isOllamaLLM(model)) { throw new Error( - `At this point, the model name should no longer have the "ollama-" prefix`, + `At this point, the model name should be one of Ollama, but it was ${model}`, ); } @@ -92,6 +96,10 @@ export async function getOllama(model: string) { ); } + if (config.cocalc?.disabled) { + throw new Error(`Ollama model ${model} is disabled`); + } + // the key is a hash of the model name and the specific config – such that changes in the config will invalidate the cache const key = `${model}:${jsonStable(config)}`; @@ -109,7 +117,7 @@ export async function getOllama(model: string) { ); } - const keepAlive = config.keepAlive ?? -1; + const keepAlive: string = config.keepAlive ?? "-1"; // extract all other properties from the config, except the url, model, keepAlive field and the "cocalc" field const other = _.omit(config, ["baseUrl", "model", "keepAlive", "cocalc"]); diff --git a/src/packages/server/llm/index.ts b/src/packages/server/llm/index.ts index 53bdb8aecf1..ac93ea4a963 100644 --- a/src/packages/server/llm/index.ts +++ b/src/packages/server/llm/index.ts @@ -25,6 +25,7 @@ import { OpenAIMessages, getLLMCost, isFreeModel, + isOllamaLLM, isValidModel, model2service, model2vendor, @@ -33,9 +34,9 @@ import { ChatOptions, ChatOutput, History } from "@cocalc/util/types/llm"; import { checkForAbuse } from "./abuse"; import { callChatGPTAPI } from "./call-llm"; import { getClient } from "./client"; +import { evaluateOllama } from "./ollama"; import { saveResponse } from "./save-response"; import { VertexAIClient } from "./vertex-ai-client"; -import { evaluateOllama } from "./ollama"; const log = getLogger("llm"); @@ -91,11 +92,9 @@ async function evaluateImpl({ const start = Date.now(); await checkForAbuse({ account_id, analytics_cookie, model }); - const client = await getClient(model); - const { output, total_tokens, prompt_tokens, completion_tokens } = await (async () => { - if (model.startsWith("ollama-")) { + if (isOllamaLLM(model)) { return await evaluateOllama({ system, history, @@ -109,7 +108,6 @@ async function evaluateImpl({ system, history, input, - client, model, maxTokens, stream, @@ -179,11 +177,12 @@ async function evaluteCall({ system, history, input, - client, model, maxTokens, stream, }) { + const client = await getClient(model); + if (client instanceof VertexAIClient) { return await evaluateVertexAI({ system, diff --git a/src/packages/server/llm/ollama.ts b/src/packages/server/llm/ollama.ts index 91ad6317f2e..dc059930a74 100644 --- a/src/packages/server/llm/ollama.ts +++ b/src/packages/server/llm/ollama.ts @@ -1,6 +1,15 @@ +import { + ChatPromptTemplate, + MessagesPlaceholder, +} from "@langchain/core/prompts"; +import { RunnableWithMessageHistory } from "@langchain/core/runnables"; +import { ChatMessageHistory } from "langchain/stores/message/in_memory"; + import getLogger from "@cocalc/backend/logger"; +import { fromOllamaModel, isOllamaLLM } from "@cocalc/util/db-schema/llm"; import { ChatOutput, History } from "@cocalc/util/types/llm"; import { getOllama } from "./client"; +import { AIMessage, HumanMessage } from "@langchain/core/messages"; const log = getLogger("llm:ollama"); @@ -17,10 +26,10 @@ interface OllamaOpts { export async function evaluateOllama( opts: Readonly, ): Promise { - if (!opts.model.startsWith("ollama-")) { + if (!isOllamaLLM(opts.model)) { throw new Error(`model ${opts.model} not supported`); } - const model = opts.model.slice("ollama-".length); + const model = fromOllamaModel(opts.model); const { system, history, input, maxTokens, stream } = opts; log.debug("evaluateOllama", { input, @@ -33,7 +42,56 @@ export async function evaluateOllama( const ollama = await getOllama(model); - const chunks = await ollama.stream(input); + const msgs: ["ai" | "human", string][] = []; + + if (history) { + let nextRole: "model" | "user" = "user"; + for (const { content } of history) { + if (nextRole === "user") { + msgs.push(["human", content]); + } else { + msgs.push(["ai", content]); + } + nextRole = nextRole === "user" ? "model" : "user"; + } + } + + const prompt = ChatPromptTemplate.fromMessages([ + ["system", system ?? ""], + new MessagesPlaceholder("chat_history"), + ["human", "{input}"], + ]); + + const chain = prompt.pipe(ollama); + + const chainWithHistory = new RunnableWithMessageHistory({ + runnable: chain, + inputMessagesKey: "input", + historyMessagesKey: "chat_history", + getMessageHistory: async (_) => { + const chatHistory = new ChatMessageHistory(); + // await history.addMessage(new HumanMessage("be brief")); + // await history.addMessage(new AIMessage("ok")); + if (history) { + let nextRole: "model" | "user" = "user"; + for (const { content } of history) { + if (nextRole === "user") { + await chatHistory.addMessage(new HumanMessage(content)); + } else { + await chatHistory.addMessage(new AIMessage(content)); + } + nextRole = nextRole === "user" ? "model" : "user"; + } + } + + return chatHistory; + }, + }); + + const chunks = await chainWithHistory.stream( + { input }, + { configurable: { sessionId: "ignored" } }, + ); let output = ""; for await (const chunk of chunks) { diff --git a/src/packages/server/package.json b/src/packages/server/package.json index a167fa941aa..a418403148f 100644 --- a/src/packages/server/package.json +++ b/src/packages/server/package.json @@ -47,6 +47,7 @@ "@google/generative-ai": "^0.1.3", "@isaacs/ttlcache": "^1.2.1", "@langchain/community": "^0.0.32", + "@langchain/core": "^0.1.32", "@node-saml/passport-saml": "^4.0.4", "@passport-js/passport-twitter": "^1.0.8", "@passport-next/passport-google-oauth2": "^1.0.0", @@ -81,6 +82,7 @@ "json-stable-stringify": "^1.0.1", "jwt-decode": "^3.1.2", "lambda-cloud-node-api": "^1.0.1", + "langchain": "^0.1.21", "lodash": "^4.17.21", "lru-cache": "^7.14.1", "ms": "2.1.2", @@ -114,5 +116,10 @@ "devDependencies": { "@types/node": "^18.16.14", "expect": "^26.6.2" + }, + "pnpm": { + "overrides": { + "@langchain/core": "^0.1.32" + } } } diff --git a/src/packages/util/db-schema/llm.ts b/src/packages/util/db-schema/llm.ts index 36a20a8e787..38e46f7c601 100644 --- a/src/packages/util/db-schema/llm.ts +++ b/src/packages/util/db-schema/llm.ts @@ -146,11 +146,17 @@ export function model2vendor(model: LanguageModel | string): Vendor { } export function toOllamaModel(model: string) { + if (isOllamaLLM(model)) { + throw new Error(`already an ollama model: ${model}`); + } return `ollama-${model}`; } export function fromOllamaModel(model: string) { - return model.replace(/^ollama-/, ""); + if (!isOllamaLLM(model)) { + throw new Error(`not an ollama model: ${model}`); + } + return model.slice("ollama-".length); } export function isOllamaLLM(model: string) { @@ -306,7 +312,7 @@ const LLM_COST: { [name in LanguageModel]: Cost } = { export function isValidModel(model?: string): boolean { if (model == null) return false; - if (model.startsWith("ollama-")) return true; + if (isOllamaLLM(model)) return true; return LLM_COST[model ?? ""] != null; }