diff --git a/README.md b/README.md index 79def17..627fa46 100644 --- a/README.md +++ b/README.md @@ -61,10 +61,10 @@ struct ContentView: View { ```swift class Bot: LLM { - convenience init?() async { + convenience init?(_ update: @escaping (Double) -> Void) async { let systemPrompt = "You are a sentient AI with emotions." let model = HuggingFaceModel("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", .Q2_K, template: .chatML(systemPrompt)) - try? await self.init(from: model) + try? await self.init(from: model) { progress in update(progress) } } } @@ -72,16 +72,24 @@ class Bot: LLM { struct ContentView: View { @State var bot: Bot? = nil + @State var progress: CGFloat = 0 + func updateProgress(_ progress: Double) { + Task { await MainActor.run { self.progress = CGFloat(progress) } } + } var body: some View { if let bot { BotView(bot) } else { - ProgressView().padding() - Text("(loading huggingface model...)").opacity(0.2) - .onAppear() { Task { - let bot = await Bot() - await MainActor.run { self.bot = bot } - } } + ProgressView(value: progress) { + Text("loading huggingface model...") + } currentValueLabel: { + Text(String(format: "%.2f%%", progress * 100)) + } + .padding() + .onAppear() { Task { + let bot = await Bot(updateProgress) + await MainActor.run { self.bot = bot } + } } } } } diff --git a/Sources/LLM/LLM.swift b/Sources/LLM/LLM.swift index c7563de..fffd3b2 100644 --- a/Sources/LLM/LLM.swift +++ b/Sources/LLM/LLM.swift @@ -54,6 +54,7 @@ open class LLM: ObservableObject { private var stopSequenceLength: Int private var params: llama_context_params private var isFull = false + private var updateProgress: (Double) -> Void = { _ in } public init( from path: String, @@ -131,9 +132,12 @@ open class LLM: ObservableObject { topP: Float = 0.95, temp: Float = 0.8, historyLimit: Int = 8, - maxTokenCount: Int32 = 2048 + maxTokenCount: Int32 = 2048, + updateProgress: @escaping (Double) -> Void = { print(String(format: "downloaded(%.2f%%)", $0 * 100)) } ) async throws { - let url = try await huggingFaceModel.download(to: url, as: name) + let url = try await huggingFaceModel.download(to: url, as: name) { progress in + Task { await MainActor.run { updateProgress(progress) } } + } self.init( from: url, template: huggingFaceModel.template, @@ -145,6 +149,7 @@ open class LLM: ObservableObject { historyLimit: historyLimit, maxTokenCount: maxTokenCount ) + self.updateProgress = updateProgress } public convenience init( @@ -578,6 +583,7 @@ public enum Quantization: String { public enum HuggingFaceError: Error { case network(statusCode: Int) case noFilteredURL + case urlIsNilForSomeReason } public struct HuggingFaceModel { @@ -616,17 +622,16 @@ public struct HuggingFaceModel { return nil } - public func download(to directory: URL = .documentsDirectory, as name: String? = nil) async throws -> URL { + public func download(to directory: URL = .documentsDirectory, as name: String? = nil, _ updateProgress: @escaping (Double) -> Void) async throws -> URL { var destination: URL if let name { destination = directory.appending(path: name) - guard !destination.exists else { return destination } + guard !destination.exists else { updateProgress(1); return destination } } guard let downloadURL = try await getDownloadURL() else { throw HuggingFaceError.noFilteredURL } destination = directory.appending(path: downloadURL.lastPathComponent) guard !destination.exists else { return destination } - let data = try await downloadURL.getData() - try data.write(to: destination) + try await downloadURL.downloadData(to: destination, updateProgress) return destination } @@ -651,6 +656,24 @@ extension URL { guard statusCode / 100 == 2 else { throw HuggingFaceError.network(statusCode: statusCode) } return data } + fileprivate func downloadData(to destination: URL, _ updateProgress: @escaping (Double) -> Void) async throws { + var observation: NSKeyValueObservation! + let url: URL = try await withCheckedThrowingContinuation { continuation in + let task = URLSession.shared.downloadTask(with: self) { url, response, error in + if let error { return continuation.resume(throwing: error) } + guard let url else { return continuation.resume(throwing: HuggingFaceError.urlIsNilForSomeReason) } + let statusCode = (response as! HTTPURLResponse).statusCode + guard statusCode / 100 == 2 else { return continuation.resume(throwing: HuggingFaceError.network(statusCode: statusCode)) } + continuation.resume(returning: url) + } + observation = task.progress.observe(\.fractionCompleted) { progress, _ in + updateProgress(progress.fractionCompleted) + } + task.resume() + } + let _ = observation + try FileManager.default.moveItem(at: url, to: destination) + } } package extension String { diff --git a/Tests/LLMTests/LLMTests.swift b/Tests/LLMTests/LLMTests.swift index 0ec93e4..a426487 100644 --- a/Tests/LLMTests/LLMTests.swift +++ b/Tests/LLMTests/LLMTests.swift @@ -207,14 +207,14 @@ final class LLMTests: XCTestCase { } func testInferenceFromHuggingFaceModel() async throws { - var bot = try await LLM(from: model) + let bot = try await LLM(from: model) let input = "have you heard of this so-called LLM.swift library?" await bot.respond(to: input) #assert(!bot.output.isEmpty) } func testRecoveryFromLengtyInput() async throws { - var bot = try await LLM(from: model, maxTokenCount: 16) + let bot = try await LLM(from: model, maxTokenCount: 16) let input = "have you heard of this so-called LLM.swift library?" await bot.respond(to: input) #assert(bot.output == "tl;dr")