Skip to content

Commit

Permalink
Merge pull request #220 from NexaAI/swift-group
Browse files Browse the repository at this point in the history
add a send button for ios demo app; update correct chat format
  • Loading branch information
zhiyuan8 authored Nov 8, 2024
2 parents 4eee461 + c3c959f commit 62a3bd4
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 26 deletions.
13 changes: 13 additions & 0 deletions examples/swift-test/Shared/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import SwiftUI
struct ContentView: View {
@State private var viewModel = ViewModel()
@State private var prompt = ""
@FocusState private var isInputActive: Bool

var body: some View {
VStack {
Expand All @@ -21,6 +22,18 @@ struct ContentView: View {
guard !prompt.isEmpty else { return }
viewModel.run(for: prompt)
}
.focused($isInputActive)

Button(action: {
guard !prompt.isEmpty else { return }
viewModel.run(for: prompt)
isInputActive = false
}) {
Text("Send")
.frame(maxWidth: .infinity)
}
.buttonStyle(.borderedProminent)
.padding(.bottom)

ScrollView {
Text(viewModel.result)
Expand Down
72 changes: 46 additions & 26 deletions swift/Sources/NexaSwift/Models/ChatCompletionMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,12 @@ class OctopusV2Formatter: ChatFormatter {

//https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/
class LlamaFormatter: ChatFormatter {
private let systemTemplate = "[INST] <<SYS>>\n{system_message}\n<</SYS>>"
private let systemTemplate = "<<SYS>>\n{system_message}\n<</SYS>>\n\n"
private let roles: [String: String] = [
"user": "<s>[INST]",
"assistant": "[/INST]"
"user": "<s>[INST] ",
"assistant": " [/INST] "
]
private let endToken = "</s>"

func format(messages: [ChatCompletionRequestMessage]) -> ChatFormatterResponse {
let formattedMessages = mapRoles(messages: messages)
Expand All @@ -233,7 +234,7 @@ class LlamaFormatter: ChatFormatter {
systemTemplate.replacingOccurrences(of: "{system_message}", with: msg)
}
let prompt = formatPrompt(systemMessage: formattedSystemMessage, messages: formattedMessages)
return ChatFormatterResponse(prompt: prompt + "[/INST]", stop: ["</s>"])
return ChatFormatterResponse(prompt: prompt, stop: [endToken])
}

private func getSystemMessage(_ messages: [ChatCompletionRequestMessage]) -> String? {
Expand Down Expand Up @@ -268,24 +269,36 @@ class LlamaFormatter: ChatFormatter {
}

private func formatPrompt(systemMessage: String?, messages: [(String, String?)]) -> String {
var prompt = ""
var conversations: [String] = []
var currentConversation = ""

if let (firstRole, firstContent) = messages.first,
let content = firstContent {
if let sysMsg = systemMessage {
prompt += "\(firstRole) \(sysMsg)\n\(content)"
} else {
prompt += "\(firstRole) \(content)"
for (index, (role, content)) in messages.enumerated() {
if index % 2 == 0 { // User message
if !currentConversation.isEmpty {
conversations.append(currentConversation + " " + endToken)
}
currentConversation = role // <s>[INST]
if index == 0 && systemMessage != nil {
currentConversation += systemMessage! + content!
} else {
currentConversation += content ?? ""
}
} else { // Assistant message
if let content = content {
currentConversation += role + content // [/INST] response
}
}
}

for (role, content) in messages.dropFirst() {
if let content = content {
prompt += " \(role) \(content)"
}
// Add the last conversation if it's a user message without response
if messages.count % 2 != 0 {
currentConversation += roles["assistant"]!
conversations.append(currentConversation)
} else if !currentConversation.isEmpty {
conversations.append(currentConversation + endToken)
}

return prompt.trimmingCharacters(in: .whitespacesAndNewlines)
return conversations.joined(separator: "\n")
}
}

Expand All @@ -296,8 +309,7 @@ class Llama3Formatter: ChatFormatter {
"user": "<|start_header_id|>user<|end_header_id|>\n\n",
"assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n"
]

private let separator = "<|eot_id|>\n"
private let endToken = "<|eot_id|>"

func format(messages: [ChatCompletionRequestMessage]) -> ChatFormatterResponse {
var formattedMessages = mapRoles(messages: messages)
Expand All @@ -306,7 +318,7 @@ class Llama3Formatter: ChatFormatter {

let prompt = formatPrompt(formattedMessages)

return ChatFormatterResponse(prompt: prompt, stop: [separator])
return ChatFormatterResponse(prompt: prompt, stop: [endToken])
}

private func mapRoles(messages: [ChatCompletionRequestMessage]) -> [(String, String?)] {
Expand Down Expand Up @@ -345,7 +357,7 @@ class Llama3Formatter: ChatFormatter {
var prompt = "<|begin_of_text|>"
for (role, content) in formattedMessages {
if let content = content {
prompt += "\(role)\(content.trimmingCharacters(in: .whitespacesAndNewlines))\(separator)"
prompt += "\(role)\(content.trimmingCharacters(in: .whitespacesAndNewlines))\(endToken)"
} else {
prompt += "\(role) "
}
Expand All @@ -362,14 +374,15 @@ class GemmaFormatter: ChatFormatter {
"assistant": "<start_of_turn>model\n"
]

private let endToken = "<end_of_turn>"
private let separator = "<end_of_turn>\n"

func format(messages: [ChatCompletionRequestMessage]) -> ChatFormatterResponse {
var formattedMessages = mapRoles(messages: messages)
formattedMessages.append((roles["assistant"]!, nil))
let prompt = formatPrompt(formattedMessages)

return ChatFormatterResponse(prompt: prompt, stop: [separator])
return ChatFormatterResponse(prompt: prompt, stop: [endToken])
}

private func mapRoles(messages: [ChatCompletionRequestMessage]) -> [(String, String?)] {
Expand Down Expand Up @@ -408,6 +421,7 @@ class GemmaFormatter: ChatFormatter {
}
}

// https://qwen.readthedocs.io/zh-cn/latest/getting_started/concepts.html#control-tokens-chat-template
class QwenFormatter: ChatFormatter {
private let roles: [String: String] = [
"user": "<|im_start|>user",
Expand Down Expand Up @@ -464,16 +478,17 @@ class QwenFormatter: ChatFormatter {
}
}

//https://www.promptingguide.ai/models/mistral-7b
// https://www.promptingguide.ai/models/mistral-7b#chat-template-for-mistral-7b-instruct
class MistralFormatter: ChatFormatter {
private let endToken = "</s>"
private let conversationStart = "<s>"
private let instructStart = "[INST] "
private let instructEnd = " [/INST]"
private let instructEnd = " [/INST] "

func format(messages: [ChatCompletionRequestMessage]) -> ChatFormatterResponse {
var prompt = ""
var prompt = conversationStart // Add <s> only once at the start

for message in messages {
for (index, message) in messages.enumerated() {
switch message {
case .user(let userMessage):
switch userMessage.content {
Expand All @@ -491,7 +506,12 @@ class MistralFormatter: ChatFormatter {
continue
}
}
prompt += instructEnd

// Add instructEnd if the last message was from user (waiting for AI response)
if messages.last.map({ if case .user = $0 { return true } else { return false } }) ?? false {
prompt += instructEnd
}

return ChatFormatterResponse(prompt: prompt, stop: [endToken])
}
}

0 comments on commit 62a3bd4

Please sign in to comment.