-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Perform deep check in StructHandle::operator== (#2811)
Previously, this method only compared the pointers held by two StructHandles. This PR changes it to check that the name, number of fields, and the DataType and value of each field match. #2714 (comment)
- Loading branch information
1 parent
4fa45ce
commit 90623fe
Showing
4 changed files
with
48 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// clang-format off | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. | ||
* All rights reserved. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
// clang-format on | ||
|
||
#include <polymorphic_value.h> | ||
#include <type.h> | ||
|
||
namespace nvfuser { | ||
|
||
bool StructHandle::operator==(const StructHandle& other) const { | ||
if (struct_ptr_ == other.struct_ptr_) { | ||
return true; | ||
} | ||
const StructType this_type = type(); | ||
const StructType other_type = other.type(); | ||
if (this_type.name != other_type.name) { | ||
return false; | ||
} | ||
if (this_type.fields.size() != other_type.fields.size()) { | ||
return false; | ||
} | ||
for (size_t i : c10::irange(this_type.fields.size())) { | ||
// Check that fields are in same position, have same type, and have same | ||
// value (recursive) | ||
const StructType::FieldInfo& fa = this_type.fields.at(i); | ||
const StructType::FieldInfo& fb = other_type.fields.at(i); | ||
PolymorphicValue a_val = (*this)->*(fa.name); | ||
PolymorphicValue b_val = other->*(fb.name); | ||
if (fa.name != fb.name || *fa.type != *fb.type || | ||
!PolymorphicValue_functions::isSame(a_val, b_val)) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
} // namespace nvfuser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters