Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AI Chat]: Adds support for multiple tab context to the conversation handler #27167

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions browser/ai_chat/ai_chat_service_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ AIChatServiceFactory::BuildServiceInstanceForBrowserContext(

return std::make_unique<AIChatService>(
ModelServiceFactory::GetForBrowserContext(context),

std::move(credential_manager), user_prefs::UserPrefs::Get(context),
(g_brave_browser_process->process_misc_metrics())
? g_brave_browser_process->process_misc_metrics()->ai_chat_metrics()
Expand Down
80 changes: 78 additions & 2 deletions browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

#include "brave/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h"

#include <cstddef>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "base/strings/utf_string_conversions.h"
#include "brave/browser/ai_chat/ai_chat_service_factory.h"
#include "brave/browser/ai_chat/ai_chat_urls.h"
#include "brave/browser/ui/side_panel/ai_chat/ai_chat_side_panel_utils.h"
Expand All @@ -21,8 +23,11 @@
#include "brave/components/constants/webui_url_constants.h"
#include "chrome/browser/favicon/favicon_service_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/ui/browser.h"
#include "chrome/browser/ui/browser_finder.h"
#include "chrome/browser/ui/browser_list.h"
#include "chrome/browser/ui/singleton_tabs.h"
#include "chrome/browser/ui/tabs/tab_strip_model.h"
#include "components/favicon/core/favicon_service.h"
#include "content/public/browser/browser_context.h"
#include "content/public/browser/storage_partition.h"
Expand Down Expand Up @@ -94,6 +99,44 @@ void AIChatUIPageHandler::HandleVoiceRecognition(
#endif
}

void AIChatUIPageHandler::AssociateTab(mojom::AvailableTabPtr tab,
const std::string& conversation_uuid) {
auto* content = content::WebContents::FromFrameTreeNodeId(
static_cast<content::FrameTreeNodeId>(tab->frame_tree_node_id));
if (!content) {
return;
}

// Ensure the content is loaded before associating
content->GetController().LoadIfNecessary();

auto* tab_helper = ai_chat::AIChatTabHelper::FromWebContents(content);
if (!tab_helper) {
return;
}

AIChatServiceFactory::GetForBrowserContext(profile_)->AssociateContent(
tab_helper, conversation_uuid);
}

void AIChatUIPageHandler::DisassociateTab(
mojom::AvailableTabPtr tab,
const std::string& conversation_uuid) {
auto* content = content::WebContents::FromFrameTreeNodeId(
static_cast<content::FrameTreeNodeId>(tab->frame_tree_node_id));
if (!content) {
return;
}

auto* tab_helper = ai_chat::AIChatTabHelper::FromWebContents(content);
if (!tab_helper) {
return;
}

AIChatServiceFactory::GetForBrowserContext(profile_)->DisassociateContent(
tab_helper, conversation_uuid);
}

void AIChatUIPageHandler::OpenAIChatSettings() {
content::WebContents* contents_to_navigate =
(active_chat_tab_helper_) ? active_chat_tab_helper_->web_contents()
Expand Down Expand Up @@ -153,6 +196,36 @@ void AIChatUIPageHandler::OpenStorageSupportUrl() {
OpenURL(GURL(kURLLearnMoreAboutStorage));
}

void AIChatUIPageHandler::GetAvailableTabs(GetAvailableTabsCallback callback) {
const BrowserList* browser_list = BrowserList::GetInstance();
std::vector<mojom::AvailableTabPtr> tabs;
for (const auto& browser : *browser_list) {
// TODO(fallaciousreasoning): Maybe we should consider other types of
// browsers as well?
if (!browser->is_type_normal()) {
continue;
}
if (browser->profile() != owner_web_contents_->GetBrowserContext()) {
continue;
}

for (int i = 0; i < browser->tab_strip_model()->count(); ++i) {
content::WebContents* contents =
browser->tab_strip_model()->GetWebContentsAt(i);
if (contents == owner_web_contents_.get()) {
continue;
}

auto id = static_cast<int>(
contents->GetPrimaryMainFrame()->GetFrameTreeNodeId());
tabs.push_back(
mojom::AvailableTab::New(id, contents->GetVisibleURL(),
base::UTF16ToUTF8(contents->GetTitle())));
}
}
std::move(callback).Run(std::move(tabs));
}

void AIChatUIPageHandler::GoPremium() {
#if !BUILDFLAG(IS_ANDROID)
OpenURL(GURL(kURLGoPremium));
Expand Down Expand Up @@ -276,8 +349,10 @@ void AIChatUIPageHandler::GetFaviconImageDataForAssociatedContent(
GetFaviconImageDataCallback callback,
mojom::SiteInfoPtr content_info,
bool should_send_page_contents) {
// TODO(fallaciousreasoning): We should look this up from the hostname
if (!content_info->is_content_association_possible ||
!content_info->url.has_value() || !content_info->url->is_valid()) {
content_info->details.empty() ||
!content_info->details[0]->url.is_valid()) {
std::move(callback).Run(std::nullopt);
return;
}
Expand All @@ -297,8 +372,9 @@ void AIChatUIPageHandler::GetFaviconImageDataForAssociatedContent(
std::move(callback).Run(std::move(bytes));
};

auto url = content_info->details[0]->url;
favicon_service_->GetRawFaviconForPageURL(
content_info->url.value(), icon_types, kDesiredFaviconSizePixels, true,
url, icon_types, kDesiredFaviconSizePixels, true,
base::BindOnce(on_favicon_available, std::move(callback)),
&favicon_task_tracker_);
}
Expand Down
5 changes: 5 additions & 0 deletions browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,16 @@ class AIChatUIPageHandler : public mojom::AIChatUIHandler,
void OpenConversationFullPage(const std::string& conversation_uuid) override;
void OpenURL(const GURL& url) override;
void OpenStorageSupportUrl() override;
void GetAvailableTabs(GetAvailableTabsCallback callback) override;
void OpenModelSupportUrl() override;
void GoPremium() override;
void RefreshPremiumSession() override;
void ManagePremium() override;
void HandleVoiceRecognition(const std::string& conversation_uuid) override;
void AssociateTab(mojom::AvailableTabPtr tab,
const std::string& conversation_uuid) override;
void DisassociateTab(mojom::AvailableTabPtr tab,
const std::string& conversation_uuid) override;
void CloseUI() override;
void SetChatUI(mojo::PendingRemote<mojom::ChatUI> chat_ui,
SetChatUICallback callback) override;
Expand Down
2 changes: 2 additions & 0 deletions components/ai_chat/core/browser/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ static_library("browser") {
"model_service.h",
"model_validator.cc",
"model_validator.h",
"multi_associated_content_driver.cc",
"multi_associated_content_driver.h",
"types.h",
"utils.cc",
"utils.h",
Expand Down
25 changes: 15 additions & 10 deletions components/ai_chat/core/browser/ai_chat_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,20 @@ std::vector<mojom::ConversationPtr> AIChatDatabase::GetAllConversations() {
if (statement.GetColumnType(index) != sql::ColumnType::kNull) {
DVLOG(1) << __func__ << " got associated content";

// TODO(fallaciousreasoning): Support multiple associated content
conversation->associated_content->uuid = statement.ColumnString(index++);
conversation->associated_content->title =
DecryptOptionalColumnToString(statement, index++);

auto detail = mojom::SiteInfoDetail::New();
detail->title =
DecryptOptionalColumnToString(statement, index++).value_or("");
auto url_raw = DecryptOptionalColumnToString(statement, index++);
if (url_raw.has_value()) {
conversation->associated_content->url = GURL(url_raw.value());
detail->url = GURL(url_raw.value());
}
conversation->associated_content->content_type =
detail->content_type =
static_cast<mojom::ContentType>(statement.ColumnInt(index++));
conversation->associated_content->details.push_back(std::move(detail));

conversation->associated_content->content_used_percentage =
statement.ColumnInt(index++);
conversation->associated_content->is_content_refined =
Expand Down Expand Up @@ -394,7 +399,7 @@ bool AIChatDatabase::AddConversation(mojom::ConversationPtr conversation,
if (conversation->associated_content->is_content_association_possible) {
DVLOG(2) << "Adding associated content for conversation "
<< conversation->uuid << " with url "
<< conversation->associated_content->url->spec();
<< conversation->associated_content->details[0]->url.spec();
if (!AddOrUpdateAssociatedContent(
conversation->uuid, std::move(conversation->associated_content),
contents)) {
Expand Down Expand Up @@ -465,11 +470,11 @@ bool AIChatDatabase::AddOrUpdateAssociatedContent(
}
CHECK(statement.is_valid());
int index = 0;
BindAndEncryptOptionalString(statement, index++, associated_content->title);
BindAndEncryptOptionalString(statement, index++,
associated_content->url->spec());
statement.BindInt(index++,
base::to_underlying(associated_content->content_type));

auto& detail = associated_content->details[0];
BindAndEncryptOptionalString(statement, index++, detail->title);
BindAndEncryptOptionalString(statement, index++, detail->url.spec());
statement.BindInt(index++, base::to_underlying(detail->content_type));
BindAndEncryptOptionalString(statement, index++, contents);
statement.BindInt(index++, associated_content->content_used_percentage);
statement.BindBool(index++, associated_content->is_content_refined);
Expand Down
47 changes: 28 additions & 19 deletions components/ai_chat/core/browser/ai_chat_database_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,18 @@ TEST_P(AIChatDatabaseTest, AddAndGetConversationAndEntries) {
// recent entry.
const GURL page_url = GURL("https://example.com/page");
const std::string expected_contents = "Page contents";
auto details =
mojom::SiteInfoDetail::New(page_url, "page title", page_url.host(),
mojom::ContentType::PageContent);
std::vector<mojom::SiteInfoDetailPtr> details_vector;
details_vector.push_back(details->Clone());
mojom::SiteInfoPtr associated_content =
has_content
? mojom::SiteInfo::New(
content_uuid, mojom::ContentType::PageContent, "page title",
page_url.host(), page_url, 62, true, true)
: mojom::SiteInfo::New(
std::nullopt, mojom::ContentType::PageContent, std::nullopt,
std::nullopt, std::nullopt, 0, false, false);
? mojom::SiteInfo::New(content_uuid, std::move(details_vector), 62,
true, true)
: mojom::SiteInfo::New(std::nullopt,
std::vector<mojom::SiteInfoDetailPtr>(), 0,
false, false);
const mojom::ConversationPtr metadata =
mojom::Conversation::New(uuid, "title", now - base::Hours(2), true,
std::nullopt, std::move(associated_content));
Expand Down Expand Up @@ -265,8 +269,8 @@ TEST_P(AIChatDatabaseTest, UpdateConversationTitle) {
const std::string updated_title = "updated title";
mojom::ConversationPtr metadata = mojom::Conversation::New(
uuid, initial_title, base::Time::Now(), true, std::nullopt,
mojom::SiteInfo::New(std::nullopt, mojom::ContentType::PageContent,
std::nullopt, std::nullopt, std::nullopt, 0, false,
mojom::SiteInfo::New(std::nullopt,
std::vector<mojom::SiteInfoDetailPtr>(), 0, false,
false));

// Persist the first entry (and get the response ready)
Expand Down Expand Up @@ -295,10 +299,12 @@ TEST_P(AIChatDatabaseTest, AddOrUpdateAssociatedContent) {
const std::string uuid = "for_associated_content";
const std::string content_uuid = "content_uuid";
const GURL page_url = GURL("https://example.com/page");
auto details = mojom::SiteInfoDetail::New(
page_url, "page title", page_url.host(), mojom::ContentType::PageContent);
std::vector<mojom::SiteInfoDetailPtr> details_vector;
mojom::ConversationPtr metadata = mojom::Conversation::New(
uuid, "title", base::Time::Now() - base::Hours(2), true, std::nullopt,
mojom::SiteInfo::New(content_uuid, mojom::ContentType::PageContent,
"page title", page_url.host(), page_url, 62, true,
mojom::SiteInfo::New(content_uuid, std::move(details_vector), 62, true,
true));

auto history = CreateSampleChatHistory(1u);
Expand Down Expand Up @@ -340,8 +346,8 @@ TEST_P(AIChatDatabaseTest, DeleteAllData) {
const GURL page_url = GURL("https://example.com/page");
mojom::ConversationPtr metadata = mojom::Conversation::New(
uuid, "title", base::Time::Now() - base::Hours(2), true, std::nullopt,
mojom::SiteInfo::New(std::nullopt, mojom::ContentType::PageContent,
std::nullopt, std::nullopt, std::nullopt, 0, false,
mojom::SiteInfo::New(std::nullopt,
std::vector<mojom::SiteInfoDetailPtr>(), 0, false,
false));

auto history = CreateSampleChatHistory(1u);
Expand Down Expand Up @@ -377,16 +383,20 @@ TEST_P(AIChatDatabaseTest, DeleteAssociatedWebContent) {

// The times in the Conversation are irrelevant, only the times of the entries
// are persisted.
auto details = mojom::SiteInfoDetail::New(
page_url, "page title", page_url.host(), mojom::ContentType::PageContent);
std::vector<mojom::SiteInfoDetailPtr> details_vector;
details_vector.push_back(details->Clone());
mojom::ConversationPtr metadata_first = mojom::Conversation::New(
"first", "title", base::Time::Now() - base::Hours(2), true, std::nullopt,
mojom::SiteInfo::New("first-content", mojom::ContentType::PageContent,
"page title", page_url.host(), page_url, 62, true,
mojom::SiteInfo::New("first-content", std::move(details_vector), 62, true,
true));

details_vector.push_back(details->Clone());
mojom::ConversationPtr metadata_second = mojom::Conversation::New(
"second", "title", base::Time::Now() - base::Hours(1), true, "model-2",
mojom::SiteInfo::New("second-content", mojom::ContentType::PageContent,
"page title", page_url.host(), page_url, 62, true,
true));
mojom::SiteInfo::New("second-content", std::move(details_vector), 62,
true, true));

auto history_first = CreateSampleChatHistory(1u, -2);
auto history_second = CreateSampleChatHistory(1u, -1);
Expand Down Expand Up @@ -426,8 +436,7 @@ TEST_P(AIChatDatabaseTest, DeleteAssociatedWebContent) {
conversations = db_->GetAllConversations();
EXPECT_EQ(conversations.size(), 2u);
ExpectConversationEquals(FROM_HERE, conversations[0], metadata_first);
metadata_second->associated_content->url = std::nullopt;
metadata_second->associated_content->title = std::nullopt;
metadata_second->associated_content->details.clear();
ExpectConversationEquals(FROM_HERE, conversations[1], metadata_second);

archive_result = db_->GetConversationData("second");
Expand Down
17 changes: 15 additions & 2 deletions components/ai_chat/core/browser/ai_chat_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h"
#include "brave/components/ai_chat/core/browser/ai_chat_database.h"
#include "brave/components/ai_chat/core/browser/ai_chat_metrics.h"
#include "brave/components/ai_chat/core/browser/associated_content_driver.h"
#include "brave/components/ai_chat/core/browser/constants.h"
#include "brave/components/ai_chat/core/browser/conversation_handler.h"
#include "brave/components/ai_chat/core/browser/model_service.h"
Expand Down Expand Up @@ -159,8 +160,8 @@ ConversationHandler* AIChatService::CreateConversation() {
mojom::ConversationPtr conversation = mojom::Conversation::New(
conversation_uuid, "", base::Time::Now(), false, std::nullopt,
mojom::SiteInfo::New(base::Uuid::GenerateRandomV4().AsLowercaseString(),
mojom::ContentType::PageContent, std::nullopt,
std::nullopt, std::nullopt, 0, false, false));
std::vector<mojom::SiteInfoDetailPtr>(), 0, false,
false));
conversations_.insert_or_assign(conversation_uuid, std::move(conversation));
}
mojom::Conversation* conversation =
Expand Down Expand Up @@ -294,6 +295,18 @@ ConversationHandler* AIChatService::CreateConversationHandlerForContent(
return conversation;
}

void AIChatService::AssociateContent(AssociatedContentDriver* driver,
const std::string& conversation_uuid) {
DCHECK(base::Contains(conversation_handlers_, conversation_uuid));
conversation_handlers_.at(conversation_uuid)->AddAssociation(driver);
}

void AIChatService::DisassociateContent(AssociatedContentDriver* driver,
const std::string& conversation_uuid) {
DCHECK(base::Contains(conversation_handlers_, conversation_uuid));
conversation_handlers_.at(conversation_uuid)->RemoveAssociation(driver);
}

void AIChatService::DeleteConversations(std::optional<base::Time> begin_time,
std::optional<base::Time> end_time) {
if (!begin_time.has_value() && !end_time.has_value()) {
Expand Down
7 changes: 7 additions & 0 deletions components/ai_chat/core/browser/ai_chat_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ namespace ai_chat {

class ModelService;
class AIChatMetrics;
class AssociatedContentDriver;

// Main entry point for creating and consuming AI Chat conversations
class AIChatService : public KeyedService,
Expand Down Expand Up @@ -122,6 +123,11 @@ class AIChatService : public KeyedService,
base::WeakPtr<ConversationHandler::AssociatedContentDelegate>
associated_content);

void AssociateContent(AssociatedContentDriver* associated_content,
const std::string& conversation_uuid);
void DisassociateContent(AssociatedContentDriver* associated_content,
const std::string& conversation_uuid);

// Removes all in-memory and persisted data for all conversations
void DeleteConversations(std::optional<base::Time> begin_time = std::nullopt,
std::optional<base::Time> end_time = std::nullopt);
Expand Down Expand Up @@ -181,6 +187,7 @@ class AIChatService : public KeyedService,
using ConversationMapCallback = base::OnceCallback<void(ConversationMap&)>;

void MaybeInitStorage();

// Called when the database encryptor is ready.
void OnOsCryptAsyncReady(os_crypt_async::Encryptor encryptor, bool success);
void LoadConversationsLazy(ConversationMapCallback callback);
Expand Down
Loading
Loading