Skip to content

Commit

Permalink
Merge pull request #26502 from brave/ai-chat-keep-conversation-alive
Browse files Browse the repository at this point in the history
keep an AI Chat conversation alive if it has chat messages and its associated content is still alive
  • Loading branch information
petemill committed Dec 9, 2024
1 parent 7228742 commit 465300d
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 38 deletions.
33 changes: 28 additions & 5 deletions components/ai_chat/core/browser/ai_chat_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,18 +268,35 @@ void AIChatService::OnPremiumStatusReceived(GetPremiumStatusCallback callback,
}

void AIChatService::MaybeEraseConversation(
ConversationHandler* conversation_handler) {
if (!conversation_handler->IsAnyClientConnected() &&
(!features::IsAIChatHistoryEnabled() ||
!conversation_handler->HasAnyHistory())) {
ConversationHandler *conversation_handler) {
// Don't unload if there is active UI for the conversation
if (conversation_handler->IsAnyClientConnected()) {
return;
}

bool has_history = conversation_handler->HasAnyHistory();

// We can keep a conversation with history in memory until there is no active
// content.
// TODO(petemill): With the history feature enabled, we should unload (if
// there is no request in progress). However, we can only do this when
// GetOrCreateConversationHandlerForContent allows a callback so that it
// can provide an answer after loading the conversation content from storage.
if (conversation_handler->IsAssociatedContentAlive() && has_history) {
return;
}

// AIChatHistory feature doesn't yet have persistant storage, so keep
// handlers and data around if it's enabled.
if (!features::IsAIChatHistoryEnabled() || has_history) {
// Can erase because no active UI and no history, so it's
// not a real / persistable conversation
auto uuid = conversation_handler->get_conversation_uuid();
conversation_observations_.RemoveObservation(conversation_handler);
conversation_handlers_.erase(uuid);
conversations_.erase(uuid);
std::erase_if(content_conversations_,
[&uuid](const auto& kv) { return kv.second == uuid; });
[&uuid](const auto &kv) { return kv.second == uuid; });
DVLOG(1) << "Erased conversation (" << uuid << "). Now have "
<< conversations_.size() << " Conversation metadata items and "
<< conversation_handlers_.size()
Expand Down Expand Up @@ -328,6 +345,12 @@ void AIChatService::OnConversationTitleChanged(ConversationHandler* handler,
OnConversationListChanged();
}

void AIChatService::OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) {
content_conversations_.erase(content_id);
MaybeEraseConversation(handler);
}

void AIChatService::GetVisibleConversations(
GetVisibleConversationsCallback callback) {
std::vector<mojom::ConversationPtr> conversations;
Expand Down
2 changes: 2 additions & 0 deletions components/ai_chat/core/browser/ai_chat_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class AIChatService : public KeyedService,
void OnClientConnectionChanged(ConversationHandler* handler) override;
void OnConversationTitleChanged(ConversationHandler* handler,
std::string title) override;
void OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) override;

// Adds new conversation and returns the handler
ConversationHandler* CreateConversation();
Expand Down
103 changes: 85 additions & 18 deletions components/ai_chat/core/browser/ai_chat_service_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI {
explicit MockConversationHandlerClient(ConversationHandler* driver) {
driver->Bind(conversation_handler_remote_.BindNewPipeAndPassReceiver(),
conversation_ui_receiver_.BindNewPipeAndPassRemote());
conversation_handler_ = driver;
}

~MockConversationHandlerClient() override = default;
Expand All @@ -141,9 +140,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI {
conversation_ui_receiver_.reset();
}

ConversationHandler* GetConversationHandler() {
return conversation_handler_;
}

MOCK_METHOD(void, OnConversationHistoryUpdate, (), (override));

Expand Down Expand Up @@ -174,7 +170,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI {
private:
mojo::Receiver<mojom::ConversationUI> conversation_ui_receiver_{this};
mojo::Remote<mojom::ConversationHandler> conversation_handler_remote_;
raw_ptr<ConversationHandler, DanglingUntriaged> conversation_handler_;
};

class MockAssociatedContent
Expand Down Expand Up @@ -203,6 +198,30 @@ class MockAssociatedContent
(override));
MOCK_METHOD(bool, HasOpenAIChatPermission, (), (const, override));

void AddRelatedConversation(ConversationHandler* conversation) override {
related_conversations_.insert(conversation);
}

void OnRelatedConversationDisassociated(
ConversationHandler* conversation) override {
related_conversations_.erase(conversation);
}

void DisassociateWithConversations(std::string archived_text_content,
bool archived_is_video) {
std::vector<base::WeakPtr<ConversationHandler>> related_conversations;
for (auto& conversation : related_conversations_) {
related_conversations.push_back(conversation->GetWeakPtr());
}

for (auto& conversation : related_conversations) {
if (conversation) {
conversation->OnAssociatedContentDestroyed(archived_text_content,
archived_is_video);
}
}
}

base::WeakPtr<ConversationHandler::AssociatedContentDelegate> GetWeakPtr() {
return weak_ptr_factory_.GetWeakPtr();
}
Expand All @@ -211,6 +230,7 @@ class MockAssociatedContent
base::WeakPtrFactory<ConversationHandler::AssociatedContentDelegate>
weak_ptr_factory_{this};
int content_id_ = 0;
std::set<raw_ptr<ConversationHandler>> related_conversations_;
};

} // namespace
Expand Down Expand Up @@ -404,6 +424,56 @@ TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithMessages) {
task_environment_.RunUntilIdle();
}

TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithContent) {
NiceMock<MockAssociatedContent> associated_content{};
ON_CALL(associated_content, GetURL())
.WillByDefault(testing::Return(GURL("https://example.com")));
associated_content.SetContentId(1);
ConversationHandler* conversation_with_content_no_messages =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
EXPECT_TRUE(conversation_with_content_no_messages);
// Asking again for same content ID gets same conversation
EXPECT_EQ(
conversation_with_content_no_messages,
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr()));
// Shouldn't be visible without messages
ExpectVisibleConversationsSize(FROM_HERE, 0u);
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
// Disconnecting the client should unload the handler and delete the
// conversation.
auto client1 =
CreateConversationClient(conversation_with_content_no_messages);
DisconnectConversationClient(client1.get());
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u);
ExpectVisibleConversationsSize(FROM_HERE, 0u);

// Create a new conversation for same content, with messages this time
ConversationHandler* conversation_with_content =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
conversation_with_content->SetChatHistoryForTesting(
CreateSampleChatHistory(1u));
ExpectVisibleConversationsSize(FROM_HERE, 1u);
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
auto client2 = CreateConversationClient(conversation_with_content);
DisconnectConversationClient(client2.get());
// Disconnecting all clients should keep the handler in memory until
// the content is destroyed.
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
ExpectVisibleConversationsSize(FROM_HERE, 1u);
associated_content.DisassociateWithConversations("", false);

if (IsAIChatHistoryEnabled()) {
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u);
ExpectVisibleConversationsSize(FROM_HERE, 1u);
} else {
EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u);
ExpectVisibleConversationsSize(FROM_HERE, 0u);
}
}

TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) {
ConversationHandler* conversation_without_content = CreateConversation();

Expand Down Expand Up @@ -439,27 +509,24 @@ TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) {
// Creating a second conversation with the same associated content should
// make the second conversation the default for that content, but leave
// the first still associated with the content.
ConversationHandler* conversation_with_content2 =
ConversationHandler* conversation2 =
ai_chat_service_->CreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr());
EXPECT_NE(conversation_with_content, conversation_with_content2);
EXPECT_NE(conversation_with_content, conversation2);
EXPECT_NE(conversation_with_content->get_conversation_uuid(),
conversation_with_content2->get_conversation_uuid());
EXPECT_EQ(
conversation_with_content2->GetAssociatedContentDelegateForTesting(),
&associated_content);
EXPECT_EQ(
conversation_with_content->GetAssociatedContentDelegateForTesting(),
conversation_with_content2->GetAssociatedContentDelegateForTesting());
conversation2->get_conversation_uuid());
EXPECT_EQ(conversation2->GetAssociatedContentDelegateForTesting(),
&associated_content);
EXPECT_EQ(conversation_with_content->GetAssociatedContentDelegateForTesting(),
conversation2->GetAssociatedContentDelegateForTesting());
// Check the second conversation is the default for that content ID
EXPECT_EQ(
ai_chat_service_->GetOrCreateConversationHandlerForContent(
associated_content.GetContentId(), associated_content.GetWeakPtr()),
conversation_with_content2);
conversation2);
// Let the conversation be deleted
std::string conversation2_uuid =
conversation_with_content2->get_conversation_uuid();
auto client1 = CreateConversationClient(conversation_with_content2);
std::string conversation2_uuid = conversation2->get_conversation_uuid();
auto client1 = CreateConversationClient(conversation2);
DisconnectConversationClient(client1.get());
ConversationHandler* conversation_with_content3 =
ai_chat_service_->GetOrCreateConversationHandlerForContent(
Expand Down
27 changes: 17 additions & 10 deletions components/ai_chat/core/browser/associated_content_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,7 @@ AssociatedContentDriver::AssociatedContentDriver(
: url_loader_factory_(url_loader_factory) {}

AssociatedContentDriver::~AssociatedContentDriver() {
for (auto& conversation : associated_conversations_) {
if (conversation) {
conversation->OnAssociatedContentDestroyed(cached_text_content_,
is_video_);
}
}
DisassociateWithConversations();
}

void AssociatedContentDriver::AddRelatedConversation(
Expand Down Expand Up @@ -273,10 +268,10 @@ void AssociatedContentDriver::OnFaviconImageDataChanged() {
}

void AssociatedContentDriver::OnNewPage(int64_t navigation_id) {
// Tell the associated_conversations_ that we're breaking up
for (auto& conversation : associated_conversations_) {
conversation->OnAssociatedContentDestroyed(cached_text_content_, is_video_);
}
// This instance will now be used for different content so existing
// conversations need to be disassociated.
DisassociateWithConversations();

// Tell the observer how to find the next conversation
for (auto& observer : observers_) {
observer.OnAssociatedContentNavigated(navigation_id);
Expand All @@ -292,4 +287,16 @@ void AssociatedContentDriver::OnNewPage(int64_t navigation_id) {
ConversationHandler::AssociatedContentDelegate::OnNewPage(navigation_id);
}

void AssociatedContentDriver::DisassociateWithConversations() {
// Iterator might be invalidated by destruction, so copy the items
std::vector<ConversationHandler*> conversations{
associated_conversations_.begin(), associated_conversations_.end()};
for (auto& conversation : conversations) {
if (conversation) {
conversation->OnAssociatedContentDestroyed(cached_text_content_,
is_video_);
}
}
}

} // namespace ai_chat
6 changes: 6 additions & 0 deletions components/ai_chat/core/browser/associated_content_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ class AssociatedContentDriver
ConversationHandler::GetStagedEntriesCallback callback,
int64_t navigation_id,
api_request_helper::APIRequestResult result);

// Let all conversations using this content know that the content
// has been destroyed or changed to represent different content (e.g. a
// navigation).
void DisassociateWithConversations();

static std::optional<std::vector<SearchQuerySummary>>
ParseSearchQuerySummaryResponse(const base::Value& value);

Expand Down
18 changes: 13 additions & 5 deletions components/ai_chat/core/browser/conversation_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ bool ConversationHandler::HasAnyHistory() {
});
}

bool ConversationHandler::IsAssociatedContentAlive() {
return associated_content_delegate_ && !archive_content_;
}

void ConversationHandler::OnConversationDeleted() {
for (auto& client : conversation_ui_handlers_) {
client->OnConversationDeleted();
Expand Down Expand Up @@ -256,11 +260,11 @@ void ConversationHandler::InitEngine() {
void ConversationHandler::OnAssociatedContentDestroyed(
std::string last_text_content,
bool is_video) {
// The associated content delegate is destroyed, so we should not try to
// fetch. It may be populated later, e.g. through back navigation.
// If this conversation is allowed to be associated with content, we can keep
// using our current cached content.
associated_content_delegate_ = nullptr;
// The associated content delegate is already or about to be destroyed.
auto content_id = associated_content_delegate_
? associated_content_delegate_->GetContentId()
: -1;
DisassociateContentDelegate();
if (!chat_history_.empty() && should_send_page_contents_ &&
associated_content_info_ && associated_content_info_->url.has_value()) {
// Get the latest version of article text and
Expand All @@ -276,6 +280,10 @@ void ConversationHandler::OnAssociatedContentDestroyed(
archive_content_ = std::move(archive_content);
}
OnAssociatedContentInfoChanged();
// Notify observers
for (auto& observer : observers_) {
observer.OnAssociatedContentDestroyed(this, content_id);
}
}

void ConversationHandler::SetAssociatedContentDelegate(
Expand Down
5 changes: 5 additions & 0 deletions components/ai_chat/core/browser/conversation_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ class ConversationHandler : public mojom::ConversationHandler,
virtual void OnSelectedLanguageChanged(
ConversationHandler* handler,
const std::string& selected_language) {}
virtual void OnAssociatedContentDestroyed(ConversationHandler* handler,
int content_id) {}
};

ConversationHandler(
Expand All @@ -167,6 +169,9 @@ class ConversationHandler : public mojom::ConversationHandler,
bool HasAnyHistory();
void OnConversationDeleted();

// Returns true if the conversation has associated content that is non-archive
bool IsAssociatedContentAlive();

// Called when the associated content is destroyed or navigated away. If
// it's a navigation, the AssociatedContentDelegate will set itself to a new
// ConversationHandler.
Expand Down

0 comments on commit 465300d

Please sign in to comment.