diff --git a/chart/env/prod.yaml b/chart/env/prod.yaml index b97fa064c33..d9649f440f3 100644 --- a/chart/env/prod.yaml +++ b/chart/env/prod.yaml @@ -144,6 +144,9 @@ envVars: "websiteUrl": "https://qwenlm.github.io/blog/qwq-32b-preview/", "logoUrl": "https://huggingface.co/datasets/huggingchat/models-logo/resolve/main/qwen-logo.png", "description": "QwQ is an experiment model from the Qwen Team with advanced reasoning capabilities.", + "reasoning": { + "type": "summarize" + }, "parameters": { "stop": ["<|im_end|>"], "truncate": 12288, diff --git a/src/lib/components/OpenWebSearchResults.svelte b/src/lib/components/OpenWebSearchResults.svelte index 4591367405c..44a569a3fe2 100644 --- a/src/lib/components/OpenWebSearchResults.svelte +++ b/src/lib/components/OpenWebSearchResults.svelte @@ -9,7 +9,6 @@ import EosIconsLoading from "~icons/eos-icons/loading"; import IconInternet from "./icons/IconInternet.svelte"; - export let classNames = ""; export let webSearchMessages: MessageWebSearchUpdate[] = []; $: sources = webSearchMessages.find(isMessageWebSearchSourcesUpdate)?.sources; @@ -23,7 +22,7 @@
type === "webSearch") ?? + $: searchUpdates = (message.updates?.filter(({ type }) => type === MessageUpdateType.WebSearch) ?? []) as MessageWebSearchUpdate[]; + $: reasoningUpdates = (message.updates?.filter( + ({ type }) => type === MessageUpdateType.Reasoning + ) ?? []) as MessageReasoningUpdate[]; + $: messageFinalAnswer = message.updates?.find( ({ type }) => type === MessageUpdateType.FinalAnswer ) as MessageFinalAnswerUpdate; @@ -208,9 +215,17 @@
{/if} {#if searchUpdates && searchUpdates.length > 0} - + {/if} + {#if reasoningUpdates && reasoningUpdates.length > 0} + {@const summaries = reasoningUpdates + .filter((u) => u.subtype === MessageReasoningUpdateType.Status) + .map((u) => u.status)} + + {/if} @@ -224,11 +239,19 @@ {/each} {/if} -
+
0 || searchUpdates.length > 0} + > {#if isLast && loading && $settings.disableStream} {/if} - + +
+ +
diff --git a/src/lib/components/chat/MarkdownRenderer.svelte b/src/lib/components/chat/MarkdownRenderer.svelte index 34076dee4e2..357bfb39a8e 100644 --- a/src/lib/components/chat/MarkdownRenderer.svelte +++ b/src/lib/components/chat/MarkdownRenderer.svelte @@ -106,21 +106,17 @@ }); -
- {#each tokens as token} - {#if token.type === "code"} - - {:else} - {@const parsed = marked.parse(processLatex(escapeHTML(token.raw)), options)} - {#await parsed then parsed} - - {@html DOMPurify.sanitize(parsed)} - {/await} - {/if} - {/each} -
+{#each tokens as token} + {#if token.type === "code"} + + {:else} + {@const parsed = marked.parse(processLatex(escapeHTML(token.raw)), options)} + {#await parsed then parsed} + + {@html DOMPurify.sanitize(parsed)} + {/await} + {/if} +{/each} diff --git a/src/lib/server/generateFromDefaultEndpoint.ts b/src/lib/server/generateFromDefaultEndpoint.ts index 4f798f90f51..48f0110b9e1 100644 --- a/src/lib/server/generateFromDefaultEndpoint.ts +++ b/src/lib/server/generateFromDefaultEndpoint.ts @@ -1,7 +1,8 @@ import { smallModel } from "$lib/server/models"; +import { MessageUpdateType, type MessageUpdate } from "$lib/types/MessageUpdate"; import type { EndpointMessage } from "./endpoints/endpoints"; -export async function generateFromDefaultEndpoint({ +export async function* generateFromDefaultEndpoint({ messages, preprompt, generateSettings, @@ -9,7 +10,7 @@ export async function generateFromDefaultEndpoint({ messages: EndpointMessage[]; preprompt?: string; generateSettings?: Record; -}): Promise { +}): AsyncGenerator { const endpoint = await smallModel.getEndpoint(); const tokenStream = await endpoint({ messages, preprompt, generateSettings }); @@ -25,6 +26,10 @@ export async function generateFromDefaultEndpoint({ } return generated_text; } + yield { + type: MessageUpdateType.Stream, + token: output.token.text, + }; } throw new Error("Generation failed"); } diff --git a/src/lib/server/models.ts b/src/lib/server/models.ts index 888dc6738f2..43c73dc4084 100644 --- a/src/lib/server/models.ts +++ b/src/lib/server/models.ts @@ -17,6 +17,21 @@ import { isHuggingChat } from "$lib/utils/isHuggingChat"; type Optional = Pick, K> & Omit; +const reasoningSchema = z.union([ + z.object({ + type: z.literal("regex"), // everything is reasoning, extract the answer from the regex + regex: z.string(), + }), + z.object({ + type: z.literal("tokens"), // use beginning and end tokens that define the reasoning portion of the answer + beginToken: z.string(), + endToken: z.string(), + }), + z.object({ + type: z.literal("summarize"), // everything is reasoning, summarize the answer + }), +]); + const modelConfig = z.object({ /** Used as an identifier in DB */ id: z.string().optional(), @@ -70,6 +85,7 @@ const modelConfig = z.object({ embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(), /** Used to enable/disable system prompt usage */ systemRoleSupported: z.boolean().default(true), + reasoning: reasoningSchema.optional(), }); const modelsRaw = z.array(modelConfig).parse(JSON5.parse(env.MODELS)); diff --git a/src/lib/server/textGeneration/generate.ts b/src/lib/server/textGeneration/generate.ts index 6f3d13def1a..48e38d69cfd 100644 --- a/src/lib/server/textGeneration/generate.ts +++ b/src/lib/server/textGeneration/generate.ts @@ -1,8 +1,14 @@ import type { ToolResult } from "$lib/types/Tool"; -import { MessageUpdateType, type MessageUpdate } from "$lib/types/MessageUpdate"; +import { + MessageReasoningUpdateType, + MessageUpdateType, + type MessageUpdate, +} from "$lib/types/MessageUpdate"; import { AbortedGenerations } from "../abortedGenerations"; import type { TextGenerationContext } from "./types"; import type { EndpointMessage } from "../endpoints/endpoints"; +import { generateFromDefaultEndpoint } from "../generateFromDefaultEndpoint"; +import { generateSummaryOfReasoning } from "./reasoning"; type GenerateContext = Omit & { messages: EndpointMessage[] }; @@ -11,6 +17,26 @@ export async function* generate( toolResults: ToolResult[], preprompt?: string ): AsyncIterable { + // reasoning mode is false by default + let reasoning = false; + let reasoningBuffer = ""; + let lastReasoningUpdate = new Date(); + let status = ""; + const startTime = new Date(); + if ( + model.reasoning && + (model.reasoning.type === "regex" || model.reasoning.type === "summarize") + ) { + // if the model has reasoning in regex or summarize mode, it starts in reasoning mode + // and we extract the answer from the reasoning + reasoning = true; + yield { + type: MessageUpdateType.Reasoning, + subtype: MessageReasoningUpdateType.Status, + status: "Started reasoning...", + }; + } + for await (const output of await endpoint({ messages, preprompt, @@ -33,20 +59,102 @@ export async function* generate( text = text.slice(0, text.length - stopToken.length); } + let finalAnswer = text; + if (model.reasoning && model.reasoning.type === "regex") { + const regex = new RegExp(model.reasoning.regex); + finalAnswer = regex.exec(reasoningBuffer)?.[1] ?? text; + } else if (model.reasoning && model.reasoning.type === "summarize") { + yield { + type: MessageUpdateType.Reasoning, + subtype: MessageReasoningUpdateType.Status, + status: "Summarizing reasoning...", + }; + const summary = yield* generateFromDefaultEndpoint({ + messages: [ + { + from: "user", + content: `Question: ${ + messages[messages.length - 1].content + }\n\nReasoning: ${reasoningBuffer}`, + }, + ], + preprompt: `Your task is to summarize concisely all your reasoning steps and then give the final answer. Keep it short, one short paragraph at most. If the final solution includes code, make sure to include it in your answer. + +If the user is just having a casual conversation that doesn't require explanations, answer directly without explaining your steps, otherwise make sure to summarize step by step, make sure to skip dead-ends in your reasoning and removing excess detail. + +Do not use prefixes such as Response: or Answer: when answering to the user.`, + generateSettings: { + max_new_tokens: 1024, + }, + }); + finalAnswer = summary; + yield { + type: MessageUpdateType.Reasoning, + subtype: MessageReasoningUpdateType.Status, + status: `Done in ${Math.round((new Date().getTime() - startTime.getTime()) / 1000)}s.`, + }; + } + yield { type: MessageUpdateType.FinalAnswer, - text, + text: finalAnswer, interrupted, webSources: output.webSources, }; continue; } + if (model.reasoning && model.reasoning.type === "tokens") { + if (output.token.text === model.reasoning.beginToken) { + reasoning = true; + reasoningBuffer += output.token.text; + yield { + type: MessageUpdateType.Reasoning, + subtype: MessageReasoningUpdateType.Status, + status: "Started thinking...", + }; + } else if (output.token.text === model.reasoning.endToken) { + reasoning = false; + reasoningBuffer += output.token.text; + yield { + type: MessageUpdateType.Reasoning, + subtype: MessageReasoningUpdateType.Status, + status: `Done in ${Math.round((new Date().getTime() - startTime.getTime()) / 1000)}s.`, + }; + } + } // ignore special tokens if (output.token.special) continue; // pass down normal token - yield { type: MessageUpdateType.Stream, token: output.token.text }; + if (reasoning) { + reasoningBuffer += output.token.text; + + // yield status update if it has changed + if (status !== "") { + yield { + type: MessageUpdateType.Reasoning, + subtype: MessageReasoningUpdateType.Status, + status, + }; + status = ""; + } + + // create a new status every 5 seconds + if (new Date().getTime() - lastReasoningUpdate.getTime() > 4000) { + lastReasoningUpdate = new Date(); + generateSummaryOfReasoning(reasoningBuffer).then((summary) => { + status = summary; + }); + } + yield { + type: MessageUpdateType.Reasoning, + subtype: MessageReasoningUpdateType.Stream, + token: output.token.text, + }; + } else { + yield { type: MessageUpdateType.Stream, token: output.token.text }; + } // abort check const date = AbortedGenerations.getInstance().getList().get(conv._id.toString()); diff --git a/src/lib/server/textGeneration/reasoning.ts b/src/lib/server/textGeneration/reasoning.ts new file mode 100644 index 00000000000..58167d4b72a --- /dev/null +++ b/src/lib/server/textGeneration/reasoning.ts @@ -0,0 +1,30 @@ +import { generateFromDefaultEndpoint } from "../generateFromDefaultEndpoint"; + +import { getReturnFromGenerator } from "$lib/utils/getReturnFromGenerator"; + +export async function generateSummaryOfReasoning(buffer: string): Promise { + // debug 5s delay + await new Promise((resolve) => setTimeout(resolve, 3000)); + + const summary = await getReturnFromGenerator( + generateFromDefaultEndpoint({ + messages: [ + { + from: "user", + content: buffer.slice(-200), + }, + ], + preprompt: `You are tasked with summarizing the latest reasoning steps. Never describe results of the reasoning, only the process. Remain vague in your summary. + The text might be incomplete, try your best to summarize it in one very short sentence, starting with a gerund and ending with three points. + Example: "Thinking about life...", "Summarizing the results...", "Processing the input..."`, + generateSettings: { + max_new_tokens: 50, + }, + }) + ).then((summary) => { + const parts = summary.split("..."); + return parts[0] + "..."; + }); + + return summary; +} diff --git a/src/lib/server/textGeneration/title.ts b/src/lib/server/textGeneration/title.ts index 24141d5d424..5502ac7a94e 100644 --- a/src/lib/server/textGeneration/title.ts +++ b/src/lib/server/textGeneration/title.ts @@ -4,6 +4,7 @@ import type { EndpointMessage } from "../endpoints/endpoints"; import { logger } from "$lib/server/logger"; import { MessageUpdateType, type MessageUpdate } from "$lib/types/MessageUpdate"; import type { Conversation } from "$lib/types/Conversation"; +import { getReturnFromGenerator } from "$lib/utils/getReturnFromGenerator"; export async function* generateTitleForConversation( conv: Conversation @@ -55,14 +56,16 @@ export async function generateTitle(prompt: string) { { from: "user", content: prompt }, ]; - return await generateFromDefaultEndpoint({ - messages, - preprompt: - "You are a summarization AI. Summarize the user's request into a single short sentence of four words or less. Do not try to answer it, only summarize the user's query. Always start your answer with an emoji relevant to the summary", - generateSettings: { - max_new_tokens: 15, - }, - }) + return await getReturnFromGenerator( + generateFromDefaultEndpoint({ + messages, + preprompt: + "You are a summarization AI. Summarize the user's request into a single short sentence of four words or less. Do not try to answer it, only summarize the user's query. Always start your answer with an emoji relevant to the summary", + generateSettings: { + max_new_tokens: 15, + }, + }) + ) .then((summary) => { // add an emoji if none is found in the first three characters if (!/\p{Emoji}/u.test(summary.slice(0, 3))) { diff --git a/src/lib/server/websearch/search/generateQuery.ts b/src/lib/server/websearch/search/generateQuery.ts index c71841a8c17..70cac567861 100644 --- a/src/lib/server/websearch/search/generateQuery.ts +++ b/src/lib/server/websearch/search/generateQuery.ts @@ -2,6 +2,7 @@ import type { Message } from "$lib/types/Message"; import { format } from "date-fns"; import type { EndpointMessage } from "../../endpoints/endpoints"; import { generateFromDefaultEndpoint } from "../../generateFromDefaultEndpoint"; +import { getReturnFromGenerator } from "$lib/utils/getReturnFromGenerator"; export async function generateQuery(messages: Message[]) { const currentDate = format(new Date(), "MMMM d, yyyy"); @@ -62,13 +63,15 @@ Current Question: Where is it being hosted?`, }, ]; - const webQuery = await generateFromDefaultEndpoint({ - messages: convQuery, - preprompt: `You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is ${currentDate}`, - generateSettings: { - max_new_tokens: 30, - }, - }); + const webQuery = await getReturnFromGenerator( + generateFromDefaultEndpoint({ + messages: convQuery, + preprompt: `You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is ${currentDate}`, + generateSettings: { + max_new_tokens: 30, + }, + }) + ); return webQuery.trim(); } diff --git a/src/lib/types/Message.ts b/src/lib/types/Message.ts index 3b4582197e7..2b0f567dcc0 100644 --- a/src/lib/types/Message.ts +++ b/src/lib/types/Message.ts @@ -10,6 +10,8 @@ export type Message = Partial & { updates?: MessageUpdate[]; webSearchId?: WebSearch["_id"]; // legacy version webSearch?: WebSearch; + + reasoning?: string; score?: -1 | 0 | 1; /** * Either contains the base64 encoded image data diff --git a/src/lib/types/MessageUpdate.ts b/src/lib/types/MessageUpdate.ts index 5600b79bcaf..17a953289dc 100644 --- a/src/lib/types/MessageUpdate.ts +++ b/src/lib/types/MessageUpdate.ts @@ -8,7 +8,8 @@ export type MessageUpdate = | MessageWebSearchUpdate | MessageStreamUpdate | MessageFileUpdate - | MessageFinalAnswerUpdate; + | MessageFinalAnswerUpdate + | MessageReasoningUpdate; export enum MessageUpdateType { Status = "status", @@ -18,6 +19,7 @@ export enum MessageUpdateType { Stream = "stream", File = "file", FinalAnswer = "finalAnswer", + Reasoning = "reasoning", } // Status @@ -114,6 +116,25 @@ export interface MessageStreamUpdate { type: MessageUpdateType.Stream; token: string; } + +export enum MessageReasoningUpdateType { + Stream = "stream", + Status = "status", +} + +export type MessageReasoningUpdate = MessageReasoningStreamUpdate | MessageReasoningStatusUpdate; + +export interface MessageReasoningStreamUpdate { + type: MessageUpdateType.Reasoning; + subtype: MessageReasoningUpdateType.Stream; + token: string; +} +export interface MessageReasoningStatusUpdate { + type: MessageUpdateType.Reasoning; + subtype: MessageReasoningUpdateType.Status; + status: string; +} + export interface MessageFileUpdate { type: MessageUpdateType.File; name: string; diff --git a/src/lib/utils/getReturnFromGenerator.ts b/src/lib/utils/getReturnFromGenerator.ts new file mode 100644 index 00000000000..cfb3283cba5 --- /dev/null +++ b/src/lib/utils/getReturnFromGenerator.ts @@ -0,0 +1,7 @@ +export async function getReturnFromGenerator(generator: AsyncGenerator): Promise { + let result: IteratorResult; + do { + result = await generator.next(); + } while (!result.done); // Keep calling `next()` until `done` is true + return result.value; // Return the final value +} diff --git a/src/routes/conversation/[id]/+page.svelte b/src/routes/conversation/[id]/+page.svelte index 4f5f8314f2f..0d82ee6168a 100644 --- a/src/routes/conversation/[id]/+page.svelte +++ b/src/routes/conversation/[id]/+page.svelte @@ -12,9 +12,9 @@ import { webSearchParameters } from "$lib/stores/webSearchParameters"; import type { Message } from "$lib/types/Message"; import { + MessageReasoningUpdateType, MessageUpdateStatus, MessageUpdateType, - type MessageUpdate, } from "$lib/types/MessageUpdate"; import titleUpdate from "$lib/stores/titleUpdate"; import file2base64 from "$lib/utils/file2base64"; @@ -215,8 +215,6 @@ files = []; - const messageUpdates: MessageUpdate[] = []; - for await (const update of messageUpdatesIterator) { if ($isAborted) { messageUpdatesAbortController.abort(); @@ -229,7 +227,7 @@ update.token = update.token.replaceAll("\0", ""); } - messageUpdates.push(update); + messageToWriteTo.updates = [...(messageToWriteTo.updates ?? []), update]; if (update.type === MessageUpdateType.Stream && !$settings.disableStream) { messageToWriteTo.content += update.token; @@ -239,7 +237,6 @@ update.type === MessageUpdateType.WebSearch || update.type === MessageUpdateType.Tool ) { - messageToWriteTo.updates = [...(messageToWriteTo.updates ?? []), update]; messages = [...messages]; } else if ( update.type === MessageUpdateType.Status && @@ -262,10 +259,18 @@ { type: "hash", value: update.sha, mime: update.mime, name: update.name }, ]; messages = [...messages]; + } else if (update.type === MessageUpdateType.Reasoning) { + if (!messageToWriteTo.reasoning) { + messageToWriteTo.reasoning = ""; + } + if (update.subtype === MessageReasoningUpdateType.Stream) { + messageToWriteTo.reasoning += update.token; + } else { + messageToWriteTo.updates = [...(messageToWriteTo.updates ?? []), update]; + } + messages = [...messages]; } } - - messageToWriteTo.updates = messageUpdates; } catch (err) { if (err instanceof Error && err.message.includes("overloaded")) { $error = "Too much traffic, please try again."; diff --git a/src/routes/conversation/[id]/+server.ts b/src/routes/conversation/[id]/+server.ts index b73f4a136d1..a7ec44f6eb5 100644 --- a/src/routes/conversation/[id]/+server.ts +++ b/src/routes/conversation/[id]/+server.ts @@ -9,6 +9,7 @@ import { error } from "@sveltejs/kit"; import { ObjectId } from "mongodb"; import { z } from "zod"; import { + MessageReasoningUpdateType, MessageUpdateStatus, MessageUpdateType, type MessageUpdate, @@ -355,6 +356,12 @@ export async function POST({ request, locals, params, getClientAddress }) { Date.now() - (lastTokenTimestamp ?? promptedAt).getTime() ); lastTokenTimestamp = new Date(); + } else if ( + event.type === MessageUpdateType.Reasoning && + event.subtype === MessageReasoningUpdateType.Stream + ) { + messageToWriteTo.reasoning ??= ""; + messageToWriteTo.reasoning += event.token; } // Set the title @@ -392,6 +399,10 @@ export async function POST({ request, locals, params, getClientAddress }) { !( event.type === MessageUpdateType.Status && event.status === MessageUpdateStatus.KeepAlive + ) && + !( + event.type === MessageUpdateType.Reasoning && + event.subtype === MessageReasoningUpdateType.Stream ) ) { messageToWriteTo?.updates?.push(event);