Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

keep an AI Chat conversation alive if it has chat messages and its associated content is still alive #26502

Merged
merged 3 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 38 additions & 27 deletions components/ai_chat/core/browser/ai_chat_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -630,36 +630,41 @@ void AIChatService::OnPremiumStatusReceived(GetPremiumStatusCallback callback,

void AIChatService::MaybeUnloadConversation(
ConversationHandler* conversation_handler) {
if (!conversation_handler->IsAnyClientConnected() &&
!conversation_handler->IsRequestInProgress()) {
// Can erase handler because no active UI
bool has_history = conversation_handler->HasAnyHistory();
auto uuid = conversation_handler->get_conversation_uuid();
conversation_observations_.RemoveObservation(conversation_handler);
conversation_handlers_.erase(uuid);
DVLOG(1) << "Unloaded conversation (" << uuid << ") from memory. Now have "
// Don't unload if there is active UI for the conversation
if (conversation_handler->IsAnyClientConnected()) {
return;
}

bool has_history = conversation_handler->HasAnyHistory();

// We can keep a conversation with history in memory until there is no active
// content.
// TODO(petemill): With the history feature enabled, we should unload (if
// there is no request in progress). However, we can only do this when
// GetOrCreateConversationHandlerForContent allows a callback so that it
// can provide an answer after loading the conversation content from storage.
if (conversation_handler->IsAssociatedContentAlive() && has_history) {
return;
}

auto uuid = conversation_handler->get_conversation_uuid();
conversation_observations_.RemoveObservation(conversation_handler);
conversation_handlers_.erase(uuid);
DVLOG(1) << "Unloaded conversation (" << uuid << ") from memory. Now have "
<< conversations_.size() << " Conversation metadata items and "
<< conversation_handlers_.size()
<< " ConversationHandler instances.";
if (!IsAIChatHistoryEnabled() || !has_history) {
// Can erase because no active UI and no history, so it's
// not a real / persistable conversation
conversations_.erase(uuid);
std::erase_if(content_conversations_,
[&uuid](const auto& kv) { return kv.second == uuid; });
DVLOG(1) << "Erased conversation (" << uuid << "). Now have "
<< conversations_.size() << " Conversation metadata items and "
<< conversation_handlers_.size()
<< " ConversationHandler instances.";
if (!IsAIChatHistoryEnabled() || !has_history) {
// Can erase because no active UI and no history, so it's
// not a real / persistable conversation
conversations_.erase(uuid);
std::erase_if(content_conversations_,
[&uuid](const auto& kv) { return kv.second == uuid; });
DVLOG(1) << "Erased conversation (" << uuid << "). Now have "
<< conversations_.size() << " Conversation metadata items and "
<< conversation_handlers_.size()
<< " ConversationHandler instances.";
OnConversationListChanged();
}
} else {
DVLOG(4) << "Not unloading conversation ("
<< conversation_handler->get_conversation_uuid()
<< ") from memory. Has active clients: "
<< (conversation_handler->IsAnyClientConnected() ? "yes" : "no")
<< " Request is in progress: "
<< (conversation_handler->IsRequestInProgress() ? "yes" : "no");
OnConversationListChanged();
}
}

Expand Down Expand Up @@ -821,6 +826,12 @@ void AIChatService::OnConversationTitleChanged(ConversationHandler* handler,
}
}

void AIChatService::OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) {
content_conversations_.erase(content_id);
MaybeUnloadConversation(handler);
}

