diff --git a/components/ai_chat/core/browser/ai_chat_service.cc b/components/ai_chat/core/browser/ai_chat_service.cc index 6ca39f53c09c..46fac0ce68d2 100644 --- a/components/ai_chat/core/browser/ai_chat_service.cc +++ b/components/ai_chat/core/browser/ai_chat_service.cc @@ -268,10 +268,27 @@ void AIChatService::OnPremiumStatusReceived(GetPremiumStatusCallback callback, } void AIChatService::MaybeEraseConversation( - ConversationHandler* conversation_handler) { - if (!conversation_handler->IsAnyClientConnected() && - (!features::IsAIChatHistoryEnabled() || - !conversation_handler->HasAnyHistory())) { + ConversationHandler *conversation_handler) { + // Don't unload if there is active UI for the conversation + if (conversation_handler->IsAnyClientConnected()) { + return; + } + + bool has_history = conversation_handler->HasAnyHistory(); + + // We can keep a conversation with history in memory until there is no active + // content. + // TODO(petemill): With the history feature enabled, we should unload (if + // there is no request in progress). However, we can only do this when + // GetOrCreateConversationHandlerForContent allows a callback so that it + // can provide an answer after loading the conversation content from storage. + if (conversation_handler->IsAssociatedContentAlive() && has_history) { + return; + } + + // AIChatHistory feature doesn't yet have persistant storage, so keep + // handlers and data around if it's enabled. + if (!features::IsAIChatHistoryEnabled() || has_history) { // Can erase because no active UI and no history, so it's // not a real / persistable conversation auto uuid = conversation_handler->get_conversation_uuid(); @@ -279,7 +296,7 @@ void AIChatService::MaybeEraseConversation( conversation_handlers_.erase(uuid); conversations_.erase(uuid); std::erase_if(content_conversations_, - [&uuid](const auto& kv) { return kv.second == uuid; }); + [&uuid](const auto &kv) { return kv.second == uuid; }); DVLOG(1) << "Erased conversation (" << uuid << "). Now have " << conversations_.size() << " Conversation metadata items and " << conversation_handlers_.size() @@ -328,6 +345,12 @@ void AIChatService::OnConversationTitleChanged(ConversationHandler* handler, OnConversationListChanged(); } +void AIChatService::OnAssociatedContentDestroyed(ConversationHandler* handler, + int content_id) { + content_conversations_.erase(content_id); + MaybeEraseConversation(handler); +} + void AIChatService::GetVisibleConversations( GetVisibleConversationsCallback callback) { std::vector conversations; diff --git a/components/ai_chat/core/browser/ai_chat_service.h b/components/ai_chat/core/browser/ai_chat_service.h index 728b3e2e7439..1653046f988e 100644 --- a/components/ai_chat/core/browser/ai_chat_service.h +++ b/components/ai_chat/core/browser/ai_chat_service.h @@ -65,6 +65,8 @@ class AIChatService : public KeyedService, void OnClientConnectionChanged(ConversationHandler* handler) override; void OnConversationTitleChanged(ConversationHandler* handler, std::string title) override; + void OnAssociatedContentDestroyed(ConversationHandler* handler, + int content_id) override; // Adds new conversation and returns the handler ConversationHandler* CreateConversation(); diff --git a/components/ai_chat/core/browser/ai_chat_service_unittest.cc b/components/ai_chat/core/browser/ai_chat_service_unittest.cc index a67e287fbfac..8d324bfbd92e 100644 --- a/components/ai_chat/core/browser/ai_chat_service_unittest.cc +++ b/components/ai_chat/core/browser/ai_chat_service_unittest.cc @@ -131,7 +131,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI { explicit MockConversationHandlerClient(ConversationHandler* driver) { driver->Bind(conversation_handler_remote_.BindNewPipeAndPassReceiver(), conversation_ui_receiver_.BindNewPipeAndPassRemote()); - conversation_handler_ = driver; } ~MockConversationHandlerClient() override = default; @@ -141,9 +140,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI { conversation_ui_receiver_.reset(); } - ConversationHandler* GetConversationHandler() { - return conversation_handler_; - } MOCK_METHOD(void, OnConversationHistoryUpdate, (), (override)); @@ -174,7 +170,6 @@ class MockConversationHandlerClient : public mojom::ConversationUI { private: mojo::Receiver conversation_ui_receiver_{this}; mojo::Remote conversation_handler_remote_; - raw_ptr conversation_handler_; }; class MockAssociatedContent @@ -203,6 +198,30 @@ class MockAssociatedContent (override)); MOCK_METHOD(bool, HasOpenAIChatPermission, (), (const, override)); + void AddRelatedConversation(ConversationHandler* conversation) override { + related_conversations_.insert(conversation); + } + + void OnRelatedConversationDisassociated( + ConversationHandler* conversation) override { + related_conversations_.erase(conversation); + } + + void DisassociateWithConversations(std::string archived_text_content, + bool archived_is_video) { + std::vector> related_conversations; + for (auto& conversation : related_conversations_) { + related_conversations.push_back(conversation->GetWeakPtr()); + } + + for (auto& conversation : related_conversations) { + if (conversation) { + conversation->OnAssociatedContentDestroyed(archived_text_content, + archived_is_video); + } + } + } + base::WeakPtr GetWeakPtr() { return weak_ptr_factory_.GetWeakPtr(); } @@ -211,6 +230,7 @@ class MockAssociatedContent base::WeakPtrFactory weak_ptr_factory_{this}; int content_id_ = 0; + std::set> related_conversations_; }; } // namespace @@ -404,6 +424,56 @@ TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithMessages) { task_environment_.RunUntilIdle(); } +TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithContent) { + NiceMock associated_content{}; + ON_CALL(associated_content, GetURL()) + .WillByDefault(testing::Return(GURL("https://example.com"))); + associated_content.SetContentId(1); + ConversationHandler* conversation_with_content_no_messages = + ai_chat_service_->GetOrCreateConversationHandlerForContent( + associated_content.GetContentId(), associated_content.GetWeakPtr()); + EXPECT_TRUE(conversation_with_content_no_messages); + // Asking again for same content ID gets same conversation + EXPECT_EQ( + conversation_with_content_no_messages, + ai_chat_service_->GetOrCreateConversationHandlerForContent( + associated_content.GetContentId(), associated_content.GetWeakPtr())); + // Shouldn't be visible without messages + ExpectVisibleConversationsSize(FROM_HERE, 0u); + EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u); + // Disconnecting the client should unload the handler and delete the + // conversation. + auto client1 = + CreateConversationClient(conversation_with_content_no_messages); + DisconnectConversationClient(client1.get()); + EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u); + ExpectVisibleConversationsSize(FROM_HERE, 0u); + + // Create a new conversation for same content, with messages this time + ConversationHandler* conversation_with_content = + ai_chat_service_->GetOrCreateConversationHandlerForContent( + associated_content.GetContentId(), associated_content.GetWeakPtr()); + conversation_with_content->SetChatHistoryForTesting( + CreateSampleChatHistory(1u)); + ExpectVisibleConversationsSize(FROM_HERE, 1u); + EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u); + auto client2 = CreateConversationClient(conversation_with_content); + DisconnectConversationClient(client2.get()); + // Disconnecting all clients should keep the handler in memory until + // the content is destroyed. + EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u); + ExpectVisibleConversationsSize(FROM_HERE, 1u); + associated_content.DisassociateWithConversations("", false); + + if (IsAIChatHistoryEnabled()) { + EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u); + ExpectVisibleConversationsSize(FROM_HERE, 1u); + } else { + EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u); + ExpectVisibleConversationsSize(FROM_HERE, 0u); + } +} + TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) { ConversationHandler* conversation_without_content = CreateConversation(); @@ -439,27 +509,24 @@ TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) { // Creating a second conversation with the same associated content should // make the second conversation the default for that content, but leave // the first still associated with the content. - ConversationHandler* conversation_with_content2 = + ConversationHandler* conversation2 = ai_chat_service_->CreateConversationHandlerForContent( associated_content.GetContentId(), associated_content.GetWeakPtr()); - EXPECT_NE(conversation_with_content, conversation_with_content2); + EXPECT_NE(conversation_with_content, conversation2); EXPECT_NE(conversation_with_content->get_conversation_uuid(), - conversation_with_content2->get_conversation_uuid()); - EXPECT_EQ( - conversation_with_content2->GetAssociatedContentDelegateForTesting(), - &associated_content); - EXPECT_EQ( - conversation_with_content->GetAssociatedContentDelegateForTesting(), - conversation_with_content2->GetAssociatedContentDelegateForTesting()); + conversation2->get_conversation_uuid()); + EXPECT_EQ(conversation2->GetAssociatedContentDelegateForTesting(), + &associated_content); + EXPECT_EQ(conversation_with_content->GetAssociatedContentDelegateForTesting(), + conversation2->GetAssociatedContentDelegateForTesting()); // Check the second conversation is the default for that content ID EXPECT_EQ( ai_chat_service_->GetOrCreateConversationHandlerForContent( associated_content.GetContentId(), associated_content.GetWeakPtr()), - conversation_with_content2); + conversation2); // Let the conversation be deleted - std::string conversation2_uuid = - conversation_with_content2->get_conversation_uuid(); - auto client1 = CreateConversationClient(conversation_with_content2); + std::string conversation2_uuid = conversation2->get_conversation_uuid(); + auto client1 = CreateConversationClient(conversation2); DisconnectConversationClient(client1.get()); ConversationHandler* conversation_with_content3 = ai_chat_service_->GetOrCreateConversationHandlerForContent( diff --git a/components/ai_chat/core/browser/associated_content_driver.cc b/components/ai_chat/core/browser/associated_content_driver.cc index 905dbb46d993..c27cc4b09636 100644 --- a/components/ai_chat/core/browser/associated_content_driver.cc +++ b/components/ai_chat/core/browser/associated_content_driver.cc @@ -61,12 +61,7 @@ AssociatedContentDriver::AssociatedContentDriver( : url_loader_factory_(url_loader_factory) {} AssociatedContentDriver::~AssociatedContentDriver() { - for (auto& conversation : associated_conversations_) { - if (conversation) { - conversation->OnAssociatedContentDestroyed(cached_text_content_, - is_video_); - } - } + DisassociateWithConversations(); } void AssociatedContentDriver::AddRelatedConversation( @@ -273,10 +268,10 @@ void AssociatedContentDriver::OnFaviconImageDataChanged() { } void AssociatedContentDriver::OnNewPage(int64_t navigation_id) { - // Tell the associated_conversations_ that we're breaking up - for (auto& conversation : associated_conversations_) { - conversation->OnAssociatedContentDestroyed(cached_text_content_, is_video_); - } + // This instance will now be used for different content so existing + // conversations need to be disassociated. + DisassociateWithConversations(); + // Tell the observer how to find the next conversation for (auto& observer : observers_) { observer.OnAssociatedContentNavigated(navigation_id); @@ -292,4 +287,16 @@ void AssociatedContentDriver::OnNewPage(int64_t navigation_id) { ConversationHandler::AssociatedContentDelegate::OnNewPage(navigation_id); } +void AssociatedContentDriver::DisassociateWithConversations() { + // Iterator might be invalidated by destruction, so copy the items + std::vector conversations{ + associated_conversations_.begin(), associated_conversations_.end()}; + for (auto& conversation : conversations) { + if (conversation) { + conversation->OnAssociatedContentDestroyed(cached_text_content_, + is_video_); + } + } +} + } // namespace ai_chat diff --git a/components/ai_chat/core/browser/associated_content_driver.h b/components/ai_chat/core/browser/associated_content_driver.h index 2ea74ded4f88..a36343623a94 100644 --- a/components/ai_chat/core/browser/associated_content_driver.h +++ b/components/ai_chat/core/browser/associated_content_driver.h @@ -122,6 +122,12 @@ class AssociatedContentDriver ConversationHandler::GetStagedEntriesCallback callback, int64_t navigation_id, api_request_helper::APIRequestResult result); + + // Let all conversations using this content know that the content + // has been destroyed or changed to represent different content (e.g. a + // navigation). + void DisassociateWithConversations(); + static std::optional> ParseSearchQuerySummaryResponse(const base::Value& value); diff --git a/components/ai_chat/core/browser/conversation_handler.cc b/components/ai_chat/core/browser/conversation_handler.cc index 920b9211e8e2..9e721ec775b1 100644 --- a/components/ai_chat/core/browser/conversation_handler.cc +++ b/components/ai_chat/core/browser/conversation_handler.cc @@ -199,6 +199,10 @@ bool ConversationHandler::HasAnyHistory() { }); } +bool ConversationHandler::IsAssociatedContentAlive() { + return associated_content_delegate_ && !archive_content_; +} + void ConversationHandler::OnConversationDeleted() { for (auto& client : conversation_ui_handlers_) { client->OnConversationDeleted(); @@ -256,11 +260,11 @@ void ConversationHandler::InitEngine() { void ConversationHandler::OnAssociatedContentDestroyed( std::string last_text_content, bool is_video) { - // The associated content delegate is destroyed, so we should not try to - // fetch. It may be populated later, e.g. through back navigation. - // If this conversation is allowed to be associated with content, we can keep - // using our current cached content. - associated_content_delegate_ = nullptr; + // The associated content delegate is already or about to be destroyed. + auto content_id = associated_content_delegate_ + ? associated_content_delegate_->GetContentId() + : -1; + DisassociateContentDelegate(); if (!chat_history_.empty() && should_send_page_contents_ && associated_content_info_ && associated_content_info_->url.has_value()) { // Get the latest version of article text and @@ -276,6 +280,10 @@ void ConversationHandler::OnAssociatedContentDestroyed( archive_content_ = std::move(archive_content); } OnAssociatedContentInfoChanged(); + // Notify observers + for (auto& observer : observers_) { + observer.OnAssociatedContentDestroyed(this, content_id); + } } void ConversationHandler::SetAssociatedContentDelegate( diff --git a/components/ai_chat/core/browser/conversation_handler.h b/components/ai_chat/core/browser/conversation_handler.h index b44c629cb85d..e57e6d0a4244 100644 --- a/components/ai_chat/core/browser/conversation_handler.h +++ b/components/ai_chat/core/browser/conversation_handler.h @@ -142,6 +142,8 @@ class ConversationHandler : public mojom::ConversationHandler, virtual void OnSelectedLanguageChanged( ConversationHandler* handler, const std::string& selected_language) {} + virtual void OnAssociatedContentDestroyed(ConversationHandler* handler, + int content_id) {} }; ConversationHandler( @@ -167,6 +169,9 @@ class ConversationHandler : public mojom::ConversationHandler, bool HasAnyHistory(); void OnConversationDeleted(); + // Returns true if the conversation has associated content that is non-archive + bool IsAssociatedContentAlive(); + // Called when the associated content is destroyed or navigated away. If // it's a navigation, the AssociatedContentDelegate will set itself to a new // ConversationHandler.