diff --git a/components/ai_chat/core/browser/conversation_driver.cc b/components/ai_chat/core/browser/conversation_driver.cc index e51f6793f9f0..1b25688247e8 100644 --- a/components/ai_chat/core/browser/conversation_driver.cc +++ b/components/ai_chat/core/browser/conversation_driver.cc @@ -36,12 +36,19 @@ using ai_chat::mojom::CharacterType; using ai_chat::mojom::ConversationTurn; using ai_chat::mojom::ConversationTurnVisibility; +namespace ai_chat { + namespace { + static const auto kAllowedSchemes = base::MakeFixedFlatSet( {url::kHttpsScheme, url::kHttpScheme, url::kFileScheme, url::kDataScheme}); -} // namespace -namespace ai_chat { +bool IsPremiumStatus(mojom::PremiumStatus status) { + return status == mojom::PremiumStatus::Active || + status == mojom::PremiumStatus::ActiveDisconnected; +} + +} // namespace ConversationDriver::ConversationDriver(raw_ptr pref_service, raw_ptr ai_chat_metrics, @@ -74,7 +81,7 @@ ConversationDriver::ConversationDriver(raw_ptr pref_service, credential_manager_->GetPremiumStatus(base::BindOnce( [](ConversationDriver* instance, mojom::PremiumStatus status) { instance->last_premium_status_ = status; - if (status == mojom::PremiumStatus::Inactive) { + if (!IsPremiumStatus(status)) { // Not premium return; } @@ -721,9 +728,15 @@ void ConversationDriver::GetPremiumStatus( void ConversationDriver::OnPremiumStatusReceived( mojom::PageHandler::GetPremiumStatusCallback parent_callback, mojom::PremiumStatus premium_status) { - if (last_premium_status_ != premium_status && - premium_status == mojom::PremiumStatus::Active) { - // Change model if we haven't already + // Maybe switch to premium model when user is newly premium and on a basic + // model + const bool should_switch_model = + // This isn't the first retrieval (that's handled in the constructor) + last_premium_status_ != mojom::PremiumStatus::Unknown && + last_premium_status_ != premium_status && + premium_status == mojom::PremiumStatus::Active && + GetCurrentModel().access == mojom::ModelAccess::BASIC; + if (should_switch_model) { ChangeModel(features::kAIModelsPremiumDefaultKey.Get()); } last_premium_status_ = premium_status; diff --git a/components/ai_chat/core/browser/conversation_driver.h b/components/ai_chat/core/browser/conversation_driver.h index 76009ecdd396..4abaa27ac1c8 100644 --- a/components/ai_chat/core/browser/conversation_driver.h +++ b/components/ai_chat/core/browser/conversation_driver.h @@ -161,7 +161,7 @@ class ConversationDriver { int64_t current_navigation_id_; bool is_same_document_navigation_ = false; mojom::APIError current_error_ = mojom::APIError::None; - mojom::PremiumStatus last_premium_status_ = mojom::PremiumStatus::Inactive; + mojom::PremiumStatus last_premium_status_ = mojom::PremiumStatus::Unknown; std::unique_ptr pending_conversation_entry_; bool pending_message_needs_page_content_ = false; diff --git a/components/ai_chat/core/browser/engine/engine_consumer_llama.cc b/components/ai_chat/core/browser/engine/engine_consumer_llama.cc index d1805cf5ff9e..294c53451dd5 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_llama.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_llama.cc @@ -162,6 +162,7 @@ std::string BuildLlama2Prompt( const std::vector& conversation_history, std::string page_content, const bool& is_video, + const bool& needs_general_seed, const std::string user_message) { // Always use a generic system message std::string system_message = @@ -206,7 +207,9 @@ std::string BuildLlama2Prompt( if (conversation_history.empty() || conversation_history.size() <= 1) { return BuildLlama2FirstSequence( today_system_message, first_user_message, absl::nullopt, - l10n_util::GetStringUTF8(IDS_AI_CHAT_LLAMA2_GENERAL_SEED)); + (needs_general_seed) ? std::optional(l10n_util::GetStringUTF8( + IDS_AI_CHAT_LLAMA2_GENERAL_SEED)) + : std::nullopt); } // Use the first two messages to build the first sequence, @@ -227,7 +230,9 @@ std::string BuildLlama2Prompt( // Build the final subsequent exchange using the current turn. prompt += BuildLlama2SubsequentSequence( user_message, absl::nullopt, - l10n_util::GetStringUTF8(IDS_AI_CHAT_LLAMA2_GENERAL_SEED)); + (needs_general_seed) ? std::optional(l10n_util::GetStringUTF8( + IDS_AI_CHAT_LLAMA2_GENERAL_SEED)) + : std::nullopt); // Trimming recommended by Meta // https://huggingface.co/meta-llama/Llama-2-13b-chat#intended-use @@ -249,6 +254,8 @@ EngineConsumerLlamaRemote::EngineConsumerLlamaRemote( api_ = std::make_unique( model.name, stop_sequences, url_loader_factory, credential_manager); + needs_general_seed_ = base::StartsWith(model.name, "llama-2"); + max_page_content_length_ = model.max_page_content_length; } @@ -332,8 +339,9 @@ void EngineConsumerLlamaRemote::GenerateAssistantResponse( GenerationCompletedCallback completed_callback) { const std::string& truncated_page_content = page_content.substr(0, max_page_content_length_); - std::string prompt = BuildLlama2Prompt( - conversation_history, truncated_page_content, is_video, human_input); + std::string prompt = + BuildLlama2Prompt(conversation_history, truncated_page_content, is_video, + needs_general_seed_, human_input); DCHECK(api_); api_->QueryPrompt(prompt, {""}, std::move(completed_callback), std::move(data_received_callback)); diff --git a/components/ai_chat/core/browser/engine/engine_consumer_llama.h b/components/ai_chat/core/browser/engine/engine_consumer_llama.h index b0dae492388d..2ad156f971d8 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_llama.h +++ b/components/ai_chat/core/browser/engine/engine_consumer_llama.h @@ -61,6 +61,7 @@ class EngineConsumerLlamaRemote : public EngineConsumer { std::unique_ptr api_ = nullptr; + bool needs_general_seed_ = true; int max_page_content_length_ = 0; base::WeakPtrFactory weak_ptr_factory_{this}; diff --git a/components/ai_chat/core/browser/models.cc b/components/ai_chat/core/browser/models.cc index fa84fb272c14..06b286bf077b 100644 --- a/components/ai_chat/core/browser/models.cc +++ b/components/ai_chat/core/browser/models.cc @@ -42,13 +42,13 @@ const std::vector& GetAllModels() { static const base::NoDestructor> kModels({ {"chat-leo-expanded", "mixtral-8x7b-instruct", "Mixtral", "Mistral AI", mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, - kFreemiumAccess, 9000, 9700}, + kFreemiumAccess, 8000, 9700}, {"chat-claude-instant", "claude-instant-v1", "Claude Instant", "Anthropic", mojom::ModelEngineType::CLAUDE_REMOTE, mojom::ModelCategory::CHAT, kFreemiumAccess, 180000, 320000}, - {"chat-basic", "llama-2-13b-chat", "llama2 13b", "Meta", + {"chat-basic", "llama-2-13b-chat", "Llama 2 13b", "Meta", mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, - mojom::ModelAccess::BASIC, 9000, 9700}, + mojom::ModelAccess::BASIC, 8000, 9700}, }); return *kModels; } diff --git a/components/ai_chat/core/common/mojom/ai_chat.mojom b/components/ai_chat/core/common/mojom/ai_chat.mojom index 52eb5147e22e..6cf6682f77ef 100644 --- a/components/ai_chat/core/common/mojom/ai_chat.mojom +++ b/components/ai_chat/core/common/mojom/ai_chat.mojom @@ -45,6 +45,7 @@ enum ModelAccess { }; enum PremiumStatus { + Unknown, Inactive, Active, ActiveDisconnected, diff --git a/components/resources/ai_chat_ui_strings.grdp b/components/resources/ai_chat_ui_strings.grdp index d62ba7b7b904..2f322520ce2f 100644 --- a/components/resources/ai_chat_ui_strings.grdp +++ b/components/resources/ai_chat_ui_strings.grdp @@ -55,7 +55,7 @@ Hi, I'm Leo. I'm a fully hosted AI assistant by Brave. I'm powered by Llama 13B, a model created by Meta to be performant and applicable to many use cases. - Hi, I'm Leo. I'm a fully hosted AI assistant by Brave. I'm powered by Mixtral 7B, a model created by Mistral AI to handle advanced tasks. + Hi, I'm Leo. I'm a fully hosted AI assistant by Brave. I'm powered by Mixtral 8x7B, a model created by Mistral AI to handle advanced tasks. Hi, I'm Leo. I'm proxied by Brave and powered by Claude Instant, a model created by Anthropic to power conversational and text processing tasks.