Skip to content

Commit

Permalink
Merge pull request #26502 from brave/ai-chat-keep-conversation-alive
Browse files Browse the repository at this point in the history
keep an AI Chat conversation alive if it has chat messages and its associated content is still alive
  • Loading branch information
petemill authored Dec 9, 2024
2 parents 054900a + 57499b0 commit ef452b7
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 60 deletions.
65 changes: 38 additions & 27 deletions components/ai_chat/core/browser/ai_chat_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -630,36 +630,41 @@ void AIChatService::OnPremiumStatusReceived(GetPremiumStatusCallback callback,

void AIChatService::MaybeUnloadConversation(
ConversationHandler* conversation_handler) {
if (!conversation_handler->IsAnyClientConnected() &&
!conversation_handler->IsRequestInProgress()) {
// Can erase handler because no active UI
bool has_history = conversation_handler->HasAnyHistory();
auto uuid = conversation_handler->get_conversation_uuid();
conversation_observations_.RemoveObservation(conversation_handler);
conversation_handlers_.erase(uuid);
DVLOG(1) << "Unloaded conversation (" << uuid << ") from memory. Now have "
// Don't unload if there is active UI for the conversation
if (conversation_handler->IsAnyClientConnected()) {
return;
}

bool has_history = conversation_handler->HasAnyHistory();

// We can keep a conversation with history in memory until there is no active
// content.
// TODO(petemill): With the history feature enabled, we should unload (if
// there is no request in progress). However, we can only do this when
// GetOrCreateConversationHandlerForContent allows a callback so that it
// can provide an answer after loading the conversation content from storage.
if (conversation_handler->IsAssociatedContentAlive() && has_history) {
return;
}

auto uuid = conversation_handler->get_conversation_uuid();
conversation_observations_.RemoveObservation(conversation_handler);
conversation_handlers_.erase(uuid);
DVLOG(1) << "Unloaded conversation (" << uuid << ") from memory. Now have "
<< conversations_.size() << " Conversation metadata items and "
<< conversation_handlers_.size()
<< " ConversationHandler instances.";
if (!IsAIChatHistoryEnabled() || !has_history) {
// Can erase because no active UI and no history, so it's
// not a real / persistable conversation
conversations_.erase(uuid);
std::erase_if(content_conversations_,
[&uuid](const auto& kv) { return kv.second == uuid; });
DVLOG(1) << "Erased conversation (" << uuid << "). Now have "
<< conversations_.size() << " Conversation metadata items and "
<< conversation_handlers_.size()
<< " ConversationHandler instances.";
if (!IsAIChatHistoryEnabled() || !has_history) {
// Can erase because no active UI and no history, so it's
// not a real / persistable conversation
conversations_.erase(uuid);
std::erase_if(content_conversations_,
[&uuid](const auto& kv) { return kv.second == uuid; });
DVLOG(1) << "Erased conversation (" << uuid << "). Now have "
<< conversations_.size() << " Conversation metadata items and "
<< conversation_handlers_.size()
<< " ConversationHandler instances.";
OnConversationListChanged();
}
} else {
DVLOG(4) << "Not unloading conversation ("
<< conversation_handler->get_conversation_uuid()
<< ") from memory. Has active clients: "
<< (conversation_handler->IsAnyClientConnected() ? "yes" : "no")
<< " Request is in progress: "
<< (conversation_handler->IsRequestInProgress() ? "yes" : "no");
OnConversationListChanged();
}
}

Expand Down Expand Up @@ -821,6 +826,12 @@ void AIChatService::OnConversationTitleChanged(ConversationHandler* handler,
}
}

void AIChatService::OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) {
content_conversations_.erase(content_id);
MaybeUnloadConversation(handler);
}

