Skip to content

Commit

Permalink
add special token support like <s>
Browse files Browse the repository at this point in the history
fix santacoder(bin) crash
  • Loading branch information
guinmoon committed Oct 20, 2023
1 parent e4e2c92 commit 719d40e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions Sources/llmfarm_core/LLMBase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ public class LLMBase {
// return Array(UnsafeBufferPointer(start: embeddings, count: embeddingsCount))
// }

public func llm_tokenize(_ input: String, bos: Bool = false, eos: Bool = false) -> [ModelToken] {
public func llm_tokenize(_ input: String, bos: Bool = true, eos: Bool = false) -> [ModelToken] {
if input.count == 0 {
return []
}
Expand Down Expand Up @@ -620,7 +620,11 @@ public class LLMBase {
case .Custom:
var formated_input = self.custom_prompt_format.replacingOccurrences(of: "{{prompt}}", with: input)
formated_input = formated_input.replacingOccurrences(of: "\\n", with: "\n")
return llm_tokenize(formated_input, bos: true)
var bos = true
if formated_input.contains("<s>"){
bos = false
}
return llm_tokenize(formated_input, bos: bos)
case .ChatBase:
return llm_tokenize("<human>: " + input + "\n<bot>:")
case .OpenAssistant:
Expand Down
2 changes: 1 addition & 1 deletion Sources/llmfarm_core/LLaMa.swift
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public class LLaMa: LLMBase {
// bool add_bos)
let n_tokens = Int32(input.utf8.count) + (bos == true ? 1 : 0)
var embeddings: [llama_token] = Array<llama_token>(repeating: llama_token(), count: input.utf8.count)
let n = llama_tokenize(self.model, input, Int32(input.utf8.count), &embeddings, n_tokens, bos, false)
let n = llama_tokenize(self.model, input, Int32(input.utf8.count), &embeddings, n_tokens, bos, true)
if n<=0{
return []
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/llmfarm_core/Starcoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class Starcoder: LLMBase {
}

deinit {
gpt2_free(context)
starcoder_free(context)
}

public override func llm_eval(inputBatch:[ModelToken]) throws -> Bool{
Expand Down

0 comments on commit 719d40e

Please sign in to comment.