void AIChatService::GetVisibleConversations(
GetVisibleConversationsCallback callback) {
LoadConversationsLazy(base::BindOnce(
Expand Down
2 changes: 2 additions & 0 deletions components/ai_chat/core/browser/ai_chat_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class AIChatService : public KeyedService,
void OnClientConnectionChanged(ConversationHandler* handler) override;
void OnConversationTitleChanged(ConversationHandler* handler,
std::string title) override;
void OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) override;

// Adds new conversation and returns the handler
ConversationHandler* CreateConversation();
Expand Down
83 changes: 65 additions & 18 deletions components/ai_chat/core/browser/ai_chat_service_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI {
explicit MockConversationHandlerClient(ConversationHandler* driver) {
driver->Bind(conversation_handler_remote_.BindNewPipeAndPassReceiver(),
conversation_ui_receiver_.BindNewPipeAndPassRemote());
conversation_handler_ = driver;
}

~MockConversationHandlerClient() override = default;
Expand All @@ -114,9 +113,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI {
conversation_ui_receiver_.reset();
}

ConversationHandler* GetConversationHandler() {
return conversation_handler_;
}

MOCK_METHOD(void, OnConversationHistoryUpdate, (), (override));

Expand Down Expand Up @@ -147,7 +143,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI {
private:
mojo::Receiver<mojom::ConversationUI> conversation_ui_receiver_{this};
mojo::Remote<mojom::ConversationHandler> conversation_handler_remote_;
raw_ptr<ConversationHandler, DanglingUntriaged> conversation_handler_;
};

class MockAssociatedContent
Expand Down Expand Up @@ -193,7 +188,12 @@ class MockAssociatedContent

void DisassociateWithConversations(std::string archived_text_content,
bool archived_is_video) {
std::vector<base::WeakPtr<ConversationHandler>> related_conversations;
for (auto& conversation : related_conversations_) {
related_conversations.push_back(conversation->GetWeakPtr());
}

for (auto& conversation : related_conversations) {
if (conversation) {
conversation->OnAssociatedContentDestroyed(archived_text_content,
archived_is_video);
Expand Down Expand Up @@ -428,6 +428,56 @@ TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithMessages) {
task_environment_.RunUntilIdle();
}

TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithContent) {
NiceMock<MockAssociatedContent> associated_content{};
ON_CALL(associated_content, GetURL())
.WillByDefault(testing::Return(GURL("https://example.com")));
associated_content.SetContentId(1);
ConversationHandler* conversation_with_content_no_messages =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
EXPECT_TRUE(conversation_with_content_no_messages);
// Asking again for same content ID gets same conversation
EXPECT_EQ(
conversation_with_content_no_messages,
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr()));
// Shouldn't be visible without messages
ExpectVisibleConversationsSize(FROM_HERE, 0u);
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
// Disconnecting the client should unload the handler and delete the
// conversation.
auto client1 =
CreateConversationClient(conversation_with_content_no_messages);
DisconnectConversationClient(client1.get());
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u);
ExpectVisibleConversationsSize(FROM_HERE, 0u);

// Create a new conversation for same content, with messages this time
ConversationHandler* conversation_with_content =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
conversation_with_content->SetChatHistoryForTesting(
CreateSampleChatHistory(1u));
ExpectVisibleConversationsSize(FROM_HERE, 1u);
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
auto client2 = CreateConversationClient(conversation_with_content);
DisconnectConversationClient(client2.get());
// Disconnecting all clients should keep the handler in memory until
// the content is destroyed.
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
ExpectVisibleConversationsSize(FROM_HERE, 1u);
associated_content.DisassociateWithConversations("", false);

if (IsAIChatHistoryEnabled()) {
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u);
ExpectVisibleConversationsSize(FROM_HERE, 1u);
} else {
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u);
ExpectVisibleConversationsSize(FROM_HERE, 0u);
}
}

TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) {
ConversationHandler* conversation_without_content = CreateConversation();

Expand Down Expand Up @@ -463,27 +513,24 @@ TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) {
// Creating a second conversation with the same associated content should
// make the second conversation the default for that content, but leave
// the first still associated with the content.
ConversationHandler* conversation_with_content2 =
ConversationHandler* conversation2 =
ai_chat_service_->CreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
EXPECT_NE(conversation_with_content, conversation_with_content2);
EXPECT_NE(conversation_with_content, conversation2);
EXPECT_NE(conversation_with_content->get_conversation_uuid(),
conversation_with_content2->get_conversation_uuid());
EXPECT_EQ(
conversation_with_content2->GetAssociatedContentDelegateForTesting(),
&associated_content);
EXPECT_EQ(
conversation_with_content->GetAssociatedContentDelegateForTesting(),
conversation_with_content2->GetAssociatedContentDelegateForTesting());
conversation2->get_conversation_uuid());
EXPECT_EQ(conversation2->GetAssociatedContentDelegateForTesting(),
&associated_content);
EXPECT_EQ(conversation_with_content->GetAssociatedContentDelegateForTesting(),
conversation2->GetAssociatedContentDelegateForTesting());
// Check the second conversation is the default for that content ID
EXPECT_EQ(
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr()),
conversation_with_content2);
conversation2);
// Let the conversation be deleted
std::string conversation2_uuid =
conversation_with_content2->get_conversation_uuid();
auto client1 = CreateConversationClient(conversation_with_content2);
std::string conversation2_uuid = conversation2->get_conversation_uuid();
auto client1 = CreateConversationClient(conversation2);
DisconnectConversationClient(client1.get());
ConversationHandler* conversation_with_content3 =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
Expand Down
27 changes: 17 additions & 10 deletions components/ai_chat/core/browser/associated_content_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,7 @@ AssociatedContentDriver::AssociatedContentDriver(
: url_loader_factory_(url_loader_factory) {}

AssociatedContentDriver::~AssociatedContentDriver() {
for (auto& conversation : associated_conversations_) {
if (conversation) {
conversation->OnAssociatedContentDestroyed(cached_text_content_,
is_video_);
}
}
DisassociateWithConversations();
}

void AssociatedContentDriver::AddRelatedConversation(
Expand Down Expand Up @@ -279,10 +274,10 @@ void AssociatedContentDriver::OnFaviconImageDataChanged() {
}

void AssociatedContentDriver::OnNewPage(int64_t navigation_id) {
// Tell the associated_conversations_ that we're breaking up
for (auto& conversation : associated_conversations_) {
conversation->OnAssociatedContentDestroyed(cached_text_content_, is_video_);
}
// This instance will now be used for different content so existing
// conversations need to be disassociated.
DisassociateWithConversations();

// Tell the observer how to find the next conversation
for (auto& observer : observers_) {
observer.OnAssociatedContentNavigated(navigation_id);
Expand All @@ -298,4 +293,16 @@ void AssociatedContentDriver::OnNewPage(int64_t navigation_id) {
ConversationHandler::AssociatedContentDelegate::OnNewPage(navigation_id);
}

void AssociatedContentDriver::DisassociateWithConversations() {
// Iterator might be invalidated by destruction, so copy the items
std::vector<ConversationHandler*> conversations{
associated_conversations_.begin(), associated_conversations_.end()};
for (auto& conversation : conversations) {
if (conversation) {
conversation->OnAssociatedContentDestroyed(cached_text_content_,
is_video_);
}
}
}

} // namespace ai_chat
6 changes: 6 additions & 0 deletions components/ai_chat/core/browser/associated_content_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ class AssociatedContentDriver
ConversationHandler::GetStagedEntriesCallback callback,
int64_t navigation_id,
api_request_helper::APIRequestResult result);

// Let all conversations using this content know that the content
// has been destroyed or changed to represent different content (e.g. a
// navigation).
void DisassociateWithConversations();

static std::optional<std::vector<SearchQuerySummary>>
ParseSearchQuerySummaryResponse(const base::Value& value);

Expand Down
18 changes: 13 additions & 5 deletions components/ai_chat/core/browser/conversation_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ bool ConversationHandler::IsRequestInProgress() {
return is_request_in_progress_;
}

bool ConversationHandler::IsAssociatedContentAlive() {
return associated_content_delegate_ && !archive_content_;
}

void ConversationHandler::OnConversationDeleted() {
for (auto& client : conversation_ui_handlers_) {
client->OnConversationDeleted();
Expand Down Expand Up @@ -382,11 +386,11 @@ void ConversationHandler::InitEngine() {
void ConversationHandler::OnAssociatedContentDestroyed(
std::string last_text_content,
bool is_video) {
// The associated content delegate is destroyed, so we should not try to
// fetch. It may be populated later, e.g. through back navigation.
// If this conversation is allowed to be associated with content, we can keep
// using our current cached content.
associated_content_delegate_ = nullptr;
// The associated content delegate is already or about to be destroyed.
auto content_id = associated_content_delegate_
? associated_content_delegate_->GetContentId()
: -1;
DisassociateContentDelegate();
if (!chat_history_.empty() && should_send_page_contents_ &&
metadata_->associated_content &&
metadata_->associated_content->is_content_association_possible) {
Expand All @@ -397,6 +401,10 @@ void ConversationHandler::OnAssociatedContentDestroyed(
SetArchiveContent(std::move(last_text_content), is_video);
}
OnAssociatedContentInfoChanged();
// Notify observers
for (auto& observer : observers_) {
observer.OnAssociatedContentDestroyed(this, content_id);
}
}

void ConversationHandler::SetArchiveContent(std::string text_content,
Expand Down
5 changes: 5 additions & 0 deletions components/ai_chat/core/browser/conversation_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ class ConversationHandler : public mojom::ConversationHandler,
virtual void OnSelectedLanguageChanged(
ConversationHandler* handler,
const std::string& selected_language) {}
virtual void OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) {}
};

ConversationHandler(
Expand Down Expand Up @@ -207,6 +209,9 @@ class ConversationHandler : public mojom::ConversationHandler,
bool HasAnyHistory();
bool IsRequestInProgress();

// Returns true if the conversation has associated content that is non-archive
bool IsAssociatedContentAlive();

// Called when the associated content is destroyed or navigated away. If
// it's a navigation, the AssociatedContentDelegate will set itself to a new
// ConversationHandler.
Expand Down
Loading