Skip to content

Commit

Permalink
[AI Chat] cleanup suggestion action types and persist to database whe…
Browse files Browse the repository at this point in the history
…n prompt differs from user-visible conversation entry text

- ActionType::CONVERSATION_STARTER
- ActionType::SUGGESTION
- ActionType::SUMMARIZE_{PAGE,VIDEO}
  • Loading branch information
petemill committed Dec 20, 2024
1 parent 2c7caf1 commit 8878883
Show file tree
Hide file tree
Showing 33 changed files with 717 additions and 420 deletions.
4 changes: 2 additions & 2 deletions browser/ai_chat/android/ai_chat_utils_android.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ static void JNI_BraveLeoUtils_OpenLeoQuery(
conversation->MaybeUnlinkAssociatedContent();
mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New(
std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY,
mojom::ConversationTurnVisibility::VISIBLE,
base::android::ConvertJavaStringToUTF8(query), std::nullopt, std::nullopt,
base::android::ConvertJavaStringToUTF8(query), std::nullopt /* prompt */,
std::nullopt /* selected_text */, std::nullopt /* events */,
base::Time::Now(), std::nullopt, false);
conversation->SubmitHumanConversationEntry(std::move(turn));

Expand Down
1 change: 0 additions & 1 deletion browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ namespace ai_chat {

using mojom::CharacterType;
using mojom::ConversationTurn;
using mojom::ConversationTurnVisibility;

AIChatUIPageHandler::ChatContextObserver::ChatContextObserver(
content::WebContents* web_contents,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ void ChromeAutocompleteProviderClient::OpenLeo(const std::u16string& query) {
ai_chat::mojom::ConversationTurn::New(
std::nullopt, ai_chat::mojom::CharacterType::HUMAN,
ai_chat::mojom::ActionType::QUERY,
ai_chat::mojom::ConversationTurnVisibility::VISIBLE,
base::UTF16ToUTF8(query) /* text */, std::nullopt /* selected_text */,
std::nullopt /* events */, base::Time::Now(),
std::nullopt /* edits */, false /* from_brave_search_SERP */);
base::UTF16ToUTF8(query) /* text */, std::nullopt /* prompt */,
std::nullopt /* selected_text */, std::nullopt /* events */,
base::Time::Now(), std::nullopt /* edits */,
false /* from_brave_search_SERP */);

conversation_handler->SubmitHumanConversationEntry(std::move(turn));

Expand Down
80 changes: 59 additions & 21 deletions components/ai_chat/core/browser/ai_chat_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@

namespace {

// These database versions should roll together unless we develop migrations.
// Lowest version we support migrations from - existing database will be deleted
// if lower.
constexpr int kLowestSupportedDatabaseVersion = 1;
// Current version of the database. Increase if breaking changes are made.
constexpr int kCurrentDatabaseVersion = 1;

constexpr char kSearchQueriesSeparator[] = "|||";

std::optional<std::string> GetOptionalString(sql::Statement& statement,
Expand All @@ -48,10 +41,31 @@ void BindOptionalString(sql::Statement& statement,
}
}

bool MigrateFrom1To2(sql::Database* db) {
// Add a new column to the associated_content table to store the content type.
static constexpr char kAddPromptColumnQuery[] =
"ALTER TABLE conversation_entry ADD COLUMN prompt BLOB";
sql::Statement statement(db->GetUniqueStatement(kAddPromptColumnQuery));

return statement.is_valid() && statement.Run();
}

} // namespace

namespace ai_chat {

// These database versions should roll together unless we develop migrations.
// Lowest version we support migrations from - existing database will be deleted
// if lower.
constexpr int kLowestSupportedDatabaseVersion = 1;

// The oldest version of the schema such that a legacy Brave client using that
// version can still read/write the current database.
constexpr int kCompatibleDatabaseVersionNumber = 1;

// Current version of the database. Increase if breaking changes are made.
constexpr int kCurrentDatabaseVersion = 2;

AIChatDatabase::AIChatDatabase(const base::FilePath& db_file_path,
os_crypt_async::Encryptor encryptor)
: db_file_path_(db_file_path),
Expand Down Expand Up @@ -87,7 +101,7 @@ sql::InitStatus AIChatDatabase::InitInternal() {

sql::MetaTable meta_table;
if (!meta_table.Init(&GetDB(), kCurrentDatabaseVersion,
/*compatible_version=*/kCurrentDatabaseVersion)) {
kCompatibleDatabaseVersionNumber)) {
DVLOG(0) << "Failed to init meta table";
return sql::InitStatus::INIT_FAILURE;
}
Expand All @@ -102,6 +116,23 @@ sql::InitStatus AIChatDatabase::InitInternal() {
return sql::InitStatus::INIT_FAILURE;
}

if (meta_table.GetVersionNumber() < kCurrentDatabaseVersion) {
bool migration_success = true;
if (meta_table.GetVersionNumber() == 1) {
migration_success = MigrateFrom1To2(&GetDB());
migration_success = meta_table.SetCompatibleVersionNumber(
kCompatibleDatabaseVersionNumber) &&
meta_table.SetVersionNumber(kCurrentDatabaseVersion);
}
// Migration unsuccessful, raze the database and re-init
if (!migration_success) {
if (db_.Raze()) {
return InitInternal();
}
return sql::InitStatus::INIT_FAILURE;
}
}

if (!transaction.Commit()) {
return sql::InitStatus::INIT_FAILURE;
}
Expand Down Expand Up @@ -211,7 +242,8 @@ std::vector<mojom::ConversationTurnPtr> AIChatDatabase::GetConversationEntries(
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

static constexpr char kEntriesQuery[] =
"SELECT uuid, date, entry_text, character_type, editing_entry_uuid, "
"SELECT uuid, date, entry_text, prompt, character_type, "
"editing_entry_uuid, "
"action_type, selected_text"
" FROM conversation_entry"
" WHERE conversation_uuid=?"
Expand All @@ -233,17 +265,19 @@ std::vector<mojom::ConversationTurnPtr> AIChatDatabase::GetConversationEntries(
std::string entry_uuid = statement.ColumnString(0);
DVLOG(4) << "Found entry row for conversation " << conversation_uuid
<< " with id " << entry_uuid;
auto date = statement.ColumnTime(1);
auto text = DecryptOptionalColumnToString(statement, 2).value_or("");
int index = 1;
auto date = statement.ColumnTime(index++);
auto text = DecryptOptionalColumnToString(statement, index++).value_or("");
auto prompt = DecryptOptionalColumnToString(statement, index++);
auto character_type =
static_cast<mojom::CharacterType>(statement.ColumnInt(3));
auto editing_entry_id = GetOptionalString(statement, 4);
auto action_type = static_cast<mojom::ActionType>(statement.ColumnInt(5));
auto selected_text = DecryptOptionalColumnToString(statement, 6);
static_cast<mojom::CharacterType>(statement.ColumnInt(index++));
auto editing_entry_id = GetOptionalString(statement, index++);
auto action_type =
static_cast<mojom::ActionType>(statement.ColumnInt(index++));
auto selected_text = DecryptOptionalColumnToString(statement, index++);

auto entry = mojom::ConversationTurn::New(
entry_uuid, character_type, action_type,
mojom::ConversationTurnVisibility::VISIBLE, text, selected_text,
entry_uuid, character_type, action_type, text, prompt, selected_text,
std::nullopt, date, std::nullopt, false);

// events
Expand Down Expand Up @@ -548,16 +582,16 @@ bool AIChatDatabase::AddConversationEntry(
if (editing_id.has_value()) {
static constexpr char kInsertEditingConversationEntryQuery[] =
"INSERT INTO conversation_entry(editing_entry_uuid, uuid,"
" conversation_uuid, date, entry_text,"
" conversation_uuid, date, entry_text, prompt,"
" character_type, action_type, selected_text)"
" VALUES(?, ?, ?, ?, ?, ?, ?, ?)";
" VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)";
insert_conversation_entry_statement.Assign(
GetDB().GetUniqueStatement(kInsertEditingConversationEntryQuery));
} else {
static constexpr char kInsertConversationEntryQuery[] =
"INSERT INTO conversation_entry(uuid, conversation_uuid, date,"
" entry_text, character_type, action_type, selected_text)"
" VALUES(?, ?, ?, ?, ?, ?, ?)";
" entry_text, prompt, character_type, action_type, selected_text)"
" VALUES(?, ?, ?, ?, ?, ?, ?, ?)";
insert_conversation_entry_statement.Assign(
GetDB().GetUniqueStatement(kInsertConversationEntryQuery));
}
Expand All @@ -572,6 +606,8 @@ bool AIChatDatabase::AddConversationEntry(
insert_conversation_entry_statement.BindTime(index++, entry->created_time);
BindAndEncryptOptionalString(insert_conversation_entry_statement, index++,
entry->text);
BindAndEncryptOptionalString(insert_conversation_entry_statement, index++,
entry->prompt);
insert_conversation_entry_statement.BindInt(
index++, base::to_underlying(entry->character_type));
insert_conversation_entry_statement.BindInt(
Expand Down Expand Up @@ -1005,6 +1041,8 @@ bool AIChatDatabase::CreateSchema() {
// Encrypted text string
// TODO(petemill): move to event only
"entry_text BLOB,"
// Encrypted optional user-invisible override prompt
"prompt BLOB,"
"character_type INTEGER NOT NULL,"
// editing_entry points to the ConversationEntry row that is being edited.
// Edits can be sorted by date.
Expand Down
5 changes: 5 additions & 0 deletions components/ai_chat/core/browser/ai_chat_database.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

namespace ai_chat {

extern const int kLowestSupportedDatabaseVersion;
extern const int kCompatibleDatabaseVersionNumber;
extern const int kCurrentDatabaseVersion;

// Persists AI Chat conversations and associated content. Conversations are
// mainly formed of their conversation entries. Edits to conversation entries
// should be handled with removal and re-adding so that other classes can make
Expand Down Expand Up @@ -76,6 +80,7 @@ class AIChatDatabase {

private:
friend class AIChatDatabaseTest;
friend class AIChatDatabaseMigrationTest;

sql::Database& GetDB();

Expand Down
146 changes: 139 additions & 7 deletions components/ai_chat/core/browser/ai_chat_database_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
#include "components/os_crypt/async/browser/test_utils.h"
#include "sql/init_status.h"
#include "sql/meta_table.h"
#include "sql/test/test_helpers.h"
#include "testing/gtest/include/gtest/gtest.h"

Expand Down Expand Up @@ -87,7 +88,6 @@ class AIChatDatabaseTest : public testing::Test,
std::unique_ptr<os_crypt_async::OSCryptAsync> os_crypt_;
base::CallbackListSubscription encryptor_ready_subscription_;
std::unique_ptr<AIChatDatabase> db_;
base::FilePath path_;
};

INSTANTIATE_TEST_SUITE_P(
Expand Down Expand Up @@ -139,6 +139,8 @@ TEST_P(AIChatDatabaseTest, AddAndGetConversationAndEntries) {

// Persist the first entry (and get the response ready)
auto history = CreateSampleChatHistory(1u);
// Edit the prompt to show that the prompt is persisted
history[0]->prompt = "first entry prompt";

EXPECT_TRUE(db_->AddConversation(
metadata->Clone(),
Expand Down Expand Up @@ -204,9 +206,8 @@ TEST_P(AIChatDatabaseTest, AddAndGetConversationAndEntries) {
last_query->edits->emplace_back(mojom::ConversationTurn::New(
base::Uuid::GenerateRandomV4().AsLowercaseString(),
mojom::CharacterType::HUMAN, mojom::ActionType::QUERY,
mojom::ConversationTurnVisibility::VISIBLE, "edited query 1",
std::nullopt, std::nullopt, base::Time::Now() + base::Minutes(121),
std::nullopt, false));
"edited query 1", std::nullopt, std::nullopt, std::nullopt,
base::Time::Now() + base::Minutes(121), std::nullopt, false));
EXPECT_TRUE(db_->DeleteConversationEntry(last_query->uuid.value()));
EXPECT_TRUE(db_->AddConversationEntry(uuid, last_query->Clone()));
}
Expand All @@ -221,9 +222,8 @@ TEST_P(AIChatDatabaseTest, AddAndGetConversationAndEntries) {
last_query->edits->emplace_back(mojom::ConversationTurn::New(
base::Uuid::GenerateRandomV4().AsLowercaseString(),
mojom::CharacterType::HUMAN, mojom::ActionType::QUERY,
mojom::ConversationTurnVisibility::VISIBLE, "edited query 2",
std::nullopt, std::nullopt, base::Time::Now() + base::Minutes(122),
std::nullopt, false));
"edited query 2", std::nullopt, std::nullopt, std::nullopt,
base::Time::Now() + base::Minutes(122), std::nullopt, false));
EXPECT_TRUE(db_->DeleteConversationEntry(last_query->uuid.value()));
EXPECT_TRUE(db_->AddConversationEntry(uuid, last_query->Clone()));
}
Expand Down Expand Up @@ -439,4 +439,136 @@ TEST_P(AIChatDatabaseTest, DeleteAssociatedWebContent) {
EXPECT_EQ(archive_result->associated_content[0]->content, expected_contents);
}

// Test the migration for each version upgrade
class AIChatDatabaseMigrationTest : public testing::Test,
public testing::WithParamInterface<int> {
public:
AIChatDatabaseMigrationTest() = default;

void SetUp() override {
CHECK(temp_directory_.CreateUniqueTempDir());
database_dump_location_ = database_dump_location_.AppendASCII("brave")
.AppendASCII("test")
.AppendASCII("data")
.AppendASCII("ai_chat");
os_crypt_ = os_crypt_async::GetTestOSCryptAsyncForTesting(
/*is_sync_for_unittests=*/true);

// Create database when os_crypt is ready
base::RunLoop run_loop;
encryptor_ready_subscription_ =
os_crypt_->GetInstance(base::BindLambdaForTesting(
[&](os_crypt_async::Encryptor encryptor, bool success) {
ASSERT_TRUE(success);
CreateDatabase(base::StringPrintf(
"aichat_database_dump_version_%d.sql", version()));
db_ = std::make_unique<AIChatDatabase>(db_file_path(),
std::move(encryptor));
run_loop.Quit();
}));
run_loop.Run();
}

void TearDown() override {
// Verify that the db was init successfully and not using default return
// values.
EXPECT_TRUE(IsInitOk());
db_.reset();
// Verify current version of database is latest
sql::Database db;
sql::MetaTable meta_table;
ASSERT_TRUE(db.Open(db_file_path()));
ASSERT_TRUE(meta_table.Init(&db, kCurrentDatabaseVersion,
kCompatibleDatabaseVersionNumber));
EXPECT_EQ(kCompatibleDatabaseVersionNumber,
meta_table.GetCompatibleVersionNumber());
EXPECT_EQ(kCurrentDatabaseVersion, meta_table.GetVersionNumber());

CHECK(temp_directory_.Delete());
}

// Creates the database from |sql_file|.
void CreateDatabase(std::string_view sql_file) {
base::FilePath database_dump =
base::PathService::CheckedGet(base::DIR_SRC_TEST_DATA_ROOT);
database_dump =
database_dump.Append(database_dump_location_).AppendASCII(sql_file);
ASSERT_TRUE(sql::test::CreateDatabaseFromSQL(db_file_path(), database_dump))
<< "Could not create database from sql dump file at: "
<< database_dump.value();
}

bool IsInitOk() {
return (db_->db_init_status_.has_value() &&
db_->db_init_status_.value() == sql::InitStatus::INIT_OK);
}

base::FilePath db_file_path() {
return temp_directory_.GetPath().AppendASCII("test_ai_chat.db");
}

// Returns the database version for the test.
int version() const { return GetParam(); }

protected:
base::test::TaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
base::FilePath database_dump_location_;
base::ScopedTempDir temp_directory_;
std::unique_ptr<os_crypt_async::OSCryptAsync> os_crypt_;
base::CallbackListSubscription encryptor_ready_subscription_;
std::unique_ptr<AIChatDatabase> db_;
};

INSTANTIATE_TEST_SUITE_P(,
AIChatDatabaseMigrationTest,
testing::Range(kLowestSupportedDatabaseVersion,
kCurrentDatabaseVersion));

// Tests the migration of the database from version() to kCurrentVersionNumber
TEST_P(AIChatDatabaseMigrationTest, MigrationToVCurrent) {
if (version() < 2) {
// Verify we have existing entries
auto conversations = db_->GetAllConversations();
EXPECT_GT(conversations.size(), 0u);
EXPECT_GT(db_->GetConversationData(conversations[0]->uuid)->entries.size(),
0u);
// ConversationEntry table changed, check it persists correctly
auto now = base::Time::Now();
const std::string uuid = "migrationtest";
mojom::SiteInfoPtr associated_content = mojom::SiteInfo::New(
std::nullopt, mojom::ContentType::PageContent, std::nullopt,
std::nullopt, std::nullopt, 0, false, false);
const mojom::ConversationPtr metadata =
mojom::Conversation::New(uuid, "title", now - base::Hours(2), true,
std::nullopt, std::move(associated_content));

// Persist the first entry (and get the response ready)
auto history = CreateSampleChatHistory(1u);
// Edit the prompt to show that the prompt is persisted
history[0]->prompt = "first entry prompt";

EXPECT_TRUE(db_->AddConversation(metadata->Clone(), std::nullopt,
history[0]->Clone()));
// Persist the response entry
EXPECT_TRUE(db_->AddConversationEntry(uuid, history[1]->Clone()));

// Test getting the conversation entries
mojom::ConversationArchivePtr result = db_->GetConversationData(uuid);
ExpectConversationHistoryEquals(FROM_HERE, result->entries, history);

// Add another pair of entries
auto next_history = CreateSampleChatHistory(1u, 1);
EXPECT_TRUE(db_->AddConversationEntry(uuid, next_history[0]->Clone()));
EXPECT_TRUE(db_->AddConversationEntry(uuid, next_history[1]->Clone()));

// Verify all entries are returned
mojom::ConversationArchivePtr result_2 = db_->GetConversationData(uuid);
for (auto& entry : next_history) {
history.push_back(std::move(entry));
}
ExpectConversationHistoryEquals(FROM_HERE, result_2->entries, history);
}
}

} // namespace ai_chat
Loading

0 comments on commit 8878883

Please sign in to comment.