Skip to content

Commit

Permalink
rename preprocess and postprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
eastriverlee committed Mar 27, 2024
1 parent 4cd6fa3 commit a86ad40
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 39 deletions.
32 changes: 16 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
if you've already bundled your model:
```swift
let bot = LLM(from: Bundle.main.url(forResource: "mistral-7b-instruct-v0.2.Q4_K_M", withExtension: "gguf")!, template: .mistral)
let question = bot.preProcess("What's the meaning of life?", [])
let question = bot.preprocess("What's the meaning of life?", [])
let answer = await bot.getCompletion(from: question)
print(answer)
```
if you want to initialize with huggingface model:
```swift
let systemPrompt = "You are a sentient AI with emotions."
let bot = await LLM(from: HuggingFaceModel("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", .Q2_K, template: .chatML(systemPrompt)))
let question = bot.preProcess("What's the meaning of life?", [])
let question = bot.preprocess("What's the meaning of life?", [])
let answer = await bot.getCompletion(from: question)
print(answer)
```
Expand Down Expand Up @@ -146,22 +146,22 @@ public enum Role {

```swift
public var history: [Chat]
public var preProcess: (_ input: String, _ history: [Chat]) -> String = { input, _ in return input }
public var postProcess: (_ output: String) -> Void = { print($0) }
public var preprocess: (_ input: String, _ history: [Chat]) -> String = { input, _ in return input }
public var postprocess: (_ output: String) -> Void = { print($0) }
public var update: (_ outputDelta: String?) -> Void = { _ in }

public func respond(to input: String, with makeOutputFrom: @escaping (AsyncStream<String>) async -> String) async {
guard isAvailable else { return }
isAvailable = false
self.input = input
let processedInput = preProcess(input, history)
let processedInput = preprocess(input, history)
let response = getResponse(from: processedInput)
let output = await makeOutputFrom(response)
history += [(.user, input), (.bot, output)]
if historyLimit < history.count {
history.removeFirst(2)
}
postProcess(output)
postprocess(output)
isAvailable = true
}

Expand All @@ -183,13 +183,13 @@ open func respond(to input: String) async {
> as you can see, `func respond(to input: String) async` has an `open` access, so that you can override it when your class inherits `LLM` and want some fine grained control over.
there are three functions users can define when initializing `LLM` class:
* `var preProcess: (_ input: String, _ history: [Chat]) -> String`
* `var postProcess: (_ output: String) -> Void`
* `var preprocess: (_ input: String, _ history: [Chat]) -> String`
* `var postprocess: (_ output: String) -> Void`
* `var update: (_ outputDelta: String?) -> Void`
they are used in `respond` function.

### preProcess
`preProcess` is commonly used for making the user input conform to a chat template. if you don't provide this, `LLM` will just work as a completion model.
### preprocess
`preprocess` is commonly used for making the user input conform to a chat template. if you don't provide this, `LLM` will just work as a completion model.

for example, this is the `ChatML` template, that is adopted by many chat models:
```
Expand All @@ -202,7 +202,7 @@ USER PROMPT<|im_end|>

to use this chat format, you should use a function that goes like this:
```swift
self.preProcess = { input, history in
self.preprocess = { input, history in
let systemPrompt = "You are a sentient AI with emotions."
var processed = "<|im_start|>system\n\(systemPrompt)<|im_end|>\n"
for chat in history {
Expand All @@ -218,10 +218,10 @@ but for convenience, you can use `Template` struct that is specifically made for
```swift
// you can use the static function that is already available for this:

self.preProcess = Template.chatML("You are a sentient AI with emotions.").preProcess
self.preprocess = Template.chatML("You are a sentient AI with emotions.").preprocess

// or even better
// you can set [template] property right away, so that it handles [preProcess] and [stopSequence] both:
// you can set [template] property right away, so that it handles [preprocess] and [stopSequence] both:

self.template = .chatML("You are a sentient AI with emotions.")

Expand All @@ -236,10 +236,10 @@ self.template = Template(
)
```
> [!TIP]
> checking `LLMTests.swift` will help you understand how `preProcess` works better.
> checking `LLMTests.swift` will help you understand how `preprocess` works better.
### postProcess
`postProcess` can be used for executing according to the `output` just made using user input.
### postprocess
`postprocess` can be used for executing according to the `output` just made using user input.
the default is set to `{ print($0) }`, so that it will print the output when it's finished generating by meeting `EOS` or `stopSequence`.
this has many usages. for instance, this can be used to implement your own function calling logic.

Expand Down
22 changes: 11 additions & 11 deletions Sources/LLM/LLM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ public typealias Chat = (role: Role, content: String)
open class LLM: ObservableObject {
public var model: Model
public var history: [Chat]
public var preProcess: (_ input: String, _ history: [Chat]) -> String = { input, _ in return input }
public var postProcess: (_ output: String) -> Void = { print($0) }
public var preprocess: (_ input: String, _ history: [Chat]) -> String = { input, _ in return input }
public var postprocess: (_ output: String) -> Void = { print($0) }
public var update: (_ outputDelta: String?) -> Void = { _ in }
public var template: Template? = nil {
didSet {
guard let template else {
preProcess = { input, _ in return input }
preprocess = { input, _ in return input }
stopSequence = nil
stopSequenceLength = 0
return
}
preProcess = template.preProcess
preprocess = template.preprocess
if let stopSequence = template.stopSequence?.utf8CString {
self.stopSequence = stopSequence
stopSequenceLength = stopSequence.count - 1
Expand Down Expand Up @@ -174,7 +174,7 @@ open class LLM: ObservableObject {
historyLimit: historyLimit,
maxTokenCount: maxTokenCount
)
self.preProcess = template.preProcess
self.preprocess = template.preprocess
self.template = template
}

Expand Down Expand Up @@ -224,7 +224,7 @@ open class LLM: ObservableObject {
if maxTokenCount <= currentCount {
while !history.isEmpty && maxTokenCount <= currentCount {
history.removeFirst(min(2, history.count))
tokens = encode(preProcess(self.input, history))
tokens = encode(preprocess(self.input, history))
initialCount = tokens.count
currentCount = Int32(initialCount)
}
Expand All @@ -249,10 +249,10 @@ open class LLM: ObservableObject {
var input = ""
if !history.isEmpty {
history.removeFirst(min(2, history.count))
input = preProcess(self.input, history)
input = preprocess(self.input, history)
} else {
response.scoup(response.count / 3)
input = preProcess(self.input, history)
input = preprocess(self.input, history)
input += response.joined()
}
let rest = getResponse(from: input)
Expand Down Expand Up @@ -332,15 +332,15 @@ open class LLM: ObservableObject {
guard isAvailable else { return }
isAvailable = false
self.input = input
let processedInput = preProcess(input, history)
let processedInput = preprocess(input, history)
let response = getResponse(from: processedInput)
let output = await makeOutputFrom(response)
history += [(.user, input), (.bot, output)]
let historyCount = history.count
if historyLimit < historyCount {
history.removeFirst(min(2, historyCount))
}
postProcess(output)
postprocess(output)
isAvailable = true
}

Expand Down Expand Up @@ -505,7 +505,7 @@ public struct Template {
self.shouldDropLast = shouldDropLast
}

public var preProcess: (_ input: String, _ history: [Chat]) -> String {
public var preprocess: (_ input: String, _ history: [Chat]) -> String {
return { [self] input, history in
var processed = prefix
if let systemPrompt {
Expand Down
24 changes: 12 additions & 12 deletions Tests/LLMTests/LLMTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ final class LLMTests: XCTestCase {
<|im_start|>assistant
"""
let output = template.preProcess(userPrompt, [])
let output = template.preprocess(userPrompt, [])
#assert(expected == output)
}

Expand All @@ -30,7 +30,7 @@ final class LLMTests: XCTestCase {
<|im_start|>assistant
"""
let output = template.preProcess(userPrompt, [])
let output = template.preprocess(userPrompt, [])
#assert(expected == output)
}

Expand All @@ -48,7 +48,7 @@ final class LLMTests: XCTestCase {
<|im_start|>assistant
"""
let output = template.preProcess(userPrompt, history)
let output = template.preprocess(userPrompt, history)
#assert(expected == output)
}

Expand All @@ -61,7 +61,7 @@ final class LLMTests: XCTestCase {
### Response:
"""
let output = template.preProcess(userPrompt, [])
let output = template.preprocess(userPrompt, [])
#assert(expected == output)
}

Expand All @@ -76,7 +76,7 @@ final class LLMTests: XCTestCase {
### Response:
"""
let output = template.preProcess(userPrompt, [])
let output = template.preprocess(userPrompt, [])
#assert(expected == output)
}

Expand All @@ -97,7 +97,7 @@ final class LLMTests: XCTestCase {
### Response:
"""
let output = template.preProcess(userPrompt, history)
let output = template.preprocess(userPrompt, history)
#assert(expected == output)
}

Expand All @@ -106,7 +106,7 @@ final class LLMTests: XCTestCase {
let expected = """
<s>[INST] \(userPrompt) [/INST]
"""
let output = template.preProcess(userPrompt, [])
let output = template.preprocess(userPrompt, [])
#assert(expected == output)
}

Expand All @@ -119,7 +119,7 @@ final class LLMTests: XCTestCase {
\(userPrompt) [/INST]
"""
let output = template.preProcess(userPrompt, [])
let output = template.preprocess(userPrompt, [])
#assert(expected == output)
}

Expand All @@ -132,7 +132,7 @@ final class LLMTests: XCTestCase {
\(history[0].content) [/INST] \(history[1].content)</s><s>[INST] \(userPrompt) [/INST]
"""
let output = template.preProcess(userPrompt, history)
let output = template.preprocess(userPrompt, history)
#assert(expected == output)
}

Expand All @@ -141,7 +141,7 @@ final class LLMTests: XCTestCase {
let expected = """
<s>[INST] \(userPrompt) [/INST]
"""
let output = template.preProcess(userPrompt, [])
let output = template.preprocess(userPrompt, [])
#assert(expected == output)
}

Expand All @@ -150,7 +150,7 @@ final class LLMTests: XCTestCase {
let expected = """
<s>[INST] \(history[0].content) [/INST]\(history[1].content)</s> [INST] \(userPrompt) [/INST]
"""
let output = template.preProcess(userPrompt, history)
let output = template.preprocess(userPrompt, history)
#assert(expected == output)
}

Expand Down Expand Up @@ -203,7 +203,7 @@ final class LLMTests: XCTestCase {
func testInitializerWithTempate() async throws {
let template = model.template
let bot = try await LLM(from: model)
#assert(bot.preProcess(userPrompt, []) == template.preProcess(userPrompt, []))
#assert(bot.preprocess(userPrompt, []) == template.preprocess(userPrompt, []))
}

func testInferenceFromHuggingFaceModel() async throws {
Expand Down

0 comments on commit a86ad40

Please sign in to comment.