Skip to content

Commit

Permalink
keep an AI Chat conversation alive if it has chat messages and its as…
Browse files Browse the repository at this point in the history
…sociated content is still alive (uplift to 1.74.x) (#26942)

Merge pull request #26502 from brave/ai-chat-keep-conversation-alive

keep an AI Chat conversation alive if it has chat messages and its associated content is still alive
  • Loading branch information
petemill authored Dec 10, 2024
1 parent 3f746b7 commit 0515869
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 36 deletions.
29 changes: 26 additions & 3 deletions components/ai_chat/core/browser/ai_chat_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,26 @@ void AIChatService::OnPremiumStatusReceived(GetPremiumStatusCallback callback,

void AIChatService::MaybeEraseConversation(
ConversationHandler* conversation_handler) {
if (!conversation_handler->IsAnyClientConnected() &&
(!features::IsAIChatHistoryEnabled() ||
!conversation_handler->HasAnyHistory())) {
// Don't unload if there is active UI for the conversation
if (conversation_handler->IsAnyClientConnected()) {
return;
}

bool has_history = conversation_handler->HasAnyHistory();

// We can keep a conversation with history in memory until there is no active
// content.
// TODO(petemill): With the history feature enabled, we should unload (if
// there is no request in progress). However, we can only do this when
// GetOrCreateConversationHandlerForContent allows a callback so that it
// can provide an answer after loading the conversation content from storage.
if (conversation_handler->IsAssociatedContentAlive() && has_history) {
return;
}

// AIChatHistory feature doesn't yet have persistant storage, so keep
// handlers and data around if it's enabled.
if (!features::IsAIChatHistoryEnabled() || !has_history) {
// Can erase because no active UI and no history, so it's
// not a real / persistable conversation
auto uuid = conversation_handler->get_conversation_uuid();
Expand Down Expand Up @@ -328,6 +345,12 @@ void AIChatService::OnConversationTitleChanged(ConversationHandler* handler,
OnConversationListChanged();
}

void AIChatService::OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) {
content_conversations_.erase(content_id);
MaybeEraseConversation(handler);
}

void AIChatService::GetVisibleConversations(
GetVisibleConversationsCallback callback) {
std::vector<mojom::ConversationPtr> conversations;
Expand Down
2 changes: 2 additions & 0 deletions components/ai_chat/core/browser/ai_chat_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class AIChatService : public KeyedService,
void OnClientConnectionChanged(ConversationHandler* handler) override;
void OnConversationTitleChanged(ConversationHandler* handler,
std::string title) override;
void OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) override;

// Adds new conversation and returns the handler
ConversationHandler* CreateConversation();
Expand Down
103 changes: 85 additions & 18 deletions components/ai_chat/core/browser/ai_chat_service_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cstddef>
#include <memory>
#include <optional>
#include <set>
#include <string>
#include <string_view>
#include <utility>
Expand Down Expand Up @@ -131,7 +132,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI {
explicit MockConversationHandlerClient(ConversationHandler* driver) {
driver->Bind(conversation_handler_remote_.BindNewPipeAndPassReceiver(),
conversation_ui_receiver_.BindNewPipeAndPassRemote());
conversation_handler_ = driver;
}

~MockConversationHandlerClient() override = default;
Expand All @@ -141,9 +141,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI {
conversation_ui_receiver_.reset();
}

ConversationHandler* GetConversationHandler() {
return conversation_handler_;
}

MOCK_METHOD(void, OnConversationHistoryUpdate, (), (override));

Expand Down Expand Up @@ -174,7 +171,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI {
private:
mojo::Receiver<mojom::ConversationUI> conversation_ui_receiver_{this};
mojo::Remote<mojom::ConversationHandler> conversation_handler_remote_;
raw_ptr<ConversationHandler, DanglingUntriaged> conversation_handler_;
};

class MockAssociatedContent
Expand Down Expand Up @@ -203,6 +199,30 @@ class MockAssociatedContent
(override));
MOCK_METHOD(bool, HasOpenAIChatPermission, (), (const, override));

void AddRelatedConversation(ConversationHandler* conversation) override {
related_conversations_.insert(conversation);
}

void OnRelatedConversationDisassociated(
ConversationHandler* conversation) override {
related_conversations_.erase(conversation);
}

void DisassociateWithConversations(std::string archived_text_content,
bool archived_is_video) {
std::vector<base::WeakPtr<ConversationHandler>> related_conversations;
for (auto& conversation : related_conversations_) {
related_conversations.push_back(conversation->GetWeakPtr());
}

for (auto& conversation : related_conversations) {
if (conversation) {
conversation->OnAssociatedContentDestroyed(archived_text_content,
archived_is_video);
}
}
}

base::WeakPtr<ConversationHandler::AssociatedContentDelegate> GetWeakPtr() {
return weak_ptr_factory_.GetWeakPtr();
}
Expand All @@ -211,6 +231,7 @@ class MockAssociatedContent
base::WeakPtrFactory<ConversationHandler::AssociatedContentDelegate>
weak_ptr_factory_{this};
int content_id_ = 0;
std::set<raw_ptr<ConversationHandler>> related_conversations_;
};

} // namespace
Expand Down Expand Up @@ -404,6 +425,55 @@ TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithMessages) {
task_environment_.RunUntilIdle();
}

TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithContent) {
NiceMock<MockAssociatedContent> associated_content{};
ON_CALL(associated_content, GetURL())
.WillByDefault(testing::Return(GURL("https://example.com")));
associated_content.SetContentId(1);
ConversationHandler* conversation_with_content_no_messages =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
EXPECT_TRUE(conversation_with_content_no_messages);
// Asking again for same content ID gets same conversation
EXPECT_EQ(
conversation_with_content_no_messages,
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr()));
// Shouldn't be visible without messages
ExpectVisibleConversationsSize(FROM_HERE, 0u);
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
// Disconnecting the client should unload the handler and delete the
// conversation.
auto client1 =
CreateConversationClient(conversation_with_content_no_messages);
DisconnectConversationClient(client1.get());
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u);
ExpectVisibleConversationsSize(FROM_HERE, 0u);

