Skip to content

Commit

Permalink
fix: apply authorization header in /hub websocket api (#1043)
Browse files Browse the repository at this point in the history
* fix: apply authorization header in /hub websocket api

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
wsxiaoys and autofix-ci[bot] authored Dec 14, 2023
1 parent db421a5 commit bd4d812
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 127 deletions.
27 changes: 27 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 13 additions & 33 deletions crates/tabby/src/worker.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use std::{env::consts::ARCH, net::IpAddr, sync::Arc};

use anyhow::Result;
use axum::{routing, Router};
use clap::Args;
use tabby_webserver::api::{tracing_context, HubClient, WorkerKind};
use tracing::{info, warn};
use tabby_webserver::api::{HubClient, WorkerKind};
use tracing::info;

use crate::{
routes::{self, run_app},
Expand Down Expand Up @@ -47,9 +46,7 @@ pub struct WorkerArgs {
parallelism: u8,
}

async fn make_chat_route(context: WorkerContext, args: &WorkerArgs) -> Router {
context.register(WorkerKind::Chat, args).await;

async fn make_chat_route(args: &WorkerArgs) -> Router {
let chat_state =
Arc::new(create_chat_service(&args.model, &args.device, args.parallelism).await);

Expand All @@ -60,8 +57,6 @@ async fn make_chat_route(context: WorkerContext, args: &WorkerArgs) -> Router {
}

async fn make_completion_route(context: WorkerContext, args: &WorkerArgs) -> Router {
context.register(WorkerKind::Completion, args).await;

let code = Arc::new(context.client.clone());
let logger = Arc::new(context.client);
let completion_state = Arc::new(
Expand All @@ -79,11 +74,11 @@ pub async fn main(kind: WorkerKind, args: &WorkerArgs) {

info!("Starting worker, this might take a few minutes...");

let context = WorkerContext::new(&args.url).await;
let context = WorkerContext::new(kind.clone(), args).await;

let app = match kind {
WorkerKind::Completion => make_completion_route(context, args).await,
WorkerKind::Chat => make_chat_route(context, args).await,
WorkerKind::Chat => make_chat_route(args).await,
};

run_app(app, None, args.host, args.port).await
Expand All @@ -94,25 +89,14 @@ struct WorkerContext {
}

impl WorkerContext {
async fn new(url: &str) -> Self {
Self {
client: tabby_webserver::api::create_client(url).await,
}
}

async fn register(&self, kind: WorkerKind, args: &WorkerArgs) {
if let Err(err) = self.register_impl(kind, args).await {
warn!("Failed to register worker: {}", err)
}
}

async fn register_impl(&self, kind: WorkerKind, args: &WorkerArgs) -> Result<()> {
async fn new(kind: WorkerKind, args: &WorkerArgs) -> Self {
let (cpu_info, cpu_count) = read_cpu_info();
let cuda_devices = read_cuda_devices().unwrap_or_default();
let worker = self
.client
.register_worker(
tracing_context(),

Self {
client: tabby_webserver::api::create_client(
&args.url,
&args.token,
kind,
args.port as i32,
args.model.to_owned(),
Expand All @@ -121,12 +105,8 @@ impl WorkerContext {
cpu_info,
cpu_count as i32,
cuda_devices,
args.token.clone(),
)
.await??;

info!("Worker alive at {}", worker.addr);

Ok(())
.await,
}
}
}
3 changes: 2 additions & 1 deletion ee/tabby-webserver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ homepage.workspace = true
anyhow.workspace = true
argon2 = "0.5.1"
async-trait.workspace = true
axum = { workspace = true, features = ["ws"] }
axum = { workspace = true, features = ["ws", "headers"] }
bincode = "1.3.3"
chrono = "0.4"
futures.workspace = true
Expand All @@ -26,6 +26,7 @@ rusqlite = { version = "0.30.0", features = ["bundled", "chrono"] }
rusqlite_migration = { version = "1.1.0-beta.1", features = ["alpha-async-tokio-rusqlite", "from-directory"] }
rust-embed = "8.0.0"
serde.workspace = true
serde_json.workspace = true
tabby-common = { path = "../../crates/tabby-common" }
tarpc = { version = "0.33.0", features = ["serde-transport"] }
thiserror.workspace = true
Expand Down
57 changes: 41 additions & 16 deletions ee/tabby-webserver/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,16 @@
use async_trait::async_trait;
use hyper::Request;
use tabby_common::api::{
code::{CodeSearch, CodeSearchError, SearchResponse},
event::RawEventLogger,
};
use tokio_tungstenite::connect_async;

pub use crate::schema::worker::{RegisterWorkerError, Worker, WorkerKind};
use crate::websocket::WebSocketTransport;
use crate::{websocket::WebSocketTransport, RegisterWorkerRequest, REGISTER_WORKER_HEADER};

#[tarpc::service]
pub trait Hub {
async fn register_worker(
kind: WorkerKind,
port: i32,
name: String,
device: String,
arch: String,
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
token: String,
) -> Result<Worker, RegisterWorkerError>;

async fn log_event(content: String);

async fn search(q: String, limit: usize, offset: usize) -> SearchResponse;
Expand All @@ -38,9 +27,45 @@ pub fn tracing_context() -> tarpc::context::Context {
tarpc::context::current()
}

pub async fn create_client(addr: &str) -> HubClient {
let addr = format!("ws://{}/hub", addr);
let (socket, _) = connect_async(&addr).await.unwrap();
pub async fn create_client(
addr: &str,
token: &str,
kind: WorkerKind,
port: i32,
name: String,
device: String,
arch: String,
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
) -> HubClient {
let request = Request::builder()
.uri(format!("ws://{}/hub", addr))
.header("Host", addr)
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", "unused")
.header("Authorization", format!("Bearer {}", token))
.header("Content-Type", "application/json")
.header(
&REGISTER_WORKER_HEADER,
serde_json::to_string(&RegisterWorkerRequest {
kind,
port,
name,
device,
arch,
cpu_info,
cpu_count,
cuda_devices,
})
.unwrap(),
)
.body(())
.unwrap();

let (socket, _) = connect_async(request).await.unwrap();
HubClient::new(Default::default(), WebSocketTransport::from(socket)).spawn()
}

Expand Down
Loading

0 comments on commit bd4d812

Please sign in to comment.