diff --git a/tink/internal/BUILD.bazel b/tink/internal/BUILD.bazel index f1233563..13c9663c 100644 --- a/tink/internal/BUILD.bazel +++ b/tink/internal/BUILD.bazel @@ -1367,3 +1367,26 @@ cc_library( hdrs = ["dfsan_forwarders.h"], deps = ["@com_google_absl//absl/base"], ) + +cc_library( + name = "tpb_message_descriptor", + srcs = ["tpb_message_descriptor.cc"], + hdrs = ["tpb_message_descriptor.h"], + include_prefix = "tink/internal", + deps = [ + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "tpb_message_descriptor_test", + srcs = ["tpb_message_descriptor_test.cc"], + deps = [ + ":tpb_message_descriptor", + "//tink/util:test_matchers", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tink/internal/CMakeLists.txt b/tink/internal/CMakeLists.txt index 1fd246c6..0cdd56b1 100644 --- a/tink/internal/CMakeLists.txt +++ b/tink/internal/CMakeLists.txt @@ -1314,3 +1314,25 @@ tink_cc_test( tink::util::statusor tink::util::test_matchers ) + +tink_cc_library( + NAME tpb_message_descriptor + SRCS + tpb_message_descriptor.cc + tpb_message_descriptor.h + DEPS + absl::btree + absl::status + absl::statusor + absl::strings +) + +tink_cc_test( + NAME tpb_message_descriptor_test + SRCS + tpb_message_descriptor_test.cc + DEPS + tink::internal::tpb_message_descriptor + gmock + tink::util::test_matchers +) diff --git a/tink/internal/tpb_message_descriptor.cc b/tink/internal/tpb_message_descriptor.cc new file mode 100644 index 00000000..4be71ee1 --- /dev/null +++ b/tink/internal/tpb_message_descriptor.cc @@ -0,0 +1,93 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////////////// + +#include "tink/internal/tpb_message_descriptor.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" + +namespace crypto { +namespace tink { +namespace internal { + +absl::Status TpbMessageDescriptor::AddUint32(int tag) { + if (!types_.emplace(tag, Type::kUint32).second) { + return absl::InvalidArgumentError( + absl::StrCat("Tag ", tag, " already exists")); + } + return absl::OkStatus(); +} + +absl::Status TpbMessageDescriptor::AddBytes(int tag) { + if (!types_.emplace(tag, Type::kBytes).second) { + return absl::InvalidArgumentError( + absl::StrCat("Tag ", tag, " already exists")); + } + return absl::OkStatus(); +} + +absl::Status TpbMessageDescriptor::AddMessage( + int tag, const TpbMessageDescriptor& descriptor) { + if (!types_.emplace(tag, Type::kMessage).second) { + return absl::InvalidArgumentError( + absl::StrCat("Tag ", tag, " already exists")); + } + message_descriptors_[tag] = + std::make_shared(descriptor); + return absl::OkStatus(); +} + +const TpbMessageDescriptor* TpbMessageDescriptor::GetMessage( + int tag) const { + auto it = message_descriptors_.find(tag); + if (it == message_descriptors_.end()) { + return nullptr; + } + return it->second.get(); +} + +absl::StatusOr TpbMessageDescriptor::GetType( + int tag) const { + auto it = types_.find(tag); + if (it == types_.end()) { + return absl::InvalidArgumentError(absl::StrCat("Tag ", tag, " not found")); + } + return it->second; +} + +bool operator==(const TpbMessageDescriptor& lhs, + const TpbMessageDescriptor& rhs) { + if (lhs.types_ != rhs.types_) return false; + return lhs.message_descriptors_.size() == rhs.message_descriptors_.size() && + std::equal(lhs.message_descriptors_.begin(), + lhs.message_descriptors_.end(), + rhs.message_descriptors_.begin(), [](auto l, auto r) { + return l.first == r.first && *l.second == *r.second; + }); +} + +bool operator!=(const TpbMessageDescriptor& lhs, + const TpbMessageDescriptor& rhs) { + return !(lhs == rhs); +} + +} // namespace internal +} // namespace tink +} // namespace crypto diff --git a/tink/internal/tpb_message_descriptor.h b/tink/internal/tpb_message_descriptor.h new file mode 100644 index 00000000..dc93a7ed --- /dev/null +++ b/tink/internal/tpb_message_descriptor.h @@ -0,0 +1,78 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////////////// + +#ifndef TINK_INTERNAL_TPB_MESSAGE_DESCRIPTOR_H_ +#define TINK_INTERNAL_TPB_MESSAGE_DESCRIPTOR_H_ + +#include + +#include "absl/container/btree_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +namespace crypto { +namespace tink { +namespace internal { + +// A TinkProtoBuf message descriptor. This describes the format of a message. +// (i.e., it describes what's usually in a .proto file). +class TpbMessageDescriptor { + public: + TpbMessageDescriptor() = default; + // Movable and copyable. + TpbMessageDescriptor(const TpbMessageDescriptor&) = default; + TpbMessageDescriptor(TpbMessageDescriptor&&) noexcept = default; + TpbMessageDescriptor& operator=(const TpbMessageDescriptor&) = default; + TpbMessageDescriptor& operator=(TpbMessageDescriptor&&) noexcept = default; + + // The type of a field: + // https://protobuf.dev/reference/protobuf/proto3-spec/#fields + enum class Type { kUint32, kBytes, kMessage }; + + // Adds a uint32 field with the given tag. + absl::Status AddUint32(int tag); + + // Adds a bytes field with the given tag. + absl::Status AddBytes(int tag); + + // Adds a message field with the given tag. + // Note: recursive messages are not supported. + absl::Status AddMessage(int tag, const TpbMessageDescriptor& descriptor); + const TpbMessageDescriptor* GetMessage(int tag) const; + + // Returns the type of the field with the given tag. + absl::StatusOr GetType(int tag) const; + + friend bool operator==(const TpbMessageDescriptor& lhs, + const TpbMessageDescriptor& rhs); + friend bool operator!=(const TpbMessageDescriptor& lhs, + const TpbMessageDescriptor& rhs); + + private: + absl::btree_map types_; + // We cannot have a map because + // TpbMessageDescriptor is incomplete at this point. We hence use shared_ptr + // instead (with unique_ptr we would have to manually write the copy + // constructor of TpbMessageDescriptor). + absl::btree_map> + message_descriptors_; +}; + +} // namespace internal +} // namespace tink +} // namespace crypto + +#endif // TINK_INTERNAL_TPB_MESSAGE_DESCRIPTOR_H_ diff --git a/tink/internal/tpb_message_descriptor_test.cc b/tink/internal/tpb_message_descriptor_test.cc new file mode 100644 index 00000000..f72c81b2 --- /dev/null +++ b/tink/internal/tpb_message_descriptor_test.cc @@ -0,0 +1,221 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/////////////////////////////////////////////////////////////////////////////// +#include "tink/internal/tpb_message_descriptor.h" +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "tink/util/test_matchers.h" + +namespace crypto { +namespace tink { +namespace internal { +namespace { + +using ::crypto::tink::test::IsOk; +using ::crypto::tink::test::IsOkAndHolds; +using ::crypto::tink::test::StatusIs; +using ::crypto::tink::util::StatusOr; +using ::testing::HasSubstr; +using ::testing::IsNull; +using ::testing::Not; +using ::testing::Test; + +TEST(TpbMessageDescriptorTest, EmptyMessageGetTypeError) { + TpbMessageDescriptor descriptor; + EXPECT_THAT(descriptor.GetType(123), Not(IsOk())); +} + +TEST(TpbMessageDescriptorTest, SingleUInt32FieldWorks) { + TpbMessageDescriptor descriptor; + ASSERT_THAT(descriptor.AddUint32(123), IsOk()); + EXPECT_THAT(descriptor.GetType(123), + IsOkAndHolds(TpbMessageDescriptor::Type::kUint32)); +} + +TEST(TpbMessageDescriptorTest, SingleBytesFieldWorks) { + TpbMessageDescriptor descriptor; + ASSERT_THAT(descriptor.AddBytes(345), IsOk()); + EXPECT_THAT(descriptor.GetType(345), + IsOkAndHolds(TpbMessageDescriptor::Type::kBytes)); +} + +TEST(TpbMessageDescriptorTest, SingleMessageFieldWorks) { + TpbMessageDescriptor nested_descriptor; + ASSERT_THAT(nested_descriptor.AddUint32(123), IsOk()); + TpbMessageDescriptor descriptor; + ASSERT_THAT(descriptor.AddMessage(345, nested_descriptor), IsOk()); + EXPECT_THAT(descriptor.GetType(345), + IsOkAndHolds(TpbMessageDescriptor::Type::kMessage)); + const TpbMessageDescriptor* message_descriptor = + descriptor.GetMessage(345); + ASSERT_THAT(message_descriptor, Not(IsNull())); + EXPECT_THAT(message_descriptor->GetType(123), + IsOkAndHolds(TpbMessageDescriptor::Type::kUint32)); +} + +TEST(TpbMessageDescriptorTest, AddExsitingFieldFails) { + TpbMessageDescriptor descriptor; + ASSERT_THAT(descriptor.AddUint32(123), IsOk()); + + EXPECT_THAT(descriptor.AddUint32(123), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Tag 123 already exists"))); +} + + +TEST(TpbMessageDescriptorTest, EqualityEmpty) { + TpbMessageDescriptor m1; + TpbMessageDescriptor m2; + EXPECT_TRUE(m1 == m2); + EXPECT_FALSE(m1 != m2); +} + +TEST(TpbMessageDescriptorTest, EqualitySingleUInt32) { + TpbMessageDescriptor m1; + TpbMessageDescriptor m2; + ASSERT_THAT(m1.AddUint32(123), IsOk()); + ASSERT_THAT(m2.AddUint32(123), IsOk()); + EXPECT_TRUE(m1 == m2); + EXPECT_FALSE(m1 != m2); +} + +TEST(TpbMessageDescriptorTest, EqualitySingleBytes) { + TpbMessageDescriptor m1; + TpbMessageDescriptor m2; + ASSERT_THAT(m1.AddBytes(123), IsOk()); + ASSERT_THAT(m2.AddBytes(123), IsOk()); + EXPECT_TRUE(m1 == m2); + EXPECT_FALSE(m1 != m2); +} + +TEST(TpbMessageDescriptorTest, EqualitySingleNestedMessage) { + TpbMessageDescriptor i1; + TpbMessageDescriptor i2; + ASSERT_THAT(i1.AddBytes(234), IsOk()); + ASSERT_THAT(i2.AddBytes(234), IsOk()); + TpbMessageDescriptor m1; + TpbMessageDescriptor m2; + ASSERT_THAT(m1.AddMessage(123, i1), IsOk()); + ASSERT_THAT(m2.AddMessage(123, i2), IsOk()); + EXPECT_TRUE(m1 == m2); + EXPECT_FALSE(m1 != m2); +} + +TEST(TpbMessageDescriptorTest, NotEqualityToEmpty) { + TpbMessageDescriptor m1; + ASSERT_THAT(m1.AddUint32(123), IsOk()); + EXPECT_FALSE(m1 == TpbMessageDescriptor()); + EXPECT_TRUE(m1 != TpbMessageDescriptor()); + + TpbMessageDescriptor m2; + ASSERT_THAT(m2.AddBytes(123), IsOk()); + EXPECT_FALSE(m2 == TpbMessageDescriptor()); + EXPECT_TRUE(m2 != TpbMessageDescriptor()); + + TpbMessageDescriptor m3; + ASSERT_THAT(m3.AddMessage(123, TpbMessageDescriptor()), IsOk()); + EXPECT_FALSE(m3 == TpbMessageDescriptor()); + EXPECT_TRUE(m3 != TpbMessageDescriptor()); +} + +TEST(TpbMessageDescriptorTest, NotEqualityDifferentTypes) { + TpbMessageDescriptor m1; + ASSERT_THAT(m1.AddUint32(123), IsOk()); + TpbMessageDescriptor m2; + ASSERT_THAT(m2.AddBytes(123), IsOk()); + EXPECT_FALSE(m1 == m2); + EXPECT_TRUE(m1 != m2); +} + +TEST(TpbMessageDescriptorTest, NotEqualIfInnerMessageDiffers) { + TpbMessageDescriptor i1; + ASSERT_THAT(i1.AddUint32(123), IsOk()); + TpbMessageDescriptor i2; + ASSERT_THAT(i2.AddBytes(123), IsOk()); + + TpbMessageDescriptor m1; + TpbMessageDescriptor m2; + ASSERT_THAT(m1.AddMessage(123, i1), IsOk()); + ASSERT_THAT(m2.AddMessage(123, i2), IsOk()); + EXPECT_FALSE(m1 == m2); + EXPECT_TRUE(m1 != m2); +} + +TEST(TpbMessageDescriptorTest, CopyConstructorWorks) { + TpbMessageDescriptor i1; + ASSERT_THAT(i1.AddUint32(123), IsOk()); + TpbMessageDescriptor i2; + ASSERT_THAT(i2.AddBytes(123), IsOk()); + + TpbMessageDescriptor m; + ASSERT_THAT(m.AddMessage(123, i1), IsOk()); + ASSERT_THAT(m.AddMessage(345, i2), IsOk()); + + TpbMessageDescriptor m2 = m; + EXPECT_TRUE(m == m2); +} + +TEST(TpbMessageDescriptorTest, CopyAssignmentWorks) { + TpbMessageDescriptor i1; + ASSERT_THAT(i1.AddUint32(123), IsOk()); + TpbMessageDescriptor i2; + ASSERT_THAT(i2.AddBytes(123), IsOk()); + + TpbMessageDescriptor m; + ASSERT_THAT(m.AddMessage(123, i1), IsOk()); + ASSERT_THAT(m.AddMessage(345, i2), IsOk()); + + TpbMessageDescriptor m2; + m2 = m; + EXPECT_TRUE(m == m2); +} + +TEST(TpbMessageDescriptorTest, MoveConstructorWorks) { + TpbMessageDescriptor i1; + ASSERT_THAT(i1.AddUint32(123), IsOk()); + TpbMessageDescriptor i2; + ASSERT_THAT(i2.AddBytes(123), IsOk()); + + TpbMessageDescriptor m; + ASSERT_THAT(m.AddMessage(123, i1), IsOk()); + ASSERT_THAT(m.AddMessage(345, i2), IsOk()); + TpbMessageDescriptor copy = m; + + TpbMessageDescriptor m2 = std::move(copy); + EXPECT_TRUE(m == m2); +} + +TEST(TpbMessageDescriptorTest, MoveAssignmentWorks) { + TpbMessageDescriptor i1; + ASSERT_THAT(i1.AddUint32(123), IsOk()); + TpbMessageDescriptor i2; + ASSERT_THAT(i2.AddBytes(123), IsOk()); + + TpbMessageDescriptor m; + ASSERT_THAT(m.AddMessage(123, i1), IsOk()); + ASSERT_THAT(m.AddMessage(345, i2), IsOk()); + TpbMessageDescriptor copy = m; + + TpbMessageDescriptor m2; + m2 = std::move(copy); + EXPECT_TRUE(m == m2); +} + +} // namespace +} // namespace internal +} // namespace tink +} // namespace crypto