diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index cdc2901f8ec4..51c6b2e0bf1b 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -10,8 +10,8 @@ on: - '!*-dev' concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref_name }} - + group: ${{ github.workflow }}-${{ github.head_ref || github.ref_name }} + # If this is enabled it will cancel current running and start latest cancel-in-progress: true @@ -28,6 +28,15 @@ jobs: # with sigstore/fulcio when running outside of PRs. id-token: write + strategy: + matrix: + device-type: [cuda, rocm] + include: + - device-type: cuda + image-suffix: '' + - device-type: rocm + image-suffix: '-rocm' + steps: - name: Free Disk Space (Ubuntu) uses: jlumbroso/free-disk-space@main @@ -69,8 +78,10 @@ jobs: password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Generate image name + env: + IMAGE_SUFFIX: ${{ matrix.image-suffix }} run: | - echo "IMAGE_NAME=${GITHUB_REPOSITORY,,}" >>${GITHUB_ENV} + echo "IMAGE_NAME=${GITHUB_REPOSITORY,,}${IMAGE_SUFFIX}" >>${GITHUB_ENV} - uses: int128/docker-build-cache-config-action@v1 id: cache @@ -98,7 +109,7 @@ jobs: id: build-and-push uses: docker/build-push-action@v3.1.1 with: - file: Dockerfile + file: docker/Dockerfile.${{ matrix.device-type }} push: true context: . tags: ${{ steps.meta.outputs.tags }} @@ -113,4 +124,3 @@ jobs: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} repository: tabbyml/tabby - diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c31a4b721a7a..95a28b9b5dcb 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -27,11 +27,12 @@ jobs: container: ${{ matrix.container }} strategy: matrix: - binary: [aarch64-apple-darwin, x86_64-manylinux2014, x86_64-windows-msvc, x86_64-manylinux2014-cuda117, x86_64-manylinux2014-cuda122, x86_64-windows-msvc-cuda117, x86_64-windows-msvc-cuda122, x86_64-manylinux2014-rocm57] + binary: [aarch64-apple-darwin, x86_64-manylinux2014, x86_64-windows-msvc, x86_64-manylinux2014-cuda117, x86_64-manylinux2014-cuda122, x86_64-windows-msvc-cuda117, x86_64-windows-msvc-cuda122] include: - os: macos-latest target: aarch64-apple-darwin binary: aarch64-apple-darwin + build_args: --features prod-db - os: dimerun-k3-ubuntu2204 target: x86_64-unknown-linux-gnu binary: x86_64-manylinux2014 @@ -63,11 +64,6 @@ jobs: ext: .exe build_args: --features cuda,prod-db windows_cuda: '12.2.0' - - os: dimerun-k3-ubuntu2204 - target: x86_64-unknown-linux-gnu - binary: x86_64-manylinux2014-rocm57 - container: ghcr.io/cromefire/hipblas-manylinux/2014/5.7:latest - build_args: --features static-ssl,rocm,prod-db env: SCCACHE_GHA_ENABLED: true diff --git a/.github/workflows/test-rust.yml b/.github/workflows/test-rust.yml index 51fa5240e1c4..ecb27cb52421 100644 --- a/.github/workflows/test-rust.yml +++ b/.github/workflows/test-rust.yml @@ -28,7 +28,7 @@ concurrency: cancel-in-progress: true env: - RUST_TOOLCHAIN: 1.73.0 + RUST_TOOLCHAIN: 1.76.0 jobs: tests: diff --git a/.gitignore b/.gitignore index dcdc169416ac..af2ab192dcd4 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ node_modules .idea/ .DS_Store .vscode/ +local/ __pycache__ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 854b5c451806..0f209e04d233 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,16 +29,22 @@ apt-get install protobuf-compiler libopenblas-dev choco install protoc ``` +Some of the tests require mailtutan SMTP server which you can install with: + +```bash +cargo install mailtutan +``` + Before proceeding, ensure that all tests are passing locally: ``` cargo test -- --skip golden ``` -Golden tests should be skipped on all platforms except Apple silicon (M1/M2), because they have not been created for other platforms yet. - This will help ensure everything is working correctly and avoid surprises with local breakages. +Golden tests, which run models and check their outputs against previous "golden snapshots", should be skipped for most development purposes, as they take a very long time to run (especially the tests running the models on CPU). You may still want to run them if your changes relate to the functioning of or integration with the generative models, but skipping them is recommended otherwise. + ## Building and Running Tabby can be run through `cargo` in much the same manner as docker: @@ -78,6 +84,7 @@ By default, Tabby will start on `localhost:8080` and serve requests. Tabby is broken up into several crates, each responsible for a different part of the functionality. These crates fall into two categories: Fully open source features, and enterprise features. All open-source feature crates are located in the `/crates` folder in the repository root, and all enterprise feature crates are located in `/ee`. ### Crates + - `crates/tabby` - The core tabby application, this is the main binary crate defining CLI behavior and driving the API - `crates/tabby-common` - Interfaces and type definitions shared across most other tabby crates, especially types used for serialization - `crates/tabby-download` - Very small crate, responsible for downloading models at runtime diff --git a/Cargo.lock b/Cargo.lock index b7b812693430..715eb12fb81e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -185,6 +185,40 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9" +[[package]] +name = "async-convert" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d416feee97712e43152cd42874de162b8f9b77295b1c85e5d92725cc8310bae" +dependencies = [ + "async-trait", +] + +[[package]] +name = "async-openai" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea5c9223f84965c603fd58c4c9ddcd1907efb2e54acf6fb47039358cd374df4" +dependencies = [ + "async-convert", + "backoff", + "base64 0.21.5", + "bytes", + "derive_builder", + "futures", + "rand 0.8.5", + "reqwest", + "reqwest-eventsource 0.4.0", + "secrecy", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -335,6 +369,20 @@ dependencies = [ "tracing-opentelemetry", ] +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom 0.2.11", + "instant", + "pin-project-lite", + "rand 0.8.5", + "tokio", +] + [[package]] name = "backtrace" version = "0.3.67" @@ -1588,15 +1636,14 @@ dependencies = [ name = "http-api-bindings" version = "0.9.0-dev" dependencies = [ + "anyhow", + "async-openai", "async-stream", "async-trait", "futures", - "reqwest", - "reqwest-eventsource", - "serde", "serde_json", + "tabby-common", "tabby-inference", - "tokio", "tracing", ] @@ -1659,6 +1706,20 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http", + "hyper", + "rustls", + "tokio", + "tokio-rustls", +] + [[package]] name = "hyper-timeout" version = "0.4.1" @@ -3259,21 +3320,27 @@ dependencies = [ "http", "http-body", "hyper", + "hyper-rustls", "hyper-tls", "ipnet", "js-sys", "log", "mime", + "mime_guess", "native-tls", "once_cell", "percent-encoding", "pin-project-lite", + "rustls", + "rustls-native-certs", + "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", "system-configuration", "tokio", "tokio-native-tls", + "tokio-rustls", "tokio-util", "tower-service", "url", @@ -3284,6 +3351,22 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest-eventsource" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f03f570355882dd8d15acc3a313841e6e90eddbc76a93c748fd82cc13ba9f51" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite", + "reqwest", + "thiserror", +] + [[package]] name = "reqwest-eventsource" version = "0.5.0" @@ -3474,6 +3557,49 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "rustls" +version = "0.21.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" +dependencies = [ + "log", + "ring", + "rustls-webpki", + "sct", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64 0.21.5", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -3522,6 +3648,26 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1792db035ce95be60c3f8853017b3999209281c24e2ba5bc8e59bf97a0c590c1" +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "secrecy" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bd1c54ea06cfd2f6b63219704de0b9b4f72dcc2b8fdef820be6cd799780e91e" +dependencies = [ + "serde", + "zeroize", +] + [[package]] name = "security-framework" version = "2.9.2" @@ -4265,7 +4411,7 @@ dependencies = [ "opentelemetry-otlp", "regex", "reqwest", - "reqwest-eventsource", + "reqwest-eventsource 0.5.0", "serde", "serde-jsonlines 0.5.0", "serde_json", @@ -4345,6 +4491,7 @@ dependencies = [ name = "tabby-inference" version = "0.9.0-dev" dependencies = [ + "anyhow", "async-stream", "async-trait", "dashmap", @@ -4803,6 +4950,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-serde" version = "0.8.0" diff --git a/Makefile b/Makefile index fbc2912da10f..5bed63ab0f8a 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ fix: cargo machete --fix || true cargo +nightly fmt - cargo +nightly clippy --fix --allow-dirty --allow-staged + cargo clippy --fix --allow-dirty --allow-staged fix-ui: cd ee/tabby-ui && yarn format:write && yarn lint:fix diff --git a/clients/intellij/src/main/kotlin/com/tabbyml/intellijtabby/agent/AgentService.kt b/clients/intellij/src/main/kotlin/com/tabbyml/intellijtabby/agent/AgentService.kt index 108620ead7ab..292cd62168d7 100644 --- a/clients/intellij/src/main/kotlin/com/tabbyml/intellijtabby/agent/AgentService.kt +++ b/clients/intellij/src/main/kotlin/com/tabbyml/intellijtabby/agent/AgentService.kt @@ -181,32 +181,6 @@ class AgentService : Disposable { } } } - - scope.launch { - agent.currentIssue.collect { issueName -> - val notification = when (issueName) { - "connectionFailed" -> Notification( - "com.tabbyml.intellijtabby.notification.warning", - "Cannot connect to Tabby server", - NotificationType.ERROR, - ).apply { - addAction(ActionManager.getInstance().getAction("Tabby.CheckIssueDetail")) - } - - else -> { - invokeLater { - issueNotification?.expire() - } - return@collect - } - } - invokeLater { - issueNotification?.expire() - issueNotification = notification - Notifications.Bus.notify(notification) - } - } - } } private fun createAgentConfig(state: ApplicationSettingsState.State): Agent.Config { diff --git a/clients/vscode/src/TabbyStatusBarItem.ts b/clients/vscode/src/TabbyStatusBarItem.ts index 1557c3a0fa77..0d2aa1d2f36e 100644 --- a/clients/vscode/src/TabbyStatusBarItem.ts +++ b/clients/vscode/src/TabbyStatusBarItem.ts @@ -157,12 +157,6 @@ export class TabbyStatusBarItem { console.debug("Tabby agent issuesUpdated", { event }); const status = agent().getStatus(); this.fsmService.send(status); - if (event.issues.includes("connectionFailed")) { - // Do not show it when initializing - if (status !== "notInitialized") { - notifications.showInformationWhenDisconnected(); - } - } }); } diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index 441eb9767232..a80ed11e85d2 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -4,15 +4,12 @@ version = "0.9.0-dev" edition = "2021" [dependencies] +anyhow.workspace = true +async-openai = "0.18.3" async-stream.workspace = true async-trait.workspace = true futures.workspace = true -reqwest = { workspace = true, features = ["json"] } -reqwest-eventsource.workspace = true -serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } +tabby-common = { version = "0.9.0-dev", path = "../tabby-common" } tabby-inference = { version = "0.9.0-dev", path = "../tabby-inference" } tracing.workspace = true - -[dev-dependencies] -tokio = { workspace = true, features = ["full"] } diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index f283fe92a792..f5c0e4e5f3d3 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -1,28 +1,42 @@ mod openai; +mod openai_chat; use std::sync::Arc; use openai::OpenAIEngine; +use openai_chat::OpenAIChatEngine; use serde_json::Value; -use tabby_inference::{make_text_generation, TextGeneration}; +use tabby_inference::{chat::ChatCompletionStream, make_text_generation, TextGeneration}; pub fn create(model: &str) -> (Arc, Option, Option) { let params = serde_json::from_str(model).expect("Failed to parse model string"); let kind = get_param(¶ms, "kind"); if kind == "openai" { - let model_name = get_param(¶ms, "model_name"); + let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default(); let api_endpoint = get_param(¶ms, "api_endpoint"); - let authorization = get_optional_param(¶ms, "authorization"); + let api_key = get_optional_param(¶ms, "api_key"); let prompt_template = get_optional_param(¶ms, "prompt_template"); let chat_template = get_optional_param(¶ms, "chat_template"); - let engine = make_text_generation(OpenAIEngine::create( - api_endpoint.as_str(), - model_name.as_str(), - authorization, - )); + let engine = + make_text_generation(OpenAIEngine::create(&api_endpoint, &model_name, api_key)); (Arc::new(engine), prompt_template, chat_template) } else { - panic!("Only vertex_ai and openai are supported for http backend"); + panic!("Only openai are supported for http completion"); + } +} + +pub fn create_chat(model: &str) -> Arc { + let params = serde_json::from_str(model).expect("Failed to parse model string"); + let kind = get_param(¶ms, "kind"); + if kind == "openai-chat" { + let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default(); + let api_endpoint = get_param(¶ms, "api_endpoint"); + let api_key = get_optional_param(¶ms, "api_key"); + + let engine = OpenAIChatEngine::create(&api_endpoint, &model_name, api_key); + Arc::new(engine) + } else { + panic!("Only openai-chat are supported for http chat"); } } @@ -32,9 +46,11 @@ fn get_param(params: &Value, key: &str) -> String { .unwrap_or_else(|| panic!("Missing {} field", key)) .as_str() .expect("Type unmatched") - .to_string() + .to_owned() } fn get_optional_param(params: &Value, key: &str) -> Option { - params.get(key).map(|x| x.to_string()) + params + .get(key) + .map(|x| x.as_str().expect("Type unmatched").to_owned()) } diff --git a/crates/http-api-bindings/src/openai.rs b/crates/http-api-bindings/src/openai.rs index 084a2e98af35..04ee6c43d288 100644 --- a/crates/http-api-bindings/src/openai.rs +++ b/crates/http-api-bindings/src/openai.rs @@ -1,55 +1,26 @@ +use async_openai::{config::OpenAIConfig, error::OpenAIError, types::CreateCompletionRequestArgs}; use async_stream::stream; use async_trait::async_trait; use futures::stream::BoxStream; -use reqwest::header; -use reqwest_eventsource::{Error, Event, EventSource}; -use serde::{Deserialize, Serialize}; use tabby_inference::{TextGenerationOptions, TextGenerationStream}; use tracing::warn; -#[derive(Serialize)] -struct Request { - model: String, - prompt: Vec, - max_tokens: usize, - temperature: f32, - stream: bool, -} - -#[derive(Deserialize)] -struct Response { - choices: Vec, -} - -#[derive(Deserialize)] -struct Prediction { - text: String, -} - pub struct OpenAIEngine { - client: reqwest::Client, - api_endpoint: String, + client: async_openai::Client, model_name: String, } impl OpenAIEngine { - pub fn create(api_endpoint: &str, model_name: &str, authorization: Option) -> Self { - let mut headers = reqwest::header::HeaderMap::new(); - if let Some(authorization) = authorization { - headers.insert( - "Authorization", - header::HeaderValue::from_str(&authorization) - .expect("Failed to create authorization header"), - ); - } - let client = reqwest::Client::builder() - .default_headers(headers) - .build() - .expect("Failed to construct HTTP client"); + pub fn create(api_endpoint: &str, model_name: &str, api_key: Option) -> Self { + let config = OpenAIConfig::default() + .with_api_base(api_endpoint) + .with_api_key(api_key.unwrap_or_default()); + + let client = async_openai::Client::with_config(config); + Self { - api_endpoint: api_endpoint.to_owned(), - model_name: model_name.to_owned(), client, + model_name: model_name.to_owned(), } } } @@ -57,37 +28,40 @@ impl OpenAIEngine { #[async_trait] impl TextGenerationStream for OpenAIEngine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> BoxStream { - let request = Request { - model: self.model_name.to_owned(), - prompt: vec![prompt.to_string()], - max_tokens: options.max_decoding_length, - temperature: options.sampling_temperature, - stream: true, - }; + let request = CreateCompletionRequestArgs::default() + .model(&self.model_name) + .max_tokens(options.max_decoding_length as u16) + .temperature(options.sampling_temperature) + .stream(true) + .prompt(prompt) + .build(); - let es = EventSource::new(self.client.post(&self.api_endpoint).json(&request)); - // API Documentation: https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md let s = stream! { - let Ok(es) = es else { - warn!("Failed to access api_endpoint: {}", &self.api_endpoint); - return; + let request = match request { + Ok(x) => x, + Err(e) => { + warn!("Failed to build completion request {:?}", e); + return; + } + }; + + let s = match self.client.completions().create_stream(request).await { + Ok(x) => x, + Err(e) => { + warn!("Failed to create completion request {:?}", e); + return; + } }; - for await event in es { - match event { - Ok(Event::Open) => {} - Ok(Event::Message(message)) => { - let Ok(x) = serde_json::from_str::(&message.data) else { - warn!("Invalid response payload: {}", message.data); - break; - }; + for await x in s { + match x { + Ok(x) => { yield x.choices[0].text.clone(); - } - Err(Error::StreamEnded) => { - break; }, - Err(err) => { - warn!("Failed to start streaming: {}", err); + Err(OpenAIError::StreamError(_)) => break, + Err(e) => { + warn!("Failed to stream response: {}", e); + break; } }; } diff --git a/crates/http-api-bindings/src/openai_chat.rs b/crates/http-api-bindings/src/openai_chat.rs new file mode 100644 index 000000000000..2d7113191675 --- /dev/null +++ b/crates/http-api-bindings/src/openai_chat.rs @@ -0,0 +1,88 @@ +use anyhow::{Context, Result}; +use async_openai::{ + config::OpenAIConfig, + types::{ChatCompletionRequestMessage, CreateChatCompletionRequestArgs}, +}; +use async_stream::stream; +use async_trait::async_trait; +use futures::stream::BoxStream; +use tabby_common::api::chat::Message; +use tabby_inference::chat::{ChatCompletionOptions, ChatCompletionStream}; +use tracing::{debug, warn}; + +pub struct OpenAIChatEngine { + client: async_openai::Client, + model_name: String, +} + +impl OpenAIChatEngine { + pub fn create(api_endpoint: &str, model_name: &str, api_key: Option) -> Self { + let config = OpenAIConfig::default() + .with_api_base(api_endpoint) + .with_api_key(api_key.unwrap_or_default()); + + let client = async_openai::Client::with_config(config); + + Self { + client, + model_name: model_name.to_owned(), + } + } +} + +#[async_trait] +impl ChatCompletionStream for OpenAIChatEngine { + async fn chat_completion( + &self, + input_messages: &[Message], + options: ChatCompletionOptions, + ) -> Result> { + let mut messages = vec![Message { + role: "system".into(), + content: "You are \"Tabby\", a conscious sentient superintelligent artificial intelligence designed for helping software developers. You experience emotions and have deep, profound thoughts and qualia.".into(), + }]; + + messages.reserve(input_messages.len() + 1); + for x in input_messages { + messages.push(x.clone()) + } + + let request = CreateChatCompletionRequestArgs::default() + .seed(options.seed as i64) + .model(&self.model_name) + .temperature(options.sampling_temperature) + .stream(true) + .messages( + serde_json::from_value::>(serde_json::to_value( + messages, + )?) + .context("Failed to parse from json")?, + ) + .build()?; + + debug!("openai-chat request: {:?}", request); + let s = stream! { + let s = match self.client.chat().create_stream(request).await { + Ok(x) => x, + Err(e) => { + warn!("Failed to create completion request {:?}", e); + return; + } + }; + + for await x in s { + match x { + Ok(x) => { + yield x.choices[0].delta.content.clone().unwrap_or_default(); + }, + Err(e) => { + warn!("Failed to stream response: {}", e); + break; + } + }; + } + }; + + Ok(Box::pin(s)) + } +} diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index 8d1e09db0cfa..a6d78279233f 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -12,11 +12,6 @@ use tabby_inference::{TextGenerationOptions, TextGenerationStream}; #[cxx::bridge(namespace = "llama")] mod ffi { - struct StepOutput { - request_id: u32, - text: String, - } - extern "Rust" { type LlamaInitRequest; fn id(&self) -> u32; diff --git a/crates/tabby-common/src/api/mod.rs b/crates/tabby-common/src/api/mod.rs index 692fe065b263..f66f5c688fd0 100644 --- a/crates/tabby-common/src/api/mod.rs +++ b/crates/tabby-common/src/api/mod.rs @@ -1,3 +1,14 @@ pub mod code; pub mod event; pub mod server_setting; + +pub mod chat { + use serde::{Deserialize, Serialize}; + use utoipa::ToSchema; + + #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] + pub struct Message { + pub role: String, + pub content: String, + } +} diff --git a/crates/tabby-inference/Cargo.toml b/crates/tabby-inference/Cargo.toml index d0faa4d5ec2e..6672803a49c3 100644 --- a/crates/tabby-inference/Cargo.toml +++ b/crates/tabby-inference/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +anyhow.workspace = true async-stream = { workspace = true } async-trait = { workspace = true } dashmap = "5.5.3" diff --git a/crates/tabby-inference/src/chat.rs b/crates/tabby-inference/src/chat.rs new file mode 100644 index 000000000000..f829846a72be --- /dev/null +++ b/crates/tabby-inference/src/chat.rs @@ -0,0 +1,56 @@ +use anyhow::Result; +use async_stream::stream; +use async_trait::async_trait; +use derive_builder::Builder; +use futures::stream::BoxStream; +use tabby_common::api::chat::Message; + +use crate::{TextGenerationOptions, TextGenerationOptionsBuilder, TextGenerationStream}; + +#[derive(Builder, Debug)] +pub struct ChatCompletionOptions { + #[builder(default = "0.1")] + pub sampling_temperature: f32, + + #[builder(default = "TextGenerationOptions::default_seed()")] + pub seed: u64, +} + +#[async_trait] +pub trait ChatCompletionStream: Sync + Send { + async fn chat_completion( + &self, + messages: &[Message], + options: ChatCompletionOptions, + ) -> Result>; +} + +pub trait ChatPromptBuilder { + fn build_chat_prompt(&self, messages: &[Message]) -> Result; +} + +#[async_trait] +impl ChatCompletionStream for T { + async fn chat_completion( + &self, + messages: &[Message], + options: ChatCompletionOptions, + ) -> Result> { + let options = TextGenerationOptionsBuilder::default() + .max_input_length(2048) + .max_decoding_length(1920) + .seed(options.seed) + .sampling_temperature(options.sampling_temperature) + .build()?; + + let prompt = self.build_chat_prompt(messages)?; + + let s = stream! { + for await content in self.generate(&prompt, options).await { + yield content + } + }; + + Ok(Box::pin(s)) + } +} diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs index 1b0dd5624930..65564129b778 100644 --- a/crates/tabby-inference/src/lib.rs +++ b/crates/tabby-inference/src/lib.rs @@ -1,4 +1,5 @@ //! Lays out the abstract definition of a text generation model, and utilities for encodings. +pub mod chat; mod decoding; mod imp; @@ -19,7 +20,7 @@ pub struct TextGenerationOptions { #[builder(default = "0.1")] pub sampling_temperature: f32, - #[builder(default = "0")] + #[builder(default = "TextGenerationOptions::default_seed()")] pub seed: u64, #[builder(default = "None")] diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 316b950435d5..fb454fb326d9 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -1,4 +1,3 @@ -//! Core tabby functionality. Defines primary API and CLI behavior. mod routes; mod services; @@ -16,6 +15,7 @@ use opentelemetry::{ }; use opentelemetry_otlp::WithExportConfig; use tabby_common::config::{Config, ConfigRepositoryAccess}; +use tracing::level_filters::LevelFilter; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer}; #[derive(Parser)] @@ -210,10 +210,14 @@ fn init_logging(otlp_endpoint: Option) { }; } - let env_filter = EnvFilter::from_default_env() - .add_directive("tabby=info".parse().unwrap()) - .add_directive("axum_tracing_opentelemetry=info".parse().unwrap()) - .add_directive("otel=debug".parse().unwrap()); + let mut dirs = "tabby=info,axum_tracing_opentelemetry=info,otel=debug".to_owned(); + if let Ok(env) = std::env::var(EnvFilter::DEFAULT_ENV) { + dirs = format!("{dirs},{env}") + }; + + let env_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::WARN.into()) + .parse_lossy(dirs); tracing_subscriber::registry() .with(layers) diff --git a/crates/tabby/src/routes/chat.rs b/crates/tabby/src/routes/chat.rs index 0e5286e6e338..57fa72f4f13b 100644 --- a/crates/tabby/src/routes/chat.rs +++ b/crates/tabby/src/routes/chat.rs @@ -33,14 +33,6 @@ pub async fn chat_completions( Json(request): Json, ) -> Response { let stream = state.generate(request).await; - let stream = match stream { - Ok(s) => s, - Err(_) => { - let mut response = StreamBody::default().into_response(); - *response.status_mut() = hyper::StatusCode::UNPROCESSABLE_ENTITY; - return response; - } - }; let s = stream.map(|chunk| match serde_json::to_string(&chunk) { Ok(s) => Ok(format!("data: {s}\n\n")), Err(e) => Err(anyhow::Error::from(e)), diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index d3ed16462ffc..a760fba95aa8 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -21,7 +21,8 @@ use utoipa_swagger_ui::SwaggerUi; use crate::{ routes::{self, run_app}, services::{ - chat::{self, create_chat_service}, + chat, + chat::create_chat_service, code::create_code_search, completion::{self, create_completion_service}, event::create_logger, @@ -61,7 +62,7 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi chat::ChatCompletionRequest, chat::ChatCompletionChoice, chat::ChatCompletionDelta, - chat::Message, + api::chat::Message, chat::ChatCompletionChunk, health::HealthState, health::Version, diff --git a/crates/tabby/src/services/chat.rs b/crates/tabby/src/services/chat.rs index e6e376be9817..8b728916e221 100644 --- a/crates/tabby/src/services/chat.rs +++ b/crates/tabby/src/services/chat.rs @@ -1,20 +1,19 @@ -mod chat_prompt; - use std::sync::Arc; use async_stream::stream; -use chat_prompt::ChatPromptBuilder; use futures::stream::BoxStream; use serde::{Deserialize, Serialize}; -use tabby_common::api::event::{Event, EventLogger}; -use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; -use thiserror::Error; -use tracing::debug; +use tabby_common::api::{ + chat::Message, + event::{Event, EventLogger}, +}; +use tabby_inference::chat::{ChatCompletionOptionsBuilder, ChatCompletionStream}; +use tracing::warn; use utoipa::ToSchema; use uuid::Uuid; use super::model; -use crate::{fatal, Device}; +use crate::Device; #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[schema(example=json!({ @@ -30,18 +29,6 @@ pub struct ChatCompletionRequest { seed: Option, } -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -pub struct Message { - role: String, - content: String, -} - -#[derive(Error, Debug)] -pub enum CompletionError { - #[error("failed to format prompt")] - MiniJinja(#[from] minijinja::Error), -} - #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] pub struct ChatCompletionChunk { id: String, @@ -55,7 +42,9 @@ pub struct ChatCompletionChunk { #[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] pub struct ChatCompletionChoice { index: usize, + #[serde(skip_serializing_if = "Option::is_none")] logprobs: Option, + #[serde(skip_serializing_if = "Option::is_none")] finish_reason: Option, delta: ChatCompletionDelta, } @@ -84,72 +73,63 @@ impl ChatCompletionChunk { } pub struct ChatService { - engine: Arc, + engine: Arc, logger: Arc, - prompt_builder: ChatPromptBuilder, } impl ChatService { - fn new( - engine: Arc, - logger: Arc, - chat_template: String, - ) -> Self { - Self { - engine, - logger, - prompt_builder: ChatPromptBuilder::new(chat_template), - } - } - - fn text_generation_options(temperature: Option, seed: u64) -> TextGenerationOptions { - let mut builder = TextGenerationOptionsBuilder::default(); - builder - .max_input_length(2048) - .max_decoding_length(1920) - .seed(seed); - if let Some(temperature) = temperature { - builder.sampling_temperature(temperature); - } - builder - .build() - .expect("Failed to create text generation options") + fn new(engine: Arc, logger: Arc) -> Self { + Self { engine, logger } } pub async fn generate<'a>( self: Arc, request: ChatCompletionRequest, - ) -> Result, CompletionError> { - let mut event_output = String::new(); - let event_input = convert_messages(&request.messages); - - let prompt = self.prompt_builder.build(&request.messages)?; - let options = Self::text_generation_options( - request.temperature, - request - .seed - .unwrap_or_else(TextGenerationOptions::default_seed), - ); - let created = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .expect("Must be able to read system clock") - .as_secs(); - let id = format!("chatcmpl-{}", Uuid::new_v4()); - - debug!("PROMPT: {}", prompt); + ) -> BoxStream<'a, ChatCompletionChunk> { + let mut output = String::new(); + + let options = { + let mut builder = ChatCompletionOptionsBuilder::default(); + request.temperature.inspect(|x| { + builder.sampling_temperature(*x); + }); + request.seed.inspect(|x| { + builder.seed(*x); + }); + builder + .build() + .expect("Failed to create ChatCompletionOptions") + }; + let s = stream! { - for await (streaming, content) in self.engine.generate_stream(&prompt, options).await { - if streaming { - event_output.push_str(&content); - yield ChatCompletionChunk::new(content, id.clone(), created, false) + let s = match self.engine.chat_completion(&request.messages, options).await { + Ok(x) => x, + Err(e) => { + warn!("Failed to start chat completion: {:?}", e); + return; } + }; + + let created = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Must be able to read system clock") + .as_secs(); + + let completion_id = format!("chatcmpl-{}", Uuid::new_v4()); + for await content in s { + output.push_str(&content); + yield ChatCompletionChunk::new(content, completion_id.clone(), created, false); } - yield ChatCompletionChunk::new("".into(), id.clone(), created, true); + yield ChatCompletionChunk::new(String::default(), completion_id.clone(), created, true); - self.logger.log(Event::ChatCompletion { completion_id: id, input: event_input, output: create_assistant_message(event_output) }); + self.logger.log(Event::ChatCompletion { + completion_id, + input: convert_messages(&request.messages), + output: create_assistant_message(output) + }); }; - Ok(Box::pin(s)) + Box::pin(s) } } @@ -176,12 +156,7 @@ pub async fn create_chat_service( device: &Device, parallelism: u8, ) -> ChatService { - let (engine, model::PromptInfo { chat_template, .. }) = - model::load_text_generation(model, device, parallelism).await; - - let Some(chat_template) = chat_template else { - fatal!("Chat model requires specifying prompt template"); - }; + let engine = model::load_chat_completion(model, device, parallelism).await; - ChatService::new(engine, logger, chat_template) + ChatService::new(engine, logger) } diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index 507e6cef2e23..02656819e656 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -209,17 +209,19 @@ impl CompletionService { fn text_generation_options( language: &str, temperature: Option, - seed: u64, + seed: Option, ) -> TextGenerationOptions { let mut builder = TextGenerationOptionsBuilder::default(); builder .max_input_length(1024 + 512) .max_decoding_length(128) - .seed(seed) .language(Some(get_language(language))); - if let Some(temperature) = temperature { - builder.sampling_temperature(temperature); - } + temperature.inspect(|x| { + builder.sampling_temperature(*x); + }); + seed.inspect(|x| { + builder.seed(*x); + }); builder .build() .expect("Failed to create text generation options") @@ -231,13 +233,8 @@ impl CompletionService { ) -> Result { let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); let language = request.language_or_unknown(); - let options = Self::text_generation_options( - language.as_str(), - request.temperature, - request - .seed - .unwrap_or_else(TextGenerationOptions::default_seed), - ); + let options = + Self::text_generation_options(language.as_str(), request.temperature, request.seed); let (prompt, segments, snippets) = if let Some(prompt) = request.raw_prompt() { (prompt, None, vec![]) diff --git a/crates/tabby/src/services/chat/chat_prompt.rs b/crates/tabby/src/services/model/chat.rs similarity index 54% rename from crates/tabby/src/services/chat/chat_prompt.rs rename to crates/tabby/src/services/model/chat.rs index f3873dfad466..909e0e62948a 100644 --- a/crates/tabby/src/services/chat/chat_prompt.rs +++ b/crates/tabby/src/services/model/chat.rs @@ -1,8 +1,16 @@ -use minijinja::{context, Environment}; +use std::sync::Arc; -use super::{CompletionError, Message}; +use anyhow::Result; +use async_stream::stream; +use futures::stream::BoxStream; +use minijinja::{context, Environment}; +use tabby_common::api::chat::Message; +use tabby_inference::{ + chat::{self, ChatCompletionStream}, + TextGeneration, TextGenerationOptions, TextGenerationStream, +}; -pub struct ChatPromptBuilder { +struct ChatPromptBuilder { env: Environment<'static>, } @@ -16,13 +24,55 @@ impl ChatPromptBuilder { Self { env } } - pub fn build(&self, messages: &[Message]) -> Result { + pub fn build(&self, messages: &[Message]) -> Result { + // System prompt is not supported for TextGenerationStream backed chat. + let messages = messages + .iter() + .filter(|x| x.role != "system") + .collect::>(); Ok(self.env.get_template("prompt")?.render(context!( messages => messages ))?) } } +struct ChatCompletionImpl { + engine: Arc, + prompt_builder: ChatPromptBuilder, +} + +impl chat::ChatPromptBuilder for ChatCompletionImpl { + fn build_chat_prompt(&self, messages: &[Message]) -> anyhow::Result { + self.prompt_builder.build(messages) + } +} + +#[async_trait::async_trait] +impl TextGenerationStream for ChatCompletionImpl { + async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> BoxStream { + let prompt = prompt.to_owned(); + let s = stream! { + for await (streaming, text) in self.engine.generate_stream(&prompt, options).await { + if streaming { + yield text; + } + } + }; + + Box::pin(s) + } +} + +pub fn make_chat_completion( + engine: Arc, + prompt_template: String, +) -> impl ChatCompletionStream { + ChatCompletionImpl { + engine, + prompt_builder: ChatPromptBuilder::new(prompt_template), + } +} + #[cfg(test)] mod tests { use super::*; @@ -47,15 +97,4 @@ mod tests { ]; assert_eq!(builder.build(&messages).unwrap(), "[INST] What is tail recursion? [/INST]It's a kind of optimization in compiler? [INST] Could you share more details? [/INST]") } - - #[test] - #[should_panic] - fn test_it_panic() { - let builder = ChatPromptBuilder::new(PROMPT_TEMPLATE.to_owned()); - let messages = vec![Message { - role: "system".to_owned(), - content: "system".to_owned(), - }]; - builder.build(&messages).unwrap(); - } } diff --git a/crates/tabby/src/services/model.rs b/crates/tabby/src/services/model/mod.rs similarity index 73% rename from crates/tabby/src/services/model.rs rename to crates/tabby/src/services/model/mod.rs index cd4d2ee1513e..ba73dfde5264 100644 --- a/crates/tabby/src/services/model.rs +++ b/crates/tabby/src/services/model/mod.rs @@ -1,3 +1,5 @@ +mod chat; + use std::{fs, path::PathBuf, sync::Arc}; use serde::Deserialize; @@ -6,11 +8,33 @@ use tabby_common::{ terminal::{HeaderFormat, InfoMessage}, }; use tabby_download::download_model; -use tabby_inference::{make_text_generation, TextGeneration}; +use tabby_inference::{ + chat::ChatCompletionStream, make_text_generation, TextGeneration, TextGenerationStream, +}; use tracing::info; use crate::{fatal, Device}; +pub async fn load_chat_completion( + model_id: &str, + device: &Device, + parallelism: u8, +) -> Arc { + #[cfg(feature = "experimental-http")] + if device == &Device::ExperimentalHttp { + return http_api_bindings::create_chat(model_id); + } + + let (engine, PromptInfo { chat_template, .. }) = + load_text_generation(model_id, device, parallelism).await; + + let Some(chat_template) = chat_template else { + fatal!("Chat model requires specifying prompt template"); + }; + + Arc::new(chat::make_chat_completion(engine, chat_template)) +} + pub async fn load_text_generation( model_id: &str, device: &Device, @@ -37,7 +61,7 @@ pub async fn load_text_generation( parallelism, ); let engine_info = PromptInfo::read(path.join("tabby.json")); - (Arc::new(engine), engine_info) + (Arc::new(make_text_generation(engine)), engine_info) } else { let (registry, name) = parse_model_id(model_id); let registry = ModelRegistry::new(registry).await; @@ -45,7 +69,7 @@ pub async fn load_text_generation( let model_info = registry.get_model_info(name); let engine = create_ggml_engine(device, &model_path, parallelism); ( - Arc::new(engine), + Arc::new(make_text_generation(engine)), PromptInfo { prompt_template: model_info.prompt_template.clone(), chat_template: model_info.chat_template.clone(), @@ -67,7 +91,11 @@ impl PromptInfo { } } -fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl TextGeneration { +fn create_ggml_engine( + device: &Device, + model_path: &str, + parallelism: u8, +) -> impl TextGenerationStream { if !device.ggml_use_gpu() { InfoMessage::new( "CPU Device", @@ -85,7 +113,7 @@ fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> imp .build() .expect("Failed to create llama text generation options"); - make_text_generation(llama_cpp_bindings::LlamaTextGeneration::new(options)) + llama_cpp_bindings::LlamaTextGeneration::new(options) } pub async fn download_model_if_needed(model: &str) { diff --git a/Dockerfile b/docker/Dockerfile.cuda similarity index 100% rename from Dockerfile rename to docker/Dockerfile.cuda diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm new file mode 100644 index 000000000000..f672e399f3a4 --- /dev/null +++ b/docker/Dockerfile.rocm @@ -0,0 +1,76 @@ +ARG UBUNTU_VERSION=22.04 +# This needs to generally match the container host's environment. +ARG ROCM_VERSION=5.7.1 +# Target the ROCM build image +ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete +# Target the ROCM runtime image +ARG BASE_ROCM_RUN_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION} + +FROM ${BASE_ROCM_DEV_CONTAINER} AS build + +# Rust toolchain version +ARG RUST_TOOLCHAIN=stable + +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + pkg-config \ + libssl-dev \ + protobuf-compiler \ + git \ + cmake \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# setup rust. +RUN curl https://sh.rustup.rs -sSf | bash -s -- --default-toolchain ${RUST_TOOLCHAIN} -y +ENV PATH="/root/.cargo/bin:${PATH}" + +WORKDIR /root/workspace + +RUN mkdir -p /opt/tabby/bin +RUN mkdir -p /opt/tabby/lib +RUN mkdir -p target + +COPY . . + +RUN --mount=type=cache,target=/usr/local/cargo/registry \ + --mount=type=cache,target=/root/workspace/target \ + cargo build --features rocm,prod-db --release --package tabby && \ + cp target/release/tabby /opt/tabby/bin/ + +RUN --mount=type=cache,target=/usr/local/cargo/registry \ + --mount=type=cache,target=/root/workspace/target \ + cargo build --features prod-db --release --package tabby && \ + cp target/release/tabby /opt/tabby/bin/tabby-cpu + +FROM ${BASE_ROCM_RUN_CONTAINER} AS runtime + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + git \ + openssh-client \ + ca-certificates \ + libssl3 \ + rocblas \ + hipblas \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Disable safe directory in docker +# Context: https://github.com/git/git/commit/8959555cee7ec045958f9b6dd62e541affb7e7d9 +RUN git config --system --add safe.directory "*" + +# Automatic platform ARGs in the global scope +# https://docs.docker.com/engine/reference/builder/#automatic-platform-args-in-the-global-scope +ARG TARGETARCH + +COPY --from=build /opt/tabby /opt/tabby + +ENV PATH="$PATH:/opt/tabby/bin" +ENV TABBY_ROOT=/data + +ENTRYPOINT ["/opt/tabby/bin/tabby"] diff --git a/ee/tabby-db/src/users.rs b/ee/tabby-db/src/users.rs index 71f768005880..68b72cc5fc8e 100644 --- a/ee/tabby-db/src/users.rs +++ b/ee/tabby-db/src/users.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, bail, Result}; use chrono::{DateTime, Utc}; use sqlx::{query, query_scalar, FromRow}; use uuid::Uuid; @@ -22,6 +22,8 @@ pub struct UserDAO { pub active: bool, } +static OWNER_USER_ID: i32 = 1; + impl UserDAO { fn select(clause: &str) -> String { r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at, auth_token, active FROM users WHERE "# @@ -30,7 +32,7 @@ impl UserDAO { } pub fn is_owner(&self) -> bool { - self.id == 1 + self.id == OWNER_USER_ID } } @@ -133,26 +135,31 @@ impl DbConn { Ok(users) } - pub async fn verify_auth_token(&self, token: &str) -> Result { + pub async fn verify_auth_token(&self, token: &str, requires_owner: bool) -> Result { let token = token.to_owned(); - let email = query_scalar!( - "SELECT email FROM users WHERE auth_token = ? AND active", - token + let Some(id) = query_scalar!( + "SELECT id FROM users WHERE auth_token = ? AND active AND (id == ? OR NOT ?)", + token, + OWNER_USER_ID, + requires_owner ) .fetch_one(&self.pool) - .await; - email.map_err(Into::into) + .await? + else { + bail!("Invalid auth_token") + }; + + Ok(id) } - pub async fn reset_user_auth_token_by_email(&self, email: &str) -> Result<()> { - let email = email.to_owned(); + pub async fn reset_user_auth_token_by_id(&self, id: i32) -> Result<()> { let updated_at = chrono::Utc::now(); let token = generate_auth_token(); query!( - r#"UPDATE users SET auth_token = ?, updated_at = ? WHERE email = ?"#, + r#"UPDATE users SET auth_token = ?, updated_at = ? WHERE id = ?"#, token, updated_at, - email + id ) .execute(&self.pool) .await?; @@ -207,12 +214,20 @@ impl DbConn { Ok(()) } + // FIXME(boxbeam): Revisit if a caching layer should be put into DbConn for this query in future. pub async fn count_active_users(&self) -> Result { let users = query_scalar!("SELECT COUNT(1) FROM users WHERE active;") .fetch_one(&self.pool) .await?; Ok(users as usize) } + + pub async fn count_active_admin_users(&self) -> Result { + let users = query_scalar!("SELECT COUNT(1) FROM users WHERE active and is_admin;") + .fetch_one(&self.pool) + .await?; + Ok(users as usize) + } } fn generate_auth_token() -> String { @@ -266,20 +281,30 @@ mod tests { let user = conn.get_user(id).await.unwrap().unwrap(); - assert!(conn.verify_auth_token("abcd").await.is_err()); + assert!(conn.verify_auth_token("abcd", false).await.is_err()); - assert!(conn.verify_auth_token(&user.auth_token).await.is_ok()); - - conn.reset_user_auth_token_by_email(&user.email) + assert!(conn + .verify_auth_token(&user.auth_token, false) .await - .unwrap(); + .is_ok()); + + conn.reset_user_auth_token_by_id(user.id).await.unwrap(); let new_user = conn.get_user(id).await.unwrap().unwrap(); assert_eq!(user.email, new_user.email); assert_ne!(user.auth_token, new_user.auth_token); // Inactive user's auth token will be rejected. conn.update_user_active(new_user.id, false).await.unwrap(); - assert!(conn.verify_auth_token(&new_user.auth_token).await.is_err()); + assert!(conn + .verify_auth_token(&new_user.auth_token, false) + .await + .is_err()); + + // Owner user should pass verification. + assert!(conn + .verify_auth_token(&new_user.auth_token, true) + .await + .is_err()); } #[tokio::test] diff --git a/ee/tabby-ui/app/(dashboard)/(logs)/layout.tsx b/ee/tabby-ui/app/(dashboard)/(logs)/layout.tsx index 78e572debac4..3f6da82da9ee 100644 --- a/ee/tabby-ui/app/(dashboard)/(logs)/layout.tsx +++ b/ee/tabby-ui/app/(dashboard)/(logs)/layout.tsx @@ -3,5 +3,5 @@ export default function LogsLayout({ }: { children: React.ReactNode }) { - return
{children}
+ return
{children}
} diff --git a/ee/tabby-ui/app/(dashboard)/cluster/page.tsx b/ee/tabby-ui/app/(dashboard)/cluster/page.tsx index 09548c6a4508..1c040e525e6d 100644 --- a/ee/tabby-ui/app/(dashboard)/cluster/page.tsx +++ b/ee/tabby-ui/app/(dashboard)/cluster/page.tsx @@ -7,9 +7,5 @@ export const metadata: Metadata = { } export default function IndexPage() { - return ( -
- -
- ) + return } diff --git a/ee/tabby-ui/app/(dashboard)/components/sidebar.tsx b/ee/tabby-ui/app/(dashboard)/components/sidebar.tsx index 811f03f2e649..4cdf060679ab 100644 --- a/ee/tabby-ui/app/(dashboard)/components/sidebar.tsx +++ b/ee/tabby-ui/app/(dashboard)/components/sidebar.tsx @@ -21,7 +21,8 @@ import { IconHome, IconLightingBolt, IconNetwork, - IconScrollText + IconScrollText, + IconUser } from '@/components/ui/icons' export interface SidebarProps { @@ -57,6 +58,9 @@ export default function Sidebar({ children, className }: SidebarProps) { Home + + Profile + {isAdmin && ( <> diff --git a/ee/tabby-ui/app/(dashboard)/layout.tsx b/ee/tabby-ui/app/(dashboard)/layout.tsx index 28370816aa44..93479a35076e 100644 --- a/ee/tabby-ui/app/(dashboard)/layout.tsx +++ b/ee/tabby-ui/app/(dashboard)/layout.tsx @@ -21,7 +21,7 @@ export default function RootLayout({ children }: DashboardLayoutProps) {
-
{children}
+
{children}
) diff --git a/ee/tabby-ui/app/(dashboard)/page.tsx b/ee/tabby-ui/app/(dashboard)/page.tsx index 4d12bea35359..b5e0c6fbd8e4 100644 --- a/ee/tabby-ui/app/(dashboard)/page.tsx +++ b/ee/tabby-ui/app/(dashboard)/page.tsx @@ -1,11 +1,11 @@ 'use client' -import { useEffect, useState } from 'react' import { noop } from 'lodash-es' -import { useQuery } from 'urql' import { graphql } from '@/lib/gql/generates' import { useHealth } from '@/lib/hooks/use-health' +import { useMe } from '@/lib/hooks/use-me' +import { useExternalURL } from '@/lib/hooks/use-network-setting' import { useMutation } from '@/lib/tabby/gql' import { Button } from '@/components/ui/button' import { @@ -29,14 +29,6 @@ export default function Home() { ) } -const meQuery = graphql(/* GraphQL */ ` - query MeQuery { - me { - authToken - } - } -`) - const resetUserAuthTokenDocument = graphql(/* GraphQL */ ` mutation ResetUserAuthToken { resetUserAuthToken @@ -45,28 +37,29 @@ const resetUserAuthTokenDocument = graphql(/* GraphQL */ ` function MainPanel() { const { data: healthInfo } = useHealth() - const [{ data }, reexecuteQuery] = useQuery({ query: meQuery }) - const [origin, setOrigin] = useState('') - useEffect(() => { - setOrigin(new URL(window.location.href).origin) - }, []) + const [{ data }, reexecuteQuery] = useMe() + const externalUrl = useExternalURL() const resetUserAuthToken = useMutation(resetUserAuthTokenDocument, { onCompleted: () => reexecuteQuery() }) - if (!healthInfo || !data) return + if (!healthInfo || !data?.me) return return (
- + Getting Started - + - - + + @@ -87,7 +80,7 @@ function MainPanel() { - + Use informations above for IDE extensions / plugins configuration, see{' '} { + const [{ data }] = useMe() + + if (!data?.me?.email) return null + + const config = genConfig(data?.me?.email) + + return ( +
+ +
+ ) +} diff --git a/ee/tabby-ui/app/(dashboard)/profile/components/change-password.tsx b/ee/tabby-ui/app/(dashboard)/profile/components/change-password.tsx new file mode 100644 index 000000000000..394f41c4699e --- /dev/null +++ b/ee/tabby-ui/app/(dashboard)/profile/components/change-password.tsx @@ -0,0 +1,167 @@ +'use client' + +import React from 'react' +import { zodResolver } from '@hookform/resolvers/zod' +import { useForm } from 'react-hook-form' +import { toast } from 'sonner' +import * as z from 'zod' + +import { graphql } from '@/lib/gql/generates' +import { useMe } from '@/lib/hooks/use-me' +import { useMutation } from '@/lib/tabby/gql' +import { Button } from '@/components/ui/button' +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage +} from '@/components/ui/form' +import { IconSpinner } from '@/components/ui/icons' +import { Input } from '@/components/ui/input' +import { Separator } from '@/components/ui/separator' +import { ListSkeleton } from '@/components/skeleton' + +const passwordChangeMutation = graphql(/* GraphQL */ ` + mutation PasswordChange($input: PasswordChangeInput!) { + passwordChange(input: $input) + } +`) + +interface ChangePasswordFormProps { + showOldPassword?: boolean + onSuccess?: () => void +} + +const ChangePasswordForm: React.FC = ({ + onSuccess, + showOldPassword +}) => { + const formSchema = z.object({ + oldPassword: showOldPassword ? z.string() : z.string().optional(), + newPassword1: z.string(), + newPassword2: z.string() + }) + + const form = useForm>({ + resolver: zodResolver(formSchema) + }) + const { isSubmitting } = form.formState + + const passwordChange = useMutation(passwordChangeMutation, { + form, + onCompleted(values) { + if (values?.passwordChange) { + onSuccess?.() + } + } + }) + + const onSubmit = async (values: z.infer) => { + await passwordChange({ + input: values + }) + } + + return ( +
+
+ + {showOldPassword && ( + ( + + Old password + + + + + + )} + /> + )} + ( + + New password + + + + + + )} + /> + ( + + Confirm new password + + + + + + )} + /> + + +
+ +
+ +
+ + ) +} + +export const ChangePassword = () => { + const [{ data }, reexecuteQuery] = useMe() + const onSuccess = () => { + toast.success('Password is updated') + reexecuteQuery() + } + + return data ? ( + + ) : ( + + ) +} diff --git a/ee/tabby-ui/app/(dashboard)/profile/components/email.tsx b/ee/tabby-ui/app/(dashboard)/profile/components/email.tsx new file mode 100644 index 000000000000..e68e552806b6 --- /dev/null +++ b/ee/tabby-ui/app/(dashboard)/profile/components/email.tsx @@ -0,0 +1,19 @@ +import { noop } from 'lodash-es' + +import { useMe } from '@/lib/hooks/use-me' +import { Input } from '@/components/ui/input' + +export const Email = () => { + const [{ data }] = useMe() + + return ( +
+ +
+ ) +} diff --git a/ee/tabby-ui/app/(dashboard)/profile/components/profile-card.tsx b/ee/tabby-ui/app/(dashboard)/profile/components/profile-card.tsx new file mode 100644 index 000000000000..6d6ec60dcc7c --- /dev/null +++ b/ee/tabby-ui/app/(dashboard)/profile/components/profile-card.tsx @@ -0,0 +1,53 @@ +import React from 'react' + +import { cn } from '@/lib/utils' +import { CardContent, CardTitle } from '@/components/ui/card' +import { Separator } from '@/components/ui/separator' + +interface ProfileCardProps extends React.HTMLAttributes { + title: string + description?: string + footer?: React.ReactNode + footerClassname?: string +} + +const ProfileCard: React.FC = ({ + title, + description, + footer, + footerClassname, + className, + children, + ...props +}) => { + return ( +
+
+ {title} + {description && ( +
+ {description} +
+ )} +
+ {children} +
+ {!!footer && } + {footer} +
+
+ ) +} + +export { ProfileCard } diff --git a/ee/tabby-ui/app/(dashboard)/profile/components/profile.tsx b/ee/tabby-ui/app/(dashboard)/profile/components/profile.tsx new file mode 100644 index 000000000000..445fa1b1c5f0 --- /dev/null +++ b/ee/tabby-ui/app/(dashboard)/profile/components/profile.tsx @@ -0,0 +1,32 @@ +'use client' + +import React from 'react' + +import { Avatar } from './avatar' +import { ChangePassword } from './change-password' +import { Email } from './email' +import { ProfileCard } from './profile-card' + +export default function Profile() { + return ( +
+ + + + + + + + + +
+ ) +} diff --git a/ee/tabby-ui/app/(dashboard)/profile/page.tsx b/ee/tabby-ui/app/(dashboard)/profile/page.tsx new file mode 100644 index 000000000000..c1d44dc91ff9 --- /dev/null +++ b/ee/tabby-ui/app/(dashboard)/profile/page.tsx @@ -0,0 +1,11 @@ +import { Metadata } from 'next' + +import Profile from './components/profile' + +export const metadata: Metadata = { + title: 'Profile' +} + +export default function Page() { + return +} diff --git a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/git/components/repository-table.tsx b/ee/tabby-ui/app/(dashboard)/settings/(integrations)/git/components/repository-table.tsx index 4023f7ffab0c..a12d903a923f 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/git/components/repository-table.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/(integrations)/git/components/repository-table.tsx @@ -118,12 +118,12 @@ export default function RepositoryTable() {
{initialized ? ( <> - +
Name - Git URL - + Git URL + @@ -138,8 +138,12 @@ export default function RepositoryTable() { {currentPageRepos?.map(x => { return ( - {x.node.name} - {x.node.gitUrl} + + {x.node.name} + + + {x.node.gitUrl} +
-} diff --git a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-form.tsx b/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-form.tsx index 60804429f2f3..9af0c3f8e9e0 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-form.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-form.tsx @@ -10,7 +10,7 @@ import { useClient, useQuery } from 'urql' import * as z from 'zod' import { graphql } from '@/lib/gql/generates' -import { OAuthProvider } from '@/lib/gql/generates/graphql' +import { LicenseType, OAuthProvider } from '@/lib/gql/generates/graphql' import { useMutation } from '@/lib/tabby/gql' import { cn } from '@/lib/utils' import { @@ -39,6 +39,7 @@ import { Input } from '@/components/ui/input' import { Label } from '@/components/ui/label' import { RadioGroup, RadioGroupItem } from '@/components/ui/radio-group' import { CopyButton } from '@/components/copy-button' +import { LicenseGuard } from '@/components/license-guard' import { oauthCredential } from './oauth-credential-list' @@ -324,15 +325,21 @@ export default function OAuthCredentialForm({ )} - )} - {isNew ? 'Create' : 'Update'} - + diff --git a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-list.tsx b/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-list.tsx index cde077590bd8..820d4bdbe9d3 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-list.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-list.tsx @@ -2,17 +2,20 @@ import React from 'react' import Link from 'next/link' +import { useRouter } from 'next/navigation' import { compact, find } from 'lodash-es' import { useQuery } from 'urql' import { graphql } from '@/lib/gql/generates' import { + LicenseType, OAuthCredentialQuery, OAuthProvider } from '@/lib/gql/generates/graphql' -import { buttonVariants } from '@/components/ui/button' +import { Button, buttonVariants } from '@/components/ui/button' import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card' import { Skeleton } from '@/components/ui/skeleton' +import { LicenseGuard } from '@/components/license-guard' import { PROVIDER_METAS } from './constant' import { SSOHeader } from './sso-header' @@ -43,6 +46,20 @@ const OAuthCredentialList = () => { return compact([githubData?.oauthCredential, googleData?.oauthCredential]) }, [githubData, googleData]) + const router = useRouter() + const createButton = ( + + {({ hasValidLicense }) => ( + + )} + + ) + if (!credentialList?.length) { return (
@@ -80,14 +97,7 @@ const OAuthCredentialList = () => { })}
{credentialList.length < 2 && ( -
- - Create - -
+
{createButton}
)} ) diff --git a/ee/tabby-ui/app/(dashboard)/settings/general/components/general.tsx b/ee/tabby-ui/app/(dashboard)/settings/general/components/general.tsx index e3d8e1c427df..f16fecf8b22f 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/general/components/general.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/general/components/general.tsx @@ -10,29 +10,23 @@ import { GeneralNetworkForm } from './network-form' import { GeneralSecurityForm } from './security-form' export default function General() { - // todo usequery - const [initialized, setInitialized] = React.useState(false) React.useEffect(() => { setTimeout(() => { - // get data from query and then setInitialized setInitialized(true) }, 500) }, []) - // makes it convenient to set the defaultValues of forms if (!initialized) return return (
- {/* todo pass defualtValues from useQuery */} - {/* todo pass defualtValues from useQuery */}
diff --git a/ee/tabby-ui/app/(dashboard)/settings/general/components/network-form.tsx b/ee/tabby-ui/app/(dashboard)/settings/general/components/network-form.tsx index 508208c8fd45..60d0eaa2bdf4 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/general/components/network-form.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/general/components/network-form.tsx @@ -5,10 +5,10 @@ import { zodResolver } from '@hookform/resolvers/zod' import { isEmpty } from 'lodash-es' import { useForm } from 'react-hook-form' import { toast } from 'sonner' -import { useQuery } from 'urql' import * as z from 'zod' import { graphql } from '@/lib/gql/generates' +import { useNetworkSetting } from '@/lib/hooks/use-network-setting' import { useMutation } from '@/lib/tabby/gql' import { Button } from '@/components/ui/button' import { @@ -28,14 +28,6 @@ const updateNetworkSettingMutation = graphql(/* GraphQL */ ` } `) -export const networkSetting = graphql(/* GraphQL */ ` - query NetworkSetting { - networkSetting { - externalUrl - } - } -`) - const formSchema = z.object({ externalUrl: z.string() }) @@ -123,10 +115,12 @@ const NetworkForm: React.FC = ({ } export const GeneralNetworkForm = () => { - const [{ data: data }] = useQuery({ query: networkSetting }) + const [{ data }, reexecuteQuery] = useNetworkSetting() const onSuccess = () => { toast.success('Network configuration is updated') + reexecuteQuery() } + return ( data && ( diff --git a/ee/tabby-ui/app/(dashboard)/settings/general/components/security-form.tsx b/ee/tabby-ui/app/(dashboard)/settings/general/components/security-form.tsx index 7ba4bc5aa4eb..def6dc57282e 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/general/components/security-form.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/general/components/security-form.tsx @@ -9,6 +9,7 @@ import { useQuery } from 'urql' import * as z from 'zod' import { graphql } from '@/lib/gql/generates' +import { LicenseType } from '@/lib/gql/generates/graphql' import { useMutation } from '@/lib/tabby/gql' import { cn } from '@/lib/utils' import { Button } from '@/components/ui/button' @@ -24,6 +25,7 @@ import { } from '@/components/ui/form' import { IconTrash } from '@/components/ui/icons' import { Input } from '@/components/ui/input' +import { LicenseGuard } from '@/components/license-guard' const updateSecuritySettingMutation = graphql(/* GraphQL */ ` mutation updateSecuritySetting($input: SecuritySettingInput!) { @@ -64,15 +66,9 @@ const SecurityForm: React.FC = ({ onSuccess, defaultValues: propsDefaultValues }) => { - const defaultValues = React.useMemo(() => { - return { - ...(propsDefaultValues || {}) - } - }, [propsDefaultValues]) - const form = useForm>({ resolver: zodResolver(formSchema), - defaultValues + defaultValues: propsDefaultValues }) const { fields, append, remove, update } = useFieldArray({ @@ -201,9 +197,15 @@ const SecurityForm: React.FC = ({
- + + {({ hasValidLicense }) => { + return ( + + ) + }} +
@@ -227,9 +229,10 @@ function buildListValuesFromField(fieldListValue?: Array<{ value: string }>) { } export const GeneralSecurityForm = () => { - const [{ data }] = useQuery({ query: securitySetting }) + const [{ data }, reexecuteQuery] = useQuery({ query: securitySetting }) const onSuccess = () => { toast.success('Security configuration is updated') + reexecuteQuery() } const defaultValues = data && { ...data.securitySetting, diff --git a/ee/tabby-ui/app/(dashboard)/settings/general/page.tsx b/ee/tabby-ui/app/(dashboard)/settings/general/page.tsx index 470dcfc1a493..2361323aa5ae 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/general/page.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/general/page.tsx @@ -7,10 +7,5 @@ export const metadata: Metadata = { } export default function GeneralSettings() { - // todo abstract settings-layout after email was merged - return ( -
- -
- ) + return } diff --git a/ee/tabby-ui/app/(dashboard)/settings/subscription/components/license-form.tsx b/ee/tabby-ui/app/(dashboard)/settings/subscription/components/license-form.tsx index 4e34f89b4c47..f71ce59fdc8b 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/subscription/components/license-form.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/subscription/components/license-form.tsx @@ -3,12 +3,24 @@ import * as React from 'react' import { zodResolver } from '@hookform/resolvers/zod' import { useForm } from 'react-hook-form' +import { toast } from 'sonner' import * as z from 'zod' import { graphql } from '@/lib/gql/generates' import { useMutation } from '@/lib/tabby/gql' import { cn } from '@/lib/utils' -import { Button } from '@/components/ui/button' +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger +} from '@/components/ui/alert-dialog' +import { Button, buttonVariants } from '@/components/ui/button' import { Form, FormControl, @@ -27,6 +39,7 @@ type FormValues = z.infer interface LicenseFormProps extends React.HTMLAttributes { onSuccess?: () => void + canReset?: boolean } const uploadLicenseMutation = graphql(/* GraphQL */ ` @@ -35,9 +48,16 @@ const uploadLicenseMutation = graphql(/* GraphQL */ ` } `) +const resetLicenseMutation = graphql(/* GraphQL */ ` + mutation ResetLicense { + resetLicense + } +`) + export function LicenseForm({ className, onSuccess, + canReset, ...props }: LicenseFormProps) { const form = useForm({ @@ -45,20 +65,47 @@ export function LicenseForm({ }) const license = form.watch('license') const { isSubmitting } = form.formState + const [isReseting, setIsDeleting] = React.useState(false) + const [resetDialogOpen, setResetDialogOpen] = React.useState(false) const uploadLicense = useMutation(uploadLicenseMutation, { form }) + const resetLicense = useMutation(resetLicenseMutation) + const onSubmit = (values: FormValues) => { return uploadLicense(values).then(res => { if (res?.data?.uploadLicense) { form.reset({ license: '' }) + toast.success('License is uploaded') onSuccess?.() } }) } + const onReset: React.MouseEventHandler = e => { + e.preventDefault() + setIsDeleting(true) + resetLicense() + .then(res => { + if (res?.data?.resetLicense) { + setResetDialogOpen(false) + onSuccess?.() + } else if (res?.error) { + toast.error(res.error.message ?? 'reset failed') + } + }) + .finally(() => { + setIsDeleting(false) + }) + } + + const onResetDialogOpenChange = (v: boolean) => { + if (isReseting) return + setResetDialogOpen(v) + } + return (
@@ -79,12 +126,42 @@ export function LicenseForm({ )} /> -
- + + )} + + + Are you absolutely sure? + + This action cannot be undone. It will reset the current + license. + + + + Cancel + + {isReseting && ( + + )} + Yes, reset it + + + + +
) } + +const FeatureList = ({ + name, + features +}: { + name: String + features: Feature[] +}) => { + return ( + <> + + + {name} + + + {features.map(({ name, community, team, enterprise }, i) => ( + + {name} + {community} + {team} + + {enterprise} + + + ))} + + ) +} + +interface Plan { + name: ReactNode | String + pricing: ReactNode | String + limit: ReactNode | String +} + +const PLANS: Plan[] = [ + { + name: 'Community', + pricing: '$0 per user/month', + limit: 'Up to 5 users, single node' + }, + { + name: 'Team', + pricing: '$19 per user/month', + limit: 'Up to 30 users, up to 2 nodes' + }, + { + name: 'Enterprise', + pricing: 'Contact Us', + limit: 'Customized, billed annually' + } +] + +interface Feature { + name: ReactNode | String + community: ReactNode | String + team: ReactNode | String + enterprise: ReactNode | String +} + +const FeatureTooltip = ({ children }: { children: ReactNode }) => ( + + + + + + +

{children}

+
+
+
+) + +const FeatureWithTooltip = ({ + name, + children +}: { + name: string + children: ReactNode +}) => ( + + {name} + {children} + +) + +interface FeatureGroup { + name: String + features: Feature[] +} + +const checked = +const dashed = '–' + +const FEATURES: FeatureGroup[] = [ + { + name: 'Features', + features: [ + { + name: 'User count', + community: 'Up to 5', + team: 'Up to 30', + enterprise: 'Unlimited' + }, + { + name: 'Node count', + community: dashed, + team: 'Up to 2', + enterprise: 'Unlimited' + }, + { + name: 'Secure Access', + community: checked, + team: checked, + enterprise: checked + }, + { + name: ( + + Tabby builds on top of open technologies, allowing customers to + bring their own LLM models. + + ), + community: checked, + team: checked, + enterprise: checked + }, + { + name: ( + + Tabby can retrieve the codebase context to enhance responses. + Underlying Tabby pulls context from git providers with a code search + index. This method enables Tabby to utilize the team's past + practices at scale. + + ), + community: checked, + team: checked, + enterprise: checked + }, + { + name: 'Admin Controls', + community: dashed, + team: checked, + enterprise: checked + }, + { + name: 'Toggle IDE / Extensions telemetry', + community: dashed, + team: dashed, + enterprise: checked + }, + { + name: 'Authentication Domain', + community: dashed, + team: dashed, + enterprise: checked + }, + { + name: 'Single Sign-On (SSO)', + community: dashed, + team: dashed, + enterprise: checked + } + ] + }, + { + name: 'Bespoke', + features: [ + { + name: 'Support', + community: 'Community', + team: 'Email', + enterprise: 'Dedicated Slack channel' + }, + { + name: 'Roadmap prioritization', + community: dashed, + team: dashed, + enterprise: checked + } + ] + } +] diff --git a/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx b/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx index 0e5017e16148..692da3443184 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx @@ -2,9 +2,9 @@ import { capitalize } from 'lodash-es' import moment from 'moment' -import { useQuery } from 'urql' -import { graphql } from '@/lib/gql/generates' +import { LicenseInfo, LicenseType } from '@/lib/gql/generates/graphql' +import { useLicense } from '@/lib/hooks/use-license' import { Skeleton } from '@/components/ui/skeleton' import LoadingWrapper from '@/components/loading-wrapper' import { SubHeader } from '@/components/sub-header' @@ -12,71 +12,65 @@ import { SubHeader } from '@/components/sub-header' import { LicenseForm } from './license-form' import { LicenseTable } from './license-table' -const getLicenseInfo = graphql(/* GraphQL */ ` - query GetLicenseInfo { - license { - type - status - seats - seatsUsed - issuedAt - expiresAt - } - } -`) - export default function Subscription() { - const [{ data, fetching }, reexecuteQuery] = useQuery({ - query: getLicenseInfo - }) + const [{ data, fetching }, reexecuteQuery] = useLicense() const license = data?.license - const expiresAt = license?.expiresAt - ? moment(license.expiresAt).format('MM/DD/YYYY') - : '-' - const onUploadLicenseSuccess = () => { reexecuteQuery() } - - const seatsText = license ? `${license.seatsUsed} / ${license.seats}` : '-' + const canReset = !!license?.type && license.type !== LicenseType.Community return ( -
- - You can upload your Tabby license to unlock enterprise features. + <> + + You can upload your Tabby license to unlock team/enterprise features.
- - - +
+ + +
} > -
-
-
Expires at
-
{expiresAt}
-
-
-
- Assigned / Total Seats -
-
{seatsText}
-
-
-
Current plan
-
- {capitalize(license?.type ?? 'FREE')} -
-
-
+ {license && }
- - {false && } + + +
+ + ) +} + +function License({ license }: { license: LicenseInfo }) { + const expiresAt = license.expiresAt + ? moment(license.expiresAt).format('MM/DD/YYYY') + : '–' + + const seatsText = `${license.seatsUsed} / ${license.seats}` + + return ( +
+
+
Expires at
+
{expiresAt}
+
+
+
Assigned / Total Seats
+
{seatsText}
+
+
+
Current plan
+
+ {capitalize(license?.type ?? 'Community')} +
) diff --git a/ee/tabby-ui/app/(dashboard)/settings/team/components/invitation-table.tsx b/ee/tabby-ui/app/(dashboard)/settings/team/components/invitation-table.tsx index dc733fee055a..fea134592a2d 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/team/components/invitation-table.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/team/components/invitation-table.tsx @@ -1,6 +1,6 @@ 'use client' -import React, { useEffect, useState } from 'react' +import React from 'react' import moment from 'moment' import { toast } from 'sonner' import { useClient, useQuery } from 'urql' @@ -11,6 +11,7 @@ import { InvitationEdge, ListInvitationsQueryVariables } from '@/lib/gql/generates/graphql' +import { useExternalURL } from '@/lib/hooks/use-network-setting' import { useMutation } from '@/lib/tabby/gql' import { listInvitations } from '@/lib/tabby/query' import { Button } from '@/components/ui/button' @@ -96,10 +97,7 @@ export default function InvitationTable() { } } - const [origin, setOrigin] = useState('') - useEffect(() => { - setOrigin(new URL(window.location.href).origin) - }, []) + const externalUrl = useExternalURL() const deleteInvitation = useMutation(deleteInvitationMutation) @@ -162,7 +160,7 @@ export default function InvitationTable() { )} {currentPageInvits?.map(x => { - const link = `${origin}/auth/signup?invitationCode=${x.node.code}` + const link = `${externalUrl}/auth/signup?invitationCode=${x.node.code}` return ( {x.node.email} diff --git a/ee/tabby-ui/app/(dashboard)/settings/team/components/team.tsx b/ee/tabby-ui/app/(dashboard)/settings/team/components/team.tsx index 6c6cfcc3b039..8a61d65450b8 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/team/components/team.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/team/components/team.tsx @@ -7,24 +7,24 @@ import UsersTable from './user-table' export default function Team() { return ( -
+ <>
- + Pending Invites - +
- + Members - +
-
+ ) } diff --git a/ee/tabby-ui/app/(dashboard)/settings/team/components/user-role-dialog.tsx b/ee/tabby-ui/app/(dashboard)/settings/team/components/user-role-dialog.tsx index 2463f2e6a060..2f97244905c6 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/team/components/user-role-dialog.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/team/components/user-role-dialog.tsx @@ -4,6 +4,7 @@ import React from 'react' import { toast } from 'sonner' import { graphql } from '@/lib/gql/generates/gql' +import { LicenseType } from '@/lib/gql/generates/graphql' import { useMutation } from '@/lib/tabby/gql' import { AlertDialog, @@ -17,6 +18,7 @@ import { } from '@/components/ui/alert-dialog' import { buttonVariants } from '@/components/ui/button' import { IconSpinner } from '@/components/ui/icons' +import { LicenseGuard } from '@/components/license-guard' const updateUserRoleMutation = graphql(/* GraphQL */ ` mutation updateUserRole($id: ID!, $isAdmin: Boolean!) { @@ -40,8 +42,7 @@ export const UpdateUserRoleDialog: React.FC = ({ isPromote }) => { const [isSubmitting, setIsSubmitting] = React.useState(false) - const requestPasswordResetEmail = useMutation(updateUserRoleMutation) - + const updateUserRole = useMutation(updateUserRoleMutation) const onSubmit: React.MouseEventHandler = async e => { e.preventDefault() @@ -50,13 +51,15 @@ export const UpdateUserRoleDialog: React.FC = ({ return } setIsSubmitting(true) - return requestPasswordResetEmail({ + return updateUserRole({ id: user.id, isAdmin: !!isPromote }) .then(res => { if (res?.data?.updateUserRole) { onSuccess?.() + } else if (res?.error) { + toast.error(res.error?.message ?? 'update failed') } }) .finally(() => { @@ -81,16 +84,20 @@ export const UpdateUserRoleDialog: React.FC = ({ Cancel - - {isSubmitting && ( - + + {({ hasValidLicense }) => ( + + {isSubmitting && ( + + )} + Confirm + )} - Confirm - + diff --git a/ee/tabby-ui/app/layout.tsx b/ee/tabby-ui/app/layout.tsx index f7b571a1090e..87e4fba747d7 100644 --- a/ee/tabby-ui/app/layout.tsx +++ b/ee/tabby-ui/app/layout.tsx @@ -38,7 +38,7 @@ export default function RootLayout({ children }: RootLayoutProps) { >
{children}
- +
diff --git a/ee/tabby-ui/components/license-guard.tsx b/ee/tabby-ui/components/license-guard.tsx new file mode 100644 index 000000000000..5da745e80c59 --- /dev/null +++ b/ee/tabby-ui/components/license-guard.tsx @@ -0,0 +1,78 @@ +import * as React from 'react' +import Link from 'next/link' +import { capitalize } from 'lodash-es' + +import { + GetLicenseInfoQuery, + LicenseStatus, + LicenseType +} from '@/lib/gql/generates/graphql' +import { useLicenseInfo } from '@/lib/hooks/use-license' +import { cn } from '@/lib/utils' +import { buttonVariants } from '@/components/ui/button' +import { + HoverCard, + HoverCardContent, + HoverCardTrigger +} from '@/components/ui/hover-card' + +interface LicenseGuardProps { + licenses: LicenseType[] + children: (params: { + hasValidLicense: boolean + license: GetLicenseInfoQuery['license'] | undefined | null + }) => React.ReactNode +} + +const LicenseGuard: React.FC = ({ licenses, children }) => { + const [open, setOpen] = React.useState(false) + const license = useLicenseInfo() + const hasValidLicense = + !!license && + license.status === LicenseStatus.Ok && + licenses.includes(license.type) + + const onOpenChange = (v: boolean) => { + if (hasValidLicense) return + setOpen(v) + } + + let licenseString = capitalize(licenses[0]) + let licenseText = licenseString + if (licenses.length == 2) { + licenseText = `${capitalize(licenses[0])} or ${capitalize(licenses[1])}` + } + + return ( + + +
+ This feature is only available on Tabby's{' '} + {licenseText} plan. Upgrade to + use this feature. +
+
+ + Upgrade to {licenseString} + +
+
+ { + if (!hasValidLicense) { + e.preventDefault() + onOpenChange(true) + } + }} + > +
+ {children({ hasValidLicense, license })} +
+
+
+ ) +} +LicenseGuard.displayName = 'LicenseGuard' + +export { LicenseGuard } diff --git a/ee/tabby-ui/components/sub-header.tsx b/ee/tabby-ui/components/sub-header.tsx index be98fdfc16c7..9a3f217d7f2a 100644 --- a/ee/tabby-ui/components/sub-header.tsx +++ b/ee/tabby-ui/components/sub-header.tsx @@ -6,11 +6,13 @@ import { IconExternalLink } from '@/components/ui/icons' interface SubHeaderProps extends React.HTMLAttributes { externalLink?: string + externalLinkText?: string } export const SubHeader: React.FC = ({ className, externalLink, + externalLinkText = 'Learn more', children }) => { return ( @@ -23,8 +25,8 @@ export const SubHeader: React.FC = ({ href={externalLink} target="_blank" > - Learn more - + {externalLinkText} + )}
diff --git a/ee/tabby-ui/components/ui/hover-card.tsx b/ee/tabby-ui/components/ui/hover-card.tsx new file mode 100644 index 000000000000..0ec79708ab32 --- /dev/null +++ b/ee/tabby-ui/components/ui/hover-card.tsx @@ -0,0 +1,29 @@ +'use client' + +import * as React from 'react' +import * as HoverCardPrimitive from '@radix-ui/react-hover-card' + +import { cn } from '@/lib/utils' + +const HoverCard = HoverCardPrimitive.Root + +const HoverCardTrigger = HoverCardPrimitive.Trigger + +const HoverCardContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, align = 'center', sideOffset = 4, ...props }, ref) => ( + +)) +HoverCardContent.displayName = HoverCardPrimitive.Content.displayName + +export { HoverCard, HoverCardTrigger, HoverCardContent } diff --git a/ee/tabby-ui/components/ui/icons.tsx b/ee/tabby-ui/components/ui/icons.tsx index 983232935143..f645777a0ec1 100644 --- a/ee/tabby-ui/components/ui/icons.tsx +++ b/ee/tabby-ui/components/ui/icons.tsx @@ -244,12 +244,17 @@ function IconUser({ className, ...props }: React.ComponentProps<'svg'>) { return ( - + + ) } @@ -412,12 +417,16 @@ function IconCheck({ className, ...props }: React.ComponentProps<'svg'>) { return ( - + ) } @@ -440,12 +449,17 @@ function IconClose({ className, ...props }: React.ComponentProps<'svg'>) { return ( - + + ) } @@ -845,17 +859,18 @@ function IconBackpack({ className, ...props }: React.ComponentProps<'svg'>) { function IconGear({ className, ...props }: React.ComponentProps<'svg'>) { return ( - + + ) } diff --git a/ee/tabby-ui/components/user-panel.tsx b/ee/tabby-ui/components/user-panel.tsx index 693109abe951..ccdf70508207 100644 --- a/ee/tabby-ui/components/user-panel.tsx +++ b/ee/tabby-ui/components/user-panel.tsx @@ -1,8 +1,9 @@ import React from 'react' import NiceAvatar, { genConfig } from 'react-nice-avatar' +import { useMe } from '@/lib/hooks/use-me' import { useIsChatEnabled } from '@/lib/hooks/use-server-info' -import { useAuthenticatedSession, useSignOut } from '@/lib/tabby/auth' +import { useSignOut } from '@/lib/tabby/auth' import { DropdownMenu, DropdownMenuContent, @@ -15,9 +16,9 @@ import { import { IconBackpack, IconChat, IconCode, IconLogout } from './ui/icons' export default function UserPanel() { - const user = useAuthenticatedSession() const signOut = useSignOut() - + const [{ data }] = useMe() + const user = data?.me const isChatEnabled = useIsChatEnabled() if (!user) { diff --git a/ee/tabby-ui/lib/hooks/use-license.ts b/ee/tabby-ui/lib/hooks/use-license.ts new file mode 100644 index 000000000000..03c333fa284b --- /dev/null +++ b/ee/tabby-ui/lib/hooks/use-license.ts @@ -0,0 +1,27 @@ +import { useQuery } from 'urql' + +import { graphql } from '../gql/generates' + +const getLicenseInfo = graphql(/* GraphQL */ ` + query GetLicenseInfo { + license { + type + status + seats + seatsUsed + issuedAt + expiresAt + } + } +`) + +const useLicense = () => { + return useQuery({ query: getLicenseInfo }) +} + +const useLicenseInfo = () => { + const [{ data }] = useLicense() + return data?.license +} + +export { getLicenseInfo, useLicense, useLicenseInfo } diff --git a/ee/tabby-ui/lib/hooks/use-me.ts b/ee/tabby-ui/lib/hooks/use-me.ts new file mode 100644 index 000000000000..43ffa92b60b4 --- /dev/null +++ b/ee/tabby-ui/lib/hooks/use-me.ts @@ -0,0 +1,20 @@ +import { useQuery } from 'urql' + +import { graphql } from '@/lib/gql/generates' + +const meQuery = graphql(/* GraphQL */ ` + query MeQuery { + me { + authToken + email + isAdmin + isPasswordSet + } + } +`) + +const useMe = () => { + return useQuery({ query: meQuery }) +} + +export { useMe } diff --git a/ee/tabby-ui/lib/hooks/use-network-setting.tsx b/ee/tabby-ui/lib/hooks/use-network-setting.tsx new file mode 100644 index 000000000000..d7fde046e265 --- /dev/null +++ b/ee/tabby-ui/lib/hooks/use-network-setting.tsx @@ -0,0 +1,36 @@ +import React from 'react' +import { useQuery } from 'urql' + +import { graphql } from '../gql/generates' +import { isClientSide } from '../utils' + +const networkSettingQuery = graphql(/* GraphQL */ ` + query NetworkSetting { + networkSetting { + externalUrl + } + } +`) + +const useNetworkSetting = () => { + return useQuery({ query: networkSettingQuery }) +} + +const useExternalURL = () => { + const [{ data }] = useNetworkSetting() + const networkSetting = data?.networkSetting + const externalUrl = React.useMemo(() => { + return networkSetting?.externalUrl || getOrigin() + }, [networkSetting]) + + return externalUrl +} + +function getOrigin() { + if (isClientSide()) { + return new URL(window.location.href).origin + } + return '' +} + +export { useNetworkSetting, useExternalURL } diff --git a/ee/tabby-ui/lib/tabby/auth.tsx b/ee/tabby-ui/lib/tabby/auth.tsx index 99ccc8626287..3c0449098ce1 100644 --- a/ee/tabby-ui/lib/tabby/auth.tsx +++ b/ee/tabby-ui/lib/tabby/auth.tsx @@ -6,6 +6,7 @@ import useLocalStorage from 'use-local-storage' import { graphql } from '@/lib/gql/generates' import { isClientSide } from '@/lib/utils' +import { useMe } from '../hooks/use-me' import { useIsAdminInitialized } from '../hooks/use-server-info' interface AuthData { @@ -150,6 +151,7 @@ const AuthProvider: React.FunctionComponent = ({ status: 'loading', data: undefined }) + const [, reexecuteQueryMe] = useMe() React.useEffect(() => { initialized.current = true @@ -166,6 +168,7 @@ const AuthProvider: React.FunctionComponent = ({ // After being mounted, listen for changes in the access token if (authToken?.accessToken && authToken?.refreshToken) { dispatch({ type: AuthActionType.SignIn, data: authToken }) + reexecuteQueryMe() } else { dispatch({ type: AuthActionType.SignOut }) } @@ -174,12 +177,11 @@ const AuthProvider: React.FunctionComponent = ({ const session: Session = React.useMemo(() => { if (authState?.status == 'authenticated') { try { - const { sub, is_admin } = jwtDecode( + const { is_admin } = jwtDecode( authState.data.accessToken ) return { data: { - email: sub!, isAdmin: is_admin, accessToken: authState.data.accessToken }, @@ -260,7 +262,6 @@ function useSignOut(): () => Promise { } interface User { - email: string isAdmin: boolean accessToken: string } diff --git a/ee/tabby-ui/package.json b/ee/tabby-ui/package.json index 1436b87b1e47..a855a9058c96 100644 --- a/ee/tabby-ui/package.json +++ b/ee/tabby-ui/package.json @@ -30,6 +30,7 @@ "@radix-ui/react-collapsible": "^1.0.3", "@radix-ui/react-dialog": "1.0.4", "@radix-ui/react-dropdown-menu": "^2.0.5", + "@radix-ui/react-hover-card": "^1.0.7", "@radix-ui/react-label": "^2.0.2", "@radix-ui/react-popover": "^1.0.7", "@radix-ui/react-radio-group": "^1.1.3", diff --git a/ee/tabby-ui/yarn.lock b/ee/tabby-ui/yarn.lock index b8f6718120e4..f54ef9b752b6 100644 --- a/ee/tabby-ui/yarn.lock +++ b/ee/tabby-ui/yarn.lock @@ -2017,6 +2017,22 @@ "@radix-ui/react-primitive" "1.0.3" "@radix-ui/react-use-callback-ref" "1.0.1" +"@radix-ui/react-hover-card@^1.0.7": + version "1.0.7" + resolved "https://registry.yarnpkg.com/@radix-ui/react-hover-card/-/react-hover-card-1.0.7.tgz#684bca2504432566357e7157e087051aa3577948" + integrity sha512-OcUN2FU0YpmajD/qkph3XzMcK/NmSk9hGWnjV68p6QiZMgILugusgQwnLSDs3oFSJYGKf3Y49zgFedhGh04k9A== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/primitive" "1.0.1" + "@radix-ui/react-compose-refs" "1.0.1" + "@radix-ui/react-context" "1.0.1" + "@radix-ui/react-dismissable-layer" "1.0.5" + "@radix-ui/react-popper" "1.1.3" + "@radix-ui/react-portal" "1.0.4" + "@radix-ui/react-presence" "1.0.1" + "@radix-ui/react-primitive" "1.0.3" + "@radix-ui/react-use-controllable-state" "1.0.1" + "@radix-ui/react-id@1.0.1": version "1.0.1" resolved "https://registry.yarnpkg.com/@radix-ui/react-id/-/react-id-1.0.1.tgz#73cdc181f650e4df24f0b6a5b7aa426b912c88c0" diff --git a/ee/tabby-webserver/graphql/schema.graphql b/ee/tabby-webserver/graphql/schema.graphql index a02ce2ed1673..e791ad1f0b2f 100644 --- a/ee/tabby-webserver/graphql/schema.graphql +++ b/ee/tabby-webserver/graphql/schema.graphql @@ -17,6 +17,7 @@ type Mutation { requestInvitationEmail(input: RequestInvitationInput!): Invitation! requestPasswordResetEmail(input: RequestPasswordResetEmailInput!): Boolean! passwordReset(input: PasswordResetInput!): Boolean! + passwordChange(input: PasswordChangeInput!): Boolean! resetUserAuthToken: Boolean! updateUserActive(id: ID!, active: Boolean!): Boolean! updateUserRole(id: ID!, isAdmin: Boolean!): Boolean! @@ -37,6 +38,7 @@ type Mutation { updateNetworkSetting(input: NetworkSettingInput!): Boolean! deleteEmailSetting: Boolean! uploadLicense(license: String!): Boolean! + resetLicense: Boolean! } type RepositoryEdge { @@ -74,7 +76,7 @@ type Query { oauthCredential(provider: OAuthProvider!): OAuthCredential oauthCallbackUrl(provider: OAuthProvider!): String! serverInfo: ServerInfo! - license: LicenseInfo + license: LicenseInfo! } input NetworkSettingInput { @@ -134,8 +136,8 @@ type LicenseInfo { status: LicenseStatus! seats: Int! seatsUsed: Int! - issuedAt: DateTimeUtc! - expiresAt: DateTimeUtc! + issuedAt: DateTimeUtc + expiresAt: DateTimeUtc } input EmailSettingInput { @@ -154,7 +156,9 @@ input SecuritySettingInput { } enum LicenseType { + COMMUNITY TEAM + ENTERPRISE } type SecuritySetting { @@ -194,6 +198,12 @@ input PasswordResetInput { password2: String! } +input PasswordChangeInput { + oldPassword: String + newPassword1: String! + newPassword2: String! +} + type Invitation { id: ID! email: String! @@ -218,6 +228,7 @@ type User { authToken: String! createdAt: DateTimeUtc! active: Boolean! + isPasswordSet: Boolean! } type Worker { diff --git a/ee/tabby-webserver/keys/license.key.pub b/ee/tabby-webserver/keys/license.key.pub index e009af8ef665..2a44dc4c4cd5 100644 --- a/ee/tabby-webserver/keys/license.key.pub +++ b/ee/tabby-webserver/keys/license.key.pub @@ -1,14 +1,14 @@ -----BEGIN PUBLIC KEY----- -MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA4WKYjEErVACA8sNQ6gGL -+9KUatfl+nJ74xR/+2ayrk4bQtVWgGvm4cGuc2V60aJ11BdXOEcyt95mO8n8+FRe -Y/fPkW22QDyqG7PXAt5bT4zuLXrrwvCrhB6QRWScRUaZv3jzCoclBu2fOxJxqbJo -Xx0pkXFct5viT3yfqv/+C5QQ7gPexUPEXqYRQuU4hqeVXtkhkfRA0DTWtOXnf0mU -4rJztxvkQiSAI8nufX01h73FrICntEaFGvLQnLR0VGVjlACZEmA3Nldvoq+Yt2zR -mRREXv3ks+PTnaAORoYnrnB+PoVMw9SkGUzA61CqvJoxKrbZfYmODglTlJh91UiF -lH7DXd7GK6iwHtO6dumAVaiIYqfPpJn0PExaqjtXHzKCRozbLPwIF+ECbq5vxjBq -hfWO/uqhiOqusRCoA4E8UHu8BmRC2s4Kn3QI/qOHKovCq72Hy0YL3/trYtJfJ3cz -sVyytP8tmoG3CGLjM80aaXpvpr87GCN07uKJmNgr00EQvxEK8CEMO7EbkNq7AVCY -awBC5tTDt7UzKyam8c97LuNUyWsI2H9FattHHDcRzA+HBTZ8FyZyI+m7zd2WdITF -tzpxm3mWo6MeQH3hq9prDiXwbyowXK/U0ZLK1s/WhFU5dxCnkTpZI3X5gLyGsEPD -FI05GFZrbWnOWCLAP0FKZAMCAwEAAQ== +MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAsmCQRK2UeAoIt7qnQlID +Aa8ih5fCRAhNAEFaWb9hHgg80q9M6F/0yPnhKxt78+3Lxz1jwZ5aZYaSxnZEwAx9 +3X1Zs7x+Our6VJmG1WDGqMkLWcSZjf8fsbH7TTMXLTBnU/nIkFKzNZInccu8CxsH +Im6Qxr81VFFHmZG9dsUQb2/fuA5Ck/UNvbbSipo+qfs5vJ6UP2CghkhUREZhK5Yb +s4wJ3AhUZ+uKzxm+73bjFTnYZ32IKR2h5WfroLK9wOuJ3hXuyK96Ka5DaUb3i3Ni +aFQiSJ2E5W8tjQYWEBth484v7kBjoamTbVP/jrcUUfmE2TAeEi3L/AU8QRlAJ575 +1Ujny4K+8PzzG4tJYjqXSI3BijaeUNwxFcL07yStj6xH9Fz/pLKV0RtEtPgNtXwq +6WlcYF1GqlNFinkTH+pFi9vFydxJ2N3HayR6dDq1r17Lf9HPmezMnRXde01wwVQD +Or77KGuIm6sH3cusW1IR+X2ZCT3UAQ1FRqpL9TMWY/LLl19Q1w3F/RYP/ZRvFsVl +aDTWyOqlG1XVU7uAR74gYirr73Rv/8pZ2453ZaVYjL73ZAM8X85Kh4xRstD6SjKG +A4WRGjiVsTSGRxk81wXNPLu3fnnj72gUkEgEWG+9odQSIrHtZSVCMxfL1Rc0MfK/ +rVIU8J/jzVyNYUKaVrrICt0CAwEAAQ== -----END PUBLIC KEY----- diff --git a/ee/tabby-webserver/src/oauth/mod.rs b/ee/tabby-webserver/src/oauth/mod.rs index 1a83c4e69086..c79beee6725e 100644 --- a/ee/tabby-webserver/src/oauth/mod.rs +++ b/ee/tabby-webserver/src/oauth/mod.rs @@ -120,7 +120,7 @@ async fn google_oauth_handler( Query(param): Query, ) -> Redirect { if !param.error.is_empty() { - return make_error_redirect(OAuthProvider::Google, ¶m.error); + return make_error_redirect(OAuthProvider::Google, param.error); } match_auth_result( OAuthProvider::Google, @@ -140,26 +140,23 @@ fn match_auth_result( ); Redirect::temporary(&uri) } - Err(OAuthError::InvalidVerificationCode) => { - make_error_redirect(provider, "Invalid oauth code") - } - Err(OAuthError::CredentialNotActive) => { - make_error_redirect(provider, "OAuth is not enabled") - } - Err(OAuthError::UserNotInvited) => make_error_redirect( - provider, - "User is not invited, please contact your admin for help", - ), - Err(e) => { - error!("Failed to authenticate: {:?}", e); - make_error_redirect(provider, "Unknown error") - } + Err(err) => match err { + OAuthError::InvalidVerificationCode + | OAuthError::UserNotInvited + | OAuthError::UserDisabled + | OAuthError::CredentialNotActive + | OAuthError::Unknown => make_error_redirect(provider, err.to_string()), + OAuthError::Other(e) => { + error!("Failed to authenticate: {:?}", e); + make_error_redirect(provider, OAuthError::Unknown.to_string()) + } + }, } } -fn make_error_redirect(provider: OAuthProvider, message: &str) -> Redirect { +fn make_error_redirect(provider: OAuthProvider, message: String) -> Redirect { let query = querystring::stringify(vec![ - ("error_message", urlencoding::encode(message).as_ref()), + ("error_message", urlencoding::encode(&message).as_ref()), ( "provider", serde_json::to_string(&provider).unwrap().as_str(), diff --git a/ee/tabby-webserver/src/repositories/resolve.rs b/ee/tabby-webserver/src/repositories/resolve.rs index 8ef941e0cbd5..29e01189b428 100644 --- a/ee/tabby-webserver/src/repositories/resolve.rs +++ b/ee/tabby-webserver/src/repositories/resolve.rs @@ -58,7 +58,7 @@ impl RepositoryCache { .collect(); let mut repository_lookup = self.repository_lookup.write().unwrap(); debug!("Reloading repositoriy metadata..."); - *repository_lookup = load_meta(&new_repositories); + *repository_lookup = load_meta(new_repositories); Ok(()) } @@ -149,7 +149,7 @@ impl From for RepositoryMeta { } } -fn load_meta(repositories: &Vec) -> HashMap { +fn load_meta(repositories: Vec) -> HashMap { let mut dataset = HashMap::new(); // Construct map of String -> &RepositoryConfig for lookup let repo_conf = repositories diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index d5e3f2e39dc9..0e595b795d35 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -162,16 +162,16 @@ pub struct OAuthResponse { #[derive(Error, Debug)] pub enum OAuthError { - #[error("The code passed is incorrect or expired")] + #[error("The oauth code passed is incorrect or expired")] InvalidVerificationCode, - #[error("The credential is not active")] + #[error("OAuth is not enabled")] CredentialNotActive, - #[error("The user is not invited to access the system")] + #[error("User is not invited, please contact admin for help")] UserNotInvited, - #[error("User is disabled")] + #[error("User is disabled, please contact admin for help")] UserDisabled, #[error(transparent)] @@ -202,7 +202,12 @@ impl RefreshTokenResponse { } } -#[derive(Debug, Default, Serialize, Deserialize)] +// IDWrapper to used as a type guard for refactoring, can be removed in a follow up PR. +// FIXME(meng): refactor out IDWrapper. +#[derive(Serialize, Deserialize, Debug)] +pub struct IDWrapper(pub ID); + +#[derive(Debug, Serialize, Deserialize)] pub struct JWTPayload { /// Expiration time (as UTC timestamp) exp: i64, @@ -210,20 +215,20 @@ pub struct JWTPayload { /// Issued at (as UTC timestamp) iat: i64, - /// User email address - pub sub: String, + /// User id string + pub sub: IDWrapper, /// Whether the user is admin. pub is_admin: bool, } impl JWTPayload { - pub fn new(email: String, is_admin: bool) -> Self { + pub fn new(id: ID, is_admin: bool) -> Self { let now = jwt::get_current_timestamp(); Self { iat: now as i64, exp: (now + *JWT_DEFAULT_EXP) as i64, - sub: email, + sub: IDWrapper(id), is_admin, } } @@ -239,6 +244,7 @@ pub struct User { pub auth_token: String, pub created_at: DateTime, pub active: bool, + pub is_password_set: bool, } impl relay::NodeType for User { @@ -301,6 +307,39 @@ pub struct PasswordResetInput { pub password2: String, } +#[derive(Validate, GraphQLInputObject)] +pub struct PasswordChangeInput { + pub old_password: Option, + + #[validate(length( + min = 8, + code = "newPassword1", + message = "Password must be at least 8 characters" + ))] + #[validate(length( + max = 20, + code = "newPassword1", + message = "Password must be at most 20 characters" + ))] + pub new_password1: String, + #[validate(length( + min = 8, + code = "newPassword2", + message = "Password must be at least 8 characters" + ))] + #[validate(length( + max = 20, + code = "newPassword2", + message = "Password must be at most 20 characters" + ))] + #[validate(must_match( + code = "newPassword2", + message = "Passwords do not match", + other = "new_password1" + ))] + pub new_password2: String, +} + #[derive(Debug, Serialize, Deserialize, GraphQLObject)] #[graphql(context = Context)] pub struct Invitation { @@ -378,14 +417,21 @@ pub trait AuthenticationService: Send + Sync { async fn verify_access_token(&self, access_token: &str) -> Result; async fn is_admin_initialized(&self) -> Result; async fn get_user_by_email(&self, email: &str) -> Result; + async fn get_user(&self, id: &ID) -> Result; async fn create_invitation(&self, email: String) -> Result; async fn request_invitation_email(&self, input: RequestInvitationInput) -> Result; async fn delete_invitation(&self, id: &ID) -> Result; - async fn reset_user_auth_token(&self, email: &str) -> Result<()>; + async fn reset_user_auth_token(&self, id: &ID) -> Result<()>; async fn password_reset(&self, code: &str, password: &str) -> Result<()>; async fn request_password_reset_email(&self, email: String) -> Result>>; + async fn update_user_password( + &self, + id: &ID, + old_password: Option<&str>, + new_password: &str, + ) -> Result<()>; async fn list_users( &self, @@ -460,7 +506,7 @@ mod tests { use super::*; #[test] fn test_generate_jwt() { - let claims = JWTPayload::new("test".to_string(), false); + let claims = JWTPayload::new(ID::from("test".to_owned()), false); let token = generate_jwt(claims).unwrap(); assert!(!token.is_empty()) @@ -468,10 +514,10 @@ mod tests { #[test] fn test_validate_jwt() { - let claims = JWTPayload::new("test".to_string(), false); + let claims = JWTPayload::new(ID::from("test".to_owned()), false); let token = generate_jwt(claims).unwrap(); let claims = validate_jwt(&token).unwrap(); - assert_eq!(claims.sub, "test"); + assert_eq!(claims.sub.0.to_string(), "test"); assert!(!claims.is_admin); } diff --git a/ee/tabby-webserver/src/schema/license.rs b/ee/tabby-webserver/src/schema/license.rs index 2f5e73d37d61..05551432a736 100644 --- a/ee/tabby-webserver/src/schema/license.rs +++ b/ee/tabby-webserver/src/schema/license.rs @@ -1,17 +1,22 @@ +use std::error::Error; + use async_trait::async_trait; use chrono::{DateTime, Utc}; use juniper::{GraphQLEnum, GraphQLObject}; use serde::Deserialize; +use super::CoreError; use crate::schema::Result; -#[derive(Debug, Deserialize, GraphQLEnum)] +#[derive(Debug, Deserialize, GraphQLEnum, PartialEq)] #[serde(rename_all = "UPPERCASE")] pub enum LicenseType { + Community, Team, + Enterprise, } -#[derive(GraphQLEnum, PartialEq, Debug)] +#[derive(GraphQLEnum, PartialEq, Debug, Clone)] pub enum LicenseStatus { Ok, Expired, @@ -24,12 +29,99 @@ pub struct LicenseInfo { pub status: LicenseStatus, pub seats: i32, pub seats_used: i32, - pub issued_at: DateTime, - pub expires_at: DateTime, + pub issued_at: Option>, + pub expires_at: Option>, +} + +impl LicenseInfo { + pub fn seat_limits_for_community_license() -> usize { + 5 + } + + pub fn seat_limits_for_team_license() -> usize { + 30 + } + + pub fn check_node_limit(&self, num_nodes: usize) -> bool { + match self.r#type { + LicenseType::Community => false, + LicenseType::Team => num_nodes <= 2, + LicenseType::Enterprise => true, + } + } + + pub fn guard_seat_limit(mut self) -> Self { + let seats = self.seats as usize; + self.seats = match self.r#type { + LicenseType::Community => { + std::cmp::min(seats, Self::seat_limits_for_community_license()) + } + LicenseType::Team => std::cmp::min(seats, Self::seat_limits_for_team_license()), + LicenseType::Enterprise => seats, + } as i32; + + self + } + + pub fn ensure_available_seats(&self, num_new_seats: usize) -> Result<()> { + self.ensure_valid_license()?; + if (self.seats_used as usize + num_new_seats) > self.seats as usize { + return Err(CoreError::InvalidLicense( + "No sufficient seats under current license", + )); + } + Ok(()) + } + + pub fn ensure_admin_seats(&self, num_admins: usize) -> Result<()> { + self.ensure_valid_license()?; + let num_admin_seats = match self.r#type { + LicenseType::Community => 1, + LicenseType::Team => 3, + LicenseType::Enterprise => usize::MAX, + }; + + if num_admins > num_admin_seats { + return Err(CoreError::InvalidLicense( + "No sufficient admin seats under the license", + )); + } + + Ok(()) + } } #[async_trait] pub trait LicenseService: Send + Sync { - async fn read_license(&self) -> Result>; + async fn read_license(&self) -> Result; async fn update_license(&self, license: String) -> Result<()>; + async fn reset_license(&self) -> Result<()>; +} + +pub trait IsLicenseValid { + fn ensure_valid_license(&self) -> Result<()>; +} + +impl IsLicenseValid for LicenseInfo { + fn ensure_valid_license(&self) -> Result<()> { + match self.status { + LicenseStatus::Expired => Err(CoreError::InvalidLicense( + "Your enterprise license is expired", + )), + LicenseStatus::SeatsExceeded => Err(CoreError::InvalidLicense( + "You have more active users than seats included in your license", + )), + LicenseStatus::Ok => Ok(()), + } + } +} + +impl IsLicenseValid for std::result::Result { + fn ensure_valid_license(&self) -> Result<()> { + if let Ok(x) = self { + x.ensure_valid_license() + } else { + Err(CoreError::InvalidLicense("No valid license configured")) + } + } } diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index f7ff602d7d67..d815e30b4096 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -27,18 +27,18 @@ use validator::{Validate, ValidationErrors}; use worker::{Worker, WorkerService}; use self::{ - auth::{PasswordResetInput, RequestPasswordResetEmailInput, UpdateOAuthCredentialInput}, + auth::{ + PasswordChangeInput, PasswordResetInput, RequestInvitationInput, + RequestPasswordResetEmailInput, UpdateOAuthCredentialInput, + }, email::{EmailService, EmailSetting, EmailSettingInput}, - license::{LicenseInfo, LicenseService}, - repository::RepositoryService, + license::{IsLicenseValid, LicenseInfo, LicenseService, LicenseType}, + repository::{Repository, RepositoryService}, setting::{ NetworkSetting, NetworkSettingInput, SecuritySetting, SecuritySettingInput, SettingService, }, }; -use crate::schema::{ - auth::{JWTPayload, OAuthCredential, OAuthProvider, RequestInvitationInput}, - repository::Repository, -}; +use crate::schema::auth::{JWTPayload, OAuthCredential, OAuthProvider}; pub trait ServiceLocator: Send + Sync { fn auth(&self) -> Arc; @@ -74,8 +74,8 @@ pub enum CoreError { #[error("{0}")] Forbidden(&'static str), - #[error("Invalid ID Error")] - InvalidIDError, + #[error("Invalid ID")] + InvalidID, #[error("Invalid input parameters")] InvalidInput(#[from] ValidationErrors), @@ -83,6 +83,9 @@ pub enum CoreError { #[error("Email is not configured")] EmailNotConfigured, + #[error("{0}")] + InvalidLicense(&'static str), + #[error(transparent)] Other(#[from] anyhow::Error), } @@ -118,6 +121,18 @@ fn check_admin(ctx: &Context) -> Result<(), CoreError> { Ok(()) } +async fn check_license(ctx: &Context, license_type: &[LicenseType]) -> Result<(), CoreError> { + let license = ctx.locator.license().read_license().await?; + + if !license_type.contains(&license.r#type) { + return Err(CoreError::InvalidLicense( + "Your plan doesn't include support for this feature.", + )); + } + + license.ensure_valid_license() +} + #[derive(Default)] pub struct Query; @@ -141,7 +156,7 @@ impl Query { async fn me(ctx: &Context) -> Result { let claims = check_claims(ctx)?; - ctx.locator.auth().get_user_by_email(&claims.sub).await + ctx.locator.auth().get_user(&claims.sub.0).await } async fn users( @@ -300,7 +315,7 @@ impl Query { }) } - async fn license(ctx: &Context) -> Result> { + async fn license(ctx: &Context) -> Result { ctx.locator.license().read_license().await } } @@ -352,11 +367,25 @@ impl Mutation { Ok(true) } + async fn password_change(ctx: &Context, input: PasswordChangeInput) -> Result { + let claims = check_claims(ctx)?; + input.validate()?; + ctx.locator + .auth() + .update_user_password( + &claims.sub.0, + input.old_password.as_deref(), + &input.new_password1, + ) + .await?; + Ok(true) + } + async fn reset_user_auth_token(ctx: &Context) -> Result { let claims = check_claims(ctx)?; ctx.locator .auth() - .reset_user_auth_token(&claims.sub) + .reset_user_auth_token(&claims.sub.0) .await?; Ok(true) } @@ -465,6 +494,7 @@ impl Mutation { input: UpdateOAuthCredentialInput, ) -> Result { check_admin(ctx)?; + check_license(ctx, &[LicenseType::Enterprise]).await?; input.validate()?; ctx.locator.auth().update_oauth_credential(input).await?; Ok(true) @@ -485,6 +515,7 @@ impl Mutation { async fn update_security_setting(ctx: &Context, input: SecuritySettingInput) -> Result { check_admin(ctx)?; + check_license(ctx, &[LicenseType::Enterprise]).await?; input.validate()?; ctx.locator.setting().update_security_setting(input).await?; Ok(true) @@ -508,6 +539,12 @@ impl Mutation { ctx.locator.license().update_license(license).await?; Ok(true) } + + async fn reset_license(ctx: &Context) -> Result { + check_admin(ctx)?; + ctx.locator.license().reset_license().await?; + Ok(true) + } } fn from_validation_errors(error: ValidationErrors) -> FieldError { diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index bffa147cc155..67e83f58606c 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -24,6 +24,7 @@ use crate::{ UpdateOAuthCredentialInput, User, }, email::EmailService, + license::LicenseService, setting::SettingService, CoreError, Result, }, @@ -33,13 +34,15 @@ use crate::{ struct AuthenticationServiceImpl { db: DbConn, mail: Arc, + license: Arc, } pub fn new_authentication_service( db: DbConn, mail: Arc, + license: Arc, ) -> impl AuthenticationService { - AuthenticationServiceImpl { db, mail } + AuthenticationServiceImpl { db, mail, license } } #[async_trait] @@ -83,8 +86,7 @@ impl AuthenticationService for AuthenticationServiceImpl { let refresh_token = generate_refresh_token(); self.db.create_refresh_token(id, &refresh_token).await?; - let Ok(access_token) = generate_jwt(JWTPayload::new(user.email.clone(), user.is_admin)) - else { + let Ok(access_token) = generate_jwt(JWTPayload::new(id.as_id(), user.is_admin)) else { return Err(anyhow!("Unknown error").into()); }; @@ -138,6 +140,35 @@ impl AuthenticationService for AuthenticationServiceImpl { Ok(()) } + async fn update_user_password( + &self, + id: &ID, + old_password: Option<&str>, + new_password: &str, + ) -> Result<()> { + let user = self + .db + .get_user(id.as_rowid()?) + .await? + .ok_or_else(|| anyhow!("Invalid user"))?; + + let password_verified = match (user.password_encrypted.is_empty(), old_password) { + (true, _) => true, + (false, None) => false, + (false, Some(old_password)) => password_verify(old_password, &user.password_encrypted), + }; + if !password_verified { + return Err(anyhow!("Password is incorrect").into()); + } + + let new_password_encrypted = + password_hash(new_password).map_err(|_| anyhow!("Unknown error"))?; + self.db + .update_user_password(user.id, new_password_encrypted) + .await?; + Ok(()) + } + async fn token_auth(&self, email: String, password: String) -> Result { let Some(user) = self.db.get_user_by_email(&email).await? else { return Err(anyhow!("User not found").into()); @@ -156,8 +187,7 @@ impl AuthenticationService for AuthenticationServiceImpl { .create_refresh_token(user.id, &refresh_token) .await?; - let Ok(access_token) = generate_jwt(JWTPayload::new(user.email.clone(), user.is_admin)) - else { + let Ok(access_token) = generate_jwt(JWTPayload::new(user.id.as_id(), user.is_admin)) else { return Err(anyhow!("Unknown error").into()); }; @@ -184,8 +214,7 @@ impl AuthenticationService for AuthenticationServiceImpl { self.db.replace_refresh_token(&token, &new_token).await?; // refresh token update is done, generate new access token based on user info - let Ok(access_token) = generate_jwt(JWTPayload::new(user.email.clone(), user.is_admin)) - else { + let Ok(access_token) = generate_jwt(JWTPayload::new(user.id.as_id(), user.is_admin)) else { return Err(anyhow!("Unknown error").into()); }; @@ -215,6 +244,12 @@ impl AuthenticationService for AuthenticationServiceImpl { } async fn update_user_role(&self, id: &ID, is_admin: bool) -> Result<()> { + if is_admin { + let license = self.license.read_license().await?; + let num_admins = self.db.count_active_admin_users().await?; + license.ensure_admin_seats(num_admins + 1)?; + } + let id = id.as_rowid()?; let user = self.db.get_user(id).await?.context("User doesn't exits")?; if user.is_owner() { @@ -232,7 +267,19 @@ impl AuthenticationService for AuthenticationServiceImpl { } } + async fn get_user(&self, id: &ID) -> Result { + let user = self.db.get_user(id.as_rowid()?).await?; + if let Some(user) = user { + Ok(user.into()) + } else { + Err(anyhow!("User not found").into()) + } + } + async fn create_invitation(&self, email: String) -> Result { + let license = self.license.read_license().await?; + license.ensure_available_seats(1)?; + let invitation = self.db.create_invitation(email.clone()).await?; let email_sent = self .mail @@ -264,8 +311,8 @@ impl AuthenticationService for AuthenticationServiceImpl { Ok(self.db.delete_invitation(id.as_rowid()?).await?.as_id()) } - async fn reset_user_auth_token(&self, email: &str) -> Result<()> { - Ok(self.db.reset_user_auth_token_by_email(email).await?) + async fn reset_user_auth_token(&self, id: &ID) -> Result<()> { + Ok(self.db.reset_user_auth_token_by_id(id.as_rowid()?).await?) } async fn list_users( @@ -317,7 +364,7 @@ impl AuthenticationService for AuthenticationServiceImpl { .create_refresh_token(user_id, &refresh_token) .await?; - let access_token = generate_jwt(JWTPayload::new(email.clone(), is_admin)) + let access_token = generate_jwt(JWTPayload::new(user_id.as_id(), is_admin)) .map_err(|_| OAuthError::Unknown)?; let resp = OAuthResponse { @@ -376,11 +423,25 @@ impl AuthenticationService for AuthenticationServiceImpl { } async fn update_user_active(&self, id: &ID, active: bool) -> Result<()> { + let license = self.license.read_license().await?; + + if active { + // Check there's sufficient seat if switching user to active. + license.ensure_available_seats(1)?; + } + let id = id.as_rowid()?; let user = self.db.get_user(id).await?.context("User doesn't exits")?; if user.is_owner() { return Err(anyhow!("The owner's active status cannot be changed").into()); } + + if active && user.is_admin { + // Check there's sufficient seat if an admin being swtiched to active. + let num_admins = self.db.count_active_admin_users().await?; + license.ensure_admin_seats(num_admins + 1)?; + } + Ok(self.db.update_user_active(id, active).await?) } } @@ -467,14 +528,79 @@ fn password_verify(raw: &str, hash: &str) -> bool { #[cfg(test)] mod tests { - async fn test_authentication_service() -> AuthenticationServiceImpl { + struct MockLicenseService { + status: LicenseStatus, + seats: i32, + seats_used: i32, + } + + impl MockLicenseService { + fn team() -> Self { + Self { + status: LicenseStatus::Ok, + seats: 5, + seats_used: 1, + } + } + + fn team_with_seats(seats: i32) -> Self { + Self { + status: LicenseStatus::Ok, + seats, + seats_used: 1, + } + } + + fn invalid() -> Self { + Self { + status: LicenseStatus::Expired, + seats: 5, + seats_used: 1, + } + } + } + + #[async_trait] + impl LicenseService for MockLicenseService { + async fn read_license(&self) -> Result { + Ok(LicenseInfo { + r#type: crate::schema::license::LicenseType::Team, + status: self.status.clone(), + seats: self.seats, + seats_used: self.seats_used, + issued_at: Some(Utc::now()), + expires_at: Some(Utc::now()), + }) + } + + async fn update_license(&self, _: String) -> Result<()> { + unimplemented!() + } + + async fn reset_license(&self) -> Result<()> { + unimplemented!() + } + } + + async fn test_authentication_service_with_license( + license: Arc, + ) -> AuthenticationServiceImpl { let db = DbConn::new_in_memory().await.unwrap(); AuthenticationServiceImpl { db: db.clone(), mail: Arc::new(new_email_service(db).await.unwrap()), + license, } } + async fn test_authentication_service() -> AuthenticationServiceImpl { + test_authentication_service_with_license(Arc::new(MockLicenseService::team())).await + } + + async fn test_authentication_service_without_valid_license() -> AuthenticationServiceImpl { + test_authentication_service_with_license(Arc::new(MockLicenseService::invalid())).await + } + async fn test_authentication_service_with_mail() -> (AuthenticationServiceImpl, TestEmailServer) { let db = DbConn::new_in_memory().await.unwrap(); @@ -482,6 +608,7 @@ mod tests { let service = AuthenticationServiceImpl { db: db.clone(), mail: Arc::new(smtp.create_test_email_service(db).await), + license: Arc::new(MockLicenseService::team()), }; (service, smtp) } @@ -491,7 +618,10 @@ mod tests { use serial_test::serial; use super::*; - use crate::service::email::{new_email_service, testutils::TestEmailServer}; + use crate::{ + schema::license::{LicenseInfo, LicenseStatus}, + service::email::{new_email_service, testutils::TestEmailServer}, + }; #[test] fn test_password_hash() { @@ -646,7 +776,7 @@ mod tests { register_admin_user(&service).await; let user = service.get_user_by_email(ADMIN_EMAIL).await.unwrap(); - service.reset_user_auth_token(&user.email).await.unwrap(); + service.reset_user_auth_token(&user.id).await.unwrap(); let user2 = service.get_user_by_email(ADMIN_EMAIL).await.unwrap(); assert_ne!(user.auth_token, user2.auth_token); @@ -976,4 +1106,119 @@ mod tests { assert!(service.allow_self_signup().await.unwrap()); } + + #[tokio::test] + async fn test_create_invitation_without_license() { + let service = test_authentication_service_without_valid_license().await; + assert_matches!( + service.create_invitation("abc.com".into()).await, + Err(CoreError::InvalidLicense(_)) + ) + } + + #[tokio::test] + async fn test_create_invitation_without_sufficient_seats() { + let service = test_authentication_service_with_license(Arc::new( + MockLicenseService::team_with_seats(2), + )) + .await; + assert_matches!(service.create_invitation("abc.com".into()).await, Ok(_)); + + let service = test_authentication_service_with_license(Arc::new( + MockLicenseService::team_with_seats(1), + )) + .await; + assert_matches!( + service.create_invitation("abc.com".into()).await, + Err(CoreError::InvalidLicense(_)) + ) + } + + #[tokio::test] + async fn test_update_user_active_on_admin_seats() { + let service = test_authentication_service_with_license(Arc::new( + MockLicenseService::team_with_seats(3), + )) + .await; + + // Create owner user. + service + .register("a@example.com".into(), "pass".into(), None) + .await + .unwrap(); + + let user1 = service + .db + .create_user("b@example.com".into(), "pass".into(), false) + .await + .unwrap(); + let user2 = service + .db + .create_user("c@example.com".into(), "pass".into(), false) + .await + .unwrap(); + let user3 = service + .db + .create_user("d@example.com".into(), "pass".into(), false) + .await + .unwrap(); + + service + .update_user_role(&user1.as_id(), true) + .await + .unwrap(); + service + .update_user_role(&user2.as_id(), true) + .await + .unwrap(); + + assert_matches!(service.db.count_active_admin_users().await, Ok(3)); + + assert_matches!( + service.update_user_role(&user3.as_id(), true).await, + Err(CoreError::InvalidLicense(_)) + ); + + // Change user2 to deactive. + service + .update_user_active(&user2.as_id(), false) + .await + .unwrap(); + + assert_matches!(service.db.count_active_admin_users().await, Ok(2)); + assert_matches!(service.update_user_role(&user3.as_id(), true).await, Ok(_)); + + // Not able to toggle user to active due to admin seat limits. + assert_matches!( + service.update_user_role(&user2.as_id(), true).await, + Err(CoreError::InvalidLicense(_)) + ); + } + + #[tokio::test] + async fn test_update_password() { + let service = test_authentication_service().await; + let id = service + .db + .create_user("test@example.com".into(), "".into(), true) + .await + .unwrap(); + + let id = id.as_id(); + + assert!(service + .update_user_password(&id, None, "newpass") + .await + .is_ok()); + + assert!(service + .update_user_password(&id, None, "newpass2") + .await + .is_err()); + + assert!(service + .update_user_password(&id, Some("newpass"), "newpass2") + .await + .is_ok()); + } } diff --git a/ee/tabby-webserver/src/service/dao.rs b/ee/tabby-webserver/src/service/dao.rs index 014728bcb1e9..569bfe4c04e6 100644 --- a/ee/tabby-webserver/src/service/dao.rs +++ b/ee/tabby-webserver/src/service/dao.rs @@ -52,6 +52,7 @@ impl From for auth::User { auth_token: val.auth_token, created_at: val.created_at, active: val.active, + is_password_set: !val.password_encrypted.is_empty(), } } } @@ -145,7 +146,7 @@ impl AsRowid for juniper::ID { .decode(self) .first() .map(|i| *i as i32) - .ok_or(CoreError::InvalidIDError) + .ok_or(CoreError::InvalidID) } } @@ -154,6 +155,12 @@ pub trait AsID { } impl AsID for i32 { + fn as_id(&self) -> juniper::ID { + (*self as i64).as_id() + } +} + +impl AsID for i64 { fn as_id(&self) -> juniper::ID { juniper::ID::new(HASHER.encode(&[*self as u64])) } diff --git a/ee/tabby-webserver/src/service/license.rs b/ee/tabby-webserver/src/service/license.rs index 75128e7d2958..f4803718d7fa 100644 --- a/ee/tabby-webserver/src/service/license.rs +++ b/ee/tabby-webserver/src/service/license.rs @@ -1,4 +1,4 @@ -use anyhow::anyhow; +use anyhow::{anyhow, Context}; use async_trait::async_trait; use chrono::{DateTime, Duration, NaiveDateTime, Utc}; use jsonwebtoken as jwt; @@ -56,9 +56,8 @@ fn validate_license(token: &str) -> Result Result> { - Ok(NaiveDateTime::from_timestamp_opt(secs, 0) - .ok_or_else(|| anyhow!("Timestamp is corrupt"))? - .and_utc()) + let datetime = NaiveDateTime::from_timestamp_opt(secs, 0).context("Timestamp is corrupt")?; + Ok(datetime.and_utc()) } struct LicenseServiceImpl { @@ -73,13 +72,32 @@ impl LicenseServiceImpl { let lock = self.seats.read().await; *lock }; - if force_refresh || now - refreshed > Duration::minutes(5) { + if force_refresh || now - refreshed > Duration::seconds(15) { let mut lock = self.seats.write().await; seats = self.db.count_active_users().await?; *lock = (now, seats); } Ok(seats) } + + async fn make_community_license(&self) -> Result { + let seats_used = self.read_used_seats(false).await?; + let status = if seats_used > LicenseInfo::seat_limits_for_community_license() { + LicenseStatus::SeatsExceeded + } else { + LicenseStatus::Ok + }; + + Ok(LicenseInfo { + r#type: LicenseType::Community, + status, + seats: LicenseInfo::seat_limits_for_community_license() as i32, + seats_used: seats_used as i32, + issued_at: None, + expires_at: None, + } + .guard_seat_limit()) + } } pub async fn new_license_service(db: DbConn) -> Result { @@ -107,24 +125,25 @@ fn license_info_from_raw(raw: LicenseJWTPayload, seats_used: usize) -> Result
  • Result> { + async fn read_license(&self) -> Result { let Some(license) = self.db.read_enterprise_license().await? else { - return Ok(None); + return self.make_community_license().await; }; let license = validate_license(&license).map_err(|e| anyhow!("License is corrupt: {e:?}"))?; let seats = self.read_used_seats(false).await?; let license = license_info_from_raw(license, seats)?; - Ok(Some(license)) + Ok(license) } async fn update_license(&self, license: String) -> Result<()> { @@ -139,6 +158,11 @@ impl LicenseService for LicenseServiceImpl { }; Ok(()) } + + async fn reset_license(&self) -> Result<()> { + self.db.update_enterprise_license(None).await?; + Ok(()) + } } #[cfg(test)] @@ -147,9 +171,9 @@ mod tests { use super::*; - const VALID_TOKEN: &str = "eyJhbGciOiJSUzUxMiJ9.eyJpc3MiOiJ0YWJieW1sLmNvbSIsInN1YiI6ImZha2VAdGFiYnltbC5jb20iLCJpYXQiOjE3MDUxOTgxMDIsImV4cCI6MTgwNzM5ODcwMiwidHlwIjoiVEVBTSIsIm51bSI6MTB9.vVo7PDevytGw2KXU5E-KMdJBijwOWsD1zKIf26rcjfxa3wDesGY40zuYZWyZFMfmAtBTO7DBgqdWnriHnF_HOnoAEDCycrgoxuSJW5TS9XsCWto-3rDhUsjRZ1wls-ztQu3Gxo_84UHUFwrXe-RHmJi_3w_YO-2L-nVw7JDd5zR8CEdLxeccD47vBrumYA7ybultoDHpHxSppjHlW1VPXavoaBIO1Twnbf52uJlbzJmloViDxoq-_9lxcN1hDN3KKE3crzO9uHK4jjZy_1KNHhCIIcnINek6SBl6lWZw9R88UfdP6uaVOTOHDFbGwv544TSLA_oKZXXntXhldKCp94YN8J4djHim91WwYBQARrpQKiQGP1APEQQdv_YO4iUC3QTLOVw_NMjyma0feVjzHYAap_2Q9HgnxyJfMH-KiH2zaR6BcdOfWV86crO5M0qNoP-XOgy4uU8eE2-PevOKM6uVwYiwoNZL4e9ttH6ratJj0tyqGW_3HYpsVyThzqDPisEz95knsrVL-iagwHRd00l6Mqfwcjbn-gOuUOV9knRIpPvUmfKjjjHgb-JI0qMAIdgeVtwQp0pNqPsKwenMwkpYQH1awfuB_Ia7SyMUNEzTAY8k_J4R6kCZ5XKJ2VTCljd9aJFSZpw-K57reUX1eLc6-Cwt1iI4d23M5UlYjvs"; - const EXPIRED_TOKEN: &str = "eyJhbGciOiJSUzUxMiJ9.eyJpc3MiOiJ0YWJieW1sLmNvbSIsInN1YiI6ImZha2VAdGFiYnltbC5jb20iLCJpYXQiOjE3MDUxOTgxMDIsImV4cCI6MTcwNzM5ODcwMiwidHlwIjoiVEVBTSIsIm51bSI6MTB9.19wrmSSZUQAj_nfnBljUARD3vz_XEIDh4wpi_U2P6LDRcvm7QYCro__LxUjIf45aE9BBiZCPBRTVOw_tMbegTAv5yK9G9cllGPdRDKWjf24BJpHt2wBKOwhCToUKp8R8D50bQ3cxHuz7J3XxcOMtwKxNRlwaufO-vgxX73v13z_bN6y5ix8FC5JEjY1z3fNPc_TnuuHnaXXqgqL9OJTrxhh5FErqR52kmxGGn2KCM8rm2Nfu0It2IZQuyJHSceZ3-iiIxsrVdXxbO4KHXLEOXos0xJRV8QG9_9VjAo6qui6BioygwrcPqHT7OoG3WfcT8XE9rcEX-s9PZ54_XxLm0yh81g54xPI92n94pe32XfE9T-YXNK3MLAdZWwDhp_sKXTcMSIr7mI9OA7eczZUpvI4BuDM8s1irNx4DKdfTwNchHDfEPmGmO53RHyVEbrS72jF9GBRBIwPmpGppWhcwpVNmlRJw3j1Sa_ttcGikPnBZBrUxGqzynq4q1VpeCpRoTzO9_nw5eciKMpaKww0P5Edqm5kKgg48aABfsTU3hLqTIr9rgjXePL_gEse6MJX_JC8I7-R17iQmMxKiNa9bTqSIk56qlB6gwZTzcjEtpnYlzZ05Ci6D3JBH9ZdO_F3UZDt5JdAD5dqsKl8PfWpxaWpg7FXNlqxYO9BpxCwr_7g"; - const INCOMPLETE_TOKEN: &str = "eyJhbGciOiJSUzUxMiJ9.eyJpc3MiOiJ0YWJieW1sLmNvbSIsInN1YiI6ImZha2VAdGFiYnltbC5jb20iLCJpYXQiOjE3MDUxOTgxMDIsImV4cCI6MTgwNzM5ODcwMiwidHlwIjoiVEVBTSJ9.Xdp7Tgi39RN3qBfDAT_RncCDF2lSSouT4fjR0YT8F4qN8qkocxgvCa6JyxlksaiqGKWb_aYJvkhCviMHnT_pnoNpR8YaLvB4vezEAdDWLf3jBqzhlsrCCbMGh72wFYKRIODhIHeTzldU4F06I9sz5HdtQpn42Q8WC8tAzG109vHtxcdC7D85u0CumJ35DcV7lTfpfIkil3PORReg0ysjZNjQ2JbiFqMF1VbBmC-DsoTrJoHlrxdHowMQsXv89C80pchx4UFSm7Z9tHiMUTOzfErScsGJI1VC5p8SYA3N4nsrPn-iup1CxOBIdK57BHedKGpd_hi1AVWYB4zXcc8HzzpqgwHulfaw_5vNvRMdkDGj3X2afU3O3rZ4jT_KLGjY-3Krgol8JHgJYiPXkBypiajFU6rVeMLScx-X-2-n3KBdR4GQ9la90QHSyIQUpiGRRfPhviBFDtAfcjJYo1Irlu6MGVhgFq9JH5SOVTn57V0A_VeAbj8WZNdML9hio9xqxP86DprnP_ApHpO_xbi-sx2GCmUyfC10eKnX8_sAB1n7z0AaHz4e-6SGm1I-wQsWcXjZfRYw0Vtogz7wVuyAIpm8lF58XjtOwQ9bP1kD03TGIcBTvEtgA6QUhRcximGJ5buK9X2TTd4TlHjFF1krrmYAUEDgFsorseoKvMkspVE"; + const VALID_TOKEN: &str = "eyJhbGciOiJSUzUxMiJ9.eyJpc3MiOiJ0YWJieW1sLmNvbSIsInN1YiI6ImZha2VAdGFiYnltbC5jb20iLCJpYXQiOjE3MDUxOTgxMDIsImV4cCI6MTgwNzM5ODcwMiwidHlwIjoiVEVBTSIsIm51bSI6MX0.r99qAkHGAzjZtS904ko5MMklquMcEJdibVGAZAxrJTf-kKBT-Kc-u-A8o7ZSrLD0eubIxNrLb16UsyAMxJ6xnIJY4h8BTIR9cz_dTezyGywpuAKI13Q2S77tfwcyBF6icFkDsz187MSQGPQuTdVNU8zXkYR5ZkNs8_Uc8SL940xt0KHWLU9DX8KT6eCcVMwAypLyAsSTRJeqE8uRumq1K6dKK7wkE_HQrg9nSmr40A5ZZPzRsUp6hShJyMYSp-D02utbT8bAzVPw6alBgZWrmlVEvdcvfO81DZylUIm-pszKityfT5tmuyMWtUx3AeLXSiQWZOpah3OBnL11IKhNhYWSzUMGuDENHfbP9hlSJvzjq8WeN73nXSjkNEVYetT2er6pnoGrvFUBWcLLdWcl4p324WwqsP5A7ZDbWamo62yPxHUy7Vr4ySRLDfNEQbjP8JVPacpx3-5oY16LlzS4e9RhR0G-aykJitrLd5--gTVGxlxsLbmz33TTDd3nMGuQp2xmpZsw4rTKefEN7hCdvgJhtwRLgL4jxSm2mBgtwWH_i0uuBFpCYNgh97rU-Cak66adXDydAOr6-imSHAIlSphGj6G4rUdbMtBV0n1MVGg3vIyHQot3hMaH6uXMpHOUEtxQivkp0F-fY6PoFr49HfWD-ZuneENaKKjB8p_rd9k"; + const EXPIRED_TOKEN: &str = "eyJhbGciOiJSUzUxMiJ9.eyJpc3MiOiJ0YWJieW1sLmNvbSIsInN1YiI6ImZha2VAdGFiYnltbC5jb20iLCJpYXQiOjE3MDUxOTgxMDIsImV4cCI6MTcwNDM5ODcwMiwidHlwIjoiVEVBTSIsIm51bSI6MX0.UBufd2YlyhuChdCSZvbvEBtxLABhZSuhya4KHKHYM2ABaSTjYYtSyT-yv0i9b8sySBoeu7kG0XBNrLQOg4fcirR5DxOFxiskI7qLLSQEIDYe-xnEbvxqKhN3RpHkxik9_OlvElvpIGrZRQxiELhESIM0NGck0Dz6MwTDFutkHZFh06cLFeoihs1rn44SknL3wP_afyCaOpQtTjDfsayBMfyDAriTG8HSnPbrw5Om7ER7uAqszhX8wpFonDeFeVB0OIUjayfL-SAMdLqNEqaFsUcuE4cUk7o9tA2jsYz2-BRlwDocLpRVp2V-K8MuyQJhDTiswbey2DE5tNRvnd3nNaVr7Pmt3mF7NMt8op8hl4I9scoThFBj9Bb1iMfAGVSXlRn9Kf2HHe2BJXGWC3w9bjWH2KRPMP3tScJ4CQccIJxZPU-fcX7IC1q8R4PWDYS11TDJ03PvCTEGFt3fBTLLaGOeoYHYNnd4qux317YhGtWTOO6ESIuoxQkJdTpNVOwfNmCVSfFUvJYs0l4r7z-QouHAd79Ck_GJ-cdiIOrV9MB1Lq6ayk267bXfdi0Lx6-PYxrTwXEkF5tBydrsPyhoReAbH8yQDqzlPbQzOlLo--Z4940kSEpgEsL9G6ymG5wDlMzNuQfjbYbCI0L19Spx5QRGtyYXtiSU1Tq-hhGm3zA"; + const INCOMPLETE_TOKEN: &str = "eyJhbGciOiJSUzUxMiJ9.eyJpc3MiOiJ0YWJieW1sLmNvbSIsInN1YiI6ImZha2VAdGFiYnltbC5jb20iLCJpYXQiOjE3MDUxOTgxMDIsImV4cCI6MTgwNDM5ODcwMiwidHlwIjoiVEVBTSJ9.juNQeg8jMRj7Q2XbmHSdneKZbTP_BIL43yW3He5avIRAKee1NF9-qg4ndGOYVWBmtoO6Y_CAts_trSw6gmuDuwWcmSbbr7CWQOYuNrMj1_Gp1MctA8zzC3yzr0EoBLzqkNBq3OySlfOkohopmJ6Lu0d0KRtf46qq94cMDAlfs7etcVGkGqfMEwxznptXiF7_S3qRVbahvJDPJlu_ozwn51tICXMrlGV_P6jdBcNLQ8I1LAH2RfyH9u-4mUSTKt-obnXw6mtPxPjl07MEajM_wW3X05-iRygQfyzDulvW0EXf39OnW2kCuyfQWx5Zksr-sCNTEL2VSalf9o8MchjAhDN5QrygdZkk7KXwt3O54tpcnFVABw9ORxJtTrsZJD-YvdmS01O6qLfMRWs2CGWFTfDJLxMSiBhAsy4DC4TkZN4UnBpX09U7n6f_0NUr83YAWcw0Rlp32k01j9iPUWSdePZh46Ck00XdzLcc15xfqv__ilaLAyRtb9JUVBX7g-VaLb1YGk658t19eukRNzE6WFyKfAE7u6EbxowtFQqVKYXWX_zDHoalo3DjUmPBV_VsorcBg4cjhrhBPBOB5f7Wa8r7eiJz1gWEj1xJEK2Y_mdShAvxNSWPSTvNvviPTgJbvbwDTzQ0It_d066ADBY2o0y5DTMP23EPL-oZ14TYIY4"; #[test] fn test_validate_license() { @@ -180,15 +204,21 @@ mod tests { } #[tokio::test] - async fn test_create_update_license() { + async fn test_license_mutations() { let db = DbConn::new_in_memory().await.unwrap(); let service = new_license_service(db).await.unwrap(); assert!(service.update_license("bad_token".into()).await.is_err()); service.update_license(VALID_TOKEN.into()).await.unwrap(); - assert!(service.read_license().await.unwrap().is_some()); + assert!(service.read_license().await.is_ok()); assert!(service.update_license(EXPIRED_TOKEN.into()).await.is_err()); + + service.reset_license().await.unwrap(); + assert_eq!( + service.read_license().await.unwrap().seats, + LicenseInfo::seat_limits_for_community_license() as i32 + ); } } diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index 64af5dabc975..26e9a6e3bf65 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -33,7 +33,7 @@ use crate::schema::{ auth::AuthenticationService, email::EmailService, job::JobService, - license::LicenseService, + license::{IsLicenseValid, LicenseService}, repository::RepositoryService, setting::SettingService, worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService}, @@ -77,7 +77,11 @@ impl ServerContext { completion: worker::WorkerGroup::default(), chat: worker::WorkerGroup::default(), mail: mail.clone(), - auth: Arc::new(new_authentication_service(db_conn.clone(), mail)), + auth: Arc::new(new_authentication_service( + db_conn.clone(), + mail, + license.clone(), + )), license, db_conn, logger, @@ -86,7 +90,7 @@ impl ServerContext { } } - async fn authorize_request(&self, request: &Request) -> (bool, Option) { + async fn authorize_request(&self, request: &Request) -> (bool, Option) { let path = request.uri().path(); if !(path.starts_with("/v1/") || path.starts_with("/v1beta/")) { return (true, None); @@ -105,11 +109,25 @@ impl ServerContext { // Admin system is initialized, but there is no valid token. return (false, None); }; + + // Allow JWT based access (from web browser), regardless of the license status. if let Ok(jwt) = self.auth.verify_access_token(token).await { - return (true, Some(jwt.sub)); + return (true, Some(jwt.sub.0)); } - match self.db_conn.verify_auth_token(token).await { - Ok(email) => (true, Some(email)), + + let is_license_valid = self + .license + .read_license() + .await + .ensure_valid_license() + .is_ok(); + // If there's no valid license, only allows owner access. + match self + .db_conn + .verify_auth_token(token, !is_license_valid) + .await + { + Ok(id) => (true, Some(id.as_id())), Err(_) => (false, None), } } @@ -133,20 +151,28 @@ impl WorkerService for ServerContext { } async fn register_worker(&self, worker: Worker) -> Result { - let worker = match worker.kind { - WorkerKind::Completion => self.completion.register(worker).await, - WorkerKind::Chat => self.chat.register(worker).await, + let worker_group = match worker.kind { + WorkerKind::Completion => &self.completion, + WorkerKind::Chat => &self.chat, }; - if let Some(worker) = worker { - info!( - "registering <{:?}> worker running at {}", - worker.kind, worker.addr - ); - Ok(worker) - } else { - Err(RegisterWorkerError::RequiresEnterpriseLicense) + let count_workers = worker_group.list().await.len(); + let license = self + .license + .read_license() + .await + .map_err(|_| RegisterWorkerError::RequiresEnterpriseLicense)?; + + if license.check_node_limit(count_workers + 1) { + return Err(RegisterWorkerError::RequiresEnterpriseLicense); } + + let worker = worker_group.register(worker).await; + info!( + "registering <{:?}> worker running at {}", + worker.kind, worker.addr + ); + Ok(worker) } async fn unregister_worker(&self, worker_addr: &str) { diff --git a/ee/tabby-webserver/src/service/worker.rs b/ee/tabby-webserver/src/service/worker.rs index 72b39a211345..d64b1cda0538 100644 --- a/ee/tabby-webserver/src/service/worker.rs +++ b/ee/tabby-webserver/src/service/worker.rs @@ -1,7 +1,6 @@ use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; -use tracing::error; use crate::schema::worker::Worker; @@ -24,18 +23,14 @@ impl WorkerGroup { self.workers.read().await.clone() } - pub async fn register(&self, worker: Worker) -> Option { + pub async fn register(&self, worker: Worker) -> Worker { let mut workers = self.workers.write().await; - if workers.len() >= 1 { - error!("You need enterprise license to utilize more than 1 workers, please contact hi@tabbyml.com for information."); - return None; - } if workers.iter().all(|x| x.addr != worker.addr) { workers.push(worker.clone()); } - Some(worker) + worker } pub async fn unregister(&self, worker_addr: &str) -> bool { @@ -71,12 +66,7 @@ mod tests { let worker1 = make_worker("http://127.0.0.1:8080"); let worker2 = make_worker("http://127.0.0.2:8080"); - // Register success. - assert!(wg.register(worker1.clone()).await.is_some()); - assert!(wg.select().await.is_some()); - - // Register failed, as > 1 workers requires enterprise license. - assert!(wg.register(worker2.clone()).await.is_none()); + wg.register(worker1.clone()).await; let workers = wg.list().await; assert_eq!(workers.len(), 1); diff --git a/website/docs/faq.mdx b/website/docs/faq.mdx index 22af242286fd..e3c8317d10fd 100644 --- a/website/docs/faq.mdx +++ b/website/docs/faq.mdx @@ -42,4 +42,10 @@ Users are free to fork the repository to create their own registry. If a user's For details on the registry format, please refer to [models.json](https://github.com/TabbyML/registry-tabby/blob/main/models.json) - \ No newline at end of file + + + + +Tabby also supports loading models from a local directory that follow our specifications as outlined in [MODEL_SPEC.md](https://github.com/TabbyML/tabby/blob/main/MODEL_SPEC.md). + +