Skip to content

Commit

Permalink
Perform deep check in StructHandle::operator== (#2811)
Browse files Browse the repository at this point in the history
Previously, this method only compared the pointers held by two
StructHandles. This PR changes it to check that the name, number of
fields, and the DataType and value of each field match.

#2714 (comment)
  • Loading branch information
jacobhinkle authored Aug 20, 2024
1 parent 4fa45ce commit 90623fe
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 5 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/options.cpp
${NVFUSER_SRCS_DIR}/parallel_dimension_map.cpp
${NVFUSER_SRCS_DIR}/parallel_type_bitmap.cpp
${NVFUSER_SRCS_DIR}/polymorphic_value.cpp
${NVFUSER_SRCS_DIR}/predicate_compute.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/add_axioms.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/allocation_order_inference.cpp
Expand Down
41 changes: 41 additions & 0 deletions csrc/polymorphic_value.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on

#include <polymorphic_value.h>
#include <type.h>

namespace nvfuser {

bool StructHandle::operator==(const StructHandle& other) const {
if (struct_ptr_ == other.struct_ptr_) {
return true;
}
const StructType this_type = type();
const StructType other_type = other.type();
if (this_type.name != other_type.name) {
return false;
}
if (this_type.fields.size() != other_type.fields.size()) {
return false;
}
for (size_t i : c10::irange(this_type.fields.size())) {
// Check that fields are in same position, have same type, and have same
// value (recursive)
const StructType::FieldInfo& fa = this_type.fields.at(i);
const StructType::FieldInfo& fb = other_type.fields.at(i);
PolymorphicValue a_val = (*this)->*(fa.name);
PolymorphicValue b_val = other->*(fb.name);
if (fa.name != fb.name || *fa.type != *fb.type ||
!PolymorphicValue_functions::isSame(a_val, b_val)) {
return false;
}
}
return true;
}

} // namespace nvfuser
6 changes: 1 addition & 5 deletions csrc/polymorphic_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,7 @@ class StructHandle {
StructHandle& operator=(const StructHandle& other) = default;
StructHandle& operator=(StructHandle&& other) = default;

//! This is a shallow comparison operator that just checks whether we point to
//! the same exact Struct
bool operator==(const StructHandle& other) const {
return struct_ptr_ == other.struct_ptr_;
}
bool operator==(const StructHandle& other) const;

template <typename T>
bool is() const {
Expand Down
5 changes: 5 additions & 0 deletions tests/cpp/test_polymorphic_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ TEST_F(PolymorphicValueTest, Struct) {
b->*"y" = 2.71828;
EXPECT_EQ(b->*"x", PolymorphicValue(2788));
EXPECT_EQ(b->*"y", PolymorphicValue(2.71828));

// At this point this struct and the one created earlier have the exact
// same structure and values, so they should compare as equal
EXPECT_EQ(a, b);

b->*"x" = 299792458;
b->*"y" = 3.1415926;
EXPECT_EQ(b->*"x", PolymorphicValue(299792458));
Expand Down

0 comments on commit 90623fe

Please sign in to comment.