diff --git a/CMakeLists.txt b/CMakeLists.txt index cc5c25a0f4a..244ac816c87 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -183,6 +183,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/options.cpp ${NVFUSER_SRCS_DIR}/parallel_dimension_map.cpp ${NVFUSER_SRCS_DIR}/parallel_type_bitmap.cpp + ${NVFUSER_SRCS_DIR}/polymorphic_value.cpp ${NVFUSER_SRCS_DIR}/predicate_compute.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/add_axioms.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/allocation_order_inference.cpp diff --git a/csrc/polymorphic_value.cpp b/csrc/polymorphic_value.cpp new file mode 100644 index 00000000000..3ae0a9d731e --- /dev/null +++ b/csrc/polymorphic_value.cpp @@ -0,0 +1,41 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include + +namespace nvfuser { + +bool StructHandle::operator==(const StructHandle& other) const { + if (struct_ptr_ == other.struct_ptr_) { + return true; + } + const StructType this_type = type(); + const StructType other_type = other.type(); + if (this_type.name != other_type.name) { + return false; + } + if (this_type.fields.size() != other_type.fields.size()) { + return false; + } + for (size_t i : c10::irange(this_type.fields.size())) { + // Check that fields are in same position, have same type, and have same + // value (recursive) + const StructType::FieldInfo& fa = this_type.fields.at(i); + const StructType::FieldInfo& fb = other_type.fields.at(i); + PolymorphicValue a_val = (*this)->*(fa.name); + PolymorphicValue b_val = other->*(fb.name); + if (fa.name != fb.name || *fa.type != *fb.type || + !PolymorphicValue_functions::isSame(a_val, b_val)) { + return false; + } + } + return true; +} + +} // namespace nvfuser diff --git a/csrc/polymorphic_value.h b/csrc/polymorphic_value.h index 0013c5f35e6..5775feb1173 100644 --- a/csrc/polymorphic_value.h +++ b/csrc/polymorphic_value.h @@ -185,11 +185,7 @@ class StructHandle { StructHandle& operator=(const StructHandle& other) = default; StructHandle& operator=(StructHandle&& other) = default; - //! This is a shallow comparison operator that just checks whether we point to - //! the same exact Struct - bool operator==(const StructHandle& other) const { - return struct_ptr_ == other.struct_ptr_; - } + bool operator==(const StructHandle& other) const; template bool is() const { diff --git a/tests/cpp/test_polymorphic_value.cpp b/tests/cpp/test_polymorphic_value.cpp index 73d90cab316..57e5403348f 100644 --- a/tests/cpp/test_polymorphic_value.cpp +++ b/tests/cpp/test_polymorphic_value.cpp @@ -127,6 +127,11 @@ TEST_F(PolymorphicValueTest, Struct) { b->*"y" = 2.71828; EXPECT_EQ(b->*"x", PolymorphicValue(2788)); EXPECT_EQ(b->*"y", PolymorphicValue(2.71828)); + + // At this point this struct and the one created earlier have the exact + // same structure and values, so they should compare as equal + EXPECT_EQ(a, b); + b->*"x" = 299792458; b->*"y" = 3.1415926; EXPECT_EQ(b->*"x", PolymorphicValue(299792458));