Skip to content

Commit

Permalink
rudimentary integration with bigmode window
Browse files Browse the repository at this point in the history
  • Loading branch information
a5huynh committed Nov 23, 2024
1 parent 4351289 commit 40ca776
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 77 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

64 changes: 43 additions & 21 deletions apps/desktop-client/src/pages/bigmode/BigMode.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ function ChatLogItem({
"flex", "flex-none", "border", "border-cyan-600", "w-[48px]", "h-[48px]", "rounded-full", "items-center",
{ "order-1": isUser }
)}>
<div className="text-lg mx-auto">{icon}</div>
<div className="text-lg mx-auto">{isStreaming ? (
<ArrowPathIcon className="w-4 animate-spin" />
) : icon}</div>
</div>
<div className={classNames("grow", { "text-left": !isUser, "text-right": isUser })}>
{chat.content}
Expand All @@ -215,26 +217,22 @@ function AskClippy() {
const clippyInput = useRef<HTMLTextAreaElement>(null);

const [isStreaming, setIsStreaming] = useState<boolean>(false);
const [stream, setStream] = useState<string>('');
const [tokens, setTokens] = useState<string[]>([]);

const [history, setHistory] = useState<ChatMessage[]>([
{ role: "user", content: "hi what's your name?" },
{ role: "assistant", content: "test" }
{ role: "assistant", content: "My name is Clippy." }
]);
const [status, setStatus] = useState<string>('');

const handleChatEvent = (event: ChatStream) => {
if (event.type == "LoadingPrompt") {
if (event.type === "LoadingPrompt") {
setStatus("Loading prompt...");
} else if (event.type == "Token") {
setStream(str => str + event.content);
} else if (event.type == "ChatDone") {
} else if (event.type === "Token") {
setTokens(toks => [...toks, event.content]);
} else if (event.type === "ChatDone") {
setIsStreaming(false);
setHistory(hist => ([...hist, {
role: "assistant",
content: stream,
}]));
setStream('');
setStatus('');
}
};

Expand All @@ -248,27 +246,51 @@ function AskClippy() {
await invoke("ask_clippy", { session: { messages: currentCtxt }});
};

const handleQuerySubmission = () => {};
const handleQuerySubmission = () => {
if (clippyInput.current) {
handleAskClippy(clippyInput.current.value.trim());
clippyInput.current.value = '';
}
};

const clearHistory = () => {
setHistory([]);
};

useEffect(() => {
if (isStreaming == false && tokens.length > 0) {
setHistory(hist => ([...hist, {
role: "assistant",
content: tokens.join(''),
}]));
setTokens([]);
}
}, [isStreaming, tokens]);

useEffect(() => {
const init = async () => {
await listen<ChatStream>("ChatEvent", (event) => handleChatEvent(event.payload));
return await listen<ChatStream>(
"ChatEvent",
(event) => {
handleChatEvent(event.payload);
},
);
};

init();
}, []);
let unlisten = init();
return () => {

Check failure on line 280 in apps/desktop-client/src/pages/bigmode/BigMode.tsx

View workflow job for this annotation

GitHub Actions / frontend-check

'unlisten' is never reassigned. Use 'const' instead
(async () => {
await unlisten.then(fn => fn());
})();
};
}, [])

