Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Uplift 1.62.x] AI Chat Mixtral fixes #21664

Merged
merged 2 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions components/ai_chat/core/browser/conversation_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,19 @@ using ai_chat::mojom::CharacterType;
using ai_chat::mojom::ConversationTurn;
using ai_chat::mojom::ConversationTurnVisibility;

namespace ai_chat {

namespace {

static const auto kAllowedSchemes = base::MakeFixedFlatSet<std::string_view>(
{url::kHttpsScheme, url::kHttpScheme, url::kFileScheme, url::kDataScheme});
} // namespace

namespace ai_chat {
bool IsPremiumStatus(mojom::PremiumStatus status) {
return status == mojom::PremiumStatus::Active ||
status == mojom::PremiumStatus::ActiveDisconnected;
}

} // namespace

ConversationDriver::ConversationDriver(raw_ptr<PrefService> pref_service,
raw_ptr<AIChatMetrics> ai_chat_metrics,
Expand Down Expand Up @@ -74,7 +81,7 @@ ConversationDriver::ConversationDriver(raw_ptr<PrefService> pref_service,
credential_manager_->GetPremiumStatus(base::BindOnce(
[](ConversationDriver* instance, mojom::PremiumStatus status) {
instance->last_premium_status_ = status;
if (status == mojom::PremiumStatus::Inactive) {
if (!IsPremiumStatus(status)) {
// Not premium
return;
}
Expand Down Expand Up @@ -721,9 +728,15 @@ void ConversationDriver::GetPremiumStatus(
void ConversationDriver::OnPremiumStatusReceived(
mojom::PageHandler::GetPremiumStatusCallback parent_callback,
mojom::PremiumStatus premium_status) {
if (last_premium_status_ != premium_status &&
premium_status == mojom::PremiumStatus::Active) {
// Change model if we haven't already
// Maybe switch to premium model when user is newly premium and on a basic
// model
const bool should_switch_model =
// This isn't the first retrieval (that's handled in the constructor)
last_premium_status_ != mojom::PremiumStatus::Unknown &&
last_premium_status_ != premium_status &&
premium_status == mojom::PremiumStatus::Active &&
GetCurrentModel().access == mojom::ModelAccess::BASIC;
if (should_switch_model) {
ChangeModel(features::kAIModelsPremiumDefaultKey.Get());
}
last_premium_status_ = premium_status;
Expand Down
2 changes: 1 addition & 1 deletion components/ai_chat/core/browser/conversation_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class ConversationDriver {
int64_t current_navigation_id_;
bool is_same_document_navigation_ = false;
mojom::APIError current_error_ = mojom::APIError::None;
mojom::PremiumStatus last_premium_status_ = mojom::PremiumStatus::Inactive;
mojom::PremiumStatus last_premium_status_ = mojom::PremiumStatus::Unknown;

std::unique_ptr<mojom::ConversationTurn> pending_conversation_entry_;
bool pending_message_needs_page_content_ = false;
Expand Down
16 changes: 12 additions & 4 deletions components/ai_chat/core/browser/engine/engine_consumer_llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ std::string BuildLlama2Prompt(
const std::vector<ConversationTurn>& conversation_history,
std::string page_content,
const bool& is_video,
const bool& needs_general_seed,
const std::string user_message) {
// Always use a generic system message
std::string system_message =
Expand Down Expand Up @@ -206,7 +207,9 @@ std::string BuildLlama2Prompt(
if (conversation_history.empty() || conversation_history.size() <= 1) {
return BuildLlama2FirstSequence(
today_system_message, first_user_message, absl::nullopt,
l10n_util::GetStringUTF8(IDS_AI_CHAT_LLAMA2_GENERAL_SEED));
(needs_general_seed) ? std::optional(l10n_util::GetStringUTF8(
IDS_AI_CHAT_LLAMA2_GENERAL_SEED))
: std::nullopt);
}

// Use the first two messages to build the first sequence,
Expand All @@ -227,7 +230,9 @@ std::string BuildLlama2Prompt(
// Build the final subsequent exchange using the current turn.
prompt += BuildLlama2SubsequentSequence(
user_message, absl::nullopt,
l10n_util::GetStringUTF8(IDS_AI_CHAT_LLAMA2_GENERAL_SEED));
(needs_general_seed) ? std::optional(l10n_util::GetStringUTF8(
IDS_AI_CHAT_LLAMA2_GENERAL_SEED))
: std::nullopt);

// Trimming recommended by Meta
// https://huggingface.co/meta-llama/Llama-2-13b-chat#intended-use
Expand All @@ -249,6 +254,8 @@ EngineConsumerLlamaRemote::EngineConsumerLlamaRemote(
api_ = std::make_unique<RemoteCompletionClient>(
model.name, stop_sequences, url_loader_factory, credential_manager);

needs_general_seed_ = base::StartsWith(model.name, "llama-2");

max_page_content_length_ = model.max_page_content_length;
}

Expand Down Expand Up @@ -332,8 +339,9 @@ void EngineConsumerLlamaRemote::GenerateAssistantResponse(
GenerationCompletedCallback completed_callback) {
const std::string& truncated_page_content =
page_content.substr(0, max_page_content_length_);
std::string prompt = BuildLlama2Prompt(
conversation_history, truncated_page_content, is_video, human_input);
std::string prompt =
BuildLlama2Prompt(conversation_history, truncated_page_content, is_video,
needs_general_seed_, human_input);
DCHECK(api_);
api_->QueryPrompt(prompt, {"</response>"}, std::move(completed_callback),
std::move(data_received_callback));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class EngineConsumerLlamaRemote : public EngineConsumer {

std::unique_ptr<RemoteCompletionClient> api_ = nullptr;

bool needs_general_seed_ = true;
int max_page_content_length_ = 0;

base::WeakPtrFactory<EngineConsumerLlamaRemote> weak_ptr_factory_{this};
Expand Down
6 changes: 3 additions & 3 deletions components/ai_chat/core/browser/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ const std::vector<ai_chat::mojom::Model>& GetAllModels() {
static const base::NoDestructor<std::vector<mojom::Model>> kModels({
{"chat-leo-expanded", "mixtral-8x7b-instruct", "Mixtral", "Mistral AI",
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT,
kFreemiumAccess, 9000, 9700},
kFreemiumAccess, 8000, 9700},
{"chat-claude-instant", "claude-instant-v1", "Claude Instant",
"Anthropic", mojom::ModelEngineType::CLAUDE_REMOTE,
mojom::ModelCategory::CHAT, kFreemiumAccess, 180000, 320000},
{"chat-basic", "llama-2-13b-chat", "llama2 13b", "Meta",
{"chat-basic", "llama-2-13b-chat", "Llama 2 13b", "Meta",
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT,
mojom::ModelAccess::BASIC, 9000, 9700},
mojom::ModelAccess::BASIC, 8000, 9700},
});
return *kModels;
}
Expand Down
1 change: 1 addition & 0 deletions components/ai_chat/core/common/mojom/ai_chat.mojom
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ enum ModelAccess {
};

enum PremiumStatus {
Unknown,
Inactive,
Active,
ActiveDisconnected,
Expand Down
2 changes: 1 addition & 1 deletion components/resources/ai_chat_ui_strings.grdp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
Hi, I'm Leo. I'm a fully hosted AI assistant by Brave. I'm powered by Llama 13B, a model created by Meta to be performant and applicable to many use cases.
</message>
<message name="IDS_CHAT_UI_INTRO_MESSAGE_CHAT_LEO_EXPANDED" desc="AI Chat intro message for the expanded model">
Hi, I'm Leo. I'm a fully hosted AI assistant by Brave. I'm powered by Mixtral 7B, a model created by Mistral AI to handle advanced tasks.
Hi, I'm Leo. I'm a fully hosted AI assistant by Brave. I'm powered by Mixtral 8x7B, a model created by Mistral AI to handle advanced tasks.
</message>
<message name="IDS_CHAT_UI_INTRO_MESSAGE_CHAT_LEO_CLAUDE_INSTANT" desc="AI Chat intro message for the Claude Instant model">
Hi, I'm Leo. I'm proxied by Brave and powered by Claude Instant, a model created by Anthropic to power conversational and text processing tasks.
Expand Down