diff --git a/Cargo.lock b/Cargo.lock index bd71825..faf29e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -658,6 +658,29 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "perplexity" +version = "0.1.0" +dependencies = [ + "futures-util", + "pretty_assertions", + "reqwest", + "serde", + "serde_json", + "thiserror", + "tokio", +] + +[[package]] +name = "perplexity-chat-completion" +version = "0.1.0" +dependencies = [ + "futures-util", + "perplexity", + "serde_json", + "tokio", +] + [[package]] name = "pin-project" version = "1.1.5" diff --git a/Cargo.toml b/Cargo.toml index e268cca..98dd178 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["anthropic", "mesh", "openai", "replicate", "examples/*"] +members = ["anthropic", "mesh", "openai", "perplexity", "replicate", "examples/*"] resolver = "2" [workspace.package] diff --git a/examples/perplexity-chat-completion/Cargo.toml b/examples/perplexity-chat-completion/Cargo.toml new file mode 100644 index 0000000..d2a924a --- /dev/null +++ b/examples/perplexity-chat-completion/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "perplexity-chat-completion" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +perplexity = { path = "../../perplexity" } +futures-util = "0.3.30" +tokio = { version = "1.39.2", features = ["full"] } +serde_json = "1.0.128" diff --git a/examples/perplexity-chat-completion/README.md b/examples/perplexity-chat-completion/README.md new file mode 100644 index 0000000..740d4b6 --- /dev/null +++ b/examples/perplexity-chat-completion/README.md @@ -0,0 +1 @@ +# Perplexity chat completion Example diff --git a/examples/perplexity-chat-completion/src/main.rs b/examples/perplexity-chat-completion/src/main.rs new file mode 100644 index 0000000..0096d96 --- /dev/null +++ b/examples/perplexity-chat-completion/src/main.rs @@ -0,0 +1,17 @@ +use perplexity::{ + client::{Client, CreateChatCompletion, Model}, + config::Config, +}; + +#[tokio::main] +async fn main() { + let api_key = std::env::var("PERPLEXITY_API_KEY") + .expect("environment variable PERPLEXITY_API_KEY should be defined"); + + let config = Config::new(api_key); + let client = Client::new(config).unwrap(); + + let message = CreateChatCompletion::new(Model::Llama31SonarLargeOnline); + let result = client.create_completion(message).await.unwrap(); + println!("{:?}", result); +} diff --git a/perplexity/Cargo.toml b/perplexity/Cargo.toml new file mode 100644 index 0000000..a9eeac1 --- /dev/null +++ b/perplexity/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "perplexity" +version = "0.1.0" +edition.workspace = true +authors.workspace = true +description = "Perplexity Rust SDK" +homepage.workspace = true +documentation.workspace = true +repository = "https://github.com/roushou/mesh" +readme = "README.md" +license.workspace = true +keywords = ["ai", "chat", "perplexity"] + +[dependencies] +futures-util = "0.3.30" +reqwest = { version = "0.12.5", features = ["json", "stream"] } +serde = { version = "1.0.206", features = ["derive"] } +serde_json = "1.0.124" +thiserror = "1.0.63" +tokio = { version = "1.39.2", features = ["full"] } + +[dev-dependencies] +pretty_assertions = "1.4.0" diff --git a/perplexity/README.md b/perplexity/README.md new file mode 100644 index 0000000..0ba29df --- /dev/null +++ b/perplexity/README.md @@ -0,0 +1,55 @@ +# Perplexity Rust SDK + +[![Crates.io][crates-badge]][crates-url] +[![MIT licensed][mit-badge]][mit-url] +[![APACHE-2.0 licensed][apache-badge]][apache-url] +[![Build Status][actions-badge]][actions-url] + +[crates-badge]: https://img.shields.io/crates/v/perplexity.svg +[crates-url]: https://crates.io/crates/perplexity +[mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg +[mit-url]: https://github.com/roushou/mesh/blob/master/LICENSE-MIT +[apache-badge]: https://img.shields.io/badge/license-apache-blue.svg +[apache-url]: https://github.com/roushou/mesh/blob/master/LICENSE-APACHE +[actions-badge]: https://github.com/roushou/mesh/workflows/CI/badge.svg +[actions-url]: https://github.com/roushou/mesh/actions?query=workflow%3ACI+branch%3Amaster + +This is an unofficial Rust SDK for the Perplexity API. + +More information about this crate can be found in the [crate documentation](https://crates.io/crates/perplexity). + +## Installation + +Add `perplexity` as a dependency to your `Cargo.toml` + +```sh +$ cargo add perplexity +``` + +## Usage + +An example to create a completion. + +```rust,ignore +use perplexity::{ + client::{Client, CreateChatCompletion, Model}, + config::Config, +}; + +#[tokio::main] +async fn main() { + let api_key = std::env::var("PERPLEXITY_API_KEY") + .expect("environment variable PERPLEXITY_API_KEY should be defined"); + + let config = Config::new(api_key); + let client = Client::new(config).unwrap(); + + let message = CreateChatCompletion::new(Model::Llama31SonarLargeOnline); + let result = client.create_completion(message.clone()).await.unwrap(); + println!("{:?}", result); +} +``` + +## License + +This project is licensed under the [MIT license](../LICENSE-MIT) and [Apache-2.0](../LICENSE-APACHE) license. diff --git a/perplexity/src/client.rs b/perplexity/src/client.rs new file mode 100644 index 0000000..c7ea09f --- /dev/null +++ b/perplexity/src/client.rs @@ -0,0 +1,341 @@ +use reqwest::{ + header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE}, + Client as ReqwestClient, Method, RequestBuilder, Url, +}; +use serde::{Deserialize, Serialize}; + +use crate::{config::Config, error::Error}; + +pub struct Client { + base_url: Url, + http_client: ReqwestClient, +} + +impl Client { + pub fn new(config: Config) -> Result { + let mut headers = HeaderMap::new(); + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(format!("Bearer {}", config.api_key.as_str()).as_str())?, + ); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + + let http_client = reqwest::Client::builder() + .default_headers(headers) + .build()?; + + let base_url = + Url::parse(&config.base_url).map_err(|err| Error::UrlParse(err.to_string()))?; + + Ok(Self { + base_url, + http_client, + }) + } + + pub async fn create_completion( + &self, + payload: CreateChatCompletion, + ) -> Result { + let completion = self + .request(Method::POST, "chat/completions")? + .json(&payload) + .send() + .await? + .json::() + .await?; + Ok(completion) + } + + fn request(&self, method: Method, path: &str) -> Result { + let url = self + .base_url + .join(path) + .map_err(|err| Error::UrlParse(err.to_string()))?; + Ok(self.http_client.request(method, url)) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatCompletion { + /// An ID generated uniquely for each response. + pub id: String, + + /// The model used to generate the response. + pub model: Model, + + /// The object type, which always equals **chat.completion**. + pub object: String, + + /// The Unix timestamp (in seconds) of when the completion was created. + pub created: u64, + + /// The list of completion choices the model generated for the input prompt. + pub choices: Vec, + + /// Usage statistics for the completion request. + pub usage: CompletionUsage, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CompletionChoice { + pub index: u64, + + /// The reason the model stopped generating tokens. Possible values include stop if the model hit a natural stopping point, or length if the maximum number of tokens specified in the request was reached. + pub finish_reason: FinishReason, + + /// The message generated by the model. + pub message: Message, + + /// The incrementally streamed next tokens. Only meaningful when **stream = true**. + pub delta: CompletionDelta, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum FinishReason { + Stop, + Length, +} + +/// The message generated by the model. +#[derive(Debug, Serialize, Deserialize)] +pub struct Message { + /// The contents of the message in this turn of conversation. + pub content: String, + + /// The role of the speaker in this turn of conversation. After the (optional) system message, user and assistant roles should alternate with user then assistant, ending in user. + pub role: Role, +} + +/// The incrementally streamed next tokens. Only meaningful when **stream = true**. +#[derive(Debug, Serialize, Deserialize)] +pub struct CompletionDelta { + /// The contents of the message in this turn of conversation. + pub content: String, + + /// The role of the speaker in this turn of conversation. After the (optional) system message, user and assistant roles should alternate with user then assistant, ending in user. + pub role: Role, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + System, + User, + Assistant, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CompletionUsage { + /// The number of tokens provided in the request prompt. + pub prompt_tokens: u64, + + /// The number of tokens generated in the response output. + pub completion_tokens: u64, + + /// The total number of tokens used in the chat completion (prompt + completion). + pub total_tokens: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CreateChatCompletion { + /// The name of the model that will complete your prompt. + pub model: Model, + + /// The maximum number of completion tokens returned by the API. The total number of tokens requested in **max_tokens** plus the number of prompt tokens sent in messages must not exceed the context window token limit of model requested. If left unspecified, then the model will generate tokens until either it reaches its stop token or the end of its context window. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// The amount of randomness in the response, valued between 0 inclusive and 2 exclusive. Higher values are more random, and lower values are more deterministic. + /// + /// Defaults to **0.2**. + pub temperature: f32, + + /// The nucleus sampling threshold, valued between 0 and 1 inclusive. For each subsequent token, the model considers the results of the tokens with **top_p** probability mass. + /// + /// Defaults to **0.9**. + pub top_p: f32, + + /// Determines whether or not a request to an online model should return citations. + /// + /// Defaults to **false**. + pub return_citations: bool, + + /// Given a list of domains, limit the citations used by the online model to URLs from the specified domains. Currently limited to only 3 domains for whitelisting and blacklisting. For blacklisting add a **-** to the beginning of the domain string. + pub search_domain_filter: Option>, + + /// Determines whether or not a request to an online model should return images. + /// + /// Defaults to **false**. + pub return_images: bool, + + /// Determines whether or not a request to an online model should return related questions. + /// + /// Defaults to **false**. + pub return_related_questions: bool, + + /// Returns search results within the specified time interval - does not apply to images. + pub search_recency_filter: RecencyFilter, + + /// The number of tokens to keep for highest **top-k** filtering, specified as an integer between **0** and **2048** inclusive. If set to **0**, **top-k** filtering is disabled. + /// + /// Defaults to **0**. + pub top_k: u16, + + /// Determines whether or not to incrementally stream the response with server-sent events. + /// + /// Defaults to **false**. + pub stream: bool, + + /// A value between **-2.0** and **2.0**. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. Incompatible with **frequency_penalty**. + /// + /// Defaults to **0**. + pub presence_penalty: u8, + + /// A multiplicative penalty greater than **0**. Values greater than **1.0** penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. A value of **1.0** means no penalty. Incompatible with **presence_penalty**. + /// + /// Defaults to **1**. + pub frequency_penalty: f32, +} + +impl CreateChatCompletion { + pub fn new(model: Model) -> Self { + Self { + model, + ..Default::default() + } + } + + pub fn with_model(mut self, model: Model) -> Self { + self.model = model; + self + } + + pub fn with_stream(mut self, stream: bool) -> Self { + self.stream = stream; + self + } + + pub fn with_max_tokens(mut self, max_tokens: u64) -> Self { + self.max_tokens = Some(max_tokens); + self + } + + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.temperature = temperature; + self + } + + pub fn with_top_k(mut self, top_k: u16) -> Self { + self.top_k = top_k; + self + } + + pub fn with_top_p(mut self, top_p: f32) -> Self { + self.top_p = top_p; + self + } + + pub fn with_frequency_penalty(mut self, frequency_penalty: f32) -> Self { + self.frequency_penalty = frequency_penalty; + self + } + + pub fn with_presence_penalty(mut self, presence_penalty: u8) -> Self { + self.presence_penalty = presence_penalty; + self + } + + pub fn with_images(mut self, return_images: bool) -> Self { + self.return_images = return_images; + self + } + + pub fn with_citations(mut self, return_citations: bool) -> Self { + self.return_citations = return_citations; + self + } + + pub fn with_return_related_questions(mut self, return_related_questions: bool) -> Self { + self.return_related_questions = return_related_questions; + self + } + + pub fn with_search_domain_filter(mut self, search_domain_filter: Vec) -> Self { + self.search_domain_filter = Some(search_domain_filter); + self + } + + pub fn with_recency_filter(mut self, recency_filter: RecencyFilter) -> Self { + self.search_recency_filter = recency_filter; + self + } +} + +impl Default for CreateChatCompletion { + fn default() -> Self { + Self { + model: Model::Llama31SonarLargeOnline, + stream: false, + max_tokens: None, + temperature: 0.2, + top_p: 0.9, + top_k: 0, + frequency_penalty: 1.0, + presence_penalty: 0, + return_images: false, + return_citations: false, + return_related_questions: false, + search_domain_filter: None, + search_recency_filter: RecencyFilter::default(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Model { + /// **8 Billion** parameters and **127,072** context length model. + #[serde(rename = "llama-3.1-sonar-small-128k-online")] + Llama31SonarSmallOnline, + + /// **70 Billion** parameters and **127,072** context length model. + #[serde(rename = "llama-3.1-sonar-large-128k-online")] + Llama31SonarLargeOnline, + + /// **405 Billion** parameters and **127,072** context length model. + #[serde(rename = "llama-3.1-sonar-huge-128k-online")] + Llama31SonarHugeOnline, + + /// **8 Billion** parameters and **127,072** context length model. + #[serde(rename = "llama-3.1-sonar-small-128k-chat")] + Llama31SonarSmallChat, + + /// **70 Billion** parameters and **127,072** context length model. + #[serde(rename = "llama-3.1-sonar-large-128k-chat")] + Llama31SonarLargeChat, + + /// **8 Billion** parameters and **131,072** context length model. + #[serde(rename = "llama-3.1-8b-instruct")] + Llama31InstructSmall, + + /// **70 Billion** parameters and **131,072** context length model. + #[serde(rename = "llama-3.1-70b-instruct")] + Llama31InstructLarge, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum RecencyFilter { + Hour, + Day, + Week, + Month, +} + +impl Default for RecencyFilter { + fn default() -> Self { + Self::Month + } +} diff --git a/perplexity/src/config.rs b/perplexity/src/config.rs new file mode 100644 index 0000000..4afd675 --- /dev/null +++ b/perplexity/src/config.rs @@ -0,0 +1,55 @@ +use crate::error::Error; + +const DEFAULT_API_BASE_URL: &str = "https://api.perplexity.ai/"; +const API_KEY_ENV_VAR: &str = "PERPLEXITY_API_KEY"; + +#[derive(Debug, Clone)] +pub struct Config { + pub api_key: String, + pub base_url: String, +} + +impl Config { + pub fn new(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + base_url: DEFAULT_API_BASE_URL.to_string(), + } + } + + /// Set the base url + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + /// Read Perplexity API key from **PERPLEXITY_API_KEY** environment variable. + pub fn from_env() -> Result { + let api_key = + std::env::var(API_KEY_ENV_VAR).map_err(|_| Error::MissingApiKey(API_KEY_ENV_VAR))?; + Ok(Self::new(api_key)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn should_use_default_values() { + let api_key = "openai-api-key"; + let config = Config::new(api_key); + + assert_eq!(config.api_key, api_key); + assert_eq!(config.base_url, DEFAULT_API_BASE_URL); + } + + #[test] + fn should_set_custom_url() { + let api_key = "openai-api-key"; + + let config = Config::new(api_key).with_base_url("https://custom-api.openai.com"); + assert_eq!(config.base_url, "https://custom-api.openai.com"); + } +} diff --git a/perplexity/src/error.rs b/perplexity/src/error.rs new file mode 100644 index 0000000..02967a4 --- /dev/null +++ b/perplexity/src/error.rs @@ -0,0 +1,113 @@ +use serde::Deserialize; +use std::str::Utf8Error; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("API error: {0}")] + Api(ApiErrorResponse), + + #[error("HTTP client error: {0}")] + Network(#[from] reqwest::Error), + + #[error("URL parse error: {0}")] + UrlParse(String), + + #[error("Failed to deserialize: {0}")] + JsonDeserialize(#[from] serde_json::Error), + + #[error("Invalid header value: {0}")] + InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue), + + #[error("Model not supported: {0}")] + ModelNotSupported(String), + + #[error("Missing API key {0}")] + MissingApiKey(&'static str), + + #[error("Invalid Stream Event")] + InvalidStreamEvent, + + #[error("UTF8 Error: {0}")] + Utf8Error(#[from] Utf8Error), + + #[error("Unexpected error: {0}")] + Unexpected(String), +} + +#[derive(Debug, Deserialize, PartialEq, Eq, thiserror::Error)] +#[error("Error response: {error_type} {error}")] +pub struct ApiErrorResponse { + #[serde(rename = "type")] + pub error_type: String, + pub error: ApiErrorDetail, +} + +#[derive(Debug, Deserialize, PartialEq, Eq, thiserror::Error)] +#[error("Api error: {error_type} {message}")] +pub struct ApiErrorDetail { + #[serde(rename = "type")] + pub error_type: ApiErrorType, + pub message: String, +} + +#[derive(Debug, Deserialize, PartialEq, Eq, thiserror::Error)] +pub enum ApiErrorType { + #[error("invalid_request_error")] + #[serde(rename = "invalid_request_error")] + InvalidRequest, + + #[error("authentication_error")] + #[serde(rename = "authentication_error")] + Authentication, + + #[error("permission_error")] + #[serde(rename = "permission_error")] + Permission, + + #[error("not_found_error")] + #[serde(rename = "not_found_error")] + NotFound, + + #[error("request_too_large")] + #[serde(rename = "request_too_large")] + RequestTooLarge, + + #[error("rate_limit_error")] + #[serde(rename = "rate_limit_error")] + RateLimit, + + #[error("api_error")] + #[serde(rename = "api_error")] + Unexpected, + + #[error("overloaded_error")] + #[serde(rename = "overloaded_error")] + Overloaded, +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn should_serialize_to_correct_error_values() { + assert_eq!( + &ApiErrorType::InvalidRequest.to_string(), + "invalid_request_error" + ); + assert_eq!( + &ApiErrorType::Authentication.to_string(), + "authentication_error" + ); + assert_eq!(&ApiErrorType::Permission.to_string(), "permission_error"); + assert_eq!(&ApiErrorType::NotFound.to_string(), "not_found_error"); + assert_eq!( + &ApiErrorType::RequestTooLarge.to_string(), + "request_too_large" + ); + assert_eq!(&ApiErrorType::RateLimit.to_string(), "rate_limit_error"); + assert_eq!(&ApiErrorType::Unexpected.to_string(), "api_error"); + assert_eq!(&ApiErrorType::Overloaded.to_string(), "overloaded_error"); + } +} diff --git a/perplexity/src/lib.rs b/perplexity/src/lib.rs new file mode 100644 index 0000000..e2f7460 --- /dev/null +++ b/perplexity/src/lib.rs @@ -0,0 +1,5 @@ +#![doc = include_str!("../README.md")] + +pub mod client; +pub mod config; +pub mod error;