Skip to content

Commit

Permalink
fix: inaccurate context length calculation when using timestamps and …
Browse files Browse the repository at this point in the history
…names
  • Loading branch information
Vali-98 committed Aug 6, 2024
1 parent 350b52e commit c30dfe5
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 11 deletions.
4 changes: 2 additions & 2 deletions app/components/Endpoint/Local.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Llama, LlamaPreset } from 'app/constants/LlamaLocal'
import { AppSettings, Global, Logger, Style } from '@globals'
import { Llama, LlamaPreset } from 'app/constants/LlamaLocal'
import { useEffect, useState } from 'react'
import {
View,
Expand Down Expand Up @@ -227,7 +227,7 @@ const Local = () => {
varname="context_length"
min={512}
max={32768}
step={32}
step={512}
/>
<SliderItem
name="Threads"
Expand Down
54 changes: 45 additions & 9 deletions app/constants/APIState/BaseAPI.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { AppSettings, Global } from '@constants/GlobalValues'
import { Llama } from '@constants/LlamaLocal'
import { Tokenizer } from '@constants/Tokenizer'
import { API } from '@globals'
import { Characters } from 'app/constants/Characters'
import { Chats, useInference } from 'app/constants/Chat'
import { InstructType, Instructs } from 'app/constants/Instructs'
Expand Down Expand Up @@ -55,6 +58,12 @@ export abstract class APIBase implements IAPIBase {

buildTextCompletionContext = (max_length: number) => {
const delta = performance.now()

const tokenizer =
mmkv.getString(Global.APIType) === API.LOCAL
? Llama.useLlama.getState().tokenLength
: Tokenizer.useTokenizer.getState().getTokenCount

const messages = [...(Chats.useChat.getState().data?.messages ?? [])]

const currentInstruct = Instructs.useInstruct.getState().replacedMacros()
Expand All @@ -71,6 +80,7 @@ export abstract class APIBase implements IAPIBase {
const user_card_data = (userCard?.data?.description ?? '').trim()
const char_card_data = (currentCard?.data?.description ?? '').trim()
let payload = ``

// set suffix length as its always added
let payload_length = instructCache.system_suffix_length
if (currentInstruct.system_prefix) {
Expand All @@ -95,41 +105,67 @@ export abstract class APIBase implements IAPIBase {
let message_acc_length = 0
let is_last = true
let index = messages.length - 1

const wrap_string = `\n`
const wrap_length = currentInstruct.wrap ? tokenizer(wrap_string) : 0

// we require lengths for names if use_names is enabled
for (const message of messages?.reverse() ?? []) {
const swipe_len = Chats.useChat.getState().getTokenCount(index)
// for last message, we want to skip the end token to allow the LLM to generate
const swipe_data = message.swipes[message.swipe_id]

/** Accumulate total string length
* The context builder MUST retain context length below the
* context limit, especially for local gens to prevent truncation
* **/

let instruct_len = message.is_user
? instructCache.input_prefix_length
: instructCache.output_suffix_length

// for last message, we want to skip the end token to allow the LLM to generate

if (!is_last)
instruct_len += message.is_user
? instructCache.input_suffix_length
: instructCache.output_suffix_length
const shard_length = swipe_len + instruct_len

const timestamp_string = `[${swipe_data.send_date.toString().split(' ')[0]} ${swipe_data.send_date.toLocaleTimeString()}]\n`
const timestamp_length = currentInstruct.timestamp ? tokenizer(timestamp_string) : 0

const name_string = `${message.name} :`
const name_length = currentInstruct.names ? tokenizer(name_string) : 0

const shard_length =
swipe_len + instruct_len + name_length + timestamp_length + wrap_length

// check if within context window

if (message_acc_length + payload_length + shard_length > max_length) {
break
}

// apply strings

let message_shard = `${message.is_user ? currentInstruct.input_prefix : currentInstruct.output_prefix}`

const swipe_data = message.swipes[message.swipe_id]
if (currentInstruct.timestamp) message_shard += timestamp_string

if (currentInstruct.timestamp)
message_shard += `[${swipe_data.send_date.toString().split(' ')[0]} ${swipe_data.send_date.toLocaleTimeString()}]\n`
if (currentInstruct.names) message_shard += message.name + ': '
if (currentInstruct.names) message_shard += name_string

message_shard += swipe_data.swipe

if (!is_last) {
message_shard += `${message.is_user ? currentInstruct.input_suffix : currentInstruct.output_suffix}`
}
// ensure no more is_last checks after this
is_last = false

if (currentInstruct.wrap) {
message_shard += `\n`
message_shard += wrap_string
}

// ensure no more is_last checks after this
is_last = false

message_acc_length += shard_length
message_acc = message_shard + message_acc
index--
Expand Down

0 comments on commit c30dfe5

Please sign in to comment.