Skip to content

Commit

Permalink
[AI Chat]: Manage associations from PageHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
fallaciousreasoning committed Jan 9, 2025
1 parent d4fd784 commit 7ce3f2c
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 117 deletions.
2 changes: 0 additions & 2 deletions browser/ai_chat/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import("//printing/buildflags/buildflags.gni")

static_library("ai_chat") {
sources = [
"ai_chat_associated_tab_provider.cc",
"ai_chat_associated_tab_provider.h",
"ai_chat_service_factory.cc",
"ai_chat_service_factory.h",
"ai_chat_settings_helper.cc",
Expand Down
29 changes: 0 additions & 29 deletions browser/ai_chat/ai_chat_associated_tab_provider.cc

This file was deleted.

24 changes: 0 additions & 24 deletions browser/ai_chat/ai_chat_associated_tab_provider.h

This file was deleted.

5 changes: 1 addition & 4 deletions browser/ai_chat/ai_chat_service_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <utility>

#include "base/no_destructor.h"
#include "brave/browser/ai_chat/ai_chat_associated_tab_provider.h"
#include "brave/browser/ai_chat/ai_chat_utils.h"
#include "brave/browser/brave_browser_process.h"
#include "brave/browser/misc_metrics/process_misc_metrics.h"
Expand Down Expand Up @@ -75,9 +74,7 @@ AIChatServiceFactory::BuildServiceInstanceForBrowserContext(
return std::make_unique<AIChatService>(
ModelServiceFactory::GetForBrowserContext(context),

std::move(credential_manager),
std::make_unique<AIChatAssociatedTabProvider>(),
user_prefs::UserPrefs::Get(context),
std::move(credential_manager), user_prefs::UserPrefs::Get(context),
(g_brave_browser_process->process_misc_metrics())
? g_brave_browser_process->process_misc_metrics()->ai_chat_metrics()
: nullptr,
Expand Down
38 changes: 38 additions & 0 deletions browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,44 @@ void AIChatUIPageHandler::HandleVoiceRecognition(
#endif
}

void AIChatUIPageHandler::AssociateTab(mojom::AvailableTabPtr tab,
const std::string& conversation_uuid) {
auto* content = content::WebContents::FromFrameTreeNodeId(
static_cast<content::FrameTreeNodeId>(tab->frame_tree_node_id));
if (!content) {
return;
}

// Ensure the content is loaded before associating
content->GetController().LoadIfNecessary();

auto* tab_helper = ai_chat::AIChatTabHelper::FromWebContents(content);
if (!tab_helper) {
return;
}

AIChatServiceFactory::GetForBrowserContext(profile_)->AssociateContent(
tab_helper, conversation_uuid);
}

void AIChatUIPageHandler::DisassociateTab(
mojom::AvailableTabPtr tab,
const std::string& conversation_uuid) {
auto* content = content::WebContents::FromFrameTreeNodeId(
static_cast<content::FrameTreeNodeId>(tab->frame_tree_node_id));
if (!content) {
return;
}

auto* tab_helper = ai_chat::AIChatTabHelper::FromWebContents(content);
if (!tab_helper) {
return;
}

AIChatServiceFactory::GetForBrowserContext(profile_)->DisassociateContent(
tab_helper, conversation_uuid);
}

void AIChatUIPageHandler::OpenAIChatSettings() {
content::WebContents* contents_to_navigate =
(active_chat_tab_helper_) ? active_chat_tab_helper_->web_contents()
Expand Down
4 changes: 4 additions & 0 deletions browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class AIChatUIPageHandler : public mojom::AIChatUIHandler,
void RefreshPremiumSession() override;
void ManagePremium() override;
void HandleVoiceRecognition(const std::string& conversation_uuid) override;
void AssociateTab(mojom::AvailableTabPtr tab,
const std::string& conversation_uuid) override;
void DisassociateTab(mojom::AvailableTabPtr tab,
const std::string& conversation_uuid) override;
void CloseUI() override;
void SetChatUI(mojo::PendingRemote<mojom::ChatUI> chat_ui,
SetChatUICallback callback) override;
Expand Down
1 change: 0 additions & 1 deletion components/ai_chat/core/browser/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ static_library("browser") {
"associated_archive_content.h",
"associated_content_driver.cc",
"associated_content_driver.h",
"associated_tab_delegate.h",
"constants.cc",
"constants.h",
"conversation_handler.cc",
Expand Down
23 changes: 12 additions & 11 deletions components/ai_chat/core/browser/ai_chat_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ bool IsConversationUpdatedTimeWithinRange(
AIChatService::AIChatService(
ModelService* model_service,
std::unique_ptr<AIChatCredentialManager> ai_chat_credential_manager,
std::unique_ptr<AssociatedTabDelegate> associated_tab_delegate,
PrefService* profile_prefs,
AIChatMetrics* ai_chat_metrics,
os_crypt_async::OSCryptAsync* os_crypt_async,
Expand All @@ -109,7 +108,6 @@ AIChatService::AIChatService(
std::make_unique<AIChatFeedbackAPI>(url_loader_factory_,
std::string(channel_string))),
credential_manager_(std::move(ai_chat_credential_manager)),
associated_tab_delegate_(std::move(associated_tab_delegate)),
profile_path_(profile_path) {
DCHECK(profile_prefs_);
pref_change_registrar_.Init(profile_prefs_);
Expand Down Expand Up @@ -297,6 +295,18 @@ ConversationHandler* AIChatService::CreateConversationHandlerForContent(
return conversation;
}

void AIChatService::AssociateContent(AssociatedContentDriver* driver,
const std::string& conversation_uuid) {
DCHECK(base::Contains(conversation_handlers_, conversation_uuid));
conversation_handlers_.at(conversation_uuid)->AddAssociation(driver);
}

void AIChatService::DisassociateContent(AssociatedContentDriver* driver,
const std::string& conversation_uuid) {
DCHECK(base::Contains(conversation_handlers_, conversation_uuid));
conversation_handlers_.at(conversation_uuid)->RemoveAssociation(driver);
}

void AIChatService::DeleteConversations(std::optional<base::Time> begin_time,
std::optional<base::Time> end_time) {
if (!begin_time.has_value() && !end_time.has_value()) {
Expand Down Expand Up @@ -951,13 +961,4 @@ void AIChatService::OpenConversationWithStagedEntries(
conversation->MaybeFetchOrClearContentStagedConversation();
}

AssociatedContentDriver* AIChatService::GetAssociatedContent(
const mojom::AvailableTabPtr& tab) {
// Can be null in tests
if (!associated_tab_delegate_) {
return nullptr;
}
return associated_tab_delegate_->GetAssociatedContent(tab);
}

} // namespace ai_chat
12 changes: 5 additions & 7 deletions components/ai_chat/core/browser/ai_chat_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include "brave/components/ai_chat/core/browser/ai_chat_database.h"
#include "brave/components/ai_chat/core/browser/ai_chat_feedback_api.h"
#include "brave/components/ai_chat/core/browser/ai_chat_metrics.h"
#include "brave/components/ai_chat/core/browser/associated_tab_delegate.h"
#include "brave/components/ai_chat/core/browser/conversation_handler.h"
#include "brave/components/ai_chat/core/browser/engine/engine_consumer.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h"
Expand All @@ -56,7 +55,6 @@ namespace ai_chat {

class ModelService;
class AIChatMetrics;
class AvailableTabDelegate;
class AssociatedContentDriver;

// Main entry point for creating and consuming AI Chat conversations
Expand All @@ -70,7 +68,6 @@ class AIChatService : public KeyedService,
AIChatService(
ModelService* model_service,
std::unique_ptr<AIChatCredentialManager> ai_chat_credential_manager,
std::unique_ptr<AssociatedTabDelegate> associated_tab_delegate,
PrefService* profile_prefs,
AIChatMetrics* ai_chat_metrics,
os_crypt_async::OSCryptAsync* os_crypt_async,
Expand Down Expand Up @@ -126,6 +123,11 @@ class AIChatService : public KeyedService,
base::WeakPtr<ConversationHandler::AssociatedContentDelegate>
associated_content);

void AssociateContent(AssociatedContentDriver* associated_content,
const std::string& conversation_uuid);
void DisassociateContent(AssociatedContentDriver* associated_content,
const std::string& conversation_uuid);

// Removes all in-memory and persisted data for all conversations
void DeleteConversations(std::optional<base::Time> begin_time = std::nullopt,
std::optional<base::Time> end_time = std::nullopt);
Expand All @@ -141,9 +143,6 @@ class AIChatService : public KeyedService,
associated_content,
base::OnceClosure open_ai_chat);

AssociatedContentDriver* GetAssociatedContent(
const mojom::AvailableTabPtr& tab);

// mojom::Service
void MarkAgreementAccepted() override;
void EnableStoragePref() override;
Expand Down Expand Up @@ -237,7 +236,6 @@ class AIChatService : public KeyedService,

std::unique_ptr<AIChatFeedbackAPI> feedback_api_;
std::unique_ptr<AIChatCredentialManager> credential_manager_;
std::unique_ptr<AssociatedTabDelegate> associated_tab_delegate_;

base::FilePath profile_path_;

Expand Down
24 changes: 0 additions & 24 deletions components/ai_chat/core/browser/associated_tab_delegate.h

This file was deleted.

17 changes: 8 additions & 9 deletions components/ai_chat/core/browser/conversation_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "brave/components/ai_chat/core/browser/ai_chat_feedback_api.h"
#include "brave/components/ai_chat/core/browser/ai_chat_service.h"
#include "brave/components/ai_chat/core/browser/associated_archive_content.h"
#include "brave/components/ai_chat/core/browser/associated_content_driver.h"
#include "brave/components/ai_chat/core/browser/local_models_updater.h"
#include "brave/components/ai_chat/core/browser/model_service.h"
#include "brave/components/ai_chat/core/browser/model_validator.h"
Expand Down Expand Up @@ -541,13 +542,12 @@ void ConversationHandler::GetState(GetStateCallback callback) {
std::move(callback).Run(std::move(state));
}

void ConversationHandler::AddAssociatedTab(mojom::AvailableTabPtr tab) {
if (HasAnyHistory()) {
void ConversationHandler::AddAssociation(AssociatedContentDriver* delegate) {
if (!delegate) {
return;
}

auto* tab_delegate = ai_chat_service_->GetAssociatedContent(tab);
if (!tab_delegate) {
if (HasAnyHistory()) {
return;
}

Expand All @@ -557,20 +557,20 @@ void ConversationHandler::AddAssociatedTab(mojom::AvailableTabPtr tab) {
}

if (!multi_content_) {
std::vector<AssociatedContentDriver*> drivers{tab_delegate};
std::vector<AssociatedContentDriver*> drivers{delegate};
auto driver =
std::make_unique<MultiAssociatedContentDriver>(std::move(drivers));
SetMultiAssociatedContentDelegate(std::move(driver));
} else {
multi_content_->AddContent(tab_delegate);
multi_content_->AddContent(delegate);
}

OnAssociatedContentInfoChanged();
MaybeSeedOrClearSuggestions();
MaybeFetchOrClearContentStagedConversation();
}

void ConversationHandler::RemoveAssociatedTab(mojom::AvailableTabPtr tab) {
void ConversationHandler::RemoveAssociation(AssociatedContentDriver* delegate) {
if (HasAnyHistory()) {
return;
}
Expand All @@ -581,8 +581,7 @@ void ConversationHandler::RemoveAssociatedTab(mojom::AvailableTabPtr tab) {
}

if (multi_content_) {
auto* tab_delegate = ai_chat_service_->GetAssociatedContent(tab);
multi_content_->RemoveContent(tab_delegate);
multi_content_->RemoveContent(delegate);
}

OnAssociatedContentInfoChanged();
Expand Down
6 changes: 4 additions & 2 deletions components/ai_chat/core/browser/conversation_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class AIChatService;
class AssociatedArchiveContent;
class AIChatCredentialManager;
class MultiAssociatedContentDriver;
class AssociatedContentDriver;

// Performs all conversation-related operations, responsible for sending
// messages to the conversation engine, handling the responses, and owning
Expand Down Expand Up @@ -238,10 +239,11 @@ class ConversationHandler : public mojom::ConversationHandler,
const mojom::Model& GetCurrentModel();
const std::vector<mojom::ConversationTurnPtr>& GetConversationHistory() const;

void AddAssociation(AssociatedContentDriver* delegate);
void RemoveAssociation(AssociatedContentDriver* delegate);

// mojom::ConversationHandler
void GetState(GetStateCallback callback) override;
void AddAssociatedTab(mojom::AvailableTabPtr tab) override;
void RemoveAssociatedTab(mojom::AvailableTabPtr tab) override;
void GetConversationHistory(GetConversationHistoryCallback callback) override;
void RateMessage(bool is_liked,
const std::string& turn_uuid,
Expand Down
7 changes: 4 additions & 3 deletions components/ai_chat/core/common/mojom/ai_chat.mojom
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,10 @@ interface AIChatUIHandler {
ManagePremium();
HandleVoiceRecognition(string conversation_uuid);

// Methods for managing multi-tab content association
AssociateTab(AvailableTab tab, string conversation_uuid);
DisassociateTab(AvailableTab tab, string conversation_uuid);

// This might be a no-op if the UI isn't closeable
CloseUI();

Expand Down Expand Up @@ -449,9 +453,6 @@ struct ConversationEntriesState {
// Browser-side handler for a Conversation
interface ConversationHandler {
GetState() => (ConversationState conversation_state);
AddAssociatedTab(AvailableTab tab);
RemoveAssociatedTab(AvailableTab tab);

GetConversationUuid() => (string conversation_uuid);

// Get all the possible models for any conversation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function SiteTitle(props: SiteTitleProps) {
{props.detail.title}
</p>
</div>
{removable && <Button fab kind='plain-faint' onClick={() => conversation.conversationHandler?.removeAssociatedTab(tab)}>
{removable && <Button fab kind='plain-faint' onClick={() => aiChat.uiHandler?.disassociateTab(tab, conversation.conversationUuid!)}>
<Icon name='trash' />
</Button>}
</div>
Expand Down

0 comments on commit 7ce3f2c

Please sign in to comment.