From 6367e2f183ff6efffb2c9f31d78c13cdb8bdfca9 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 3 May 2024 20:46:36 -0700 Subject: [PATCH] fix(core): fix remote http api model name dispaly (#2047) --- crates/tabby/src/serve.rs | 10 +++++++-- crates/tabby/src/services/health.rs | 23 +++++++++++++++++++-- ee/tabby-ui/lib/hooks/use-health.tsx | 1 + ee/tabby-ui/lib/hooks/use-workers.ts | 31 +++++++++++++++------------- 4 files changed, 47 insertions(+), 18 deletions(-) diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index 3cb87c800ae8..f6bade3babce 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -222,8 +222,11 @@ async fn api_router( let health_state = Arc::new(health::HealthState::new( args.model.as_deref(), - args.chat_model.as_deref(), &args.device, + args.chat_model.as_deref(), + args.chat_model + .as_deref() + .map(|_| args.chat_device.as_ref().unwrap_or(&args.device)), webserver, )); @@ -322,8 +325,11 @@ async fn api_router( fn start_heartbeat(args: &ServeArgs, webserver: Option) { let state = health::HealthState::new( args.model.as_deref(), - args.chat_model.as_deref(), &args.device, + args.chat_model.as_deref(), + args.chat_model + .as_deref() + .map(|_| args.chat_device.as_ref().unwrap_or(&args.device)), webserver, ); tokio::spawn(async move { diff --git a/crates/tabby/src/services/health.rs b/crates/tabby/src/services/health.rs index df4855aa0fd6..140f6ee9f148 100644 --- a/crates/tabby/src/services/health.rs +++ b/crates/tabby/src/services/health.rs @@ -14,6 +14,8 @@ pub struct HealthState { model: Option, #[serde(skip_serializing_if = "Option::is_none")] chat_model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + chat_device: Option, device: String, arch: String, cpu_info: String, @@ -26,8 +28,9 @@ pub struct HealthState { impl HealthState { pub fn new( model: Option<&str>, - chat_model: Option<&str>, device: &Device, + chat_model: Option<&str>, + chat_device: Option<&Device>, webserver: Option, ) -> Self { let (cpu_info, cpu_count) = read_cpu_info(); @@ -37,9 +40,25 @@ impl HealthState { Err(_) => vec![], }; + let http_model_name = Some("Remote"); + let is_model_http = device == &Device::ExperimentalHttp; + let model = if is_model_http { + http_model_name + } else { + model + }; + + let is_chat_model_http = chat_device == Some(&Device::ExperimentalHttp); + let chat_model = if is_chat_model_http { + http_model_name + } else { + chat_model + }; + Self { - model: model.map(|x| x.to_owned()), + model: model.map(|x| x.to_string()), chat_model: chat_model.map(|x| x.to_owned()), + chat_device: chat_device.map(|x| x.to_string()), device: device.to_string(), arch: ARCH.to_string(), cpu_info, diff --git a/ee/tabby-ui/lib/hooks/use-health.tsx b/ee/tabby-ui/lib/hooks/use-health.tsx index 48dafaac6d5e..12ad138736ae 100644 --- a/ee/tabby-ui/lib/hooks/use-health.tsx +++ b/ee/tabby-ui/lib/hooks/use-health.tsx @@ -8,6 +8,7 @@ export interface HealthInfo { device: 'metal' | 'cpu' | 'cuda' model?: string chat_model?: string + chat_device?: string cpu_info: string cpu_count: number cuda_devices: string[] diff --git a/ee/tabby-ui/lib/hooks/use-workers.ts b/ee/tabby-ui/lib/hooks/use-workers.ts index 32d4a6023e53..2edef8825c5c 100644 --- a/ee/tabby-ui/lib/hooks/use-workers.ts +++ b/ee/tabby-ui/lib/hooks/use-workers.ts @@ -7,22 +7,27 @@ import { Worker, WorkerKind } from '@/lib/gql/generates/graphql' import { useHealth, type HealthInfo } from './use-health' -const modelNameMap: Record = { - [WorkerKind.Chat]: 'chat_model', - [WorkerKind.Completion]: 'model' +function transformHealthInfoToCompletionWorker(healthInfo: HealthInfo): Worker { + return { + kind: WorkerKind.Completion, + device: healthInfo.device, + addr: 'localhost', + arch: '', + cpuInfo: healthInfo.cpu_info, + name: healthInfo.model!, + cpuCount: healthInfo.cpu_count, + cudaDevices: healthInfo.cuda_devices + } } -function transformHealthInfoToWorker( - healthInfo: HealthInfo, - kind: WorkerKind -): Worker { +function transformHealthInfoToChatWorker(healthInfo: HealthInfo): Worker { return { - kind, - device: healthInfo.device, + kind: WorkerKind.Chat, + device: healthInfo.chat_device!, addr: 'localhost', arch: '', cpuInfo: healthInfo.cpu_info, - name: healthInfo?.[modelNameMap[kind]] ?? '', + name: healthInfo.chat_model!, cpuCount: healthInfo.cpu_count, cudaDevices: healthInfo.cuda_devices } @@ -56,12 +61,10 @@ function useWorkers() { findIndex(_workers, { kind: WorkerKind.Chat }) > -1 if (!haveRemoteCompletionWorkers && healthInfo?.model) { - _workers.push( - transformHealthInfoToWorker(healthInfo, WorkerKind.Completion) - ) + _workers.push(transformHealthInfoToCompletionWorker(healthInfo)) } if (!haveRemoteChatWorkers && healthInfo?.chat_model) { - _workers.push(transformHealthInfoToWorker(healthInfo, WorkerKind.Chat)) + _workers.push(transformHealthInfoToChatWorker(healthInfo)) } return groupBy(_workers, 'kind') }, [healthInfo, workers])