Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AI Chat: don't unload conversations that are associated with non-content-sending pages, e.g. chrome:// or where content is disabled #26952

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions browser/brave_tab_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ void AttachTabHelpers(content::WebContents* web_contents) {
if (ai_chat::IsAllowedForContext(context)) {
ai_chat::AIChatTabHelper::CreateForWebContents(
web_contents,
ai_chat::AIChatServiceFactory::GetForBrowserContext(context),
#if BUILDFLAG(ENABLE_PRINT_PREVIEW)
std::make_unique<ai_chat::PrintPreviewExtractor>(web_contents)
#else
Expand Down
7 changes: 4 additions & 3 deletions browser/ui/ai_chat/ai_chat_tab_helper_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ class AIChatTabHelperUnitTest : public content::RenderViewHostTestHarness,
favicon::ContentFaviconDriver::CreateForWebContents(web_contents(),
&favicon_service_);
AIChatTabHelper::CreateForWebContents(
web_contents(), is_print_preview_supported_
? std::make_unique<MockPrintPreviewExtractor>()
: nullptr);
web_contents(), nullptr /*ai_chat_service*/,
is_print_preview_supported_
? std::make_unique<MockPrintPreviewExtractor>()
: nullptr);
helper_ = AIChatTabHelper::FromWebContents(web_contents());
helper_->SetPageContentFetcherDelegateForTesting(
std::make_unique<MockPageContentFetcher>());
Expand Down
4 changes: 3 additions & 1 deletion components/ai_chat/content/browser/ai_chat_tab_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,13 @@ void AIChatTabHelper::BindPageContentExtractorHost(
}

AIChatTabHelper::AIChatTabHelper(content::WebContents* web_contents,
AIChatService* ai_chat_service,
std::unique_ptr<PrintPreviewExtractionDelegate>
print_preview_extraction_delegate)
: content::WebContentsObserver(web_contents),
content::WebContentsUserData<AIChatTabHelper>(*web_contents),
AssociatedContentDriver(web_contents->GetBrowserContext()
AssociatedContentDriver(ai_chat_service,
web_contents->GetBrowserContext()
->GetDefaultStoragePartition()
->GetURLLoaderFactoryForBrowserProcess()),
print_preview_extraction_delegate_(
Expand Down
1 change: 1 addition & 0 deletions components/ai_chat/content/browser/ai_chat_tab_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class AIChatTabHelper : public content::WebContentsObserver,
// PrintPreviewExtractionDelegate is provided as it's implementation is
// in a different layer.
AIChatTabHelper(content::WebContents* web_contents,
AIChatService* ai_chat_service,
std::unique_ptr<PrintPreviewExtractionDelegate>
print_preview_extraction_delegate);

Expand Down
72 changes: 61 additions & 11 deletions components/ai_chat/core/browser/ai_chat_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,10 @@ ConversationHandler* AIChatService::GetOrCreateConversationHandlerForContent(
if (!conversation) {
// New conversation needed
conversation = CreateConversation();
// Provide the content delegate, if allowed
MaybeAssociateContentWithConversation(conversation, associated_content_id,
associated_content);
}
// Provide the content delegate, if allowed
MaybeAssociateContentWithConversation(conversation, associated_content_id,
associated_content);

return conversation;
}
Expand Down Expand Up @@ -545,6 +545,21 @@ void AIChatService::MaybeAssociateContentWithConversation(
// if we don't call SetAssociatedContentDelegate, the conversation still
// has a default Tab's navigation for which is is associated. The Conversation
// won't use that Tab's Page for context.
if (!IsAIChatHistoryEnabled()) {
// First, if we're replacing a conversation and history is not enabled,
// then there'll be no way to get back to that conversation so we can
// delete it.
auto conversation_uuid_it =
content_conversations_.find(associated_content_id);
if (conversation_uuid_it != content_conversations_.end()) {
auto conversation_uuid = conversation_uuid_it->second;
ConversationHandler* previous_conversation =
GetConversation(conversation_uuid);
if (previous_conversation) {
MaybeUnloadConversation(previous_conversation);
}
}
}
content_conversations_.insert_or_assign(
associated_content_id, conversation->get_conversation_uuid());
}
Expand Down Expand Up @@ -631,23 +646,39 @@ void AIChatService::OnPremiumStatusReceived(GetPremiumStatusCallback callback,
void AIChatService::MaybeUnloadConversation(
ConversationHandler* conversation_handler) {
// Don't unload if there is active UI for the conversation
if (conversation_handler->IsAnyClientConnected()) {
if (conversation_handler->IsAnyClientConnected() ||
conversation_handler->IsRequestInProgress()) {
return;
}

bool has_history = conversation_handler->HasAnyHistory();
std::string uuid = conversation_handler->get_conversation_uuid();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason get_conversation_uuid doesn't return a const std::string&? Seems like we might be doing some unnecessary copying


// We can keep a conversation with history in memory until there is no active
// content.
// Some conversations are only associated with content via this Service
// but not via the ConversationHandler. This is only for lookup, but is
// important if the content is still alive to keep the conversation alive
// so that when UI is opened for that content, the conversation still exists.
// TODO(petemill): With the history feature enabled, we should unload (if
// there is no request in progress). However, we can only do this when
// there is no request in progress). However, we can only do that when
// GetOrCreateConversationHandlerForContent allows a callback so that it
// can provide an answer after loading the conversation content from storage.
if (conversation_handler->IsAssociatedContentAlive() && has_history) {
return;
// We can also only do that when re-loading the conversation for still-alive
// content rejoins to that live content instead of creating a content archive.
if (has_history) {
bool is_content_alive = base::ranges::any_of(
content_conversations_,
[&uuid](const auto& kv) { return kv.second == uuid; });

if (IsAIChatHistoryEnabled()) {
is_content_alive =
is_content_alive || conversation_handler->IsAssociatedContentAlive();
}

if (is_content_alive) {
return;
}
}

auto uuid = conversation_handler->get_conversation_uuid();
conversation_observations_.RemoveObservation(conversation_handler);
conversation_handlers_.erase(uuid);
DVLOG(1) << "Unloaded conversation (" << uuid << ") from memory. Now have "
Expand Down Expand Up @@ -828,10 +859,29 @@ void AIChatService::OnConversationTitleChanged(ConversationHandler* handler,

void AIChatService::OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) {
content_conversations_.erase(content_id);
// This will fire for conversations which are directly associated with the
// content, including non-default conversations. They need to be unloaded
// too.
MaybeUnloadConversation(handler);
}

void AIChatService::OnAssociatedContentDestroyed(int content_id) {
// This will fire also for conversations which don't know they are associated
// with content (e.g. for chrome:// content) and only content_conversations_
// has that knowledge). So we handle calling MaybeUnloadConversation here.
auto conversation_uuid_it = content_conversations_.find(content_id);
if (conversation_uuid_it == content_conversations_.end()) {
return;
}
std::string conversation_uuid = conversation_uuid_it->second;
content_conversations_.erase(conversation_uuid_it);
// Maybe unload the conversation when its associated content is destroyed
ConversationHandler* conversation = GetConversation(conversation_uuid);
if (conversation) {
MaybeUnloadConversation(conversation);
}
}

void AIChatService::GetVisibleConversations(
GetVisibleConversationsCallback callback) {
LoadConversationsLazy(base::BindOnce(
Expand Down
19 changes: 12 additions & 7 deletions components/ai_chat/core/browser/ai_chat_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ class AIChatService : public KeyedService,
void OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) override;

// Content destroyed that may or may not be directly associated with a
// ConversationHandler, but may be associated via the content_conversations_
// map.
void OnAssociatedContentDestroyed(int content_id);

// Adds new conversation and returns the handler
ConversationHandler* CreateConversation();

Expand Down Expand Up @@ -158,6 +163,7 @@ class AIChatService : public KeyedService,
BindObserverCallback callback) override;

bool HasUserOptedIn();
bool IsAIChatHistoryEnabled();
bool IsPremiumStatus();

std::unique_ptr<EngineConsumer> GetDefaultAIEngine();
Expand Down Expand Up @@ -214,8 +220,6 @@ class AIChatService : public KeyedService,
mojom::ServiceStatePtr BuildState();
void OnStateChanged();

bool IsAIChatHistoryEnabled();

raw_ptr<ModelService> model_service_;
raw_ptr<PrefService> profile_prefs_;
raw_ptr<AIChatMetrics> ai_chat_metrics_;
Expand Down Expand Up @@ -246,11 +250,12 @@ class AIChatService : public KeyedService,
std::map<std::string, std::unique_ptr<ConversationHandler>>
conversation_handlers_;

// Map associated content id (a.k.a navigation id) to conversation uuid. This
// acts as a cache for back-navigation to find the most recent conversation
// for that navigation. This should be periodically cleaned up by removing any
// keys where the ConversationHandler has had a destroyed
// associated_content_delegate_ for some time.
// Map associated content id (a.k.a navigation id) to defaultconversation
// uuid. This is not an exhaustive list because associated content can have
// multiple conversations if the history feature is enabled.
// TODO(petemill): Use NavigationEntry::SetUserData to store conversation
// relationship with specific navigations so that we can support
// back-navigation without needing to keep all entries in this map in memory.
std::map<int, std::string> content_conversations_;

base::ScopedMultiSourceObservation<ConversationHandler,
Expand Down
79 changes: 50 additions & 29 deletions components/ai_chat/core/browser/ai_chat_service_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ class MockAssociatedContent
related_conversations_.erase(conversation);
}

void DisassociateWithConversations(std::string archived_text_content,
void DisassociateWithConversations(AIChatService* service,
std::string archived_text_content,
bool archived_is_video) {
std::vector<base::WeakPtr<ConversationHandler>> related_conversations;
for (auto& conversation : related_conversations_) {
Expand All @@ -199,6 +200,7 @@ class MockAssociatedContent
archived_is_video);
}
}
service->OnAssociatedContentDestroyed(content_id_);
}

base::WeakPtr<ConversationHandler::AssociatedContentDelegate> GetWeakPtr() {
Expand Down Expand Up @@ -467,7 +469,8 @@ TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithContent) {
// the content is destroyed.
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
ExpectVisibleConversationsSize(FROM_HERE, 1u);
associated_content.DisassociateWithConversations("", false);
associated_content.DisassociateWithConversations(ai_chat_service_.get(), "",
false);

if (IsAIChatHistoryEnabled()) {
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u);
Expand All @@ -486,11 +489,14 @@ TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) {
ON_CALL(associated_content, GetURL())
.WillByDefault(testing::Return(GURL("https://example.com")));
associated_content.SetContentId(1);
ConversationHandler* conversation_with_content =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
base::WeakPtr<ConversationHandler> conversation_with_content =
ai_chat_service_
->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(),
associated_content.GetWeakPtr())
->GetWeakPtr();
EXPECT_TRUE(conversation_with_content);
EXPECT_NE(conversation_without_content, conversation_with_content);
EXPECT_NE(conversation_without_content, conversation_with_content.get());
EXPECT_NE(conversation_without_content->get_conversation_uuid(),
conversation_with_content->get_conversation_uuid());
base::RunLoop run_loop;
Expand All @@ -508,29 +514,40 @@ TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) {
EXPECT_EQ(
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr()),
conversation_with_content);
conversation_with_content.get());

// Creating a second conversation with the same associated content should
// make the second conversation the default for that content, but leave
// the first still associated with the content.
ConversationHandler* conversation2 =
ai_chat_service_->CreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
EXPECT_NE(conversation_with_content, conversation2);
EXPECT_NE(conversation_with_content->get_conversation_uuid(),
conversation2->get_conversation_uuid());
base::WeakPtr<ConversationHandler> conversation2 =
ai_chat_service_
->CreateConversationHandlerForContent(
associated_content.GetContentId(),
associated_content.GetWeakPtr())
->GetWeakPtr();
if (IsAIChatHistoryEnabled()) {
EXPECT_NE(conversation_with_content.get(), conversation2.get());
EXPECT_NE(conversation_with_content->get_conversation_uuid(),
conversation2->get_conversation_uuid());
EXPECT_EQ(
conversation_with_content->GetAssociatedContentDelegateForTesting(),
conversation2->GetAssociatedContentDelegateForTesting());
} else {
// With no history feature, the previous conversation will be deleted
// because there is no client connected to it.
EXPECT_EQ(nullptr, conversation_with_content);
}
EXPECT_EQ(conversation2->GetAssociatedContentDelegateForTesting(),
&associated_content);
EXPECT_EQ(conversation_with_content->GetAssociatedContentDelegateForTesting(),
conversation2->GetAssociatedContentDelegateForTesting());

// Check the second conversation is the default for that content ID
EXPECT_EQ(
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr()),
conversation2);
conversation2.get());
// Let the conversation be deleted
std::string conversation2_uuid = conversation2->get_conversation_uuid();
auto client1 = CreateConversationClient(conversation2);
auto client1 = CreateConversationClient(conversation2.get());
DisconnectConversationClient(client1.get());
ConversationHandler* conversation_with_content3 =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
Expand Down Expand Up @@ -674,28 +691,30 @@ TEST_P(AIChatServiceUnitTest, OpenConversationWithStagedEntries_NoPermission) {

TEST_P(AIChatServiceUnitTest, OpenConversationWithStagedEntries) {
NiceMock<MockAssociatedContent> associated_content{};
ConversationHandler* conversation =
ai_chat_service_->CreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
auto conversation_client = CreateConversationClient(conversation);
ON_CALL(associated_content, HasOpenAIChatPermission)
.WillByDefault(testing::Return(true));

// Allowed scheme to be associated with a conversation
ON_CALL(associated_content, GetURL())
.WillByDefault(testing::Return(GURL("https://example.com")));
associated_content.SetContentId(1);

ON_CALL(associated_content, GetStagedEntriesFromContent)
.WillByDefault(
[](ConversationHandler::GetStagedEntriesCallback callback) {
std::move(callback).Run(std::vector<SearchQuerySummary>{
SearchQuerySummary("query", "summary")});
});
ON_CALL(associated_content, HasOpenAIChatPermission)
.WillByDefault(testing::Return(true));

// Allowed scheme to be associated with a conversation
ON_CALL(associated_content, GetURL())
.WillByDefault(testing::Return(GURL("https://example.com")));

// One from setting up a connected client, one from
// OpenConversationWithStagedEntries.
EXPECT_CALL(associated_content, GetStagedEntriesFromContent).Times(2);

ConversationHandler* conversation =
ai_chat_service_->CreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
auto conversation_client = CreateConversationClient(conversation);

bool opened = false;
ai_chat_service_->OpenConversationWithStagedEntries(
associated_content.GetWeakPtr(),
Expand Down Expand Up @@ -827,8 +846,10 @@ TEST_P(AIChatServiceUnitTest, DeleteAssociatedWebContent) {
}

// Archive content for conversations 2 and 3
data[1].associated_content.DisassociateWithConversations(page_content, false);
data[2].associated_content.DisassociateWithConversations(page_content, false);
data[1].associated_content.DisassociateWithConversations(
ai_chat_service_.get(), page_content, false);
data[2].associated_content.DisassociateWithConversations(
ai_chat_service_.get(), page_content, false);

// Delete associated content from conversations between 1 hours ago and 3
// hours ago.
Expand Down
10 changes: 9 additions & 1 deletion components/ai_chat/core/browser/associated_content_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "base/one_shot_event.h"
#include "base/strings/strcat.h"
#include "brave/brave_domains/service_domains.h"
#include "brave/components/ai_chat/core/browser/ai_chat_service.h"
#include "brave/components/ai_chat/core/browser/brave_search_responses.h"
#include "brave/components/ai_chat/core/browser/conversation_handler.h"
#include "brave/components/ai_chat/core/browser/utils.h"
Expand Down Expand Up @@ -63,8 +64,10 @@ GetSearchQuerySummaryNetworkTrafficAnnotationTag() {
} // namespace

AssociatedContentDriver::AssociatedContentDriver(
AIChatService* ai_chat_service,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
: url_loader_factory_(url_loader_factory) {}
: url_loader_factory_(url_loader_factory),
ai_chat_service_(ai_chat_service) {}

AssociatedContentDriver::~AssociatedContentDriver() {
DisassociateWithConversations();
Expand Down Expand Up @@ -303,6 +306,11 @@ void AssociatedContentDriver::DisassociateWithConversations() {
is_video_);
}
}
// We also notify the AIChatService directly for conversations that aren't
// directly related.
if (ai_chat_service_) {
ai_chat_service_->OnAssociatedContentDestroyed(current_navigation_id_);
}
}

} // namespace ai_chat
Loading
Loading