diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 6ac52ba0564..2e70e36c679 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -131,7 +131,25 @@ class NVF_API IterDomain : public Val { IterDomain* in, Val* factor, bool inner_split, - bool rfactor_domain = false); + bool rfactor_domain); + + // Shortcut of the above split function. Previously, the above split function + // had a default value of false for the rfactor_domain flag. The + // default value is removed as it may implicitly drop + // the rfactor flag when the input is an rfactor iter domain, which + // may not be intended behavior. This + // shortcut is added to provide the backward compatible API with an + // assertion that the input iter domain is not indeed an rfactor domain. + static std::pair split( + IterDomain* in, + Val* factor, + bool inner_split) { + NVF_ERROR( + !in->isRFactorProduct(), + "Unexpected input iter domain. Input should not be an rfactor iter domain: ", + in->toString()); + return IterDomain::split(in, factor, inner_split, false); + } //! Resize an IterDomain by expanding both the left and right sides //! by given widths. The resulting IterDomain has an extent of