Skip to content

Commit

Permalink
chore: Remove DataRecordStruct usage
Browse files Browse the repository at this point in the history
Bug: 385338000
Change-Id: I6dbeda5a9597b46b2f0fb81a6d808f2d6ee993ef
GitOrigin-RevId: 690022fc002613bd2b4b80f63e19f61a0d0a01c0
  • Loading branch information
lusayaa authored and copybara-github committed Dec 27, 2024
1 parent b780f74 commit 6d8fd09
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 90 deletions.
20 changes: 11 additions & 9 deletions components/tools/benchmarks/benchmark_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,19 @@ TEST(BenchmarkUtilTest, VerifyWriteRecords) {
auto status = WriteRecords(num_records, record_size, data_stream);
EXPECT_TRUE(status.ok()) << status;
DeltaRecordStreamReader record_reader(data_stream);
testing::MockFunction<absl::Status(DataRecordStruct)> record_callback;
testing::MockFunction<absl::Status(const DataRecord&)> record_callback;
EXPECT_CALL(record_callback, Call)
.Times(num_records)
.WillRepeatedly([record_size](DataRecordStruct data_record) {
if (std::holds_alternative<KeyValueMutationRecordStruct>(
data_record.record)) {
auto kv_record =
std::get<KeyValueMutationRecordStruct>(data_record.record);
EXPECT_EQ(std::get<std::string_view>(kv_record.value).size(),
record_size);
}
.WillRepeatedly([record_size](const DataRecord& data_record) {
DataRecordT data_record_struct;
data_record.UnPackTo(&data_record_struct);
EXPECT_EQ(data_record_struct.record.type,
Record::KeyValueMutationRecord);
const auto kv_record =
*data_record_struct.record.AsKeyValueMutationRecord();
EXPECT_EQ(kv_record.value.type, Value::StringValue);
const auto value = *kv_record.value.AsStringValue();
EXPECT_EQ(value.value.size(), record_size);
return absl::OkStatus();
});
status = record_reader.ReadRecords(record_callback.AsStdFunction());
Expand Down
13 changes: 7 additions & 6 deletions tools/data_cli/commands/generate_snapshot_command.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,13 @@ absl::Status WriteRecordsToSnapshotStream(
ShardingFunction sharding_function(/*seed=*/"");
return record_reader.ReadRecords(
[&params, &snapshot_writer,
&sharding_function](DataRecordStruct data_record) {
&sharding_function](const DataRecord& data_record) {
DataRecordT data_record_struct;
data_record.UnPackTo(&data_record_struct);
if (params.shard_number >= 0 &&
std::holds_alternative<KeyValueMutationRecordStruct>(
data_record.record)) {
KeyValueMutationRecordStruct record_struct =
std::get<KeyValueMutationRecordStruct>(data_record.record);
data_record_struct.record.type == Record::KeyValueMutationRecord) {
KeyValueMutationRecordT record_struct =
*data_record_struct.record.AsKeyValueMutationRecord();
auto record_shard_num = sharding_function.GetShardNumForKey(
record_struct.key, params.number_of_shards);
if (params.shard_number != record_shard_num) {
Expand All @@ -152,7 +153,7 @@ absl::Status WriteRecordsToSnapshotStream(
return absl::OkStatus();
}
}
return snapshot_writer.WriteRecord(data_record);
return snapshot_writer.WriteRecord(data_record_struct);
});
}

Expand Down
3 changes: 2 additions & 1 deletion tools/request_simulation/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ cc_library(
"//components/data/common:thread_manager",
"//components/data/realtime:realtime_notifier",
"//public/data_loading:data_loading_fbs",
"//public/data_loading:records_utils",
"//public/data_loading:record_utils",
"//public/data_loading/readers:riegeli_stream_io",
"//public/data_loading/readers:stream_record_reader_factory",
"@com_github_google_flatbuffers//:flatbuffers",
Expand Down Expand Up @@ -319,6 +319,7 @@ cc_test(
deps = [
":delta_based_request_generator",
"//components/data/common:mocks",
"//public/test_util:data_record",
"//public/test_util:mocks",
"@com_google_googletest//:gtest_main",
"@google_privacysandbox_servers_common//src/telemetry:mocks",
Expand Down
29 changes: 16 additions & 13 deletions tools/request_simulation/delta_based_request_generator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,20 @@
#include "gtest/gtest.h"
#include "public/constants.h"
#include "public/data_loading/filename_utils.h"
#include "public/data_loading/records_utils.h"
#include "public/data_loading/record_utils.h"
#include "public/test_util/data_record.h"
#include "public/test_util/mocks.h"
#include "src/telemetry/mocks.h"
#include "tools/request_simulation/request_generation_util.h"

using kv_server::BlobStorageChangeNotifier;
using kv_server::BlobStorageClient;
using kv_server::DataRecordStruct;
using kv_server::DataRecordT;
using kv_server::DeltaBasedRequestGenerator;
using kv_server::FilePrefix;
using kv_server::FileType;
using kv_server::KeyValueMutationRecordStruct;
using kv_server::GetSimpleStringValue;
using kv_server::KeyValueMutationRecordT;
using kv_server::KeyValueMutationType;
using kv_server::KVFileMetadata;
using kv_server::MessageQueue;
Expand All @@ -47,11 +49,8 @@ using kv_server::MockRealtimeNotifier;
using kv_server::MockStreamRecordReader;
using kv_server::MockStreamRecordReaderFactory;
using kv_server::Record;
using kv_server::Serialize;
using kv_server::ToDeltaFileName;
using kv_server::ToFlatBufferBuilder;
using kv_server::ToStringView;
using kv_server::UserDefinedFunctionsConfigStruct;
using kv_server::UserDefinedFunctionsLanguage;
using kv_server::Value;
using privacy_sandbox::server_common::MockMetricsRecorder;
using testing::_;
Expand Down Expand Up @@ -128,12 +127,16 @@ TEST_F(GenerateRequestsFromDeltaFilesTest, LoadingDataFromDeltaFiles) {
.Times(1)
.WillOnce(
[](const std::function<absl::Status(std::string_view)>& callback) {
callback(
ToStringView(ToFlatBufferBuilder(DataRecordStruct{
.record =
KeyValueMutationRecordStruct{
KeyValueMutationType::Update, 3, "key", "value"}})))
.IgnoreError();
KeyValueMutationRecordT kv_mutation_record = {
.mutation_type = KeyValueMutationType::Update,
.logical_commit_time = 3,
.key = "key",
};
kv_mutation_record.value.Set(GetSimpleStringValue("value"));
DataRecordT data_record =
GetNativeDataRecord(std::move(kv_mutation_record));
auto [fbs_buffer, serialized_string_view] = Serialize(data_record);
callback(serialized_string_view).IgnoreError();
return absl::OkStatus();
});
EXPECT_CALL(delta_stream_reader_factory_, CreateConcurrentReader)
Expand Down
2 changes: 1 addition & 1 deletion tools/serving_data_generator/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ cc_binary(
deps = [
"//public/data_loading:data_loading_fbs",
"//public/data_loading:filename_utils",
"//public/data_loading:records_utils",
"//public/data_loading:record_utils",
"//public/data_loading:riegeli_metadata_cc_proto",
"//public/sharding:sharding_function",
"@com_google_absl//absl/flags:flag",
Expand Down
49 changes: 31 additions & 18 deletions tools/serving_data_generator/test_serving_data_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "absl/log/log.h"
#include "public/data_loading/data_loading_generated.h"
#include "public/data_loading/filename_utils.h"
#include "public/data_loading/records_utils.h"
#include "public/data_loading/record_utils.h"
#include "public/data_loading/riegeli_metadata.pb.h"
#include "public/sharding/sharding_function.h"
#include "riegeli/bytes/ostream_writer.h"
Expand Down Expand Up @@ -52,24 +52,32 @@ ABSL_FLAG(int, num_set_records, 5, "Number of records to generate");
ABSL_FLAG(uint32_t, range_min, 0, "Minimum element in set records.");
ABSL_FLAG(uint32_t, range_max, 2147483647, "Maximum element in set records.");

using kv_server::DataRecordStruct;
using kv_server::KeyValueMutationRecordStruct;
using kv_server::DataRecordT;
using kv_server::KeyValueMutationRecordT;
using kv_server::KeyValueMutationType;
using kv_server::KVFileMetadata;
using kv_server::ShardingMetadata;
using kv_server::StringSetT;
using kv_server::StringValueT;
using kv_server::ToDeltaFileName;
using kv_server::ToFlatBufferBuilder;
using kv_server::ToStringView;
using kv_server::UInt32SetT;

const std::array<std::string, 3> kSetOps = {" - ", " | ", " & "};

void WriteKeyValueRecord(std::string_view key, std::string_view value,
int64_t logical_commit_time,
riegeli::RecordWriterBase& writer) {
auto kv_record = KeyValueMutationRecordStruct{
KeyValueMutationType::Update, logical_commit_time, key, value};
writer.WriteRecord(ToStringView(
ToFlatBufferBuilder(DataRecordStruct{.record = std::move(kv_record)})));
KeyValueMutationRecordT kv_record = {
.mutation_type = KeyValueMutationType::Update,
.logical_commit_time = logical_commit_time,
.key = std::string(key),
};
kv_record.value.Set(StringValueT{.value = std::string(value)});
DataRecordT data_record;
data_record.record.Set(std::move(kv_record));
auto [fbs_buffer, serialized_string_view] = Serialize(data_record);
writer.WriteRecord(serialized_string_view);
}

std::vector<std::string> WriteKeyValueRecords(
Expand Down Expand Up @@ -127,23 +135,28 @@ void WriteKeyValueSetRecords(const std::vector<std::string>& keys,
}
}
auto set_value_key = absl::StrCat(set_key_prefix, i);
KeyValueMutationRecordStruct record;
record.mutation_type = KeyValueMutationType::Update;
record.logical_commit_time = timestamp++;
record.key = set_value_key;
KeyValueMutationRecordT kv_record = {
.mutation_type = KeyValueMutationType::Update,
.logical_commit_time = timestamp++,
.key = set_value_key,
};
if (absl::GetFlag(FLAGS_generate_int_set_records)) {
record.value = uint32_set;
writer.WriteRecord(ToStringView(
ToFlatBufferBuilder(DataRecordStruct{.record = std::move(record)})));
kv_record.value.Set(UInt32SetT{.value = std::move(uint32_set)});
DataRecordT data_record;
data_record.record.Set(std::move(kv_record));
auto [fbs_buffer, serialized_string_view] = Serialize(data_record);
writer.WriteRecord(serialized_string_view);
}
if (absl::GetFlag(FLAGS_generate_string_set_records)) {
std::vector<std::string_view> string_set_view;
for (const auto& v : string_set) {
string_set_view.emplace_back(v);
}
record.value = string_set_view;
writer.WriteRecord(ToStringView(
ToFlatBufferBuilder(DataRecordStruct{.record = std::move(record)})));
kv_record.value.Set(StringSetT{.value = std::move(string_set)});
DataRecordT data_record;
data_record.record.Set(std::move(kv_record));
auto [fbs_buffer, serialized_string_view] = Serialize(data_record);
writer.WriteRecord(serialized_string_view);
}
absl::StrAppend(&query, set_value_key,
kSetOps[std::rand() % kSetOps.size()]);
Expand Down
93 changes: 51 additions & 42 deletions tools/udf/udf_tester/udf_delta_file_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,40 +57,49 @@ using google::protobuf::util::JsonStringToMessage;
// string_view to non-const string_view. Since this tool is for simple testing,
// the current solution is to pass by value.
absl::Status LoadCacheFromKVMutationRecord(
UDFDeltaFileTestLogContext& log_context,
KeyValueMutationRecordStruct record, Cache& cache) {
UDFDeltaFileTestLogContext& log_context, KeyValueMutationRecordT record,
Cache& cache) {
switch (record.mutation_type) {
case KeyValueMutationType::Update: {
LOG(INFO) << "Updating cache with key " << record.key
<< ", logical commit time " << record.logical_commit_time;
std::visit(
[&cache, &record, &log_context](auto& value) {
using VariantT = std::decay_t<decltype(value)>;
if constexpr (std::is_same_v<VariantT, std::string_view>) {
cache.UpdateKeyValue(log_context, record.key, value,
record.logical_commit_time);
return;
}
constexpr bool is_list =
(std::is_same_v<VariantT, std::vector<std::string_view>>);
if constexpr (is_list) {
cache.UpdateKeyValueSet(log_context, record.key,
absl::MakeSpan(value),
record.logical_commit_time);
return;
}
},
record.value);
break;
if (record.value.type == Value::StringValue) {
cache.UpdateKeyValue(log_context, record.key,
record.value.AsStringValue()->value,
record.logical_commit_time);
return absl::OkStatus();
}
if (record.value.type == Value::StringSet) {
std::vector<std::string> values_str = record.value.AsStringSet()->value;
std::vector<std::string_view> values(values_str.begin(),
values_str.end());
cache.UpdateKeyValueSet(log_context, record.key, absl::MakeSpan(values),
record.logical_commit_time);
return absl::OkStatus();
}
if (record.value.type == Value::UInt32Set) {
auto values = record.value.AsUInt32Set()->value;
cache.UpdateKeyValueSet(log_context, record.key, absl::MakeSpan(values),
record.logical_commit_time);
return absl::OkStatus();
}
if (record.value.type == Value::UInt64Set) {
auto values = record.value.AsUInt64Set()->value;
cache.UpdateKeyValueSet(log_context, record.key, absl::MakeSpan(values),
record.logical_commit_time);
return absl::OkStatus();
}
return absl::InvalidArgumentError(
absl::StrCat("Record with key: ", record.key,
" has unsupported value type: ", record.value.type));
}
case KeyValueMutationType::Delete: {
cache.DeleteKey(log_context, record.key, record.logical_commit_time);
break;
}
default:
return absl::InvalidArgumentError(
absl::StrCat("Invalid mutation type: ",
EnumNameKeyValueMutationType(record.mutation_type)));
absl::StrCat("Invalid mutation type: ", record.mutation_type));
}
return absl::OkStatus();
}
Expand All @@ -100,23 +109,22 @@ absl::Status LoadCacheFromFile(UDFDeltaFileTestLogContext& log_context,
std::ifstream delta_file(file_path);
DeltaRecordStreamReader record_reader(delta_file);
absl::Status status = record_reader.ReadRecords(
[&cache, &log_context](const DataRecordStruct& data_record) {
[&cache, &log_context](const DataRecord& data_record) {
DataRecordT data_record_struct;
data_record.UnPackTo(&data_record_struct);
// Only load KVMutationRecords into cache.
if (std::holds_alternative<KeyValueMutationRecordStruct>(
data_record.record)) {
if (data_record_struct.record.type == Record::KeyValueMutationRecord) {
return LoadCacheFromKVMutationRecord(
log_context,
std::get<KeyValueMutationRecordStruct>(data_record.record),
cache);
*data_record_struct.record.AsKeyValueMutationRecord(), cache);
}
return absl::OkStatus();
});
return status;
}

void ReadCodeConfigFromUdfConfig(
const UserDefinedFunctionsConfigStruct& udf_config,
CodeConfig& code_config) {
void ReadCodeConfigFromUdfConfig(const UserDefinedFunctionsConfigT& udf_config,
CodeConfig& code_config) {
code_config.js = udf_config.code_snippet;
code_config.logical_commit_time = udf_config.logical_commit_time;
code_config.udf_handler_name = udf_config.handler_name;
Expand All @@ -127,17 +135,18 @@ absl::Status ReadCodeConfigFromFile(std::string file_path,
CodeConfig& code_config) {
std::ifstream delta_file(file_path);
DeltaRecordStreamReader record_reader(delta_file);
return record_reader.ReadRecords(
[&code_config](const DataRecordStruct& data_record) {
if (std::holds_alternative<UserDefinedFunctionsConfigStruct>(
data_record.record)) {
ReadCodeConfigFromUdfConfig(
std::get<UserDefinedFunctionsConfigStruct>(data_record.record),
code_config);
return absl::OkStatus();
}
return absl::InvalidArgumentError("Invalid record type.");
});
return record_reader.ReadRecords([&code_config](
const DataRecord& data_record) {
DataRecordT data_record_struct;
data_record.UnPackTo(&data_record_struct);
if (data_record_struct.record.type == Record::UserDefinedFunctionsConfig) {
ReadCodeConfigFromUdfConfig(
*data_record_struct.record.AsUserDefinedFunctionsConfig(),
code_config);
return absl::OkStatus();
}
return absl::InvalidArgumentError("Invalid record type.");
});
}

void ShutdownUdf(UdfClient& udf_client) {
Expand Down

0 comments on commit 6d8fd09

Please sign in to comment.