diff --git a/Sources/LLM/LLM.swift b/Sources/LLM/LLM.swift index 7610896..c7563de 100644 --- a/Sources/LLM/LLM.swift +++ b/Sources/LLM/LLM.swift @@ -87,7 +87,7 @@ open class LLM: ObservableObject { self.model = model self.history = history self.totalTokenCount = Int(llama_n_vocab(model)) - self.newlineToken = llama_token_nl(model) + self.newlineToken = model.newLineToken self.stopSequence = stopSequence?.utf8CString self.stopSequenceLength = (self.stopSequence?.count ?? 1) - 1 batch = llama_batch_init(Int32(self.maxTokenCount), 0, 1) @@ -179,9 +179,9 @@ open class LLM: ObservableObject { @InferenceActor private func predictNextToken() async -> Token { - guard shouldContinuePredicting else { return llama_token_eos(model) } + guard shouldContinuePredicting else { return model.endToken } let logits = llama_get_logits_ith(context.pointer, batch.n_tokens - 1)! - var candidates: [llama_token_data] = (0...Continuation) -> Bool { struct saved { - static var endIndex = 0 + static var stopSequenceEndIndex = 0 static var letters: [CChar] = [] } - guard token != llama_token_eos(model) else { return false } + guard token != model.endToken else { return false } var word = decode(token) guard let stopSequence else { output.yield(word); return true } - var found = 0 < saved.endIndex + var found = 0 < saved.stopSequenceEndIndex var letters: [CChar] = [] for letter in word.utf8CString { guard letter != 0 else { break } - if letter == stopSequence[saved.endIndex] { - saved.endIndex += 1 + if letter == stopSequence[saved.stopSequenceEndIndex] { + saved.stopSequenceEndIndex += 1 found = true saved.letters.append(letter) - guard saved.endIndex == stopSequenceLength else { continue } - saved.endIndex = 0 + guard saved.stopSequenceEndIndex == stopSequenceLength else { continue } + saved.stopSequenceEndIndex = 0 saved.letters.removeAll() return false } else if found { - saved.endIndex = 0 + saved.stopSequenceEndIndex = 0 if !saved.letters.isEmpty { word = String(cString: saved.letters + [0]) + word saved.letters.removeAll() @@ -364,6 +364,9 @@ open class LLM: ObservableObject { } extension Model { + public var endToken: Token { llama_token_eos(self) } + public var newLineToken: Token { llama_token_nl(self) } + public func shouldAddBOS() -> Bool { let addBOS = llama_add_bos_token(self); guard addBOS != -1 else {