Skip to content

Commit

Permalink
util/llm: refactoring to keep my sanity (and fixing circular imports)
Browse files Browse the repository at this point in the history
  • Loading branch information
haraldschilly committed Feb 23, 2024
1 parent b6f4c87 commit 71140d3
Show file tree
Hide file tree
Showing 30 changed files with 421 additions and 384 deletions.
2 changes: 1 addition & 1 deletion src/packages/frontend/account/avatar/avatar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import { LanguageModelVendorAvatar } from "@cocalc/frontend/components/language-
import { ProjectTitle } from "@cocalc/frontend/projects/project-title";
import { DEFAULT_COLOR } from "@cocalc/frontend/users/store";
import { webapp_client } from "@cocalc/frontend/webapp-client";
import { service2model } from "@cocalc/util/db-schema/openai";
import { service2model } from "@cocalc/util/db-schema/llm";
import { ensure_bound, startswith, trunc_middle } from "@cocalc/util/misc";
import { avatar_fontcolor } from "./font-color";

Expand Down
2 changes: 1 addition & 1 deletion src/packages/frontend/account/chatbot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
MODELS,
Vendor,
model2vendor,
} from "@cocalc/util/db-schema/openai";
} from "@cocalc/util/db-schema/llm";

// we either check if the prefix is one of the known ones (used in some circumstances)
// or if the account id is exactly one of the language models (more precise)
Expand Down
2 changes: 1 addition & 1 deletion src/packages/frontend/account/other-settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import {
getValidLanguageModelName,
isFreeModel,
model2vendor,
} from "@cocalc/util/db-schema/openai";
} from "@cocalc/util/db-schema/llm";
import {
VBAR_EXPLANATION,
VBAR_KEY,
Expand Down
2 changes: 1 addition & 1 deletion src/packages/frontend/account/useLanguageModelSetting.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
fromOllamaModel,
getValidLanguageModelName,
isOllamaLLM,
} from "@cocalc/util/db-schema/openai";
} from "@cocalc/util/db-schema/llm";

export const SETTINGS_LANGUAGE_MODEL_KEY = "language_model";

Expand Down
2 changes: 1 addition & 1 deletion src/packages/frontend/chat/actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {
model2vendor,
type LanguageModel,
LANGUAGE_MODEL_PREFIXES,
} from "@cocalc/util/db-schema/openai";
} from "@cocalc/util/db-schema/llm";
import { cmp, isValidUUID, parse_hashtags, uuid } from "@cocalc/util/misc";
import { getSortedDates } from "./chat-log";
import { message_to_markdown } from "./message";
Expand Down
12 changes: 5 additions & 7 deletions src/packages/frontend/client/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,17 @@ import { EventEmitter } from "events";

import { redux } from "@cocalc/frontend/app-framework";
import type { History } from "@cocalc/frontend/misc/openai"; // do not import until needed -- it is HUGE!
import type {
EmbeddingData,
LanguageModel,
} from "@cocalc/util/db-schema/openai";
import type { EmbeddingData } from "@cocalc/util/db-schema/openai";
import {
MAX_EMBEDDINGS_TOKENS,
MAX_REMOVE_LIMIT,
MAX_SAVE_LIMIT,
MAX_SEARCH_LIMIT,
isFreeModel,
model2service,
} from "@cocalc/util/db-schema/openai";
import * as message from "@cocalc/util/message";
import type { WebappClient } from "./client";
import { LanguageModel, LanguageService } from "@cocalc/util/db-schema/llm";
import { isFreeModel, model2service } from "@cocalc/util/db-schema/llm";

const DEFAULT_SYSTEM_PROMPT =
"ASSUME THAT I HAVE FULL ACCESS TO COCALC AND I AM USING COCALC RIGHT NOW. ENCLOSE ALL MATH IN $. INCLUDE THE LANGUAGE DIRECTLY AFTER THE TRIPLE BACKTICKS IN ALL MARKDOWN CODE BLOCKS. BE BRIEF.";
Expand Down Expand Up @@ -98,7 +95,8 @@ export class LLMClient {
}