// Create a new conversation for same content, with messages this time
ConversationHandler* conversation_with_content =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
conversation_with_content->SetChatHistoryForTesting(CreateSampleHistory());
ExpectVisibleConversationsSize(FROM_HERE, 1u);
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
auto client2 = CreateConversationClient(conversation_with_content);
DisconnectConversationClient(client2.get());
// Disconnecting all clients should keep the handler in memory until
// the content is destroyed.
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
ExpectVisibleConversationsSize(FROM_HERE, 1u);
associated_content.DisassociateWithConversations("", false);

if (IsAIChatHistoryEnabled()) {
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
ExpectVisibleConversationsSize(FROM_HERE, 1u);
} else {
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u);
ExpectVisibleConversationsSize(FROM_HERE, 0u);
}
}

TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) {
ConversationHandler* conversation_without_content = CreateConversation();

Expand Down Expand Up @@ -439,27 +509,24 @@ TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) {
// Creating a second conversation with the same associated content should
// make the second conversation the default for that content, but leave
// the first still associated with the content.
ConversationHandler* conversation_with_content2 =
ConversationHandler* conversation2 =
ai_chat_service_->CreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
EXPECT_NE(conversation_with_content, conversation_with_content2);
EXPECT_NE(conversation_with_content, conversation2);
EXPECT_NE(conversation_with_content->get_conversation_uuid(),
conversation_with_content2->get_conversation_uuid());
EXPECT_EQ(
conversation_with_content2->GetAssociatedContentDelegateForTesting(),
&associated_content);
EXPECT_EQ(
conversation_with_content->GetAssociatedContentDelegateForTesting(),
conversation_with_content2->GetAssociatedContentDelegateForTesting());
conversation2->get_conversation_uuid());
EXPECT_EQ(conversation2->GetAssociatedContentDelegateForTesting(),
&associated_content);
EXPECT_EQ(conversation_with_content->GetAssociatedContentDelegateForTesting(),
conversation2->GetAssociatedContentDelegateForTesting());
// Check the second conversation is the default for that content ID
EXPECT_EQ(
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr()),
conversation_with_content2);
conversation2);
// Let the conversation be deleted
std::string conversation2_uuid =
conversation_with_content2->get_conversation_uuid();
auto client1 = CreateConversationClient(conversation_with_content2);
std::string conversation2_uuid = conversation2->get_conversation_uuid();
auto client1 = CreateConversationClient(conversation2);
DisconnectConversationClient(client1.get());
ConversationHandler* conversation_with_content3 =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
Expand Down
27 changes: 17 additions & 10 deletions components/ai_chat/core/browser/associated_content_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,7 @@ AssociatedContentDriver::AssociatedContentDriver(
: url_loader_factory_(url_loader_factory) {}

AssociatedContentDriver::~AssociatedContentDriver() {
for (auto& conversation : associated_conversations_) {
if (conversation) {
conversation->OnAssociatedContentDestroyed(cached_text_content_,
is_video_);
}
}
DisassociateWithConversations();
}

void AssociatedContentDriver::AddRelatedConversation(
Expand Down Expand Up @@ -273,10 +268,10 @@ void AssociatedContentDriver::OnFaviconImageDataChanged() {
}

void AssociatedContentDriver::OnNewPage(int64_t navigation_id) {
// Tell the associated_conversations_ that we're breaking up
for (auto& conversation : associated_conversations_) {
conversation->OnAssociatedContentDestroyed(cached_text_content_, is_video_);
}
// This instance will now be used for different content so existing
// conversations need to be disassociated.
DisassociateWithConversations();

// Tell the observer how to find the next conversation
for (auto& observer : observers_) {
observer.OnAssociatedContentNavigated(navigation_id);
Expand All @@ -292,4 +287,16 @@ void AssociatedContentDriver::OnNewPage(int64_t navigation_id) {
ConversationHandler::AssociatedContentDelegate::OnNewPage(navigation_id);
}

void AssociatedContentDriver::DisassociateWithConversations() {
// Iterator might be invalidated by destruction, so copy the items
std::vector<ConversationHandler*> conversations{
associated_conversations_.begin(), associated_conversations_.end()};
for (auto& conversation : conversations) {
if (conversation) {
conversation->OnAssociatedContentDestroyed(cached_text_content_,
is_video_);
}
}
}

} // namespace ai_chat
6 changes: 6 additions & 0 deletions components/ai_chat/core/browser/associated_content_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ class AssociatedContentDriver
ConversationHandler::GetStagedEntriesCallback callback,
int64_t navigation_id,
api_request_helper::APIRequestResult result);

// Let all conversations using this content know that the content
// has been destroyed or changed to represent different content (e.g. a
// navigation).
void DisassociateWithConversations();

static std::optional<std::vector<SearchQuerySummary>>
ParseSearchQuerySummaryResponse(const base::Value& value);

Expand Down
18 changes: 13 additions & 5 deletions components/ai_chat/core/browser/conversation_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ bool ConversationHandler::HasAnyHistory() {
});
}

