Skip to content

Commit

Permalink
improve readability
Browse files Browse the repository at this point in the history
  • Loading branch information
eastriverlee committed Jan 30, 2024
1 parent 6628e7e commit 9929934
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions Sources/LLM/LLM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ open class LLM: ObservableObject {
self.model = model
self.history = history
self.totalTokenCount = Int(llama_n_vocab(model))
self.newlineToken = llama_token_nl(model)
self.newlineToken = model.newLineToken
self.stopSequence = stopSequence?.utf8CString
self.stopSequenceLength = (self.stopSequence?.count ?? 1) - 1
batch = llama_batch_init(Int32(self.maxTokenCount), 0, 1)
Expand Down Expand Up @@ -179,9 +179,9 @@ open class LLM: ObservableObject {

@InferenceActor
private func predictNextToken() async -> Token {
guard shouldContinuePredicting else { return llama_token_eos(model) }
guard shouldContinuePredicting else { return model.endToken }
let logits = llama_get_logits_ith(context.pointer, batch.n_tokens - 1)!
var candidates: [llama_token_data] = (0..<totalTokenCount).map { token in
var candidates = (0..<totalTokenCount).map { token in
llama_token_data(id: Int32(token), logit: logits[token], p: 0.0)
}
var token: llama_token!
Expand Down Expand Up @@ -257,26 +257,26 @@ open class LLM: ObservableObject {

private func process(_ token: Token, to output: borrowing AsyncStream<String>.Continuation) -> Bool {
struct saved {
static var endIndex = 0
static var stopSequenceEndIndex = 0
static var letters: [CChar] = []
}
guard token != llama_token_eos(model) else { return false }
guard token != model.endToken else { return false }
var word = decode(token)
guard let stopSequence else { output.yield(word); return true }
var found = 0 < saved.endIndex
var found = 0 < saved.stopSequenceEndIndex
var letters: [CChar] = []
for letter in word.utf8CString {
guard letter != 0 else { break }
if letter == stopSequence[saved.endIndex] {
saved.endIndex += 1
if letter == stopSequence[saved.stopSequenceEndIndex] {
saved.stopSequenceEndIndex += 1
found = true
saved.letters.append(letter)
guard saved.endIndex == stopSequenceLength else { continue }
saved.endIndex = 0
guard saved.stopSequenceEndIndex == stopSequenceLength else { continue }
saved.stopSequenceEndIndex = 0
saved.letters.removeAll()
return false
} else if found {
saved.endIndex = 0
saved.stopSequenceEndIndex = 0
if !saved.letters.isEmpty {
word = String(cString: saved.letters + [0]) + word
saved.letters.removeAll()
Expand Down Expand Up @@ -364,6 +364,9 @@ open class LLM: ObservableObject {
}

extension Model {
public var endToken: Token { llama_token_eos(self) }
public var newLineToken: Token { llama_token_nl(self) }

public func shouldAddBOS() -> Bool {
let addBOS = llama_add_bos_token(self);
guard addBOS != -1 else {
Expand Down

0 comments on commit 9929934

Please sign in to comment.