if (!isFreeModel(model)) {
const service = model2service(model);
// Ollama and others are treated as "free"
const service = model2service(model) as LanguageService;
// when client gets non-free openai model request, check if allowed. If not, show quota modal.
const { allowed, reason } =
await this.client.purchases_client.isPurchaseAllowed(service);
Expand Down
2 changes: 1 addition & 1 deletion src/packages/frontend/codemirror/extensions/ai-formula.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import ModelSwitch, {
import { show_react_modal } from "@cocalc/frontend/misc";
import track from "@cocalc/frontend/user-tracking";
import { webapp_client } from "@cocalc/frontend/webapp-client";
import { isFreeModel, isLanguageModel } from "@cocalc/util/db-schema/openai";
import { isFreeModel, isLanguageModel } from "@cocalc/util/db-schema/llm";
import { unreachable } from "@cocalc/util/misc";

type Mode = "tex" | "md";
Expand Down
3 changes: 1 addition & 2 deletions src/packages/frontend/components/language-model-icon.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { isLanguageModel, model2vendor } from "@cocalc/util/db-schema/openai";

import { CSS } from "@cocalc/frontend/app-framework";
import { isLanguageModel, model2vendor } from "@cocalc/util/db-schema/llm";
import { unreachable } from "@cocalc/util/misc";
import AIAvatar from "./ai-avatar";
import GoogleGeminiLogo from "./google-gemini-avatar";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {
LLM_USERNAMES,
USER_SELECTABLE_LANGUAGE_MODELS,
model2service,
} from "@cocalc/util/db-schema/openai";
} from "@cocalc/util/db-schema/llm";
import { cmp, timestamp_cmp, trunc_middle } from "@cocalc/util/misc";
import { Item } from "./complete";

Expand Down
2 changes: 1 addition & 1 deletion src/packages/frontend/frame-editors/llm/create-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export interface Options {
command: string;
allowEmpty?: boolean;
tag?: string;
model: LanguageModel;
model: LanguageModel | string;
}

export default async function createChat({
Expand Down
15 changes: 12 additions & 3 deletions src/packages/frontend/frame-editors/llm/model-switch.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import {
LLM_USERNAMES,
LanguageModel,
USER_SELECTABLE_LANGUAGE_MODELS,
fromOllamaModel,
isFreeModel,
isOllamaLLM,
model2service,
toOllamaModel,
} from "@cocalc/util/db-schema/openai";
} from "@cocalc/util/db-schema/llm";

export { DEFAULT_MODEL };
export type { LanguageModel };
Expand Down Expand Up @@ -139,11 +141,18 @@ export default function ModelSwitch({
);
}

export function modelToName(model: LanguageModel): string {
export function modelToName(model: LanguageModel | string): string {
if (isOllamaLLM(model)) {
const ollama = redux.getStore("customize").get("ollama")?.toJS() ?? {};
const om = ollama[fromOllamaModel(model)];
if (om) {
return om.display ?? `Ollama ${model}`;
}
}
return LLM_USERNAMES[model] ?? model;
}

export function modelToMention(model: LanguageModel): string {
export function modelToMention(model: LanguageModel | string): string {
return `<span class="user-mention" account-id=${model2service(
model,
)} >@${modelToName(model)}</span>`;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ async function updateInput(
actions: Actions,
id,
scope,
model: LanguageModel,
model: LanguageModel | string,
): Promise<{ input: string; inputOrig: string }> {
if (scope == "none") {
return { input: "", inputOrig: "" };
Expand Down
2 changes: 1 addition & 1 deletion src/packages/frontend/jupyter/chatgpt/explain.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async function getExplanation({
actions: JupyterActions;
project_id: string;
path: string;
model: LanguageModel;
model: LanguageModel | string;
}) {
const message = createMessage({ id, actions, model, open: false });
if (!message) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {
LanguageModel,
getVendorStatusCheckMD,
model2vendor,
} from "@cocalc/util/db-schema/openai";
} from "@cocalc/util/db-schema/llm";
import { COLORS } from "@cocalc/util/theme";
import { JupyterActions } from "../browser-actions";
import { insertCell } from "./util";
Expand Down Expand Up @@ -168,7 +168,7 @@ interface QueryLanguageModelProps {
actions: JupyterActions;
frameActions: React.MutableRefObject<NotebookFrameActions | undefined>;
id: string;
model: LanguageModel;
model: LanguageModel | string;
path: string;
position: "above" | "below";
project_id: string;
Expand Down Expand Up @@ -316,7 +316,7 @@ interface GetInputProps {
actions: JupyterActions;
frameActions: React.MutableRefObject<NotebookFrameActions | undefined>;
id: string;
model: LanguageModel;
model: LanguageModel | string;
position: "above" | "below";
prompt: string;
}
Expand Down
10 changes: 5 additions & 5 deletions src/packages/frontend/misc/openai.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// NOTE! This gpt-3-tokenizer is LARGE, e.g., 1.6MB, so be
// sure to async load it by clients of this code.
import GPT3Tokenizer from "gpt3-tokenizer";
import type { Model } from "@cocalc/util/db-schema/openai";
import { getMaxTokens } from "@cocalc/util/db-schema/openai";
import type { Model } from "@cocalc/util/db-schema/llm";
import { getMaxTokens } from "@cocalc/util/db-schema/llm";

export { getMaxTokens };

Expand All @@ -25,7 +25,7 @@ const tokenizer = new GPT3Tokenizer({ type: "gpt3" });

export function numTokensUpperBound(
content: string,
maxTokens: number
maxTokens: number,
): number {
return (
tokenizer.encode(content.slice(0, maxTokens * APPROX_CHARACTERS_PER_TOKEN))
Expand Down Expand Up @@ -64,7 +64,7 @@ export function truncateMessage(content: string, maxTokens: number): string {
export function truncateHistory(
history: History,
maxTokens: number,
model: Model
model: Model,
): History {
if (maxTokens <= 0) {
return [];
Expand Down Expand Up @@ -101,7 +101,7 @@ export function truncateHistory(
const before = tokens[largestIndex].length;
const toRemove = Math.max(
1,
Math.min(maxTokens - total, Math.ceil(tokens[largestIndex].length / 5))
Math.min(maxTokens - total, Math.ceil(tokens[largestIndex].length / 5)),
);
tokens[largestIndex] = tokens[largestIndex].slice(0, -toRemove);
const after = tokens[largestIndex].length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import { once } from "@cocalc/util/async-utils";
import {
getVendorStatusCheckMD,
model2vendor,
} from "@cocalc/util/db-schema/openai";
} from "@cocalc/util/db-schema/llm";
import { field_cmp, to_iso_path } from "@cocalc/util/misc";
import { COLORS } from "@cocalc/util/theme";
import { ensure_project_running } from "../../project-start-warning";
Expand Down
2 changes: 1 addition & 1 deletion src/packages/frontend/sagews/chatgpt.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { redux } from "@cocalc/frontend/app-framework";
import { getHelp } from "@cocalc/frontend/frame-editors/llm/help-me-fix";
import { getValidLanguageModelName } from "@cocalc/util/db-schema/openai";
import { getValidLanguageModelName } from "@cocalc/util/db-schema/llm";
import { MARKERS } from "@cocalc/util/sagews";
import { SETTINGS_LANGUAGE_MODEL_KEY } from "../account/useLanguageModelSetting";

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Vendor } from "@cocalc/util/db-schema/openai";
import { Vendor } from "@cocalc/util/db-schema/llm";
import { unreachable } from "@cocalc/util/misc";
import A from "components/misc/A";

Expand Down
7 changes: 4 additions & 3 deletions src/packages/server/llm/abuse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ import { assertPurchaseAllowed } from "@cocalc/server/purchases/is-purchase-allo
import {
isFreeModel,
LanguageModel,
LanguageService,
model2service,
MODELS,
} from "@cocalc/util/db-schema/openai";
} from "@cocalc/util/db-schema/llm";
import { isValidUUID } from "@cocalc/util/misc";

const QUOTAS = {
Expand Down Expand Up @@ -73,7 +74,7 @@ export async function checkForAbuse({
// This is a for-pay product, so let's make sure user can purchase it.
await assertPurchaseAllowed({
account_id,
service: model2service(model),
service: model2service(model) as LanguageService,
});
// We always allow usage of for pay models, since the user is paying for
// them. Only free models need to be throttled.
Expand Down Expand Up @@ -113,7 +114,7 @@ export async function checkForAbuse({
// This is a for-pay product, so let's make sure user can purchase it.
await assertPurchaseAllowed({
account_id,
service: model2service(model),
service: model2service(model) as LanguageService,
});
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/packages/server/llm/call-llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { delay } from "awaiting";
import type OpenAI from "openai";

import getLogger from "@cocalc/backend/logger";
import { ModelOpenAI, OpenAIMessages } from "@cocalc/util/db-schema/openai";
import { ModelOpenAI, OpenAIMessages } from "@cocalc/util/db-schema/llm";
import { ChatOutput } from "@cocalc/util/types/llm";
import { Stream } from "openai/streaming";
import { totalNumTokens } from "./chatgpt-numtokens";
Expand Down
6 changes: 3 additions & 3 deletions src/packages/server/llm/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ Get the client for the given LanguageModel.
You do not have to worry too much about throwing an exception, because they're caught in ./index::evaluate
*/

import OpenAI from "openai";
import jsonStable from "json-stable-stringify";
import { Ollama } from "@langchain/community/llms/ollama";
import jsonStable from "json-stable-stringify";
import * as _ from "lodash";
import OpenAI from "openai";

import getLogger from "@cocalc/backend/logger";
import { getServerSettings } from "@cocalc/database/settings/server-settings";
import { LanguageModel, model2vendor } from "@cocalc/util/db-schema/openai";
import { LanguageModel, model2vendor } from "@cocalc/util/db-schema/llm";
import { unreachable } from "@cocalc/util/misc";
import { VertexAIClient } from "./vertex-ai-client";

Expand Down
7 changes: 4 additions & 3 deletions src/packages/server/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ import {
DEFAULT_MODEL,
LLM_USERNAMES,
LanguageModel,
LanguageService,
OpenAIMessages,
getLLMCost,
isFreeModel,
isValidModel,
model2service,
model2vendor,
} from "@cocalc/util/db-schema/openai";
} from "@cocalc/util/db-schema/llm";
import { ChatOptions, ChatOutput, History } from "@cocalc/util/types/llm";
import { checkForAbuse } from "./abuse";
import { callChatGPTAPI } from "./call-llm";
Expand Down Expand Up @@ -136,9 +137,9 @@ async function evaluateImpl({
account_id,
project_id,
cost,
service: model2service(model),
service: model2service(model) as LanguageService,
description: {
type: model2service(model),
type: model2service(model) as LanguageService,
prompt_tokens,
completion_tokens,
},
Expand Down
2 changes: 1 addition & 1 deletion src/packages/server/llm/vertex-ai-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*/

import getLogger from "@cocalc/backend/logger";
import { LanguageModel } from "@cocalc/util/db-schema/openai";
import { LanguageModel } from "@cocalc/util/db-schema/llm";
import { ChatOutput, History } from "@cocalc/util/types/llm";
import {
DiscussServiceClient,
Expand Down
2 changes: 1 addition & 1 deletion src/packages/server/purchases/get-service-cost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
getLLMCost,
isLanguageModelService,
service2model,
} from "@cocalc/util/db-schema/openai";
} from "@cocalc/util/db-schema/llm";
import type { Service } from "@cocalc/util/db-schema/purchases";
import { unreachable } from "@cocalc/util/misc";

Expand Down
2 changes: 1 addition & 1 deletion src/packages/server/purchases/is-purchase-allowed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
getMaxCost,
isLanguageModelService,
service2model,
} from "@cocalc/util/db-schema/openai";
} from "@cocalc/util/db-schema/llm";
import { QUOTA_SPEC, Service } from "@cocalc/util/db-schema/purchase-quotas";
import { MAX_COST } from "@cocalc/util/db-schema/purchases";
import { currency, round2up, round2down } from "@cocalc/util/misc";
Expand Down
13 changes: 13 additions & 0 deletions src/packages/util/db-schema/llm.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// this tests the wrongly named openai.ts file

import { isFreeModel } from "./llm";

describe("openai/llm", () => {
test("isFreeModel", () => {
expect(isFreeModel("gpt-3")).toBe(true);
expect(isFreeModel("gpt-4")).toBe(false);
// WARNING: if the following breaks, and ollama becomes non-free, then a couple of assumptions are broken as well.
// search for model2service(...) as LanguageService in the codebase!
expect(isFreeModel("ollama-1")).toBe(true);
});
});
Loading

0 comments on commit 71140d3

Please sign in to comment.