diff --git a/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc b/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc index aedd32563d60..020553eb65ec 100644 --- a/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc +++ b/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc @@ -97,7 +97,7 @@ void AIChatUIPageHandler::SetClientPage( // ex. A user may ask a question from the location bar if (active_chat_tab_helper_ && active_chat_tab_helper_->HasPendingConversationEntry()) { - OnConversationEntryPending(); + OnHistoryUpdate(); } } @@ -124,10 +124,12 @@ void AIChatUIPageHandler::ChangeModel(const std::string& model_key) { void AIChatUIPageHandler::SubmitHumanConversationEntry( const std::string& input) { + DCHECK(!active_chat_tab_helper_->IsRequestInProgress()) + << "Should not be able to submit more" + << "than a single human conversation turn at a time."; mojom::ConversationTurn turn = {CharacterType::HUMAN, ConversationTurnVisibility::VISIBLE, input}; - active_chat_tab_helper_->MakeAPIRequestWithConversationHistoryUpdate( - std::move(turn)); + active_chat_tab_helper_->SubmitHumanConversationEntry(std::move(turn)); } void AIChatUIPageHandler::SubmitSummarizationRequest() { @@ -142,21 +144,9 @@ void AIChatUIPageHandler::GetConversationHistory( std::move(callback).Run({}); return; } - std::vector history = - active_chat_tab_helper_->GetConversationHistory(); - std::vector list; - - // Remove conversations that are meant to be hidden from the user - auto new_end_it = std::remove_if( - history.begin(), history.end(), [](const ConversationTurn& turn) { - return turn.visibility == ConversationTurnVisibility::HIDDEN; - }); - - std::transform(history.begin(), new_end_it, std::back_inserter(list), - [](const ConversationTurn& turn) { return turn.Clone(); }); - - std::move(callback).Run(std::move(list)); + std::move(callback).Run( + active_chat_tab_helper_->GetVisibleConversationHistory()); } void AIChatUIPageHandler::GetSuggestedQuestions( @@ -412,12 +402,6 @@ void AIChatUIPageHandler::OnPageHasContent(mojom::SiteInfoPtr site_info) { } } -void AIChatUIPageHandler::OnConversationEntryPending() { - if (page_.is_bound()) { - page_->OnConversationEntryPending(); - } -} - void AIChatUIPageHandler::GetFaviconImageData( GetFaviconImageDataCallback callback) { if (!active_chat_tab_helper_) { diff --git a/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h b/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h index 7e33ae2243c5..443f522d0371 100644 --- a/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h +++ b/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h @@ -93,7 +93,6 @@ class AIChatUIPageHandler : public ai_chat::mojom::PageHandler, mojom::SuggestionGenerationStatus suggestion_generation_status) override; void OnFaviconImageDataChanged() override; void OnPageHasContent(mojom::SiteInfoPtr site_info) override; - void OnConversationEntryPending() override; void GetFaviconImageData(GetFaviconImageDataCallback callback) override; diff --git a/chromium_src/chrome/browser/autocomplete/chrome_autocomplete_provider_client.cc b/chromium_src/chrome/browser/autocomplete/chrome_autocomplete_provider_client.cc index 92e45209e408..b0b72042fe61 100644 --- a/chromium_src/chrome/browser/autocomplete/chrome_autocomplete_provider_client.cc +++ b/chromium_src/chrome/browser/autocomplete/chrome_autocomplete_provider_client.cc @@ -69,7 +69,7 @@ void ChromeAutocompleteProviderClient::OpenLeo(const std::u16string& query) { ai_chat::mojom::CharacterType::HUMAN, ai_chat::mojom::ConversationTurnVisibility::VISIBLE, base::UTF16ToUTF8(query)}; - chat_tab_helper->MakeAPIRequestWithConversationHistoryUpdate(std::move(turn)); + chat_tab_helper->SubmitHumanConversationEntry(std::move(turn)); ai_chat::AIChatMetrics* metrics = g_brave_browser_process->process_misc_metrics()->ai_chat_metrics(); CHECK(metrics); diff --git a/components/ai_chat/content/browser/ai_chat_tab_helper.cc b/components/ai_chat/content/browser/ai_chat_tab_helper.cc index 8db98cdc2c93..d2621d85e31d 100644 --- a/components/ai_chat/content/browser/ai_chat_tab_helper.cc +++ b/components/ai_chat/content/browser/ai_chat_tab_helper.cc @@ -58,56 +58,45 @@ AIChatTabHelper::~AIChatTabHelper() = default; // content::WebContentsObserver -void AIChatTabHelper::DocumentOnLoadCompletedInPrimaryMainFrame() { - // We might have content here, so check. - // TODO(petemill): If there are other navigation events to also - // check if content is available at, then start a queue and make - // sure we don't have multiple async distills going on at the same time. - MaybeGeneratePageText(); -} - void AIChatTabHelper::WebContentsDestroyed() { - CleanUp(); favicon::ContentFaviconDriver::FromWebContents(web_contents()) ->RemoveObserver(this); } void AIChatTabHelper::DidFinishNavigation( content::NavigationHandle* navigation_handle) { - // Store current navigation ID of the main document - // so that we can ignore async responses against any navigated-away-from - // documents. if (!navigation_handle->IsInMainFrame()) { - DVLOG(3) << "FinishNavigation NOT in main frame"; return; } DVLOG(2) << __func__ << navigation_handle->GetNavigationId() << " url: " << navigation_handle->GetURL().spec() << " same document? " << navigation_handle->IsSameDocument(); - SetNavigationId(navigation_handle->GetNavigationId()); + // Allow same-document navigation, as content often changes as a result // of framgment / pushState / replaceState navigations. // Content won't be retrieved immediately and we don't have a similar // "DOM Content Loaded" event, so let's wait for something else such as - // page title changing, or a timer completing before calling - // |MaybeGeneratePageText|. - SetSameDocumentNavigation(navigation_handle->IsSameDocument()); - // Experimentally only call |CleanUp| _if_ a same-page navigation - // results in a page title change (see |TtileWasSet|). - if (!IsSameDocumentNavigation()) { - CleanUp(); + // page title changing before committing to starting a new conversation + // and treating it as a "fresh page". + is_same_document_navigation_ = navigation_handle->IsSameDocument(); + pending_navigation_id_ = navigation_handle->GetNavigationId(); + // Experimentally only call |OnNewPage| for same-page navigations _if_ + // it results in a page title change (see |TtileWasSet|). + if (!is_same_document_navigation_) { + OnNewPage(pending_navigation_id_); } } void AIChatTabHelper::TitleWasSet(content::NavigationEntry* entry) { DVLOG(3) << __func__ << entry->GetTitle(); if (is_same_document_navigation_) { - // Seems as good a time as any to check for content after a same-document - // navigation. - // We only perform CleanUp here in case it was a minor pushState / fragment - // navigation and didn't result in new content. - CleanUp(); - MaybeGeneratePageText(); + DVLOG(3) << "Same document navigation detected new \"page\" - calling " + "OnNewPage()"; + // Page title modification after same-document navigation seems as good a + // time as any to assume meaningful changes occured to the content. + OnNewPage(pending_navigation_id_); + // Don't respond to further TitleWasSet + is_same_document_navigation_ = false; } } @@ -129,16 +118,9 @@ GURL AIChatTabHelper::GetPageURL() const { } void AIChatTabHelper::GetPageContent( - base::OnceCallback callback) const { - FetchPageContent(web_contents(), std::move(callback)); -} - -bool AIChatTabHelper::HasPrimaryMainFrame() const { - return web_contents()->GetPrimaryMainFrame() != nullptr; -} - -bool AIChatTabHelper::IsDocumentOnLoadCompletedInPrimaryMainFrame() const { - return web_contents()->IsDocumentOnLoadCompletedInPrimaryMainFrame(); + GetPageContentCallback callback, + std::string_view invalidation_token) const { + FetchPageContent(web_contents(), invalidation_token, std::move(callback)); } std::u16string AIChatTabHelper::GetPageTitle() const { diff --git a/components/ai_chat/content/browser/ai_chat_tab_helper.h b/components/ai_chat/content/browser/ai_chat_tab_helper.h index 6190d2e72a06..59725cb89fe3 100644 --- a/components/ai_chat/content/browser/ai_chat_tab_helper.h +++ b/components/ai_chat/content/browser/ai_chat_tab_helper.h @@ -48,7 +48,6 @@ class AIChatTabHelper : public content::WebContentsObserver, PrefService* local_state_prefs); // content::WebContentsObserver - void DocumentOnLoadCompletedInPrimaryMainFrame() override; void WebContentsDestroyed() override; void DidFinishNavigation( content::NavigationHandle* navigation_handle) override; @@ -63,18 +62,14 @@ class AIChatTabHelper : public content::WebContentsObserver, // ai_chat::ConversationDriver GURL GetPageURL() const override; - void GetPageContent(base::OnceCallback - callback) const override; - bool HasPrimaryMainFrame() const override; - bool IsDocumentOnLoadCompletedInPrimaryMainFrame() const override; + void GetPageContent(GetPageContentCallback callback, + std::string_view invalidation_token) const override; std::u16string GetPageTitle() const override; raw_ptr ai_chat_metrics_; - // Store the unique ID for each navigation so that - // we can ignore API responses for previous navigations. - int64_t current_navigation_id_; bool is_same_document_navigation_ = false; + int64_t pending_navigation_id_; base::WeakPtrFactory weak_ptr_factory_{this}; WEB_CONTENTS_USER_DATA_KEY_DECL(); diff --git a/components/ai_chat/content/browser/page_content_fetcher.cc b/components/ai_chat/content/browser/page_content_fetcher.cc index bbf44c2662b6..01c885a59394 100644 --- a/components/ai_chat/content/browser/page_content_fetcher.cc +++ b/components/ai_chat/content/browser/page_content_fetcher.cc @@ -58,6 +58,7 @@ net::NetworkTrafficAnnotationTag GetNetworkTrafficAnnotationTag() { class PageContentFetcher { public: void Start(mojo::Remote content_extractor, + std::string_view invalidation_token, scoped_refptr url_loader_factory, FetchPageContentCallback callback) { url_loader_factory_ = url_loader_factory; @@ -70,9 +71,9 @@ class PageContentFetcher { // after it is destroyed. content_extractor_.set_disconnect_handler(base::BindOnce( &PageContentFetcher::DeleteSelf, base::Unretained(this))); - content_extractor_->ExtractPageContent( - base::BindOnce(&PageContentFetcher::OnTabContentResult, - base::Unretained(this), std::move(callback))); + content_extractor_->ExtractPageContent(base::BindOnce( + &PageContentFetcher::OnTabContentResult, base::Unretained(this), + std::move(callback), invalidation_token)); } private: @@ -80,12 +81,14 @@ class PageContentFetcher { void SendResultAndDeleteSelf(FetchPageContentCallback callback, std::string content = "", + std::string invalidation_token = "", bool is_video = false) { - std::move(callback).Run(content, is_video); + std::move(callback).Run(content, is_video, invalidation_token); delete this; } void OnTabContentResult(FetchPageContentCallback callback, + std::string_view invalidation_token, mojom::PageContentPtr data) { if (!data) { VLOG(1) << __func__ << " no data."; @@ -101,7 +104,7 @@ class PageContentFetcher { auto content = data->content->get_content(); DVLOG(1) << __func__ << ": Got content with char length of " << content.length(); - SendResultAndDeleteSelf(std::move(callback), content, false); + SendResultAndDeleteSelf(std::move(callback), content, "", false); return; } // If it's video, we expect content url @@ -110,7 +113,16 @@ class PageContentFetcher { if (content_url.is_empty() || !content_url.is_valid() || !content_url.SchemeIs(url::kHttpsScheme)) { VLOG(1) << "Invalid content_url"; - SendResultAndDeleteSelf(std::move(callback), "", true); + SendResultAndDeleteSelf(std::move(callback), "", "", true); + return; + } + // Subsequent calls do not need to re-fetch if the url stays the same + auto new_invalidation_token = content_url.spec(); + if (new_invalidation_token == invalidation_token) { + VLOG(2) << "Not fetching content since invalidation token matches: " + << invalidation_token; + SendResultAndDeleteSelf(std::move(callback), "", new_invalidation_token, + true); return; } DVLOG(1) << "Making video transcript fetch to " << content_url.spec(); @@ -132,13 +144,14 @@ class PageContentFetcher { auto on_response = base::BindOnce(&PageContentFetcher::OnTranscriptFetchResponse, weak_ptr_factory_.GetWeakPtr(), std::move(callback), - std::move(loader), is_youtube); + std::move(loader), is_youtube, new_invalidation_token); loader_ptr->DownloadToString(url_loader_factory_.get(), std::move(on_response), 2 * 1024 * 1024); } void OnYoutubeTranscriptXMLParsed( FetchPageContentCallback callback, + std::string invalidation_token, base::expected result) { // Example Youtube transcript XML: // @@ -182,13 +195,15 @@ class PageContentFetcher { transcript_text += text; } - SendResultAndDeleteSelf(std::move(callback), transcript_text, true); + SendResultAndDeleteSelf(std::move(callback), transcript_text, + invalidation_token, true); } void OnTranscriptFetchResponse( FetchPageContentCallback callback, std::unique_ptr loader, bool is_youtube, + std::string invalidation_token, std::unique_ptr response_body) { auto response_code = -1; base::flat_map headers; @@ -215,11 +230,13 @@ class PageContentFetcher { data_decoder::mojom::XmlParser::WhitespaceBehavior:: kPreserveSignificant, base::BindOnce(&PageContentFetcher::OnYoutubeTranscriptXMLParsed, - weak_ptr_factory_.GetWeakPtr(), std::move(callback))); + weak_ptr_factory_.GetWeakPtr(), std::move(callback), + invalidation_token)); return; } - SendResultAndDeleteSelf(std::move(callback), transcript_content, true); + SendResultAndDeleteSelf(std::move(callback), transcript_content, + invalidation_token, true); } scoped_refptr url_loader_factory_; @@ -230,6 +247,7 @@ class PageContentFetcher { } // namespace void FetchPageContent(content::WebContents* web_contents, + std::string_view invalidation_token, FetchPageContentCallback callback) { VLOG(2) << __func__ << " Extracting page content from renderer..."; @@ -240,7 +258,7 @@ void FetchPageContent(content::WebContents* web_contents, LOG(ERROR) << "Content extraction request submitted for a WebContents without " "a primary main frame"; - std::move(callback).Run("", false); + std::move(callback).Run("", false, ""); return; } @@ -255,7 +273,8 @@ void FetchPageContent(content::WebContents* web_contents, ->GetDefaultStoragePartition() ->GetURLLoaderFactoryForBrowserProcess() .get(); - fetcher->Start(std::move(extractor), loader, std::move(callback)); + fetcher->Start(std::move(extractor), invalidation_token, loader, + std::move(callback)); } } // namespace ai_chat diff --git a/components/ai_chat/content/browser/page_content_fetcher.h b/components/ai_chat/content/browser/page_content_fetcher.h index 5bcbc86aedd7..ce005ac6a847 100644 --- a/components/ai_chat/content/browser/page_content_fetcher.h +++ b/components/ai_chat/content/browser/page_content_fetcher.h @@ -17,9 +17,11 @@ class WebContents; namespace ai_chat { using FetchPageContentCallback = - base::OnceCallback; - + base::OnceCallback; void FetchPageContent(content::WebContents* web_contents, + std::string_view invalidation_token, FetchPageContentCallback callback); } // namespace ai_chat diff --git a/components/ai_chat/core/browser/conversation_driver.cc b/components/ai_chat/core/browser/conversation_driver.cc index 1b25688247e8..7dd4bcc440ba 100644 --- a/components/ai_chat/core/browser/conversation_driver.cc +++ b/components/ai_chat/core/browser/conversation_driver.cc @@ -15,6 +15,7 @@ #include "base/functional/bind.h" #include "base/memory/weak_ptr.h" #include "base/notreached.h" +#include "base/one_shot_event.h" #include "base/ranges/algorithm.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" @@ -50,14 +51,16 @@ bool IsPremiumStatus(mojom::PremiumStatus status) { } // namespace -ConversationDriver::ConversationDriver(raw_ptr pref_service, - raw_ptr ai_chat_metrics, - std::unique_ptr credential_manager, - scoped_refptr url_loader_factory) : - pref_service_(pref_service), +ConversationDriver::ConversationDriver( + raw_ptr pref_service, + raw_ptr ai_chat_metrics, + std::unique_ptr credential_manager, + scoped_refptr url_loader_factory) + : pref_service_(pref_service), ai_chat_metrics_(ai_chat_metrics), credential_manager_(std::move(credential_manager)), - url_loader_factory_(url_loader_factory) { + url_loader_factory_(url_loader_factory), + on_page_text_fetch_complete_(new base::OneShotEvent()) { DCHECK(pref_service_); pref_change_registrar_.Init(pref_service_); @@ -147,13 +150,27 @@ const std::vector& ConversationDriver::GetConversationHistory( return chat_history_; } +std::vector +ConversationDriver::GetVisibleConversationHistory() { + // Remove conversations that are meant to be hidden from the user + std::vector list; + for (const auto& turn : GetConversationHistory()) { + if (turn.visibility != ConversationTurnVisibility::HIDDEN) { + list.push_back(turn.Clone()); + } + } + if (pending_conversation_entry_ && pending_conversation_entry_->visibility != + ConversationTurnVisibility::HIDDEN) { + list.push_back(pending_conversation_entry_->Clone()); + } + return list; +} + void ConversationDriver::OnConversationActiveChanged(bool is_conversation_active) { is_conversation_active_ = is_conversation_active; DVLOG(3) << "Conversation active changed: " << is_conversation_active; - if (MaybePopPendingRequests()) { - return; - } - MaybeGeneratePageText(); + MaybeSeedOrClearSuggestions(); + MaybePopPendingRequests(); } void ConversationDriver::InitEngine() { @@ -207,9 +224,7 @@ bool ConversationDriver::HasUserOptedIn() { } void ConversationDriver::OnUserOptedIn() { - if (!MaybePopPendingRequests()) { - MaybeGeneratePageText(); - } + MaybePopPendingRequests(); if (ai_chat_metrics_ != nullptr && HasUserOptedIn()) { ai_chat_metrics_->RecordEnabled(true); @@ -276,106 +291,160 @@ bool ConversationDriver::MaybePopPendingRequests() { // We don't discard requests related to summarization until we have the // article text. - if (article_text_.empty() && pending_message_needs_page_content_) { + if (is_page_text_fetch_in_progress_) { return false; } mojom::ConversationTurn request = std::move(*pending_conversation_entry_); pending_conversation_entry_.reset(); - pending_message_needs_page_content_ = false; - MakeAPIRequestWithConversationHistoryUpdate(std::move(request)); + SubmitHumanConversationEntry(std::move(request)); return true; } -void ConversationDriver::MaybeGeneratePageText() { - const GURL url = GetPageURL(); - - if (!base::Contains(kAllowedSchemes, url.scheme())) { +void ConversationDriver::MaybeSeedOrClearSuggestions() { + if (!is_conversation_active_) { return; } - // User might have already asked questions before the page is loaded. It'd be - // strange if we generate contents based on the page. - // TODO(sko) This makes it impossible to ask something like "Summarize this - // page" once a user already asked a question. But for now we'd like to keep - // it simple and not confuse users with the context changing. We'll see what - // users say. - if (!chat_history_.empty()) { + const bool is_page_associated = + IsContentAssociationPossible() && should_send_page_contents_; + + if (!is_page_associated && !suggestions_.empty()) { + suggestions_.clear(); + OnSuggestedQuestionsChanged(); return; } + if (is_page_associated && suggestions_.empty() && + suggestion_generation_status_ != + mojom::SuggestionGenerationStatus::IsGenerating && + suggestion_generation_status_ != + mojom::SuggestionGenerationStatus::HasGenerated) { + // TODO(petemill): ask content fetcher if it knows whether current page is a + // video. + suggestions_.emplace_back( + is_video_ ? l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_VIDEO) + : l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_PAGE)); + suggestion_generation_status_ = + mojom::SuggestionGenerationStatus::CanGenerate; + OnSuggestedQuestionsChanged(); + } +} + +void ConversationDriver::GeneratePageContent(GetPageContentCallback callback) { + VLOG(1) << __func__; + DCHECK(should_send_page_contents_); + DCHECK(IsContentAssociationPossible()) + << "Shouldn't have been asked to generate page text when " + << "|IsContentAssociationPossible()| is false."; + DCHECK(!is_page_text_fetch_in_progress_) + << "UI shouldn't allow multiple operations at the same time"; + // Make sure user is opted in since this may make a network request // for more page content (e.g. video transcript). + DCHECK(HasUserOptedIn()) + << "UI shouldn't allow operations before user has accepted agreement"; + // Perf: make sure we're not doing this when the feature - // won't be used (e.g. not opted in or no active conversation). - if (is_page_text_fetch_in_progress_ || !article_text_.empty() || - !HasUserOptedIn() || !is_conversation_active_ || - !IsDocumentOnLoadCompletedInPrimaryMainFrame()) { + // won't be used (e.g. no active conversation). + DCHECK(is_conversation_active_) + << "UI shouldn't allow operations for an inactive conversation"; + + // Only perform a fetch once at a time, and then use the results from + // an in-progress operation. + if (is_page_text_fetch_in_progress_) { + VLOG(1) << "A page content fetch is in progress, waiting for the existing " + "operation to complete"; + auto handle_existing_fetch_complete = base::BindOnce( + &ConversationDriver::OnExistingGeneratePageContentComplete, + weak_ptr_factory_.GetWeakPtr(), std::move(callback)); + on_page_text_fetch_complete_->Post( + FROM_HERE, std::move(handle_existing_fetch_complete)); return; } - if (!HasPrimaryMainFrame()) { - VLOG(1) << "Summary request submitted for a WebContents without a " - "primary main frame"; - return; - } + is_page_text_fetch_in_progress_ = true; + // Update fetching status + OnPageHasContentChanged(BuildSiteInfo()); - if (should_send_page_contents_) { - is_page_text_fetch_in_progress_ = true; - // Update fetching status - OnPageHasContentChanged(BuildSiteInfo()); - GetPageContent( - base::BindOnce(&ConversationDriver::OnPageContentRetrieved, - weak_ptr_factory_.GetWeakPtr(), current_navigation_id_)); - } + GetPageContent( + base::BindOnce(&ConversationDriver::OnGeneratePageContentComplete, + weak_ptr_factory_.GetWeakPtr(), current_navigation_id_, + std::move(callback)), + content_invalidation_token_); } -void ConversationDriver::OnPageContentRetrieved(int64_t navigation_id, - std::string contents_text, - bool is_video) { +void ConversationDriver::OnGeneratePageContentComplete( + int64_t navigation_id, + GetPageContentCallback callback, + std::string contents_text, + bool is_video, + std::string invalidation_token) { + VLOG(1) << "OnGeneratePageContentComplete"; + VLOG(4) << "Contents(is_video=" << is_video + << ", invalidation_token=" << invalidation_token + << "): " << contents_text; if (navigation_id != current_navigation_id_) { VLOG(1) << __func__ << " for a different navigation. Ignoring."; return; } is_page_text_fetch_in_progress_ = false; + + // If invalidation token matches existing token, then + // content was not re-fetched and we can use our existing cache. + if (!invalidation_token.empty() && + (invalidation_token == content_invalidation_token_)) { + contents_text = article_text_; + } else { + is_video_ = is_video; + // Cache page content on instance so we don't always have to re-fetch + // if the content fetcher knows the content won't have changed and the fetch + // operation is expensive (e.g. network). + article_text_ = contents_text; + content_invalidation_token_ = invalidation_token; + engine_->SanitizeInput(article_text_); + // Update completion status + OnPageHasContentChanged(BuildSiteInfo()); + } + + on_page_text_fetch_complete_->Signal(); + on_page_text_fetch_complete_ = std::make_unique(); + if (contents_text.empty()) { VLOG(1) << __func__ << ": No data"; - return; } - is_video_ = is_video; - article_text_ = contents_text; - engine_->SanitizeInput(article_text_); + VLOG(4) << "calling callback with text: " << article_text_; - // Update completion status - OnPageHasContentChanged(BuildSiteInfo()); + std::move(callback).Run(article_text_, is_video_, + content_invalidation_token_); +} - // Now that we have article text, we can suggest to summarize it - DCHECK(suggestions_.empty()) - << "Expected suggested questions to be clear when there has been no" - << " previous text content but there were " << suggestions_.size() - << " suggested questions: " << base::JoinString(suggestions_, ", "); - - // Now that we have content, we can provide a summary on-demand. Add that to - // suggested questions. - suggestions_.emplace_back( - is_video_ ? l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_VIDEO) - : l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_PAGE)); - suggestion_generation_status_ = - mojom::SuggestionGenerationStatus::CanGenerate; - OnSuggestedQuestionsChanged(); - // We check again to see if any page content related prompt is pending - MaybePopPendingRequests(); +void ConversationDriver::OnExistingGeneratePageContentComplete( + GetPageContentCallback callback) { + // Don't need to check navigation ID since existing event will be + // deleted when there's a new conversation. + VLOG(1) << "Existing page content fetch completed, proceeding with " + "the results of that operation."; + std::move(callback).Run(article_text_, is_video_, + content_invalidation_token_); +} + +void ConversationDriver::OnNewPage(int64_t navigation_id) { + current_navigation_id_ = navigation_id; + CleanUp(); } void ConversationDriver::CleanUp() { + DVLOG(1) << __func__; chat_history_.clear(); article_text_.clear(); + content_invalidation_token_.clear(); + on_page_text_fetch_complete_ = std::make_unique(); + is_video_ = false; suggestions_.clear(); pending_conversation_entry_.reset(); - pending_message_needs_page_content_ = false; - is_same_document_navigation_ = false; is_page_text_fetch_in_progress_ = false; is_request_in_progress_ = false; suggestion_generation_status_ = mojom::SuggestionGenerationStatus::None; @@ -384,6 +453,8 @@ void ConversationDriver::CleanUp() { SetAPIError(mojom::APIError::None); engine_->ClearAllQueries(); + MaybeSeedOrClearSuggestions(); + // Trigger an observer update to refresh the UI. for (auto& obs : observers_) { obs.OnHistoryUpdate(); @@ -392,22 +463,6 @@ void ConversationDriver::CleanUp() { } } -int64_t ConversationDriver::GetNavigationId() const { - return current_navigation_id_; -} - -void ConversationDriver::SetNavigationId(int64_t navigation_id) { - current_navigation_id_ = navigation_id; -} - -bool ConversationDriver::IsSameDocumentNavigation() const { - return is_same_document_navigation_; -} - -void ConversationDriver::SetSameDocumentNavigation(bool same_document_navigation) { - is_same_document_navigation_ = same_document_navigation; -} - std::vector ConversationDriver::GetSuggestedQuestions( mojom::SuggestionGenerationStatus& suggestion_status) { // Can we get suggested questions @@ -416,9 +471,11 @@ std::vector ConversationDriver::GetSuggestedQuestions( } void ConversationDriver::SetShouldSendPageContents(bool should_send) { + DCHECK(IsContentAssociationPossible()); DCHECK(should_send_page_contents_ != should_send); - should_send_page_contents_ = should_send; + + MaybeSeedOrClearSuggestions(); } bool ConversationDriver::GetShouldSendPageContents() { @@ -450,17 +507,16 @@ void ConversationDriver::GenerateQuestions() { << "opted in to AI Chat"; return; } + DCHECK(should_send_page_contents_) + << "Cannot get suggestions when not associated with content."; + DCHECK(IsContentAssociationPossible()) + << "Should not be associated with content when not allowed to be"; // We're not expecting to call this if the UI is not active for this // conversation. DCHECK(is_conversation_active_); // We're not expecting to already have generated suggestions DCHECK_LE(suggestions_.size(), 1u); - // Can't operate if we don't have an article text - if (article_text_.empty()) { - return; - } - if (suggestion_generation_status_ == mojom::SuggestionGenerationStatus::IsGenerating || suggestion_generation_status_ == @@ -473,15 +529,22 @@ void ConversationDriver::GenerateQuestions() { suggestion_generation_status_ = mojom::SuggestionGenerationStatus::IsGenerating; OnSuggestedQuestionsChanged(); - // Make API request for questions. + // Make API request for questions but first get page content. // Do not call SetRequestInProgress, this progress // does not need to be shown to the UI. - auto navigation_id_for_query = current_navigation_id_; - engine_->GenerateQuestionSuggestions( - is_video_, article_text_, - base::BindOnce(&ConversationDriver::OnSuggestedQuestionsResponse, - weak_ptr_factory_.GetWeakPtr(), - std::move(navigation_id_for_query))); + auto on_content_retrieved = [](ConversationDriver* instance, + int64_t navigation_id, + std::string page_content, bool is_video, + std::string invalidation_token) { + instance->engine_->GenerateQuestionSuggestions( + is_video, page_content, + base::BindOnce(&ConversationDriver::OnSuggestedQuestionsResponse, + instance->weak_ptr_factory_.GetWeakPtr(), + std::move(navigation_id))); + }; + GeneratePageContent(base::BindOnce(std::move(on_content_retrieved), + base::Unretained(this), + current_navigation_id_)); } void ConversationDriver::OnSuggestedQuestionsResponse( @@ -502,32 +565,44 @@ void ConversationDriver::OnSuggestedQuestionsResponse( DVLOG(2) << "Got questions:" << base::JoinString(suggestions_, "\n"); } -void ConversationDriver::MakeAPIRequestWithConversationHistoryUpdate( - mojom::ConversationTurn turn, - bool needs_page_content /* = false */) { +void ConversationDriver::SubmitHumanConversationEntry( + mojom::ConversationTurn turn) { + VLOG(1) << __func__; + DVLOG(4) << __func__ << ": " << turn.text; // Decide if this entry needs to wait for one of: // - user to be opted-in // - conversation to be active - // - content to be retrieved + // - is request in progress (should only be possible if regular entry is + // in-progress and another entry is submitted outside of regular UI, e.g. from + // location bar. if (!is_conversation_active_ || !HasUserOptedIn() || - (article_text_.empty() && needs_page_content)) { - // This function should not be presented in the UI if the user has not - // opted-in yet. + is_request_in_progress_) { + VLOG(1) << "Adding as a pending conversation entry"; + // This is possible (on desktop) if user submits multiple location bar + // messages before an entry is complete. But that should be obvious from the + // UI that the 1 in-progress + 1 pending message is the limit. + if (pending_conversation_entry_) { + VLOG(1) << "Should not be able to add a pending conversation entry " + << "when there is already a pending conversation entry."; + return; + } pending_conversation_entry_ = std::make_unique(std::move(turn)); - - if (article_text_.empty() && needs_page_content) { - pending_message_needs_page_content_ = true; + // Pending entry is added to conversation history when asked for + // so notify observers. + for (auto& obs : observers_) { + obs.OnHistoryUpdate(); } - - // Invoking this before the creation of the page handler means the pending - // request will not be reported. - OnConversationEntryPending(); return; } DCHECK(turn.character_type == CharacterType::HUMAN); + is_request_in_progress_ = true; + for (auto& obs : observers_) { + obs.OnAPIRequestInProgress(IsRequestInProgress()); + } + bool is_suggested_question = false; // If it's a suggested question, remove it @@ -561,33 +636,46 @@ void ConversationDriver::MakeAPIRequestWithConversationHistoryUpdate( is_suggested_question ? std::vector() : chat_history_; - auto data_received_callback = base::BindRepeating( - &ConversationDriver::OnEngineCompletionDataReceived, - weak_ptr_factory_.GetWeakPtr(), current_navigation_id_); + // Add the human part to the conversation + AddToConversationHistory(std::move(turn)); - auto data_completed_callback = - base::BindOnce(&ConversationDriver::OnEngineCompletionComplete, - weak_ptr_factory_.GetWeakPtr(), current_navigation_id_); + const bool is_page_associated = + IsContentAssociationPossible() && should_send_page_contents_; - // Now the conversation is committed, we can remove some unneccessary data if - // we're not associated with a page. - if (!should_send_page_contents_) { + if (is_page_associated) { + // Fetch updated page content before performing generation + GeneratePageContent( + base::BindOnce(&ConversationDriver::PerformAssistantGeneration, + weak_ptr_factory_.GetWeakPtr(), question_part, + std::move(history), current_navigation_id_)); + } else { + // Now the conversation is committed, we can remove some unneccessary data + // if we're not associated with a page. article_text_.clear(); suggestions_.clear(); OnSuggestedQuestionsChanged(); + // Perform generation immediately + PerformAssistantGeneration(question_part, history, current_navigation_id_); } +} - engine_->GenerateAssistantResponse( - is_video_, article_text_, history, question_part, - std::move(data_received_callback), std::move(data_completed_callback)); - - // Add the human part to the conversation - AddToConversationHistory(std::move(turn)); +void ConversationDriver::PerformAssistantGeneration( + std::string input, + std::vector history, + int64_t current_navigation_id, + std::string page_content, + bool is_video, + std::string invalidation_token) { + auto data_received_callback = base::BindRepeating( + &ConversationDriver::OnEngineCompletionDataReceived, + weak_ptr_factory_.GetWeakPtr(), current_navigation_id); - is_request_in_progress_ = true; - for (auto& obs : observers_) { - obs.OnAPIRequestInProgress(IsRequestInProgress()); - } + auto data_completed_callback = + base::BindOnce(&ConversationDriver::OnEngineCompletionComplete, + weak_ptr_factory_.GetWeakPtr(), current_navigation_id); + engine_->GenerateAssistantResponse(is_video, page_content, history, input, + std::move(data_received_callback), + std::move(data_completed_callback)); } void ConversationDriver::RetryAPIRequest() { @@ -602,7 +690,7 @@ void ConversationDriver::RetryAPIRequest() { auto turn = *std::make_move_iterator(rit); auto human_turn_iter = rit.base() - 1; chat_history_.erase(human_turn_iter, chat_history_.end()); - MakeAPIRequestWithConversationHistoryUpdate(turn); + SubmitHumanConversationEntry(turn); break; } } @@ -696,8 +784,7 @@ void ConversationDriver::SubmitSummarizationRequest() { mojom::ConversationTurn turn = { CharacterType::HUMAN, ConversationTurnVisibility::VISIBLE, l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_PAGE)}; - MakeAPIRequestWithConversationHistoryUpdate(std::move(turn), - /*needs_page_content=*/true); + SubmitHumanConversationEntry(std::move(turn)); } mojom::SiteInfoPtr ConversationDriver::BuildSiteInfo() { @@ -743,10 +830,4 @@ void ConversationDriver::OnPremiumStatusReceived( std::move(parent_callback).Run(premium_status); } -void ConversationDriver::OnConversationEntryPending() { - for (auto& obs : observers_) { - obs.OnConversationEntryPending(); - } -} - } // namespace ai_chat diff --git a/components/ai_chat/core/browser/conversation_driver.h b/components/ai_chat/core/browser/conversation_driver.h index 4abaa27ac1c8..1ddd47a15355 100644 --- a/components/ai_chat/core/browser/conversation_driver.h +++ b/components/ai_chat/core/browser/conversation_driver.h @@ -13,6 +13,7 @@ #include "base/memory/raw_ptr.h" #include "base/memory/ref_counted.h" #include "base/observer_list.h" +#include "base/one_shot_event.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" @@ -27,6 +28,15 @@ class AIChatMetrics; class ConversationDriver { public: + // |invalidation_token| is an optional parameter that will be passed back on + // the next call to |GetPageContent| so that the implementer may determine if + // the page content is static or if it needs to be fetched again. Most page + // content should be fetched again, but some pages are known to be static + // during their lifetime and may have expensive content fetching, e.g. videos + // with transcripts fetched over the network. + using GetPageContentCallback = base::OnceCallback< + void(std::string content, bool is_video, std::string invalidation_token)>; + class Observer : public base::CheckedObserver { public: ~Observer() override {} @@ -40,7 +50,6 @@ class ConversationDriver { mojom::SuggestionGenerationStatus suggestion_generation_status) {} virtual void OnFaviconImageDataChanged() {} virtual void OnPageHasContent(mojom::SiteInfoPtr site_info) {} - virtual void OnConversationEntryPending() {} }; ConversationDriver(raw_ptr pref_service, @@ -55,15 +64,14 @@ class ConversationDriver { void ChangeModel(const std::string& model_key); const mojom::Model& GetCurrentModel(); const std::vector& GetConversationHistory(); + std::vector GetVisibleConversationHistory(); // Whether the UI for this conversation is open or not. Determines // whether content is retrieved and queries are sent for the conversation // when the page changes. void OnConversationActiveChanged(bool is_conversation_active); void AddToConversationHistory(mojom::ConversationTurn turn); void UpdateOrCreateLastAssistantEntry(std::string text); - void MakeAPIRequestWithConversationHistoryUpdate( - mojom::ConversationTurn turn, - bool needs_page_content = false); + void SubmitHumanConversationEntry(mojom::ConversationTurn turn); void RetryAPIRequest(); bool IsRequestInProgress(); void AddObserver(Observer* observer); @@ -89,32 +97,45 @@ class ConversationDriver { protected: virtual GURL GetPageURL() const = 0; virtual std::u16string GetPageTitle() const = 0; - virtual void GetPageContent( - base::OnceCallback callback) const = 0; - virtual bool HasPrimaryMainFrame() const = 0; - virtual bool IsDocumentOnLoadCompletedInPrimaryMainFrame() const = 0; - - virtual void OnFaviconImageDataChanged(); - void MaybeGeneratePageText(); - void CleanUp(); + // Implementer should fetch content from the "page" associated with this + // conversation. + // |is_video| lets the conversation know that the content is focused on video + // content so that various UI language can be adapted. + // |invalidation_token| is an optional parameter received in a prior callback + // response of this function against the same page. See GetPageContentCallback + // for explanation. + virtual void GetPageContent(GetPageContentCallback callback, + std::string_view invalidation_token) const = 0; - int64_t GetNavigationId() const; - void SetNavigationId(int64_t navigation_id); + virtual void OnFaviconImageDataChanged(); - bool IsSameDocumentNavigation() const; - void SetSameDocumentNavigation(bool same_document_navigation); + // To be called when a page navigation is detected and a new conversation + // is expected. + void OnNewPage(int64_t navigation_id); private: void InitEngine(); bool HasUserOptedIn(); void OnUserOptedIn(); bool MaybePopPendingRequests(); - void MaybeGenerateQuestions(); + void MaybeSeedOrClearSuggestions(); + + void PerformAssistantGeneration(std::string input, + std::vector history, + int64_t current_navigation_id, + std::string page_content = "", + bool is_video = false, + std::string invalidation_token = ""); + + void GeneratePageContent(GetPageContentCallback callback); + void OnGeneratePageContentComplete(int64_t navigation_id, + GetPageContentCallback callback, + std::string contents_text, + bool is_video, + std::string invalidation_token); + void OnExistingGeneratePageContentComplete(GetPageContentCallback callback); - void OnPageContentRetrieved(int64_t navigation_id, - std::string contents_text, - bool is_video = false); void OnEngineCompletionDataReceived(int64_t navigation_id, std::string result); void OnEngineCompletionComplete(int64_t navigation_id, @@ -126,11 +147,12 @@ class ConversationDriver { void OnPremiumStatusReceived( mojom::PageHandler::GetPremiumStatusCallback parent_callback, mojom::PremiumStatus premium_status); - void OnConversationEntryPending(); void SetAPIError(const mojom::APIError& error); bool IsContentAssociationPossible(); + void CleanUp(); + raw_ptr pref_service_; raw_ptr ai_chat_metrics_; std::unique_ptr credential_manager_; @@ -143,9 +165,14 @@ class ConversationDriver { // TODO(nullhook): Abstract the data model std::string model_key_; std::vector chat_history_; - std::string article_text_; bool is_conversation_active_ = false; + + // Page content + std::string article_text_; + std::string content_invalidation_token_; bool is_page_text_fetch_in_progress_ = false; + std::unique_ptr on_page_text_fetch_complete_; + bool is_request_in_progress_ = false; std::vector suggestions_; // Keep track of whether we've generated suggested questions for the current @@ -156,15 +183,16 @@ class ConversationDriver { mojom::SuggestionGenerationStatus::None; bool is_video_ = false; bool should_send_page_contents_ = true; - // Store the unique ID for each navigation so that - // we can ignore API responses for previous navigations. + + // Store the unique ID for each "page" so that + // we can ignore API async responses against any navigated-away-from + // documents. int64_t current_navigation_id_; - bool is_same_document_navigation_ = false; + mojom::APIError current_error_ = mojom::APIError::None; mojom::PremiumStatus last_premium_status_ = mojom::PremiumStatus::Unknown; std::unique_ptr pending_conversation_entry_; - bool pending_message_needs_page_content_ = false; base::WeakPtrFactory weak_ptr_factory_{this}; }; diff --git a/components/ai_chat/core/common/mojom/ai_chat.mojom b/components/ai_chat/core/common/mojom/ai_chat.mojom index 6cf6682f77ef..3094f70dcd82 100644 --- a/components/ai_chat/core/common/mojom/ai_chat.mojom +++ b/components/ai_chat/core/common/mojom/ai_chat.mojom @@ -165,6 +165,4 @@ interface ChatUIPage { // |is_fetching| will be |true|. OnSiteInfoChanged(SiteInfo info); OnFaviconImageDataChanged(array favicon_image_data); - // This reports if there are any pending entries to be sent to the remote LLM. - OnConversationEntryPending(); }; diff --git a/components/ai_chat/resources/page/components/main/index.tsx b/components/ai_chat/resources/page/components/main/index.tsx index 02bb38c63137..419f2b3c040f 100644 --- a/components/ai_chat/resources/page/components/main/index.tsx +++ b/components/ai_chat/resources/page/components/main/index.tsx @@ -141,10 +141,12 @@ function Main() { onScroll={handleScroll} > - {context.hasAcceptedAgreement && } - + {context.hasAcceptedAgreement && <> + + + } {currentErrorElement && (
{currentErrorElement}
)} diff --git a/components/ai_chat/resources/page/state/data-context-provider.tsx b/components/ai_chat/resources/page/state/data-context-provider.tsx index a81d63e8baf3..e7ecd13b3ea9 100644 --- a/components/ai_chat/resources/page/state/data-context-provider.tsx +++ b/components/ai_chat/resources/page/state/data-context-provider.tsx @@ -78,6 +78,14 @@ function DataContextProvider (props: DataContextProviderProps) { .then((res) => setConversationHistory(res.conversationHistory)) } + // If a conversation entry is submitted but we haven't yet + // accepted the policy, show the policy. + React.useEffect(() => { + if (conversationHistory.length && !hasAcceptedAgreement) { + setShowAgreementModal(true) + } + }, [conversationHistory?.length, hasAcceptedAgreement]) + const getSuggestedQuestions = () => { getPageHandlerInstance() .pageHandler.getSuggestedQuestions() @@ -255,12 +263,6 @@ function DataContextProvider (props: DataContextProviderProps) { setCurrentModelKey(modelKey) }) - getPageHandlerInstance().callbackRouter.onConversationEntryPending.addListener(() => { - if (!hasAcceptedAgreement) { - setShowAgreementModal(true) - } - }) - // Since there is no server-side event for premium status changing, // we should check often. And since purchase or login is performed in // a separate WebContents, we can check when focus is returned here.