diff --git a/Sources/llmfarm_core/LLMBase.swift b/Sources/llmfarm_core/LLMBase.swift index 17bd7c9..161bf70 100644 --- a/Sources/llmfarm_core/LLMBase.swift +++ b/Sources/llmfarm_core/LLMBase.swift @@ -592,7 +592,7 @@ public class LLMBase { // return Array(UnsafeBufferPointer(start: embeddings, count: embeddingsCount)) // } - public func llm_tokenize(_ input: String, bos: Bool = false, eos: Bool = false) -> [ModelToken] { + public func llm_tokenize(_ input: String, bos: Bool = true, eos: Bool = false) -> [ModelToken] { if input.count == 0 { return [] } @@ -620,7 +620,11 @@ public class LLMBase { case .Custom: var formated_input = self.custom_prompt_format.replacingOccurrences(of: "{{prompt}}", with: input) formated_input = formated_input.replacingOccurrences(of: "\\n", with: "\n") - return llm_tokenize(formated_input, bos: true) + var bos = true + if formated_input.contains(""){ + bos = false + } + return llm_tokenize(formated_input, bos: bos) case .ChatBase: return llm_tokenize(": " + input + "\n:") case .OpenAssistant: diff --git a/Sources/llmfarm_core/LLaMa.swift b/Sources/llmfarm_core/LLaMa.swift index c12829d..b8494ab 100644 --- a/Sources/llmfarm_core/LLaMa.swift +++ b/Sources/llmfarm_core/LLaMa.swift @@ -129,7 +129,7 @@ public class LLaMa: LLMBase { // bool add_bos) let n_tokens = Int32(input.utf8.count) + (bos == true ? 1 : 0) var embeddings: [llama_token] = Array(repeating: llama_token(), count: input.utf8.count) - let n = llama_tokenize(self.model, input, Int32(input.utf8.count), &embeddings, n_tokens, bos, false) + let n = llama_tokenize(self.model, input, Int32(input.utf8.count), &embeddings, n_tokens, bos, true) if n<=0{ return [] } diff --git a/Sources/llmfarm_core/Starcoder.swift b/Sources/llmfarm_core/Starcoder.swift index bddc4b6..58d7349 100644 --- a/Sources/llmfarm_core/Starcoder.swift +++ b/Sources/llmfarm_core/Starcoder.swift @@ -20,7 +20,7 @@ public class Starcoder: LLMBase { } deinit { - gpt2_free(context) + starcoder_free(context) } public override func llm_eval(inputBatch:[ModelToken]) throws -> Bool{