bool ConversationHandler::IsAssociatedContentAlive() {
return associated_content_delegate_ && !archive_content_;
}

void ConversationHandler::OnConversationDeleted() {
for (auto& client : conversation_ui_handlers_) {
client->OnConversationDeleted();
Expand Down Expand Up @@ -256,11 +260,11 @@ void ConversationHandler::InitEngine() {
void ConversationHandler::OnAssociatedContentDestroyed(
std::string last_text_content,
bool is_video) {
// The associated content delegate is destroyed, so we should not try to
// fetch. It may be populated later, e.g. through back navigation.
// If this conversation is allowed to be associated with content, we can keep
// using our current cached content.
associated_content_delegate_ = nullptr;
// The associated content delegate is already or about to be destroyed.
auto content_id = associated_content_delegate_
? associated_content_delegate_->GetContentId()
: -1;
DisassociateContentDelegate();
if (!chat_history_.empty() && should_send_page_contents_ &&
associated_content_info_ && associated_content_info_->url.has_value()) {
// Get the latest version of article text and
Expand All @@ -276,6 +280,10 @@ void ConversationHandler::OnAssociatedContentDestroyed(
archive_content_ = std::move(archive_content);
}
OnAssociatedContentInfoChanged();
// Notify observers
for (auto& observer : observers_) {
observer.OnAssociatedContentDestroyed(this, content_id);
}
}

void ConversationHandler::SetAssociatedContentDelegate(
Expand Down
5 changes: 5 additions & 0 deletions components/ai_chat/core/browser/conversation_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ class ConversationHandler : public mojom::ConversationHandler,
virtual void OnSelectedLanguageChanged(
ConversationHandler* handler,
const std::string& selected_language) {}
virtual void OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) {}
};

ConversationHandler(
Expand All @@ -167,6 +169,9 @@ class ConversationHandler : public mojom::ConversationHandler,
bool HasAnyHistory();
void OnConversationDeleted();

// Returns true if the conversation has associated content that is non-archive
bool IsAssociatedContentAlive();

// Called when the associated content is destroyed or navigated away. If
// it's a navigation, the AssociatedContentDelegate will set itself to a new
// ConversationHandler.
Expand Down

0 comments on commit 0515869

Please sign in to comment.