void AIChatService::GetVisibleConversations(
GetVisibleConversationsCallback callback) {
LoadConversationsLazy(base::BindOnce(
Expand Down
2 changes: 2 additions & 0 deletions components/ai_chat/core/browser/ai_chat_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class AIChatService : public KeyedService,
void OnClientConnectionChanged(ConversationHandler* handler) override;
void OnConversationTitleChanged(ConversationHandler* handler,
std::string title) override;
void OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) override;

// Adds new conversation and returns the handler
ConversationHandler* CreateConversation();
Expand Down
83 changes: 65 additions & 18 deletions components/ai_chat/core/browser/ai_chat_service_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI {
explicit MockConversationHandlerClient(ConversationHandler* driver) {
driver->Bind(conversation_handler_remote_.BindNewPipeAndPassReceiver(),
conversation_ui_receiver_.BindNewPipeAndPassRemote());
conversation_handler_ = driver;
}

~MockConversationHandlerClient() override = default;
Expand All @@ -114,9 +113,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI {
conversation_ui_receiver_.reset();
}

ConversationHandler* GetConversationHandler() {
return conversation_handler_;
}

MOCK_METHOD(void, OnConversationHistoryUpdate, (), (override));

Expand Down Expand Up @@ -147,7 +143,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI {
private:
mojo::Receiver<mojom::ConversationUI> conversation_ui_receiver_{this};
mojo::Remote<mojom::ConversationHandler> conversation_handler_remote_;
raw_ptr<ConversationHandler, DanglingUntriaged> conversation_handler_;
};

class MockAssociatedContent
Expand Down Expand Up @@ -193,7 +188,12 @@ class MockAssociatedContent

void DisassociateWithConversations(std::string archived_text_content,
bool archived_is_video) {
std::vector<base::WeakPtr<ConversationHandler>> related_conversations;
for (auto& conversation : related_conversations_) {
related_conversations.push_back(conversation->GetWeakPtr());
}

for (auto& conversation : related_conversations) {
if (conversation) {
conversation->OnAssociatedContentDestroyed(archived_text_content,
archived_is_video);
Expand Down Expand Up @@ -428,6 +428,56 @@ TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithMessages) {
task_environment_.RunUntilIdle();
}

TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithContent) {
NiceMock<MockAssociatedContent> associated_content{};
ON_CALL(associated_content, GetURL())
.WillByDefault(testing::Return(GURL("https://example.com")));
associated_content.SetContentId(1);
ConversationHandler* conversation_with_content_no_messages =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
EXPECT_TRUE(conversation_with_content_no_messages);
// Asking again for same content ID gets same conversation
EXPECT_EQ(
conversation_with_content_no_messages,
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr()));
// Shouldn't be visible without messages
ExpectVisibleConversationsSize(FROM_HERE, 0u);
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
// Disconnecting the client should unload the handler and delete the
// conversation.
auto client1 =
CreateConversationClient(conversation_with_content_no_messages);
DisconnectConversationClient(client1.get());
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u);
ExpectVisibleConversationsSize(FROM_HERE, 0u);

// Create a new conversation for same content, with messages this time
ConversationHandler* conversation_with_content =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
conversation_with_content->SetChatHistoryForTesting(
CreateSampleChatHistory(1u));
ExpectVisibleConversationsSize(FROM_HERE, 1u);
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
auto client2 = CreateConversationClient(conversation_with_content);
DisconnectConversationClient(client2.get());
// Disconnecting all clients should keep the handler in memory until
// the content is destroyed.
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
ExpectVisibleConversationsSize(FROM_HERE, 1u);
associated_content.DisassociateWithConversations("", false);

if (IsAIChatHistoryEnabled()) {
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u);
ExpectVisibleConversationsSize(FROM_HERE, 1u);
} else {
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u);
ExpectVisibleConversationsSize(FROM_HERE, 0u);
}
}

TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) {
ConversationHandler* conversation_without_content = CreateConversation();

Expand Down Expand Up @@ -463,27 +513,24 @@ TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) {
// Creating a second conversation with the same associated content should
// make the second conversation the default for that content, but leave
// the first still associated with the content.
ConversationHandler* conversation_with_content2 =
ConversationHandler* conversation2 =
ai_chat_service_->CreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
EXPECT_NE(conversation_with_content, conversation_with_content2);
EXPECT_NE(conversation_with_content, conversation2);
EXPECT_NE(conversation_with_content->get_conversation_uuid(),
conversation_with_content2->get_conversation_uuid());
EXPECT_EQ(
conversation_with_content2->GetAssociatedContentDelegateForTesting(),
&associated_content);
EXPECT_EQ(
conversation_with_content->GetAssociatedContentDelegateForTesting(),
conversation_with_content2->GetAssociatedContentDelegateForTesting());
conversation2->get_conversation_uuid());
EXPECT_EQ(conversation2->GetAssociatedContentDelegateForTesting(),
&associated_content);
EXPECT_EQ(conversation_with_content->GetAssociatedContentDelegateForTesting(),
conversation2->GetAssociatedContentDelegateForTesting());
// Check the second conversation is the default for that content ID
EXPECT_EQ(
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr()),
conversation_with_content2);
conversation2);
// Let the conversation be deleted
std::string conversation2_uuid =
conversation_with_content2->get_conversation_uuid();
auto client1 = CreateConversationClient(conversation_with_content2);
std::string conversation2_uuid = conversation2->get_conversation_uuid();
auto client1 = CreateConversationClient(conversation2);
DisconnectConversationClient(client1.get());
ConversationHandler* conversation_with_content3 =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
Expand Down
27 changes: 17 additions & 10 deletions components/ai_chat/core/browser/associated_content_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,7 @@ AssociatedContentDriver::AssociatedContentDriver(
: url_loader_factory_(url_loader_factory) {}

AssociatedContentDriver::~AssociatedContentDriver() {
for (auto& conversation : associated_conversations_) {
if (conversation) {
conversation->OnAssociatedContentDestroyed(cached_text_content_,
is_video_);
}
}
DisassociateWithConversations();
}

void AssociatedContentDriver::AddRelatedConversation(
Expand Down Expand Up @@ -279,10 +274,10 @@ void AssociatedContentDriver::OnFaviconImageDataChanged() {
}

void AssociatedContentDriver::OnNewPage(int64_t navigation_id) {
// Tell the associated_conversations_ that we're breaking up
for (auto& conversation : associated_conversations_) {
conversation->OnAssociatedContentDestroyed(cached_text_content_, is_video_);
}
// This instance will now be used for different content so existing
// conversations need to be disassociated.
DisassociateWithConversations();

// Tell the observer how to find the next conversation
for (auto& observer : observers_) {
observer.OnAssociatedContentNavigated(navigation_id);
Expand All @@ -298,4 +293,16 @@ void AssociatedContentDriver::OnNewPage(int64_t navigation_id) {
ConversationHandler::AssociatedContentDelegate::OnNewPage(navigation_id);
}

void AssociatedContentDriver::DisassociateWithConversations() {
// Iterator might be invalidated by destruction, so copy the items
std::vector<ConversationHandler*> conversations{
associated_conversations_.begin(), associated_conversations_.end()};
for (auto& conversation : conversations) {
if (conversation) {
conversation->OnAssociatedContentDestroyed(cached_text_content_,
is_video_);
}
}
}

} // namespace ai_chat
6 changes: 6 additions & 0 deletions components/ai_chat/core/browser/associated_content_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ class AssociatedContentDriver
ConversationHandler::GetStagedEntriesCallback callback,
int64_t navigation_id,
api_request_helper::APIRequestResult result);

// Let all conversations using this content know that the content
// has been destroyed or changed to represent different content (e.g. a
// navigation).
void DisassociateWithConversations();

static std::optional<std::vector<SearchQuerySummary>>
ParseSearchQuerySummaryResponse(const base::Value& value);

Expand Down
18 changes: 13 additions & 5 deletions components/ai_chat/core/browser/conversation_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ bool ConversationHandler::IsRequestInProgress() {
return is_request_in_progress_;
}

bool ConversationHandler::IsAssociatedContentAlive() {
return associated_content_delegate_ && !archive_content_;
}

void ConversationHandler::OnConversationDeleted() {
for (auto& client : conversation_ui_handlers_) {
client->OnConversationDeleted();
Expand Down Expand Up @@ -382,11 +386,11 @@ void ConversationHandler::InitEngine() {
void ConversationHandler::OnAssociatedContentDestroyed(
std::string last_text_content,
bool is_video) {
// The associated content delegate is destroyed, so we should not try to
// fetch. It may be populated later, e.g. through back navigation.
// If this conversation is allowed to be associated with content, we can keep
// using our current cached content.
associated_content_delegate_ = nullptr;
// The associated content delegate is already or about to be destroyed.
auto content_id = associated_content_delegate_
? associated_content_delegate_->GetContentId()
: -1;
DisassociateContentDelegate();
if (!chat_history_.empty() && should_send_page_contents_ &&
metadata_->associated_content &&
metadata_->associated_content->is_content_association_possible) {
Expand All @@ -397,6 +401,10 @@ void ConversationHandler::OnAssociatedContentDestroyed(
SetArchiveContent(std::move(last_text_content), is_video);
}
OnAssociatedContentInfoChanged();
// Notify observers
for (auto& observer : observers_) {
observer.OnAssociatedContentDestroyed(this, content_id);
}
}

void ConversationHandler::SetArchiveContent(std::string text_content,
Expand Down
5 changes: 5 additions & 0 deletions components/ai_chat/core/browser/conversation_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ class ConversationHandler : public mojom::ConversationHandler,
virtual void OnSelectedLanguageChanged(
ConversationHandler* handler,
const std::string& selected_language) {}
virtual void OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) {}
};

ConversationHandler(
Expand Down Expand Up @@ -207,6 +209,9 @@ class ConversationHandler : public mojom::ConversationHandler,
bool HasAnyHistory();
bool IsRequestInProgress();

// Returns true if the conversation has associated content that is non-archive
bool IsAssociatedContentAlive();

// Called when the associated content is destroyed or navigated away. If
// it's a navigation, the AssociatedContentDelegate will set itself to a new
// ConversationHandler.
Expand Down

0 comments on commit ef452b7

Please sign in to comment.