return (
<div className="flex flex-col grow bg-neutral-800 text-white">
<div className="flex flex-col grow place-content-end min-h-[128px]">
<div className="flex flex-col overflow-y-scroll">
<div className="flex flex-col bg-neutral-800 text-white h-full">
<div className="flex flex-col grow place-content-end">
<div className="flex flex-col place-content-end overflow-y-scroll">
<ChatLog history={history} />
{ isStreaming ? (
<ChatLogItem chat={{ role: "assistant", content: stream ?? status }} isStreaming={isStreaming} />
<ChatLogItem chat={{ role: "assistant", content: tokens.length > 0 ? tokens.join('') : status }} isStreaming={isStreaming} />
) : null}
</div>
</div>
Expand Down
11 changes: 5 additions & 6 deletions apps/tauri/src/cmd/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ use crate::rpc;

#[tauri::command]
pub async fn ask_clippy(win: tauri::Window, session: LlmSession) -> Result<(), String> {
if let Some(rpc) = win.app_handle().try_state::<rpc::RpcMutex>() {
let rpc = rpc.lock().await;
if let Err(err) = rpc.client.chat_completion(session).await {
return Err(err.to_string());
tokio::spawn(async move {
if let Some(rpc) = win.app_handle().try_state::<rpc::RpcMutex>() {
let rpc = rpc.lock().await;
let _ = rpc.client.chat_completion(session).await;
}
}

});
Ok(())
}
70 changes: 42 additions & 28 deletions apps/tauri/src/plugins/notify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ pub struct NotificationHandler(JoinHandle<()>);

pub fn init(app: &AppHandle) {
log::info!("starting notify plugin");
let handle = tauri::async_runtime::spawn(setup_notification_handler(app.clone()));
let handle: JoinHandle<()> =
tauri::async_runtime::spawn(setup_notification_handler(app.clone()));
app.manage(NotificationHandler(handle));
}

Expand Down Expand Up @@ -77,45 +78,58 @@ async fn setup_notification_handler(app: AppHandle) {
log::debug!("received event: {:?}", event);
let notif: Option<(String, String)> = match &event.event_type {
RpcEventType::ChatStream => {
let _ = app.emit(ClientEvent::ChatEvent.as_ref(), event.payload);
if let Some(payload) = event.payload {
let _ = app.emit(ClientEvent::ChatEvent.as_ref(), payload);
}
None
},
RpcEventType::ConnectionSyncFinished => Some(("Sync Completed".into(), event.payload)),
RpcEventType::ConnectionSyncFinished => Some((
"Sync Completed".into(),
event.payload.map(|p| p.to_string()).unwrap_or_default()
)),
RpcEventType::LensInstalled => {
let _ = app.emit(ClientEvent::LensInstalled.as_ref(), event.payload.clone());
log::debug!("lens installed {}", &event.payload);
Some(("Lens Installed".into(), format!("{} was installed in your library", event.payload)))
log::debug!("lens installed {:?}", &event.payload);
Some((
"Lens Installed".into(),
format!("{} was installed in your library",
event.payload.map(|p| p.to_string()).unwrap_or_default()
)))
},
RpcEventType::LensUninstalled => {
let _ = app.emit(ClientEvent::LensUninstalled.as_ref(), event.payload.clone());
log::debug!("lens removed {}", &event.payload);
Some(("Lens Uninstalled".into(), format!("{} was removed from your library", event.payload)))
log::debug!("lens removed {:?}", &event.payload);
Some(("Lens Uninstalled".into(), format!("{} was removed from your library", event.payload.map(|p| p.to_string()).unwrap_or_default())))
},
RpcEventType::ModelDownloadStatus => {
if let Ok(status) = serde_json::de::from_str::<ModelDownloadStatusPayload>(&event.payload) {
match status {
ModelDownloadStatusPayload::Finished { model_name } => {
let window = crate::window::update_progress_window(&app, &model_name, 100);
let _ = window.close();
if let Some(payload) = event.payload {
if let Ok(status) = serde_json::from_value::<ModelDownloadStatusPayload>(payload) {
match status {
ModelDownloadStatusPayload::Finished { model_name } => {
let window = crate::window::update_progress_window(&app, &model_name, 100);
let _ = window.close();

Some((
"Model Installed".into(),
format!("Finished downloading {}", model_name)
))
},
ModelDownloadStatusPayload::Error { model_name, msg } => {
Some((
"Model Download Failed".into(),
format!("Unable to download {} model: {}", model_name, msg)
))
},
ModelDownloadStatusPayload::InProgress { model_name, percent } => {
log::info!("downloading: {} - {}", model_name, percent);
crate::window::update_progress_window(&app, &model_name, percent);
None
Some((
"Model Installed".into(),
format!("Finished downloading {}", model_name)
))
},
ModelDownloadStatusPayload::Error { model_name, msg } => {
Some((
"Model Download Failed".into(),
format!("Unable to download {} model: {}", model_name, msg)
))
},
ModelDownloadStatusPayload::InProgress { model_name, percent } => {
log::info!("downloading: {} - {}", model_name, percent);
crate::window::update_progress_window(&app, &model_name, percent);
None
}
}
} else {
None
}
} else {
} else {
None
}
}
Expand Down
2 changes: 1 addition & 1 deletion assets/templates/llm/llama3-instruct.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
<|start_header_id|>{{ msg.role }}<|end_header_id|>
{{ msg.content }}<|eot_id|>
{% endfor %}}
<|start_header_id|>assistant<|end_header_id|>
<|start_header_id|>assistant<|end_header_id|>
2 changes: 1 addition & 1 deletion assets/templates/llm/phi3.5-instruct.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
<|{{ msg.role }}|>
{{ msg.content }}<|end|>
{% endfor %}
<|assistant|>
<|assistant|>
3 changes: 2 additions & 1 deletion crates/spyglass-rpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ edition = "2021"

[dependencies]
serde = { workspace = true }
serde_json = { workspace = true }
jsonrpsee = { workspace = true, features = ["full"] }
shared = { path = "../shared" }

[lib]
name = "spyglass_rpc"
path = "src/lib.rs"
crate-type = ["lib"]
crate-type = ["lib"]
3 changes: 2 additions & 1 deletion crates/spyglass-rpc/src/events.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)]
pub enum RpcEventType {
Expand All @@ -14,7 +15,7 @@ pub struct RpcEvent {
/// Event Type
pub event_type: RpcEventType,
/// Payload serialized as JSON if applicable.
pub payload: String,
pub payload: Option<Value>,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
Expand Down
4 changes: 2 additions & 2 deletions crates/spyglass/src/api/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ pub async fn chat_completion(state: AppState, session: &LlmSession) -> RpcResult
state_clone
.publish_event(&RpcEvent {
event_type: RpcEventType::ChatStream,
payload: serde_json::to_string(&msg).expect("Unable to serialize ChatStream"),
payload: Some(serde_json::to_value(&msg).unwrap()),
})
.await;

Expand Down Expand Up @@ -623,7 +623,7 @@ pub async fn uninstall_lens(state: AppState, config: &Config, name: &str) -> Rpc
state
.publish_event(&RpcEvent {
event_type: RpcEventType::LensUninstalled,
payload: name.to_string(),
payload: Some(serde_json::to_value(name.to_string()).unwrap()),
})
.await;

Expand Down
2 changes: 1 addition & 1 deletion crates/spyglass/src/pipeline/cache_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ pub async fn process_update(
state
.publish_event(&spyglass_rpc::RpcEvent {
event_type: spyglass_rpc::RpcEventType::LensInstalled,
payload: lens.name.to_string(),
payload: Some(serde_json::to_value(lens.name.to_string()).unwrap()),
})
.await;
}
34 changes: 19 additions & 15 deletions crates/spyglass/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,13 +409,13 @@ async fn download_model(
state
.publish_event(&RpcEvent {
event_type: RpcEventType::ModelDownloadStatus,
payload: serde_json::to_string(
&ModelDownloadStatusPayload::InProgress {
payload: Some(
serde_json::to_value(&ModelDownloadStatusPayload::InProgress {
model_name: model_name.into(),
percent,
},
)
.unwrap_or_default(),
})
.unwrap(),
),
})
.await;
last_update = std::time::Instant::now();
Expand All @@ -426,10 +426,12 @@ async fn download_model(
state
.publish_event(&RpcEvent {
event_type: RpcEventType::ModelDownloadStatus,
payload: serde_json::to_string(&ModelDownloadStatusPayload::Finished {
model_name: model_name.into(),
})
.unwrap_or_default(),
payload: Some(
serde_json::to_value(&ModelDownloadStatusPayload::Finished {
model_name: model_name.into(),
})
.unwrap_or_default(),
),
})
.await;
Ok(())
Expand All @@ -438,11 +440,13 @@ async fn download_model(
state
.publish_event(&RpcEvent {
event_type: RpcEventType::ModelDownloadStatus,
payload: serde_json::to_string(&ModelDownloadStatusPayload::Error {
model_name: model_name.into(),
msg: err.to_string(),
})
.unwrap_or_default(),
payload: Some(
serde_json::to_value(&ModelDownloadStatusPayload::Error {
model_name: model_name.into(),
msg: err.to_string(),
})
.unwrap(),
),
})
.await;

Expand Down Expand Up @@ -546,7 +550,7 @@ pub async fn worker_task(

state.publish_event(&RpcEvent {
event_type: RpcEventType::ConnectionSyncFinished,
payload,
payload: Some(serde_json::to_value(&payload).unwrap())
}).await;
}
Err(err) => log::warn!("Unable to sync w/ connection: {account}@{api_id} - {err}"),
Expand Down

0 comments on commit 40ca776

Please sign in to comment.