From 97454332020ad9559d3b41bc17d0e0ce70a1db37 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Wed, 1 Feb 2023 08:52:16 -0500 Subject: [PATCH 01/81] [Unity] Relax VM (#13878) This PR implements a flexible register-based VM to execute relax programs with dynamic shape and control flow. Design: https://github.com/tlc-pack/relax/wiki/Relax-VM-Design. Co-Authored-by: Ziheng Jiang Co-Authored-by: Ruihang Lai Co-Authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Co-Authored-by: Junru Shao Co-Authored-by: Prakalp Srivastava Co-Authored-by: Yong Wu Co-Authored-by: Steven S. Lyubomirsky Co-Authored-by: Tianqi Chen Co-Authored-by: Hongyi Jin <3231950289@qq.com> --- CMakeLists.txt | 2 + include/tvm/relax/exec_builder.h | 181 ++++ include/tvm/runtime/relax_vm/builtin.h | 89 ++ include/tvm/runtime/relax_vm/bytecode.h | 223 +++++ include/tvm/runtime/relax_vm/executable.h | 209 +++++ include/tvm/runtime/relax_vm/memory_manager.h | 142 +++ include/tvm/runtime/relax_vm/vm.h | 148 ++++ python/tvm/relax/__init__.py | 24 + python/tvm/relax/_ffi_api.py | 20 + python/tvm/relax/exec_builder.py | 147 ++++ python/tvm/relax/testing/vm.py | 85 ++ python/tvm/relax/vm.py | 609 +++++++++++++ src/relax/backend/vm/exec_builder.cc | 399 +++++++++ src/runtime/relax_vm/builtin.cc | 445 ++++++++++ src/runtime/relax_vm/bytecode.cc | 68 ++ src/runtime/relax_vm/executable.cc | 576 +++++++++++++ src/runtime/relax_vm/memory_manager.cc | 181 ++++ src/runtime/relax_vm/naive_allocator.h | 65 ++ src/runtime/relax_vm/pooled_allocator.h | 111 +++ src/runtime/relax_vm/vm.cc | 811 ++++++++++++++++++ tests/python/relax/test_vm_execbuilder.py | 262 ++++++ 21 files changed, 4797 insertions(+) create mode 100644 include/tvm/relax/exec_builder.h create mode 100644 include/tvm/runtime/relax_vm/builtin.h create mode 100644 include/tvm/runtime/relax_vm/bytecode.h create mode 100644 include/tvm/runtime/relax_vm/executable.h create mode 100644 include/tvm/runtime/relax_vm/memory_manager.h create mode 100644 include/tvm/runtime/relax_vm/vm.h create mode 100644 python/tvm/relax/__init__.py create mode 100644 python/tvm/relax/_ffi_api.py create mode 100644 python/tvm/relax/exec_builder.py create mode 100644 python/tvm/relax/testing/vm.py create mode 100644 python/tvm/relax/vm.py create mode 100644 src/relax/backend/vm/exec_builder.cc create mode 100644 src/runtime/relax_vm/builtin.cc create mode 100644 src/runtime/relax_vm/bytecode.cc create mode 100644 src/runtime/relax_vm/executable.cc create mode 100644 src/runtime/relax_vm/memory_manager.cc create mode 100644 src/runtime/relax_vm/naive_allocator.h create mode 100644 src/runtime/relax_vm/pooled_allocator.h create mode 100644 src/runtime/relax_vm/vm.cc create mode 100644 tests/python/relax/test_vm_execbuilder.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 818e8b50addb..ed2afc392067 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -289,6 +289,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/driver/*.cc src/support/*.cc src/script/*.cc + src/relax/backend/vm/*.cc ) tvm_file_glob(GLOB CODEGEN_SRCS @@ -335,6 +336,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS src/runtime/*.cc src/runtime/vm/*.cc src/runtime/minrpc/*.cc + src/runtime/relax_vm/*.cc ) if(BUILD_FOR_HEXAGON) diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h new file mode 100644 index 000000000000..03e58392c269 --- /dev/null +++ b/include/tvm/relax/exec_builder.h @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/exec_builder.h + */ +#ifndef TVM_RELAX_EXEC_BUILDER_H_ +#define TVM_RELAX_EXEC_BUILDER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relax { + +namespace vm = tvm::runtime::relax_vm; + +class ExecBuilder; + +/*! + * \brief A builder provides api to build VM executable with instructions. + */ +class ExecBuilderNode : public Object { + public: + /*! + * \brief Declare a function, it is OK to have multiple declarations. + * \param func The function name. + * \param kind The kind of the function. + */ + void DeclareFunction(const std::string& func, vm::VMFuncInfo::FuncKind kind); + /*! + * \brief To annotate the start of a vm function. + * \param func The function name. + * \param num_inputs The number of inputs. + * \param param_names The function parameter names. + * \param kind The kind of the function. + * \param init_register_size Initial setting of register file size. + */ + void EmitFunction(const std::string& func, int64_t num_inputs, + Optional> param_names, + vm::VMFuncInfo::FuncKind kind = vm::VMFuncInfo::FuncKind::kVMFunc, + int64_t init_register_size = 0); + /*! + * \brief Annotate the end of a vm function. + * \param func The function name. + */ + void EndFunction(const std::string& func); + /*! + * \brief Emit a call instruction for a packed function. + * \param func The packed function name. + * \param args The arguments of the function. + * \param ret The return register. + */ + void EmitCall(const std::string& func, std::vector args, vm::RegName ret); + /*! + * \brief Emit a call instruction with func as argument. + * \param func The packed function index. + * \param args The arguments of the function. + * \param ret The return register. + */ + void EmitCall(vm::Instruction::Arg func, std::vector args, vm::RegName ret); + /*! + * \brief Emit a ret instruction. + * \param result The return result. + * \note result must be a register. + */ + void EmitRet(vm::Instruction::Arg result); + /*! + * \brief Emit a goto instruction. + * \param pc_offset The program counter offset as the jump offset. + */ + void EmitGoto(vm::Index pc_offset); + /*! + * \brief Emit an If instruction. + * \param cond The register containing the cond value. + * \param false_offset The program counter offset for the false branch. + * \note result must be a register. + */ + void EmitIf(vm::Instruction::Arg cond, vm::Index false_offset); + /*! + * \brief Get function index by its name. + * \param name The name of the function. + * \return The argument corresponding to the function index. + */ + vm::Instruction::Arg GetFunction(const std::string& name); + /*! + * \brief Convert a constant value something that exec builder can understand. + * + * This function may update the constant pool to include the obj value. + * + * \param value The input constant value + * \return An Arg that represents the result of constant argument. + */ + template + vm::Instruction::Arg ConvertConstant(T value) { + TVMRetValue rv; + rv = value; + return ConvertConstant_(rv); + } + /*! + * \brief Raw access to underlying executable build in progress. + */ + vm::Executable* exec() const; + /*! + * \brief Finalize the build, run formalize and get the final result. + * \note This function should not be called during construction. + */ + ObjectPtr Get(); + /*! + * \brief Create an ExecBuilder. + * \return The ExecBuilder. + */ + TVM_DLL static ExecBuilder Create(); + + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.ExecBuilder"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExecBuilderNode, Object); + + private: + /*! + * \brief Convert a constant value something that exec builder can understand. + * + * This function may update the constant pool to include the obj value. + * + * \param obj The constant value to be emitted + * \return An Arg that represents the result of constant argument. + */ + vm::Instruction::Arg ConvertConstant_(TVMRetValue obj); + + /*! + * \brief A helper function to check if an executable is legal by checking if registers are used + * properly + */ + void CheckExecutable(); + /*! + * \brief Formalize the executable. + */ + void Formalize(); + + /*! \brief The mutable internal executable. */ + ObjectPtr exec_; // mutable + /*! \brief internal dedup map when creating index for a new constant */ + std::unordered_map const_dedup_map_; +}; + +class ExecBuilder : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ExecBuilder, ObjectRef, ExecBuilderNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_EXEC_BUILDER_H_ diff --git a/include/tvm/runtime/relax_vm/builtin.h b/include/tvm/runtime/relax_vm/builtin.h new file mode 100644 index 000000000000..b994e44ae88d --- /dev/null +++ b/include/tvm/runtime/relax_vm/builtin.h @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/builtin.h + * \brief Builtin runtime APIs. + */ +#ifndef TVM_RUNTIME_RELAX_VM_BUILTIN_H_ +#define TVM_RUNTIME_RELAX_VM_BUILTIN_H_ + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! + * \brief Op code used in built-in match-shape function. + * + * The function takes the following signature: + + * MatchShape(input_shape, shape_heap, n, c[0], r[0], c[1], r[1], ... c[n], r[n], err_ctx) + * + * This function provides runtime shape population and checking support for match-cast. + * When a shape variable appears in the first time, we should load the shape and + * populate the variable. When a shape variable already appears, we should + * assert that it already equals an existing shape value. + * + * NOTE: It is OK to pass nullptr shape_heap if all code are AssertEqualToImm. + */ +enum class MatchShapeCode : int { + /*! + * \brief Perform an assertion that shape equals immediate. + * + * assert input_shape[i] == r[i] + */ + kAssertEqualToImm = 0, + /*! + * \brief This is the first time we see a symbolic shape variable, store to heap. + * + * shape_heap[r[i]] = input_shape[i] + */ + kStoreToHeap = 1, + /*! + * \brief skip and do not do anything. + */ + kNoOp = 2, + /*! + * \brief Peform an assertion that the shape equals a loaded value. + * + * assert input_shape[i] == shape_heap[r[i]] + */ + kAssertEqualToLoad = 3, +}; + +/*! + * \brief Op code used in builtin function MakeShape. + * + * MakeShape(shape_heap, n, c[0], r[0], c[1], r[1], ... c[n], r[n]). + * + * \note It is OK to pass nullptr to shape_heap if all code are UseImm. + */ +enum class MakeShapeCode : int { + /*! \brief Use the following r[i] as immediate shape value. */ + kUseImm = 0, + /*! + * \brief Load shape value from the shape_heap[[r[i]]. + */ + kLoadShape = 1, +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RELAX_VM_BUILTIN_H_ diff --git a/include/tvm/runtime/relax_vm/bytecode.h b/include/tvm/runtime/relax_vm/bytecode.h new file mode 100644 index 000000000000..91d182325886 --- /dev/null +++ b/include/tvm/runtime/relax_vm/bytecode.h @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/bytecode.h + * \brief The bytecode for the virtual machine. + */ +#ifndef TVM_RUNTIME_RELAX_VM_BYTECODE_H_ +#define TVM_RUNTIME_RELAX_VM_BYTECODE_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! + * \brief The storage type for the bytecode in the VM. + */ +using ExecWord = int64_t; + +/*! \brief A register name. */ +using RegName = ExecWord; + +/*! + * \brief An alias for the integer type used ubiquitously in the VM. + */ +using Index = ExecWord; + +/*! + * \brief An enumeration of Relax's opcodes. + * + * The opcode is used to implement instruction + * as a tagged union. + */ +enum class Opcode { + Call = 1U, + Ret = 2U, + Goto = 3U, + If = 4U, +}; + +/*! \brief A single virtual machine instruction. + * + * The representation of the instruction is as + * a tagged union. + * + * The first field represents which instruction, + * and by extension which field of the union + * is active. + */ +struct Instruction { + /*! \brief The number of bit for storing value. */ + static constexpr ExecWord kKindBit = 8; + /*! \brief The number of bit for storing value. */ + static constexpr ExecWord kValueBit = sizeof(ExecWord) * 8 - kKindBit; + /*! \brief The bit mask of the value part. */ + static constexpr ExecWord kValueMask = (static_cast(1) << kValueBit) - 1; + /*! \brief Maximum possible value, use 1 bit for sign. */ + static constexpr ExecWord kValueMaxLimit = (static_cast(1) << (kValueBit - 1)) - 1; + /*! \brief Minimum possible value, remove 1 slot to keep things symmetric. */ + static constexpr ExecWord kValueMinLimit = -kValueMaxLimit; + /*! \brief Begining of special register section. */ + static constexpr RegName kBeginSpecialReg = static_cast(1) << 54; + /*! \brief Random magic number that represents void argument, indicate null value */ + static constexpr RegName kVoidRegister = kBeginSpecialReg + 0; + /*! \brief Random magic number that represents the VM context */ + static constexpr RegName kVMRegister = kBeginSpecialReg + 1; + /*! + * \brief The kind of instruction's argument. + */ + enum class ArgKind : int { kRegister = 0, kImmediate = 1, kConstIdx = 2, kFuncIdx = 3 }; + /*! + * \brief The auxiliary data structure for instruction argument. + */ + class Arg { + public: + /*! \brief Construct a void argument. */ + Arg() : data_(Instruction::kVoidRegister) {} + /*! + * \brief construct Arg from data. + * \param data The data repr. + */ + static Arg FromData(ExecWord data) { return Arg(data); } + /*! + * \brief construct a register Arg. + * \param reg The register number. + * \return The constructed arg. + */ + static Arg Register(RegName reg) { return Arg(ArgKind::kRegister, reg); } + /*! + * \brief construct a ConstIdx arg. + * \param index The constant index. + * \return The constructed arg. + */ + static Arg ConstIdx(Index index) { return Arg(ArgKind::kConstIdx, index); } + /*! + * \brief construct a immediate arg. + * \param imm_value The immediate value. + * \return The constructed arg. + */ + static Arg Immediate(int64_t imm_value) { return Arg(ArgKind::kImmediate, imm_value); } + /*! + * \brief construct a FuncIdx arg. + * \param index The func index in the function table. + * \return The constructed arg. + */ + static Arg FuncIdx(Index index) { return Arg(ArgKind::kFuncIdx, index); } + /*! + * \brief Get the kind of argument.. + * \return The kind of argument. + */ + ArgKind kind() const { + uint8_t kind = (data_ >> kValueBit) & 0xFF; + return Instruction::ArgKind(kind); + } + /*! + * \brief Get the value of argument. + * \return The value of argument. + * \note We store both positive and negative values by sign extension. + */ + ExecWord value() const { return ((data_ & kValueMask) << kKindBit) >> kKindBit; } + /*! + * \brief Get the raw data repr of the arg. + * \return The raw data. + */ + ExecWord data() const { return data_; } + + private: + /*! \brief Construct from the data. */ + explicit Arg(ExecWord data) : data_(data) {} + /*! \brief Construct from the kind and value. */ + Arg(ArgKind kind, Index value) { + ICHECK_LE(value, kValueMaxLimit); + ICHECK_GE(value, kValueMinLimit); + data_ = (static_cast(kind) << kValueBit) | (value & kValueMask); + } + /*! \brief The underlying stored data. */ + ExecWord data_; + }; + /*! \brief The instruction opcode. */ + Opcode op; + union { + struct /* Call */ { + /*! \brief The destination register. */ + RegName dst; + /*! \brief The index into the packed function table. */ + Index func_idx; + /*! \brief The number of arguments to the packed function. */ + Index num_args; + /*! \brief The arguments of the packed function. */ + Arg* args; + }; + struct /* Ret */ { + /*! \brief The return result. */ + RegName result; + }; + struct /* Goto */ { + /*! \brief The jump offset. */ + Index pc_offset; + }; + struct /* If */ { + /*! \brief The register containing the cond value. */ + RegName cond; + /*! \brief The program counter offset for the false branch. */ + Index false_offset; + }; + }; + /*! + * \brief Construct a Call instruction. + * \param func_idx The index of the function to call. + * \param num_args The number of arguments. + * \param args The input arguments. + * \param dst The destination register. + * \return The call instruction. + */ + static Instruction Call(Index func_idx, Index num_args, Arg* args, RegName dst); + /*! + * \brief Construct a return instruction. + * \param result The register containing the return value. + * \return The return instruction. + */ + static Instruction Ret(RegName result); + /*! + * \brief Construct a goto instruction. + * \param pc_offset The register containing the jump offset. + * \return The goto instruction. + */ + static Instruction Goto(RegName pc_offset); + /*! + * \brief Construct an If instruction. + * \param cond The register containing the cond value. + * \param false_offset The program counter offset for the false branch. + * \return The If instruction. + */ + static Instruction If(RegName cond, Index false_offset); +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_BYTECODE_H_ diff --git a/include/tvm/runtime/relax_vm/executable.h b/include/tvm/runtime/relax_vm/executable.h new file mode 100644 index 000000000000..316a03a3b8cf --- /dev/null +++ b/include/tvm/runtime/relax_vm/executable.h @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/executable.h + */ +#ifndef TVM_RUNTIME_RELAX_VM_EXECUTABLE_H_ +#define TVM_RUNTIME_RELAX_VM_EXECUTABLE_H_ + +#include +#include +#include + +#include +#include +#include + +#include "./bytecode.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! + * \brief Information entry in executable function table. + * + * Contains metadata about the compiled function, as + * well as the compiled VM instructions. + */ +struct VMFuncInfo { + /*! \brief kind of the function. */ + enum class FuncKind : int { + /*! \brief system level packed function */ + kPackedFunc = 0, + /*! \brief VM function. */ + kVMFunc = 1, + /*! \brief VMTIR function. */ + kVMTIRFunc = 2, + }; + /*! \brief The kind of function. */ + FuncKind kind; + /*! \brief The function's name, global symbol */ + std::string name; + /*! \brief The start instruction index of the function. */ + Index start_instr = 0; + /*! \brief The end instruction index of the function. */ + Index end_instr = 0; + /*! \brief The number of arguments of the function. */ + Index num_args = 0; + /*! \brief The register file size of the function. */ + Index register_file_size = 0; + /*! \brief The function parameter names.*/ + std::vector param_names; + + // defined customized loader save + void Save(dmlc::Stream* writer) const; + bool Load(dmlc::Stream* reader); +}; + +/*! + * \brief The executable emitted by the VM compiler. + * + * The executable contains information (e.g. data in different memory regions) + * to run in a virtual machine. + */ +class Executable : public runtime::ModuleNode { + public: + /*! + * \brief Get a PackedFunc from the executable module. + * \param name the name of the function. + * \param sptr_to_self The shared_ptr that points to this module node. + * \return PackedFunc or nullptr when it is not available. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + /*! + * \brief Print the detailed statistics of the given code, i.e. number of + * globals and constants, etc. + * \return The statistics represented by a string. + */ + std::string Stats() const; + /*! + * \brief Get the i-th instruction from the executable. + * \param i The index of the instruction to be fetched. + * \return The instruction. + */ + Instruction GetInstruction(Index i) const; + /*! + * \brief Set j-th byte data of i-th instruction to val. + * \param i The index of the instruction to be updated. + * \param j The index of the byte data of the instruction to be updated. + * \param val The value to be set + */ + void SetInstructionData(Index i, Index j, ExecWord val); + /*! + * \brief Print the instructions as text format. + * \return The text format of the instructions. + */ + String AsText() const; + /*! + * \brief Print the instructions as python program. + * \return The python program of the instructions, represented by a string. + */ + String AsPython() const; + /*! + * \brief Write the Executable to the binary stream in serialized form. + * \param stream The binary stream to save the executable to. + */ + void SaveToBinary(dmlc::Stream* stream) final; + /*! + * \brief Load Executable from the binary stream in serialized form. + * \param stream The binary stream that load the executable from. + * \return The loaded executable, in the form of a `runtime::Module`. + */ + static Module LoadFromBinary(void* stream); + /*! + * \brief Write the Executable to the provided path as a file containing its serialized content. + * \param file_name The name of the file to write the serialized data to. + * \param format The target format of the saved file. + */ + void SaveToFile(const std::string& file_name, const std::string& format) final; + /*! + * \brief Load Executable from the file. + * \param file_name The path of the file that load the executable from. + * \return The loaded executable, in the form of a `runtime::Module`. + */ + static Module LoadFromFile(const std::string& file_name); + + /*! \brief The virtual machine's function table. */ + std::vector func_table; + /*! \brief A map from globals (as strings) to their index in the function map. */ + std::unordered_map func_map; + /*! \brief The global constant pool. */ + std::vector constants; + /*! \brief The offset of instruction. */ + std::vector instr_offset; + /*! \brief The byte data of instruction. */ + std::vector instr_data; + + virtual ~Executable() {} + + const char* type_key() const final { return "relax.Executable"; } + + private: + /*! + * \brief Save the globals. + * \param strm The input stream. + */ + void SaveGlobalSection(dmlc::Stream* strm); + /*! + * \brief Save the constant pool. + * \param strm The input stream. + */ + void SaveConstantSection(dmlc::Stream* strm); + /*! + * \brief Save the instructions. + * \param strm The input stream. + */ + void SaveCodeSection(dmlc::Stream* strm); + /*! + * \brief Save the packed functions. + * \param strm The input stream. + */ + void SavePackedFuncNames(dmlc::Stream* strm); + /*! + * \brief Load the globals. + * \param strm The input stream. + */ + void LoadGlobalSection(dmlc::Stream* strm); + /*! + * \brief Load the constant pool. + * \param strm The input stream. + */ + void LoadConstantSection(dmlc::Stream* strm); + /*! + * \brief Load the instructions. + * \param strm The input stream. + */ + void LoadCodeSection(dmlc::Stream* strm); + /*! + * \brief Save the packed functions. + * \param strm The input stream. + */ + void LoadPackedFuncNames(dmlc::Stream* strm); +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +namespace dmlc { +DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::relax_vm::VMFuncInfo, true); +} // namespace dmlc +#endif // TVM_RUNTIME_RELAX_VM_EXECUTABLE_H_ diff --git a/include/tvm/runtime/relax_vm/memory_manager.h b/include/tvm/runtime/relax_vm/memory_manager.h new file mode 100644 index 000000000000..e5ae8cfcfbaa --- /dev/null +++ b/include/tvm/runtime/relax_vm/memory_manager.h @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/memory_manager.h + * \brief Abstract device memory management API + */ +#ifndef TVM_RUNTIME_RELAX_VM_MEMORY_MANAGER_H_ +#define TVM_RUNTIME_RELAX_VM_MEMORY_MANAGER_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +struct Buffer { + /*! \brief The pointer to the allocated block of memory. */ + void* data{nullptr}; + /*! \brief The size of the block. */ + size_t size{0}; + /*! \brief The device of the allocated buffers. */ + Device device; +}; + +enum AllocatorType { + kNaive = 1, + kPooled, +}; + +class Allocator { + public: + explicit Allocator(AllocatorType type) : type_(type) {} + virtual ~Allocator() = default; + /*! \brief Allocate an empty NDArray using from the allocator. + * \param shape The shape of the NDArray. + * \param dtype The datatype of the NDArray. + * \param dev The device where the array is allocated. + * \return The empty NDArray. + */ + runtime::NDArray Empty(std::vector shape, DLDataType dtype, Device dev); + /*! \brief Return the allocator type. */ + inline AllocatorType type() const { return type_; } + /*! \brief Allocate a buffer given a size, alignment and type. + * \param nbytes The size of the buffer. + * \param alignment The alignment of the buffer. + * \param type_hint A type hint to the allocator. + * \return A sized allocation in the form of a buffer. + */ + virtual Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) = 0; + /*! \brief Free a buffer allocated by the allocator. + * \param buffer The buffer to free. + */ + virtual void Free(const Buffer& buffer) = 0; + + private: + AllocatorType type_; +}; + +class MemoryManager { + public: + static MemoryManager* Global(); + /*! + * \brief Get or create an allocator given the device and allocator type. + * \param dev The TVM device + * \param type The allocator type + * \return The memory allocator. + */ + static Allocator* GetOrCreateAllocator(Device dev, AllocatorType type); + /*! + * \brief Get an allocator given the device. + * \param dev The TVM device + * \return The memory allocator. + */ + static Allocator* GetAllocator(Device dev); + + private: + MemoryManager() {} + + private: + std::mutex mutex_; + std::unordered_map> allocators_; +}; + +/*! \brief An object representing a storage allocation. */ +class StorageObj : public Object { + public: + /*! \brief The index into the VM function table. */ + Buffer buffer; + + /*! \brief Allocate an NDArray from a given piece of storage. */ + runtime::NDArray AllocNDArray(uint64_t offset, ShapeTuple shape, DLDataType dtype); + + /*! \brief The deleter for an NDArray when allocated from underlying storage. */ + static void Deleter(Object* ptr); + + ~StorageObj() { + auto alloc = MemoryManager::Global()->GetAllocator(buffer.device); + alloc->Free(buffer); + } + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.Storage"; + TVM_DECLARE_FINAL_OBJECT_INFO(StorageObj, Object); +}; + +/*! \brief reference to storage. */ +class Storage : public ObjectRef { + public: + explicit Storage(Buffer buffer); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Storage, ObjectRef, StorageObj); +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_MEMORY_MANAGER_H_ diff --git a/include/tvm/runtime/relax_vm/vm.h b/include/tvm/runtime/relax_vm/vm.h new file mode 100644 index 000000000000..cfe388090456 --- /dev/null +++ b/include/tvm/runtime/relax_vm/vm.h @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/vm.h + */ +#ifndef TVM_RUNTIME_RELAX_VM_VM_H_ +#define TVM_RUNTIME_RELAX_VM_VM_H_ + +#include +#include +#include + +#include "./bytecode.h" +#include "./executable.h" +#include "./memory_manager.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! + * \brief An object representing a vm closure. + */ +class VMClosureObj : public ClosureObj { + public: + /*! + * \brief The function name. The function could be any + * function object that is compatible to the VM runtime. + */ + String func_name; + + /*! + * \brief The implementation of the Closure. + * \note This function takes context pointer(VirtualMachine*) + * as the first argument. The rest of arguments follows + * the same arguments as the normal function call. + */ + PackedFunc impl; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.vm.Closure"; + TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, ClosureObj); +}; + +/*! \brief reference to closure. */ +class VMClosure : public Closure { + public: + VMClosure(String func_name, PackedFunc impl); + TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, Closure, VMClosureObj); + + /*! + * \brief Create another PackedFunc with last arguments already bound to last_args. + * + * This is a helper function to create captured closures. + * \param func The input func, can be a VMClosure or PackedFunc. + * \param last_args The arguments to bound to in the end of the function. + * \note The new function takes in arguments and append the last_args in the end. + */ + static PackedFunc BindLastArgs(PackedFunc func, std::vector last_args); +}; + +/*! + * \brief The virtual machine. + * + * The virtual machine contains all the current execution state, + * as well as the executable. + * + * The goal is to have a single self-contained object, + * enabling one to easily pass around VMs, execute them on + * multiple threads, or serialize them to disk or over the + * wire. + */ +class VirtualMachine : public runtime::ModuleNode { + public: + /*! + * \brief Initialize the virtual machine for a set of devices. + * \param devices The set of TVM devices. + * \param alloc_types The allocator types for each device. + */ + virtual void Init(const std::vector& devices, + const std::vector& alloc_types) = 0; + /*! + * \brief Load the executable for the virtual machine. + * \param exec The executable. + */ + virtual void LoadExecutable(ObjectPtr exec) = 0; + /*! + * \brief Get global function in the VM. + * \param func_name The name of the function. + * \return The closure + */ + virtual VMClosure GetClosure(const String& func_name) = 0; + /*! + * \brief Invoke closure or packed function using PackedFunc convention. + * \param closure_or_packedfunc A VM closure or a packed_func. + * \param args The input arguments. + * \param rv The return value. + */ + virtual void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args, + TVMRetValue* rv) = 0; + /*! + * \brief Create a specific instance of VM. + * \return Created VM + */ + static ObjectPtr Create(); + /*! + * \brief Helper function for vm closure functions to get the context ptr + * \param arg The argument value. + */ + static VirtualMachine* GetContextPtr(TVMArgValue arg) { + return static_cast(arg.operator void*()); + } + + ~VirtualMachine() {} + + const char* type_key() const final { return "relax.VirtualMachine"; } + + //-------------------------------------------------------------------------- + // The following section contains states that other builtin can depend on + //-------------------------------------------------------------------------- + /*! \brief The memory allocators. */ + std::vector allocators; + /*! \brief Runtime physical device list. */ + std::vector devices; +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_VM_H_ diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py new file mode 100644 index 000000000000..c070fa479188 --- /dev/null +++ b/python/tvm/relax/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, wrong-import-position +"""The Relax IR namespace containing the IR, type, operator, builder, vm, etc.""" +from . import exec_builder +from . import vm + +# VM +from .exec_builder import ExecBuilder +from .vm import VirtualMachine diff --git a/python/tvm/relax/_ffi_api.py b/python/tvm/relax/_ffi_api.py new file mode 100644 index 000000000000..a127e1c81378 --- /dev/null +++ b/python/tvm/relax/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI API for Relax.""" +import tvm._ffi + +tvm._ffi._init_api("relax", __name__) diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py new file mode 100644 index 000000000000..1e28c967d18f --- /dev/null +++ b/python/tvm/relax/exec_builder.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""A builder to build Relax VM executable.""" +from enum import IntEnum +from typing import Optional, Union, List +import tvm +from tvm.runtime import Object +from tvm.runtime.container import ShapeTuple +from .vm import Executable +from . import _ffi_api + + +class SpecialReg(IntEnum): + """Magic numbers that represent special registers in vm.""" + + VOID_ARG = (1 << 54) + 0 + VM_STATE = (1 << 54) + 1 + + +class VMFuncKind(IntEnum): + """VM function kind code.""" + + PACKED_FUNC = 0 + VM_FUNC = 1 + + +class VMFuncScope(object): + """An object corresponds to each VM function, working as a context manager.""" + + stack: List["VMFuncScope"] = [] + + def __init__(self, exit_callback): + self.exit_callback = exit_callback + + def __enter__(self): + VMFuncScope.stack.append(self) + return self + + def __exit__(self, ptype, value, trace): + VMFuncScope.stack.pop() + self.exit_callback() + + +@tvm._ffi.register_object("relax.ExecBuilder") +class ExecBuilder(Object): + """A builder to emit instructions and build executable for the virtual machine.""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__(_ffi_api.ExecBuilderCreate) # type: ignore + + def r(self, idx: int) -> int: + """set instruction's argument as a register.""" + return _ffi_api.ExecBuilderR(self, idx) # type: ignore + + def imm(self, value: int) -> int: + """set instruction's argument as an immediate.""" + return _ffi_api.ExecBuilderImm(self, value) # type: ignore + + def c(self, idx: int) -> int: + """set instruction's argument as a constant.""" + return _ffi_api.ExecBuilderC(self, idx) # type: ignore + + def f(self, name: str) -> int: + """set instruction's argument as a function.""" + return _ffi_api.ExecBuilderF(self, name) # type: ignore + + def void_arg(self) -> int: + return self.r(SpecialReg.VOID_ARG) + + def vm_state(self) -> int: + return self.r(SpecialReg.VM_STATE) + + def declare_function(self, func_name: str, kind: VMFuncKind = VMFuncKind.PACKED_FUNC) -> None: + """Declare a function""" + _ffi_api.ExecBuilderDecalreFunction(self, func_name, kind) # type: ignore + + def function( + self, func_name: str, num_inputs: Optional[int] = 0, param_names: List[str] = None + ) -> VMFuncScope: + """annotate a VM function.""" + _ffi_api.ExecBuilderEmitFunction(self, func_name, num_inputs, param_names) # type: ignore + return VMFuncScope(lambda: _ffi_api.ExecBuilderEndFunction(self, func_name)) # type: ignore + + def _check_scope(self) -> None: + if len(VMFuncScope.stack) == 0: + raise ValueError("emit should happen in a function scope") + + def convert_constant(self, const: object) -> int: + return _ffi_api.ExecBuilderConvertConstant(self, const) # type: ignore + + def emit_call( + self, + name: str, + args: Optional[List[Union[tvm.nd.NDArray, tvm.DataType]]] = None, + dst: int = None, + ) -> None: + """emit a call instruction which calls a packed function.""" + self._check_scope() + if dst is None: + dst = SpecialReg.VOID_ARG + args_ = [] + if args is not None: + for arg in args: + if isinstance(arg, tuple): + shape_tuple = ShapeTuple(arg) + new_arg = self.convert_constant(shape_tuple) + args_.append(new_arg) + elif isinstance(arg, (tvm.nd.NDArray, tvm.DataType, ShapeTuple)): + new_arg = self.convert_constant(arg) + args_.append(new_arg) + else: + args_.append(arg) + _ffi_api.ExecBuilderEmitCall(self, name, args_, dst) # type: ignore + + def emit_ret(self, result: int) -> None: + """emit a return instruction""" + self._check_scope() + _ffi_api.ExecBuilderEmitRet(self, result) # type: ignore + + def emit_goto(self, pc_offset): + """emit a goto instruction""" + self._check_scope() + _ffi_api.ExecBuilderEmitGoto(self, pc_offset) # type: ignore + + def emit_if(self, cond, false_offset): + """emit an if instruction""" + self._check_scope() + _ffi_api.ExecBuilderEmitIf(self, cond, false_offset) # type: ignore + + def get(self) -> Executable: + """return the executable""" + return Executable(_ffi_api.ExecBuilderGet(self)) # type: ignore diff --git a/python/tvm/relax/testing/vm.py b/python/tvm/relax/testing/vm.py new file mode 100644 index 000000000000..79da54be1010 --- /dev/null +++ b/python/tvm/relax/testing/vm.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Testing utilities for relax VM""" +from typing import Any, List +import numpy as np # type: ignore + +import tvm +from tvm import relax +from tvm.runtime.object import Object + + +@tvm.register_func("test.vm.move") +def move(src): + return src + + +@tvm.register_func("test.vm.add") +def add(a, b): + ret = a.numpy() + b.numpy() + return tvm.nd.array(ret) + + +@tvm.register_func("test.vm.mul") +def mul(a, b): + ret = a.numpy() * b.numpy() + return tvm.nd.array(ret) + + +@tvm.register_func("test.vm.equal_zero") +def equal_zero(a): + ret = np.all((a.numpy() == 0)) + return tvm.nd.array(ret) + + +@tvm.register_func("test.vm.subtract_one") +def subtract_one(a): + ret = np.subtract(a.numpy(), 1) + return tvm.nd.array(ret) + + +@tvm.register_func("test.vm.identity") +def identity_packed(a, b): + b[:] = tvm.nd.array(a.numpy()) + + +@tvm.register_func("test.vm.tile") +def tile_packed(a, b): + b[:] = tvm.nd.array(np.tile(a.numpy(), (1, 2))) + + +@tvm.register_func("test.vm.add_scalar") +def add_scalar(a, b): + return a + b + + +@tvm.register_func("test.vm.get_device_id") +def get_device_id(device): + return device.device_id + + +def check_saved_func(vm: relax.VirtualMachine, func_name: str, *inputs: List[Any]) -> Object: + # uses save_function to create a closure with the given inputs + # and ensure the result is the same + # (assumes the functions return tensors and that they're idempotent) + saved_name = f"{func_name}_saved" + vm.save_function(func_name, saved_name, *inputs) + res1 = vm[func_name](*inputs) + res2 = vm[saved_name]() + tvm.testing.assert_allclose(res1.numpy(), res2.numpy(), rtol=1e-7, atol=1e-7) + return res1 diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py new file mode 100644 index 000000000000..ba16dfb07985 --- /dev/null +++ b/python/tvm/relax/vm.py @@ -0,0 +1,609 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, redefined-builtin, no-else-return +"""The Relax virtual machine""" +from typing import Callable, List, Optional, Union, Dict, Tuple, Any +import numpy as np # type: ignore + +from tvm._ffi import base as _base +import tvm +from tvm import relax +from tvm.ir.module import IRModule +from tvm.runtime import Device, Module, PackedFunc, container +from tvm.runtime.object import Object +from tvm.tir.function import PrimFunc +from . import _ffi_api +from ..rpc.base import RPC_SESS_MASK + + +class Executable(object): + """The executable object emitted by the VM compiler or the ExecBuilder.""" + + def __init__(self, mod: Module): + self.mod = mod + self._stats = self.mod["stats"] + self._as_text = self.mod["as_text"] + self._as_python = self.mod["as_python"] + + def stats(self) -> str: + """print the detailed statistics of the executable.""" + return self._stats() + + def as_text(self) -> str: + """print the instructions as text format.""" + return self._as_text() + + def as_python(self) -> str: + """print the instructions as python program.""" + return self._as_python() + + +class VirtualMachine(object): + """Relax VM runtime.""" + + NAIVE_ALLOCATOR = 1 + POOLED_ALLOCATOR = 2 + + def __init__( + self, + exec: Union[Executable, Module], + device: Union[Device, List[Device]], + memory_cfg: Optional[Union[str, Dict[Device, str]]] = None, + ) -> None: + """ + Construct a VirtualMachine wrapper object. + + Parameters + ---------- + exec: Union[Executable, Module] + The VM executable or Runtime Module + + device : Union[Device, List[Device]] + The device to deploy the module. + + memory_cfg : Optional[Union[str, Dict[Device, str]]] + Config the type of memory allocator. The allocator type can be ["naive", + "pooled"]. If memory_cfg is None, all devices will use pooled allocator + by default. If memory_cfg is string, all devices will use the specified + allocator type. If memory_cfg is a dict, each device uses the allocator + type specified in the dict, or pooled allocator if not specified in the + dict. + """ + self.module = ( + exec.mod["vm_load_executable"]() + if isinstance(exec, Executable) + else exec["vm_load_executable"]() + ) + self._invoke_closure = self.module["invoke_closure"] + self._save_function = self.module["save_function"] + self._set_input = self.module["set_input"] + self._invoke_stateful = self.module["invoke_stateful"] + self._get_output = self.module["get_output"] + self._get_output_arity = self.module["get_output_arity"] + self._get_function_arity = self.module["get_function_arity"] + self._get_function_param_name = self.module["get_function_param_name"] + self._setup_device(device, memory_cfg) + + def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) -> None: + """init devices and allocators.""" + devs = dev + if not isinstance(dev, (list, tuple)): + if not isinstance(dev, tvm.runtime.Device): + raise TypeError( + "dev is expected to be Device or \ + List[Device]" + ) + devs = [dev] + + if any(dev.device_type % RPC_SESS_MASK == tvm.cpu().device_type for dev in devs[:-1]): + raise RuntimeError( + "CPU host is required to be the last element of the device list if provided." + ) + + # CPU is required for executing shape functions + if devs[-1].device_type % RPC_SESS_MASK != tvm.cpu().device_type: + devs.append(tvm.cpu()) + + default_alloc_type = VirtualMachine.POOLED_ALLOCATOR + if memory_cfg is None: + memory_cfg = {} + elif isinstance(memory_cfg, str): + assert memory_cfg in ["naive", "pooled"] + if memory_cfg == "naive": + default_alloc_type = VirtualMachine.NAIVE_ALLOCATOR + memory_cfg = {} + elif not isinstance(memory_cfg, dict): + raise TypeError( + "memory_cfg is expected be string or dictionary, " + + "but received {}".format(type(memory_cfg)) + ) + init_args = [] + for device in devs: + init_args.append(device.device_type % RPC_SESS_MASK) + init_args.append(device.device_id) + alloc_type = memory_cfg[device] if device in memory_cfg else default_alloc_type + init_args.append(alloc_type) + self.module["vm_initialization"](*init_args) + + def __getitem__(self, key: str) -> PackedFunc: + return self.module[key] + + def invoke_closure(self, closure: Object, *args: Any) -> Object: + """Invoke a closure. + + Parameters + ---------- + closure : Object + The VMClosure Object. + + args : list[tvm.runtime.NDArray] or list[np.ndarray] + The arguments to the closure. + + Returns + ------- + result : Object + The output. + """ + return self._invoke_closure(closure, *args) + + def save_function( + self, + func_name: str, + saved_name: str, + *args: List[Any], + include_return: bool = True, + **kwargs: Dict[str, Any], + ) -> None: + """ + Convenience function. Takes a function from the module and saves + a `PackedFunc` that, when called, will invoke the function with the given arguments. + The `PackedFunc` can be accessed from the module using `saved_name`. + This is included to facilitate timing trials: + Invoking the returned `PackedFunc` will have less overhead from dictionary lookups + than normally running through the VM. + + If the saved name is taken, it can be overridden, though it cannot override + the name of a function defined in the Relax source. + + This is really creating a closure, but the function has a different name + to avoid confusion with `invoke_closure` (they are not meant to be used together). + + Parameters + ---------- + func_name : str + The function that should be packaged up. + + saved_name : str + The name that the resulting closure should be saved under. + + include_return : bool + Whether the saved PackedFunc should return its output. + If timing over RPC, it may not be desirable to send output + between machines. + + args : List[Any] + The arguments to package up with the function. + + kwargs : Dict[str, Any] + Any named arguments to package up with the function + """ + cargs: List[Any] = [] + if kwargs: + args = self._convert_func_named_args(func_name, args, **kwargs) + for arg in args: + self._convert(arg, cargs) + self._save_function(func_name, saved_name, int(include_return), *cargs) + + def _convert(self, arg: Any, cargs: List) -> None: + """helper function to convert arguments to vm function.""" + + def _gettype(arg): + if isinstance(arg, np.float16): + return "float16" + elif isinstance(arg, (_base.integer_types, bool)): + return "int32" + else: + return "float32" + + if isinstance(arg, Object): + cargs.append(arg) + elif isinstance(arg, np.ndarray): + nd_arr = tvm.nd.array(arg, device=tvm.cpu(0)) + cargs.append(nd_arr) + elif isinstance(arg, tvm.runtime.NDArray): + cargs.append(arg) + elif isinstance(arg, (tuple, list)): + field_args: List[Any] = [] + for field in arg: + self._convert(field, field_args) + cargs.append(container.tuple_object(field_args)) + elif isinstance(arg, (_base.numeric_types, bool)): + dtype = _gettype(arg) + value = tvm.nd.array(np.array(arg, dtype=dtype), device=tvm.cpu(0)) + cargs.append(value) + elif isinstance(arg, str): + cargs.append(arg) + else: + raise TypeError("Unsupported type: %s" % (type(arg))) + + def _convert_func_named_args(self, func_name: str, args: Any, **kwargs: Any) -> Any: + """ + Takes named function parameters and returns a list of those needed, + in the order they should appear + """ + # kwargs can be a super set of the required function parameters. + # We only find the ones that are needed. + func_arity = self._get_function_arity(func_name) + func_params = [self._get_function_param_name(func_name, i) for i in range(func_arity)] + new_args = [None] * len(func_params) + cnt = 0 + for k in kwargs: + if k in func_params: + idx = func_params.index(k) + new_args[idx] = kwargs[k] + cnt += 1 + else: + print(f'Warning: Keyword argument "{k}" is unused in {func_name}') + assert len(args) + cnt == len(func_params) + idx = 0 + for i, arg in enumerate(new_args): + if arg is None: + new_args[i] = args[idx] + idx += 1 + return new_args + + def set_input(self, func_name: str, *args: Any, **kwargs: Any) -> None: + """Set the inputs to a function. + This interface works when using VM over RPC by internally converting NDArray in + the arguments to DLTensor, which is supported in RPC where remote could only + have a minimal C runtime. + + Note: If `set_input` is used, the function *must* be called using `invoke_stateful` + and the results must be obtained using `get_outputs`. + + Parameters + ---------- + func_name : str + The name of the function. + args: List[tvm.runtime.NDArray] or List[np.ndarray] + The arguments to the function. + kwargs: dict of str to tvm.runtime.NDArray or np.ndarray + Named arguments to the function. + """ + cargs: List[Any] = [] + + if kwargs: + args = self._convert_func_named_args(func_name, args, **kwargs) + + for arg in args: + self._convert(arg, cargs) + + self._set_input(func_name, *cargs) + + def invoke_stateful(self, func_name: str) -> None: + """ + Call the named function from the VM module using the arguments set using `set_input`. + It is an error to call `invoke_stateful` without using `set_input` first + (even if it's to set 0 inputs); conversely, if `set_input` has been called, + it is an error to call the function without using `invoke_stateful`. + + The results of the call can be obtained by calling `get_outputs`. + + Parameters + ---------- + func_name: str + The name of the function to call. + """ + self._invoke_stateful(func_name) + + def get_outputs(self, func_name: str) -> Union[tvm.Object, Tuple[Any]]: + """ + Get the value output by the function by the given name + after a call of `invoke_stateful`. + + It is an error to call this function without first calling `invoke_stateful`. + + Parameters + ---------- + func_name: str + The name of the function whose output should be fetched. + + Returns + ------- + ret: Union[tvm.Object, Tuple[Any]] + The result of the earlier call to the function via `invoke_stateful`. + If the result is a tuple, it returns a list of the fields. + The fields are potentially also tuples, so these can be arbitrily nested. + """ + # to deal with potentially nested tuples, we need to query for arity recursively + def get_output_rec(func_name, *idx): + arity = self._get_output_arity(func_name, *idx) + if arity == -1: + return self._get_output(func_name, *idx) + # otherwise we need to specify more indices + idx_list = list(idx) + return tuple(get_output_rec(func_name, *(idx_list + [i])) for i in range(arity)) + + return get_output_rec(func_name) + + def time_evaluator( + self, + func_name, + dev, + number=10, + repeat=1, + min_repeat_ms=0, + cooldown_interval_ms=0, + repeats_to_cooldown=1, + f_preproc="", + ) -> Callable[..., tvm.runtime.module.BenchmarkResult]: + """ + Returns an evaluator that times a function in the module. + This follows the same convention as time_evaluator in tvm.runtime.module. + This can be used in combination with save_function() so that the + timings avoid extra dictionary lookups. + + Parameters + ---------- + func_name: str + The name of the function in the module. + + dev: Device + The device we should run this function on. + + number: int + The number of times to run this function for taking average. + We call these runs as one `repeat` of measurement. + + repeat: int, optional + The number of times to repeat the measurement. + In total, the function will be invoked (1 + number x repeat) times, + where the first one is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. + + min_repeat_ms: int, optional + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. + + cooldown_interval_ms: int, optional + The cooldown interval in milliseconds between the number of repeats defined by + `repeats_to_cooldown`. + + repeats_to_cooldown: int, optional + The number of repeats before the cooldown is activated. + + f_preproc: str, optional + The preprocess function name we want to execute before executing the time evaluator. + + Note + ---- + The function will be invoked (1 + number x repeat) times, + with the first call discarded in case there is lazy initialization. + + Example + ------- + Normal use with a VM function (may not work over RPC if the function returns a tuple): + + .. code-block:: python + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestTimeEvaluator, target) + vm = relax.VirtualMachine(mod, tvm.cpu()) + timing_res = vm.time_evaluator("func_name", tvm.cpu())(arg0, arg1, ..., argn) + + Use with the stateful API: + + .. code-block:: python + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestTimeEvaluator, target) + vm = relax.VirtualMachine(mod, tvm.cpu()) + vm.set_input("func_name", arg0, arg1, ..., argn) + timing_res = vm.time_evaluator("invoke_stateful", tvm.cpu())("func_name") + + With saved closures via `save_function` (this results in + fewer dictionary lookups in the timed portion): + + .. code-block:: python + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestTimeEvaluator, target) + vm = relax.VirtualMachine(mod, tvm.cpu()) + vm.save_function("func_name", "func_name_saved", arg0, arg1, ..., argn) + timing_res = vm.time_evaluator("func_name_saved", tvm.cpu())() + + Returns + ------- + ftimer : function + The function that takes same argument as func and returns a BenchmarkResult. + The ProfileResult reports `repeat` time costs in seconds. + + """ + return self.module.time_evaluator( + func_name, + dev, + number=number, + repeat=repeat, + min_repeat_ms=min_repeat_ms, + cooldown_interval_ms=cooldown_interval_ms, + repeats_to_cooldown=repeats_to_cooldown, + f_preproc=f_preproc, + ) + + +def _vmcodegen( + builder: "relax.ExecBuilder", + mod: tvm.IRModule, + exec_mode: str = "bytecode", +) -> tvm.IRModule: + """Running VM codegen. + + Parameters + ---------- + builder: relax.ExecBuilder + ExecBuilder to collect the vm executable. + + mod: IRModule + The input IRModule to be built. + + exec_mode: {"bytecode", "compiled"} + The execution mode. + + Return + ------ + leftover: IRModule + Left over IRModule that may contain extra functions. + """ + + if exec_mode == "bytecode": + return _ffi_api.VMCodeGen(builder, mod) # type:ignore + if exec_mode == "compiled": + return _ffi_api.VMTIRCodeGen(builder, mod) # type: ignore + raise ValueError("Unknown exec_mode %s" % exec_mode) + + +def _vmlink( + builder: "relax.ExecBuilder", + target: Union[str, tvm.target.Target], + tir_mod: Optional[tvm.IRModule] = None, + ext_libs: List[tvm.runtime.Module] = None, + params: Optional[Dict[str, list]] = None, +): + """ + Internal codegen function to make executable. + + This function is only used for unit-testing purpoes. + + Use build instead. + + Parameters + ---------- + builder: relax.ExecBuilder + Builder used to collect executables. + + target : Union[str, tvm.target.Target] + A build target which can have optional host side compilation target. + + tir_mod: IRModule + The input TIR IRModule to be linked together. + + ext_libs: List[tvm.runtime.Module] + List of compiled external modules. + + params: Optional[Dict[str, list]] + Extra parameter mappings. + + Returns + ------- + ex: tvm.relax.vm.Executable + An executable that can be loaded by virtual machine. + """ + if isinstance(target, str): + target = tvm.target.Target(target) + if params is None: + params = {} + if ext_libs is None: + ext_libs = [] + lib = None + if tir_mod is not None: + lib = tvm.build(tir_mod, target=target) + return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore + + +def build( + mod: tvm.IRModule, + target: Union[str, tvm.target.Target], + params: Optional[Dict[str, list]] = None, + exec_mode: str = "bytecode", +) -> Executable: + """ + Build an IRModule to VM executable. + + Parameters + ---------- + mod: IRModule + The input IRModule to be built. + + target : Union[str, tvm.target.Target] + A build target which can have optional host side compilation target. + + When TVM compiles device specific program such as CUDA, + we also need host(CPU) side code to interact with the driver + to setup the dimensions and parameters correctly. + host is used to specify the host side codegen target. + By default, llvm is used if it is enabled, + otherwise a stackvm interpreter is used. + + params: Optional[Dict[str, list]] + Parameters for the input IRModule that will be bound. + + exec_mode: {"bytecode", "compiled"} + The execution mode. + + Returns + ------- + ex: tvm.relax.vm.Executable + An executable that can be loaded by virtual machine. + + Example + ------- + + .. code-block:: python + class InputModule: + @R.function + def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): + z = R.add(x, y) + return z + + mod = InputModule + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + """ + if isinstance(target, str): + target = tvm.target.Target(target) + + passes = [relax.transform.ToNonDataflow()] + passes.append(relax.transform.CallTIRRewrite()) + passes.append(relax.transform.VMBuiltinLower()) + passes.append(relax.transform.VMShapeLower()) + passes.append(relax.transform.AttachGlobalSymbol()) + seq = tvm.transform.Sequential(passes) + new_mod = seq(mod) + + # Extract external runtime modules if exist. + ext_libs = [] + if mod.attrs and "external_mods" in mod.attrs: + ext_libs = mod.attrs["external_mods"] + + # builder collects the executable + builder = relax.ExecBuilder() + leftover_mod = _vmcodegen(builder, new_mod, exec_mode=exec_mode) + tir_mod = _filter_tir(leftover_mod) + return _vmlink(builder, target, tir_mod, ext_libs, params) + + +def _filter_tir(mod: tvm.IRModule) -> tvm.IRModule: + tir_mod = IRModule({}) + for gv in mod.get_global_vars(): + if isinstance(mod[gv], PrimFunc): + tir_mod[gv] = mod[gv] + return tir_mod diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc new file mode 100644 index 000000000000..b5d932137be0 --- /dev/null +++ b/src/relax/backend/vm/exec_builder.cc @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/vm/exec_builder.cc + */ +#include + +#include + +namespace tvm { +namespace relax { + +using namespace vm; + +TVM_REGISTER_NODE_TYPE(ExecBuilderNode); + +ExecBuilder ExecBuilderNode::Create() { + ExecBuilder ret(make_object()); + ret->exec_ = make_object(); + return ret; +} + +Executable* ExecBuilderNode::exec() const { return exec_.get(); } + +ObjectPtr ExecBuilderNode::Get() { + this->Formalize(); + this->CheckExecutable(); + return exec_; +} + +vm::Instruction::Arg ExecBuilderNode::ConvertConstant_(TVMRetValue cvalue) { + // emit constant immediate as immediate. + if (cvalue.type_code() == kDLInt) { + int64_t val = cvalue.operator int64_t(); + if (val <= vm::Instruction::kValueMaxLimit && val >= vm::Instruction::kValueMinLimit) { + return vm::Instruction::Arg::Immediate(val); + } + } + // convert string to object string + if (cvalue.type_code() == kTVMStr) { + cvalue = cvalue.operator String(); + } + + // run dedup for object with structural equality + if (cvalue.IsObjectRef()) { + ObjectRef obj = cvalue.operator ObjectRef(); + auto it = const_dedup_map_.find(obj); + if (it != const_dedup_map_.end()) { + return vm::Instruction::Arg::ConstIdx(it->second); + } + vm::Index idx = exec_->constants.size(); + exec_->constants.push_back(cvalue); + const_dedup_map_[obj] = idx; + return vm::Instruction::Arg::ConstIdx(idx); + } else { + // emit normal constant + vm::Index idx = exec_->constants.size(); + exec_->constants.push_back(cvalue); + return vm::Instruction::Arg::ConstIdx(idx); + } +} + +void ExecBuilderNode::DeclareFunction(const std::string& func_name, VMFuncInfo::FuncKind kind) { + auto it = exec_->func_map.find(func_name); + if (it != exec_->func_map.end()) { + ICHECK(kind == exec_->func_table[it->second].kind) + << "Function " << func_name << "already declared in a different kind"; + return; + } + VMFuncInfo vmfunc; + vmfunc.kind = kind; + vmfunc.name = func_name; + // use num args to mark undefined. + vmfunc.start_instr = 0; + vmfunc.num_args = -2; + vmfunc.register_file_size = 0; + exec_->func_map[func_name] = exec_->func_table.size(); + exec_->func_table.push_back(vmfunc); +} + +vm::Instruction::Arg ExecBuilderNode::GetFunction(const std::string& func_name) { + auto it = exec_->func_map.find(func_name); + ICHECK(it != exec_->func_map.end()) << "Cannot find function " << func_name; + return vm::Instruction::Arg::FuncIdx(it->second); +} + +void ExecBuilderNode::EmitFunction(const std::string& func_name, int64_t num_inputs, + Optional> param_names, + vm::VMFuncInfo::FuncKind kind, int64_t init_register_size) { + auto it = exec_->func_map.find(func_name); + if (it == exec_->func_map.end()) { + this->DeclareFunction(func_name, kind); + } + auto& vmfunc = exec_->func_table.at(exec_->func_map.at(func_name)); + ICHECK_EQ(vmfunc.name, func_name); + ICHECK_EQ(vmfunc.num_args, -2) << "Function " << func_name << " already defined"; + vmfunc.num_args = num_inputs; + if (param_names.defined()) { + std::vector names; + for (auto name : param_names.value()) { + names.push_back(name); + } + vmfunc.param_names = names; + } + vmfunc.register_file_size = init_register_size; + if (kind == vm::VMFuncInfo::FuncKind::kVMFunc) { + vmfunc.start_instr = exec_->instr_offset.size(); + } +} + +void ExecBuilderNode::EndFunction(const std::string& func_name) { + auto it = exec_->func_map.find(func_name); + ICHECK(it != exec_->func_map.end()); + VMFuncInfo& vmfunc = exec_->func_table.at(it->second); + ICHECK_EQ(vmfunc.end_instr, 0) << "EndFuncton can only be called once"; + + if (vmfunc.kind == vm::VMFuncInfo::FuncKind::kVMFunc) { + vmfunc.end_instr = exec_->instr_offset.size(); + } +} + +void ExecBuilderNode::EmitCall(vm::Instruction::Arg func, std::vector args, + vm::RegName dst) { + ICHECK(func.kind() == vm::Instruction::ArgKind::kFuncIdx); + // store instruction + exec_->instr_offset.push_back(exec_->instr_data.size()); + exec_->instr_data.push_back(static_cast(Opcode::Call)); + exec_->instr_data.push_back(dst); + exec_->instr_data.push_back(func.value()); + exec_->instr_data.push_back(args.size()); + for (Instruction::Arg arg : args) { + exec_->instr_data.push_back(arg.data()); + } +} + +void ExecBuilderNode::EmitCall(const std::string& func, std::vector args, + RegName dst) { + auto it = exec_->func_map.find(func); + if (it == exec_->func_map.end()) { + this->DeclareFunction(func, VMFuncInfo::FuncKind::kPackedFunc); + } + Index func_idx = exec_->func_map.at(func); + EmitCall(vm::Instruction::Arg::FuncIdx(func_idx), args, dst); +} + +void ExecBuilderNode::EmitRet(vm::Instruction::Arg result) { + ICHECK(result.kind() == vm::Instruction::ArgKind::kRegister); + exec_->instr_offset.push_back(exec_->instr_data.size()); + exec_->instr_data.push_back(static_cast(Opcode::Ret)); + exec_->instr_data.push_back(result.value()); +} + +void ExecBuilderNode::EmitGoto(Index pc_offset) { + exec_->instr_offset.push_back(exec_->instr_data.size()); + exec_->instr_data.push_back(static_cast(Opcode::Goto)); + exec_->instr_data.push_back(pc_offset); +} + +void ExecBuilderNode::EmitIf(vm::Instruction::Arg cond, vm::Index false_offset) { + ICHECK(cond.kind() == vm::Instruction::ArgKind::kRegister); + exec_->instr_offset.push_back(exec_->instr_data.size()); + exec_->instr_data.push_back(static_cast(Opcode::If)); + exec_->instr_data.push_back(cond.value()); + exec_->instr_data.push_back(false_offset); +} + +void ExecBuilderNode::CheckExecutable() { + for (auto it = exec_->func_table.cbegin(); it != exec_->func_table.cend(); ++it) { + if (it->kind == VMFuncInfo::FuncKind::kPackedFunc) continue; + if (it->kind == VMFuncInfo::FuncKind::kVMTIRFunc) { + ICHECK_GE(it->register_file_size, it->num_args + 1) + << "Function " << it->name << " do not meet register file constraint."; + continue; + } + Index num_inputs = it->num_args; + std::unordered_set dst_registers; + std::unordered_set arg_registers; + size_t start_instr = it->start_instr; + size_t end_instr = it->end_instr; + + CHECK_LT(start_instr, end_instr) + << "Function " << it->name << " EndFunction has not be been called"; + + auto check_reg_defined = [&](Instruction::Arg arg) { + if (arg.kind() != Instruction::ArgKind::kRegister) return; + if (arg.value() >= Instruction::kBeginSpecialReg) return; + if (arg.value() < num_inputs) return; + + if (dst_registers.find(arg.value()) == dst_registers.end()) { + LOG(FATAL) << "register r(" << arg.value() << ") in VM function \"" << it->name + << "\" is used as input while it is never defined" + << " as a destination. Dump:\n" + << exec_->AsText(); + } + }; + + auto check_const_defined = [&](Instruction::Arg arg) { + if (arg.kind() != Instruction::ArgKind::kConstIdx) return; + CHECK_LT(arg.value(), exec_->constants.size()) + << "Constant index " << arg.value() << " exceed size of constant pool. Dump:\n" + << exec_->AsText(); + }; + + auto check_func_defined = [&](Instruction::Arg arg) { + if (arg.kind() != Instruction::ArgKind::kFuncIdx) return; + CHECK_LT(arg.value(), exec_->func_table.size()) + << "Func index " << arg.value() << " exceed size of fun_table. Dump:\n" + << exec_->AsText(); + }; + + for (size_t idx = start_instr; idx < end_instr; ++idx) { + Instruction instr = exec_->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + check_func_defined(Instruction::Arg::FuncIdx(instr.func_idx)); + for (int i = 0; i < instr.num_args; ++i) { + check_reg_defined(instr.args[i]); + check_const_defined(instr.args[i]); + check_func_defined(instr.args[i]); + arg_registers.emplace(instr.args[i].value()); + } + if (instr.dst != Instruction::kVoidRegister) { + dst_registers.emplace(instr.dst); + } + break; + } + case Opcode::Ret: { + arg_registers.emplace(instr.result); + check_reg_defined(Instruction::Arg::Register(instr.result)); + break; + } + case Opcode::Goto: { + ICHECK_NE(instr.pc_offset, 0); + break; + } + case Opcode::If: { + ICHECK_GT(instr.false_offset, 1); + check_reg_defined(Instruction::Arg::Register(instr.cond)); + arg_registers.emplace(instr.cond); + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + } +} + +void ExecBuilderNode::Formalize() { + // a pass to formalize user-specified register indexes in the order of use + // and decide the number of registers to allocate for each VMFunction in the Executable + for (auto it = this->exec_->func_table.begin(); it != this->exec_->func_table.end(); ++it) { + if (it->kind == VMFuncInfo::FuncKind::kPackedFunc) continue; + if (it->kind == VMFuncInfo::FuncKind::kVMTIRFunc) continue; + + Index num_inputs = it->num_args; + RegName register_idx = num_inputs; + std::unordered_map register_map; + size_t start_instr = it->start_instr; + size_t end_instr = it->end_instr; + + for (size_t idx = start_instr; idx < end_instr; ++idx) { + Instruction instr = this->exec_->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + // rewrite args + for (int i = 0; i < instr.num_args; ++i) { + if (instr.args[i].kind() == Instruction::ArgKind::kRegister && + instr.args[i].value() >= num_inputs && + instr.args[i].value() < Instruction::kBeginSpecialReg && + register_map.find(instr.args[i].value()) != register_map.end()) { + this->exec_->instr_data[this->exec_->instr_offset[idx] + 4 + i] = + register_map[instr.args[i].value()]; + } + } + if (instr.dst >= num_inputs && instr.dst < Instruction::kBeginSpecialReg) { + auto it = register_map.find(instr.dst); + if (it != register_map.end()) { + this->exec_->instr_data[this->exec_->instr_offset[idx] + 1] = it->second; + } else { + this->exec_->instr_data[this->exec_->instr_offset[idx] + 1] = register_idx; + register_map[instr.dst] = register_idx++; + } + } + break; + } + case Opcode::Ret: { + if (register_map.find(instr.result) != register_map.end()) { + this->exec_->instr_data[this->exec_->instr_offset[idx] + 1] = + register_map[instr.result]; + } + break; + } + case Opcode::Goto: { + break; + } + case Opcode::If: { + if (register_map.find(instr.cond) != register_map.end()) { + this->exec_->instr_data[this->exec_->instr_offset[idx] + 1] = register_map[instr.cond]; + } + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + it->register_file_size = register_idx; + } +} + +TVM_REGISTER_GLOBAL("relax.ExecBuilderCreate").set_body_typed(ExecBuilderNode::Create); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderConvertConstant") + .set_body([](TVMArgs args, TVMRetValue* ret) { + ExecBuilder builder = args[0]; + TVMRetValue rt; + rt = args[1]; + *ret = builder->ConvertConstant(rt).data(); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitFunction") + .set_body_typed([](ExecBuilder builder, String func, int64_t num_inputs, + Optional> param_names) { + builder->EmitFunction(func, num_inputs, param_names); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEndFunction") + .set_body_method(&ExecBuilderNode::EndFunction); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderDeclareFunction") + .set_body_typed([](ExecBuilder builder, String name, int32_t kind) { + builder->DeclareFunction(name, static_cast(kind)); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") + .set_body_typed([](ExecBuilder builder, String name, Array args, int64_t dst) { + std::vector args_; + for (size_t i = 0; i < args.size(); ++i) { + args_.push_back(Instruction::Arg::FromData(args[i]->value)); + } + auto dst_ = Instruction::Arg::Register(dst); + builder->EmitCall(name, args_, dst_.value()); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitRet") + .set_body_typed([](ExecBuilder builder, int64_t data) { + builder->EmitRet(Instruction::Arg::FromData(data)); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitGoto") + .set_body_method(&ExecBuilderNode::EmitGoto); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitIf") + .set_body_typed([](ExecBuilder builder, int64_t data, vm::Index false_offset) { + builder->EmitIf(Instruction::Arg::FromData(data), false_offset); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderR").set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg::Register(value).data(); +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderImm").set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg::Immediate(value).data(); +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderC").set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg::ConstIdx(value).data(); +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderF").set_body_typed([](ExecBuilder builder, String value) { + return builder->GetFunction(value).data(); +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_typed([](ExecBuilder builder) { + ObjectPtr p_exec = builder->Get(); + return runtime::Module(p_exec); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc new file mode 100644 index 000000000000..0ef63c8a4147 --- /dev/null +++ b/src/runtime/relax_vm/builtin.cc @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/runtime/relax_vm/builtin.cc + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../runtime_base.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +using tvm::runtime::NDArray; + +//------------------------------------------------- +// Shape/StructInfo handling. +//------------------------------------------------- +/*! + * \brief Builtin function to allocate shape heap. + * \param ctx_ptr The context module pointer. + * \param size the size of the heap. + * \return An allocate NDArray as shape heap. + */ +NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) { + VirtualMachine* vm = static_cast(ctx_ptr); + // use host allocator, which is always last element. + size_t host_device_index = vm->devices.size() - 1; + // specialy handle hexagon on-device RT. + // TODO(relax-team): visit and consider other possible choices. + if (vm->devices[0].device_type == kDLHexagon) { + host_device_index = 0; + } + auto* alloc = vm->allocators[host_device_index]; + return alloc->Empty({size}, DLDataType{kDLInt, 64, 1}, vm->devices[host_device_index]); +} + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap").set_body_typed(AllocShapeHeap); + +/*! + * \brief Builtin match shape function. + * \param args The packed function arguments. + * \param rv The return value. + * + * \sa MatchShapeCode + */ +void MatchShape(TVMArgs args, TVMRetValue* rv) { + // input shape the first argument can take in tensor or shape. + ShapeTuple input_shape; + if (args[0].IsObjectRef()) { + input_shape = args[0].operator NDArray().Shape(); + } else { + input_shape = args[0]; + } + DLTensor* heap = args[1]; + int64_t* heap_data = heap == nullptr ? nullptr : static_cast(heap->data); + int64_t size = args[2]; + const int64_t kBeginCode = 3; + ICHECK_LE(kBeginCode + size * 2, args.size()); + // a function that lazily get context for error reporting + const int64_t kErrorContextOffset = kBeginCode + size * 2; + Optional err_ctx = args[kErrorContextOffset]; + + CHECK_EQ(input_shape.size(), size) + << "RuntimeError: " << err_ctx.value_or("") << " match_cast shape size mismatch."; + + for (int64_t i = 0; i < size; ++i) { + MatchShapeCode code = static_cast(args[kBeginCode + i * 2].operator int()); + int64_t reg = args[kBeginCode + i * 2 + 1]; + + if (code == MatchShapeCode::kAssertEqualToImm) { + CHECK_EQ(input_shape[i], reg) + << "RuntimeError: " << err_ctx.value_or("") << " match_cast error, " + << " shape[" << i << "]" + << " mismatch to specified constant."; + } else if (code == MatchShapeCode::kStoreToHeap) { + heap_data[reg] = input_shape[i]; + } else if (code == MatchShapeCode::kNoOp) { + } else { + ICHECK(code == MatchShapeCode::kAssertEqualToLoad); + CHECK_EQ(input_shape[i], heap_data[reg]) + << "RuntimeError: " << err_ctx.value_or("") << " match_cast error, " + << " shape[" << i << "]" + << " mismatch to a previous populated value."; + } + } +} + +TVM_REGISTER_GLOBAL("vm.builtin.match_shape").set_body(MatchShape); + +/*! + * \brief Builtin make shape function. + * \param args The packed function arguments. + * \param rv The return value. + * + * \sa MakeShapeCode + */ +void MakeShape(TVMArgs args, TVMRetValue* rv) { + // NOTE: heap can be nullptr + DLTensor* heap = args[0]; + int64_t* heap_data = heap == nullptr ? nullptr : static_cast(heap->data); + int64_t size = args[1]; + const int64_t kBeginCode = 2; + + std::vector shape(size); + + for (int64_t i = 0; i < size; ++i) { + MakeShapeCode code = static_cast(args[kBeginCode + i * 2].operator int()); + int64_t reg = args[kBeginCode + i * 2 + 1]; + if (code == MakeShapeCode::kUseImm) { + shape[i] = reg; + } else { + ICHECK(code == MakeShapeCode::kLoadShape); + shape[i] = heap_data[reg]; + } + } + *rv = ShapeTuple(std::move(shape)); +} + +TVM_REGISTER_GLOBAL("vm.builtin.make_shape").set_body(MakeShape); + +/*! + * \brief Builtin function to check if arg is Tensor(dtype, ndim) + * \param arg The input argument. + * \param ndim Expected ndim of the Tensor, can be -1 (indicate unknown). + * \param dtype The expected content data type. + * \param err_ctx Additional context if error occurs. + */ +void CheckTensorInfo(TVMArgs args, TVMRetValue* rv) { + ObjectRef arg = args[0]; + int ndim = args[1]; + DataType dtype; + Optional err_ctx; + + if (args.size() == 3) { + dtype = DataType::Void(); + err_ctx = args[2].operator Optional(); + } else { + dtype = args[2]; + err_ctx = args[3].operator Optional(); + } + + auto* ptr = arg.as(); + CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Tensor but get " + << arg->GetTypeKey(); + + if (ndim != -1) { + CHECK(ptr->dl_tensor.ndim == ndim) + << "ValueError: " << err_ctx.value_or("") << " expect Tensor with ndim " << ndim + << " but get " << ptr->dl_tensor.ndim; + } + + if (dtype != DataType::Void()) { + CHECK(DataType(ptr->dl_tensor.dtype) == dtype) + << "ValueError: " << err_ctx.value_or("") << " expect Tensor with dtype " << dtype + << " but get " << ptr->dl_tensor.dtype; + } +} + +TVM_REGISTER_GLOBAL("vm.builtin.check_tensor_info").set_body(CheckTensorInfo); + +/*! + * \brief Builtin function to check if arg is Shape(ndim) + * \param arg The input argument. + * \param ndim Expected size of the shape, can be -1 (indicate unknown). + * \param err_ctx Additional context if error occurs. + */ +void CheckShapeInfo(ObjectRef arg, int ndim, Optional err_ctx) { + // a function that lazily get context for error reporting + auto* ptr = arg.as(); + CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Shape but get " + << arg->GetTypeKey(); + if (ndim != -1) { + CHECK(ptr->size == static_cast(ndim)) + << "ValueError: " << err_ctx.value_or("") << " expect Shape with ndim " << ndim + << " but get " << ptr->size; + } +} + +TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo); + +/*! + * \brief Builtin function to check if arg is Tuple with size elements. + * \param arg The input argument. + * \param size The expected size of the tuple. + * \param err_ctx Additional context if error occurs. + */ +void CheckTupleInfo(ObjectRef arg, int64_t size, Optional err_ctx) { + using Tuple = runtime::ADT; + // a function that lazily get context for error reporting + auto* ptr = arg.as(); + CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Tuple but get " + << arg->GetTypeKey(); + CHECK(static_cast(ptr->size) == size) + << "ValueError: " << err_ctx.value_or("") << " expect a Tuple with " << size << " elements, " + << " but get a Tuple with " << ptr->size << " elements."; +} + +TVM_REGISTER_GLOBAL("vm.builtin.check_tuple_info").set_body_typed(CheckTupleInfo); + +/*! + * \brief Builtin function to check if arg is a callable function. + * \param arg The input argument. + * \param err_ctx Additional context if error occurs. + */ +void CheckFuncInfo(ObjectRef arg, Optional err_ctx) { + // a function that lazily get context for error reporting + bool is_func = arg.as() || arg.as(); + CHECK(is_func) << "TypeError: " << err_ctx.value_or("") << " expect a Function but get " + << arg->GetTypeKey(); +} + +TVM_REGISTER_GLOBAL("vm.builtin.check_func_info").set_body_typed(CheckFuncInfo); + +//------------------------------------------------- +// Storage management. +//------------------------------------------------- +Storage VMAllocStorage(void* ctx_ptr, ShapeTuple buffer_size, Index device_index, + DLDataType dtype_hint) { + VirtualMachine* vm = static_cast(ctx_ptr); + + ICHECK_EQ(buffer_size.size(), 1); + int alignment = runtime::kAllocAlignment; + ICHECK_LT(device_index, vm->devices.size()) + << "The device index is out of VM physical devices list"; + + if (device_index == -1) { + // Allocate on host. Host is always the last element of vm->devices. + device_index = vm->devices.size() - 1; + } + + int64_t size_imm = buffer_size[0]; + + auto storage_obj = runtime::SimpleObjAllocator().make_object(); + auto* alloc = vm->allocators[device_index]; + ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?"; + storage_obj->buffer = alloc->Alloc(size_imm, alignment, dtype_hint); + Storage storage(storage_obj); + return storage; +} + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage").set_body_typed(VMAllocStorage); + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_tensor").set_body_method(&StorageObj::AllocNDArray); + +//------------------------------------------------- +// Closure function handling, calling convention +//------------------------------------------------- +TVM_REGISTER_GLOBAL("vm.builtin.make_closure").set_body([](TVMArgs args, TVMRetValue* rv) { + VMClosure clo = args[0]; + std::vector saved_args; + saved_args.resize(args.size() - 1); + for (size_t i = 0; i < saved_args.size(); ++i) { + saved_args[i] = args[i + 1]; + } + auto impl = VMClosure::BindLastArgs(clo->impl, saved_args); + *rv = VMClosure(clo->func_name, impl); +}); + +TVM_REGISTER_GLOBAL("vm.builtin.invoke_closure").set_body([](TVMArgs args, TVMRetValue* rv) { + // args[0]: vm; args[1]: closure; args[2, 3, ...]: function arguments + VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); + ObjectRef vm_closure = args[1]; + vm->InvokeClosurePacked(vm_closure, + TVMArgs(args.values + 2, args.type_codes + 2, args.size() - 2), rv); +}); + +TVM_REGISTER_GLOBAL("vm.builtin.call_tir_dyn").set_body([](TVMArgs args, TVMRetValue* rv) { + PackedFunc func = args[0]; + ShapeTuple to_unpack = args[args.size() - 1]; + size_t num_tensor_args = args.size() - 2; + + std::vector values(num_tensor_args + to_unpack.size()); + std::vector tcodes(num_tensor_args + to_unpack.size()); + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + + std::copy(args.values + 1, args.values + args.size() - 1, values.data()); + std::copy(args.type_codes + 1, args.type_codes + args.size() - 1, tcodes.data()); + + for (size_t i = 0; i < to_unpack.size(); ++i) { + setter(i + num_tensor_args, to_unpack[i]); + } + TVMArgs func_args(values.data(), tcodes.data(), values.size()); + func.CallPacked(func_args, rv); +}); + +//------------------------------------- +// Builtin runtime operators. +//------------------------------------- +TVM_REGISTER_GLOBAL("vm.builtin.shape_of").set_body_method(&NDArray::Shape); + +TVM_REGISTER_GLOBAL("vm.builtin.copy").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = args[0]; +}); + +/*! + * \brief Load the scalar value in cond and return the result value. + * \param cond The condition + * \return Bool + */ +bool ReadIfCond(TVMArgValue cond) { + if (cond.type_code() == kDLInt) return cond.operator bool(); + NDArray arr = cond.operator tvm::runtime::NDArray(); + if (arr->device.device_type != kDLCPU) { + arr = arr.CopyTo(DLDevice{kDLCPU, 0}); + } + ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt); + int64_t result; + switch (arr->dtype.bits) { + case 1: { + result = reinterpret_cast(arr->data)[0]; + break; + } + case 8: { + result = reinterpret_cast(arr->data)[0]; + break; + } + case 16: { + result = reinterpret_cast(arr->data)[0]; + break; + } + case 32: { + result = reinterpret_cast(arr->data)[0]; + break; + } + case 64: { + result = reinterpret_cast(arr->data)[0]; + break; + } + default: + LOG(FATAL) << "Unknown scalar int type: " << DLDataType2String(arr->dtype); + throw; + } + return result != 0; +} + +TVM_REGISTER_GLOBAL("vm.builtin.read_if_cond").set_body_typed(ReadIfCond); + +//------------------------------------- +// Data structure API +//------------------------------------- +TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem").set_body_typed([](runtime::ADT arr, int64_t index) { + return arr[index]; +}); + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +//------------------------------------------------- +// AnyList C runtime API: keep in relax for now. +//-------------------------------------------------- +extern "C" { +/*! + * \brief Backend function to get anylist item and set into Packed Func call arg stack. + * + * \param anylist The handle to the anylist, backed by TVMRetValue* + * \param int The index. + * \param args The args stack. + * \param type_codes The type codes stack. + * \param arg_offset The offset of argument. + * \return 0 when no error is thrown, -1 when failure happens + */ +TVM_DLL int TVMBackendAnyListSetPackedArg(void* anylist, int index, TVMValue* args, int* type_codes, + int arg_offset); +/*! + * \brief Backend function to get anylist item and set into Packed Func call arg stack. + * + * \param anylist The handle to the anylist, backed by TVMRetValue* + * \param int The index. + */ +TVM_DLL int TVMBackendAnyListResetItem(void* anylist, int index); + +/*! + * \brief Backend function to set anylist item by moving from packed func return. + * + * \param anylist The handle to the anylist, backed by TVMRetValue* + * \param int The index. + * \param args The args stack. + * \param type_codes The type codes stack. + * \param arg_offset The offset of argument. + * \return 0 when no error is thrown, -1 when failure happens. + */ +TVM_DLL int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMValue* args, + int* type_codes, int ret_offset); + +int TVMBackendAnyListSetPackedArg(void* anylist, int index, TVMValue* args, int* type_codes, + int arg_offset) { + using namespace tvm::runtime; + API_BEGIN(); + auto* list = static_cast(anylist); + TVMArgsSetter setter(args, type_codes); + setter(arg_offset, list[index]); + API_END(); +} + +int TVMBackendAnyListResetItem(void* anylist, int index) { + using namespace tvm::runtime; + API_BEGIN(); + auto* list = static_cast(anylist); + list[index] = nullptr; + API_END(); +} + +int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMValue* args, int* type_codes, + int ret_offset) { + using namespace tvm::runtime; + API_BEGIN(); + auto* list = static_cast(anylist); + if (type_codes[ret_offset] == kTVMStr || type_codes[ret_offset] == kTVMBytes) { + list[index] = TVMArgValue(args[ret_offset], type_codes[ret_offset]); + } else { + list[index] = TVMRetValue::MoveFromCHost(args[ret_offset], type_codes[ret_offset]); + } + API_END(); +} +} // extern "C" diff --git a/src/runtime/relax_vm/bytecode.cc b/src/runtime/relax_vm/bytecode.cc new file mode 100644 index 000000000000..9084207848b5 --- /dev/null +++ b/src/runtime/relax_vm/bytecode.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/relax_vm/bytecode.cc + * \brief The bytecode for Relax virtual machine. + */ + +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +Instruction Instruction::Call(Index func_idx, Index num_args, Instruction::Arg* args, RegName dst) { + Instruction instr; + instr.op = Opcode::Call; + instr.dst = dst; + instr.func_idx = func_idx; + instr.num_args = num_args; + instr.args = args; + return instr; +} + +Instruction Instruction::Ret(RegName result) { + Instruction instr; + instr.op = Opcode::Ret; + instr.result = result; + return instr; +} + +Instruction Instruction::Goto(Index pc_offset) { + Instruction instr; + instr.op = Opcode::Goto; + instr.pc_offset = pc_offset; + return instr; +} + +Instruction Instruction::If(RegName cond, Index false_offset) { + Instruction instr; + instr.op = Opcode::If; + instr.cond = cond; + instr.false_offset = false_offset; + return instr; +} +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/relax_vm/executable.cc b/src/runtime/relax_vm/executable.cc new file mode 100644 index 000000000000..b7915d7978aa --- /dev/null +++ b/src/runtime/relax_vm/executable.cc @@ -0,0 +1,576 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/relax_vm/executable.cc + */ + +#include +#include +#include +#include + +#include +#include + +#include "../file_utils.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! \brief The magic number for the serialized VM bytecode file */ +constexpr uint64_t kTVMVMBytecodeMagic = 0xD225DE2F4214151D; + +/*! \brief Possible types in the constant pool */ +enum ConstantType : int { + kNDArray = 0, + kDLDataType = 1, + kShapeTuple = 2, + kString = 3, + kInt = 4, +}; + +#define STREAM_CHECK(val, section) \ + ICHECK(val) << "Invalid VM file format in the " << section << " section." \ + << "\n"; + +PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "stats") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); + } else if (name == "as_text") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->AsText(); }); + } else if (name == "as_python") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->AsPython(); }); + } else if (name == "vm_load_executable") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ObjectPtr vm = VirtualMachine::Create(); + ICHECK(sptr_to_self.get() == this); + vm->LoadExecutable(GetObjectPtr(this)); + *rv = Module(vm); + }); + } + return nullptr; +} + +std::string Executable::Stats() const { + std::ostringstream oss; + oss << "Relax VM executable statistics:" << std::endl; + + // Get the number of constants. + // If the constant is an NDArray, get the shape of each of them. + // If the constant is an DLDataType, get the data type of each of them. + oss << " Constant pool (# " << constants.size() << "): ["; + for (const auto& it : constants) { + if (it.IsObjectRef()) { + const auto ndarray = it.operator tvm::runtime::NDArray(); + const auto& shape = ndarray.Shape(); + // Scalar + if (shape.empty()) { + oss << "scalar, "; + continue; + } + oss << "["; + for (auto s : shape) { + oss << s << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], "; + } else if (it.IsObjectRef()) { + ShapeTuple shape = it.operator ShapeTuple(); + oss << "shapetuple["; + for (size_t i = 0; i < shape.size(); ++i) { + oss << shape.at(i) << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], "; + } else if (it.IsObjectRef()) { + std::string f = it.AsObjectRef().operator std::string(); + oss << "\""; + oss << f; + oss << "\", "; + } else if (it.type_code() == kDLInt) { + oss << static_cast(it); + oss << ", "; + } else { + try { + DLDataType dtype = it.operator DLDataType(); + oss << dtype; + oss << ", "; + } catch (std::exception& exc) { + LOG(FATAL) << "Constant pool can only contain NDArray and DLDataType, but got " + << ArgTypeCode2Str(it.type_code()); + } + } + } + if (!constants.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of globals and the name of each of them. + oss << " Globals (#" << func_table.size() << "): ["; + for (const auto& it : func_table) { + oss << it.name << ", "; + } + if (!func_map.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + return oss.str(); +} + +void Executable::SetInstructionData(Index i, Index j, ExecWord val) { + ICHECK_LT(i, instr_offset.size()); + Index instr_idx = instr_offset[i]; + ICHECK_LT(instr_idx + j, instr_data.size()); + instr_data[instr_idx + j] = val; +} + +Instruction Executable::GetInstruction(Index i) const { + Index offset = instr_offset[i]; + Opcode op = static_cast(instr_data[offset]); + switch (op) { + case Opcode::Call: { + RegName dst = instr_data[offset + 1]; + Index func_idx = instr_data[offset + 2]; + Index num_args = instr_data[offset + 3]; + ExecWord* args = const_cast(&instr_data[offset + 4]); + return Instruction::Call(func_idx, num_args, reinterpret_cast(args), dst); + } + case Opcode::Ret: { + RegName result = instr_data[offset + 1]; + return Instruction::Ret(result); + } + case Opcode::Goto: { + Index pc_offset = instr_data[offset + 1]; + return Instruction::Goto(pc_offset); + } + case Opcode::If: { + RegName cond = instr_data[offset + 1]; + Index false_offset = instr_data[offset + 2]; + return Instruction::If(cond, false_offset); + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(op); + break; + } + return Instruction(); +} + +void SaveHeader(dmlc::Stream* strm) { + uint64_t header = kTVMVMBytecodeMagic; + strm->Write(header); + std::string version = TVM_VERSION; + strm->Write(version); +} + +void LoadHeader(dmlc::Stream* strm) { + // Check header. + uint64_t header; + STREAM_CHECK(strm->Read(&header), "header"); + STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); + + // Check version. + std::string version; + STREAM_CHECK(strm->Read(&version), "version"); + STREAM_CHECK(version == TVM_VERSION, "version"); +} + +void Executable::SaveToBinary(dmlc::Stream* stream) { + std::string code; + // Initialize the stream object. + dmlc::MemoryStringStream strm(&code); + + // Save header + SaveHeader(&strm); + + // Global section. + SaveGlobalSection(&strm); + + // Constant section. + SaveConstantSection(&strm); + + // Code section. + SaveCodeSection(&strm); + + stream->Write(code); +} + +void Executable::SaveToFile(const std::string& file_name, const std::string& format) { + std::string data; + dmlc::MemoryStringStream writer(&data); + dmlc::SeekStream* strm = &writer; + Executable::SaveToBinary(strm); + runtime::SaveBinaryToFile(file_name, data); +} + +Module Executable::LoadFromBinary(void* stream) { + std::string code; + static_cast(stream)->Read(&code); + dmlc::MemoryStringStream strm(&code); + + ObjectPtr exec = make_object(); + + // Load header. + LoadHeader(&strm); + + // Global section. + exec->LoadGlobalSection(&strm); + + // Constant section. + exec->LoadConstantSection(&strm); + + // Code section. + exec->LoadCodeSection(&strm); + + return Module(exec); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_relax.Executable") + .set_body_typed(Executable::LoadFromBinary); + +Module Executable::LoadFromFile(const std::string& file_name) { + std::string data; + runtime::LoadBinaryFromFile(file_name, &data); + dmlc::MemoryStringStream reader(&data); + dmlc::Stream* strm = &reader; + return Executable::LoadFromBinary(reinterpret_cast(strm)); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_relax.Executable") + .set_body_typed(Executable::LoadFromFile); + +void VMFuncInfo::Save(dmlc::Stream* strm) const { + int32_t temp_kind = static_cast(kind); + strm->Write(temp_kind); + strm->Write(name); + strm->Write(start_instr); + strm->Write(end_instr); + strm->Write(num_args); + strm->Write(register_file_size); + strm->Write(param_names); +} + +bool VMFuncInfo::Load(dmlc::Stream* strm) { + int32_t temp_kind; + if (!strm->Read(&temp_kind)) return false; + this->kind = static_cast(temp_kind); + if (!strm->Read(&name)) return false; + if (!strm->Read(&start_instr)) return false; + if (!strm->Read(&end_instr)) return false; + if (!strm->Read(&num_args)) return false; + if (!strm->Read(®ister_file_size)) return false; + if (!strm->Read(¶m_names)) return false; + return true; +} + +void Executable::SaveGlobalSection(dmlc::Stream* strm) { strm->Write(func_table); } + +void Executable::SaveConstantSection(dmlc::Stream* strm) { + strm->Write(static_cast(this->constants.size())); + for (const auto& it : this->constants) { + if (it.IsObjectRef()) { + strm->Write(ConstantType::kNDArray); + runtime::SaveDLTensor(strm, it.operator DLTensor*()); + } else if (it.IsObjectRef()) { + ShapeTuple shape = it.operator ShapeTuple(); + strm->Write(ConstantType::kShapeTuple); + strm->Write(shape.size()); + for (size_t i = 0; i < shape.size(); ++i) { + strm->Write(shape.at(i)); + } + } else if (it.IsObjectRef()) { + String str = it.operator String(); + strm->Write(ConstantType::kString); + strm->Write(str.size()); + for (size_t i = 0; i < str.size(); ++i) { + strm->Write(str.at(i)); + } + } else if (it.type_code() == kDLInt) { + strm->Write(ConstantType::kInt); + strm->Write(it.value()); + } else { + try { + strm->Write(ConstantType::kDLDataType); + strm->Write(it.operator DLDataType()); + } catch (std::exception& exc) { + LOG(FATAL) << "Constant pool can only contain NDArray, DLDataType, and Integers but got " + << ArgTypeCode2Str(it.type_code()); + } + } + } +} + +void Executable::SaveCodeSection(dmlc::Stream* strm) { + strm->Write(instr_offset); + strm->Write(instr_data); +} + +void Executable::LoadGlobalSection(dmlc::Stream* strm) { + STREAM_CHECK(strm->Read(&func_table), "Global Section"); + // setup func map + for (size_t i = 0; i < func_table.size(); ++i) { + this->func_map[func_table[i].name] = i; + } +} + +void Executable::LoadConstantSection(dmlc::Stream* strm) { + uint64_t sz; + // Load the number of constants. + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); + + size_t size = static_cast(sz); + runtime::NDArray ndarray; + DLDataType dtype; + // Load each of the constants. + for (size_t i = 0; i < size; i++) { + int constant_type; + STREAM_CHECK(strm->Read(&constant_type, sizeof(constant_type)), "constant"); + if (constant_type == ConstantType::kNDArray) { + ndarray.Load(strm); + TVMRetValue cell; + cell = ndarray; + this->constants.push_back(cell); + } else if (constant_type == ConstantType::kShapeTuple) { + uint64_t size; + strm->Read(&size); + std::vector data(size); + for (size_t i = 0; i < size; ++i) { + strm->Read(&(data[i])); + } + TVMRetValue cell; + cell = ShapeTuple(data); + this->constants.push_back(cell); + } else if (constant_type == ConstantType::kDLDataType) { + strm->Read(&dtype); + TVMRetValue cell; + cell = dtype; + this->constants.push_back(cell); + } else if (constant_type == ConstantType::kString) { + uint64_t size; + strm->Read(&size); + std::vector data(size); + for (size_t i = 0; i < size; ++i) { + strm->Read(&(data[i])); + } + TVMRetValue cell; + cell = String(std::string(data.begin(), data.end())); + this->constants.push_back(cell); + } else if (constant_type == ConstantType::kInt) { + int64_t value; + strm->Read(&value); + TVMRetValue cell; + cell = value; + this->constants.push_back(cell); + } else { + LOG(FATAL) << "Constant pool can only contain NDArray and DLDataType, but got " + << ArgTypeCode2Str(constant_type) << " when loading the VM constant pool."; + } + } +} + +void Executable::LoadCodeSection(dmlc::Stream* strm) { + STREAM_CHECK(strm->Read(&(this->instr_offset)), "instr offset"); + STREAM_CHECK(strm->Read(&(this->instr_data)), "instr data"); +} + +template +std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ", + std::function repr = std::to_string) { + if (cnt == 0) { + return ""; + } + std::ostringstream oss; + oss << repr(items[offset]); + for (int i = 1; i < cnt; ++i) { + oss << delim << repr(items[offset + i]); + } + return oss.str(); +} + +std::string RegNameToStr(RegName reg) { + if (reg == Instruction::kVoidRegister) { + return "%void"; + } + if (reg == Instruction::kVMRegister) { + return "%vm"; + } + return "%" + std::to_string(reg); +} + +String Executable::AsText() const { + auto get_func_name = [&](Index index) -> std::string { + if (static_cast(index) < func_table.size()) { + return func_table[index].name; + } else { + return "unknown_func_index(" + std::to_string(index) + ")"; + } + }; + + auto instr_to_str = [&](Instruction::Arg arg) -> std::string { + // only for argument + switch (arg.kind()) { + case Instruction::ArgKind::kRegister: + return RegNameToStr(arg.value()); + case Instruction::ArgKind::kImmediate: + return "i" + std::to_string(arg.value()); + case Instruction::ArgKind::kConstIdx: + return "c[" + std::to_string(arg.value()) + "]"; + case Instruction::ArgKind::kFuncIdx: + return "f[" + get_func_name(arg.value()) + "]"; + default: + LOG(FATAL) << "Wrong instruction kind: " << static_cast(arg.kind()); + return ""; + } + }; + + // print the text format + std::ostringstream os; + for (size_t fidx = 0; fidx < this->func_table.size(); ++fidx) { + const VMFuncInfo& gfunc = this->func_table[fidx]; + if (gfunc.kind == VMFuncInfo::FuncKind::kPackedFunc) { + os << "@" << gfunc.name << " packed_func;\n\n"; + continue; + } + if (gfunc.kind == VMFuncInfo::FuncKind::kVMTIRFunc) { + os << "@" << gfunc.name << " num_inputs=" << gfunc.num_args << " vm_tir_func;\n\n"; + continue; + } + ICHECK(gfunc.kind == VMFuncInfo::FuncKind::kVMFunc); + os << "@" << gfunc.name << ":\n"; + size_t start_instr = gfunc.start_instr; + size_t end_instr = gfunc.end_instr; + + for (size_t idx = start_instr; idx < end_instr; ++idx) { + os << " "; + Instruction instr = this->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + os << std::setw(6) << std::left << "call" << std::setw(16) << std::left + << get_func_name(instr.func_idx) << " in: " << std::setw(12) << std::left + << StrJoin(instr.args, 0, instr.num_args, ", ", instr_to_str) + << " dst: " << RegNameToStr(instr.dst) << "\n"; + break; + } + case Opcode::Ret: { + os << std::setw(6) << std::left << "ret " << RegNameToStr(instr.result) << "\n"; + break; + } + case Opcode::Goto: { + os << std::setw(6) << std::left << "goto" << instr.pc_offset << "\n"; + break; + } + case Opcode::If: { + os << std::setw(6) << std::left << "If" << RegNameToStr(instr.cond) << ", " + << instr.false_offset << "\n"; + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + os << "\n"; + } + return String(os.str()); +} + +String Executable::AsPython() const { + auto get_func_name = [&](Index index) -> std::string { + if (static_cast(index) < func_table.size()) { + return "\"" + func_table[index].name + "\""; + } else { + return "ib.unknown_func_index(" + std::to_string(index) + ")"; + } + }; + + auto arg_to_py_str = [&](Instruction::Arg arg) -> std::string { + switch (arg.kind()) { + case Instruction::ArgKind::kRegister: + if (arg.value() == Instruction::kVMRegister) { + return "ib.r(vm)"; + } + return "ib.r(" + std::to_string(arg.value()) + ")"; + case Instruction::ArgKind::kImmediate: + return "ib.imm(" + std::to_string(arg.value()) + ")"; + case Instruction::ArgKind::kConstIdx: + return "ib.c(" + std::to_string(arg.value()) + ")"; + case Instruction::ArgKind::kFuncIdx: { + return "ib.f(" + get_func_name(arg.value()) + ")"; + } + default: + LOG(FATAL) << "Wrong instruction kind: " << static_cast(arg.kind()); + return ""; + } + }; + + // print the python format + std::ostringstream os; + os << "ib = rx.Builder()\n"; + for (size_t fidx = 0; fidx < this->func_table.size(); ++fidx) { + const VMFuncInfo& gfunc = this->func_table[fidx]; + if (gfunc.kind == VMFuncInfo::FuncKind::kPackedFunc) { + continue; + } + if (gfunc.kind == VMFuncInfo::FuncKind::kVMTIRFunc) { + continue; + } + ICHECK(gfunc.kind == VMFuncInfo::FuncKind::kVMFunc); + + os << "with ib.function(\"" << gfunc.name << "\", num_inputs=" << gfunc.num_args << "):\n"; + size_t start_instr = gfunc.start_instr; + size_t end_instr = gfunc.end_instr; + + for (size_t idx = start_instr; idx < end_instr; ++idx) { + Instruction instr = this->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + os << " ib.emit_call(" << get_func_name(instr.func_idx) << ", args=[" + << StrJoin(instr.args, 0, instr.num_args, ", ", arg_to_py_str) + << "]"; + if (instr.dst != Instruction::kVoidRegister) os << ", dst=ib.r(" << instr.dst << ")"; + os << ")\n"; + break; + } + case Opcode::Ret: { + os << " ib.emit_ret(ib.r(" << instr.result << "))\n"; + break; + } + case Opcode::Goto: { + os << " ib.emit_goto(" << instr.pc_offset << ")\n"; + break; + } + case Opcode::If: { + os << " ib.emit_if(ib.r(" << instr.cond << "), " << instr.false_offset << ")\n"; + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + } + return String(os.str()); +} + +TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(Executable::LoadFromFile); + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/relax_vm/memory_manager.cc b/src/runtime/relax_vm/memory_manager.cc new file mode 100644 index 000000000000..a017b9c6d944 --- /dev/null +++ b/src/runtime/relax_vm/memory_manager.cc @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/memory_manager.cc + * \brief Allocate and manage memory for the Relay VM. + */ +#include + +#include +#include + +#include "naive_allocator.h" +#include "pooled_allocator.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +static void BufferDeleter(Object* obj) { + auto* ptr = static_cast(obj); + ICHECK(ptr->manager_ctx != nullptr); + Buffer* buffer = reinterpret_cast(ptr->manager_ctx); + MemoryManager::GetAllocator(buffer->device)->Free(*(buffer)); + delete buffer; + delete ptr; +} + +void StorageObj::Deleter(Object* obj) { + auto* ptr = static_cast(obj); + // When invoking AllocNDArray we don't own the underlying allocation + // and should not delete the buffer, but instead let it be reclaimed + // by the storage object's destructor. + // + // We did bump the reference count by 1 to keep alive the StorageObj + // allocation in case this NDArray is the sole owner. + // + // We decrement the object allowing for the buffer to release our + // reference count from allocation. + StorageObj* storage = reinterpret_cast(ptr->manager_ctx); + storage->DecRef(); + delete ptr; +} + +inline void VerifyDataType(DLDataType dtype) { + ICHECK_GE(dtype.lanes, 1); + if (dtype.code == kDLFloat) { + ICHECK_EQ(dtype.bits % 8, 0); + } else { + // allow uint1 as a special flag for bool. + if (dtype.bits == 1 && dtype.code == kDLUInt) return; + ICHECK_EQ(dtype.bits % 8, 0); + } + ICHECK_EQ(dtype.bits & (dtype.bits - 1), 0); +} + +inline size_t GetDataAlignment(const DLTensor& arr) { + size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes; + if (align < runtime::kAllocAlignment) return runtime::kAllocAlignment; + return align; +} + +runtime::NDArray StorageObj::AllocNDArray(uint64_t offset, ShapeTuple shape, DLDataType dtype) { + VerifyDataType(dtype); + + // critical zone: allocate header, cannot throw + runtime::NDArray::Container* container = + new runtime::NDArray::Container(nullptr, shape, dtype, this->buffer.device); + + container->SetDeleter(StorageObj::Deleter); + size_t needed_size = runtime::GetDataSize(container->dl_tensor); + this->IncRef(); + // The manager context pointer must continue to point to the storage object + // which owns the backing memory, and keeps track of the reference count. + // + // When we free a container we extract the storage object, decrement its + // reference count, then destroy the container, but leave the underlying + // buffer intact. + container->manager_ctx = reinterpret_cast(this); + + // is this UB? + // The only change we make w.r.t offset is modifying the data pointer + // of the backing tensor to point into the buffer instead of its start. + auto offset_ptr = reinterpret_cast(this->buffer.data) + offset; + container->dl_tensor.data = reinterpret_cast(offset_ptr); + + runtime::NDArray ret(runtime::GetObjectPtr(container)); + // RAII in effect, now run the check. + + ICHECK(offset + needed_size <= this->buffer.size) + << "storage allocation failure, attempted to allocate " << needed_size << " at offset " + << offset << " in region that is " << this->buffer.size << "bytes"; + + return ret; +} + +MemoryManager* MemoryManager::Global() { + // NOTE: explicitly use new to avoid exit-time destruction of global state + // Global state will be recycled by OS as the process exits. + static auto* inst = new MemoryManager(); + return inst; +} + +Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) { + MemoryManager* m = MemoryManager::Global(); + std::lock_guard lock(m->mutex_); + if (m->allocators_.find(dev) == m->allocators_.end()) { + std::unique_ptr alloc; + switch (type) { + case kNaive: { + DLOG(INFO) << "New naive allocator for " << runtime::DeviceName(dev.device_type) << "(" + << dev.device_id << ")"; + alloc.reset(new NaiveAllocator(dev)); + break; + } + case kPooled: { + DLOG(INFO) << "New pooled allocator for " << runtime::DeviceName(dev.device_type) << "(" + << dev.device_id << ")"; + alloc.reset(new PooledAllocator(dev)); + break; + } + default: + LOG(FATAL) << "Unknown allocator type: " << type; + } + auto ret = alloc.get(); + m->allocators_.emplace(dev, std::move(alloc)); + return ret; + } + auto alloc = m->allocators_.at(dev).get(); + if (alloc->type() != type) { + LOG(WARNING) << "The type of existing allocator for " << runtime::DeviceName(dev.device_type) + << "(" << dev.device_id << ") is different from the request type (" + << alloc->type() << " vs " << type << ")"; + } + return alloc; +} + +Allocator* MemoryManager::GetAllocator(Device dev) { + MemoryManager* m = MemoryManager::Global(); + std::lock_guard lock(m->mutex_); + auto it = m->allocators_.find(dev); + if (it == m->allocators_.end()) { + LOG(FATAL) << "Allocator for " << runtime::DeviceName(dev.device_type) << "(" << dev.device_id + << ") has not been created yet."; + } + return it->second.get(); +} + +runtime::NDArray Allocator::Empty(std::vector shape, DLDataType dtype, DLDevice dev) { + VerifyDataType(dtype); + runtime::NDArray::Container* container = + new runtime::NDArray::Container(nullptr, shape, dtype, dev); + container->SetDeleter(BufferDeleter); + size_t size = runtime::GetDataSize(container->dl_tensor); + size_t alignment = GetDataAlignment(container->dl_tensor); + Buffer* buffer = new Buffer; + *buffer = this->Alloc(size, alignment, dtype); + container->manager_ctx = reinterpret_cast(buffer); + container->dl_tensor.data = buffer->data; + return runtime::NDArray(runtime::GetObjectPtr(container)); +} + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/relax_vm/naive_allocator.h b/src/runtime/relax_vm/naive_allocator.h new file mode 100644 index 000000000000..843a559602ab --- /dev/null +++ b/src/runtime/relax_vm/naive_allocator.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/naive_allocator.h + */ +#ifndef TVM_RUNTIME_RELAX_VM_NAIVE_ALLOCATOR_H_ +#define TVM_RUNTIME_RELAX_VM_NAIVE_ALLOCATOR_H_ + +#include +#include + +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +class NaiveAllocator final : public Allocator { + public: + explicit NaiveAllocator(Device dev) : Allocator(kNaive), used_memory_(0), device_(dev) {} + + Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override { + Buffer buf; + buf.device = device_; + buf.size = nbytes; + buf.data = + runtime::DeviceAPI::Get(device_)->AllocDataSpace(device_, nbytes, alignment, type_hint); + used_memory_.fetch_add(nbytes, std::memory_order_relaxed); + DLOG(INFO) << "allocate " << nbytes << " B, used memory " << used_memory_ << " B"; + return buf; + } + + void Free(const Buffer& buffer) override { + runtime::DeviceAPI::Get(device_)->FreeDataSpace(buffer.device, buffer.data); + used_memory_.fetch_sub(buffer.size, std::memory_order_relaxed); + DLOG(INFO) << "free " << buffer.size << " B, used memory " << used_memory_ << " B"; + } + + private: + std::atomic used_memory_; + Device device_; +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_NAIVE_ALLOCATOR_H_ diff --git a/src/runtime/relax_vm/pooled_allocator.h b/src/runtime/relax_vm/pooled_allocator.h new file mode 100644 index 000000000000..0dd7d8b0277b --- /dev/null +++ b/src/runtime/relax_vm/pooled_allocator.h @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/pooled_allocator.h + */ +#ifndef TVM_RUNTIME_RELAX_VM_POOLED_ALLOCATOR_H_ +#define TVM_RUNTIME_RELAX_VM_POOLED_ALLOCATOR_H_ + +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +class PooledAllocator final : public Allocator { + public: + static constexpr size_t kDefaultPageSize = 4096; + + explicit PooledAllocator(Device dev, size_t page_size = kDefaultPageSize) + : Allocator(kPooled), page_size_(page_size), used_memory_(0), device_(dev) {} + + ~PooledAllocator() { ReleaseAll(); } + + Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override { + std::lock_guard lock(mu_); + size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_; + auto&& it = memory_pool_.find(size); + if (it != memory_pool_.end() && !it->second.empty()) { + auto&& pool = it->second; + auto ret = pool.back(); + pool.pop_back(); + return ret; + } + Buffer buf; + buf.device = device_; + buf.size = size; + try { + buf.data = + runtime::DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + } catch (InternalError& err) { + LOG(WARNING) << "PooledAllocator got InternalError during allocation: " << err.message(); + LOG(WARNING) << "Trying to release all unused memory and reallocate..."; + ReleaseAll(); + buf.data = + runtime::DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + } + + used_memory_.fetch_add(size, std::memory_order_relaxed); + DLOG(INFO) << "allocate " << size << " B, used memory " << used_memory_ << " B"; + return buf; + } + + void Free(const Buffer& buffer) override { + std::lock_guard lock(mu_); + if (memory_pool_.find(buffer.size) == memory_pool_.end()) { + memory_pool_.emplace(buffer.size, std::vector{}); + } + memory_pool_.at(buffer.size).push_back(buffer); + DLOG(INFO) << "reclaim buffer " << buffer.size; + } + + private: + void ReleaseAll() { + std::lock_guard lock(mu_); + for (auto const& it : memory_pool_) { + auto const& pool = it.second; + for (auto const& buf : pool) { + runtime::DeviceAPI::Get(buf.device)->FreeDataSpace(buf.device, buf.data); + } + } + memory_pool_.clear(); + used_memory_ = 0; + DLOG(INFO) << "release all buffers"; + } + + private: + size_t page_size_; + std::atomic used_memory_; + std::unordered_map > memory_pool_; + std::recursive_mutex mu_; + Device device_; +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_POOLED_ALLOCATOR_H_ diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc new file mode 100644 index 000000000000..3cf65faaa81a --- /dev/null +++ b/src/runtime/relax_vm/vm.cc @@ -0,0 +1,811 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/relax_vm/vm.cc + */ + +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +//--------------------------------------------- +// VM Closure object +//--------------------------------------------- +TVM_REGISTER_OBJECT_TYPE(VMClosureObj); + +VMClosure::VMClosure(String func_name, PackedFunc impl) { + auto ptr = make_object(); + ptr->func_name = func_name; + ptr->impl = std::move(impl); + data_ = std::move(ptr); +} + +/*! + * \brief Create another PackedFunc with last arguments already bound to last_args. + * \param func The input func, can be a VMClosure or PackedFunc. + * \param last_args The arguments to bound to in the end of the function. + * \note The new function takes in arguments and append the last_args in the end. + */ +PackedFunc VMClosure::BindLastArgs(PackedFunc func, std::vector last_args) { + return PackedFunc([func, last_args](TVMArgs args, TVMRetValue* rv) { + std::vector values(args.size() + last_args.size()); + std::vector tcodes(args.size() + last_args.size()); + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + std::copy(args.values, args.values + args.size(), values.data()); + std::copy(args.type_codes, args.type_codes + args.size(), tcodes.data()); + for (size_t i = 0; i < last_args.size(); ++i) { + setter(i + args.size(), last_args[i]); + } + func.CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), rv); + }); +} + +//----------------------------------------------------------- +// Utility functions. +//----------------------------------------------------------- +// Use the args after `starting_arg_idx` as a series of indices into `obj`, +// indexing into nested ADTs and returning the final indexed object. +ObjectRef IndexIntoNestedObject(ObjectRef obj, TVMArgs args, int starting_arg_idx) { + for (int i = starting_arg_idx; i < args.size(); i++) { + // the object must be an ADT to be able to index into it + if (!obj.as()) { + LOG(FATAL) << "ValueError: Attempted to index into an object that is not an ADT."; + } + int index = args[i]; + auto adt = Downcast(obj); + // make sure the index is in bounds + if (index >= static_cast(adt.size())) { + LOG(FATAL) << "IndexError: Invalid index (" << index << " >= " << adt.size() << ")."; + } + obj = adt[index]; + } + return obj; +} + +NDArray ConvertNDArrayToDevice(NDArray src, const DLDevice& dev) { + if (src->device.device_type == dev.device_type && src->device.device_id == dev.device_id) { + return src; + } + return src.CopyTo(dev); +} + +ObjectRef ConvertObjectToDevice(ObjectRef src, const Device& dev) { + if (src->IsInstance()) { + return ConvertNDArrayToDevice(Downcast(src), dev); + } else if (src->IsInstance()) { + std::vector ret; + ADT adt = Downcast(src); + for (size_t i = 0; i < adt.size(); i++) { + ret.push_back(ConvertObjectToDevice(adt[i], dev)); + } + return ADT(adt->tag, ret.begin(), ret.end()); + } else { + return src; + } +} + +TVMRetValue ConvertArgToDevice(TVMArgValue input, Device dev) { + // NOTE: NDArray::FromExternalDLTensor is not safe + // in terms of memory-behavior. + // To be extra careful, we copy DLTensor. + // The developer can still explicitly allocate NDArray + // in TVM Native API or NDArray::FromDLPack to regain zero copy behavior. + TVMRetValue ret; + + if (input.type_code() == kTVMDLTensorHandle) { + ret = NDArray::NewFromDLTensor(input, dev); + } else if (input.IsObjectRef()) { + ret = ConvertObjectToDevice(input.operator ObjectRef(), dev); + } else { + ret = input; + } + return ret; +} + +TVMRetValue ConvertRegToDevice(TVMRetValue input, Device dev) { + TVMRetValue ret; + if (input.IsObjectRef()) { + ret = ConvertObjectToDevice(input.operator ObjectRef(), dev); + } else { + ret = input; + } + return ret; +} + +//----------------------------------------------------------- +// VM implementations. +//----------------------------------------------------------- +/*! + * \brief The register type. + */ +using RegType = TVMRetValue; + +/*! + * \brief A representation of a stack frame. + * + * A stack frame is a record containing the information needed + * to restore the caller's virtual machine state after returning + * from a function call. + */ +struct VMFrame { + /*! \brief The return program counter. */ + Index return_pc; + /*! \brief Statically allocated space for objects */ + std::vector register_file; + /*! \brief Register in caller's frame to put return value */ + RegName caller_return_register; + // The following fields are used for PackedFunc call within + // a single function scope. The space is reused across multiple + // packed func calls to increase cache locality and avoid re-allocation + /*! \brief Temporary argument value stack for packed func call. */ + std::vector call_arg_values; + /*! \brief Temporary argument tcode stack for packed func call. */ + std::vector call_arg_tcodes; + + VMFrame(Index pc, Index register_file_size) + : return_pc(pc), register_file(register_file_size), caller_return_register(0) {} +}; + +class VirtualMachineImpl : public VirtualMachine { + public: + //--------------------------------------------------- + // Public facing functions overloading + //--------------------------------------------------- + void LoadExecutable(ObjectPtr exec) final; + + void Init(const std::vector& devices, + const std::vector& alloc_types) final; + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + + VMClosure GetClosure(const String& func_name) final; + + void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args, + TVMRetValue* rv) final; + + //-------------------------------------------------- + // Additional support arguments functions for VM + //-------------------------------------------------- + /*! + * \brief Set inputs to a function. + * \param func_name The function name. + * \param args args[offset:] are arguments to the function. If the arguments are not of the + * correct device for the function, they will be copied to the device. + * \param offset Starting offset of the arguments in \p args. + * \note This interface works when using VM over RPC by internally converting NDArray in + * the arguments to DLTensor, which is supported in RPC where remote could only have a minimal C + * runtime. + */ + void SetInput(std::string func_name, TVMArgs args, int offset); + + /*! + * \brief Look up whether the VM has a function by the given name. + * \param func_name the function's name + * \return The function, if it exists. Logs a fatal error if not. + */ + VMFuncInfo LookupVMFuncInfo(const std::string& func_name); + + /*! + * \brief Look up whether the VM has outputs for the given function. + * \param func_name the function's name + * \return The output, if it exists. Logs a fatal error if not. + */ + RegType LookupVMOutput(const std::string& func_name); + + /*! + * \brief Fully bind the argument of a global function and save it in the env. + * \param func_name The global function name to be saved. + * \param save_name The saved name of the function. + * \param include_return Whether forward the return value, set it to false allows + * us to ignore forwarding return value, which can be helpful to do benchmarking + * in RPC environment when return value is complicated ADT. + * + * \param args The arguments to bound to the function. + * \note This function is used by RPC server to help benchmarking. + */ + void SaveClosure(const String& func_name, const String& save_name, bool include_return, + TVMArgs args); + /*! + * \brief Internal function to invoke a closure. + * \param closure_or_packed The closure to be invoked. + * \param args The arguments to the function. + * \return The result value. + */ + RegType InvokeClosureInternal(const ObjectRef& closure_or_packed, + const std::vector& args); + /*! + * \brief Invoke a VM function by interpreting bytecode. + * \param fidx The function index. + * \param args The arguments to the function. + * \return The object representing the result. + */ + RegType InvokeBytecode(Index fidx, const std::vector& args); + + protected: + /*! + * \brief Get function by querying all of the current module's imports. + * \param name The name of the function. + * \return The result function, can return PackedFunc(nullptr) if nothing is found. + */ + PackedFunc GetFuncFromImports(const String& name) { + for (auto& lib : this->imports_) { + PackedFunc func = lib->GetFunction(name, true); + if (func.defined()) return func; + } + return PackedFunc(nullptr); + } + /*! + * \brief Initialize function pool. + */ + void InitFuncPool(); + //------------------------------------------------- + // Instruction interpretations. + //------------------------------------------------- + /*! + * \brief Push a call frame onto the call stack. + * \param ret_pc The program counter to return to. + * \param vm_func The function to be pushed to the call stack. + */ + void PushFrame(Index ret_pc, const VMFuncInfo& vm_func) { + frames_.emplace_back(std::make_unique(ret_pc, vm_func.register_file_size)); + } + /*! + * \brief Pop a frame off the call stack. + */ + void PopFrame() { + ICHECK_GT(frames_.size(), 0); + pc_ = frames_.back()->return_pc; + frames_.pop_back(); + } + /*! + * \brief Write to a VM register. + * \param frame current vm frame. + * \param reg The register to write to. + * \param obj The object to write to. + */ + void WriteRegister(VMFrame* frame, RegName reg, const RegType& obj) { + ICHECK_LT(reg, frame->register_file.size()); + frame->register_file[reg] = obj; + } + /*! + * \brief Read a VM register. + * \param frame current vm frame. + * \param reg The register to read from. + * \return The value of the register. + */ + RegType ReadRegister(VMFrame* frame, RegName reg) { + if (reg < Instruction::kBeginSpecialReg) { + return frame->register_file[reg]; + } + RegType ret; + if (reg == Instruction::kVoidRegister) { + ret = nullptr; + } else { + ICHECK_EQ(reg, Instruction::kVMRegister); + // per convention, ctx ptr must be VirtualMachine* casted to void. + // this and VirtualMachine* may or maynot be the same + // do first cast to VirtualMachine* then to void* + ret = static_cast(static_cast(this)); + } + return ret; + } + /*! + * \brief Run call instruction. + * \param curr_frame The current frame. + * \param inst The call instruction. + */ + inline void RunInstrCall(VMFrame* curr_frame, Instruction inst); + + /*! \brief Run VM dispatch loop. */ + void RunLoop(); + + private: + //-------------------------------------------------------- + // Internal states for execution. + //-------------------------------------------------------- + /*! \brief The loaded executable. */ + ObjectPtr exec_; + /*! \brief The global constant pool */ + std::vector const_pool_; + /*! + * \brief Function pool to cache functions in func_table + */ + std::vector func_pool_; + //-------------------------------------------------------- + // Executor interface support + //-------------------------------------------------------- + /*! \brief The function name to input register mapping. */ + std::unordered_map> inputs_; + /*! \brief The function name to output register. */ + std::unordered_map outputs_; + /*! \brief A store of closures created by `save_function`. */ + std::unordered_map saved_closures_; + //------------------------------------------------------------ + // VM Instruction execution. + //------------------------------------------------------------ + /*! + * \brief The current stack of call frames. + * \note: Use unique ptr to avoid re-allocation and copy when frames_ get resized. + */ + std::vector> frames_; + /*! \brief The virtual machine PC. */ + Index pc_{0}; + /*! \brief The special return register. */ + RegType return_value_; +}; + +void VirtualMachineImpl::LoadExecutable(ObjectPtr exec) { + this->exec_ = exec; + this->imports_ = exec_->imports(); +} + +void VirtualMachineImpl::Init(const std::vector& devices, + const std::vector& alloc_types) { + // TODO(@yuchen): support multi-device heterogeneous execution + ICHECK_LT(devices.size(), 3) + << "Currently relax vm only supports at most 2 devices (host + device)"; + ICHECK_EQ(devices.size(), alloc_types.size()); + + this->devices.reserve(devices.size()); + this->allocators.reserve(alloc_types.size()); + for (size_t i = 0; i < devices.size(); i++) { + auto alloc = MemoryManager::GetOrCreateAllocator(devices[i], alloc_types[i]); + this->devices.push_back(devices[i]); + this->allocators.push_back(alloc); + } + // Setup constant sections. + this->const_pool_.reserve(exec_->constants.size()); + for (const auto& constant : exec_->constants) { + if (constant.type_code() != kTVMNDArrayHandle) { + this->const_pool_.push_back(constant); + } else { + this->const_pool_.push_back(ConvertRegToDevice(constant, devices[0])); + } + } + // Setup function sections. + this->InitFuncPool(); +} + +VMFuncInfo VirtualMachineImpl::LookupVMFuncInfo(const std::string& func_name) { + ICHECK(exec_) << "The executable is not created yet."; + auto it = this->exec_->func_map.find(func_name); + CHECK(it != this->exec_->func_map.end()) << "ValueError: Unknown function: " << func_name; + + return exec_->func_table[it->second]; +} + +RegType VirtualMachineImpl::LookupVMOutput(const std::string& func_name) { + if (!outputs_.count(func_name)) { + LOG(FATAL) << "ValueError: No output saved for call of \"" << func_name + << "\"; use `invoke_stateful` to call it first."; + } + return outputs_[func_name]; +} + +PackedFunc VirtualMachineImpl::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + if (name == "vm_initialization") { + // initialize the VirtualMachine, takes variable-length arguments + // first argument is a runtime::Module, followed by one or more device_type, device_id, + // and the AllocatorType associated with the device. + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args.size() % 3, 0); + std::vector devices; + std::vector alloc_types; + for (int i = 0; i < args.size(); i += 3) { + Device dev; + int device_type = args[i]; + dev.device_type = DLDeviceType(device_type); + dev.device_id = args[i + 1]; + int type = args[i + 2]; + devices.push_back(dev); + alloc_types.push_back(AllocatorType(type)); + } + this->Init(devices, alloc_types); + }); + } else if (name == "save_function") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK_GE(args.size(), 3); + this->SaveClosure(args[0], args[1], args[2], + TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); + }); + } else if (name == "invoke_closure") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + VMClosure clo = args[0]; + this->InvokeClosurePacked(clo, TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), + rv); + }); + } else if (name == "invoke_stateful") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + const auto& m = this->exec_->func_map; + if (m.find(func_name) == m.end()) { + LOG(FATAL) << "ValueError: Unknown function: " << func_name; + } + Index gf_idx = m.at(func_name); + if (!inputs_.count(func_name)) { + LOG(FATAL) << "ValueError: No inputs set for stateful call of " << func_name + << "; use `set_input` first."; + return; + } + outputs_[func_name] = this->InvokeClosureInternal(func_pool_[gf_idx], inputs_[func_name]); + }); + } else if (name == "get_output_arity") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + RegType out = LookupVMOutput(func_name); + // use remaining args as indices + ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef(), args, 1); + // after chasing through the indices, examine the final object + if (const auto* adt = obj.as()) { + *rv = static_cast(adt->size); + } else { + *rv = -1; + } + }); + } else if (name == "get_output") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + RegType out = LookupVMOutput(func_name); + // use remaining args as indices + ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef(), args, 1); + if (obj.as()) { + LOG(FATAL) << "ValueError: `get_output` cannot return a tuple for RPC compatibility. " + "Please specify another index argument."; + return; + } + *rv = obj; + }); + } else if (name == "set_input") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInput(args[0], args, 1); }); + } else if (name == "get_function_arity") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name); + *rv = static_cast(vm_func.param_names.size()); + }); + } else if (name == "get_function_param_name") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + int index = args[1]; + const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name); + if (static_cast(index) >= vm_func.param_names.size()) { + LOG(FATAL) << "ValueError: Invalid index for " << func_name << " (" << index << " out of " + << vm_func.param_names.size() << ")"; + } + *rv = vm_func.param_names[index]; + }); + } else { + // default case, look up closure in VM. + VMClosure clo = this->GetClosure(name); + return PackedFunc([sptr_to_self, this, clo](TVMArgs args, TVMRetValue* rv) { + this->InvokeClosurePacked(clo, args, rv); + }); + } +} + +void VirtualMachineImpl::SetInput(std::string func_name, TVMArgs args, int offset) { + const auto& m = exec_->func_map; + if (m.find(func_name) != m.end()) { + Index gf_idx = m.at(func_name); + const VMFuncInfo& vm_func = exec_->func_table[gf_idx]; + size_t params_num = vm_func.num_args; + ICHECK_EQ(args.size() - offset, params_num) + << "The number of provided parameters doesn't match the number of arguments for"; + std::vector func_args(params_num); + for (int i = offset; i < args.size(); ++i) { + int index = i - offset; + func_args[index] = ConvertArgToDevice(args[i], devices[0]); + } + inputs_.emplace(func_name, func_args); + } else { + LOG(FATAL) << "ValueError: Unknown function: " << func_name; + } +} + +//------------------------------------------ +// Closure handling +//------------------------------------------ +void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, TVMArgs args, + TVMRetValue* rv) { + // run packed call if it is a packed func. + if (auto* packed = closure_or_packedfunc.as()) { + packed->CallPacked(args, rv); + return; + } + // run closure call. + auto* clo = closure_or_packedfunc.as(); + ICHECK(clo != nullptr) << "Function expects a closure or PackedFunc "; + + std::vector values(args.size() + 1); + std::vector tcodes(args.size() + 1); + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + // per convention, ctx ptr must be VirtualMachine* casted to void. + // this and VirtualMachine* may or maynot be the same + // do first cast to VirtualMachine* then to void* + setter(0, static_cast(static_cast(this))); + std::copy(args.values, args.values + args.size(), values.begin() + 1); + std::copy(args.type_codes, args.type_codes + args.size(), tcodes.begin() + 1); + clo->impl.CallPacked(TVMArgs(values.data(), tcodes.data(), args.size() + 1), rv); +} + +// internal variant version of invoke closurepacked +RegType VirtualMachineImpl::InvokeClosureInternal(const ObjectRef& closure_or_packed, + const std::vector& args) { + RegType ret; + auto* packed = closure_or_packed.as(); + auto* clo = closure_or_packed.as(); + int clo_offset = clo != nullptr ? 1 : 0; + std::vector values(args.size() + clo_offset); + std::vector tcodes(args.size() + clo_offset); + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + + if (clo != nullptr) { + setter(0, static_cast(static_cast(this))); + } + for (size_t i = 0; i < args.size(); ++i) { + setter(i + clo_offset, args[i]); + } + + if (packed != nullptr) { + packed->CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), &ret); + } else { + ICHECK(clo != nullptr); + clo->impl.CallPacked(TVMArgs(values.data(), tcodes.data(), values.size()), &ret); + } + return ret; +} + +void VirtualMachineImpl::SaveClosure(const String& func_name, const String& save_name, + bool include_return, TVMArgs args) { + VMClosure clo = this->GetClosure(func_name); + std::vector inputs(args.size()); + for (int i = 0; i < args.size(); ++i) { + inputs[i] = ConvertArgToDevice(args[i], this->devices[0]); + } + PackedFunc impl = VMClosure::BindLastArgs(clo->impl, inputs); + if (!include_return) { + impl = PackedFunc([impl](TVMArgs args, TVMRetValue* rv) { + TVMRetValue temp; + impl.CallPacked(args, &temp); + }); + } + saved_closures_[save_name] = VMClosure(save_name, impl); +} + +VMClosure VirtualMachineImpl::GetClosure(const String& func_name) { + // look up saved closures. + auto saved_it = saved_closures_.find(func_name); + if (saved_it != saved_closures_.end()) { + return saved_it->second; + } + auto it = exec_->func_map.find(func_name); + CHECK(it != exec_->func_map.end()) << "ValueError: Unknown function: " << func_name; + + Index gf_idx = it->second; + const VMFuncInfo& finfo = exec_->func_table[gf_idx]; + + if (finfo.kind == VMFuncInfo::FuncKind::kVMFunc) { + // NOTE: should not capture strong ref to self and avoid cyclic ref. + auto impl = PackedFunc([gf_idx](TVMArgs args, TVMRetValue* rv) { + // Per convention, ctx ptr is a VirtualMachine* + VirtualMachine* ctx_ptr = static_cast(args[0].operator void*()); + + std::vector inputs(args.size() - 1); + for (size_t i = 0; i < inputs.size(); ++i) { + inputs[i] = args[i + 1]; + } + *rv = static_cast(ctx_ptr)->InvokeBytecode(gf_idx, inputs); + }); + return VMClosure(func_name, impl); + } else { + ICHECK(finfo.kind == VMFuncInfo::FuncKind::kVMTIRFunc) + << "Cannot support closure with function kind " << static_cast(finfo.kind); + PackedFunc tir_func = GetFuncFromImports("__vmtir__" + finfo.name); + ICHECK(tir_func != nullptr) << "Cannot find underlying compiled tir function of VMTIRFunc " + << finfo.name; + auto impl = PackedFunc([this, finfo, tir_func](TVMArgs args, TVMRetValue* rv) { + // Per convention, ctx ptr is a VirtualMachine* + VirtualMachine* ctx_ptr = static_cast(args[0].operator void*()); + ICHECK(ctx_ptr == this); + ICHECK_EQ(args.size() - 1, finfo.num_args) + << "Function " << finfo.name << " expects " << finfo.num_args << " arguments"; + ICHECK_GE(finfo.register_file_size, finfo.num_args + 1); + std::vector reg_file(finfo.register_file_size); + for (int64_t i = 0; i < finfo.num_args; ++i) { + reg_file[i] = args[i + 1]; + } + void* reg_anylist_handle = reg_file.data(); + void* const_anylist_handle = this->const_pool_.data(); + void* func_anylist_handle = this->func_pool_.data(); + tir_func(static_cast(ctx_ptr), reg_anylist_handle, const_anylist_handle, + func_anylist_handle); + // Return value always stored after inputs. + *rv = reg_file[finfo.num_args]; + }); + return VMClosure(func_name, impl); + } +} + +//-------------------------------------------------------------------- +// Instruction interpretations. +//-------------------------------------------------------------------- +RegType VirtualMachineImpl::InvokeBytecode(Index gf_idx, const std::vector& args) { + const VMFuncInfo& gfunc = exec_->func_table[gf_idx]; + ICHECK(gfunc.kind == VMFuncInfo::FuncKind::kVMFunc); + + // Get the curr instr which might be a potential caller. + Instruction curr_instr = exec_->GetInstruction(pc_); + PushFrame(this->pc_, gfunc); + // Get new frame and set the caller info. + VMFrame* curr_frame = frames_.back().get(); + if (curr_instr.op == Opcode::Call) { + curr_frame->caller_return_register = curr_instr.dst; + } + + // load arguments to the register file + ICHECK_EQ(static_cast(gfunc.num_args), args.size()) + << "ValueError: Invoking function " << gfunc.name << " requires " << gfunc.num_args + << " inputs but only " << args.size() << " inputs are provided."; + for (size_t i = 0; i < args.size(); ++i) { + WriteRegister(frames_.back().get(), i, args[i]); + } + // set program counter + pc_ = gfunc.start_instr; + RunLoop(); + return return_value_; +} + +void VirtualMachineImpl::InitFuncPool() { + func_pool_.resize(exec_->func_table.size()); + + for (size_t func_index = 0; func_index < exec_->func_table.size(); ++func_index) { + const VMFuncInfo& info = exec_->func_table[func_index]; + if (info.kind == VMFuncInfo::FuncKind::kPackedFunc) { + // only look through imports first + PackedFunc func = GetFuncFromImports(info.name); + if (!func.defined()) { + const PackedFunc* p_func = Registry::Get(info.name); + if (p_func != nullptr) func = *(p_func); + } + ICHECK(func.defined()) + << "Error: Cannot find PackedFunc " << info.name + << " in either Relax VM kernel library, or in TVM runtime PackedFunc registry, or in " + "global Relax functions of the VM executable"; + func_pool_[func_index] = func; + + } else { + ICHECK(info.kind == VMFuncInfo::FuncKind::kVMFunc || + info.kind == VMFuncInfo::FuncKind::kVMTIRFunc); + auto clo = this->GetClosure(info.name); + func_pool_[func_index] = clo; + } + } +} + +void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { + DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << exec_->func_table[instr.func_idx].name; + + // Use the call arg stack from the current frame to increase reuse + // and avoid re-allocation + curr_frame->call_arg_values.resize(instr.num_args); + curr_frame->call_arg_tcodes.resize(instr.num_args); + + // NOTE: no changes and resize to those vector ref(otherwise can leads to segfault) + // in the remainder part of the function. + std::vector& values = curr_frame->call_arg_values; + std::vector& tcodes = curr_frame->call_arg_tcodes; + + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + for (Index i = 0; i < instr.num_args; ++i) { + Instruction::Arg arg = instr.args[i]; + switch (arg.kind()) { + case Instruction::ArgKind::kRegister: { + setter(i, ReadRegister(curr_frame, arg.value())); + break; + } + case Instruction::ArgKind::kImmediate: { + setter(i, arg.value()); + break; + } + case Instruction::ArgKind::kConstIdx: { + setter(i, this->const_pool_[arg.value()]); + break; + } + case Instruction::ArgKind::kFuncIdx: { + ICHECK_LT(static_cast(arg.value()), this->func_pool_.size()); + setter(i, this->func_pool_[arg.value()]); + break; + } + default: { + LOG(FATAL) << "ValueError: Unknown argument kind: " << int(arg.kind()); + } + } + } + TVMArgs args(values.data(), tcodes.data(), values.size()); + TVMRetValue ret; + + ICHECK_LT(static_cast(instr.func_idx), this->func_pool_.size()); + this->InvokeClosurePacked(func_pool_[instr.func_idx], args, &ret); + // save the return value to the register + // saving to special register is a NOP + if (instr.dst < Instruction::kBeginSpecialReg) { + WriteRegister(curr_frame, instr.dst, ret); + } + // increment pc + pc_++; +} + +void VirtualMachineImpl::RunLoop() { + VMFrame* curr_frame = frames_.back().get(); + + while (true) { + ICHECK_LT(static_cast(pc_), exec_->instr_offset.size()) << "run into invalide section"; + Instruction instr = exec_->GetInstruction(pc_); + switch (instr.op) { + case Opcode::Call: { + this->RunInstrCall(curr_frame, instr); + break; + } + case Opcode::Ret: { + // If we have hit the point from which we started + // running, we should return to the caller breaking + // the dispatch loop. + return_value_ = ReadRegister(curr_frame, instr.result); + RegName caller_return_register = curr_frame->caller_return_register; + PopFrame(); + if (frames_.size() == 0) { + // directly return if no frame in the call stack. + } else { + // return from a local call. + // Update the current frame to be the parent frame. + curr_frame = frames_.back().get(); + WriteRegister(curr_frame, caller_return_register, return_value_); + } + return; + } + case Opcode::Goto: { + pc_ += instr.pc_offset; + break; + } + case Opcode::If: { + int64_t cond_val = ReadRegister(curr_frame, instr.cond); + if (cond_val != 0) { + pc_++; + } else { + ICHECK_GT(instr.false_offset, 1); + pc_ += instr.false_offset; + } + break; + } + } + } +} + +ObjectPtr VirtualMachine::Create() { return make_object(); } + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/tests/python/relax/test_vm_execbuilder.py b/tests/python/relax/test_vm_execbuilder.py new file mode 100644 index 000000000000..9a7cd0c87938 --- /dev/null +++ b/tests/python/relax/test_vm_execbuilder.py @@ -0,0 +1,262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Lowest level testing VM. Test execbuilder and execution.""" +import tvm +import pytest +import numpy as np +from tvm import relax, TVMError +from tvm.relax.testing.vm import check_saved_func + + +def test_vm_execute(): + ib = relax.ExecBuilder() + with ib.function("func0", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + + add_res = check_saved_func(vm, "func0", a, b) + tvm.testing.assert_allclose(add_res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_multiple_func(): + ib = relax.ExecBuilder() + with ib.function("func0", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + with ib.function("func1", num_inputs=2): + ib.emit_call("test.vm.mul", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + mul_res = check_saved_func(vm, "func1", a, b) + add_res = check_saved_func(vm, "func0", a, b) + tvm.testing.assert_allclose(add_res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(mul_res.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_checker(): + ib = relax.ExecBuilder() + with pytest.raises(TVMError): + with ib.function("func0", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(2)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ib.get() + + +def test_neg_imm(): + ib = relax.ExecBuilder() + + with ib.function("func0", num_inputs=1): + ib.emit_call("test.vm.add_scalar", args=[ib.imm(-3), ib.r(0)], dst=ib.r(1)) + ib.emit_ret(ib.r(1)) + ib.get() + + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + assert vm["func0"](1) == -2 + assert vm["func0"](-3) == -6 + + +def test_emit_cache(): + ib = relax.ExecBuilder() + + with ib.function("func0", num_inputs=1): + x0 = ib.convert_constant("str0") + x1 = ib.convert_constant("str0") + # cache constant str + assert x0 == x1 + s0 = ib.convert_constant(tvm.runtime.container.ShapeTuple([1, 2])) + s1 = ib.convert_constant(tvm.runtime.container.ShapeTuple([1, 2])) + s2 = ib.convert_constant(tvm.runtime.container.ShapeTuple([1, 3])) + assert s0 == s1 + assert s1 != s2 + y0 = ib.convert_constant(tvm.nd.array(np.array([1, 2, 3]).astype("int32"))) + y1 = ib.convert_constant(tvm.nd.array(np.array([1, 2, 3]).astype("int32"))) + assert y0 == y1 + ib.emit_ret(ib.r(0)) + + +def test_vm_formalize(): + ib0 = relax.ExecBuilder() + ib1 = relax.ExecBuilder() + with ib0.function("func0", num_inputs=2): + ib0.emit_call("test.vm.add", args=[ib0.r(0), ib0.r(1)], dst=ib0.r(100)) + ib0.emit_call("test.vm.mul", args=[ib0.r(1), ib0.r(100)], dst=ib0.r(50)) + ib0.emit_ret(ib0.r(50)) + with ib1.function("func0", num_inputs=2): + ib1.emit_call("test.vm.add", args=[ib1.r(0), ib1.r(1)], dst=ib1.r(2)) + ib1.emit_call("test.vm.mul", args=[ib1.r(1), ib1.r(2)], dst=ib1.r(3)) + ib1.emit_ret(ib1.r(3)) + exec0 = ib0.get() + exec1 = ib1.get() + assert exec0.as_text() == exec1.as_text() + + +def test_vm_operand(): + ib0 = relax.ExecBuilder() + with ib0.function("func0", num_inputs=2): + ib0.emit_call("test.vm.add_scalar", args=[ib0.r(0), ib0.r(1)], dst=ib0.r(2)) + ib0.emit_ret(ib0.r(2)) + exec0 = ib0.get() + vm = relax.VirtualMachine(exec0, tvm.cpu()) + res = vm["func0"](2, 3) + assert res == 5 + + ib1 = relax.ExecBuilder() + with ib1.function("func1", num_inputs=1): + ib1.emit_call("test.vm.get_device_id", args=[ib1.r(0)], dst=ib1.r(1)) + ib1.emit_ret(ib1.r(1)) + exec1 = ib1.get() + vm = relax.VirtualMachine(exec1, tvm.cpu()) + res = vm["func1"](tvm.cpu(3)) + assert res == 3 + + +def test_vm_shapeof(): + ib = relax.ExecBuilder() + shape = (32, 16) + arr = tvm.nd.array(np.random.rand(*shape)) + with ib.function("main", num_inputs=0): + ib.emit_call("vm.builtin.shape_of", args=[arr], dst=ib.r(0)) + ib.emit_ret(ib.r(0)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + for i, s in enumerate(res): + assert s == shape[i] + + +def test_vm_storage(): + dtype = tvm.DataType("float32") + shape = (4, 6) + ib = relax.ExecBuilder() + with ib.function("main", num_inputs=0): + ib.emit_call( + "vm.builtin.alloc_storage", + args=[ib.vm_state(), (24,), ib.convert_constant(0), dtype], + dst=ib.r(1), + ) + ib.emit_call( + "vm.builtin.alloc_tensor", args=[ib.r(1), ib.imm(0), shape, dtype], dst=ib.r(2) + ) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + assert res.device == tvm.cpu() + assert res.shape == shape + + +def test_vm_goto(): + ib = relax.ExecBuilder() + with ib.function("main", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_goto(2) + ib.emit_call("test.vm.mul", args=[ib.r(2), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + res = check_saved_func(vm, "main", a, b) + tvm.testing.assert_allclose(res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_if(): + ib = relax.ExecBuilder() + with ib.function("main", num_inputs=3): + ib.emit_if(ib.r(0), 3) + ib.emit_call("test.vm.add", args=[ib.r(1), ib.r(2)], dst=ib.r(3)) + ib.emit_goto(2) + ib.emit_call("test.vm.mul", args=[ib.r(1), ib.r(2)], dst=ib.r(3)) + ib.emit_ret(ib.r(3)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + res = vm["main"](0, a, b) + tvm.testing.assert_allclose(res.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + res = vm["main"](1, a, b) + tvm.testing.assert_allclose(res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_invoke_closure(): + ib = relax.ExecBuilder() + with ib.function("lifted_func_1", num_inputs=4): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(4)) + ib.emit_call("test.vm.add", args=[ib.r(2), ib.r(4)], dst=ib.r(5)) + ib.emit_call("test.vm.add", args=[ib.r(3), ib.r(5)], dst=ib.r(6)) + ib.emit_ret(ib.r(6)) + with ib.function("main", num_inputs=2): + ib.emit_call( + "vm.builtin.make_closure", args=[ib.f("lifted_func_1"), ib.r(0), ib.r(1)], dst=ib.r(2) + ) + ib.emit_ret(ib.r(2)) + + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + w_inp = tvm.nd.array(np.random.rand(2, 3)) + x_inp = tvm.nd.array(np.random.rand(2, 3)) + y_inp = tvm.nd.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]]) + z_inp = tvm.nd.array(np.random.rand(2, 3)) + clo = vm["main"](w_inp, x_inp) + res = vm.invoke_closure(clo, y_inp, z_inp) + tvm.testing.assert_allclose( + res.numpy(), w_inp.numpy() + x_inp.numpy() + y_inp.numpy() + z_inp.numpy() + ) + + +if __name__ == "__main__": + tvm.testing.main() From 2f96da7e7a2730c986f1584b4c9a2801f8a65783 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Thu, 2 Feb 2023 09:35:41 -0500 Subject: [PATCH 02/81] [Unity] Relax expressions and types (#13901) --- include/tvm/ir/expr.h | 8 + include/tvm/relax/expr.h | 1003 ++++++++++++++++++++++++++++++++++++++ include/tvm/relax/type.h | 166 +++++++ 3 files changed, 1177 insertions(+) create mode 100644 include/tvm/relax/expr.h create mode 100644 include/tvm/relax/type.h diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index c8531c88465a..d4ba628d36cf 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -367,6 +367,14 @@ class RelayExprNode : public BaseExprNode { * This value is discarded during serialization. */ mutable Type checked_type_ = Type(nullptr); + + /*! + * \brief Stores the result of structure information of the + * expression that encapsulate both static shape and + * runtime information such as shape. + */ + mutable Optional struct_info_ = Optional(); + /*! * \return The checked_type */ diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h new file mode 100644 index 000000000000..8154b1dd86de --- /dev/null +++ b/include/tvm/relax/expr.h @@ -0,0 +1,1003 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RELAX_EXPR_H_ +#define TVM_RELAX_EXPR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +using Expr = RelayExpr; +using ExprNode = RelayExprNode; +using relay::Id; + +/*! + * \brief Base type of all structure information. + * + * StructInfo stores possible structure information + * deduced during compile-time. It encapsulates + * both static type and runtime information such + * as shape. + * + * StructInfo of each non-primitive Expr can be + * deduced during compilation in a "best-effort" manner. + * + * When struct_info appears in function parameter and return + * signatures. They will imply a runtime check that matches + * the structure information with the value. + * + * When it appears in Expr, they follow "assume-semantics", + * which means the compiler will take the deduced information as it is + * and only do best effort prove and checks. + * + * Each struct info can be uniquely erased to a static-type. + * The compiler will still compile the code(with less information) + * when we erase to the static type. + * + * If an StructInfo contains an Expr field, then that field + * must be normalized already through NormalizeArg. + * This invariant will be checked in constructors + * and help us to simplify our assumption + * during struct info deduction. + */ +class StructInfoNode : public Object { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + static constexpr const char* _type_key = "StructInfo"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 5; + TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object); +}; + +/*! + * \brief Managed reference to StructInfoNode. + * \sa StructInfoNode + */ +class StructInfo : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(StructInfo, ObjectRef, StructInfoNode); +}; + +/*! + * \brief Call corresponds to callable invocation. + * Corresponds to operation in computational graph terminology. + */ +class CallNode : public ExprNode { + public: + /*! + * \brief The operator(function) being invoked + * + * - It can be tvm::Op which corresponds to the primitive operators. + * - It can also be user defined functions (Function, GlobalVar, Var). + */ + Expr op; + + /*! \brief The arguments(inputs) of the call */ + tvm::Array args; + + /*! \brief The additional attributes */ + Attrs attrs; + + /*! + * \brief The structure info arguments of a CallNode. + * sinfo_args is designed to be non-empty only for intrinsic op (e.g., + * call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main + * usage of structure info inference. + */ + Array sinfo_args; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + v->Visit("args", &args); + v->Visit("attrs", &attrs); + v->Visit("sinfo_args", &sinfo_args); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { + // skip sinfo_args check for primitive ops. + equal->MarkGraphNode(); + return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) && + (IsPrimitiveOp(op) || equal(sinfo_args, other->sinfo_args)) && + equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(op); + hash_reduce(args); + hash_reduce(attrs); + if (!IsPrimitiveOp(op)) { + hash_reduce(sinfo_args); + } + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.Call"; + TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); +}; + +class Call : public Expr { + public: + /*! + * \brief The constructor + * \param op The operator to be invoked. + * \param args The arguments of the call. + * \param attrs The attributes of the call node. + * \param sinfo_args The structure info arguments passed to a function. + * \param span The source span of the expression. + */ + TVM_DLL Call(Expr op, Array args, Attrs attrs = Attrs(), + Array sinfo_args = Array(), Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); +}; + +/*! + * \brief Returns \p call with the given properties. A null property denotes 'no change'. + * Returns \p call if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +Call WithFields(Call call, Optional opt_op = Optional(), + Optional> opt_args = Optional>(), + Optional opt_attrs = Optional(), + Optional> opt_sinfo_args = Optional>(), + Optional opt_span = Optional()); + +/*! + * \brief Condition expression + * + * Unlike traditional statement `if`s, the if evalutes + * to the result of the branch taken. + * + * x = if (true) { 1 } else { 0 }; // x is 1 + * y = if (false) { 1 } else { 0 }; // y is 0 + * + * \note This is similar to C's ternary operator. + */ +class IfNode : public ExprNode { + public: + /*! \brief The condition. */ + Expr cond; + /*! \brief The expression evaluated when condition is true. */ + Expr true_branch; + /*! \brief The expression evaluated when condition is false */ + Expr false_branch; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("cond", &cond); + v->Visit("true_branch", &true_branch); + v->Visit("false_branch", &false_branch); + v->Visit("_checked_type_", &checked_type_); + v->Visit("struct_info_", &struct_info_); + v->Visit("span", &span); + } + + bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(cond, other->cond) && equal(true_branch, other->true_branch) && + equal(false_branch, other->false_branch) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(cond); + hash_reduce(true_branch); + hash_reduce(false_branch); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.If"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); +}; + +class If : public Expr { + public: + /*! + * \brief The constructor + * \param cond The condition of a if node. + * \param true_branch The fall through branch + * \param false_branch The branch for execution when condition is false. + * \param span The source span of the expression. + */ + TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode); +}; + +/*! + * \brief Returns \p if_expr with the given properties. A null property denotes 'no change'. + * Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +If WithFields(If if_expr, Optional opt_cond = Optional(), + Optional opt_true_branch = Optional(), + Optional opt_false_branch = Optional(), + Optional opt_span = Optional()); + +/*! \brief Tuple container */ +class TupleNode : public ExprNode { + public: + /*! \brief the fields of the tuple */ + tvm::Array fields; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("fields", &fields); + v->Visit("_checked_type_", &checked_type_); + v->Visit("struct_info_", &struct_info_); + v->Visit("span", &span); + } + + bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from fields. + return equal(fields, other->fields); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } + + static constexpr const char* _type_key = "relax.expr.Tuple"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode); +}; + +class Tuple : public Expr { + public: + /*! + * \brief The constructor + * \param fields The fields of a tuple. + * \param span The source span of the expression. + */ + TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode); +}; + +/*! + * \brief Returns \p tuple with the given properties. A null property denotes 'no change'. + * Returns \p tuple if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +Tuple WithFields(Tuple tuple, Optional> opt_fields = Optional>(), + Optional opt_span = Optional()); + +/*! \brief Get index-th field out of a tuple. */ +class TupleGetItemNode : public ExprNode { + public: + /*! \brief The tuple Expression */ + Expr tuple; + /*! \brief which value to get */ + int index; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tuple_value", &tuple); + v->Visit("index", &index); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const { + // struct info can be deterministically tuple and index. + return equal(tuple, other->tuple) && equal(index, other->index); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(tuple); + hash_reduce(index); + } + + static constexpr const char* _type_key = "relax.expr.TupleGetItem"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode); +}; + +class TupleGetItem : public Expr { + public: + /*! + * \brief The constructor + * \param tuple The tuple to get an element from. + * \param index The index for extracting a value in the tuple. + * \param span The source span of the expression. + */ + TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode); +}; + +/*! + * \brief Returns \p tuple_get_item with the given properties. A null property denotes 'no change'. + * Returns \p tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), + Optional opt_index = Optional(), + Optional opt_span = Optional()); + +/*! + * \brief Base type of all (non-function) leaf Exprs. + * \sa Expr + */ +class LeafExprNode : public ExprNode { + public: + static constexpr const char* _type_key = "relax.expr.LeafExpr"; + static constexpr const uint32_t _type_child_slots = 7; + TVM_DECLARE_BASE_OBJECT_INFO(LeafExprNode, ExprNode); +}; + +/*! + * \brief Managed reference to BaseExprNode. + * \sa LeafExprNode + */ +class LeafExpr : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(LeafExpr, Expr, LeafExprNode); +}; + +/*! \brief A shape expression which allows users to construct a shape containing PrimExpr. + */ +class ShapeExprNode : public LeafExprNode { + public: + /*! The values of the shape expression. */ + Array values; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("values", &values); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from values. + return equal(values, other->values); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(values); } + + static constexpr const char* _type_key = "relax.expr.ShapeExpr"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeExprNode, LeafExprNode); +}; + +class ShapeExpr : public LeafExpr { + public: + TVM_DLL explicit ShapeExpr(Array values, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, LeafExpr, ShapeExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode); +}; + +/*! \brief The variable class for all Relax bindings. */ +class VarNode : public LeafExprNode { + public: + /*! \brief The identifier of the variable, which is used for comparing stable equality across + * transformations. */ + Id vid; + + /*! \return The name hint of the variable */ + const String& name_hint() const { return vid->name_hint; } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("vid", &vid); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(vid, other->vid) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(vid); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.Var"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 2; + TVM_DECLARE_BASE_OBJECT_INFO(VarNode, LeafExprNode); +}; + +class Var : public LeafExpr { + public: + TVM_DLL explicit Var(String name_hint, Optional struct_info_annotation, + Span span = Span()) + : Var(Id(name_hint), struct_info_annotation, span) {} + + TVM_DLL explicit Var(Id vid, Optional struct_info_annotation, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); +}; + +/*! \brief A sub-type of the variable node used to mark dataflow variables from + * normal visible "function local" bindings. + */ +class DataflowVarNode : public VarNode { + public: + void VisitAttrs(AttrVisitor* v) { + v->Visit("vid", &vid); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(vid, other->vid) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(vid); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.DataflowVar"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarNode, VarNode); +}; + +class DataflowVar : public Var { + public: + TVM_DLL explicit DataflowVar(String name_hint, Optional struct_info_annotation, + Span span = Span()) + : DataflowVar(Id(name_hint), struct_info_annotation, span) {} + + TVM_DLL explicit DataflowVar(Id vid, Optional struct_info_annotation, + Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode); +}; + +/*! + * \brief Constant tensor. + * + * \note Scalar constants are represented by ndim-0 constant tensors. + */ +class ConstantNode : public LeafExprNode { + public: + /*! \brief The data of the tensor */ + runtime::NDArray data; + + /*! \return The corresponding tensor type of the data */ + TensorType tensor_type() const; + + /*! \return Whether it is scalar(ndim-0 tensor) */ + bool is_scalar() const { return data->ndim == 0; } + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("data", &data); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. + return equal(data, other->data); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); } + + static constexpr const char* _type_key = "relax.expr.Constant"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, LeafExprNode); +}; + +class Constant : public LeafExpr { + public: + /*! + * \brief The constructor + * \param data The data of the constant tensor. + * \param span The source span of the expression. + */ + TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Constant, LeafExpr, ConstantNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode); +}; + +/*! + * \brief PrimValue. + * + * Expression representing a TIR POD expression. + */ +class PrimValueNode : public LeafExprNode { + public: + /*! \brief The prim expr representing the value */ + PrimExpr value; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const PrimValueNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. + return equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } + + static constexpr const char* _type_key = "relax.expr.PrimValue"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimValueNode, LeafExprNode); +}; + +/*! + * \brief Managed reference to PrimValueNode + * \sa PrimValeNode + */ +class PrimValue : public LeafExpr { + public: + /*! + * \brief The constructor + * \param value The value input. + * \param span The source span of the expression. + */ + TVM_DLL explicit PrimValue(PrimExpr value, Span span = Span()); + + /*! + * \brief Create a int64 prim value. + * \param value The input value. + * \param span The source span of the expression. + * \return The created prim value. + */ + TVM_DLL static PrimValue Int64(int64_t value, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(PrimValue, LeafExpr, PrimValueNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimValueNode); +}; + +/*! + * \brief Represent a string literal constant. + */ +class StringImmNode : public LeafExprNode { + public: + /*! \brief The data value. */ + String value; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. + return equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } + + static constexpr const char* _type_key = "relax.expr.StringImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, LeafExprNode); +}; + +/*! + * \brief Managed reference to StringImm + * \sa StringImmNode + */ +class StringImm : public LeafExpr { + public: + /*! + * \brief The constructor + * \param value The value input. + * \param span The source span of the expression. + */ + TVM_DLL explicit StringImm(String value, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(StringImm, LeafExpr, StringImmNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode); +}; + +/*! + * \brief Represent a data type constant. + */ +class DataTypeImmNode : public LeafExprNode { + public: + /*! \brief The data value. */ + DataType value; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const DataTypeImmNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. + return equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } + + static constexpr const char* _type_key = "relax.expr.DataTypeImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataTypeImmNode, LeafExprNode); +}; + +/*! + * \brief Managed reference to DataTypeImm + * \sa DataTypeImmNode + */ +class DataTypeImm : public LeafExpr { + public: + /*! + * \brief The constructor + * \param value The value input. + * \param span The source span of the expression. + */ + TVM_DLL explicit DataTypeImm(DataType value, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(DataTypeImm, LeafExpr, DataTypeImmNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DataTypeImmNode); +}; + +/*! \brief The base class of a variable binding in Relax. */ +class BindingNode : public Object { + public: + /*! \brief The return variable to bound to. */ + Var var; + mutable Span span; + + static constexpr const char* _type_key = "relax.expr.Binding"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(BindingNode, Object); +}; + +class Binding : public ObjectRef { + protected: + Binding() = default; + + public: + explicit Binding(ObjectPtr n) : ObjectRef(n) {} + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding); + const BindingNode* operator->() const { return static_cast(data_.get()); } + const BindingNode* get() const { return operator->(); } + using ContainerType = BindingNode; +}; + +/*! + * \brief Runtime-match the value to the struct info. + * + * This operation does runtime check, populates the un-defined symbolic shape vars + * and vars in struct_info in first occurance, and insert equality assertions in + * other cases. + */ +class MatchCastNode : public BindingNode { + public: + /*! \brief The input value to match cast. */ + Expr value; + /*! \brief The struct info pattern to match to. */ + StructInfo struct_info; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("struct_info", &struct_info); + v->Visit("span", &span); + } + + bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const { + // NOTE: pattern can contain ShapeExpr which defines the vars + return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) && + equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { + // NOTE: pattern can contain ShapeExpr which defines the vars + hash_reduce.DefHash(var); + hash_reduce.DefHash(struct_info); + hash_reduce(value); + } + + static constexpr const char* _type_key = "relax.expr.MatchCast"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(MatchCastNode, BindingNode); +}; + +/*! + * \brief Managed reference to MatchCastNode. + * \sa MatchCastNode + */ +class MatchCast : public Binding { + public: + TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(MatchCast, Binding, MatchCastNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode); +}; + +class VarBindingNode : public BindingNode { + public: + /*! \brief The binding value. */ + Expr value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("span", &span); + } + + bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { + return equal.DefEqual(var, other->var) && equal(value, other->value); + } + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(var); + hash_reduce(value); + } + static constexpr const char* _type_key = "relax.expr.VarBinding"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(VarBindingNode, BindingNode); +}; + +class VarBinding : public Binding { + public: + TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode); +}; + +class BindingBlockNode : public Object { + public: + mutable Span span; + Array bindings; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("span", &span); + v->Visit("bindings", &bindings); + } + + bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const { + return equal(bindings, other->bindings); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + + static constexpr const char* _type_key = "relax.expr.BindingBlock"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(BindingBlockNode, Object); +}; + +class BindingBlock : public ObjectRef { + public: + TVM_DLL explicit BindingBlock(Array bindings, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BindingBlockNode); +}; + +class DataflowBlock; +class DataflowBlockNode : public BindingBlockNode { + public: + bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const { + return equal(bindings, other->bindings); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + + static constexpr const char* _type_key = "relax.expr.DataflowBlock"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockNode, BindingBlockNode); +}; + +class DataflowBlock : public BindingBlock { + public: + TVM_DLL explicit DataflowBlock(Array bindings, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode); +}; + +/*! \brief A sequence of blocks followed by an expression. + * + * The order of blocks enforces scoping and ordering. + */ +class SeqExprNode : public ExprNode { + public: + Array blocks; + Expr body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("blocks", &blocks); + v->Visit("body", &body); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const SeqExprNode* other, SEqualReducer equal) const { + return equal(blocks, other->blocks) && equal(body, other->body) && + equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(blocks); + hash_reduce(body); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.SeqExpr"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(SeqExprNode, ExprNode); +}; + +class SeqExpr : public Expr { + public: + TVM_DLL explicit SeqExpr(Array blocks, Expr body, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode); +}; + +/*! \brief A Relax function. */ +class FunctionNode : public BaseFuncNode { + public: + /*! \brief The parameters to the function. */ + Array params; + /*! \brief The body of the function. */ + Expr body; + /*! \brief The return type of the function. */ + StructInfo ret_struct_info; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("body", &body); + v->Visit("ret_struct_info", &ret_struct_info); + v->Visit("attrs", &attrs); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal.DefEqual(params, other->params) && equal(body, other->body) && + equal(ret_struct_info, other->ret_struct_info) && equal(attrs, other->attrs) && + equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce.DefHash(params); + hash_reduce(body); + hash_reduce(ret_struct_info); + hash_reduce(attrs); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.Function"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); +}; + +class Function : public BaseFunc { + public: + TVM_DLL explicit Function(Array params, Expr body, Optional ret_struct_info, + DictAttrs attrs = NullValue(), Span span = Span()); + + /*! + * \brief Mimics the constructor but without body Expr. + * \note ret_struct_info is required, since it can not deduced by the body + */ + TVM_DLL static Function CreateEmpty(Array params, StructInfo ret_struct_info, + DictAttrs attrs = NullValue(), Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); +}; + +// TODO(@sunggg): Investigate the exact usage of kComposite, kPartitionedFromPattern, and +// kPrimitive. +namespace attr { +/*! \brief Mark the function as a primitive function. */ +constexpr const char* kPrimitive = "Primitive"; +/*! + * \brief Indicate the codegen that should be used for building this function. + * When this is unset or set to "default", the default compilation pipeline will be used. + */ +constexpr const char* kCodegen = "Codegen"; +/*! \brief Treat the function as a composite operator. */ +constexpr const char* kComposite = "Composite"; +/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ +constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; +} // namespace attr + +/*! \brief The extern function, which can represent packed function. */ +class ExternFuncNode : public BaseFuncNode { + public: + /*! \brief The name of global symbol. */ + String global_symbol; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("global_symbol", &global_symbol); + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const { + return equal(global_symbol, other->global_symbol) && equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(global_symbol); + hash_reduce(struct_info_); + } + + static constexpr const char* _type_key = "relax.expr.ExternFunc"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncNode, BaseFuncNode); +}; + +class ExternFunc : public BaseFunc { + public: + TVM_DLL ExternFunc(String global_symbol, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, BaseFunc, ExternFuncNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode); +}; + +/*! + * \brief Get the shape of Expr. + * \param expr The input expr. + * \return The corresonding shape. + * + * \note This function requires expr to be normalized. + * The function will report an error if expr's StructInfo is not TensorStructInfo. + * It will try to return symbolic function when possible. If the tensor do not + * have a compile-time symbolic shape, the function will then choose to return + * Call(relax.op.shape_of, [expr]). + */ +TVM_DLL Expr GetShapeOf(const Expr& expr); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_EXPR_H_ diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h new file mode 100644 index 000000000000..9c20a524353a --- /dev/null +++ b/include/tvm/relax/type.h @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/type.h + * \brief Relax Types. + */ +#ifndef TVM_RELAX_TYPE_H_ +#define TVM_RELAX_TYPE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! \brief Indicates the number of dimensions of a tensor is unknown at compile time. */ +static constexpr int kUnknownNDim = -1; + +class ShapeTypeNode : public TypeNode { + public: + /*! \brief size of the shape. */ + int ndim; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } + + bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { + return equal(ndim, other->ndim); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(ndim); } + + static constexpr const char* _type_key = "relax.ShapeType"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTypeNode, TypeNode); +}; + +class ShapeType : public Type { + public: + // TODO(relax-team): remove the default value later. + TVM_DLL ShapeType(int ndim = kUnknownNDim, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode); +}; + +class ObjectTypeNode : public TypeNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const ObjectTypeNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.ObjectType"; + TVM_DECLARE_FINAL_OBJECT_INFO(ObjectTypeNode, TypeNode); +}; + +class ObjectType : public Type { + public: + TVM_DLL ObjectType(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectType, Type, ObjectTypeNode); +}; + +class DynTensorTypeNode : public BaseTensorTypeNode { + public: + /*! + * \brief The number of dimensions of the tensor, use -1 to denote tensor with unknwon number of + * dimensions. + */ + int ndim; + /*! \brief The content data type, use void to denote the dtype is unknown. */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("ndim", &ndim); + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + bool SEqualReduce(const DynTensorTypeNode* other, SEqualReducer equal) const { + return equal(ndim, other->ndim) && equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(ndim); + hash_reduce(dtype); + } + + inline bool IsUnknownNdim() const { return ndim == kUnknownNDim; } + + inline bool IsUnknownDtype() const { return dtype.is_void(); } + + static constexpr const char* _type_key = "relax.DynTensorType"; + TVM_DECLARE_FINAL_OBJECT_INFO(DynTensorTypeNode, BaseTensorTypeNode); +}; + +/*! + * \brief Managed reference to DynTensorTypeNode. + * \sa DynTensorTypeNode. + */ +class DynTensorType : public Type { + public: + /*! + * \brief Constructor. + * \param ndim The number of dimensions of the tensor. + * \param dtype The runtime dtype of the tensor's elements. + * \param span The span. + */ + TVM_DLL DynTensorType(int ndim, DataType dtype, Span span = Span()); + + /*! + * \brief Create a DynTensorType with unknown ndim. + */ + TVM_DLL static DynTensorType CreateUnknownNDim(DataType dtype, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode); +}; + +class PackedFuncTypeNode : public TypeNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const PackedFuncTypeNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.PackedFuncType"; + TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncTypeNode, TypeNode); +}; + +class PackedFuncType : public Type { + public: + TVM_DLL PackedFuncType(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PackedFuncType, Type, PackedFuncTypeNode); +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_TYPE_H_ From ecbd0a41f889ccef4810b73b8a08145302141520 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Thu, 2 Feb 2023 23:32:15 -0500 Subject: [PATCH 03/81] [Unity][IR] First-class StructInfo (#13907) * [Unity][IR] First-class StructInfo Relax tracks structural information (such as tensor shape) via `StructInfo` about the values in Relax. * Fix rust build --------- Co-authored-by: Junru Shao --- CMakeLists.txt | 1 + include/tvm/relax/struct_info.h | 430 ++++++++++++++++++++++++++++++++ rust/tvm/src/ir/relay/mod.rs | 2 + src/relax/ir/struct_info.cc | 238 ++++++++++++++++++ 4 files changed, 671 insertions(+) create mode 100644 include/tvm/relax/struct_info.h create mode 100644 src/relax/ir/struct_info.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index ed2afc392067..19f37d06f315 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -289,6 +289,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/driver/*.cc src/support/*.cc src/script/*.cc + src/relax/ir/*.cc src/relax/backend/vm/*.cc ) diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h new file mode 100644 index 000000000000..d21c8db86b3f --- /dev/null +++ b/include/tvm/relax/struct_info.h @@ -0,0 +1,430 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RELAX_STRUCT_INFO_H_ +#define TVM_RELAX_STRUCT_INFO_H_ + +#include +#include +#include +// #include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Opaque object. + */ +class ObjectStructInfoNode : public StructInfoNode { + public: + void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const ObjectStructInfoNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.ObjectStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(ObjectStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to ObjectStructInfoNode. + * \sa ObjectStructInfoNode + */ +class ObjectStructInfo : public StructInfo { + public: + TVM_DLL ObjectStructInfo(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectStructInfo, StructInfo, ObjectStructInfoNode); +}; + +/*! + * \brief Primitive value. + */ +class PrimStructInfoNode : public StructInfoNode { + public: + /*! \brief Underlying data type of the primitive value */ + DataType dtype; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); } + + static constexpr const char* _type_key = "relax.PrimStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to PrimStructInfoNode. + * \sa PrimStructInfoNode + */ +class PrimStructInfo : public StructInfo { + public: + TVM_DLL PrimStructInfo(DataType dtype, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, PrimStructInfoNode); +}; + +/*! + * \brief StructInfo of shape value. + */ +class ShapeStructInfoNode : public StructInfoNode { + public: + /*! \brief optionally stores the symbolic value patterns of the shape */ + Optional> values; + /*! + * \brief The number of dimension of the shape, can be unknown. + * \sa kUnknownNDim + */ + int ndim; + + /*! \return Whether the struct info contains unknown ndim. */ + bool IsUnknownNdim() const { return ndim == kUnknownNDim; } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("values", &values); + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } + + bool SEqualReduce(const ShapeStructInfoNode* other, SEqualReducer equal) const { + return equal(values, other->values) && equal(ndim, other->ndim); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(values); + hash_reduce(ndim); + } + + static constexpr const char* _type_key = "relax.ShapeStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to ShapeStructInfoNode. + * \sa ShapeStructInfoNode + */ +class ShapeStructInfo : public StructInfo { + public: + /*! + * \brief Construction with known symbolic shape patterns + * \param values The symbolic shape values + * \param span The span of the AST. + */ + TVM_DLL ShapeStructInfo(Array values, Span span = Span()); + /*! + * \brief Construction with known unknown symbolic shape patterns. + * \param ndim Number of dimensions -- can be kUnknownNDim + * \param span The span of the AST. + */ + TVM_DLL ShapeStructInfo(int ndim, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeStructInfo, StructInfo, ShapeStructInfoNode); +}; + +/*! + * \brief StructInfo of Tensor. + */ +class TensorStructInfoNode : public StructInfoNode { + public: + /*! + * \brief optionally store the shape expression of the tensor. + * \note shape must be normalized: it can only be NullOpt or ShapeExpr or Var. + */ + Optional shape; + /*! \brief The content data type, use void to denote the dtype is unknown. */ + DataType dtype; + /*! + * \brief The number of dimension of the tensor, can be unknown. + * \sa kUnknownNDim + */ + int ndim; + + /*! \return Whether the struct info contains unknown ndim. */ + bool IsUnknownNdim() const { return ndim == kUnknownNDim; } + + /*! \return Whether the struct info contains unknown dtype. */ + bool IsUnknownDtype() const { return dtype.is_void(); } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + v->Visit("ndim", &ndim); + v->Visit("span", &span); + } + + bool SEqualReduce(const TensorStructInfoNode* other, SEqualReducer equal) const { + return equal(shape, other->shape) && equal(ndim, other->ndim) && equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(shape); + hash_reduce(dtype); + hash_reduce(ndim); + } + + static constexpr const char* _type_key = "relax.TensorStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to TensorStructInfoNode. + * \sa TensorStructInfoNode + */ +class TensorStructInfo : public StructInfo { + public: + /*! + * \brief Construction with a known shape expression. + * \param shape The shape of the tensor. + * \param dtype The data type of tensor's elements. + * \param span The span of the AST. + * + * \note shape must already be normalized. + */ + TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Span span = Span()); + + /*! + * \brief Construction with an unknown shape expression. + * \param dtype The data type of tensor's elements. + * \param ndim The number of dimensions + * \param span The span of the AST. + */ + TVM_DLL TensorStructInfo(DataType dtype, int ndim, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode); +}; + +/*! + * \brief StructInfo of Tuple. + */ +class TupleStructInfoNode : public StructInfoNode { + public: + /*! \brief The struct info of tuple fields. */ + Array fields; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("fields", &fields); + v->Visit("span", &span); + } + + bool SEqualReduce(const TupleStructInfoNode* other, SEqualReducer equal) const { + return equal(fields, other->fields); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } + + static constexpr const char* _type_key = "relax.TupleStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to TupleStructInfoNode. + * \sa TupleStructInfoNode + */ +class TupleStructInfo : public StructInfo { + public: + /*! + * \brief Constructor + * \param fields Struct info of tuple fields. + * \param span The span of the AST. + */ + TVM_DLL TupleStructInfo(Array fields, Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo, TupleStructInfoNode); +}; + +class BlockBuilder; + +/*! + * \brief custom-defined StructInfo derivation function. + * \param call The call expression to be derived. + * \param ctx The builder context. + * \return The derived struct info of the call. + */ +using StructInfoDeriveFunc = TypedEnvFunc; + +/*! + * \brief Structure information about function. + * + * This data structure contains enough information for us to + * do best-effort structure information deduction. + */ +class FuncStructInfoNode : public StructInfoNode { + public: + /*! + * \brief The parameter struct info of the function. + * \note When params is NullOpt means the function can take arbitrary number of arguments. + * We define such functions as Opaque function. + */ + Optional> params; + /*! + * \brief The struct info of the function's return value. + */ + StructInfo ret; + /*! + * \brief Derivation function of opaque functions that may take any number of parameters. + * \note When derive_func is not empty, then params should be NullOpt, + * ret should be ObjectStructInfo() + */ + Optional derive_func; + + /*! + * \return Whether the func struct info is opaque. + * \note We define a function as opaque we have no constraints on params. + */ + bool IsOpaque() const { return !params.defined(); } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("ret", &ret); + v->Visit("derive_func", &derive_func); + v->Visit("span", &span); + } + + bool SEqualReduce(const FuncStructInfoNode* other, SEqualReducer equal) const { + return equal.DefEqual(params, other->params) && equal(ret, other->ret) && + equal(derive_func, other->derive_func); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(params); + hash_reduce(ret); + hash_reduce(derive_func); + } + + static constexpr const char* _type_key = "relax.FuncStructInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuncStructInfoNode, StructInfoNode); +}; + +/*! + * \brief Managed reference to FuncStructInfoNode. + * \sa FuncStructInfoNode + */ +class FuncStructInfo : public StructInfo { + public: + /*! + * \brief Constructor from parameter struct info and return value struct info. + * \param params The struct info of function parameters. + * \param ret The return value struct info. + * \param span The span of the AST. + * + * \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from + * params. If you are unsure, you can always erase ret to static. + */ + TVM_DLL FuncStructInfo(Array params, StructInfo ret, Span span = Span()); + + /*! + * \brief Constructing an opaque function struct info using derive_func. + * + * \param derive_func Derivation function. + * \param span The span of the AST. + * + * \return The FuncStructInfo for opaque packedfunc. + * \note Defaults to an derive func that always return ObjectStructInfo if not specified. + */ + TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, Span span = Span()); + + /*! + * \brief Construct an opaque function using from return struct info. + * + * \param ret The struct info of the return value. + * \param span The span of the AST. + * + * \return The FuncStructInfo for opaque packedfunc. + * \note Defaults to an derive func that always return ObjectStructInfo if not specified. + */ + TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode); +}; + +/*! + * \brief Match and check if expr have StructInfo T and return it. + * + * \param expr The input expression. + * \return The result of match. + * \tparam T the underlying structure info type + */ +template +inline Optional MatchStructInfo(const Expr& expr) { + using TNode = typename T::ContainerType; + if (const TNode* ptr = expr->struct_info_.as()) { + return GetRef(ptr); + } else { + return NullOpt; + } +} + +/*! + * \brief Get the structure info of a given expr and try to cast it as const T*. + * + * \param expr The input expression. + * \return The pointer. Returns nullptr if the type does not match + * \tparam T the underlying structure info type + */ +template +inline const T* GetStructInfoAs(const Expr& expr) { + ICHECK(expr->struct_info_.defined()) + << "The struct_info is not populated, check if you have normalized the expr"; + return expr->struct_info_.as(); +} + +/*! + * \brief Get the underlying structure info of expr. + * + * \param expr The input expression. + * \return underlying struct info. + */ +inline StructInfo GetStructInfo(const Expr& expr) { + auto* ptr = expr->struct_info_.as(); + ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr"; + return GetRef(ptr); +} + +/*! + * \brief Whether the expr has void struct info. + * + * \param expr The input expression. + * \return Whether the expr has void struct info. + */ +inline bool HasVoidStructInfo(const Expr& expr) { + auto* ptr = expr->struct_info_.as(); + return ptr != nullptr && ptr->fields.size() == 0; +} + +/*! + * \brief Update the struct info of an Expr. + * \param expr The Expr whose struct info to be updated. + * \param struct_info The struct_info assigned. + * \note We ensure idempotence, that is we can only update the struct_info of an Expr only + * if the original one is nullptr. + */ +TVM_DLL void UpdateStructInfo(Expr expr, StructInfo struct_info); + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_STRUCT_INFO_H_ diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index abc25e89c48c..08ce082c4586 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -40,6 +40,7 @@ pub mod attrs; pub struct ExprNode { pub base: BaseExprNode, pub checked_type: Type, + pub struct_info: ObjectRef, pub virtual_device: ObjectRef, } @@ -48,6 +49,7 @@ impl ExprNode { ExprNode { base: BaseExprNode::base::(span.clone()), checked_type: Type::null(), + struct_info: ObjectRef::null(), virtual_device: ObjectRef::null(), } } diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc new file mode 100644 index 000000000000..88046ed81f10 --- /dev/null +++ b/src/relax/ir/struct_info.cc @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/struct_info.cc + * \brief Relax struct info. + */ +#include +#include + +namespace tvm { +namespace relax { + +ObjectStructInfo::ObjectStructInfo(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ObjectStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) { + return ObjectStructInfo(span); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "ObjectStructInfo()"; + }); + +// Prim +PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { + ObjectPtr n = make_object(); + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PrimStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, Span span) { + return PrimStructInfo(dtype, span); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + p->stream << "PrimStructInfo(" << node->dtype << ")"; + }); + +// Shape +ShapeStructInfo::ShapeStructInfo(Array values, Span span) { + ObjectPtr n = make_object(); + n->ndim = static_cast(values.size()); + n->values = values.Map([](PrimExpr value) { + if (value->IsInstance()) { + return tvm::cast(DataType::Int(64), value); + } + ICHECK(value.dtype() == DataType::Int(64)) + << "the value in ShapeStructInfo can only have dtype of int64"; + return value; + }); + n->span = span; + data_ = std::move(n); +} + +ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { + ObjectPtr n = make_object(); + CHECK_GE(ndim, -1) << "ndim of ShapeStructInfo must be >= -1, but got " << ndim; + n->ndim = ndim; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ShapeStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.ShapeStructInfo") + .set_body_typed([](Optional> values, int ndim, Span span) { + if (values.defined()) { + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify values and ndim"; + return ShapeStructInfo(values.value(), span); + } else { + return ShapeStructInfo(ndim, span); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + if (node->values.defined()) { + p->stream << "ShapeStructInfo(" << node->values.value() << ")"; + } else { + p->stream << "ShapeStructInfo(ndim=" << node->ndim << ")"; + } + }); + +// Tensor +TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) { + ObjectPtr n = make_object(); + // assign ndim before move + Optional sinfo = MatchStructInfo(shape); + ICHECK(sinfo) << "We expect shape to contain pre-set shape struct info"; + ICHECK(shape.defined()) << "Must provide a shape in this constructor"; + ICHECK(shape->IsInstance() || shape->IsInstance()) + << "We require shape to be normalized when constructing TensorStructInfo"; + n->ndim = sinfo.get()->ndim; + // assign rest of the fields. + n->shape = std::move(shape); + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Span span) { + ObjectPtr n = make_object(); + CHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " << ndim; + n->ndim = ndim; + n->dtype = dtype; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TensorStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.TensorStructInfo") + .set_body_typed([](Optional shape, DataType dtype, int ndim, Span span) { + if (shape.defined()) { + CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape and ndim"; + return TensorStructInfo(shape.value(), dtype, span); + } else { + return TensorStructInfo(dtype, ndim, span); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + if (node->shape.defined()) { + p->stream << "TensorStructInfo(" << node->shape.value() << ", " << node->dtype << ")"; + } else { + p->stream << "TensorStructInfo(" << node->dtype << ", ndim=" << node->ndim << ")"; + } + }); + +// Tuple +TupleStructInfo::TupleStructInfo(Array fields, Span span) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TupleStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.TupleStructInfo") + .set_body_typed([](Array fields, Span span) { + return TupleStructInfo(fields, span); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + p->stream << "TupleStructInfo(" << node->fields << ")"; + }); + +// Func +FuncStructInfo::FuncStructInfo(Array params, StructInfo ret, Span span) { + ObjectPtr n = make_object(); + n->params = std::move(params); + n->ret = std::move(ret); + n->span = span; + data_ = std::move(n); +} + +FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfoDeriveFunc derive_func, Span span) { + ObjectPtr n = make_object(); + n->derive_func = std::move(derive_func); + n->ret = ObjectStructInfo(); + n->span = span; + return FuncStructInfo(n); +} + +FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, Span span) { + ObjectPtr n = make_object(); + n->ret = std::move(ret); + n->span = span; + return FuncStructInfo(n); +} + +TVM_REGISTER_NODE_TYPE(FuncStructInfoNode); + +TVM_REGISTER_GLOBAL("relax.FuncStructInfo") + .set_body_typed([](Array params, StructInfo ret, Span span) { + return FuncStructInfo(params, ret, span); + }); + +TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc") + .set_body_typed([](Optional ret, Optional derive_func, + Span span) { + if (derive_func.defined()) { + ICHECK(!ret.defined()) << "ValueError: Cannot specify both ret and derive_func"; + return FuncStructInfo::OpaqueFunc(derive_func.value(), span); + } else { + return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), span); + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + p->stream << "FuncStructInfo(" << node->params << ", " << node->ret << ")"; + }); + +// Helper functions +// TODO(unity-team): add UpdateStructInfo once analysis.cc is upstreamed + +TVM_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) { + return GetStructInfo(expr); +}); + +} // namespace relax +} // namespace tvm From 76cc9f7dc5eeb4ea3eb985c7a21fc4985625b419 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 2 Feb 2023 23:33:15 -0500 Subject: [PATCH 04/81] [Unity][CI] Unity specific jenkins setup (do not upstream to main) (#13910) This PR setup a unity specific jenkins with minimum jenkinsfile without sharding and disables most of the tests to reduce overall cost. We can add tests of unty branch by configuring the specific groovy file. --- ci/jenkins/generated/arm_jenkinsfile.groovy | 5 + .../generated/cortexm_jenkinsfile.groovy | 5 + ci/jenkins/generated/cpu_jenkinsfile.groovy | 5 + .../generated/docker_jenkinsfile.groovy | 5 + ci/jenkins/generated/gpu_jenkinsfile.groovy | 5 + .../generated/hexagon_jenkinsfile.groovy | 5 + ci/jenkins/generated/i386_jenkinsfile.groovy | 5 + ci/jenkins/generated/lint_jenkinsfile.groovy | 5 + .../minimal_cross_isa_jenkinsfile.groovy | 5 + .../generated/minimal_jenkinsfile.groovy | 5 + ci/jenkins/generated/riscv_jenkinsfile.groovy | 5 + ci/jenkins/generated/wasm_jenkinsfile.groovy | 5 + ci/jenkins/unity_jenkinsfile.groovy | 337 ++++++++++++++++++ tests/scripts/task_lint.sh | 4 +- tests/scripts/unity/README | 2 + tests/scripts/unity/task_extra_lint.sh | 23 ++ tests/scripts/unity/task_python_relax.sh | 37 ++ .../unity/task_python_relax_gpuonly.sh | 25 ++ 18 files changed, 486 insertions(+), 2 deletions(-) create mode 100644 ci/jenkins/unity_jenkinsfile.groovy create mode 100644 tests/scripts/unity/README create mode 100755 tests/scripts/unity/task_extra_lint.sh create mode 100755 tests/scripts/unity/task_python_relax.sh create mode 100755 tests/scripts/unity/task_python_relax_gpuonly.sh diff --git a/ci/jenkins/generated/arm_jenkinsfile.groovy b/ci/jenkins/generated/arm_jenkinsfile.groovy index 4c830dce2c30..ffcfa9b842d7 100644 --- a/ci/jenkins/generated/arm_jenkinsfile.groovy +++ b/ci/jenkins/generated/arm_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/cortexm_jenkinsfile.groovy b/ci/jenkins/generated/cortexm_jenkinsfile.groovy index d8a4d4671e86..c1a62736702b 100644 --- a/ci/jenkins/generated/cortexm_jenkinsfile.groovy +++ b/ci/jenkins/generated/cortexm_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/cpu_jenkinsfile.groovy b/ci/jenkins/generated/cpu_jenkinsfile.groovy index cdd2564e0591..e689cbb65583 100644 --- a/ci/jenkins/generated/cpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/cpu_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/docker_jenkinsfile.groovy b/ci/jenkins/generated/docker_jenkinsfile.groovy index 32dec7863bcf..74e3ddfabeac 100644 --- a/ci/jenkins/generated/docker_jenkinsfile.groovy +++ b/ci/jenkins/generated/docker_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/gpu_jenkinsfile.groovy b/ci/jenkins/generated/gpu_jenkinsfile.groovy index 390c8ddc3dc2..f14e8f541b41 100644 --- a/ci/jenkins/generated/gpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/gpu_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/hexagon_jenkinsfile.groovy b/ci/jenkins/generated/hexagon_jenkinsfile.groovy index 58fe4d14c969..7d5bd3309ee5 100644 --- a/ci/jenkins/generated/hexagon_jenkinsfile.groovy +++ b/ci/jenkins/generated/hexagon_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/i386_jenkinsfile.groovy b/ci/jenkins/generated/i386_jenkinsfile.groovy index b5bf5cb1fe40..98e09c393a69 100644 --- a/ci/jenkins/generated/i386_jenkinsfile.groovy +++ b/ci/jenkins/generated/i386_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/lint_jenkinsfile.groovy b/ci/jenkins/generated/lint_jenkinsfile.groovy index ed5aa8d67954..1a3120efb0e1 100644 --- a/ci/jenkins/generated/lint_jenkinsfile.groovy +++ b/ci/jenkins/generated/lint_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy b/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy index 4c748e3f20d7..08143791c68e 100644 --- a/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy +++ b/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/minimal_jenkinsfile.groovy b/ci/jenkins/generated/minimal_jenkinsfile.groovy index 72864ec4ca0f..ff10d01670ce 100644 --- a/ci/jenkins/generated/minimal_jenkinsfile.groovy +++ b/ci/jenkins/generated/minimal_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/riscv_jenkinsfile.groovy b/ci/jenkins/generated/riscv_jenkinsfile.groovy index 2dfeb3561281..df1160b3c1e5 100644 --- a/ci/jenkins/generated/riscv_jenkinsfile.groovy +++ b/ci/jenkins/generated/riscv_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/generated/wasm_jenkinsfile.groovy b/ci/jenkins/generated/wasm_jenkinsfile.groovy index 27e8f6570ed0..37b50f97ad17 100644 --- a/ci/jenkins/generated/wasm_jenkinsfile.groovy +++ b/ci/jenkins/generated/wasm_jenkinsfile.groovy @@ -54,6 +54,11 @@ // - Periodically cleanup the old versions on local workers // +// unity: Skip less relevant tests +// to reduce ci time and resource cost +// (DO NOT UPSTREAM TO MAIN) +return + // ============================= IMPORTANT NOTE ============================= // This file is generated by 'jenkins/generate.py'. Do not edit this file directly! // Make edits to 'jenkins/Jenkinsfile.j2' and regenerate this with diff --git a/ci/jenkins/unity_jenkinsfile.groovy b/ci/jenkins/unity_jenkinsfile.groovy new file mode 100644 index 000000000000..714260a28345 --- /dev/null +++ b/ci/jenkins/unity_jenkinsfile.groovy @@ -0,0 +1,337 @@ +#!groovy +// -*- mode: groovy -*- + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Jenkins pipeline +// See documents at https://jenkins.io/doc/book/pipeline/jenkinsfile/ + +// ============================= IMPORTANT NOTE ============================= +// To keep things simple +// This file is manually updated to maintain unity branch specific builds. +// Please do not send this file to main + + +import org.jenkinsci.plugins.pipeline.modeldefinition.Utils + +// NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> +ci_lint = 'tlcpack/ci-lint:20221025-182121-e41d0ed6e' +ci_gpu = 'tlcpack/ci-gpu:20221128-070141-ae4fd7df7' +ci_cpu = 'tlcpack/ci-cpu:20230110-070003-d00168ffb' +ci_wasm = 'tlcpack/ci-wasm:v0.72' +ci_i386 = 'tlcpack/ci-i386:v0.75' +ci_qemu = 'tlcpack/ci-qemu:v0.11' +ci_arm = 'tlcpack/ci-arm:v0.08' +ci_hexagon = 'tlcpack/ci-hexagon:20221025-182121-e41d0ed6e' +// <--- End of regex-scanned config. + +// Parameters to allow overriding (in Jenkins UI), the images +// to be used by a given build. When provided, they take precedence +// over default values above. +properties([ + parameters([ + string(name: 'ci_lint_param', defaultValue: ''), + string(name: 'ci_cpu_param', defaultValue: ''), + string(name: 'ci_gpu_param', defaultValue: ''), + string(name: 'ci_wasm_param', defaultValue: ''), + string(name: 'ci_i386_param', defaultValue: ''), + string(name: 'ci_qemu_param', defaultValue: ''), + string(name: 'ci_arm_param', defaultValue: ''), + string(name: 'ci_hexagon_param', defaultValue: '') + ]) +]) + +// tvm libraries +tvm_runtime = 'build/libtvm_runtime.so, build/config.cmake' +tvm_lib = 'build/libtvm.so, ' + tvm_runtime +// LLVM upstream lib +tvm_multilib = 'build/libtvm.so, ' + + 'build/libvta_fsim.so, ' + + tvm_runtime + +tvm_multilib_tsim = 'build/libvta_tsim.so, ' + + tvm_multilib + +// command to start a docker container +docker_run = 'docker/bash.sh' +// timeout in minutes +max_time = 240 + +def per_exec_ws(folder) { + return "workspace/exec_${env.EXECUTOR_NUMBER}/" + folder +} + +// initialize source codes +def init_git() { + checkout scm + // Add more info about job node + sh ( + script: './tests/scripts/task_show_node_info.sh', + label: 'Show executor node info', + ) + retry(5) { + timeout(time: 2, unit: 'MINUTES') { + sh (script: 'git submodule update --init -f', label: 'Update git submodules') + } + } +} + +def should_skip_slow_tests(pr_number) { + withCredentials([string( + credentialsId: 'tvm-bot-jenkins-reader', + variable: 'GITHUB_TOKEN', + )]) { + // Exit code of 1 means run slow tests, exit code of 0 means skip slow tests + result = sh ( + returnStatus: true, + script: "./tests/scripts/should_run_slow_tests.py --pr '${pr_number}'", + label: 'Check if CI should run slow tests', + ) + } + return result == 0 +} + +def cancel_previous_build() { + // cancel previous build if it is not on main. + if (env.BRANCH_NAME != 'main') { + def buildNumber = env.BUILD_NUMBER as int + // Milestone API allows us to cancel previous build + // with the same milestone number + if (buildNumber > 1) milestone(buildNumber - 1) + milestone(buildNumber) + } +} + +def should_skip_ci(pr_number) { + withCredentials([string( + credentialsId: 'tvm-bot-jenkins-reader', + variable: 'TOKEN', + )]) { + // Exit code of 1 means run full CI (or the script had an error, so run + // full CI just in case). Exit code of 0 means skip CI. + git_skip_ci_code = sh ( + returnStatus: true, + script: "./tests/scripts/git_skip_ci.py --pr '${pr_number}'", + label: 'Check if CI should be skipped', + ) + } + return git_skip_ci_code == 0 +} + +cancel_previous_build() + +def lint() { +stage('Prepare') { + node('CPU-SMALL') { + // When something is provided in ci_*_param, use it, otherwise default with ci_* + ci_lint = params.ci_lint_param ?: ci_lint + ci_cpu = params.ci_cpu_param ?: ci_cpu + ci_gpu = params.ci_gpu_param ?: ci_gpu + ci_wasm = params.ci_wasm_param ?: ci_wasm + ci_i386 = params.ci_i386_param ?: ci_i386 + ci_qemu = params.ci_qemu_param ?: ci_qemu + ci_arm = params.ci_arm_param ?: ci_arm + ci_hexagon = params.ci_hexagon_param ?: ci_hexagon + + sh (script: """ + echo "Docker images being used in this build:" + echo " ci_lint = ${ci_lint}" + echo " ci_cpu = ${ci_cpu}" + echo " ci_gpu = ${ci_gpu}" + echo " ci_wasm = ${ci_wasm}" + echo " ci_i386 = ${ci_i386}" + echo " ci_qemu = ${ci_qemu}" + echo " ci_arm = ${ci_arm}" + echo " ci_hexagon = ${ci_hexagon}" + """, label: 'Docker image names') + } +} + +stage('Sanity Check') { + timeout(time: max_time, unit: 'MINUTES') { + node('CPU') { + ws(per_exec_ws('tvm/sanity')) { + init_git() + is_docs_only_build = sh ( + returnStatus: true, + script: './tests/scripts/git_change_docs.sh', + label: 'Check for docs only changes', + ) + skip_ci = should_skip_ci(env.CHANGE_ID) + skip_slow_tests = should_skip_slow_tests(env.CHANGE_ID) + sh ( + script: "${docker_run} ${ci_lint} ./tests/scripts/task_lint.sh", + label: 'Run lint', + ) + sh ( + script: "${docker_run} ${ci_lint} ./tests/scripts/unity/task_extra_lint.sh", + label: 'Run extra lint', + ) + } + } + } +} +} + +lint() + +// Run make. First try to do an incremental make from a previous workspace in hope to +// accelerate the compilation. If something is wrong, clean the workspace and then +// build from scratch. +def make(docker_type, path, make_flag) { + timeout(time: max_time, unit: 'MINUTES') { + try { + cmake_build(docker_type, path, make_flag) + // always run cpp test when build + // sh "${docker_run} ${docker_type} ./tests/scripts/task_cpp_unittest.sh" + } catch (hudson.AbortException ae) { + // script exited due to user abort, directly throw instead of retry + if (ae.getMessage().contains('script returned exit code 143')) { + throw ae + } + echo 'Incremental compilation failed. Fall back to build from scratch' + sh ( + script: "${docker_run} ${docker_type} ./tests/scripts/task_clean.sh ${path}", + label: 'Clear old cmake workspace', + ) + cmake_build(docker_type, path, make_flag) + cpp_unittest(docker_type) + } + } +} + +// Specifications to Jenkins "stash" command for use with various pack_ and unpack_ functions. +tvm_runtime = 'build/libtvm_runtime.so, build/config.cmake' // use libtvm_runtime.so. +tvm_lib = 'build/libtvm.so, ' + tvm_runtime // use libtvm.so to run the full compiler. +// LLVM upstream lib +tvm_multilib = 'build/libtvm.so, ' + + 'build/libvta_fsim.so, ' + + tvm_runtime + +tvm_multilib_tsim = 'build/libvta_tsim.so, ' + + tvm_multilib + +microtvm_tar_gz = 'build/microtvm_template_projects.tar.gz' + +// pack libraries for later use +def pack_lib(name, libs) { + sh (script: """ + echo "Packing ${libs} into ${name}" + echo ${libs} | sed -e 's/,/ /g' | xargs md5sum + """, label: 'Stash libraries and show md5') + stash includes: libs, name: name +} + +// unpack libraries saved before +def unpack_lib(name, libs) { + unstash name + sh (script: """ + echo "Unpacked ${libs} from ${name}" + echo ${libs} | sed -e 's/,/ /g' | xargs md5sum + """, label: 'Unstash libraries and show md5') +} + +// compress microtvm template projects and pack the tar. +def pack_microtvm_template_projects(name) { + sh( + script: 'cd build && tar -czvf microtvm_template_projects.tar.gz microtvm_template_projects/', + label: 'Compress microtvm_template_projects' + ) + pack_lib(name + '-microtvm-libs', microtvm_tar_gz) +} + +def unpack_microtvm_template_projects(name) { + unpack_lib(name + '-microtvm-libs', microtvm_tar_gz) + sh( + script: 'cd build && tar -xzvf microtvm_template_projects.tar.gz', + label: 'Unpack microtvm_template_projects' + ) +} + +def ci_setup(image) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_ci_setup.sh", + label: 'Set up CI environment', + ) +} + +def python_unittest(image) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_python_unittest.sh", + label: 'Run Python unit tests', + ) +} + +def fsim_test(image) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_python_vta_fsim.sh", + label: 'Run VTA tests in FSIM', + ) +} + +def cmake_build(image, path, make_flag) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_build.py --sccache-bucket tvm-sccache-prod", + label: 'Run cmake build', + ) +} + +def cpp_unittest(image) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_cpp_unittest.sh", + label: 'Build and run C++ tests', + ) +} + +def add_hexagon_permissions() { + sh( + script: 'find build/hexagon_api_output -type f | xargs chmod +x', + label: 'Add execute permissions for hexagon files', + ) +} + +// NOTE: limit tests to relax folder for now to allow us to skip some of the tests +// that are mostly related to changes in main. +// This helps to speedup CI time and reduce CI cost. +stage('Build and Test') { + if (is_docs_only_build != 1) { + parallel 'BUILD: GPU': { + node('GPU') { + ws(per_exec_ws('tvm/build-gpu')) { + init_git() + sh "${docker_run} ${ci_gpu} nvidia-smi" + sh "${docker_run} ${ci_gpu} ./tests/scripts/task_config_build_gpu.sh build" + make("${ci_gpu}", 'build', '-j2') + sh "${docker_run} ${ci_gpu} ./tests/scripts/unity/task_python_relax_gpuonly.sh" + } + } + }, + 'BUILD: CPU': { + node('CPU') { + ws(per_exec_ws('tvm/build-cpu')) { + init_git() + sh "${docker_run} ${ci_cpu} ./tests/scripts/task_config_build_cpu.sh build" + make(ci_cpu, 'build', '-j2') + sh "${docker_run} ${ci_cpu} ./tests/scripts/unity/task_python_relax.sh" + } + } + } + } else { + Utils.markStageSkippedForConditional('BUILD: CPU') + } +} diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh index 83ea86ecccb8..9ca83ece5cd5 100755 --- a/tests/scripts/task_lint.sh +++ b/tests/scripts/task_lint.sh @@ -31,8 +31,8 @@ function shard1 { echo "Convert scripts to Python..." tests/scripts/task_convert_scripts_to_python.sh - echo "Check Jenkinsfile generation" - python3 ci/jenkins/generate.py --check + # echo "Check Jenkinsfile generation" + # python3 ci/jenkins/generate.py --check echo "Checking file types..." python3 tests/lint/check_file_type.py diff --git a/tests/scripts/unity/README b/tests/scripts/unity/README new file mode 100644 index 000000000000..42f8c3e040ea --- /dev/null +++ b/tests/scripts/unity/README @@ -0,0 +1,2 @@ +This folder contains CI task scripts that are specialized +to unity branch, please do not send to other places. diff --git a/tests/scripts/unity/task_extra_lint.sh b/tests/scripts/unity/task_extra_lint.sh new file mode 100755 index 000000000000..989f4df7389e --- /dev/null +++ b/tests/scripts/unity/task_extra_lint.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euxo pipefail + +source tests/scripts/setup-pytest-env.sh + +# place extra lint here. diff --git a/tests/scripts/unity/task_python_relax.sh b/tests/scripts/unity/task_python_relax.sh new file mode 100755 index 000000000000..8869c318fab7 --- /dev/null +++ b/tests/scripts/unity/task_python_relax.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euxo pipefail + +source tests/scripts/setup-pytest-env.sh +export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python +export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}" + +# to avoid CI CPU thread throttling. +export TVM_BIND_THREADS=0 +export TVM_NUM_THREADS=2 + +make cython3 + +# Run Relax tests +TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax + +# Run Relax examples +# python3 ./apps/relax_examples/mlp.py +# python3 ./apps/relax_examples/nn_module.py +# python3 ./apps/relax_examples/resnet.py diff --git a/tests/scripts/unity/task_python_relax_gpuonly.sh b/tests/scripts/unity/task_python_relax_gpuonly.sh new file mode 100755 index 000000000000..acbcce44f279 --- /dev/null +++ b/tests/scripts/unity/task_python_relax_gpuonly.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +export TVM_TEST_TARGETS="llvm;cuda" +export PYTEST_ADDOPTS="-m gpu $PYTEST_ADDOPTS" +export TVM_RELAY_TEST_TARGETS="cuda" +export TVM_INTEGRATION_TESTSUITE_NAME=python-integration-gpu +export TVM_INTEGRATION_GPU_ONLY=1 + +./tests/scripts/unity/task_python_relax.sh From fa561c816d40e3fe7e16028cbf93619bcf604eb3 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Sun, 5 Feb 2023 00:51:08 -0500 Subject: [PATCH 05/81] [Unity] Basic StructInfo Analysis and Expr construction (#13916) [Unity] Basic StructInfo Analysis and Expr construction. This PR adds struct info analysis and expr support. These are logics to construct the IR node and perform struct info related analysis. Testcases are added to cover the IR node construction and related struct info analysis checks. Co-authored-by: Tianqi Chen Co-authored-by: Altan Haan Co-authored-by: Andrew Liu Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Jiawei Liu Co-authored-by: Junru Shao Co-authored-by: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Co-authored-by: masahi Co-authored-by: Prakalp Srivastava Co-authored-by: Ruihang Lai Co-authored-by: Siyuan Feng Co-authored-by: Steven S. Co-authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Co-authored-by: Yixin Dong Co-authored-by: Yong Wu Co-authored-by: Ziheng Jiang --- CMakeLists.txt | 1 + include/tvm/ir/type.h | 3 +- include/tvm/relax/analysis.h | 252 ++++++ include/tvm/relax/expr.h | 43 +- include/tvm/relax/expr_functor.h | 415 ++++++++++ include/tvm/relax/struct_info.h | 7 +- include/tvm/relax/struct_info_functor.h | 151 ++++ python/tvm/ir/expr.py | 11 + python/tvm/relax/__init__.py | 48 ++ python/tvm/relax/analysis/__init__.py | 20 + python/tvm/relax/analysis/_ffi_api.py | 19 + python/tvm/relax/analysis/analysis.py | 135 ++++ python/tvm/relax/expr.py | 729 ++++++++++++++++++ python/tvm/relax/struct_info.py | 197 +++++ python/tvm/relax/ty.py | 75 ++ python/tvm/script/__init__.py | 1 + python/tvm/script/parser/relax/__init__.py | 21 + src/ir/function.cc | 14 +- src/ir/type.cc | 3 +- src/relax/analysis/shape_analysis.cc | 55 ++ src/relax/analysis/struct_info_analysis.cc | 716 +++++++++++++++++ src/relax/ir/expr.cc | 601 +++++++++++++++ src/relax/ir/expr_functor.cc | 546 +++++++++++++ src/relax/ir/struct_info.cc | 14 +- src/relax/ir/struct_info_functor.cc | 130 ++++ src/relax/ir/type.cc | 88 +++ .../test_analysis_struct_info_analysis.py | 418 ++++++++++ tests/python/relax/test_expr.py | 258 +++++++ tests/python/relax/test_struct_info.py | 241 ++++++ 29 files changed, 5198 insertions(+), 14 deletions(-) create mode 100644 include/tvm/relax/analysis.h create mode 100644 include/tvm/relax/expr_functor.h create mode 100644 include/tvm/relax/struct_info_functor.h create mode 100644 python/tvm/relax/analysis/__init__.py create mode 100644 python/tvm/relax/analysis/_ffi_api.py create mode 100644 python/tvm/relax/analysis/analysis.py create mode 100644 python/tvm/relax/expr.py create mode 100644 python/tvm/relax/struct_info.py create mode 100644 python/tvm/relax/ty.py create mode 100644 python/tvm/script/parser/relax/__init__.py create mode 100644 src/relax/analysis/shape_analysis.cc create mode 100644 src/relax/analysis/struct_info_analysis.cc create mode 100644 src/relax/ir/expr.cc create mode 100644 src/relax/ir/expr_functor.cc create mode 100644 src/relax/ir/struct_info_functor.cc create mode 100644 src/relax/ir/type.cc create mode 100644 tests/python/relax/test_analysis_struct_info_analysis.py create mode 100644 tests/python/relax/test_expr.py create mode 100644 tests/python/relax/test_struct_info.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 19f37d06f315..fa38ba6c6c8a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -290,6 +290,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/support/*.cc src/script/*.cc src/relax/ir/*.cc + src/relax/analysis/*.cc src/relax/backend/vm/*.cc ) diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index c6baf5e08be3..ec13635a2643 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -131,8 +131,9 @@ class PrimType : public Type { /*! * \brief Constructor * \param dtype The corresponding dtype. + * \param span The span */ - TVM_DLL explicit PrimType(runtime::DataType dtype); + TVM_DLL explicit PrimType(runtime::DataType dtype, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode); }; diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h new file mode 100644 index 000000000000..82145032f458 --- /dev/null +++ b/include/tvm/relax/analysis.h @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/analysis.h + * \brief The set of Relax specific analysis on IR. + */ +#ifndef TVM_RELAX_ANALYSIS_H_ +#define TVM_RELAX_ANALYSIS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { +//----------------------------------- +// Shape expression analysis +//---------------------------------- +/*! + * \brief Can prove the two symbolic shape arrays equals to each other. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param ana The analyzer used for integer analysis. + * \return The prove result. + * + * \note This function does best effort prove, which means + * if result is false, there is still possibility that + * two shapes equals to each other during runtime. + */ +TVM_DLL bool CanProveShapeEqual(const Array& lhs, const Array& rhs, + arith::Analyzer* ana); + +/*! + * \brief Can prove the two symbolic shape expressions equals to each other. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param ana The analyzer used for integer analysis. + * + * \note This function does best effort prove, which means + * if result is false, there is still possibility that + * two shapes equals to each other during runtime. + */ +TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana); + +//----------------------------------- +// Foundational StructInfo analysis +//----------------------------------- +/*! + * \brief Get the corresponding static type from a given struct info. + * \param info The struct info. + * \return the corresponding static type. + */ +TVM_DLL Type GetStaticType(const StructInfo& info); + +/*! + * \brief Get the corresponding struct info from static type. + * \param type The input type + * \return the corresponding struct info. + */ +TVM_DLL StructInfo StructInfoFromType(const Type& type); + +/*! + * \brief Erase the info to a corresponding more coarse grained + * struct info that is still well-defined(with all the vars in scope). + * + * When we are returning a StructInfo to another scope, + * it is important to remember that StructInfo may carry + * dependencies on var that is not defined the other scope. + * + * In such cases, it is important to call EraseToWellDefined to get + * another StructInfo that **only** contains the vars that are defined + * in the target scope. + * + * For example, consider the following function + * + * \code + * + * @R.function + * def f(x: R.Tensor[(n, m)]): + * k = tir.Var("k", "int64") + * v0 = opaque_fn(x) + * v1 = match_cast(v0, R.Tensor[(n, k)]) + * v2 : R.Tensor[(n + 1, k + 2)] = pad(v1) + * return v2 + * + * \endcode + * + * In the above code, the return value y have shape `(n + 1, k + 2)`, + * However, at the level of function signature, only n, m are defined, + * k is undefined here. + * + * When we call EraseToWellDefined(R.Tensor[(n + 1, k + 2)], fshape_var_map={n: n, m: m}), + * we will obtain R.Tensor(ndim=2), which is an erased info that does not depend + * on k(which is undefined from parameter signature). + * + * However, if we call EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: n, m: m}), + * Then the return value will be R.Tensor[(n + 1, m)], because both n and m are defined. + * + * We can also make these var map to return a different expression. + * For example, EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: 2, m: m}) + * will give us R.Tensor[(3, m)], where n get replaced by 2. + * + * Use this function in the following scenarios: + * - Decide the struct_info of expr with sub-scopes, such as If, SeqExpr + * - Decide the deduced return struct_info of a function that can be fully decided by params. + * + * \param info The struct info. + * \param f_shape_var_map callback function to specify + * whether a symbolic shape var is defined and the value it maps to, + * return nullopt if var is undefined. + * \param f_var_map callback function to specify + * whether a var is defined in the target scope and the value it maps to, + * return nullopt if var is undefined. + * \param ana Optional context analyzer to prove symbolic expression equality. + * + * \return the corresponding erased struct info. + */ +TVM_DLL StructInfo +EraseToWellDefined(const StructInfo& info, + std::function(const tir::Var& var)> f_shape_var_map = nullptr, + std::function(const Var& var)> f_var_map = nullptr, + arith::Analyzer* ana = nullptr); + +/*! + * \brief EraseToWellDefined variant with map. + * \param info The struct info. + * \param shape_var_map map to specify + * whether a symbolic shape var is defined and the value it maps to, + * return nullopt if var is undefined. + * \param var_map map to specify + * whether a var is defined in the target scope and the value it maps to, + * return nullopt if var is undefined. + * \param ana Optional context analyzer to prove symbolic expression equality. + * + * \return the corresponding erased struct info. + */ +TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, Map shape_var_map, + Map var_map, arith::Analyzer* ana = nullptr); + +/*! + * \brief Fine grained result of base check. + * + * This analysis comes with different levels of checking failures + * that can help to customize the compilation decisions. + * + * For a given pair of lhs_struct_info, rhs_struct_info. We adopt + * the following terminology: + * - LSet = {value | value matches lhs_struct_info} + * - RSet = {value | value matches rhs_struct_info} + * + * See the definition of each level below. + */ +enum class BaseCheckResult { + /*! + * \brief The two value sets have no intersection at all: Interset(LSet, RSet) = empty + */ + kFailL0 = 0, + /*! + * \brief LSet is not superset of RSet by only looking at static information. + * + * \note This level will trigger static type checking error when lhs is param and rhs is arg. + */ + kFailL1 = 1, + /*! + * \brief WLSet is not superset of RSet because of mismatch in value information. + * + * L1-level mismatches in params of FuncStructInfo is categorized as + * If lhs is FuncStructInfo, then L1-level mismatch in its params + * is categorized as L2-level mismatch for lhs. + * + * Design considerations for functions: + * - (a) We want to be able to erase type/value in function signature + * when we unify function struct info and preserve simpler representations. + * - (b) We automatically insert match_cast at function boundary, so + * we can erase (int)->int argument as (object)->int. + * The input shape/type mismatch will be detected by runtime checks at function boundary. + * This behavior is also consistent with the PackedFunc behavior. + * + * \note This level means there is no problem about static known information. + * It is OK for the checker to do best effort and return this value. + */ + kFailL2 = 2, + /*! \brief LSet is superset of RSet. */ + kPass = 3 +}; + +/*! + * \brief Run a base check to see if base subsumes derived. + * + * This function returns fine-grained base-check result on reasons of failure. + * + * \param base The base struct info. + * \param derived The derived struct info. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return Whether the relation holds. + * + * \sa BaseCheckResult + */ +TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived, + arith::Analyzer* ana = nullptr); + +/*! + * \brief Check the relation of two struct info to see if one subsumes another one. + * + * \param base The base struct info. + * \param derived The derived struct info. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return Whether the relation holds. + */ +TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, + arith::Analyzer* ana = nullptr); + +/*! + * \brief Unify the two struct info to their least common ancestor. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return The unified information. + */ +TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, + arith::Analyzer* ana = nullptr); +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ANALYSIS_H_ diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 8154b1dd86de..9e563c7061dc 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -35,7 +34,47 @@ namespace relax { using Expr = RelayExpr; using ExprNode = RelayExprNode; -using relay::Id; +/*! + * \brief The unique identifier of variables. + * + * Id is like name to the variables, + * except that id is unique for each Var. + * + * \note Do not create Id directly, they are created in Var. + */ +class IdNode : public Object { + public: + /*! + * \brief The name of the variable, + * this only acts as a hint to the user, + * and is not used for equality. + */ + String name_hint; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); } + + bool SEqualReduce(const IdNode* other, SEqualReducer equal) const { + return equal.FreeVarEqualImpl(this, other); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.FreeVarHashImpl(this); } + + static constexpr const char* _type_key = "relax.Id"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object); +}; + +class Id : public ObjectRef { + public: + /*! + * \brief The constructor + * \param name_hint The name of the variable. + */ + TVM_DLL explicit Id(String name_hint); + + TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode); +}; /*! * \brief Base type of all structure information. diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h new file mode 100644 index 000000000000..5735e8661f6f --- /dev/null +++ b/include/tvm/relax/expr_functor.h @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/expr_functor.h + * \brief A more powerful visitor which enables defining arbitrary function + * signatures with type based dispatch on first argument. + */ +#ifndef TVM_RELAX_EXPR_FUNCTOR_H_ +#define TVM_RELAX_EXPR_FUNCTOR_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A dynamical functor that dispatches on in the first Expr argument. + * You can use this as a more powerful Visitor, since it allows you to + * define function signatures of Visit Function. + * + * \sa tvm/ir_functor.h + * + * \tparam FType function signiture + * This type is only defined for FType with function signature R(const Expr&, + * Args...) + */ +template +class ExprFunctor; + +// functions to be overriden. +#define EXPR_FUNCTOR_DEFAULT \ + { return VisitExprDefault_(op, std::forward(args)...); } + +#define RELAX_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); \ + }); + +#define PY_EXPR_VISITOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC) \ + { \ + if (PY_FUNC != nullptr) \ + PY_FUNC(N); \ + else \ + DEFAULT_FUNC; \ + } + +#define PY_EXPR_MUTATOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC, RET_TYPE) \ + { \ + if (PY_FUNC != nullptr) { \ + RET_TYPE ret = PY_FUNC(N); \ + return ret; \ + } else { \ + return DEFAULT_FUNC; \ + } \ + } + +#define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + if (self->PY_FUNC != nullptr) \ + self->PY_FUNC(n); \ + else \ + self->VisitExpr_(static_cast(n.get())); \ + }); + +#define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + if (self->PY_FUNC != nullptr) { \ + Expr expr = self->PY_FUNC(n); \ + return expr; \ + } else { \ + return self->VisitExpr_(static_cast(n.get())); \ + } \ + }); + +#define PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OP) \ + post_order_vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + return self->VisitExprPostOrder_(static_cast(n.get())); \ + }); + +template +class ExprFunctor { + private: + using TSelf = ExprFunctor; + using FType = tvm::NodeFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~ExprFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward(args)...); } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitExpr(const Expr& n, Args... args) { + ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may " + "have generated invalid data."; + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + // NOTE: cross dialect calls are invoked through global var + // We do not expect inline PrimFunc to appear in relax IR. + virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const DataflowVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ShapeExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ExternFuncNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SeqExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const PrimValueNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const DataTypeImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExprDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAX_EXPR_FUNCTOR_DISPATCH(ConstantNode); + RELAX_EXPR_FUNCTOR_DISPATCH(TupleNode); + RELAX_EXPR_FUNCTOR_DISPATCH(VarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(DataflowVarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(ShapeExprNode); + RELAX_EXPR_FUNCTOR_DISPATCH(ExternFuncNode); + RELAX_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode); + RELAX_EXPR_FUNCTOR_DISPATCH(CallNode); + RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode); + RELAX_EXPR_FUNCTOR_DISPATCH(IfNode); + RELAX_EXPR_FUNCTOR_DISPATCH(OpNode); + RELAX_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); + RELAX_EXPR_FUNCTOR_DISPATCH(PrimValueNode); + RELAX_EXPR_FUNCTOR_DISPATCH(StringImmNode); + RELAX_EXPR_FUNCTOR_DISPATCH(DataTypeImmNode); + return vtable; + } +}; + +/*! + * \brief A simple visitor wrapper around ExprFunctor. + * Recursively visit the content. + */ +class ExprVisitor : public ExprFunctor { + public: + /*! + * \brief Generic dispatcher for Expr. + * \param expr The expr to be visited. + */ + void VisitExpr(const Expr& expr) override; + // specific leaf level visitor functions + void VisitExpr_(const ConstantNode* op) override; + void VisitExpr_(const TupleNode* op) override; + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const DataflowVarNode* op) override; + void VisitExpr_(const ShapeExprNode* op) override; + void VisitExpr_(const ExternFuncNode* op) override; + void VisitExpr_(const GlobalVarNode* op) override; + void VisitExpr_(const FunctionNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const SeqExprNode* op) override; + void VisitExpr_(const IfNode* op) override; + void VisitExpr_(const OpNode* op) override; + void VisitExpr_(const TupleGetItemNode* op) override; + void VisitExpr_(const PrimValueNode* op) override; + void VisitExpr_(const StringImmNode* op) override; + void VisitExpr_(const DataTypeImmNode* op) override; + + /*! + * \brief Generic dispatcher for bindings. + * \param binding The binding to be visited. + */ + virtual void VisitBinding(const Binding& binding); + // specific leaf level visitor functions + virtual void VisitBinding_(const VarBindingNode* binding); + virtual void VisitBinding_(const MatchCastNode* binding); + // second level dispatching based on binding value type. + // these dispatching functions get called from first-level dispatch on VarBinding + virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const TupleNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const VarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ExternFuncNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const GlobalVarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const CallNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const SeqExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const PrimValueNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val); + /*! + * \brief Generic dispatcher for binding blocks. + * \param block The binding block to be visited. + */ + virtual void VisitBindingBlock(const BindingBlock& block); + // specific leaf level visitor functions + virtual void VisitBindingBlock_(const BindingBlockNode* block); + virtual void VisitBindingBlock_(const DataflowBlockNode* block); + + /*! + * \brief Generic dispatcher for visiting the var definition site. + * \param var The var to be visited. + * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var + */ + virtual void VisitVarDef(const Var& var); + + /*! + * \brief Visit struct_info may recursively contain Expr/PrimExpr. + * + * By default, this function recurse into struct info such as + * TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr + * accordingly. It does not recurse into FunctionStructInfo as it does + * not contain Expr defined in the current scope. + * + * Pass writers can overload this function to change to other behaviors. + * For example, if we are not interested in Expr in StructInfo, we can + * override this function by a no-op. + * + * \param struct_info Input struct info field. + */ + virtual void VisitExprDepStructInfoField(const StructInfo& struct_info); + + // specific leaf level visitor functions + virtual void VisitVarDef_(const VarNode* var); + virtual void VisitVarDef_(const DataflowVarNode* var); + + virtual void VisitSpan(const Span& span); + virtual void VisitPrimExpr(const PrimExpr& expr); + + private: + using TSelf = ExprVisitor; + using VisitBindingVTable = + tvm::NodeFunctor; + // initialize the vtable. + static VisitBindingVTable InitVisitBindingVTable(); + /*! + * \brief Private internal struct info field visitor. + * + * Support default visiting of struct info field and recursive into + * their Expr fields. + * + * We use component instead of sub-classing so there can be other + * joint inheritance between ExprVisitor and StructInfoVisitor. + */ + class DefaultStructInfoFieldVisitor : public StructInfoVisitor { + public: + explicit DefaultStructInfoFieldVisitor(ExprVisitor* parent); + + // Override defaults in struct info visitor. + void VisitStructInfoExprField(const Expr& expr) final; + void VisitStructInfoExprField(const PrimExpr& expr) final; + void VisitStructInfo_(const FuncStructInfoNode* op) final; + + private: + ExprVisitor* parent_; + }; + // This visitor is not visible to child classes and only + // used to supportd default visiting behavior. + DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this}; +}; + +void PostOrderVisit(const Expr& node, std::function fvisit); + +/*! + * \brief A mutator works in unnormalized form. + * + * ExprMutatorBase expects input AST to be in the unnormalized form, i.e., checked_type_ and shape_ + * of expressions can be nullptr, and the expressions may nest(and as a result the AST is not in + * ANF). + */ + +class ExprMutatorBase : public ExprFunctor { + public: + Expr VisitExpr(const Expr& expr) override; + Expr VisitExpr_(const ConstantNode* op) override; + Expr VisitExpr_(const TupleNode* op) override; + Expr VisitExpr_(const VarNode* op) override; + Expr VisitExpr_(const DataflowVarNode* op) override; + Expr VisitExpr_(const ShapeExprNode* op) override; + Expr VisitExpr_(const ExternFuncNode* op) override; + Expr VisitExpr_(const GlobalVarNode* op) override; + Expr VisitExpr_(const FunctionNode* op) override; + Expr VisitExpr_(const CallNode* op) override; + Expr VisitExpr_(const SeqExprNode* op) override; + Expr VisitExpr_(const IfNode* op) override; + Expr VisitExpr_(const OpNode* op) override; + Expr VisitExpr_(const TupleGetItemNode* op) override; + Expr VisitExpr_(const PrimValueNode* op) override; + Expr VisitExpr_(const StringImmNode* op) override; + Expr VisitExpr_(const DataTypeImmNode* op) override; + + /*! + * \brief Mutate BindingBlock. + * \param block The binding block to be visited. + * \return The binding block after transformation. + */ + virtual BindingBlock VisitBindingBlock(const BindingBlock& block); + + /*! + * \brief Used to visit the PrimExpr inside of expressions. + * + * Can be overloaded to transform the shape expressions. + */ + virtual PrimExpr VisitPrimExpr(const PrimExpr& expr); + + /*! + * \brief Visit struct_info that may recursively contain Expr/PrimExpr. + * + * By default, this function recurse into struct info such as + * TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr + * accordingly. It does not recurse into FunctionStructInfo as it does + * not contain Expr defined in the current scope. + * + * Pass writers can overload this function to change to other behaviors. + * For example, if in Expr in StructInfo won't change, we can + * override this function by an identity function. + * + * \param struct_info Input struct info field. + * \return The updated struct info. + */ + virtual StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info); + + protected: + /*! + * \brief Check whether VisitExprDepStructInfoField change struct_info. + * \return Whether struct info changed. + * \note This function is used by mutator implementations to check if + * previous Expr update will trigger a change in struct_info. + * If change is detected, the implementation can generate a fresh + * node without struct_info, and trigger normalizer to re-derive. + */ + bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef& struct_info) { + if (const StructInfoNode* sinfo = struct_info.as()) { + return this->VisitExprDepStructInfoField(GetRef(sinfo)).same_as(struct_info); + } else { + return true; + } + } + + private: + /*! + * \brief Private internal struct info field visitor to support + * Default visiting of struct info field and recursive into their Expr fields. + * + * We use component instead of sub-classing so there can be other + * joint inheritance between ExprMutator and StructInfoMutator. + */ + class DefaultStructInfoFieldMutator : public StructInfoMutator { + public: + explicit DefaultStructInfoFieldMutator(ExprMutatorBase* parent); + + // Override defaults in struct info visitor. + Expr VisitStructInfoExprField(const Expr& expr) final; + PrimExpr VisitStructInfoExprField(const PrimExpr& expr) final; + StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final; + + private: + ExprMutatorBase* parent_; + }; + // This visitor is not visible to child classes and only + // used to supportd default visiting behavior. + DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this}; +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_EXPR_FUNCTOR_H_ diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index d21c8db86b3f..f38a32f6bb83 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -22,13 +22,16 @@ #include #include #include -// #include #include #include namespace tvm { namespace relax { +// TODO(relax-team) replace with real BlockBuilder +// once it is ready. +using BlockBuilder = ObjectRef; + /*! * \brief Opaque object. */ @@ -257,8 +260,6 @@ class TupleStructInfo : public StructInfo { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo, TupleStructInfoNode); }; -class BlockBuilder; - /*! * \brief custom-defined StructInfo derivation function. * \param call The call expression to be derived. diff --git a/include/tvm/relax/struct_info_functor.h b/include/tvm/relax/struct_info_functor.h new file mode 100644 index 000000000000..382b4ab2c936 --- /dev/null +++ b/include/tvm/relax/struct_info_functor.h @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/struct_info_functor.h + * \brief Functors and visitors for struct info. + */ +#ifndef TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ +#define TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ + +#include +#include + +#include + +namespace tvm { +namespace relax { + +template +class StructInfoFunctor; + +// functions to be overriden. +#define STRUCT_INFO_FUNCTOR_DEFAULT \ + { return VisitStructInfoDefault_(op, std::forward(args)...); } + +#define TVM_STRUCT_INFO_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitStructInfo_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class StructInfoFunctor { + private: + using TSelf = StructInfoFunctor; + using FStructInfo = tvm::NodeFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~StructInfoFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const StructInfo& n, Args... args) { + return VisitStructInfo(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitStructInfo(const StructInfo& n, Args... args) { + ICHECK(n.defined()); + static FStructInfo vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitStructInfo_(const ObjectStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const PrimStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const ShapeStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const TensorStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const TupleStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfo_(const FuncStructInfoNode* op, + Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; + virtual R VisitStructInfoDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; // unreachable, written to stop compiler warning + } + + private: + // initialize the vtable. + static FStructInfo InitVTable() { + FStructInfo vtable; + // Set dispatch + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(ObjectStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(PrimStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(ShapeStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TensorStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TupleStructInfoNode); + TVM_STRUCT_INFO_FUNCTOR_DISPATCH(FuncStructInfoNode); + return vtable; + } +}; + +#undef TVM_STRUCT_INFO_FUNCTOR_DISPATCH + +/*! + * \brief A struct info visitor. + */ +class TVM_DLL StructInfoVisitor : public StructInfoFunctor { + public: + void VisitStructInfo_(const ObjectStructInfoNode* op) override; + void VisitStructInfo_(const PrimStructInfoNode* op) override; + void VisitStructInfo_(const ShapeStructInfoNode* op) override; + void VisitStructInfo_(const TensorStructInfoNode* op) override; + void VisitStructInfo_(const TupleStructInfoNode* op) override; + void VisitStructInfo_(const FuncStructInfoNode* op) override; + + protected: + // two functions to override when visit expr fields in struct info. + virtual void VisitStructInfoExprField(const Expr& expr) {} + virtual void VisitStructInfoExprField(const PrimExpr& expr) {} +}; + +/*! + * \brief StructInfoMutator that mutates struct info. + */ +class TVM_DLL StructInfoMutator : public StructInfoFunctor { + public: + StructInfo VisitStructInfo_(const ObjectStructInfoNode* op) override; + StructInfo VisitStructInfo_(const PrimStructInfoNode* op) override; + StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) override; + StructInfo VisitStructInfo_(const TensorStructInfoNode* op) override; + StructInfo VisitStructInfo_(const TupleStructInfoNode* op) override; + StructInfo VisitStructInfo_(const FuncStructInfoNode* op) override; + + protected: + // two functions to override when visit expr fields in struct info. + virtual Expr VisitStructInfoExprField(const Expr& expr) { return expr; } + virtual PrimExpr VisitStructInfoExprField(const PrimExpr& expr) { return expr; } +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_STRUCT_INFO_FUNCTOR_H_ diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 3c3fefb6d6c6..f90468de66c6 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -51,6 +51,17 @@ def checked_type(self): raise ValueError("The type checker has not populated" " the checked_type for this node") return ret + @property + def struct_info(self) -> "tvm.relax.StructInfo": + """Get the struct info field + + Returns + ------- + struct_info : tvm.relax.StructInfo + The struct info if available. + """ + return _ffi_api.ExprStructInfo(self) + @tvm._ffi.register_object("GlobalVar") class GlobalVar(RelayExpr): diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index c070fa479188..01310f6455dd 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -17,8 +17,56 @@ # pylint: disable=invalid-name, wrong-import-position """The Relax IR namespace containing the IR, type, operator, builder, vm, etc.""" from . import exec_builder +from . import expr +from . import ty +from . import analysis from . import vm +from . import struct_info + +# Expr +from .expr import ( + Expr, + Span, + SourceName, + Id, + GlobalVar, + Var, + DataflowVar, + Binding, + MatchCast, + VarBinding, + BindingBlock, + DataflowBlock, + SeqExpr, + ShapeExpr, + Tuple, + TupleGetItem, + Function, + ExternFunc, + Call, + If, + Constant, + PrimValue, + DataTypeImm, + StringImm, +) + +from .expr import const, extern, get_shape_of + +# Type +from .ty import Type, ObjectType, ShapeType, DynTensorType, TupleType, FuncType, PackedFuncType # VM from .exec_builder import ExecBuilder from .vm import VirtualMachine + +# StructInfo +from .struct_info import ( + StructInfo, + ObjectStructInfo, + PrimStructInfo, + ShapeStructInfo, + TensorStructInfo, + TupleStructInfo, + FuncStructInfo, +) diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py new file mode 100644 index 000000000000..cc0089ff3134 --- /dev/null +++ b/python/tvm/relax/analysis/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax IR analysis. """ + +from .analysis import * diff --git a/python/tvm/relax/analysis/_ffi_api.py b/python/tvm/relax/analysis/_ffi_api.py new file mode 100644 index 000000000000..40ee05c3960d --- /dev/null +++ b/python/tvm/relax/analysis/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs""" +import tvm._ffi + +tvm._ffi._init_api("relax.analysis", __name__) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py new file mode 100644 index 000000000000..301f3ecc7265 --- /dev/null +++ b/python/tvm/relax/analysis/analysis.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return +# pylint: disable=unidiomatic-typecheck +""" +This file contains the set of passes for Relax, which exposes an interface for +configuring the passes and scripting them in Python. +""" + +from typing import Dict +from enum import IntEnum + +from tvm import tir +from tvm.relax.ty import Type +from tvm.relax.struct_info import StructInfo +from tvm.relax.expr import Var, Expr +from . import _ffi_api + + +def get_static_type(sinfo: StructInfo) -> Type: + """Get the corresponding static type from a StructInfo. + + Parameters + ---------- + sinfo : StructInfo + The input struct info. + + Returns + ------- + ret : Type + The corresponding static type. + """ + return _ffi_api.GetStaticType(sinfo) # type: ignore + + +def erase_to_well_defined( + sinfo: StructInfo, + shape_var_map: Dict[tir.Var, tir.PrimExpr] = None, + var_map: Dict[Var, Expr] = None, +) -> StructInfo: + """Erase sinfo into a well defined form. + + This function removes the StructInfo's dependencies on shape and vars that + are not defined in given maps. + + Parameters + ---------- + sinfo : StructInfo + The input struct info. + + shape_var_map : Dict[tir.Var, tir.PrimExpr] + Specifies the defined shape vars and the values they should map to. + + var_map : Dict[Var, Expr] + Specifies the defined vars and the values they should map to. + + Returns + ------- + ret : StructInfo + The corresponding erased struct info. + """ + shape_var_map = {} if shape_var_map is None else shape_var_map + var_map = {} if var_map is None else var_map + + return _ffi_api.EraseToWellDefined(sinfo, shape_var_map, var_map) # type: ignore + + +class BaseCheckResult(IntEnum): + """Return result of fine-grained base check. + + Note + ---- + Base check comes with fine-grained fail levels. + + - FAIL_L0: The lhs and rhs have no intersection at all. + - FAIL_L1: We get the failure by looking at static information. + - FAIL_L2: We get the failure due to unknown symbolic variable relations. + """ + + FAIL_L0 = 0 + FAIL_L1 = 1 + FAIL_L2 = 2 + PASS = 3 + + +def struct_info_base_check(base: StructInfo, derived: StructInfo) -> BaseCheckResult: + """Run a base check to see if base subsumes derived. + + Parameters + ---------- + base: StructInfo + The base struct info. + + derived: StructInfo + The derived struct info. + + Returns + ------- + ret : StructInfo + The derived return value struct info. + """ + return _ffi_api.StructInfoBaseCheck(base, derived) # type: ignore + + +def struct_info_lca(lhs: StructInfo, rhs: StructInfo) -> StructInfo: + """Unify the two struct info to their least common ancestor. + + Parameters + ---------- + lhs: StructInfo + The left operand. + + rhs: StructInfo + The right operand. + + Returns + ------- + ret : StructInfo + The corresponding lca result. + """ + return _ffi_api.StructInfoLCA(lhs, rhs) # type: ignore diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py new file mode 100644 index 000000000000..138724ed0693 --- /dev/null +++ b/python/tvm/relax/expr.py @@ -0,0 +1,729 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-import, super-init-not-called +# pylint: disable=redefined-builtin +"""The expression nodes of Relax.""" +import typing +from numbers import Number +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as _np # type: ignore +import tvm +import tvm._ffi +import tvm.relax +import tvm.ir +from tvm import DataType +from tvm._ffi import base as _base +from tvm.runtime import ndarray as _nd, Object + +from ..ir import BaseFunc, Node, SourceName, Span +from ..runtime import String +from ..tir import PrimExpr +from . import _ffi_api + +# It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370 +# This feature is not supported until python 3.10: +# https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias +Expr = Union[tvm.ir.RelayExpr] +Type = Union[tvm.ir.Type] +GlobalVar = Union[tvm.ir.GlobalVar] + + +@tvm._ffi.register_object("relax.Id") +class Id(Object): + """Unique identifier(name) used in Var. + Guaranteed to be stable across all passes. + """ + + def __init__(self): + raise RuntimeError("Cannot directly construct Id") + + +# NOTE: place base struct info in expr to avoid cyclic dep +# from expr to struct info. +class StructInfo(Node): + """The base class of all StructInfo. + + StructInfo contains both the static type + and runtime structural information. + """ + + def __eq__(self, other): + """Compare two struct info for structural equivalence.""" + return tvm.ir.structural_equal(self, other) + + def __ne__(self, other): + return not self.__eq__(other) + + def same_as(self, other): + """Overload with structural equality.""" + return super().__eq__(other) + + def is_base_of(self, derived: "StructInfo") -> bool: + """Check if self is base of another derived struct info. + + Parameters + ---------- + derived : StructInfo + The derived struct info to be checked. + + Returns + ------- + result : bool + The check result. + """ + return _ffi_api.StructInfoIsBaseOf(self, derived) # type: ignore + + +# will be registered afterwards in python/tvm/relax/op/init.py +_op_ffi_api = None + + +def _binary_op_helper(lhs: "ExprWithOp", rhs: "ExprWithOp", op: Callable) -> "ExprWithOp": + if not isinstance(lhs, Expr): # type: ignore + raise ValueError("lhs must be Expr") + if isinstance(rhs, Expr): # type: ignore + return op(lhs, rhs) + elif isinstance(rhs, Number): + raise TypeError(f"Please convert {rhs} with `const` first") + else: + raise TypeError(f"type {type(rhs)} not supported") + + +def _binary_rhs_helper(rhs: "ExprWithOp") -> "ExprWithOp": + if isinstance(rhs, Number): + raise TypeError(f"Please convert {rhs} with `const` first") + raise TypeError(f"type {type(rhs)} not supported") + + +class ExprWithOp(Expr): + """Basetype of all relax expressions that defines op overloading.""" + + def astype(self, dtype: Union[str, DataType]) -> "ExprWithOp": + """Cast the content type of the current data to dtype. + + Parameters + ---------- + dtype : str + The target data type. + + Note + ---- + This function only works for TensorType Exprs. + + Returns + ------- + result : ExprWithOp + The result expression. + """ + return _op_ffi_api.astype(self, dtype) # type: ignore + + def __neg__(self) -> "ExprWithOp": + raise ValueError("relax.negative is not supported yet.") + + def __lt__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.less) # type: ignore + + def __gt__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.greater) # type: ignore + + def __ge__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.greater_equal) # type: ignore + + def __le__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.less_equal) # type: ignore + + # NOTE: Cannot override __eq__ and __ne__, which will influence object equal + + def __add__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.add) # type: ignore + + def __radd__(self, other: Expr) -> "ExprWithOp": + return self.__add__(other) + + def __sub__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.subtract) # type: ignore + + def __rsub__(self, other: Expr) -> "ExprWithOp": + return _binary_rhs_helper(other) + + def __mul__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.multiply) # type: ignore + + def __rmul__(self, other: Expr) -> "ExprWithOp": + return self.__mul__(other) + + def __truediv__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.divide) # type: ignore + + def __rtruediv__(self, other: Expr) -> "ExprWithOp": + return _binary_rhs_helper(other) + + def __floordiv__(self, other: Expr) -> "ExprWithOp": + return _binary_op_helper(self, other, _op_ffi_api.floor_divide) # type: ignore + + def __rfloordiv__(self, other: Expr) -> "ExprWithOp": + return _binary_rhs_helper(other) + + def __mod__(self, other: Expr) -> "ExprWithOp": + # TODO(siyuan): Support it after mod operator is supported in relax + raise ValueError("relax.mod is not supported yet.") + + def __rmod__(self, other: Expr) -> "ExprWithOp": + return _binary_rhs_helper(other) + + def __call__(self, *args: List[Expr], attrs: Optional[Dict[str, Any]] = None) -> "ExprWithOp": + """Call the variable (if it represents a function). + + Parameters + ---------- + args: List[Expr] + The arguments to the call. + + attr: Optional[Dict[str, object]] + The additional attributes to the call. + + Returns + ------- + call: ExprWithOp + A call taking the variable as a function. + """ + return Call(self, args, attrs=attrs) + + def __getitem__(self, index: int) -> "ExprWithOp": + """Get the i-th element of the tuple or Expr with TupleType. + + Parameters + ---------- + index: int + The index of the element to be retrieved. + + Note + ---- + This function will be overridden by Tuple and ShapeExpr + + Returns + ------- + result: ExprWithOp + The result expression. + """ + return TupleGetItem(self, index) + + +@tvm._ffi.register_object("relax.expr.Call") +class Call(ExprWithOp): + """Function call node in Relax. + + Call node corresponds the operator application node + in computational graph terminology. + + Parameters + ---------- + op: tvm.ir.Op or any tvm.relax.Expr with function type. + The operation to be called. + + args: Union[List[Expr], typing.Tuple[Expr, ...]] + The arguments to the call. + + attrs: Optional[tvm.ir.Attrs] + Attributes to the call, can be None + + sinfo_args: Optional[Union[List[StructInfo], typing.Tuple[StructInfo, ...]]] + The structure info arguments of a CallNode. + sinfo_args is designed to be non-empty only for intrinsic op (e.g., + call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main + usage of structure info inference. + + span: Optional[Span] + Span that points to original source code + """ + + def __init__( + self, + op: Union[Expr, tvm.ir.Op], + args: Union[List[Expr], typing.Tuple[Expr, ...]], + attrs: Optional[tvm.ir.Attrs] = None, + sinfo_args: Optional[Union[List[StructInfo], typing.Tuple[StructInfo, ...]]] = None, + span: Optional[Span] = None, + ): + if not sinfo_args: + sinfo_args = [] + self.__init_handle_by_constructor__( + _ffi_api.Call, op, args, attrs, sinfo_args, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.expr.If") +class If(ExprWithOp): + """A conditional expression in Relax. + + Parameters + ---------- + cond: Expr + The condition. + + true_branch: Expr + The expression evaluated when condition is true. + + false_branch: Expr + The expression evaluated when condition is false. + """ + + def __init__(self, cond: Expr, true_branch: Expr, false_branch: Expr, span: Span = None): + self.__init_handle_by_constructor__( + _ffi_api.If, cond, true_branch, false_branch, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.expr.Tuple") +class Tuple(ExprWithOp): + """Tuple expression that groups several fields together. + + Parameters + ---------- + fields : Union[List[Expr], typing.Tuple[Expr, ...]] + The fields in the tuple. + + span: Optional[Span] + Span that points to original source code + """ + + def __init__(self, fields: Union[List[Expr], typing.Tuple[Expr, ...]], span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.Tuple, fields, span) # type: ignore + + def __getitem__(self, index: int) -> Expr: + if index >= len(self) or index < -len(self): + raise IndexError("Tuple index out of range") + return self.fields[index] + + def __len__(self) -> int: + return len(self.fields) + + +@tvm._ffi.register_object("relax.expr.TupleGetItem") +class TupleGetItem(ExprWithOp): + """Get index-th item from a tuple. + + Parameters + ---------- + tuple_value: Expr + The input tuple expression. + + index: int + The index. + """ + + def __init__(self, tuple_value: Expr, index: int): + self.__init_handle_by_constructor__( + _ffi_api.TupleGetItem, tuple_value, index # type: ignore + ) + + +@tvm._ffi.register_object("relax.expr.ShapeExpr") +class ShapeExpr(ExprWithOp): + """A shape expression which allows users to construct a shape containing PrimExpr.""" + + values: List[PrimExpr] + + def __init__( + self, + values: Union[List[PrimExpr], typing.Tuple[PrimExpr, ...], tvm.ir.Array], + span: Span = None, + ) -> None: + self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values, span) # type: ignore + + def __getitem__(self, index): + if index >= len(self) or index < -len(self): + raise IndexError("ShapeExpr index out of range") + return self.values[index] + + def __len__(self): + return len(self.values) + + +def make_shape(shape: Union[List[Any], typing.Tuple[Any, ...]]) -> ShapeExpr: + if isinstance(shape, (list, tuple)): + return ShapeExpr(shape) + raise ValueError("Wrong type") + + +@tvm._ffi.register_object("relax.expr.Constant") +class Constant(ExprWithOp): + def __init__(self, data: tvm.nd.NDArray, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.Constant, data, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.Var") +class Var(ExprWithOp): + """The variable class for all Relax bindings.""" + + vid: Id + struct_info: Optional[StructInfo] + + def __init__( + self, + name_hint: Union[str, Id], + struct_info: Optional[StructInfo] = None, + span: Span = None, + ) -> None: + if struct_info is not None: + struct_info = tvm.runtime.convert_to_object(struct_info) + if not isinstance(struct_info, StructInfo): + raise TypeError( + "struct_info needs to be an instance of StructInfo. " + "If you attempt to pass in shape, " + "use relax.TensorStructInfo(shape, dtype)." + ) + self.__init_handle_by_constructor__( + _ffi_api.Var if isinstance(name_hint, str) else _ffi_api.VarFromId, # type: ignore + name_hint, + struct_info, + span, + ) + + @property + def name_hint(self): + """Get name hint of the current var.""" + name = str(self.vid.name_hint) + return name + + +@tvm._ffi.register_object("relax.expr.DataflowVar") +class DataflowVar(Var): + """A sub-type of the variable node used to mark dataflow variables from + normal visible "function local" bindings.""" + + vid: Id + struct_info: Optional[StructInfo] + + def __init__( + self, + name_hint: Union[str, Id], + struct_info: Optional[StructInfo] = None, + span: Span = None, + ) -> None: + if struct_info is not None: + struct_info = tvm.runtime.convert_to_object(struct_info) + if not isinstance(struct_info, StructInfo): + raise TypeError( + "struct_info needs to be an instance of StructInfo. " + "If you attempt to pass in shape, " + "use relax.TensorStructInfo(shape, dtype)." + ) + + self.__init_handle_by_constructor__( + _ffi_api.DataflowVar # type: ignore + if isinstance(name_hint, str) + else _ffi_api.DataflowVarFromId, # type: ignore + name_hint, + struct_info, + span, + ) + + +@tvm._ffi.register_object("relax.expr.PrimValue") +class PrimValue(Expr): + """The prim expr representing the value.""" + + value: PrimExpr + + def __init__(self, value: Union[PrimExpr, int], span: Span = None) -> None: + if isinstance(value, int): + value = tvm.tir.IntImm("int64", value) + self.__init_handle_by_constructor__(_ffi_api.PrimValue, value, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.StringImm") +class StringImm(Expr): + """Represent a string literal constant.""" + + value: str + + def __init__(self, value: str, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.DataTypeImm") +class DataTypeImm(Expr): + """Represent a data type constant.""" + + value: DataType + + def __init__(self, value: Union[DataType, str], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.DataTypeImm, value, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.Binding") +class Binding(Node): + """The base class of a binding in Relax.""" + + ... + + +@tvm._ffi.register_object("relax.expr.MatchCast") +class MatchCast(Binding): + """Runtime-match the value to the struct info. + + This operation does runtime check, populates the un-defined symbolic shape vars + and vars in struct_info in the first occurrence, and insert equality assertions in + other cases. + + Parameters + ---------- + var: Var + The return variable that the match cast bind to. + + value: Expr + The input value expression. + + struct_info: tvm.relax.StructInfo + The struct info to match cast to. + """ + + var: Var + struct_info: "tvm.relax.StructInfo" + value: Expr + + def __init__( + self, var: Var, value: Expr, struct_info: "tvm.relax.StructInfo", span: Span = None + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.MatchCast, var, value, struct_info, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.expr.VarBinding") +class VarBinding(Binding): + """Variable binding, bind he variable of the lhs with the rhs.""" + + var: Var + value: Expr + + def __init__(self, var: Var, value: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.VarBinding, var, value, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.BindingBlock") +class BindingBlock(Node): + """base class of binding block, bindings inside can be impure + (with side effect or control flow)""" + + bindings: List[Binding] + + def __init__(self, bindings: List[Binding], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.BindingBlock, bindings, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.DataflowBlock") +class DataflowBlock(BindingBlock): + """dataflow block, bindings inside are pure (no side effect and no control flow)""" + + def __init__(self, bindings: List[Binding], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.DataflowBlock, bindings, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.SeqExpr") +class SeqExpr(ExprWithOp): + """A sequence of binding blocks followed by an expression.""" + + blocks: List[BindingBlock] + body: Expr + + def __init__(self, blocks: List[BindingBlock], body: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.SeqExpr, blocks, body, span) # type: ignore + + +@tvm._ffi.register_object("relax.expr.Function") +class Function(BaseFunc): + """A Relax function.""" + + params: List[Var] + body: Expr + ret_struct_info: StructInfo + attrs: Optional[tvm.ir.DictAttrs] + + def __init__( + self, + params: List[Var], + body: Expr, + ret_struct_info: Optional[StructInfo] = None, + attrs: Optional[tvm.ir.DictAttrs] = None, + span: Optional[Span] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.Function, params, body, ret_struct_info, attrs, span # type: ignore + ) + + @staticmethod + def create_empty( + params: List[Var], + ret_struct_info: StructInfo, + attrs: Optional[tvm.ir.DictAttrs] = None, + span: Optional[Span] = None, + ): + """Construct a relax.Function but without body""" + return _ffi_api.FunctionCreateEmpty(params, ret_struct_info, attrs, span) # type: ignore + + def __call__(self, *args): + """Invoke the global function. + + Parameters + ---------- + args: List[relax.Expr] + Arguments. + """ + return Call(self, args, None, None) + + def script(self, show_meta: bool = False) -> str: + """Print relax.Function into TVMScript + + Parameters + ---------- + show_meta : bool + Whether to show meta information + + Returns + ------- + script : str + The TVM Script of the relax.Function + """ + return tvm._ffi.get_global_func("script.AsRelaxScript")(self, show_meta) # type: ignore + + def show(self, style: str = "light") -> None: + """ + A sugar for print highlighted TVM script. + + Parameters + ---------- + style : str, optional + Pygments styles extended by "light" (default) and "dark", by default "light" + """ + from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel + + # Use deferred import to avoid circular import while keeping cprint under tvm/script + cprint(self, style=style) + + +@tvm._ffi.register_object("relax.expr.ExternFunc") +class ExternFunc(BaseFunc): + """extern function, which can represent a TIR PrimFunc or a PackedFunc.""" + + global_symbol: String + + def __init__(self, global_symbol: String, span: Span = None) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ExternFunc, global_symbol, span # type: ignore + ) + + +def extern(name: str, span: Span = None): + """Create extern function.""" + return ExternFunc(name, span) + + +def const( + value: Union[bool, int, float, _np.ndarray, tvm.nd.NDArray], dtype: Optional[str] = None +) -> Constant: + """Create a constant value. + + Parameters + ---------- + value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray] + The constant value. + + dtype: Optional[str] + The data type of the resulting constant. + + Note + ---- + When dtype is None, we use the following rule: + + - int maps to "int32" + - float maps to "float32" + - bool maps to "bool" + - other using the same default rule as numpy. + """ + if isinstance(value, (_base.numeric_types, (bool, list))): + value = _np.array(value, dtype=dtype) + + if not dtype: + # when dtype is None: int maps to "int32", float maps to "float32" + dtype = { # type: ignore + _np.dtype("int64"): _np.int32, # type: ignore + _np.dtype("float64"): _np.float32, # type: ignore + }.get( + value.dtype, None # type: ignore + ) + + if isinstance(value, (_np.ndarray, _np.generic)): + if dtype is not None: + value = value.astype(dtype) + value = _nd.array(value) + + if not isinstance(value, _nd.NDArray): + raise ValueError("value has to be scalar or NDArray") + + return Constant(value) + + +def te_tensor( + value: Expr, tir_var_map: Dict[tvm.tir.Var, tvm.tir.PrimExpr], name: str = "rxplaceholder" +): + """Create a TE tensor from relax expression, with TIR variables in the + tensor shape substituted by the given mapping + + Parameters + ---------- + value : Expr + The relax expression, which is required to have TensorStructInfo. + + tir_var_map : Dict[tvm.tir.Var, tvm.tir.PrimExpr] + The mapping to substitute the TIR variables appeared in the + shape of the input Expr. + + name : str + The name of the created tensor. + """ + return _ffi_api.TETensor(value, tir_var_map, name) # type: ignore + + +def get_shape_of(expr: Expr) -> Expr: + """Get shape of expr. + + Parameters + ---------- + expr: Expr + The input expr. + + Returns + ------- + shape: Expr + The shape expression + + Note + ---- + This function requires expr to be normalized. + The function will report an error if expr's StructInfo is not TensorStructInfo. + It will try to return symbolic function when possible. If the tensor do not + have a compile-time symbolic shape, the function will then choose to return + `Call(relax.op.shape_of, [expr])`. + """ + return _ffi_api.GetShapeOf(expr) # type: ignore + + +def _update_struct_info(expr: Expr, struct_info: Optional[StructInfo]) -> None: + _ffi_api.UpdateStructInfo(expr, struct_info) # type: ignore diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py new file mode 100644 index 000000000000..2ff027b22924 --- /dev/null +++ b/python/tvm/relax/struct_info.py @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-import +"""The struct info nodes of the Relax language.""" +from typing import List, Optional, Tuple, Union + +import tvm._ffi +import tvm + +from tvm.ir import Span, Node, EnvFunc, Array, Type +from tvm.tir import PrimExpr +from .expr import StructInfo, Var, Expr, ShapeExpr + +from . import _ffi_api, ty, expr + + +@tvm._ffi.register_object("relax.ObjectStructInfo") +class ObjectStructInfo(StructInfo): + """StructInfo of an Object.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ObjectStructInfo, span) # type: ignore + + +@tvm._ffi.register_object("relax.PrimStructInfo") +class PrimStructInfo(StructInfo): + """StructInfo of a primitive POD value. + + Parameters + ---------- + dtype : str + The data type of the prim value. + """ + + dtype: str + + def __init__(self, dtype: str, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.PrimStructInfo, dtype, span) # type: ignore + + +@tvm._ffi.register_object("relax.ShapeStructInfo") +class ShapeStructInfo(StructInfo): + """StructInfo of a shape value. + + Parameters + ---------- + values : Optional[List[PrimExpr]] + The symbolic shape values if known. + + ndim : Optional[int] + The size of the shape. + + Note + ---- + Do not specify values and ndim at the same time. + """ + + values: Optional[List[PrimExpr]] + ndim: int + span: Span + + def __init__( + self, values: Optional[List[PrimExpr]] = None, ndim: int = -1, span: Span = None + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ShapeStructInfo, values, ndim, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.TensorStructInfo") +class TensorStructInfo(StructInfo): + """StructInfo of a Tensor value. + + Parameters + ---------- + shape : Optional[Expr] + The shape expression. + + dtype : Optional[str] + The content data type. + + ndim : Optional[int] + The number of dimensions of the tensor. + + Note + ---- + Do not specify shape and ndim at the same time. + """ + + shape: Optional[Expr] + dtype: str + ndim: int + span: Span + + def __init__( + self, + shape: Union[Optional[Expr], List[PrimExpr]] = None, + dtype: str = "float32", + ndim: int = -1, + span: Span = None, + ) -> None: + if isinstance(shape, (list, tuple, Array)): + shape = ShapeExpr(shape) + + self.__init_handle_by_constructor__( + _ffi_api.TensorStructInfo, shape, dtype, ndim, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.TupleStructInfo") +class TupleStructInfo(StructInfo): + """StructInfo of a Tuple value. + + Parameters + ---------- + fields: List[StructInfo] + The struct info of the fields. + """ + + fields: List[StructInfo] + span: Span + + def __init__(self, fields: List[StructInfo], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.TupleStructInfo, fields, span) # type: ignore + + +@tvm._ffi.register_object("relax.FuncStructInfo") +class FuncStructInfo(StructInfo): + """StructInfo of a function value. + + Parameters + ---------- + params: List[StructInfo] + The struct info of the fields. + + ret: StructInfo + The struct info of return value + """ + + params: Optional[List[StructInfo]] + ret: StructInfo + derive_func: Optional[EnvFunc] + span: Span + + def __init__(self, params: List[StructInfo], ret: StructInfo, span: Span = None) -> None: + self.__init_handle_by_constructor__( + _ffi_api.FuncStructInfo, params, ret, span # type: ignore + ) + + @staticmethod + def opaque_func( + *, + ret: Optional[StructInfo] = None, + derive_func: Optional[EnvFunc] = None, + span: Span = None, + ) -> "FuncStructInfo": + """ + Create an opaque FuncStructInfo. + + The opaque function takes either a ret + that specificies the struct info of the return value + or a derive_func that provides a customized derivation rule. + + Parameters + ---------- + ret: Optional[StructInfo] + The struct info of the the function return value. + + derive_func: Optional[EnvFunc] + The environment function used for derivation + + span: Optional[Span] + Optional span information of the ast. + + Returns + ------- + info: FuncStructInfo + + Note + ---- + We cannot specify ret and derive_func simultaneously. + """ + return _ffi_api.FuncStructInfoOpaqueFunc(ret, derive_func, span) # type: ignore diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py new file mode 100644 index 000000000000..05492d6a9c34 --- /dev/null +++ b/python/tvm/relax/ty.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-import +"""The type nodes of the Relax language.""" +import tvm._ffi +from tvm.ir import Type, TensorType, TupleType, FuncType, Span + +from . import _ffi_api + + +@tvm._ffi.register_object("relax.ShapeType") +class ShapeType(Type): + """The type of shape in Relax. + + Parameters + ---------- + ndim : Optional[int] + The size of the shape. + """ + + # TODO(relax-team): consider make ndim mandatory + def __init__(self, ndim: int = -1, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ShapeType, ndim, span) # type: ignore + + +@tvm._ffi.register_object("relax.ObjectType") +class ObjectType(Type): + """A type that corresponds to tvm::runtime::Object, is base of all possible object + values in TVM.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ObjectType, span) # type: ignore + + +@tvm._ffi.register_object("relax.DynTensorType") +class DynTensorType(Type): + """A dynamic tensor type in Relax. + + This is the type assigned to tensors with a known dtype and unknown shape. + + Parameters + ---------- + ndim : Optional[int] + The ndim of the Tensor + + dtype : Optional[str] + The content data type. + """ + + def __init__(self, ndim=-1, dtype="float32", span: Span = None) -> None: + self.__init_handle_by_constructor__( + _ffi_api.DynTensorType, ndim, dtype, span # type: ignore + ) + + +@tvm._ffi.register_object("relax.PackedFuncType") +class PackedFuncType(Type): + """The type of ExternFunc in Relax.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.PackedFuncType, span) # type: ignore diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index 9283727ad41a..6d92c68367b3 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -18,3 +18,4 @@ from .parser import ir, ir_module from .parser import parse as from_source from .parser import tir +from .parser import relax diff --git a/python/tvm/script/parser/relax/__init__.py b/python/tvm/script/parser/relax/__init__.py new file mode 100644 index 000000000000..feb8e683401c --- /dev/null +++ b/python/tvm/script/parser/relax/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Initial impl of relax parser for sugars""" +from tvm.relax import TensorStructInfo, ShapeStructInfo + +Tensor = TensorStructInfo +Shape = ShapeStructInfo diff --git a/src/ir/function.cc b/src/ir/function.cc index ce294708b2a9..69752f529a3c 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -22,6 +22,8 @@ * \brief The function data structure. */ #include +#include +#include #include #include @@ -35,13 +37,13 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> BaseFunc { if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); + } else if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); } - if (const auto* f = runtime::Registry::Get("relay.ir.FuncWithAttr")) { - if (Optional ret = (*f)(func, key, value)) { - return ret.value(); - } - } - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); }); } // namespace tvm diff --git a/src/ir/type.cc b/src/ir/type.cc index d965406e8bb0..b61a3df09107 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -25,9 +25,10 @@ #include namespace tvm { -PrimType::PrimType(runtime::DataType dtype) { +PrimType::PrimType(runtime::DataType dtype, Span span) { ObjectPtr n = make_object(); n->dtype = dtype; + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/relax/analysis/shape_analysis.cc b/src/relax/analysis/shape_analysis.cc new file mode 100644 index 000000000000..70ce5ac06e90 --- /dev/null +++ b/src/relax/analysis/shape_analysis.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file shape_analysis.cc + * + * \brief Utilities for shape analysis. + */ + +#include +#include + +namespace tvm { +namespace relax { + +bool CanProveShapeEqual(const Array& lhs, const Array& rhs, + arith::Analyzer* ana) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!ana->CanProveEqual(lhs[i], rhs[i])) return false; + } + return true; +} + +bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana) { + if (lhs.same_as(rhs)) return true; + auto* lhs_shape = lhs.as(); + auto* rhs_shape = rhs.as(); + + if (lhs_shape && rhs_shape) { + return CanProveShapeEqual(lhs_shape->values, rhs_shape->values, ana); + } else { + return false; + } +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc new file mode 100644 index 000000000000..d9b139753455 --- /dev/null +++ b/src/relax/analysis/struct_info_analysis.cc @@ -0,0 +1,716 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file struct_info_analysis.cc + * \brief Implementations of foundation struct info analysis + * + * \note Update this file when you added a new StructInfo. + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +//-------------------------- +// GetStaticType +//-------------------------- +class StaticTypeDeriver : public StructInfoFunctor { + public: + Type VisitStructInfo_(const ObjectStructInfoNode* op) final { return ObjectType(op->span); } + + Type VisitStructInfo_(const PrimStructInfoNode* op) final { + return PrimType(op->dtype, op->span); + } + + Type VisitStructInfo_(const ShapeStructInfoNode* op) final { + return ShapeType(op->ndim, op->span); + } + + Type VisitStructInfo_(const TensorStructInfoNode* op) final { + return DynTensorType(op->ndim, op->dtype); + } + + Type VisitStructInfo_(const TupleStructInfoNode* op) final { + Array fields = + op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + return TupleType(fields, op->span); + } + + Type VisitStructInfo_(const FuncStructInfoNode* op) final { + if (op->IsOpaque()) return PackedFuncType(op->span); + Array params = op->params.value().Map( + [this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + Type ret = this->VisitStructInfo(op->ret); + return FuncType(params, ret, {}, {}, op->span); + } +}; + +Type GetStaticType(const StructInfo& info) { return StaticTypeDeriver()(info); } + +TVM_REGISTER_GLOBAL("relax.analysis.GetStaticType").set_body_typed([](const StructInfo& info) { + return GetStaticType(info); +}); + +//-------------------------- +// StructInfoFromType +//-------------------------- + +StructInfo StructInfoFromType(const Type& type) { + if (type.as()) { + return ObjectStructInfo(type->span); + } else if (const PrimTypeNode* prim_type = type.as()) { + return PrimStructInfo(prim_type->dtype, prim_type->span); + } else if (const ShapeTypeNode* shape_type = type.as()) { + return ShapeStructInfo(shape_type->ndim, type->span); + } else if (const DynTensorTypeNode* tensor_type = type.as()) { + return TensorStructInfo(tensor_type->dtype, tensor_type->ndim); + } else if (const TupleTypeNode* tuple_type = type.as()) { + Array fields; + for (const Type& field : tuple_type->fields) { + fields.push_back(StructInfoFromType(field)); + } + return TupleStructInfo(fields, type->span); + } else if (const FuncTypeNode* func_type = type.as()) { + Array params = + func_type->arg_types.Map([](const Type& param) { return StructInfoFromType(param); }); + StructInfo ret = StructInfoFromType(func_type->ret_type); + return FuncStructInfo(params, ret, func_type->span); + } else { + LOG(FATAL) << "Unsupported type: " << type; + return StructInfo(); + } +} + +//-------------------------- +// EraseToWellDefined +//-------------------------- +class WellDefinedEraser : public StructInfoMutator, + public ExprMutatorBase, + public tir::ExprMutator { + public: + WellDefinedEraser(std::function(const tir::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, arith::Analyzer* ana) + : f_shape_var_map_(f_shape_var_map), f_var_map_(f_var_map), ana_(ana) {} + + StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) final { + bool has_undefined = false; + Optional> values; + + if (op->values.defined()) { + std::swap(has_undefined_, has_undefined); + values = op->values.value().Map([&](PrimExpr val) { return this->VisitPrimExpr(val); }); + std::swap(has_undefined_, has_undefined); + } + // erase symbolic shape if we have undefined. + if (!has_undefined) { + if (values.same_as(op->values)) { + return GetRef(op); + } else { + return ShapeStructInfo(values.value(), op->span); + } + } else { + return ShapeStructInfo(op->ndim, op->span); + } + } + + StructInfo VisitStructInfo_(const TensorStructInfoNode* op) final { + bool has_undefined = false; + Optional shape; + + if (op->shape.defined()) { + std::swap(has_undefined_, has_undefined); + shape = relax::ExprMutatorBase::VisitExpr(op->shape.value()); + std::swap(has_undefined_, has_undefined); + } + + // erase symbolic shape if we have undefined. + if (!has_undefined) { + if (shape.same_as(op->shape)) { + return GetRef(op); + } else { + if (shape.defined()) { + return TensorStructInfo(shape.value(), op->dtype, op->span); + } else { + return TensorStructInfo(op->dtype, op->ndim, op->span); + } + } + } else { + return TensorStructInfo(op->dtype, op->ndim, op->span); + } + } + + StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final { + // NOTE: we always require func struct info to be well-defined. + // + // All the occuring symbolic variables are defined in parameters' + // struct info annotations. So there is no needed to erase. + return GetRef(op); + } + + using relax::ExprMutatorBase::VisitExpr_; + using tir::ExprMutator::VisitExpr_; + + // connect things up + PrimExpr VisitPrimExpr(const PrimExpr& expr) { + // apply eager simplification + PrimExpr val = tir::ExprMutator::VisitExpr(expr); + if (!val.same_as(expr)) { + return ana_->Simplify(val); + } else { + return val; + } + } + + Expr VisitExpr_(const DataflowVarNode* var) final { + return VisitExpr_(static_cast(var)); + } + + Expr VisitExpr_(const VarNode* var) final { + Optional ret; + if (f_var_map_ != nullptr) { + ret = f_var_map_(GetRef(var)); + } + has_undefined_ = has_undefined_ || !ret.defined(); + if (ret.defined()) { + ICHECK(ret.as() || ret.as()) + << "Only allow Expr in StructInfo to be ShapeExpr or Var"; + } + return ret.value_or(GetRef(var)); + } + + PrimExpr VisitExpr_(const tir::VarNode* var) final { + Optional ret; + if (f_shape_var_map_ != nullptr) { + ret = f_shape_var_map_(GetRef(var)); + } + has_undefined_ = has_undefined_ || !ret.defined(); + + if (ret.defined()) { + PrimExpr value = ret.value(); + if (value->IsInstance()) { + return tvm::cast(DataType::Int(64), value); + } + ICHECK(value.dtype() == DataType::Int(64)) << "Can only provide i64 expressions in shape"; + return value; + } else { + return GetRef(var); + } + } + + private: + bool has_undefined_ = false; + std::function(const tir::Var& var)> f_shape_var_map_; + std::function(const Var& var)> f_var_map_; + arith::Analyzer* ana_; +}; + +StructInfo EraseToWellDefined( + const StructInfo& info, std::function(const tir::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return WellDefinedEraser(f_shape_var_map, f_var_map, &inst).VisitStructInfo(info); + } else { + return WellDefinedEraser(f_shape_var_map, f_var_map, ana).VisitStructInfo(info); + } +} + +StructInfo EraseToWellDefined(const StructInfo& info, Map shape_var_map, + Map var_map, arith::Analyzer* ana) { + std::function(const tir::Var& var)> f_shape_var_map = nullptr; + std::function(const Var& var)> f_var_map = nullptr; + + if (!shape_var_map.empty()) { + f_shape_var_map = [&](const tir::Var& var) -> Optional { + auto it = shape_var_map.find(var); + if (it != shape_var_map.end()) return (*it).second; + return NullOpt; + }; + } + + if (!var_map.empty()) { + f_var_map = [&](const Var& var) -> Optional { + auto it = var_map.find(var); + if (it != var_map.end()) return (*it).second; + return NullOpt; + }; + } + + return EraseToWellDefined(info, f_shape_var_map, f_var_map, ana); +} + +TVM_REGISTER_GLOBAL("relax.analysis.EraseToWellDefined") + .set_body_typed([](const StructInfo& info, Map shape_var_map, + Map var_map) { + return EraseToWellDefined(info, shape_var_map, var_map); + }); + +//-------------------------- +// IsBaseOf +//-------------------------- +class StructInfoBaseChecker + : public StructInfoFunctor { + public: + explicit StructInfoBaseChecker(arith::Analyzer* ana) : analyzer_(ana) {} + + BaseCheckResult VisitStructInfo(const StructInfo& lhs, const StructInfo& other) override { + // quick path + // Note: subclass may disable this quick path if we need to go over all struct info. + if (lhs.same_as(other)) return BaseCheckResult::kPass; + return StructInfoFunctor::VisitStructInfo(lhs, other); + } + + // Object is base of everything + BaseCheckResult VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { + return BaseCheckResult::kPass; + } + + BaseCheckResult VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + return lhs->dtype == rhs->dtype ? BaseCheckResult::kPass : BaseCheckResult::kFailL0; + } + + BaseCheckResult VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + // lhs have unknown ndim + if (lhs->IsUnknownNdim()) return BaseCheckResult::kPass; + + // ndim must match + if (lhs->ndim != rhs->ndim) { + if (rhs->IsUnknownNdim()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + + // lhs does not have symbolic value + if (!lhs->values.defined()) return BaseCheckResult::kPass; + // rhs does not have symbolic value but lhs do. + if (!rhs->values.defined()) return BaseCheckResult::kFailL2; + + // shape match check + return ShapeMatchCheck(lhs->values.value(), rhs->values.value()); + } + + BaseCheckResult VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + // dtype mismatch + if (!lhs->IsUnknownDtype() && lhs->dtype != rhs->dtype) { + if (rhs->IsUnknownDtype()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + + // ndim msiamtch + if (!lhs->IsUnknownNdim() && lhs->ndim != rhs->ndim) { + if (rhs->IsUnknownNdim()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + // lhs does not have defined shape and everything else matches + if (!lhs->shape.defined()) return BaseCheckResult::kPass; + // rhs does not have symbolic value but lhs don't + if (!rhs->shape.defined()) return BaseCheckResult::kFailL2; + + // shape match check + return ShapeMatchCheck(lhs->shape.value(), rhs->shape.value()); + } + + BaseCheckResult VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + return ArrayCheck(lhs->fields, rhs->fields); + } + + BaseCheckResult VisitStructInfo_(const FuncStructInfoNode* lhs, + const StructInfo& other) override { + auto* rhs = other.as(); + if (rhs == nullptr) { + if (other.as()) return BaseCheckResult::kFailL1; + return BaseCheckResult::kFailL0; + } + + // lhs opaque handling + if (lhs->IsOpaque()) { + if (lhs->derive_func.defined()) { + // function proving is best effort. + return lhs->derive_func.same_as(rhs->derive_func) ? BaseCheckResult::kPass + : BaseCheckResult::kFailL2; + } + // no derivation function, only depends on ret + return this->VisitStructInfo(lhs->ret, rhs->ret); + } + + // Function check is best effort. + // rhs is opaque but lhs is not + if (rhs->IsOpaque()) return BaseCheckResult::kFailL2; + + // NOTE: lhs->params, rhs->params may contain different symbolic + // vars that needs to be re-mapped to each other. + // This can only be done through structural equality check and not ArrayCheck. + // + // So we check structural equality here and if two are structurally + // equal return true. + // + // otherwise we do best effort BaseArrayCheck. + // + // This still does not handle cases where some arguments are sub of another + // while other parameters needs to get remapped. + // + // Given we only do best effort checking in these cases, and such cases + // are likely not a primary concern atm, we take this approach here. + if (struct_equal_(GetRef(lhs), other)) return BaseCheckResult::kPass; + + auto param_check = FuncParamsCheck(lhs->params.value(), rhs->params.value()); + auto ret_check = this->VisitStructInfo(lhs->ret, rhs->ret); + return CombineCheck(param_check, ret_check); + } + + protected: + // analyzer + arith::Analyzer* analyzer_; + // struct equal checker + StructuralEqual struct_equal_; + + // customizable functions. + /*! + * \brief Check symbolic shape value equivalence. + * \param lhs The left hand shape. + * \param rhs The right hand shape. + * \return CheckResult. + */ + virtual BaseCheckResult PrimValueMatchCheck(const PrimExpr& lhs, const PrimExpr& rhs) { + // get static shape checking right. + auto* int_lhs = lhs.as(); + auto* int_rhs = rhs.as(); + if (int_lhs && int_rhs) { + if (int_lhs->value == int_rhs->value) { + return BaseCheckResult::kPass; + } else { + return BaseCheckResult::kFailL0; + } + } + return analyzer_->CanProveEqual(lhs, rhs) ? BaseCheckResult::kPass : BaseCheckResult::kFailL2; + } + /*! + * \brief CheckShape value. + * \param lhs The left hand shape. + * \param rhs The right hand shape. + * \return CheckResult. + */ + virtual BaseCheckResult ShapeMatchCheck(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) return BaseCheckResult::kFailL0; + + BaseCheckResult ret = BaseCheckResult::kPass; + for (size_t i = 0; i < lhs.size(); ++i) { + auto cmp_ret = PrimValueMatchCheck(lhs[i], rhs[i]); + if (ret == BaseCheckResult::kFailL0) return ret; + ret = CombineCheck(cmp_ret, ret); + } + return ret; + } + + /*! + * \brief CheckShape value. + * \param lhs The left hand shape. + * \param rhs The right hand shape. + * \return Check result. + */ + virtual BaseCheckResult ShapeMatchCheck(const Expr& lhs, const Expr& rhs) { + if (lhs.same_as(rhs)) return BaseCheckResult::kPass; + auto* lhs_shape = lhs.as(); + auto* rhs_shape = rhs.as(); + if (lhs_shape && rhs_shape) { + return ShapeMatchCheck(lhs_shape->values, rhs_shape->values); + } else { + return BaseCheckResult::kFailL2; + } + } + + /*! + * \brief CheckShape function parameters. + * \param lhs The left hand params. + * \param rhs The right hand params. + * \return Check result. + */ + virtual BaseCheckResult FuncParamsCheck(const Array& lhs, + const Array& rhs) { + auto res = ArrayCheck(lhs, rhs); + // treat L1 failures in params checking as L2. + if (res == BaseCheckResult::kFailL1) res = BaseCheckResult::kFailL2; + return res; + } + // helper functions + /*! + * \brief Combine check results. + * \param lhs The left operand. + * \param rhs The righr operand. + * \return The check result. + */ + static BaseCheckResult CombineCheck(BaseCheckResult lhs, BaseCheckResult rhs) { + if (lhs == BaseCheckResult::kFailL0 || rhs == BaseCheckResult::kFailL0) { + return BaseCheckResult::kFailL0; + } + if (lhs == BaseCheckResult::kFailL1 || rhs == BaseCheckResult::kFailL1) { + return BaseCheckResult::kFailL1; + } + if (lhs == BaseCheckResult::kFailL2 || rhs == BaseCheckResult::kFailL2) { + return BaseCheckResult::kFailL2; + } + return BaseCheckResult::kPass; + } + + /*! + * \brief Generic helper function to check arrays. + * \param lhs The left operand. + * \param rhs The right operand. + */ + BaseCheckResult ArrayCheck(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) return BaseCheckResult::kFailL0; + BaseCheckResult ret = BaseCheckResult::kPass; + + for (size_t i = 0; i < lhs.size(); ++i) { + auto cmp_ret = this->VisitStructInfo(lhs[i], rhs[i]); + if (ret == BaseCheckResult::kFailL0) return ret; + ret = CombineCheck(cmp_ret, ret); + } + return ret; + } +}; + +BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived, + arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return StructInfoBaseChecker(&inst)(base, derived); + } else { + return StructInfoBaseChecker(ana)(base, derived); + } +} + +TVM_REGISTER_GLOBAL("relax.analysis.StructInfoBaseCheck") + .set_body_typed([](const StructInfo& base, const StructInfo& derived) -> int { + return static_cast(StructInfoBaseCheck(base, derived)); + }); + +bool IsBaseOf(const StructInfo& base, const StructInfo& derived, arith::Analyzer* ana) { + return StructInfoBaseCheck(base, derived, ana) == BaseCheckResult::kPass; +} + +TVM_REGISTER_GLOBAL("relax.StructInfoIsBaseOf") + .set_body_typed([](const StructInfo& base, const StructInfo& derived) { + return IsBaseOf(base, derived); + }); + +//-------------------------- +// UnifyToLCA +//-------------------------- +class StructInfoLCAFinder + : public StructInfoFunctor { + public: + explicit StructInfoLCAFinder(arith::Analyzer* ana) : analyzer_(ana) {} + + StructInfo VisitStructInfo(const StructInfo& lhs, const StructInfo& other) final { + // quick path + if (lhs.same_as(other)) return lhs; + return StructInfoFunctor::VisitStructInfo(lhs, other); + } + + // Object is based of everything, unify to object. + StructInfo VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { + return GetRef(lhs); + } + + StructInfo VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + if (lhs->dtype == rhs->dtype) return GetRef(lhs); + // PrimType will be treated as their boxed(object) values + // as a result we can unify to object. + return ObjectStructInfo(lhs->span); + } + + StructInfo VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + + int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim; + if (lhs->ndim != rhs->ndim || !lhs->values.defined() || !rhs->values.defined() || + !CanProveShapeEqual(lhs->values.value(), rhs->values.value(), analyzer_)) { + // prefers return same when possible + if (!lhs->values.defined() && lhs->ndim == ndim) { + return GetRef(lhs); + } else { + return ShapeStructInfo(ndim, lhs->span); + } + } + // equals to each other + return GetRef(lhs); + } + + StructInfo VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + + // find the target dtype and ndim. + DataType dtype = lhs->dtype == rhs->dtype ? lhs->dtype : DataType::Void(); + int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim; + // if ndim mismatch or one side of shape is missing + // then we cannot keep in symbolic shape + if (lhs->ndim != rhs->ndim || !lhs->shape.defined() || !rhs->shape.defined() || + !CanProveShapeEqual(lhs->shape.value(), rhs->shape.value(), analyzer_)) { + // reuse lhs when possible + if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim) { + return GetRef(lhs); + } else { + return TensorStructInfo(dtype, ndim, lhs->span); + } + } + // symbolic shape match but dtype mismatch + if (lhs->dtype != dtype) { + return TensorStructInfo(lhs->shape.value(), dtype, lhs->span); + } else { + return GetRef(lhs); + } + } + + StructInfo VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + Optional> fields = UnifyArray(lhs->fields, rhs->fields); + // tuple length not the same. + if (!fields.defined()) return ObjectStructInfo(lhs->span); + + // same length tuple. + if (!fields.same_as(lhs->fields)) { + return TupleStructInfo(fields.value(), lhs->span); + } else { + return GetRef(lhs); + } + } + + StructInfo VisitStructInfo_(const FuncStructInfoNode* lhs, const StructInfo& other) final { + auto* rhs = other.as(); + if (rhs == nullptr) return ObjectStructInfo(lhs->span); + + // lhs opaque handling + if (lhs->IsOpaque()) { + if (lhs->derive_func.defined()) { + if (lhs->derive_func.same_as(rhs->derive_func)) { + return GetRef(lhs); + } else { + // Create a new opaque with object return + return FuncStructInfo::OpaqueFunc(ObjectStructInfo(), lhs->span); + } + } else { + // no derivation function, only depends on ret + StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); + if (ret.same_as(lhs->ret)) return GetRef(lhs); + return FuncStructInfo::OpaqueFunc(ret, lhs->span); + } + } + // rhs is opaque, lhs is not + if (rhs->IsOpaque()) { + // unify ret value, note that rhs's ret is context free(because it is opaque) + // so result of the unify is also context-free. + StructInfo ret = this->VisitStructInfo(lhs->ret, rhs->ret); + return FuncStructInfo::OpaqueFunc(ret, lhs->span); + } + + // Both lhs and rhs are not opaque + // NOTE: lhs->params, rhs->params may contain different symbolic + // vars that needs to be re-mapped to each other. + // This can only be done through structural equality check. + // + // So we check structural equality here and if two are structurally + // equal return true. + // + // otherwise we do best effort of unify types without considering var remap. + // + // This still does not handle cases where some arguments are sub of another + // while other parameters needs to get remapped. + // + // Given we only do best effort checking in these cases, and such cases + // are likely not a primary concern atm, we take this approach here. + if (struct_equal_(GetRef(lhs), GetRef(rhs))) { + return GetRef(lhs); + } + + auto params = UnifyArray(lhs->params.value(), rhs->params.value()); + auto ret = this->VisitStructInfo(lhs->ret, rhs->ret); + + if (params.same_as(lhs->params) && ret.same_as(lhs->ret)) { + return GetRef(lhs); + } else { + // fail to unify the params + if (!params.defined()) { + return FuncStructInfo::OpaqueFunc(ret, lhs->span); + } else { + return FuncStructInfo(params.value(), ret, lhs->span); + } + } + } + + private: + // analyzer + arith::Analyzer* analyzer_; + // struct equal checker + StructuralEqual struct_equal_; + + // check arrays + Optional> UnifyArray(const Array& lhs, + const Array& rhs) { + if (lhs.same_as(rhs)) return lhs; + if (lhs.size() != rhs.size()) return NullOpt; + size_t index = 0; + return lhs.Map([&](const StructInfo& a) { return this->VisitStructInfo(a, rhs[index++]); }); + } +}; + +StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return StructInfoLCAFinder(&inst)(lhs, rhs); + } else { + return StructInfoLCAFinder(ana)(lhs, rhs); + } +} + +TVM_REGISTER_GLOBAL("relax.analysis.StructInfoLCA") + .set_body_typed([](const StructInfo& lhs, const StructInfo& rhs) { + return StructInfoLCA(lhs, rhs); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc new file mode 100644 index 000000000000..45868a488a36 --- /dev/null +++ b/src/relax/ir/expr.cc @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +using tvm::ReprPrinter; +using tvm::runtime::Optional; + +TVM_REGISTER_NODE_TYPE(IdNode); + +Id::Id(String name_hint) { + ObjectPtr n = make_object(); + n->name_hint = std::move(name_hint); + data_ = std::move(n); +} + +Call::Call(Expr op, Array args, Attrs attrs, Array sinfo_args, Span span) { + ObjectPtr n = make_object(); + n->op = std::move(op); + n->args = std::move(args); + n->attrs = std::move(attrs); + n->sinfo_args = std::move(sinfo_args); + n->span = std::move(span); + data_ = std::move(n); +} + +Call WithFields(Call call, Optional opt_op, Optional> opt_args, + Optional opt_attrs, Optional> opt_sinfo_args, + Optional opt_span) { + // Collect new values for fields. + Expr op = opt_op.value_or(call->op); + Array args = opt_args.value_or(call->args); + Attrs attrs = opt_attrs.value_or(call->attrs); + Array sinfo_args = opt_sinfo_args.value_or(call->sinfo_args); + Span span = opt_span.value_or(call->span); + + // Check if anything changed. + bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && span.same_as(call->span); + if (unchanged) { + if (args.size() == call->args.size()) { + for (size_t i = 0; i < args.size(); i++) { + unchanged &= args[i].same_as(call->args[i]); + } + } else { + unchanged = false; + } + } + if (unchanged) { + if (sinfo_args.size() == call->sinfo_args.size()) { + for (size_t i = 0; i < sinfo_args.size(); i++) { + unchanged &= sinfo_args[i].same_as(call->sinfo_args[i]); + } + } else { + unchanged = false; + } + } + + if (!unchanged) { + // If call is only references, update it in place. Otherwise copy and update. + CallNode* cow_call_node = call.CopyOnWrite(); + cow_call_node->op = op; + cow_call_node->args = args; + cow_call_node->attrs = attrs; + cow_call_node->sinfo_args = sinfo_args; + cow_call_node->span = span; + } + return call; +} + +TVM_REGISTER_NODE_TYPE(CallNode); + +TVM_REGISTER_GLOBAL("relax.Call") + .set_body_typed([](Expr op, Array args, Attrs attrs, Array sinfo_args, + Span span) { return Call(op, args, attrs, sinfo_args, span); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", " + << node->sinfo_args << ")"; + }); + +If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { + ObjectPtr n = make_object(); + n->cond = std::move(cond); + n->true_branch = std::move(true_branch); + n->false_branch = std::move(false_branch); + n->span = std::move(span); + data_ = std::move(n); +} + +If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branch, + Optional opt_false_branch, Optional opt_span) { + Expr cond = opt_cond.value_or(if_expr->cond); + Expr true_branch = opt_true_branch.value_or(if_expr->true_branch); + Expr false_branch = opt_false_branch.value_or(if_expr->false_branch); + Span span = opt_span.value_or(if_expr->span); + + bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) && + false_branch.same_as(if_expr->false_branch) && span.same_as(if_expr->span); + + if (!unchanged) { + IfNode* cow_if_node = if_expr.CopyOnWrite(); + cow_if_node->cond = cond; + cow_if_node->true_branch = true_branch; + cow_if_node->false_branch = false_branch; + cow_if_node->span = span; + } + return if_expr; +} + +TVM_REGISTER_NODE_TYPE(IfNode); + +TVM_REGISTER_GLOBAL("relax.If") + .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span span) { + return If(cond, true_branch, false_branch, span); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", " + << node->false_branch << ")"; + }); + +Tuple::Tuple(tvm::Array fields, Span span) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TupleNode); + +TVM_REGISTER_GLOBAL("relax.Tuple").set_body_typed([](tvm::Array fields, Span span) { + return Tuple(fields, span); +}); + +Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_span) { + Array fields = opt_fields.value_or(tuple->fields); + Span span = opt_span.value_or(tuple->span); + + bool all_fields_unchanged = true; + if (fields.size() == tuple->fields.size()) { + for (size_t i = 0; i < fields.size(); i++) { + all_fields_unchanged &= fields[i].same_as(tuple->fields[i]); + } + } else { + all_fields_unchanged = false; + } + + all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span); + if (!all_fields_unchanged) { + TupleNode* cow_tuple_node = tuple.CopyOnWrite(); + cow_tuple_node->fields = fields; + cow_tuple_node->span = span; + } + return tuple; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Tuple(" << node->fields << ")"; + }); + +TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { + ObjectPtr n = make_object(); + n->tuple = std::move(tuple); + n->index = index; + n->span = std::move(span); + data_ = std::move(n); +} + +TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, + Optional opt_index, Optional opt_span) { + Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); + Integer index = opt_index.value_or(tuple_get_item->index); + Span span = opt_span.value_or(tuple_get_item->span); + + bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && + span.same_as(tuple_get_item->span); + if (!unchanged) { + TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite(); + cow_tuple_get_item_node->tuple = tuple; + cow_tuple_get_item_node->index = index.IntValue(); + cow_tuple_get_item_node->span = span; + } + return tuple_get_item; +} + +TVM_REGISTER_NODE_TYPE(TupleGetItemNode); + +TVM_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int index) { + return TupleGetItem(tuple, index); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; + }); + +TVM_REGISTER_NODE_TYPE(ShapeExprNode); + +ShapeExpr::ShapeExpr(Array values, Span span) { + ObjectPtr n = make_object(); + + n->values = values.Map([](PrimExpr value) { + if (value->IsInstance()) { + return tvm::cast(DataType::Int(64), value); + } + ICHECK(value.dtype() == DataType::Int(64)) + << "the value in ShapeStructInfo can only have dtype of int64"; + return value; + }); + n->span = span; + n->checked_type_ = ShapeType(values.size()); + n->struct_info_ = ShapeStructInfo(values, span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, Span span) { + return ShapeExpr(values, span); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const ShapeExprNode* node = static_cast(ref.get()); + p->stream << "ShapeExpr("; + for (auto it = node->values.begin(); it != node->values.end(); it++) { + if (it != node->values.begin()) { + p->stream << ", "; + } + p->stream << *it; + } + p->stream << ")"; + }); + +TVM_REGISTER_NODE_TYPE(VarNode); + +Var::Var(Id vid, Optional struct_info_annotation, Span span) { + ObjectPtr n = make_object(); + n->vid = std::move(vid); + if (struct_info_annotation) { + n->checked_type_ = GetStaticType(struct_info_annotation.value()); + } + n->struct_info_ = std::move(struct_info_annotation); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.Var") + .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { + return Var(name_hint, struct_info_annotation, span); + }); + +TVM_REGISTER_GLOBAL("relax.VarFromId") + .set_body_typed([](Id vid, Optional struct_info_annotation, Span span) { + return Var(vid, struct_info_annotation, span); + }); + +TVM_REGISTER_NODE_TYPE(DataflowVarNode); + +DataflowVar::DataflowVar(Id vid, Optional struct_info_annotation, Span span) { + ObjectPtr n = make_object(); + n->vid = std::move(vid); + if (struct_info_annotation) { + n->checked_type_ = GetStaticType(struct_info_annotation.value()); + } + n->struct_info_ = std::move(struct_info_annotation); + n->span = std::move(span); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowVar") + .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { + return DataflowVar(name_hint, struct_info_annotation, span); + }); + +TVM_REGISTER_GLOBAL("relax.DataflowVarFromId") + .set_body_typed([](Id vid, Optional struct_info_annotation, Span span) { + return DataflowVar(vid, struct_info_annotation, span); + }); + +Constant::Constant(runtime::NDArray data, Span span) { + ObjectPtr n = make_object(); + n->data = std::move(data); + n->span = std::move(span); + + // set struct info. + Array values; + auto shape_tuple = n->data.Shape(); + for (size_t dim = 0; dim < shape_tuple.size(); ++dim) { + values.push_back(IntImm(DataType::Int(64), shape_tuple[dim])); + } + TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), span); + + n->struct_info_ = tinfo; + n->checked_type_ = DynTensorType(tinfo->ndim, tinfo->dtype); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ConstantNode); + +TVM_REGISTER_GLOBAL("relax.Constant").set_body_typed([](runtime::NDArray data, Span span = Span()) { + return Constant(data, span); +}); + +PrimValue::PrimValue(PrimExpr value, Span span) { + ObjectPtr n = make_object(); + n->checked_type_ = PrimType(value.dtype()); + n->struct_info_ = PrimStructInfo(value.dtype()); + n->value = std::move(value); + n->span = std::move(span); + data_ = std::move(n); +} + +PrimValue PrimValue::Int64(int64_t value, Span span) { + return PrimValue(IntImm(DataType::Int(64), value), span); +} + +TVM_REGISTER_NODE_TYPE(PrimValueNode); + +TVM_REGISTER_GLOBAL("relax.PrimValue").set_body_typed([](PrimExpr value, Span span) { + return PrimValue(value, span); +}); + +StringImm::StringImm(String value, Span span) { + ObjectPtr n = make_object(); + n->value = std::move(value); + n->span = std::move(span); + // use the base structinfo for now + // we can choose to introduce more fine-grained struct info later if necessary. + n->checked_type_ = ObjectType(); + n->struct_info_ = ObjectStructInfo(); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(StringImmNode); + +TVM_REGISTER_GLOBAL("relax.StringImm").set_body_typed([](String value, Span span) { + return StringImm(value, span); +}); + +DataTypeImm::DataTypeImm(DataType value, Span span) { + ObjectPtr n = make_object(); + n->value = std::move(value); + n->span = std::move(span); + // use the base structinfo for now + // we can choose to introduce more fine-grained struct info later if necessary. + n->checked_type_ = ObjectType(); + n->struct_info_ = ObjectStructInfo(); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DataTypeImmNode); + +TVM_REGISTER_GLOBAL("relax.DataTypeImm").set_body_typed([](DataType value, Span span) { + return DataTypeImm(value, span); +}); + +TVM_REGISTER_NODE_TYPE(MatchCastNode); + +MatchCast::MatchCast(Var var, Expr value, StructInfo struct_info, Span span) { + ObjectPtr n = make_object(); + ICHECK(var.defined()) << "MatchCast requires var to be defined"; + n->var = std::move(var); + n->value = std::move(value); + n->struct_info = std::move(struct_info); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.MatchCast") + .set_body_typed([](Var var, Expr value, StructInfo struct_info, Span span) { + return MatchCast(var, value, struct_info, span); + }); + +TVM_REGISTER_NODE_TYPE(VarBindingNode); + +VarBinding::VarBinding(Var var, Expr value, Span span) { + ObjectPtr n = make_object(); + n->var = std::move(var); + n->value = std::move(value); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, Span span) { + return VarBinding(var, value, span); +}); + +TVM_REGISTER_NODE_TYPE(BindingBlockNode); + +BindingBlock::BindingBlock(Array bindings, Span span) { + ObjectPtr n = make_object(); + n->bindings = std::move(bindings); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array bindings, Span span) { + return BindingBlock(bindings, span); +}); + +TVM_REGISTER_NODE_TYPE(DataflowBlockNode); + +DataflowBlock::DataflowBlock(Array bindings, Span span) { + ObjectPtr n = make_object(); + n->bindings = std::move(bindings); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowBlock").set_body_typed([](Array bindings, Span span) { + return DataflowBlock(bindings, span); +}); + +TVM_REGISTER_NODE_TYPE(SeqExprNode); + +SeqExpr::SeqExpr(Array blocks, Expr body, Span span) { + ObjectPtr n = make_object(); + n->blocks = std::move(blocks); + n->body = std::move(body); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.SeqExpr") + .set_body_typed([](Array blocks, Expr body, Span span) { + return SeqExpr(blocks, body, span); + }); + +TVM_REGISTER_NODE_TYPE(FunctionNode); + +Function::Function(Array params, Expr body, Optional ret_struct_info, + DictAttrs attrs, Span span) { + // Set the function type. + // For function, we take a conservative approach and require the function type + // to be known at construction time. + Array param_sinfo; + + for (const Var& param : params) { + CHECK(param->struct_info_.defined()) + << "relax.Function requires params to contain struct_info_"; + param_sinfo.push_back(GetStructInfo(param)); + } + + Optional body_sinfo; + + if (body->struct_info_.defined()) { + body_sinfo = GetStructInfo(body); + } + + if (ret_struct_info.defined()) { + // allow body to override ret if body is more fine-grained. + if (body_sinfo.defined()) { + if (IsBaseOf(ret_struct_info.value(), body_sinfo.value())) { + ret_struct_info = body_sinfo; + } + } + } else { + CHECK(body_sinfo.defined()) + << "Function do not have a return signature and body is not normalized"; + ret_struct_info = body_sinfo; + } + + FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value()); + + // set the fields + ObjectPtr n = make_object(); + n->params = std::move(params); + n->body = std::move(body); + n->ret_struct_info = std::move(ret_struct_info.value()); + n->checked_type_ = GetStaticType(func_sinfo); + n->struct_info_ = std::move(func_sinfo); + n->attrs = std::move(attrs); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.Function") + .set_body_typed([](Array params, Expr body, Optional ret_struct_info, + DictAttrs attrs, + Span span) { return Function(params, body, ret_struct_info, attrs, span); }); + +Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, DictAttrs attrs, + Span span) { + Array param_sinfo; + for (const Var& param : params) { + ICHECK(param->checked_type_.defined()) + << "relax.Function requires params to contain checked_type_."; + param_sinfo.push_back(GetStructInfo(param)); + } + FuncStructInfo finfo(param_sinfo, ret_struct_info); + + // set the fields + ObjectPtr n = make_object(); + n->params = std::move(params); + n->body = Expr(); + n->checked_type_ = GetStaticType(finfo); + n->struct_info_ = std::move(finfo); + n->ret_struct_info = std::move(ret_struct_info); + n->attrs = std::move(attrs); + n->span = std::move(span); + return Function(std::move(n)); +} + +TVM_REGISTER_GLOBAL("relax.FunctionCreateEmpty") + .set_body_typed([](Array params, StructInfo ret_struct_info, DictAttrs attrs, Span span) { + return Function::CreateEmpty(params, ret_struct_info, attrs, span); + }); + +// Special opaque derivation function for ExternFunc +// Take look at sinfo_args to figure out the return StructInfo. +TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_by_sinfo_args") + .set_body_typed([](const Call& call, const BlockBuilder& ctx) -> StructInfo { + ICHECK(call->sinfo_args.defined()) << "sinfo_args field of CallNode should always be defined"; + if (call->sinfo_args.empty()) { + return ObjectStructInfo(); + } else if (call->sinfo_args.size() == 1) { + return call->sinfo_args[0]; + } else { + return TupleStructInfo(call->sinfo_args); + } + }); + +// Get the derive function. +FuncStructInfo GetExternFuncStructInfo() { + EnvFunc fn = EnvFunc::Get("tvm.relax.struct_info.infer_by_sinfo_args"); + StructInfoDeriveFunc derive; + derive = fn; + return FuncStructInfo::OpaqueFunc(derive); +} + +TVM_REGISTER_NODE_TYPE(ExternFuncNode); + +ExternFunc::ExternFunc(String global_symbol, Span span) { + ObjectPtr n = make_object(); + n->global_symbol = std::move(global_symbol); + n->span = span; + static auto sinfo = GetExternFuncStructInfo(); + n->struct_info_ = sinfo; + n->checked_type_ = GetStaticType(sinfo); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol, Span span) { + return ExternFunc(global_symbol, span); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const auto* node = static_cast(ref.get()); + p->stream << "ExternFunc(\"" << node->global_symbol << "\")"; + }); + +Expr GetShapeOf(const Expr& expr) { + // default case, to be normalized. + ICHECK(expr->struct_info_.defined()) << "GetShapeOf can only be applied to normalized expr"; + auto* tinfo = GetStructInfoAs(expr); + + ICHECK(tinfo != nullptr) << "ShapeOf can only be applied to expr with TensorStructInfo"; + if (tinfo->shape.defined()) return tinfo->shape.value(); + + static const Op& op = Op::Get("relax.shape_of"); + // default case, call shape of, eagerly normalize the expr. + relax::Call call_shape_of(op, {expr}, {}, {}); + UpdateStructInfo(call_shape_of, ShapeStructInfo(tinfo->ndim)); + return call_shape_of; +} + +TVM_REGISTER_GLOBAL("relax.GetShapeOf").set_body_typed([](const Expr& expr) { + return GetShapeOf(expr); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc new file mode 100644 index 000000000000..048de7950f97 --- /dev/null +++ b/src/relax/ir/expr_functor.cc @@ -0,0 +1,546 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/expr_functor.cc + * \brief A wrapper around ExprFunctor which functionally updates the AST. + * + * ExprMutator uses memoization and self return in order to amortize + * the cost of using functional updates. + */ +#include +#include +#include +#include + +// functions to be overriden. +#define RELAX_VISIT_BINDING_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const ObjectRef& n, TSelf* self, const VarBindingNode* binding) { \ + self->VisitBinding_(binding, static_cast(n.get())); \ + }); + +#define RELAX_VAR_BINDING_DISPATCH_IMPL(Type) \ + Type::VisitBindingVTable Type::InitVisitBindingVTable() { \ + VisitBindingVTable vtable; \ + RELAX_VISIT_BINDING_DISPATCH(ConstantNode); \ + RELAX_VISIT_BINDING_DISPATCH(TupleNode); \ + RELAX_VISIT_BINDING_DISPATCH(VarNode); \ + RELAX_VISIT_BINDING_DISPATCH(DataflowVarNode); \ + RELAX_VISIT_BINDING_DISPATCH(ShapeExprNode); \ + RELAX_VISIT_BINDING_DISPATCH(ExternFuncNode); \ + RELAX_VISIT_BINDING_DISPATCH(GlobalVarNode); \ + RELAX_VISIT_BINDING_DISPATCH(FunctionNode); \ + RELAX_VISIT_BINDING_DISPATCH(CallNode); \ + RELAX_VISIT_BINDING_DISPATCH(SeqExprNode); \ + RELAX_VISIT_BINDING_DISPATCH(IfNode); \ + RELAX_VISIT_BINDING_DISPATCH(OpNode); \ + RELAX_VISIT_BINDING_DISPATCH(TupleGetItemNode); \ + RELAX_VISIT_BINDING_DISPATCH(PrimValueNode); \ + RELAX_VISIT_BINDING_DISPATCH(StringImmNode); \ + RELAX_VISIT_BINDING_DISPATCH(DataTypeImmNode); \ + return vtable; \ + } \ + void Type::VisitBinding_(const VarBindingNode* binding) { \ + static VisitBindingVTable vtable = InitVisitBindingVTable(); \ + const Expr& value = binding->value; \ + ICHECK(value.defined()) << "Found null pointer node while traversing AST."; \ + ICHECK(vtable.can_dispatch(value)) \ + << "VisitVarBinding do not allow binding value type" << value->GetTypeKey(); \ + vtable(value, this, binding); \ + } + +// functions to be overriden. +#define RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(OP) \ + void ExprVisitor::VisitBinding_(const VarBindingNode* binding, const OP* value) { \ + this->VisitExpr(binding->value); \ + this->VisitVarDef(binding->var); \ + } + +// functions to be overriden. +#define RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(OP) \ + void ExprMutator::VisitBinding_(const VarBindingNode* binding, const OP* value) { \ + Expr new_value = this->VisitExpr(binding->value); \ + this->ReEmitBinding(binding, new_value); \ + } + +namespace tvm { +namespace relax { + +// ================== +// ExprVisitor + +void ExprVisitor::VisitExprDepStructInfoField(const StructInfo& struct_info) { + // recurse into struct info in case they depend on value + // under the current scope. + default_struct_info_field_visitor_.VisitStructInfo(struct_info); +} + +ExprVisitor::DefaultStructInfoFieldVisitor::DefaultStructInfoFieldVisitor(ExprVisitor* parent) + : parent_(parent) {} + +void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfoExprField(const Expr& expr) { + parent_->VisitExpr(expr); +} + +void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfoExprField(const PrimExpr& expr) { + parent_->VisitPrimExpr(expr); +} + +void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfo_(const FuncStructInfoNode* op) { + // Do not recurse into function struct info + // as they won't contain ref to values in current scope. +} + +void ExprVisitor::VisitExpr(const Expr& expr) { ExprFunctor::VisitExpr(expr); } + +void ExprVisitor::VisitExpr_(const ConstantNode* op) { + this->VisitSpan(op->span); + // Constant's StructInfo does not depend on Expr. +} + +void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { + this->VisitSpan(op->span); + // FuncStructInfo is not value-dep +} + +void ExprVisitor::VisitExpr_(const TupleNode* op) { + this->VisitSpan(op->span); + for (Expr field : op->fields) { + this->VisitExpr(field); + } + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +// Visit the use-site of a defined Var +void ExprVisitor::VisitExpr_(const VarNode* op) { + this->VisitSpan(op->span); + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +// Visit the use-site of a defined DataflowVar +void ExprVisitor::VisitExpr_(const DataflowVarNode* op) { + this->VisitSpan(op->span); + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const FunctionNode* op) { + this->VisitSpan(op->span); + for (Var param : op->params) { + this->VisitVarDef(param); + } + + this->VisitExpr(op->body); + // FuncStructInfo does not depend on Expr. +} + +void ExprVisitor::VisitExpr_(const CallNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->op); + + for (StructInfo sinfo_arg : op->sinfo_args) { + this->VisitExprDepStructInfoField(sinfo_arg); + } + + for (Expr arg : op->args) { + this->VisitExpr(arg); + } + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const IfNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->cond); + this->VisitExpr(op->true_branch); + this->VisitExpr(op->false_branch); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const OpNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->tuple); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { + for (PrimExpr val : op->values) { + this->VisitPrimExpr(val); + } + this->VisitSpan(op->span); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const ExternFuncNode* op) { + this->VisitSpan(op->span); + // FuncStructInfo does not depend on Expr. +} + +void ExprVisitor::VisitExpr_(const SeqExprNode* op) { + this->VisitSpan(op->span); + for (BindingBlock block : op->blocks) { + this->VisitBindingBlock(block); + } + this->VisitExpr(op->body); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} + +void ExprVisitor::VisitExpr_(const PrimValueNode* op) { + this->VisitPrimExpr(op->value); + this->VisitSpan(op->span); +} + +void ExprVisitor::VisitExpr_(const StringImmNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitExpr_(const DataTypeImmNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitSpan(const Span& span) {} + +void ExprVisitor::VisitPrimExpr(const PrimExpr& expr) {} + +// implementations of binding visitor dispatch +RELAX_VAR_BINDING_DISPATCH_IMPL(ExprVisitor); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ConstantNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(TupleNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(VarNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataflowVarNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ShapeExprNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ExternFuncNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(GlobalVarNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(FunctionNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(CallNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(SeqExprNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(IfNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(OpNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(TupleGetItemNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(PrimValueNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(StringImmNode); +RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataTypeImmNode); + +void ExprVisitor::VisitBinding_(const MatchCastNode* binding) { + this->VisitExpr(binding->value); + this->VisitVarDef(binding->var); +} + +void ExprVisitor::VisitBindingBlock_(const BindingBlockNode* block) { + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } +} + +void ExprVisitor::VisitBindingBlock_(const DataflowBlockNode* block) { + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } +} + +void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) { this->VisitSpan(var->span); } + +void ExprVisitor::VisitVarDef_(const VarNode* var) { this->VisitSpan(var->span); } + +void ExprVisitor::VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + VisitBinding_(node); + } else if (const auto* node = binding.as()) { + VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } +} + +void ExprVisitor::VisitBindingBlock(const BindingBlock& block) { + if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } +} + +void ExprVisitor::VisitVarDef(const Var& var) { + if (const auto* node = var.as()) { + VisitVarDef_(node); + } else if (const auto* node = var.as()) { + VisitVarDef_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } +} + +class ExprApplyVisit : public ExprVisitor { + public: + explicit ExprApplyVisit(std::function f) : f_(f) {} + + void VisitExpr(const Expr& e) final { + ExprVisitor::VisitExpr(e); + f_(e); + } + + private: + std::function f_; +}; + +void PostOrderVisit(const Expr& e, std::function fvisit) { + ExprApplyVisit(fvisit).VisitExpr(e); +} + +TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit").set_body_typed([](Expr expr, PackedFunc f) { + PostOrderVisit(expr, [f](const Expr& n) { f(n); }); +}); + +// ================== +// ExprMutatorBase + +StructInfo ExprMutatorBase::VisitExprDepStructInfoField(const StructInfo& struct_info) { + // recurse into struct info in case they depend on value + // under the current scope. + return default_struct_info_field_mutator_.VisitStructInfo(struct_info); +} + +ExprMutatorBase::DefaultStructInfoFieldMutator::DefaultStructInfoFieldMutator( + ExprMutatorBase* parent) + : parent_(parent) {} + +Expr ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfoExprField(const Expr& expr) { + return parent_->VisitExpr(expr); +} + +PrimExpr ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfoExprField( + const PrimExpr& expr) { + return parent_->VisitPrimExpr(expr); +} + +StructInfo ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfo_( + const FuncStructInfoNode* op) { + // Do not recurse into function struct info + // as they won't contain ref to values in current scope. + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); } + +Expr ExprMutatorBase::VisitExpr_(const ConstantNode* op) { + // Constant' struct info won't be affected by Expr/PrimExpr change. + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr_(const GlobalVarNode* op) { + // FuncStructInfo won't be affected by Expr/PrimExpr change. + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { + bool unchanged = true; + tvm::Array fields; + for (Expr field : op->fields) { + Expr new_field = this->VisitExpr(field); + fields.push_back(new_field); + unchanged &= new_field.same_as(field); + } + + if (unchanged) { + // If tuple's struct info change it means that + // one of its fields' struct info will change + // so un-changed already implies that struct info won't change + return GetRef(op); + } else { + // when there is a change return a new tuple node + return Tuple(fields, op->span); + } +} + +// Visit the use-site of a defined Var +Expr ExprMutatorBase::VisitExpr_(const VarNode* op) { + // struct info of var-use should remain stable + // or the var itself will get replaced + return GetRef(op); +} + +// Visit the use-site of a defined DataflowVar +Expr ExprMutatorBase::VisitExpr_(const DataflowVarNode* op) { + // struct info of var-use should remain stable + // or the var itself will get replaced + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr_(const FunctionNode* op) { + // struct info of function is not value dependent + // so no need to check struct_info field + Expr body = this->VisitExpr(op->body); + + if (body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(op->params, body, op->ret_struct_info, op->attrs); + } +} + +Expr ExprMutatorBase::VisitExpr_(const CallNode* call_node) { + Expr new_op = this->VisitExpr(call_node->op); + bool unchanged = call_node->op.same_as(new_op); + + Array sinfo_args; + for (StructInfo sinfo_arg : call_node->sinfo_args) { + StructInfo new_sinfo_arg = this->VisitExprDepStructInfoField(sinfo_arg); + sinfo_args.push_back(new_sinfo_arg); + unchanged &= new_sinfo_arg.same_as(sinfo_arg); + } + + tvm::Array call_args; + for (Expr arg : call_node->args) { + Expr new_arg = this->VisitExpr(arg); + call_args.push_back(new_arg); + unchanged &= new_arg.same_as(arg); + } + + if (unchanged && VisitAndCheckStructInfoFieldUnchanged(call_node->struct_info_)) { + return GetRef(call_node); + } else { + return Call(new_op, call_args, call_node->attrs, sinfo_args, call_node->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const IfNode* op) { + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitExpr(op->true_branch); + Expr false_b = this->VisitExpr(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const OpNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const TupleGetItemNode* op) { + auto t = this->VisitExpr(op->tuple); + if (op->tuple.same_as(t)) { + // struct info can be deterministically derived by tuple and index + // if t does not change, then struct info won't change. + return GetRef(op); + } else { + return TupleGetItem(t, op->index, op->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const PrimValueNode* op) { + auto value = this->VisitPrimExpr(op->value); + if (op->value.same_as(value)) { + // struct info can be deterministically derived by value + // if value does not change, then struct info won't change. + return GetRef(op); + } + return PrimValue(value, op->span); +} + +Expr ExprMutatorBase::VisitExpr_(const StringImmNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const DataTypeImmNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { + auto values = op->values.Map([this](const PrimExpr& e) { return this->VisitPrimExpr(e); }); + + if (values.same_as(op->values)) { + // If values does not change, struct info won't change. + return GetRef(op); + } else { + return ShapeExpr(values, op->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const ExternFuncNode* op) { + // StructInfo of function remains value independent. + return GetRef(op); +} + +Expr ExprMutatorBase::VisitExpr_(const SeqExprNode* op) { + bool all_blocks_unchanged = true; + Array blocks; + for (auto block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + all_blocks_unchanged &= block.same_as(new_block); + } + + Expr body = this->VisitExpr(op->body); + + if (all_blocks_unchanged && body.same_as(op->body) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + return GetRef(op); + } + return SeqExpr(blocks, body); +} + +BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { + Array bindings; + if (const auto* node = block.as()) { + for (auto binding : node->bindings) { + if (auto var_binding = binding.as()) { + Expr new_value = this->VisitExpr(var_binding->value); + bindings.push_back(VarBinding(var_binding->var, new_value)); + } else if (auto match_cast = binding.as()) { + Expr new_value = this->VisitExpr(match_cast->value); + bindings.push_back(MatchCast(match_cast->var, new_value, match_cast->struct_info)); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + } + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + + if (block.as()) { + return DataflowBlock(bindings); + } else { + return BindingBlock(bindings); + } +} + +PrimExpr ExprMutatorBase::VisitPrimExpr(const PrimExpr& expr) { return expr; } + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 88046ed81f10..9db7cea6725d 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -21,7 +21,9 @@ * \file src/relax/ir/struct_info.cc * \brief Relax struct info. */ +#include #include +#include #include namespace tvm { @@ -228,7 +230,17 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Helper functions -// TODO(unity-team): add UpdateStructInfo once analysis.cc is upstreamed +void UpdateStructInfo(Expr expr, StructInfo struct_info) { + ICHECK(!expr->struct_info_.defined()) + << "the struct_info_ of the Expr to be updated must be nullptr for idempotency"; + expr->struct_info_ = struct_info; + // also set checked type + expr->checked_type_ = GetStaticType(struct_info); +} + +TVM_REGISTER_GLOBAL("relax.UpdateStructInfo").set_body_typed([](Expr expr, StructInfo struct_info) { + UpdateStructInfo(expr, struct_info); +}); TVM_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) { return GetStructInfo(expr); diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc new file mode 100644 index 000000000000..199491e3c63f --- /dev/null +++ b/src/relax/ir/struct_info_functor.cc @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file struct_info_functor.cc + * \brief Implementations of struct info functors. + */ +#include + +namespace tvm { +namespace relax { + +void StructInfoVisitor::VisitStructInfo_(const ObjectStructInfoNode* op) {} + +void StructInfoVisitor::VisitStructInfo_(const PrimStructInfoNode* op) {} + +void StructInfoVisitor::VisitStructInfo_(const ShapeStructInfoNode* op) { + if (op->values.defined()) { + for (PrimExpr value : op->values.value()) { + this->VisitStructInfoExprField(value); + } + } +} + +void StructInfoVisitor::VisitStructInfo_(const TensorStructInfoNode* op) { + if (op->shape.defined()) { + this->VisitStructInfoExprField(op->shape.value()); + } +} + +void StructInfoVisitor::VisitStructInfo_(const TupleStructInfoNode* op) { + for (StructInfo field : op->fields) { + this->VisitStructInfo(field); + } +} + +void StructInfoVisitor::VisitStructInfo_(const FuncStructInfoNode* op) { + if (op->params.defined()) { + for (StructInfo param : op->params.value()) { + this->VisitStructInfo(param); + } + } + this->VisitStructInfo(op->ret); +} + +StructInfo StructInfoMutator::VisitStructInfo_(const ObjectStructInfoNode* op) { + return GetRef(op); +} + +StructInfo StructInfoMutator::VisitStructInfo_(const PrimStructInfoNode* op) { + return GetRef(op); +} + +StructInfo StructInfoMutator::VisitStructInfo_(const ShapeStructInfoNode* op) { + Optional> values; + + if (op->values.defined()) { + // if no changes are made the original array will be returned. + values = op->values.value().Map( + [this](const PrimExpr& expr) { return this->VisitStructInfoExprField(expr); }); + } + + if (values.same_as(op->values)) { + return GetRef(op); + } else { + return ShapeStructInfo(values.value(), op->span); + } +} + +StructInfo StructInfoMutator::VisitStructInfo_(const TensorStructInfoNode* op) { + Optional shape; + + if (op->shape.defined()) { + shape = this->VisitStructInfoExprField(op->shape.value()); + } + + if (shape.same_as(op->shape)) { + return GetRef(op); + } else { + return TensorStructInfo(shape.value(), op->dtype, op->span); + } +} + +StructInfo StructInfoMutator::VisitStructInfo_(const TupleStructInfoNode* op) { + Array fields = + op->fields.Map([this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + + if (fields.same_as(op->fields)) { + return GetRef(op); + } else { + return TupleStructInfo(fields, op->span); + } +} + +StructInfo StructInfoMutator::VisitStructInfo_(const FuncStructInfoNode* op) { + Optional> params; + + if (op->params.defined()) { + params = op->params.value().Map( + [this](const StructInfo& sinfo) { return this->VisitStructInfo(sinfo); }); + } + + StructInfo ret = this->VisitStructInfo(op->ret); + + if (params.same_as(op->params) && ret.same_as(op->ret)) { + return GetRef(op); + } else { + ICHECK(ret.defined()) << "FuncStructInfo that contains params must contain ret"; + return FuncStructInfo(params.value(), ret, op->span); + } +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc new file mode 100644 index 000000000000..49ef1d7163f1 --- /dev/null +++ b/src/relax/ir/type.cc @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/type.cc + * \brief Relax type system. + */ +#include +#include + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(ShapeTypeNode); + +ShapeType::ShapeType(int ndim, Span span) { + ObjectPtr n = make_object(); + n->ndim = ndim; + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ShapeType").set_body_typed([](int ndim, Span span) { + return ShapeType(ndim, span); +}); + +ObjectType::ObjectType(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ObjectTypeNode); + +TVM_REGISTER_GLOBAL("relax.ObjectType").set_body_typed([](Span span) { return ObjectType(span); }); + +DynTensorType::DynTensorType(int ndim, DataType dtype, Span span) { + ObjectPtr n = make_object(); + n->ndim = std::move(ndim); + n->dtype = std::move(dtype); + n->span = span; + data_ = std::move(n); +} + +DynTensorType DynTensorType::CreateUnknownNDim(DataType dtype, Span span) { + ObjectPtr n = make_object(); + n->ndim = -1; + n->dtype = std::move(dtype); + n->span = std::move(span); + return DynTensorType(std::move(n)); +} + +TVM_REGISTER_NODE_TYPE(DynTensorTypeNode); + +TVM_REGISTER_GLOBAL("relax.DynTensorType").set_body_typed([](int ndim, DataType dtype, Span span) { + return DynTensorType(ndim, dtype, span); +}); + +PackedFuncType::PackedFuncType(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PackedFuncTypeNode); + +TVM_REGISTER_GLOBAL("relax.PackedFuncType").set_body_typed([](Span span) { + return PackedFuncType(span); +}); + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py new file mode 100644 index 000000000000..faf8fedcf4bf --- /dev/null +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -0,0 +1,418 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests analysis functions of struct info""" + +import pytest +import tvm +import tvm.testing +from tvm import relax as rx, TVMError +from tvm import tir + + +def test_get_static_type_basic(): + # object + s0 = rx.ObjectStructInfo() + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s0), rx.ObjectType()) + + # prim + s1 = rx.PrimStructInfo("float32") + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s1), tvm.ir.PrimType("float32")) + + +def test_get_static_type_shape(): + # shape + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s2 = rx.ShapeStructInfo([1, n + 1, m]) + s3 = rx.ShapeStructInfo(ndim=2) + + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s2), rx.ShapeType(ndim=3)) + + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s3), rx.ShapeType(ndim=2)) + + +def test_get_static_type_tensor(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + s4 = rx.TensorStructInfo([1, n + 1, m], "int64") + + tvm.ir.assert_structural_equal( + rx.analysis.get_static_type(s4), rx.DynTensorType(ndim=3, dtype="int64") + ) + + +def test_get_static_type_tuple(): + # tuple + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + s0 = rx.ObjectStructInfo() + s2 = rx.ShapeStructInfo([1, n + 1, m]) + s4 = rx.TensorStructInfo([1, n + 1, m], "int64") + t0 = rx.TupleStructInfo([s4, s0]) + t1 = rx.TupleStructInfo([t0, s2]) + + tvm.ir.assert_structural_equal( + rx.analysis.get_static_type(t1), + rx.TupleType( + [ + rx.TupleType([rx.DynTensorType(ndim=3, dtype="int64"), rx.ObjectType()]), + rx.ShapeType(ndim=3), + ] + ), + ) + + +def test_get_static_type_func(): + # tuple + def fn_info(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + def fn_type(): + x = rx.DynTensorType(ndim=3, dtype="float32") + y = rx.DynTensorType(ndim=3, dtype="float32") + z = rx.DynTensorType(ndim=2, dtype="float32") + return rx.FuncType([x, y], z) + + f0 = fn_info(1) + + tvm.ir.assert_structural_equal(rx.analysis.get_static_type(fn_info(1)), fn_type()) + + +def test_erase_to_well_defined_basic(): + s0 = rx.ObjectStructInfo() + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s0), s0) + + # prim + s1 = rx.PrimStructInfo("float32") + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1), s1) + + +def test_erase_to_well_defined_shape(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s2 = rx.ShapeStructInfo([1, n + 1, m]) + s3 = rx.ShapeStructInfo(ndim=2) + # have undefined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s2), rx.ShapeStructInfo(ndim=3) + ) + # all defined + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2, {n: n, m: m}), s2) + + # replacement + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s2, {n: 2, m: m + 1}), rx.ShapeStructInfo([1, 3, m + 1]) + ) + + # partial defined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s2, {n: n}), rx.ShapeStructInfo(ndim=3) + ) + + +def test_erase_to_well_defined_tensor(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + rshape = rx.Var("shape", rx.ShapeStructInfo(ndim=2)) + s0 = rx.TensorStructInfo(rshape, dtype="int32") + + # undefined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s0, None, None), + rx.TensorStructInfo(ndim=2, dtype="int32"), + ) + + # defined + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s0, None, {rshape: rshape}), s0 + ) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s0, None, {rshape: rx.ShapeExpr([1, 2])}), + rx.TensorStructInfo([1, 2], dtype="int32"), + ) + + s1 = rx.TensorStructInfo([m + 1, n], dtype="float32") + + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s1, {n: n, m: m}), s1) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s1, {n: 2, m: 3}), + rx.TensorStructInfo([4, 2], dtype="float32"), + ) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(s1, {m: m}, {rshape: rshape}), + rx.TensorStructInfo(ndim=2, dtype="float32"), + ) + + s2 = rx.TensorStructInfo([1, 2], dtype="float32") + + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(s2), s2) + + +def test_erase_to_well_defined_tuple(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + s0 = rx.ObjectStructInfo() + s2 = rx.ShapeStructInfo([1, m]) + s4 = rx.TensorStructInfo([1, n + 1, m], "int64") + t0 = rx.TupleStructInfo([s4, s0]) + t1 = rx.TupleStructInfo([t0, s2]) + + tvm.ir.assert_structural_equal( + rx.analysis.erase_to_well_defined(t1, {m: m + 1}), + rx.TupleStructInfo( + [ + rx.TupleStructInfo( + [rx.TensorStructInfo(ndim=3, dtype="int64"), rx.ObjectStructInfo()] + ), + rx.ShapeStructInfo([1, m + 1]), + ] + ), + ) + + +def test_erase_to_well_defined_func(): + def fn_info(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + f0 = fn_info(1) + + tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(f0), f0) + + +def test_base_check(): + BR = rx.analysis.BaseCheckResult + bcheck = rx.analysis.struct_info_base_check + + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + obj0 = rx.ObjectStructInfo() + prim0 = rx.PrimStructInfo("int32") + prim1 = rx.PrimStructInfo("float32") + + shape0 = rx.ShapeStructInfo(ndim=-1) + shape1 = rx.ShapeStructInfo(ndim=2) + shape2 = rx.ShapeStructInfo(ndim=3) + shape3 = rx.ShapeStructInfo([1, 2, 3]) + shape4 = rx.ShapeStructInfo([1, n, 3]) + + tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32") + tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32") + tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32") + tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32") + tensor4 = rx.TensorStructInfo([n, m], "int32") + tensor5 = rx.TensorStructInfo([n, m, 1], "int32") + tensor6 = rx.TensorStructInfo([n, m, 2], "int32") + + # obj + assert bcheck(obj0, prim0) == BR.PASS + assert bcheck(obj0, shape1) == BR.PASS + assert bcheck(obj0, tensor2) == BR.PASS + assert obj0.is_base_of(tensor2) + + # prim + assert prim0.is_base_of(prim0) + assert not prim0.is_base_of(prim1) + assert bcheck(prim0, obj0) == BR.FAIL_L1 + assert bcheck(prim0, prim0) == BR.PASS + assert bcheck(prim0, prim1) == BR.FAIL_L0 + + # shape + assert bcheck(shape0, obj0) == BR.FAIL_L1 + assert bcheck(shape0, prim0) == BR.FAIL_L0 + + # unknown dim + assert bcheck(shape0, shape1) == BR.PASS + assert bcheck(shape1, shape0) == BR.FAIL_L1 + + # ndim mismatch + assert bcheck(shape1, shape2) == BR.FAIL_L0 + + # lhs do not have symbolic value but ndim match + assert bcheck(shape2, shape3) == BR.PASS + + # rhs do not symbolic but lhs do + assert bcheck(shape3, shape2) == BR.FAIL_L2 + + # shape mismatch + assert bcheck(shape3, shape4) == BR.FAIL_L2 + assert shape4.is_base_of(rx.ShapeStructInfo([1, n, 3])) + + # tensor + assert bcheck(tensor0, obj0) == BR.FAIL_L1 + assert bcheck(tensor0, prim0) == BR.FAIL_L0 + assert bcheck(tensor0, shape0) == BR.FAIL_L0 + + # dtype mismatch + assert bcheck(tensor0, tensor1) == BR.FAIL_L0 + assert bcheck(tensor0, tensor3) == BR.FAIL_L0 + assert bcheck(tensor3, tensor4) == BR.FAIL_L0 + assert bcheck(tensor1, tensor2) == BR.FAIL_L0 + + # ndim mismatch + assert bcheck(tensor2, tensor5) == BR.FAIL_L0 + + # static shape mismatch + assert bcheck(tensor5, tensor6) == BR.FAIL_L0 + + # match + assert tensor0.is_base_of(rx.TensorStructInfo(ndim=-1, dtype="int32")) + assert tensor0.is_base_of(tensor2) + assert tensor0.is_base_of(tensor4) + assert tensor0.is_base_of(tensor5) + assert tensor0.is_base_of(tensor6) + assert tensor2.is_base_of(tensor4) + assert tensor4.is_base_of(rx.TensorStructInfo([n, m], dtype="int32")) + + # tuple + t0 = rx.TupleStructInfo([obj0, tensor0]) + t1 = rx.TupleStructInfo([prim0, tensor4]) + t2 = rx.TupleStructInfo([obj0, tensor0, obj0]) + t3 = rx.TupleStructInfo([tensor0, obj0]) + + assert t0.is_base_of(t1) + + assert bcheck(t0, t2) == BR.FAIL_L0 + assert bcheck(t0, t3) == BR.FAIL_L1 + + assert rx.TupleStructInfo([t0, t1]).is_base_of(rx.TupleStructInfo([t1, t1])) + assert bcheck(rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t0])) == BR.FAIL_L1 + + def fn_info_shape(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + def fn_info_erased(): + x = rx.TensorStructInfo(ndim=3, dtype="float32") + y = rx.TensorStructInfo(ndim=3, dtype="float32") + z = rx.TensorStructInfo(ndim=2, dtype="float32") + return rx.FuncStructInfo([x, y], z) + + assert fn_info_shape(1).is_base_of(fn_info_shape(1)) + assert fn_info_erased().is_base_of(fn_info_shape(1)) + assert bcheck(fn_info_shape(1), fn_info_erased()) == BR.FAIL_L2 + + fopaque = rx.FuncStructInfo.opaque_func() + assert fopaque.is_base_of(fn_info_shape(1)) + + +def _check_lca(lhs, rhs, target): + tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(lhs, rhs), target) + tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(rhs, lhs), target) + + +def test_struct_info_lca(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + obj0 = rx.ObjectStructInfo() + prim0 = rx.PrimStructInfo("int32") + prim1 = rx.PrimStructInfo("float32") + + shape0 = rx.ShapeStructInfo(ndim=-1) + shape1 = rx.ShapeStructInfo(ndim=2) + shape2 = rx.ShapeStructInfo(ndim=3) + shape3 = rx.ShapeStructInfo([1, 2, 3]) + shape4 = rx.ShapeStructInfo([1, n, 3]) + + tensor0 = rx.TensorStructInfo(ndim=-1, dtype="int32") + tensor1 = rx.TensorStructInfo(ndim=-1, dtype="float32") + tensor2 = rx.TensorStructInfo(ndim=2, dtype="int32") + tensor3 = rx.TensorStructInfo(ndim=2, dtype="float32") + tensor4 = rx.TensorStructInfo([n, m], "int32") + tensor5 = rx.TensorStructInfo([n, m, 1], "int32") + tensor6 = rx.TensorStructInfo([n, m, 2], "int32") + + # obj + _check_lca(obj0, prim0, obj0) + _check_lca(obj0, prim1, obj0) + + # shape + _check_lca(shape0, tensor0, obj0) + _check_lca(shape0, shape1, shape0) + _check_lca(shape1, shape2, shape0) + _check_lca(shape1, shape3, shape0) + + _check_lca(shape2, shape3, shape2) + _check_lca(shape3, shape4, shape2) + _check_lca(shape4, rx.ShapeStructInfo([1, n, 3]), shape4) + + # tensor + _check_lca(tensor0, prim0, obj0) + _check_lca(tensor0, tensor1, rx.TensorStructInfo(ndim=-1, dtype=None)) + _check_lca(tensor0, tensor2, tensor0) + _check_lca(tensor0, tensor4, tensor0) + + _check_lca(tensor2, tensor4, tensor2) + _check_lca(tensor5, tensor6, rx.TensorStructInfo(ndim=3, dtype="int32")) + _check_lca(tensor4, tensor5, rx.TensorStructInfo(ndim=-1, dtype="int32")) + _check_lca(tensor4, rx.TensorStructInfo([n, m], dtype="int32"), tensor4) + + # tuple + t0 = rx.TupleStructInfo([obj0, tensor0]) + t1 = rx.TupleStructInfo([prim0, tensor4]) + t2 = rx.TupleStructInfo([obj0, tensor0, obj0]) + t3 = rx.TupleStructInfo([tensor0, obj0]) + + _check_lca(t0, t1, t0) + _check_lca(t0, t2, obj0) + _check_lca(t0, t3, rx.TupleStructInfo([obj0, obj0])) + + t5 = rx.TupleStructInfo([t0, t1]) + t6 = rx.TupleStructInfo([t1, t2]) + + _check_lca(t5, t6, rx.TupleStructInfo([t0, obj0])) + + t7 = rx.TupleStructInfo([]) + _check_lca(t7, rx.TupleStructInfo([]), t7) + + def fn_info_shape(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n], "float32") + return rx.FuncStructInfo([x, y], z) + + def fn_info_erased(): + x = rx.TensorStructInfo(ndim=3, dtype="float32") + y = rx.TensorStructInfo(ndim=3, dtype="float32") + z = rx.TensorStructInfo(ndim=2, dtype="float32") + return rx.FuncStructInfo([x, y], z) + + fopaque0 = lambda: rx.FuncStructInfo.opaque_func() + fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0) + fopaque2 = lambda: rx.FuncStructInfo.opaque_func( + ret=rx.TensorStructInfo(ndim=2, dtype="float32") + ) + + _check_lca(fn_info_shape(1), fn_info_shape(2), fn_info_erased()) + _check_lca(fn_info_shape(2), fn_info_shape(2), fn_info_shape(2)) + + _check_lca(fopaque0(), fopaque1(), fopaque0()) + _check_lca(fopaque0(), fn_info_shape(1), fopaque0()) + _check_lca(fopaque2(), fn_info_shape(1), fopaque2()) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py new file mode 100644 index 000000000000..4eeaed1e0b50 --- /dev/null +++ b/tests/python/relax/test_expr.py @@ -0,0 +1,258 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest +import tvm +from tvm import relax as rx +from tvm import tir +from tvm.script import relax as R + + +def _check_equal(x, y, map_free_vars=False): + tvm.ir.assert_structural_equal(x, y, map_free_vars) + tvm.ir.assert_structural_equal(y, x, map_free_vars) + + xhash = tvm.ir.structural_hash(x, map_free_vars) + yhash = tvm.ir.structural_hash(y, map_free_vars) + + assert xhash == yhash + + +def _check_json_roundtrip(x): + xret = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, xret, map_free_vars=True) + return xret + + +def test_var() -> None: + v0 = rx.Var("v0") + assert v0.name_hint == "v0" + assert v0._checked_type_ is None + assert v0.struct_info_ is None + shape = [54, 96] + v1 = rx.Var("v1", R.Tensor(shape, "float32")) + assert v1.name_hint == "v1" + for s0, s1 in zip(v1.struct_info.shape, shape): + assert s0 == s1 + assert v1.checked_type == rx.DynTensorType(2, "float32") + tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float32")) + + +def test_dataflow_var() -> None: + v0 = rx.DataflowVar("v0") + assert v0.name_hint == "v0" + assert v0._checked_type_ is None + assert v0.struct_info_ is None + + shape = [54, 96] + v1 = rx.DataflowVar("v1", R.Tensor(shape, "float16")) + assert v1.name_hint == "v1" + + assert v1._checked_type_ == rx.DynTensorType(2, "float16") + assert isinstance(v1, rx.DataflowVar) + tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float16")) + + +def test_tuple() -> None: + v0 = rx.Var("v0") + v1 = rx.Var("v1") + t = rx.Tuple((v0, v1)) + + assert t.fields[0] == v0 + assert t.fields[1] == v1 + assert t[0] == v0 + assert t[1] == v1 + assert t[-1] == v1 + assert t[-2] == v0 + + with pytest.raises(IndexError, match="Tuple index out of range"): + t[2] + + with pytest.raises(IndexError, match="Tuple index out of range"): + t[-3] + + +def test_match_cast() -> None: + # match_cast([16, 8], [m, n]) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + shape = rx.const([16, 8], "int32") + var = rx.Var("v0", R.Shape()) + b0 = rx.MatchCast(var, shape, R.Tensor([m, n], "int32")) + assert b0.value == shape + assert b0.pattern[0] == m + assert b0.pattern[1] == n + assert b0.var is not None + assert b0.var.checked_type == rx.ShapeType() + + # var1: R.Tensor((m, n), "float32") = + # match_cast(var0: R.Tensor("float32", ndim=-1), R.Tensor((m, n), "float32")) + value = rx.Var("value", R.Tensor("float32", ndim=-1)) + + var = rx.Var("v1", R.Tensor([m, n], "float32")) + b1 = rx.MatchCast(var, value, R.Tensor([m, n], "float32")) + assert b1.value == value + assert b1.pattern[0] == m + assert b1.pattern[1] == n + assert b1.var is not None + assert b1.var.checked_type == rx.DynTensorType(2, "float32") + + +def test_match_cast() -> None: + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + ivalue = rx.Var("input_value") + sinfo = rx.TensorStructInfo([n, m], "float32") + b0 = rx.MatchCast(rx.Var("v"), ivalue, sinfo) + assert b0.value.same_as(ivalue) + assert b0.struct_info == sinfo + _check_json_roundtrip(b0) + + +def test_var_binding() -> None: + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b0 = rx.VarBinding(v0, val) + assert b0.var.name_hint == "v0" + assert b0.value == val + + +def test_binding_block() -> None: + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.BindingBlock([b0, b1]) + assert block0.bindings[0] == b0 + assert block0.bindings[1] == b1 + + +def test_dataflow_block() -> None: + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.DataflowBlock([b0, b1]) + assert block0.bindings[0] == b0 + assert block0.bindings[1] == b1 + assert isinstance(block0, rx.DataflowBlock) + + +def test_seq_expr() -> None: + x = rx.Var("foo") + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + seqe = rx.SeqExpr(blocks, x) + assert seqe.blocks[0] == blocks[0] + assert seqe.body == x + + +def test_func(): + x = rx.Var("foo", R.Tensor(dtype="float32", ndim=2)) + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + + seqe = rx.SeqExpr(blocks, x) + ret_struct_info = R.Tensor(dtype="float32", ndim=-1) + func = rx.Function([x], seqe, ret_struct_info) + func = func.with_attr("global_symbol", "func") + assert func.params[0] == x + assert func.body == seqe + assert func.ret_struct_info == ret_struct_info + assert func.attrs["global_symbol"] == "func" + + +def test_shape_of(): + shape = [96, 54] + v1 = rx.Var("v1", R.Tensor(shape)) + s1 = rx.get_shape_of(v1) + for x, y in zip(shape, s1): + assert x == y + + +def test_shape_expr(): + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + s = rx.ShapeExpr([m, n]) + assert s.values[0] == m + assert s.values[1] == n + assert s[0] == m + assert s[1] == n + assert s[-1] == n + assert s[-2] == m + assert isinstance(s.struct_info, rx.ShapeStructInfo) + + with pytest.raises(IndexError, match="ShapeExpr index out of range"): + s[2] + + with pytest.raises(IndexError, match="ShapeExpr index out of range"): + s[-3] + + shape_expr = rx.ShapeExpr([10, 20]) + assert shape_expr.values[0] == 10 + assert shape_expr.values[1] == 20 + assert shape_expr.checked_type == rx.ShapeType(ndim=2) + tvm.ir.assert_structural_equal(shape_expr.struct_info, R.Shape((10, 20))) + + x = rx.Var("v0", R.Tensor((10, 20), "float32")) + assert x.struct_info.shape[0] == 10 + assert x.struct_info.shape[1] == 20 + assert x.struct_info.shape.checked_type == rx.ShapeType(ndim=2) + tvm.ir.assert_structural_equal(x.struct_info.shape.struct_info, R.Shape((10, 20))) + + m = tir.Var("m", "int32") + with pytest.raises( + tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64" + ): + rx.ShapeExpr([m, 3]) + + +def test_prim_value(): + pv = rx.PrimValue(tir.IntImm("int64", 1)) + assert pv.value.value == 1 + _check_equal(pv, rx.PrimValue(tir.IntImm("int64", 1))) + _check_json_roundtrip(pv) + + +def test_string_imm(): + s0 = rx.StringImm("hello") + s1 = rx.StringImm("hello") + assert s0.value == "hello" + _check_equal(s0, s1) + _check_json_roundtrip(s0) + + +def test_datatype_imm(): + d0 = rx.DataTypeImm("int32") + d1 = rx.DataTypeImm("int32") + assert d0.value == "int32" + _check_equal(d0, d1) + _check_json_roundtrip(d0) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_struct_info.py b/tests/python/relax/test_struct_info.py new file mode 100644 index 000000000000..80ebc3cb182a --- /dev/null +++ b/tests/python/relax/test_struct_info.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +import pytest + +from tvm import relax as rx, TVMError, tir + + +def _check_equal(x, y, map_free_vars=False): + tvm.ir.assert_structural_equal(x, y, map_free_vars) + tvm.ir.assert_structural_equal(y, x, map_free_vars) + + xhash = tvm.ir.structural_hash(x, map_free_vars) + yhash = tvm.ir.structural_hash(y, map_free_vars) + + assert xhash == yhash + + +def _check_json_roundtrip(x): + xret = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, xret, map_free_vars=True) + return xret + + +def test_object_struct_info(): + s0 = rx.ObjectStructInfo() + s1 = rx.ObjectStructInfo() + + # can turn into str + str(s0) + _check_equal(s0, s1) + + assert isinstance(s0, rx.ObjectStructInfo) + _check_json_roundtrip(s0) + + +def test_shape_type(): + t0 = rx.ShapeType() + t1 = rx.ShapeType() + assert t0 == t1 + + +def test_dyn_tensor_type(): + t0 = rx.DynTensorType() + assert t0.ndim == -1 + t1 = rx.DynTensorType(3, "int32") + assert t1.ndim == 3 + assert t1.dtype == "int32" + + +def test_prim_struct_info(): + s0 = rx.PrimStructInfo("float32") + s1 = rx.PrimStructInfo("float32") + s2 = rx.PrimStructInfo("int32") + + _check_equal(s0, s1) + + # can turn into str + str(s0) + + assert s0 == s1 + assert s0 != s2 + + assert isinstance(s0, rx.PrimStructInfo) + _check_json_roundtrip(s0) + _check_json_roundtrip(s1) + + assert s1.dtype == "float32" + assert s2.dtype == "int32" + + # wrong API constructors + with pytest.raises(TVMError): + rx.PrimStructInfo(1) + + +def test_shape_struct_info(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s0 = rx.ShapeStructInfo([1, n + 1, m]) + s1 = rx.ShapeStructInfo([1, n + 1, m]) + + _check_equal(s0, s1) + + assert s0 == s1 + assert s0.ndim == 3 + assert s1.ndim == 3 + + assert s0.values[2] == m + + assert isinstance(s0, rx.ShapeStructInfo) + _check_json_roundtrip(s0) + _check_json_roundtrip(s1) + + s2 = rx.ShapeStructInfo(ndim=2) + + assert s2.ndim == 2 + assert s2.values is None + _check_json_roundtrip(s2) + assert s0 != s2 + + # can turn into str + str(s0) + + # wrong argument type + with pytest.raises(TVMError): + rx.ShapeStructInfo(1) + + # cannot pass both ndim and values + with pytest.raises(ValueError): + rx.ShapeStructInfo([1, 2], ndim=3) + + # cannot pass both ndim and values even if they are consistent + with pytest.raises(ValueError): + rx.ShapeStructInfo([1, 2], ndim=2) + + +def test_tensor_struct_info(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s0 = rx.TensorStructInfo([1, n + 1, m], "float32") + s1 = rx.TensorStructInfo(rx.ShapeExpr([1, n + 1, m]), "float32") + + _check_equal(s0, s1) + + assert s0 == s1 + assert s0.ndim == 3 + assert s1.ndim == 3 + + assert isinstance(s0, rx.TensorStructInfo) + _check_json_roundtrip(s0) + _check_json_roundtrip(s1) + + s2 = rx.TensorStructInfo(ndim=2, dtype="int32") + + assert s2.ndim == 2 + assert s2.dtype == "int32" + assert s2.shape is None + _check_json_roundtrip(s2) + assert s0 != s2 + + # take in opaque var + rshape = rx.Var("shape", rx.ShapeStructInfo(ndim=2)) + + s3 = rx.TensorStructInfo(rshape, dtype="int32") + assert s3.dtype == "int32" + assert s3.shape == rshape + assert s3.ndim == 2 + _check_json_roundtrip(s3) + + # can turn into str + str(s0) + + # cannot pass both ndim and values + with pytest.raises(ValueError): + rx.TensorStructInfo([1, 2], ndim=3) + + # cannot pass both ndim and values even if they are consistent + with pytest.raises(ValueError): + rx.TensorStructInfo([1, 2], ndim=2) + + +def test_tuple_struct_info(): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + s0 = rx.TensorStructInfo([1, 2, m + n], "float32") + s1 = rx.ObjectStructInfo() + + t0 = rx.TupleStructInfo([s0, s1]) + t1 = rx.TupleStructInfo([s0, rx.ObjectStructInfo()]) + t2 = rx.TupleStructInfo([s0, s0]) + + _check_equal(t0, t1) + + assert t0 == t1 + + assert isinstance(t0, rx.TupleStructInfo) + t0 = _check_json_roundtrip(t0) + t1 = _check_json_roundtrip(t1) + t2 = _check_json_roundtrip(t2) + + # can turn into str + str(t0) + + # wrong argument type + with pytest.raises(TVMError): + rx.TupleStructInfo(1) + + +def test_func_struct_info(): + def fn_info(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([c, n, m], "float32") + y = rx.TensorStructInfo([c, n, 1], "float32") + z = rx.TensorStructInfo([c, n, m], "float32") + return rx.FuncStructInfo([x, y], z) + + f0 = fn_info(1) + f1 = fn_info(1) + f2 = fn_info(2) + f3 = rx.FuncStructInfo.opaque_func() + + _check_equal(f0, f1) + + assert f0 == f1 + assert f0 != f2 + + assert len(f0.params) == 2 + assert isinstance(f0.ret, rx.TensorStructInfo) + assert f2.derive_func is None + assert f3.params is None + assert f3.derive_func is None + _check_equal(f3.ret, rx.ObjectStructInfo()) + + assert isinstance(f0, rx.FuncStructInfo) + f0 = _check_json_roundtrip(f0) + f1 = _check_json_roundtrip(f1) + f2 = _check_json_roundtrip(f2) + f3 = _check_json_roundtrip(f3) + + # can turn into str + str(f3) + + +if __name__ == "__main__": + tvm.testing.main() From ff488e94df71682844b184cc43d1bcbc9b741475 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Tue, 7 Feb 2023 01:30:11 -0500 Subject: [PATCH 06/81] [Unity] Relax BlockBuilder and ExprMutator (#13926) This PR adds BlockBuilder: the core data structure to construct Relax AST, and ExprMutator: performs AST mutation for implementing transformation passes. Co-Authored-by: Tianqi Chen Co-Authored-by: Altan Haan Co-Authored-by: Andrew Liu Co-Authored-by: Hongyi Jin <3231950289@qq.com> Co-Authored-by: Jiawei Liu Co-Authored-by: Junru Shao Co-Authored-by: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Co-Authored-by: masahi Co-Authored-by: Prakalp Srivastava Co-Authored-by: Ruihang Lai Co-Authored-by: Siyuan Feng Co-Authored-by: Steven S. Co-Authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Co-Authored-by: Yixin Dong Co-Authored-by: Yong Wu Co-Authored-by: Ziheng Jiang --- CMakeLists.txt | 2 + include/tvm/relax/analysis.h | 13 + include/tvm/relax/block_builder.h | 239 +++ include/tvm/relax/expr.h | 7 +- include/tvm/relax/expr_functor.h | 138 +- include/tvm/relax/op_attr_types.h | 75 + include/tvm/relax/struct_info.h | 5 +- include/tvm/relax/utils.h | 155 ++ include/tvm/te/operation.h | 2 +- python/tvm/ir/function.py | 16 + python/tvm/meta_schedule/utils.py | 50 +- python/tvm/relax/__init__.py | 11 + python/tvm/relax/analysis/analysis.py | 33 +- python/tvm/relax/block_builder.py | 801 +++++++++ python/tvm/relax/expr_functor.py | 1530 +++++++++++++++++ python/tvm/relax/op/__init__.py | 22 + python/tvm/relax/op/_ffi_api.py | 19 + python/tvm/relax/op/base.py | 358 ++++ python/tvm/relax/op/binary.py | 67 + python/tvm/relax/utils.py | 278 +++ python/tvm/te/__init__.py | 1 + python/tvm/te/operation.py | 54 +- src/ir/function.cc | 14 + src/relax/analysis/struct_info_analysis.cc | 149 ++ src/relax/ir/block_builder.cc | 969 +++++++++++ src/relax/ir/emit_te.cc | 78 + src/relax/ir/emit_te.h | 68 + src/relax/ir/expr_functor.cc | 244 +++ src/relax/ir/py_expr_functor.cc | 649 +++++++ src/relax/op/op.cc | 77 + src/relax/op/op_common.cc | 122 ++ src/relax/op/op_common.h | 285 +++ src/relax/op/tensor/binary.cc | 87 + src/relax/op/tensor/binary.h | 70 + src/relax/utils.cc | 41 + src/te/operation/create_primfunc.cc | 80 + src/te/operation/create_primfunc.h | 17 + .../test_analysis_struct_info_analysis.py | 143 ++ tests/python/relax/test_blockbuilder.py | 542 ++++++ tests/python/relax/test_expr.py | 4 +- tests/python/relax/test_expr_functor.py | 746 ++++++++ 41 files changed, 8233 insertions(+), 28 deletions(-) create mode 100644 include/tvm/relax/block_builder.h create mode 100644 include/tvm/relax/op_attr_types.h create mode 100644 include/tvm/relax/utils.h create mode 100644 python/tvm/relax/block_builder.py create mode 100644 python/tvm/relax/expr_functor.py create mode 100644 python/tvm/relax/op/__init__.py create mode 100644 python/tvm/relax/op/_ffi_api.py create mode 100644 python/tvm/relax/op/base.py create mode 100644 python/tvm/relax/op/binary.py create mode 100644 python/tvm/relax/utils.py create mode 100644 src/relax/ir/block_builder.cc create mode 100644 src/relax/ir/emit_te.cc create mode 100644 src/relax/ir/emit_te.h create mode 100644 src/relax/ir/py_expr_functor.cc create mode 100644 src/relax/op/op.cc create mode 100644 src/relax/op/op_common.cc create mode 100644 src/relax/op/op_common.h create mode 100644 src/relax/op/tensor/binary.cc create mode 100644 src/relax/op/tensor/binary.h create mode 100644 src/relax/utils.cc create mode 100644 tests/python/relax/test_blockbuilder.py create mode 100644 tests/python/relax/test_expr_functor.py diff --git a/CMakeLists.txt b/CMakeLists.txt index fa38ba6c6c8a..eecd67be94c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -290,8 +290,10 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/support/*.cc src/script/*.cc src/relax/ir/*.cc + src/relax/op/*.cc src/relax/analysis/*.cc src/relax/backend/vm/*.cc + src/relax/utils.cc ) tvm_file_glob(GLOB CODEGEN_SRCS diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 82145032f458..ad2bd19aa41a 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -85,6 +85,19 @@ TVM_DLL Type GetStaticType(const StructInfo& info); */ TVM_DLL StructInfo StructInfoFromType(const Type& type); +/*! + * \return Derive the call's ret value struct info from inputs. + * \param finfo The function struct info. + * \param call The call expression to be derived. + * \param ctx The builder context. + * \param ana Optional context analyzer to prove symbolic expression equality. + * \return The derived struct info of the call. + * \note call->op field is ignored during derivation and we only rely on information + * presented by func_sinfo. + */ +TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, + const BlockBuilder& ctx, arith::Analyzer* ana = nullptr); + /*! * \brief Erase the info to a corresponding more coarse grained * struct info that is still well-defined(with all the vars in scope). diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h new file mode 100644 index 000000000000..7222ae08f956 --- /dev/null +++ b/include/tvm/relax/block_builder.h @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/block_builder.h + * \brief The utility for constructing Relax binding blocks. + */ +#ifndef TVM_RELAX_BLOCK_BUILDER_H_ +#define TVM_RELAX_BLOCK_BUILDER_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A builder to build Relax binding blocks. + * + * BlockBuilder provides the following three categories + * of main functionalities for IR building and transformations: + * + * - Global context management: manages the IRModule, + * allowing query, update the surrounding global context. + * Provide context tools for analysis. + * - Scope management: + * - Manages block scopes for bulding nested blocks. + * - Emit bindings to the current scope. + * - Construct blocks by calling EndScope. + * - Normalization: Take an Expr, normalize it + * to deduce shape/type, turn things into normal forms. + * + * Importantly, these three categories of features can be dependent + * on each other. For example, when we emit into scope we will call + * normalize to ensure the code is in normal form. Similarly, when we + * normalize we could choose to emit into the current context. + * + * We would encourage the developers to keep these three category + * in mind when using and developing BlockBuilder, we can group + * the code in a logically clean way. + * + * BlockBuilderNode is implemented as a virtual interface to + * allow logically grouped implementation and internal data + * structures that are hidden from the users. + */ +class BlockBuilderNode : public Object { + public: + //------------------------------- + // Global Context management + //------------------------------- + /*! + * \brief Get the name table for generating unique names. + * + * \return The name table. + */ + virtual NameTable* name_table() = 0; + + /*! + * \brief Get the context IRModule in this builder. + * + * \note The context + * \return The IRModule in this BlockBuilder. + */ + virtual IRModule GetContextIRModule() const = 0; + + /*! + * \brief Add a Relax function or a TIR PrimFunc to internal context module. + * \param func The function to be added. + * \param func_name_hint The name hint of the function to be added. + * \note If the function to be added already exists, return its + * GlobalVar directly. + * \return The global var bound to the added function. + */ + virtual GlobalVar AddFunction(const BaseFunc& func, String func_name_hint) = 0; + + /*! + * \brief Update a Relax function or a TIR PrimFunc in the internal context module. + * \param gv The global var referring the function to be updated. + * \param function The updated function. + */ + virtual void UpdateFunction(const GlobalVar& gv, BaseFunc function) = 0; + + /*! + * \brief Report an error during transformation construction. + * \param diagnostic The diagnostic information. + */ + virtual void ReportFatal(const Diagnostic& diagnostic) = 0; + + //------------------------------- + // Scope management + //------------------------------- + /*! + * \brief Lookup the binding value that var binds to in the current emitted sequences. + * \param var The input var. + * \return The Expr bound to the input \p var. + * \note For function parameters, this function returns NullOpt. + */ + virtual Optional LookupBinding(const Var& var) = 0; + + /*! + * \brief Begin a new scope, with optional parameters that + * are visible within the scope. + * + * \param params Parameters that are visible within the scope. + * + * \note This function should be called when new scope is introduced + * (function, seq) to properly track the variable availability + * and help the best effort deduction. + * + * \sa EndScope + */ + virtual void BeginScope(Optional> params) = 0; + + /*! \brief End the previously defined scope. */ + virtual void EndScope() = 0; + + /*! \brief Begin to build a DataflowBlock. */ + virtual void BeginDataflowBlock() = 0; + + /*! \brief Begin to build a BindingBlock. */ + virtual void BeginBindingBlock() = 0; + /*! + * \brief End building a BindingBlock. + * \return The BindingBlock being built. + */ + virtual BindingBlock EndBlock() = 0; + + /*! + * \brief Check if the block being built is DataflowBlock or not. + * \return A boolean that indicates if the block being built is DataflowBlock or not. + */ + virtual bool CurrentBlockIsDataFlow() = 0; + + /*! + * \brief Emits an Expr, and returns the variable it is bound to. + * \param expr The Expr to be emitted. + * \param name_hint Name hint for the bound variable. + * \return The new variable that \p expr is bound to. + * + * \note This Emit function normalizes the \p expr, and + * performs shape and type deductions by calling Normalize. + */ + virtual Var Emit(Expr expr, String name_hint = "") = 0; + + /*! + * \brief Emit a MatchCast. + * \param value The input value. + * \param struct_info The struct info to be matched. + * \param name_hint Name hint for the bound variable. + * \return The variable bound to the MatchCast. + */ + virtual Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint = "") = 0; + + /*! + * \brief Generate an output for the current dataflow block. + * \param output The output variable of the block. + * \param name_hint Name hint for the bound variable. + * \return The variable bound to \p output. + */ + virtual Var EmitOutput(Expr output, String name_hint = "") = 0; + + /*! + * \brief Emit a binding that is already normalized. + * + * \param normalized_binding A binding whose value is already normalized. + * + * \note This function requires binding to be pre-normalized. + */ + virtual void EmitNormalized(Binding normalized_binding) = 0; + + /*! + * \brief Convert an expression to normal form, and try to eagerly infer types and shapes. + * \param expr The input expression. + * \return The normalized expression. + * + * \note Invariant: If any of the sub expr have struct_info field. + * they must have already been normalized. + */ + virtual Expr Normalize(const Expr& expr) = 0; + + /*! + * \brief Normalize argument to a call or another IRNode. + * \param expr The input expression. + * \return The normalized expression. + * + * \note This function will create a binding var for non-leaf expressions such as Call. + */ + virtual Expr NormalizeArgument(const Expr& expr) = 0; + + /*! + * \brief Get the analyzer of the BlockBuilder. + * \return The BlockBuilder's arithmetic analyzer. + */ + virtual arith::Analyzer* GetAnalyzer() = 0; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.BlockBuilder"; + TVM_DECLARE_BASE_OBJECT_INFO(BlockBuilderNode, Object); +}; + +class BlockBuilder : public ObjectRef { + public: + /*! + * \brief Create a BlockBuilder. + * + * \param ctx_mod Optional before-transformation context module for rewriting. + * \return The created BlockBuilder. + * + * \note When rewriting an existing IRModule, it is important to pass it in as + * ctx_mod so you can lookup the context functions for cross function + * call analysis. + */ + TVM_DLL static BlockBuilder Create(Optional ctx_mod); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BLOCK_BUILDER_H_ diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 9e563c7061dc..0788193ee7c4 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -171,8 +171,7 @@ class CallNode : public ExprNode { // skip sinfo_args check for primitive ops. equal->MarkGraphNode(); return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) && - (IsPrimitiveOp(op) || equal(sinfo_args, other->sinfo_args)) && - equal(struct_info_, other->struct_info_); + equal(sinfo_args, other->sinfo_args) && equal(struct_info_, other->struct_info_); } void SHashReduce(SHashReducer hash_reduce) const { @@ -180,9 +179,7 @@ class CallNode : public ExprNode { hash_reduce(op); hash_reduce(args); hash_reduce(attrs); - if (!IsPrimitiveOp(op)) { - hash_reduce(sinfo_args); - } + hash_reduce(sinfo_args); hash_reduce(struct_info_); } diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 5735e8661f6f..655ecc52b656 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -26,15 +26,18 @@ #define TVM_RELAX_EXPR_FUNCTOR_H_ #include +#include #include #include #include #include #include +#include #include +#include #include - +#include namespace tvm { namespace relax { @@ -410,6 +413,139 @@ class ExprMutatorBase : public ExprFunctor { DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this}; }; +/*! + * \brief A mutator works in normal form. + * + * ExprMutator expects input AST to be in the normal form, i.e., the expressions are normalized(no + * nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are + * available. + */ +class ExprMutator : public ExprMutatorBase { + public: + using ExprMutatorBase::VisitExpr_; + + ExprMutator(Optional mod = NullOpt) { builder_ = BlockBuilder::Create(mod); } + Expr VisitExpr(const Expr& expr) override; + Expr VisitExpr_(const VarNode* op) override; + Expr VisitExpr_(const DataflowVarNode* op) override; + Expr VisitExpr_(const FunctionNode* op) override; + Expr VisitExpr_(const SeqExprNode* op) override; + Expr VisitExpr_(const IfNode* op) override; + + /*! + * \brief Generic dispatcher for bindings. + * \param binding The binding to be visited. + */ + virtual void VisitBinding(const Binding& binding); + // specific leaf level visitor functions + virtual void VisitBinding_(const VarBindingNode* binding); + virtual void VisitBinding_(const MatchCastNode* binding); + // second level dispatching based on binding value type. + // these dispatching functions get called from first-level dispatch on VarBinding + virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const TupleNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const VarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const ExternFuncNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const GlobalVarNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const CallNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const SeqExprNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const IfNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const OpNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const PrimValueNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const StringImmNode* val); + virtual void VisitBinding_(const VarBindingNode* binding, const DataTypeImmNode* val); + /*! + * \brief Generic dispatcher for binding blocks. + * \param block The binding block to be visited. + * \return The binding block after transformation. + */ + virtual BindingBlock VisitBindingBlock(const BindingBlock& block) override; // NOLINT(*) + // specific leaf level visitor functions + virtual BindingBlock VisitBindingBlock_(const BindingBlockNode* block); + virtual BindingBlock VisitBindingBlock_(const DataflowBlockNode* block); + + /*! + * \brief Generic dispatcher for rewriting the var definition site. + * \param var The var to be visited. + * \return The var after post-order rewritten. + * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var + */ + virtual Var VisitVarDef(const Var& var); + // specific leaf level visitor functions + virtual Var VisitVarDef_(const VarNode* var); + virtual Var VisitVarDef_(const DataflowVarNode* var); + + protected: + /*! + * \brief Try to remit binding and bind it to a new_value + * + * This function is called after VisitExpr(binding->value) in + * VisitBinding_(const VarBinding*). + * It will try to reuse the current binding when the new value's shape/type + * matches the original binding and no changes in var is needed. + * + * Otherwise, a new binding will be emitted to replace the var specified in + * the current binding. + */ + void ReEmitBinding(const VarBindingNode* binding, Expr new_value); + + /*! + * \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If. + * + * \param body_expr The body to be visited. + * \param params Optional parameters that are visible within the scope. + * \return The expr after visiting. + * + * \note The body_expr must be an SeqExpr in the normal form. + */ + Expr VisitWithNewScope(const Expr& body_expr, Optional> params = NullOpt); + + /*! + * \brief Look up the value bound to a variable. + * \param var The var to be looked up. + * \return The value bound to the input \p var. + * \note For function parameters, this function returns NullOpt. + */ + Optional LookupBinding(const Var& var); + + /*! + * \brief Post-order rewrite a node and normalize. + * \tparam T The node type to be rewritten. + * \param op The node to be rewritten. + * \return The node after post rewritten. + */ + template + Expr VisitExprPostOrder_(const T* op) { + return builder_->Normalize(ExprMutator::VisitExpr_(op)); + } + + /*! + * \brief Create a new var with specified struct_info if the original var's shape or type does + * not match with the specified ones. + * \param var The var to be updated. + * \param struct_info The struct info to be updated. + * \return The var filled with struct_info + */ + Var WithStructInfo(Var var, StructInfo struct_info); + + /*! \brief Internal block builder to emit bindings during rewriting. */ + BlockBuilder builder_; + + /*! \brief Remap a var to a new var in use-site. */ + std::unordered_map var_remap_; + + private: + using TSelf = ExprMutator; + using VisitBindingVTable = + tvm::NodeFunctor; + // initialize the vtable. + static VisitBindingVTable InitVisitBindingVTable(); +}; + } // namespace relax } // namespace tvm #endif // TVM_RELAX_EXPR_FUNCTOR_H_ diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h new file mode 100644 index 000000000000..e171a8d47b0d --- /dev/null +++ b/include/tvm/relax/op_attr_types.h @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/op_attr_types.h + * \brief Data structures that can appear in operator attributes. + */ +#ifndef TVM_RELAX_OP_ATTR_TYPES_H_ +#define TVM_RELAX_OP_ATTR_TYPES_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Infer output struct info given the call + * + * \param call The call expression to be derived. + * \param ctx The builder context. + */ +using FInferStructInfo = + runtime::TypedPackedFunc; + +/*! + * \brief Packed function implementation for operators. The relax operator will be lowered to + * this packed function call during codegen. + */ +using FCallPacked = String; + +struct PrintAttrs : public tvm::AttrsNode { + std::string format; + TVM_DECLARE_ATTRS(PrintAttrs, "relax.attrs.PrintAttrs") { + TVM_ATTR_FIELD(format) + .describe("Python-style format string to use for displaying the input. Ignored if empty.") + .set_default(""); + } +}; + +struct AssertOpAttrs : public tvm::AttrsNode { + std::string format; + TVM_DECLARE_ATTRS(AssertOpAttrs, "relax.attrs.AssertOpAttrs") { + TVM_ATTR_FIELD(format) + .describe( + "Python-style format string to use for displaying " + "an error message if the assert fails. " + "Ignored if empty.") + .set_default(""); + } +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_OP_ATTR_TYPES_H_ diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index f38a32f6bb83..b9aebc549474 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -22,16 +22,13 @@ #include #include #include +#include #include #include namespace tvm { namespace relax { -// TODO(relax-team) replace with real BlockBuilder -// once it is ready. -using BlockBuilder = ObjectRef; - /*! * \brief Opaque object. */ diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h new file mode 100644 index 000000000000..1457a16427cc --- /dev/null +++ b/include/tvm/relax/utils.h @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/utils.h + * \brief Utility classes and functions for working with the Relax IR. + */ +#ifndef TVM_RELAX_UTILS_H_ +#define TVM_RELAX_UTILS_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Utility data structure for generating unique names for IR construction. + */ +class NameTable { + public: + /*! + * \brief Generate a unique name with a specified prefix. + * \param prefix The name prefix. + * \return The generated name. + */ + inline std::string GetUniqueName(std::string prefix) { + std::replace(prefix.begin(), prefix.end(), '.', '_'); + std::string unique_prefix = prefix; + auto it = alloc_map_.find(prefix); + if (it != alloc_map_.end()) { + while (alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) { + } + } + alloc_map_[unique_prefix] = 0; + return unique_prefix; + } + + NameTable() = default; + + template + explicit NameTable(Iter begin, Iter end, Lambda f) { + // static_assert is more reader-friendly than SFINAE when template specialization is not needed. + static_assert(std::is_convertible::value, + "Lambda f must has a signature of [?](*it) -> string {}"); + for (auto it = begin; it != end; ++it) { + const std::string& name = f(*it); + const size_t idx_last_first_num = std::distance( + std::find_if(name.rbegin(), name.rend(), [](char c) { return !std::isdigit(c); }), + name.rend()); + // name = {O = others}{D = consecutive digits} + // let O -> prefix; + std::string prefix = name.substr(0, idx_last_first_num); + ICHECK(prefix.size() > 0 && std::isalpha(prefix[0])) << "Invalid variable name: " << name; + if (0 == alloc_map_.count(prefix)) alloc_map_[prefix] = 0; + if (idx_last_first_num < name.size()) { // has some digits. + // let D's nearest natural number -> idx; + // note: stoul("000123") = 123; + alloc_map_[prefix] = + std::max(alloc_map_[prefix], std::stoi(name.substr(idx_last_first_num))); + } + } + } + + template + explicit NameTable(Iter begin, Iter end) + : NameTable(begin, end, [](const decltype(*begin)& v) { return v; }) {} + + private: + std::unordered_map alloc_map_; +}; + +/*! + * \brief Bind the variables to a Relax expression. This is a helper + * function usually called by other pass functions to help optimizations. + * If any free variables are introduced into a function, those are added + * to the function parameters. + * Additionally this may change the order of parameters if you map a variable + * to a variable. + * + * \param expr The input expression. + * \param binds The variable to expression map that will be used to help the + * binding. + * + * \return The updated expression. + */ +TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); + +/*! + * \brief Check if the given type is a boolean scalar type (tensor of rank 0 with a boolean dtype). + * + * \param ty The input type. + * \param permit_unknown_rank If true, it will permit the input type to have unknown rank + * (ndim of -1), which will require a dynamic check. + * \param permit_unknown_dtype If true, it will permit the input type to have an unknown dtype + * (namely, void), which will require a dynamic check. + * + * \return True iff the input type is a boolean scalar type (or, depending on options, has unknown + * rank or dtype) + */ +TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true, + bool permit_unknown_dtype = true); + +/*! + * \brief Check if the given expression is a "leaf" node or tuple node for normalization purposes. + * + * The following expressions are defined as leaf nodes: Var, Constant, ShapeExpr, + * GlobalVar, Op, ExternFunc. + * + * Tuples are included in this list mainly for convenience in grouping operator arguments. + * *Note*: Since tuples can contain nested expressions, it is necessary to ensure that + * values nested inside them are also leaves. + * + * \param expr The input expression + * + * \return True iff the input expression is a "leaf" node (a value allowed to appear + * inline without being bound to a var during normalization). + */ +TVM_DLL bool IsLeafOrTuple(const Expr& expr); + +/*! + * \brief Copy the given function. The parameters of the original function would be copied to + * satisfy the restriction in the well-formed check: any two functions cannot share the same + * parameter variable. + * \param func The relax function to copy. + * \return The copied function. + */ +TVM_DLL Function CopyWithNewParams(Function func); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_UTILS_H_ diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 2c50f3c3157b..f5753afa560f 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -182,7 +182,7 @@ class PlaceholderOpNode : public OperationNode { } static constexpr const char* _type_key = "PlaceholderOp"; - TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode); + TVM_DECLARE_BASE_OBJECT_INFO(PlaceholderOpNode, OperationNode); }; /*! diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index c3f1bf5f562a..d02698edb54d 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -66,3 +66,19 @@ def with_attr(self, attr_key_or_dict, attr_value=None): return _ffi_api.BaseFuncWithAttr( res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value) ) + + def without_attr(self, attr_key: str) -> "BaseFunc": + """Create a new copy of the function with an attribute without provided key. + + Parameters + ---------- + attr_key : str + The attribute key to delete from the attrubte pairs. + + + Returns + ------- + func : BaseFunc + A new copy of the function + """ + return _ffi_api.BaseFuncWithoutAttr(self, attr_key) diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 401fdab08a26..1f2cfd34016e 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -75,14 +75,27 @@ def _extract(inst: type, name: str): def method(*args, **kwargs): return getattr(inst, name)(*args, **kwargs) - if getattr(base, name) is getattr(cls, name) and name != "__str__": - # for task scheduler return None means calling default function - # otherwise it will trigger a TVMError of method not implemented - # on the c++ side when you call the method, __str__ not required - return None - return method + for inherit_cls, base_cls in zip(cls.__mro__, cls.__mro__[1:]): + # extract functions that differ from the base class + if not hasattr(base_cls, name): + continue + if getattr(base_cls, name) is getattr(inherit_cls, name) and name != "__str__": + continue + return method + + # for task scheduler return None means calling default function + # otherwise it will trigger a TVMError of method not implemented + # on the c++ side when you call the method, __str__ not required + return None assert isinstance(cls.__base__, type) + if hasattr(cls, "_type") and cls._type == "TVMDerivedObject": # type: ignore + raise TypeError( + ( + f"Inheritance from a decorated object `{cls.__name__}` is not allowed. " + f"Please inherit from `{cls.__name__}._cls`." + ) + ) assert hasattr( cls, "_tvm_metadata" ), "Please use the user-facing method overriding class, i.e., PyRunner." @@ -95,6 +108,9 @@ def method(*args, **kwargs): class TVMDerivedObject(metadata["cls"]): # type: ignore """The derived object to avoid cyclic dependency.""" + _cls = cls + _type = "TVMDerivedObject" + def __init__(self, *args, **kwargs): """Constructor.""" self.handle = None @@ -111,12 +127,22 @@ def __init__(self, *args, **kwargs): # using weakref to avoid cyclic dependency self._inst._outer = weakref.ref(self) - def __getattr__(self, name: str): - """Bridge the attribute function.""" - try: - return self._inst.__getattribute__(name) - except AttributeError: - return super(TVMDerivedObject, self).__getattr__(name) + def __getattr__(self, name): + # fall back to instance attribute if there is not any + # return self._inst.__getattribute__(name) + import inspect # pylint: disable=import-outside-toplevel + + result = self._inst.__getattribute__(name) + if inspect.ismethod(result): + + def method(*args, **kwargs): + return result(*args, **kwargs) + + # set __own__ to aviod implicit deconstruction + setattr(method, "__own__", self) + return method + + return result def __setattr__(self, name, value): if name not in ["_inst", "key", "handle"]: diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 01310f6455dd..ce175354d02c 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -21,6 +21,8 @@ from . import ty from . import analysis from . import vm +from . import block_builder +from . import op from . import struct_info # Expr @@ -60,6 +62,15 @@ from .exec_builder import ExecBuilder from .vm import VirtualMachine +# Operator +from .op.base import call_tir + +# BlockBuilder +from .block_builder import BlockBuilder + +# ExprFunctor +from .expr_functor import ExprFunctor, PyExprVisitor, PyExprMutator + # StructInfo from .struct_info import ( StructInfo, diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 301f3ecc7265..d81c477145ec 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -26,8 +26,8 @@ from tvm import tir from tvm.relax.ty import Type -from tvm.relax.struct_info import StructInfo -from tvm.relax.expr import Var, Expr +from tvm.relax.struct_info import StructInfo, FuncStructInfo +from tvm.relax.expr import Var, Expr, Call from . import _ffi_api @@ -116,6 +116,35 @@ def struct_info_base_check(base: StructInfo, derived: StructInfo) -> BaseCheckRe return _ffi_api.StructInfoBaseCheck(base, derived) # type: ignore +def derive_call_ret_struct_info( + func_sinfo: FuncStructInfo, call: Call, ctx: "tvm.relax.BlockBuilder" +) -> StructInfo: + """Derive the call's ret value struct info from inputs. + + Parameters + ---------- + func_sinfo: FuncStructInfo + The call's function signature. + + call: Call + The call expression + + ctx: tvm.relax.BlockBuilder + The context block builder. + + Returns + ------- + ret : StructInfo + The derived return value struct info. + + Note + ---- + This is an internal derivation function, call.op field is + ignored in this case and the derivation only depends on func_sinfo. + """ + return _ffi_api.DeriveCallRetStructInfo(func_sinfo, call, ctx) # type: ignore + + def struct_info_lca(lhs: StructInfo, rhs: StructInfo) -> StructInfo: """Unify the two struct info to their least common ancestor. diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py new file mode 100644 index 000000000000..77b45fdf5519 --- /dev/null +++ b/python/tvm/relax/block_builder.py @@ -0,0 +1,801 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, invalid-name +"""Developer API of constructing Relax AST.""" +import typing + +from typing import Dict, List, Optional, Union, Any, Callable +from tvm.ir.module import IRModule +from tvm.runtime import Object +from tvm import relax as rx, tir +import tvm +from .expr import ( + Expr, + te_tensor, + Var, + ShapeExpr, + GlobalVar, + BindingBlock, + Tuple, + BaseFunc, + Binding, +) +from .struct_info import ShapeStructInfo, StructInfo, TensorStructInfo +from .op.base import call_tir +from . import _ffi_api + + +class FunctionScope(object): + """Auxiliary scope for function""" + + def __init__(self, block_builder, name, params, attrs): + self._bb = block_builder + self._name = name + self._params = params + self._attrs = attrs + + def __enter__(self): + self._bb._enter_function_scope(self._name, self._params, self._attrs) + + def __exit__(self, exc_type, exc_val, exc_tb): + # __exit__ should properly handle the case where the with block exits with an exception + # when handling error case in exit, always check if there is already an exception + # been thrown in the with block + self._bb._exit_function_scope(exc_type, exc_val, exc_tb) + + +class DataflowScope(object): + """Auxiliary scope for Dataflow block""" + + def __init__(self, block_builder): + self._bb = block_builder + + def __enter__(self): + block = self._bb._end_block() + if len(block.bindings) > 0: + self._bb._blocks.append(block) + self._bb._begin_dataflow_block() + + def __exit__(self, ptype, value, trace): + block = self._bb._end_block() + if len(block.bindings) > 0: + self._bb._blocks.append(block) + self._bb._begin_binding_block() + + +class TestingScope(object): + """Auxiliary scope for testing purposes""" + + def __init__(self, block_builder, def_vars): + self._bb = block_builder + shape_vars = [] + for var in def_vars: + if isinstance(var, tvm.tir.Var): + shape_vars.append(var) + else: + raise ValueError("def_vars only can take tir.Var") + # setup a dummy var so shape is in scope. + sparam = rx.Var("sparam", rx.ShapeStructInfo(shape_vars)) + self._scope_params = [sparam] + + def __enter__(self): + self._bb.begin_scope(self._scope_params) + self._bb._begin_dataflow_block() + + def __exit__(self, ptype, value, trace): + self._bb._end_block() + self._bb.end_scope() + + +@tvm._ffi.register_object("relax.BlockBuilder") +class BlockBuilder(Object): + """A builder to build Relax IR for testing and dev. + + Examples + -------- + .. code-block:: python + + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16") + bb = rx.BlockBuilder() + with bb.function([x, y], "func"): + with bb.dataflow() as df: + lv0 = bb.emit(rx.add(x, y)) + lv1 = bb.emit(rx.multiply(lv0, y)) + gv0 = bb.emit_output(lv1) + bb.emit_func_output(gv0) + mod = bb.get() + + BlockBuilder can also be used to construct neural networks with nn.Module API + + .. code-block:: python + + from tvm.relax.testing import nn + + n = tir.Var("n", "int64") + input_size = 784 + hidden_sizes = [128, 32] + output_size = 10 + bb = rx.BlockBuilder() + + with bb.function("main"): + model = nn.Sequential( + nn.Linear(input_size, hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], output_size), + nn.LogSoftmax(), + ) + data = nn.Placeholder((n, input_size), name="data") + output = model(data) + params = [data] + model.parameters() + builder.emit_func_output(output, params=params) + mod = bb.get() + """ + + _current = None + + @staticmethod + def current(): + """Returns the current BlockBuilder.""" + return BlockBuilder._current + + def __init__(self, mod: IRModule = None): + self._blocks: List[BindingBlock] = [] + # a boolean flag that tracks if emit_func_output has been called + self._is_emit_func_output_called = False + self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate, mod) # type: ignore + + def _begin_dataflow_block(self) -> None: + _ffi_api.BlockBuilderBeginDataflowBlock(self) # type: ignore + + def _begin_binding_block(self) -> None: + _ffi_api.BlockBuilderBeginBindingBlock(self) # type: ignore + + def _end_block(self) -> BindingBlock: + return _ffi_api.BlockBuilderEndBlock(self) # type: ignore + + def _enter_function_scope(self, name, params, attrs): + if BlockBuilder.current() is not None: + raise RuntimeError("BlockBuilder does not allow nested functions.") + BlockBuilder._current = self + self._func_name = name + self._func_params = params + self._func_attrs = attrs + self.begin_scope(params) + self._begin_binding_block() + + def _exit_function_scope(self, exc_type, exc_val, exc_tb): + # record + is_emit_func_output_called = self._is_emit_func_output_called + # recover to default state + self._blocks = [] + self._is_emit_func_output_called = False + BlockBuilder._current = None + + # NOTE: we must raise after we recover the state so future + # block builder scoping functions correctly + if exc_type is None: + if not is_emit_func_output_called: + raise RuntimeError("emit_func_output must be called in a relax function.") + + def _convert_te_arg( + self, te_args: Any, tir_var_map: Dict[tir.Var, tir.PrimExpr] + ) -> typing.Tuple[Any, List[tvm.te.Tensor]]: + """Helper function used by `call_te` to convert Relax expressions to TE tensor. + + In the common case, the type of te_args is a Relax expression and is converted + into a TE tensor. + If te_args is a nested or recursive datatype (i.e list, dict, tvm.ir.Map, tvm.ir.Array), + we recursive and convert any value of type Relax expression into a TE tensor. + Common values of type int, float, and str are preserved. + + In dynamic shape cases, the passed in arguments may contain TIR variable. + For example, the argument can be a Relax Var with TensorStructInfo, which + has symbolic shape, or the argument can be a ShapeExpr with symbolic variables. + To make the PrimFunc generated by `call_te` has independent variables with + the caller Relax function, we will substitute the TIR variables in the input + arguments with fresh ones, which is done by maintaining a TIR variable mapping. + + Parameters + ---------- + te_args : Any + Argument to convert to TE + + tir_var_map : Dict[tir.Var, tir.PrimExpr] + The TIR variable mapping, which maps TIR variables on the Relax function + side to the new set of variables used on the PrimFunc side. + + Returns + ------- + ret : (Any, [tvm.te.Tensor]) + A tuple of the converted te_args, and a list of te tensors for each converted + Relax expression + """ + te_args_list = [] + + def _copy_undefined_var(expr: tir.PrimExpr): + def _visit_expr(e: tir.PrimExpr): + if isinstance(e, tir.Var) and e not in tir_var_map: + new_var = tir.Var(e.name, e.dtype) + tir_var_map[e] = new_var + + tir.stmt_functor.post_order_visit(expr, _visit_expr) + + def _convert_te_arg_helper(arg): + if isinstance(arg, Expr): # type: ignore + if isinstance(arg.struct_info, TensorStructInfo): + assert isinstance( + arg.struct_info.shape, ShapeExpr + ), "emit_te now only supports Tensor that has ShapeExpr shape" + for shape_value in arg.struct_info.shape.values: + _copy_undefined_var(shape_value) + + arg = te_tensor(arg, tir_var_map) + te_args_list.append(arg) + return arg + elif isinstance(arg.struct_info, ShapeStructInfo): + assert isinstance( + arg, ShapeExpr + ), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr" + return [_convert_te_arg_helper(val) for val in arg.values] + elif isinstance(arg, (list, tvm.ir.Array)): + return [_convert_te_arg_helper(x) for x in arg] + elif isinstance(arg, tuple): + return tuple([_convert_te_arg_helper(x) for x in arg]) + elif isinstance(arg, (dict, tvm.ir.Map)): + for key in arg: + assert isinstance( + key, str + ), "emit_te only supports dict with string as the key currently" + return {k: _convert_te_arg_helper(arg[k]) for k in arg} + elif isinstance(arg, tir.PrimExpr): + _copy_undefined_var(arg) + return tir.stmt_functor.substitute(arg, tir_var_map) + elif isinstance(arg, (int, float, str, tvm.ir.Type, tvm.ir.Attrs)) or arg is None: + return arg + raise TypeError("not supported type in emit_te: {}".format(type(arg))) + + new_arg = _convert_te_arg_helper(te_args) + return new_arg, te_args_list + + def _get_unbound_tir_vars(self, args: List[tvm.te.Tensor]) -> List[tvm.tir.Var]: + """get unbound TIR vars (i.e TIR vars used in the shape but is not + itself a dimension of a shape)""" + bound_vars = set() + used_vars = set() + + def _populate_used_vars(expr): + if isinstance(expr, tvm.tir.Var): + used_vars.add(expr) + + for x in args: + for s in x.shape: + tvm.tir.stmt_functor.post_order_visit(s, _populate_used_vars) + if isinstance(s, tir.Var): + bound_vars.add(s) + + diff = used_vars - bound_vars + return list(diff) + + def function( + self, + name: str, + params: Optional[Union[Var, Tuple, List[Var]]] = None, + attrs: Optional[Dict[str, Object]] = None, + ) -> FunctionScope: + """Annotate a Relax function. + + Parameters + ---------- + name : str, optional + The name of the function + + params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional + The parameters of the function. + If params is None, it means deferring initialization of function parameters + until emit_func_output. + + attrs : Dict[str, Object], optional + The function attrs + + Returns + ------- + ret: FunctionScope + A FunctionScope for building a Relax function node. + """ + if not params: + params = None + elif isinstance(params, rx.Var): + params = [params] + elif isinstance(params, (list, tuple)): + for param in params: + if not isinstance(param, rx.Var): + raise TypeError( + "each element of function parameters must be of type tvm.relax.Var,\ + but got: {}".format( + type(param) + ) + ) + if attrs is None: + attrs = {} + return FunctionScope(self, name, params, attrs) + + def testing_scope(self, def_vars: List[tir.Var]) -> TestingScope: + """Start a scope for unit-testing purposes. + + Parameters + ---------- + def_vars: List[tir.Var] + List of symbolic variables that are marked as defined in scope. + + Returns + ------- + ret: TestingScope + A TestingScope to setup builder for emit and other purposes. + """ + return TestingScope(self, def_vars) + + def dataflow(self) -> DataflowScope: + """Annotate a Relax dataflow block. + + Returns + ------- + ret: DataflowScope + A DataflowScope for building a Relax dataflow block. + """ + return DataflowScope(self) + + def emit(self, expr: Expr) -> Var: + """Emit an expr. + This infers the shape and type of the expr, create a variable, + and bind the expr to the variable. + + Parameters + ---------- + expr : tvm.relax.Expr + The Expr to be emitted. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that gets bound to the input expr. + """ + return _ffi_api.BlockBuilderEmit(self, expr) # type: ignore + + def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr: + """Generate a call node according to the te function. + This function converts arguments from relax expression to te tensor, + The callback func should return a te tensor or a list of te tensors. + Please see detailed example in emit_te + + Parameters + ---------- + func : Callable + A function that returns a te tensor or a list of te tensors. + + args : Any, optional + arguments passed to the function. + + kwargs : Any, optional + The keyword arguments passed to the function. + Note that the key "primfunc_name_hint" is reserved for passing name hint + to the PrimFunc that gets generated. + + Returns + ------- + ret : tvm.relax.Call + A newly created call node + """ + + primfunc_name_hint = kwargs.pop("primfunc_name_hint", None) + tir_var_map: Dict[tir.Var, tir.PrimExpr] = dict() + new_args, te_arg_list = self._convert_te_arg(args, tir_var_map) + new_kwargs, te_kwarg_list = self._convert_te_arg(kwargs, tir_var_map) + + te_args = te_arg_list + te_kwarg_list + + te_out = func(*new_args, **new_kwargs) + assert isinstance(te_out, tvm.te.tensor.Tensor) or ( + isinstance(te_out, (tuple, list, tvm.ir.Array)) + and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out) + ), "only support te.tensor or tuple/list/Array of te.tensor as function output" + + outs = [te_out] if isinstance(te_out, tvm.te.tensor.Tensor) else list(te_out) + unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs) + + inputs = [*te_args] + outs + tir_func = tvm.te.create_relax_prim_func(inputs, unbound_tir_vars, "int64") + + tir_func = tir_func.without_attr("global_symbol") + + if primfunc_name_hint: + gvar = self.add_func(tir_func, primfunc_name_hint) + else: + gvar = self.add_func(tir_func, func.__name__) + + call_args = [x.op.value for x in te_args] + + def _shape_with_old_tir_var( + shape_values: List[tir.PrimExpr], tir_var_inverse_map: Dict[tir.Var, tir.PrimExpr] + ): + return ShapeExpr( + [tir.stmt_functor.substitute(value, tir_var_inverse_map) for value in shape_values] + ) + + # Invert the TIR variable mapping, to convert the output shape back + # with old set of variables. + tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} + + output_sinfo = [ + TensorStructInfo(_shape_with_old_tir_var(out.shape, tir_var_inverse_map), out.dtype) + for out in outs + ] + + # add arguments for extra parameters from unbound var + if len(unbound_tir_vars) > 0: + call = call_tir( + gvar, + call_args, + output_sinfo, + tir_vars=_shape_with_old_tir_var(unbound_tir_vars, tir_var_inverse_map), + ) + else: + call = call_tir(gvar, call_args, output_sinfo) + return call + + def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var: + """Emit a call node according to the te function. + This function converts arguments from relax expression to te tensor, + The callback func should return a te tensor or a list of te tensors. + + Parameters + ---------- + func : Callable + A function that returns a te tensor or a list of te tensors. + + args : Any, optional + arguments passed to the function. + + kwargs : Any, optional + The keyword arguments passed to the function. + Note that the key "primfunc_name_hint" is reserved for passing name hint + to the PrimFunc that gets generated. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that gets bound to the call code. + + Example + ------- + + .. code-block:: python + + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) + + def te_func(args, args_dict, msg): + A = args[0] + B = args_dict["B"] + return te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + + with bb.function([x, y], "rx_func"): + out = bb.emit_te(te_func, [x], {"B": y}, msg="hello") + bb.emit_func_output(out) + + will result in TVMScript + + .. code-block:: python + + @tvm.script.ir_module + class Module: + @T.prim_func + def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, + var_compute: T.handle) -> None: + # function attr dict + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32") + compute = T.match_buffer(var_compute, [128, 128], dtype="float32") + # body + # with T.block("root") + for i0, i1 in T.grid(128, 128): + with T.block("compute"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads([rxplaceholder[i, j], rxplaceholder_1[i, j]]) + T.writes([compute[i, j]]) + compute[i, j] = rxplaceholder[i, j] + rxplaceholder_1[i, j] + + @R.function + def rx_func(x: Tensor((n, m), "float32"), y: Tensor((n, m), "float32")) -> Tensor: + # block 0 + gv = relax.call_tir("te_func", (x, y), R.Tensor((128, 128), "float32")) + return gv + + Example + ------- + + .. code-block:: python + + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", relax.TensorStructInfo([n], "float32")) + y = relax.Var("y", relax.TensorStructInfo([n + 1], "float32")) + + def te_func(A): + C = te.compute((n + 1), lambda i: A[i]) + return C + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, y) + bb.emit_func_output(x1) + + will result in TVMScript + + .. code-block:: python + + @tvm.script.ir_module + class Module: + @T.prim_func + def te_func(var_rxplaceholder: T.handle, var_compute: T.handle, n: T.int64) -> None: + rxplaceholder = T.match_buffer(var_rxplaceholder, [n + T.int64(1)], + dtype="float32") + compute = T.match_buffer(var_compute, [n + T.int64(1)], dtype="float32") + # body + # with T.block("root") + for i0 in T.serial(0, n + T.int64(1)): + with T.block("compute"): + i = T.axis.spatial(n + T.int64(1), i0) + T.reads([rxplaceholder[i]]) + T.writes([compute[i]]) + compute[i] = rxplaceholder[i] + + @R.function + def rx_func(x: Tensor((n,), "float32"), y: Tensor(((n + 1),), "float32")) + -> Tensor(None, "float32", ndim=-1): + # block 0 + gv = relax.call_tir(te_func, (y,), R.Tensor((n + 1,), "float32"), (n,)) + return gv + """ + return self.emit(self.call_te(func, *args, **kwargs)) + + def match_cast(self, value: Expr, struct_info: StructInfo) -> Var: + """Emit a MatchCast. + + Parameters + ---------- + value : tvm.relax.Expr + The value of the MatchCast to be emitted. + + struct_info : StructInfo + The struct info to be matched. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that get bounds to be the casted result. + """ + return _ffi_api.BlockBuilderEmitMatchCast(self, value, struct_info) # type: ignore + + def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None: + """Emit output for the current dataflow block or function. + + Parameters + ---------- + output : Expr | Tuple | List[Expr] + The output of the current block/function. + + Returns + ------- + ret : tvm.relax.Var + The return variable which gets bound to the output. + """ + if isinstance(output, (list, tuple)): + output = Tuple(output) + return _ffi_api.BlockBuilderEmitOutput(self, output) # type: ignore + + def emit_func_output( + self, + output: Union[Expr, Tuple, List[Expr]], + params: Optional[Union[Var, Tuple, List[Var]]] = None, + ) -> None: + """Emit output for the function. + + Parameters + ---------- + output : Expr | Tuple | List[Expr] + The output of the current block/function. + + params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional + The parameters of the function to be built. + If params is None, it means the params have been initialized in the function with scope. + + Returns + ------- + ret : tvm.relax.Var + The return variable which gets bound to the output. + """ + if self._is_emit_func_output_called: + raise RuntimeError("emit_func_output must be called exactly once in a relax function.") + self._is_emit_func_output_called = True + + if self._func_params is not None and params is not None: + raise RuntimeError( + "function parameters have been initialized in the function with scope." + ) + + if self._func_params is None and params is None: + raise RuntimeError("Relax function must have parameter.") + + if self._func_params is None: + self._func_params = params + + if BlockBuilder.current() is not self: + raise RuntimeError("BlockBuilder._current must be self.") + + if isinstance(output, (list, tuple)): + output = Tuple(output) + + block = self._end_block() + if len(block.bindings) > 0: + self._blocks.append(block) + seqe = self.normalize(rx.SeqExpr(self._blocks, output)) + + # do not specify ret_struct_info and let constructor deduce + # from seqe.struct_info + func = rx.Function(self._func_params, seqe) + for key, value in self._func_attrs.items(): + func = func.with_attr(key, value) + self.end_scope() + self.add_func(func, self._func_name) + + def normalize(self, expr: Expr) -> Expr: + """Normalize an Expr to complete its shape and type. + + Parameters + ---------- + expr : Expr + The input expr. + + Returns + ------- + ret : Expr + The expr with normalized shape and type. + """ + return _ffi_api.BlockBuilderNormalize(self, expr) # type: ignore + + def get(self) -> tvm.IRModule: + """Return the IRModule being built. + + Returns + ------- + ret : tvm.IRModule + An IRModule with Relax and TIR functions being built. + """ + return _ffi_api.BlockBuilderGetContextIRModule(self) # type: ignore + + def get_unique_name(self, name_prefix: str) -> str: + """Generate a unique name with a specified prefix. + + Parameters + ---------- + name_hint : str + The name prefix. + + Returns + ------- + ret : str + The generated name. + """ + return _ffi_api.BlockBuilderGetUniqueName(self, name_prefix) # type: ignore + + def add_func(self, func: BaseFunc, func_name: str) -> GlobalVar: + """Add a Relax function or a TIR PrimFunc to the IRModule being built. + + Parameters + ---------- + func : BaseFunc + The function to be added. + + func_name : str + The name of the function to be added. + + Returns + ------- + gvar : GlobalVar + The global var bound to the added function. + """ + return _ffi_api.BlockBuilderAddFunction(self, func, func_name) # type: ignore + + def update_func(self, gv: GlobalVar, updated_func: BaseFunc) -> None: + """Add a Relax function or a TIR PrimFunc to the IRModule being built. + + Parameters + ---------- + gv : GlobalVar + The global var referring the function to be updated. + + updated_func : BaseFunc + The updated function. + """ + return _ffi_api.BlockBuilderUpdateFunction(self, gv, updated_func) # type: ignore + + def current_block_is_dataflow(self) -> bool: + """Check if the block being built is DataflowBlock or not. + + Returns + ------- + ret : bool + A boolean that indicates if the block being built is DataflowBlock or not. + """ + return _ffi_api.BlockBuilderCurrentBlockIsDataFlow(self) # type: ignore + + def emit_normalized(self, binding: Binding) -> None: + """Emit an already normalized binding. + + Parameters + ---------- + binding: Binding + The binding to be emitted. + """ + _ffi_api.BlockBuilderEmitNormalized(self, binding) # type: ignore + + def lookup_binding(self, var: Var) -> Optional[Expr]: + """Lookup a var in the binding table binding_table_. + + Parameters + ---------- + var: Var + The input var. + + Returns + ------- + expr: Expr + The Expr bound to the input var. + """ + return _ffi_api.BlockBuilderLookupBinding(self, var) # type: ignore + + def begin_scope(self, params: Optional[List[Var]] = None) -> None: + """Begin a new scope, with optional parameters that + are visible within the scope. + + Parameters + ---------- + params: Optional[List[Var]] + Parameters that are visible within the scope. + + Note + ---- + This function should be called when new scope is introduced + (function, seq) to properly track the variable availability + and help the best effort deduction. + """ + + return _ffi_api.BlockBuilderBeginScope(self, params) # type: ignore + + def end_scope(self) -> None: + """End the current scope. Please see `begin_scope` for details""" + + return _ffi_api.BlockBuilderEndScope(self) # type: ignore diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py new file mode 100644 index 000000000000..0252720f6ee8 --- /dev/null +++ b/python/tvm/relax/expr_functor.py @@ -0,0 +1,1530 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, arguments-differ +"""The expression functor of Relax.""" +from typing import Callable, Optional + +import tvm +from tvm.ir import Op +from tvm.meta_schedule.utils import derived_object +from tvm.runtime import Object + +from ..ir.module import IRModule +from . import _ffi_api +from .block_builder import BlockBuilder +from .expr import ( + Binding, + BindingBlock, + Call, + Constant, + Id, + DataflowBlock, + DataflowVar, + DataTypeImm, + Expr, + ExternFunc, + Function, + GlobalVar, + If, + MatchCast, + PrimValue, + SeqExpr, + ShapeExpr, + Span, + StringImm, + Tuple, + TupleGetItem, + Var, + VarBinding, +) +from .struct_info import StructInfo + +visitor = derived_object +""" +A decorator to wrap user-customized PyExprVisitor as TVM object _PyExprVisitor. + +Parameters +---------- +visitor_cls : PyExprVisitor + The user-customized PyExprVisitor. + +Returns +------- +cls : _PyExprVisitor + The decorated TVM object _PyExprVisitor(ExprVisitor on the C++ side). + +Example +------- +.. code-block:: python + + @relax.expr_functor.visitor + class MyExprVisitor(PyExprVisitor): + # customize visit function + def visit_call_(self, op: Call) -> None: + # just for demo purposes + ... + # myvisitor is now a special visitor that visit every Call with + # user-customized visit_call_ + myvisitor = MyExprVisitor() + # apply myvisitor to Expr/Binding/BindingBlock/VarDef + myvisitor.visit_expr(expr) + myvisitor.visit_binding(binding) + myvisitor.visit_binding_block(bindingblock) + myvisitor.visit_var_def(var) +""" + +mutator = derived_object +""" +A decorator to wrap user-customized PyExprMutator as TVM object _PyExprMutator. +Note: Cannot override visit function and post-order rewrite at the same time. + +Parameters +---------- +mutator_cls : PyExprMutator + The user-customized PyExprMutator. + +Returns +------- +cls : _PyExprMutator + The decorated TVM object _PyExprMutator(ExprMutator on the C++ side). + +Example +------- +.. code-block:: python + + @relax.expr_functor.mutator + class MyExprMutator(PyExprMutator): + # customize rewrite function + def visit_tuple_(self, op: Tuple) -> Expr: + # just for demo purposes + ... + + # mymutator is now a special mutator that rewrite every Tuple with + # user-customized visit_tuple_ + mymutator = MyExprMutator() + # apply mymutator to Expr/Binding/BindingBlock/VarDef + mymutator.visit_expr(expr) + mymutator.visit_binding(binding) + mymutator.visit_binding_block(bindingblock) + mymutator.visit_var_def(var) +""" + + +class ExprFunctor: + """ + An abstract visitor defined over Expr. + Defines the default dispatch over expressions, and + implements memoization. + """ + + def visit_expr(self, expr: Expr) -> Expr: + """Apply the visitor to an expression.""" + if isinstance(expr, Constant): # type: ignore + ret = self.visit_constant_(expr) + elif isinstance(expr, Tuple): + ret = self.visit_tuple_(expr) + elif isinstance(expr, DataflowVar): + ret = self.visit_dataflow_var_(expr) + elif isinstance(expr, Var): + ret = self.visit_var_(expr) + elif isinstance(expr, ShapeExpr): + ret = self.visit_shape_expr_(expr) + elif isinstance(expr, ExternFunc): + ret = self.visit_extern_func_(expr) + elif isinstance(expr, GlobalVar): # type: ignore + ret = self.visit_global_var_(expr) + elif isinstance(expr, Function): + ret = self.visit_function_(expr) + elif isinstance(expr, Call): # type: ignore + ret = self.visit_call_(expr) + elif isinstance(expr, SeqExpr): + ret = self.visit_seq_expr_(expr) + elif isinstance(expr, If): # type: ignore + ret = self.visit_if_(expr) + elif isinstance(expr, Op): + ret = self.visit_op_(expr) + elif isinstance(expr, TupleGetItem): + ret = self.visit_tuple_getitem_(expr) + elif isinstance(expr, PrimValue): + ret = self.visit_prim_value_(expr) + elif isinstance(expr, StringImm): + ret = self.visit_string_imm_(expr) + elif isinstance(expr, DataTypeImm): + ret = self.visit_data_type_imm_(expr) + else: + raise TypeError("Invalid type: {0}".format(type(expr))) + + return ret + + def visit_constant_(self, op: Constant): + raise NotImplementedError() + + def visit_tuple_(self, op: Tuple): + raise NotImplementedError() + + def visit_dataflow_var_(self, op: DataflowVar): + raise NotImplementedError() + + def visit_var_(self, op: Var): + raise NotImplementedError() + + def visit_shape_expr_(self, op: ShapeExpr): + raise NotImplementedError() + + def visit_extern_func_(self, op: ExternFunc): + raise NotImplementedError() + + def visit_global_var_(self, op: GlobalVar): + raise NotImplementedError() + + def visit_function_(self, op: Function): + raise NotImplementedError() + + def visit_call_(self, op: Call): + raise NotImplementedError() + + def visit_seq_expr_(self, op: SeqExpr): + raise NotImplementedError() + + def visit_if_(self, op: If): + raise NotImplementedError() + + def visit_op_(self, op: Op): + raise NotImplementedError() + + def visit_tuple_getitem_(self, op: TupleGetItem): + raise NotImplementedError() + + def visit_prim_value_(self, op: PrimValue): + raise NotImplementedError() + + def visit_string_imm_(self, op: StringImm): + raise NotImplementedError() + + def visit_data_type_imm_(self, op: DataTypeImm): + raise NotImplementedError() + + def visit_var_binding_(self, binding: VarBinding): + raise NotImplementedError() + + def visit_match_cast_(self, binding: MatchCast): + raise NotImplementedError() + + def visit_binding_block_(self, block: BindingBlock): + raise NotImplementedError() + + def visit_dataflow_block_(self, block: DataflowBlock): + raise NotImplementedError() + + def visit_var_def_(self, var: Var): + raise NotImplementedError() + + def visit_dataflow_var_def_(self, var: DataflowVar): + raise NotImplementedError() + + def visit_binding(self, binding: Binding): + if isinstance(binding, MatchCast): + self.visit_match_cast_(binding) + elif isinstance(binding, VarBinding): + self.visit_var_binding_(binding) + else: + raise TypeError("Invalid type: {0}".format(type(binding))) + + def visit_binding_block(self, block: BindingBlock): + if isinstance(block, DataflowBlock): + self.visit_dataflow_block_(block) + elif isinstance(block, BindingBlock): + self.visit_binding_block_(block) + else: + raise TypeError("Invalid type: {0}".format(type(block))) + + def visit_var_def(self, var: Var): + if isinstance(var, DataflowVar): + self.visit_dataflow_var_def_(var) + elif isinstance(var, Var): + self.visit_var_def_(var) + else: + raise TypeError("Invalid type: {0}".format(type(var))) + + +@tvm._ffi.register_object("expr_functor.PyExprVisitor") +class _PyExprVisitor(Object): + """ + A TVM object to support customization of ExprVisitor on the python side. + This is the decorated result returned from visitor decorator. + + WARNING: This is NOT the user facing class for method overwriting inheritance. + + See also: visitor, PyExprVisitor + """ + + def __init__( + self, + f_visit_expr: Callable = None, + f_visit_constant_: Callable = None, + f_visit_tuple_: Callable = None, + f_visit_var_: Callable = None, + f_visit_dataflow_var_: Callable = None, + f_visit_shape_expr_: Callable = None, + f_visit_extern_func_: Callable = None, + f_visit_global_var_: Callable = None, + f_visit_function_: Callable = None, + f_visit_call_: Callable = None, + f_visit_seq_expr_: Callable = None, + f_visit_if_: Callable = None, + f_visit_op_: Callable = None, + f_visit_tuple_getitem_: Callable = None, + f_visit_prim_value_: Callable = None, + f_visit_string_imm_: Callable = None, + f_visit_data_type_imm_: Callable = None, + f_visit_binding: Callable = None, + f_visit_var_binding_: Callable = None, + f_visit_match_cast_: Callable = None, + f_visit_binding_block: Callable = None, + f_visit_binding_block_: Callable = None, + f_visit_dataflow_block_: Callable = None, + f_visit_var_def: Callable = None, + f_visit_var_def_: Callable = None, + f_visit_dataflow_var_def_: Callable = None, + f_visit_span: Callable = None, + ) -> None: + """Constructor.""" + + self.__init_handle_by_constructor__( + _ffi_api.MakePyExprVisitor, # type: ignore + f_visit_expr, + f_visit_constant_, + f_visit_tuple_, + f_visit_var_, + f_visit_dataflow_var_, + f_visit_shape_expr_, + f_visit_extern_func_, + f_visit_global_var_, + f_visit_function_, + f_visit_call_, + f_visit_seq_expr_, + f_visit_if_, + f_visit_op_, + f_visit_tuple_getitem_, + f_visit_prim_value_, + f_visit_string_imm_, + f_visit_data_type_imm_, + f_visit_binding, + f_visit_var_binding_, + f_visit_match_cast_, + f_visit_binding_block, + f_visit_binding_block_, + f_visit_dataflow_block_, + f_visit_var_def, + f_visit_var_def_, + f_visit_dataflow_var_def_, + f_visit_span, + ) + + def visit_expr(self, expr: Expr) -> None: + """Generic dispatcher for Expr. + + Parameters + ---------- + expr : Expr + The expr to be visited. + """ + return _ffi_api.PyExprVisitorVisitExpr(self, expr) # type: ignore + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + return _ffi_api.PyExprVisitorVisitBinding(self, binding) # type: ignore + + def visit_binding_block(self, block: BindingBlock) -> None: + """Generic dispatcher for BindingBlock. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + """ + return _ffi_api.PyExprVisitorVisitBindingBlock(self, block) # type: ignore + + def visit_var_def(self, var: Var) -> None: + """Generic dispatcher for visiting the var definition site. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + """ + return _ffi_api.PyExprVisitorVisitVarDef(self, var) # type: ignore + + +class PyExprVisitor: + """ + An abstract ExprVisitor with customized methods on the python-side. + This is the user facing class for method overwriting inheritance. + _tvm_metadata discribes the class to inherit("cls"), the methods + that users can overwrite("methods"). + + Note: @relax.expr_functor.visitor is required for proper usage of any inherited class. + + See also: visitor, _PyExprVisitor + + Example: + @relax.expr_functor.visitor + def MyExprVisitor(PyExprVisitor): + ... + """ + + _tvm_metadata = { + "cls": _PyExprVisitor, + "methods": [ + "visit_expr", + "visit_constant_", + "visit_tuple_", + "visit_var_", + "visit_dataflow_var_", + "visit_shape_expr_", + "visit_extern_func_", + "visit_global_var_", + "visit_function_", + "visit_call_", + "visit_seq_expr_", + "visit_if_", + "visit_op_", + "visit_tuple_getitem_", + "visit_prim_value_", + "visit_string_imm_", + "visit_data_type_imm_", + "visit_binding", + "visit_var_binding_", + "visit_match_cast_", + "visit_binding_block", + "visit_binding_block_", + "visit_dataflow_block_", + "visit_var_def", + "visit_var_def_", + "visit_dataflow_var_def_", + "visit_span", + ], + } + + def visit_expr(self, expr: Expr) -> None: + """Generic dispatcher for Expr. + Users can customized this function to overwrite VisitExpr(const Expr& expr) on the C++ side. + + Parameters + ---------- + expr : Expr + The expr to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitExpr(self._outer(), expr) # type: ignore + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + Users can customized this function to overwrite VisitBinding(const Binding& binding) + on the C++ side. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitBinding(self._outer(), binding) # type: ignore + + def visit_binding_block(self, block: BindingBlock) -> None: + """Generic dispatcher for BindingBlock. + Users can customized this function to overwrite VisitBindingBlock(const BindingBlock& block) + on the C++ side. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_var_def(self, var: Var) -> None: + """Generic dispatcher for visiting the var definition site. + Users can customized this function to overwrite VisitVarDef(const Var& var) on the C++ side. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitVarDef(self._outer(), var) # type: ignore + + def visit_constant_(self, op: Constant) -> None: + """Visit Constant. + Users can customized this function to overwrite VisitExpr_(const ConstantNode* op) + on the C++ side. + + Parameters + ---------- + op : Constant + The Constant to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_tuple_(self, op: Tuple) -> None: + """Visit Tuple. + Users can customized this function to overwrite VisitExpr_(const TupleNode* op) + on the C++ side. + + Parameters + ---------- + op : Tuple + The Tuple to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_var_(self, op: Var) -> None: + """Visit Var. + Users can customized this function to overwrite VisitExpr_(const VarNode* op) + on the C++ side. + + Parameters + ---------- + op : Var + The Var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_dataflow_var_(self, op: DataflowVar) -> None: + """Visit DataflowVar. + Users can customized this function to overwrite VisitExpr_(const DataflowVarNode* op) + on the C++ side. + + Parameters + ---------- + op : DataflowVar + The DataflowVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_shape_expr_(self, op: ShapeExpr) -> None: + """Visit ShapeExpr. + Users can customized this function to overwrite VisitExpr_(const ShapeExprNode* op) + on the C++ side. + + Parameters + ---------- + op : ShapeExpr + The ShapeExpr to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_extern_func_(self, op: ExternFunc) -> None: + """Visit ExternFunc. + Users can customized this function to overwrite VisitExpr_(const ExternFuncNode* op) + on the C++ side. + + Parameters + ---------- + op : ExternFunc + The ExternFunc to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_global_var_(self, op: GlobalVar) -> None: + """Visit GlobalVar. + Users can customized this function to overwrite VisitExpr_(const GlobalVarNode* op) + on the C++ side. + + Parameters + ---------- + op : GlobalVar + The GlobalVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_function_(self, op: Function) -> None: + """Visit Function. + Users can customized this function to overwrite VisitExpr_(const FunctionNode* op) + on the C++ side. + + Parameters + ---------- + op : Function + The Function to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_call_(self, op: Call) -> None: + """Visit Call. + Users can customized this function to overwrite VisitExpr_(const CallNode* op) + on the C++ side. + + Parameters + ---------- + op : Call + The Call to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_seq_expr_(self, op: SeqExpr) -> None: + """Visit SeqExpr. + Users can customized this function to overwrite VisitExpr_(const SeqExprNode* op) + on the C++ side. + + Parameters + ---------- + op : SeqExpr + The SeqExpr to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_if_(self, op: If) -> None: + """Visit If. + Users can customized this function to overwrite VisitExpr_(const IfNode* op) + on the C++ side. + + Parameters + ---------- + op : If + The If to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_op_(self, op: Op) -> None: + """Visit Op. + Users can customized this function to overwrite VisitExpr_(const OpNode* op) + on the C++ side. + + Parameters + ---------- + op : Op + The Op to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_tuple_getitem_(self, op: TupleGetItem) -> None: + """Visit TupleGetItem. + Users can customized this function to overwrite VisitExpr_(const TupleGetItemNode* op) + on the C++ side. + + Parameters + ---------- + op : TupleGetItem + The TupleGetItem to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_prim_value_(self, op: PrimValue) -> None: + """Visit PrimValue. + Users can customized this function to overwrite VisitExpr_(const PrimValueNode* op) + on the C++ side. + + Parameters + ---------- + op : PrimValue + The PrimValue to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_string_imm_(self, op: StringImm) -> None: + """Visit StringImm. + Users can customized this function to overwrite VisitExpr_(const StringImmNode* op) + on the C++ side. + + Parameters + ---------- + op : StringImm + The StringImm to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_data_type_imm_(self, op: DataTypeImm) -> None: + """Visit DataTypeImm. + Users can customized this function to overwrite VisitExpr_(const DataTypeImmNode* op) + on the C++ side. + + Parameters + ---------- + op : DataTypeImm + The DataTypeImm to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) # type: ignore + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Visit VarBinding. + Users can customized this function to overwrite VisitBinding_(const VarBindingNode* binding) + on the C++ side. + + Parameters + ---------- + binding : VarBinding + The VarBinding to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBinding(self._outer(), binding) # type: ignore + + def visit_match_cast_(self, binding: MatchCast) -> None: + """Visit MatchCast. + Users can customized this function to overwrite VisitBinding_(const MatchCastNode* binding) + on the C++ side. + + Parameters + ---------- + binding : MatchCast + The MatchCast to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBinding(self._outer(), binding) # type: ignore + + def visit_binding_block_(self, block: BindingBlock) -> None: + """Visit BindingBlock. + Users can customized this function to overwrite VisitBindingBlock_(const BindingBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : BindingBlock + The BindingBlock to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + """Visit DataflowBlock. + Users can customized this function to overwrite VisitBindingBlock_(const DataflowBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : DataflowBlock + The DataflowBlock to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_var_def_(self, var: Var) -> None: + """Visit the Var definition site. + Users can customized this function to overwrite VisitVarDef_(const VarNode* var) + on the C++ side. + + Parameters + ---------- + var : Var + The Var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitVarDef(self._outer(), var) # type: ignore + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + """Visit the DataflowVar definition site. + Users can customized this function to overwrite VisitVarDef_(const DataflowVarNode* var) + on the C++ side. + + Parameters + ---------- + var : DataflowVar + The DataflowVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitVarDef(self._outer(), var) # type: ignore + + def visit_span(self, span: Span) -> None: + """Visit Span. + Users can customized this function to overwrite VisitSpan(const Span& span) on the C++ side. + + Parameters + ---------- + span : Span + The Span to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitSpan(self._outer(), span) # type: ignore + + +@tvm._ffi.register_object("expr_functor.PyExprMutator") +class _PyExprMutator(Object): + """ + A TVM object to support customization of ExprMutator on the python side. + This is the decorated result returned from mutator decorator. + + WARNING: This is NOT the user facing class for method overwriting inheritance. + + See also: mutator, PyExprmutator + """ + + def __init__( + self, + builder: BlockBuilder = None, + f_visit_expr: Callable = None, + f_visit_constant_: Callable = None, + f_visit_tuple_: Callable = None, + f_visit_var_: Callable = None, + f_visit_dataflow_var_: Callable = None, + f_visit_shape_expr_: Callable = None, + f_visit_extern_func_: Callable = None, + f_visit_global_var_: Callable = None, + f_visit_function_: Callable = None, + f_visit_call_: Callable = None, + f_visit_seq_expr_: Callable = None, + f_visit_if_: Callable = None, + f_visit_op_: Callable = None, + f_visit_tuple_getitem_: Callable = None, + f_visit_prim_value_: Callable = None, + f_visit_string_imm_: Callable = None, + f_visit_data_type_imm_: Callable = None, + f_visit_binding: Callable = None, + f_visit_var_binding_: Callable = None, + f_visit_match_cast_: Callable = None, + f_visit_binding_block: Callable = None, + f_visit_binding_block_: Callable = None, + f_visit_dataflow_block_: Callable = None, + f_visit_var_def: Callable = None, + f_visit_var_def_: Callable = None, + f_visit_dataflow_var_def_: Callable = None, + f_visit_span: Callable = None, + ) -> None: + """Constructor.""" + + self.__init_handle_by_constructor__( + _ffi_api.MakePyExprMutator, # type: ignore + builder, + f_visit_expr, + f_visit_constant_, + f_visit_tuple_, + f_visit_var_, + f_visit_dataflow_var_, + f_visit_shape_expr_, + f_visit_extern_func_, + f_visit_global_var_, + f_visit_function_, + f_visit_call_, + f_visit_seq_expr_, + f_visit_if_, + f_visit_op_, + f_visit_tuple_getitem_, + f_visit_prim_value_, + f_visit_string_imm_, + f_visit_data_type_imm_, + f_visit_binding, + f_visit_var_binding_, + f_visit_match_cast_, + f_visit_binding_block, + f_visit_binding_block_, + f_visit_dataflow_block_, + f_visit_var_def, + f_visit_var_def_, + f_visit_dataflow_var_def_, + f_visit_span, + ) + + def visit_expr(self, expr: Expr) -> Expr: + """Generic dispatcher for Expr. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation. + """ + return _ffi_api.PyExprMutatorVisitExpr(self, expr) # type: ignore + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + return _ffi_api.PyExprMutatorVisitBinding(self, binding) # type: ignore + + def visit_binding_block(self, block: BindingBlock) -> BindingBlock: + """Generic dispatcher for BindingBlock. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation. + """ + return _ffi_api.PyExprMutatorVisitBindingBlock(self, block) # type: ignore + + def visit_var_def(self, var: Var) -> Var: + """Generic dispatcher for visiting the var definition site. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + return _ffi_api.PyExprMutatorVisitVarDef(self, var) # type: ignore + + +class PyExprMutator: + """ + An abstract ExprMutator with customized methods on the python-side. + This is the user facing class for method overwriting inheritance. + _tvm_metadata discribes the class to inherit("cls"), the methods that users can + overwrite("methods"), the constructor's parameters("fields") + + Note: @relax.expr_functor.mutator is required for proper usage of any inherited class. + + See also: visitor, _PyExprVisitor + + Example: + @relax.expr_functor.mutator + def MyExprMutator(PyExprMutator): + ... + """ + + _tvm_metadata = { + "cls": _PyExprMutator, + "fields": ["builder_"], + "methods": [ + "visit_expr", + "visit_constant_", + "visit_tuple_", + "visit_var_", + "visit_dataflow_var_", + "visit_shape_expr_", + "visit_extern_func_", + "visit_global_var_", + "visit_function_", + "visit_call_", + "visit_seq_expr_", + "visit_if_", + "visit_op_", + "visit_tuple_getitem_", + "visit_prim_value_", + "visit_string_imm_", + "visit_data_type_imm_", + "visit_binding", + "visit_var_binding_", + "visit_match_cast_", + "visit_binding_block", + "visit_binding_block_", + "visit_dataflow_block_", + "visit_var_def", + "visit_var_def_", + "visit_dataflow_var_def_", + "visit_span", + ], + } + + def __init__(self, mod: Optional[IRModule] = None) -> None: + """Constructor""" + self.builder_ = BlockBuilder(mod) + + def visit_expr(self, expr: Expr) -> Expr: + """Generic dispatcher for Expr. + Users can customized this function to overwrite VisitExpr(const Expr& expr) on the C++ side. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitExpr(self._outer(), expr) # type: ignore + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + Users can customized this function to overwrite VisitBinding(const Binding& binding) + on the C++ side. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitBinding(self._outer(), binding) # type: ignore + + def visit_binding_block(self, block: BindingBlock) -> BindingBlock: + """Generic dispatcher for BindingBlock. + Users can customized this function to overwrite VisitBindingBlock(const BindingBlock& block) + on the C++ side. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_var_def(self, var: Var) -> Var: + """Generic dispatcher for visiting the var definition site. + Users can customized this function to overwrite VisitVarDef(const Var& var) on the C++ side. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + + Returns + ------- + result: Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitVarDef(self._outer(), var) # type: ignore + + def visit_constant_(self, op: Constant) -> Expr: + """Visit Constant. + Users can customized this function to overwrite VisitExpr_(const ConstantNode* op) + on the C++ side. + + Parameters + ---------- + op : Constant + The Constant to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_tuple_(self, op: Tuple) -> Expr: + """Visit Tuple. + Users can customized this function to overwrite VisitExpr_(const TupleNode* op) + on the C++ side. + + Parameters + ---------- + op : Tuple + The Tuple to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_var_(self, op: Var) -> Expr: + """Visit Var. + Users can customized this function to overwrite VisitExpr_(const VarNode* op) + on the C++ side. + + Parameters + ---------- + op : Var + The Var to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_dataflow_var_(self, op: DataflowVar) -> Expr: + """Visit DataflowVar. + Users can customized this function to overwrite VisitExpr_(const DataflowVarNode* op) + on the C++ side. + + Parameters + ---------- + op : DataflowVar + The DataflowVar to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_shape_expr_(self, op: ShapeExpr) -> Expr: + """Visit ShapeExpr. + Users can customized this function to overwrite VisitExpr_(const ShapeExprNode* op) + on the C++ side. + + Parameters + ---------- + op : ShapeExpr + The ShapeExpr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_extern_func_(self, op: ExternFunc) -> Expr: + """Visit ExternFunc. + Users can customized this function to overwrite VisitExpr_(const ExternFuncNode* op) + on the C++ side. + + Parameters + ---------- + op : ExternFunc + The ExternFunc to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_global_var_(self, op: GlobalVar) -> Expr: + """Visit GlobalVar. + Users can customized this function to overwrite VisitExpr_(const GlobalVarNode* op) + on the C++ side. + + Parameters + ---------- + op : GlobalVar + The GlobalVar to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_function_(self, op: Function) -> Expr: + """Visit Function. + Users can customized this function to overwrite VisitExpr_(const FunctionNode* op) + on the C++ side. + + Parameters + ---------- + op : Function + The Function to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_call_(self, op: Call) -> Expr: + """Visit Call. + Users can customized this function to overwrite VisitExpr_(const CallNode* op) + on the C++ side. + + Parameters + ---------- + op : Call + The Call to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_seq_expr_(self, op: SeqExpr) -> Expr: + """Visit SeqExpr. + Users can customized this function to overwrite VisitExpr_(const SeqExprNode* op) + on the C++ side. + + Parameters + ---------- + op : SeqExpr + The SeqExpr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_if_(self, op: If) -> Expr: + """Visit If. + Users can customized this function to overwrite VisitExpr_(const IfNode* op) + on the C++ side. + + Parameters + ---------- + op : If + The If to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_op_(self, op: Op) -> Expr: + """Visit Op. + Users can customized this function to overwrite VisitExpr_(const OpNode* op) + on the C++ side. + + Parameters + ---------- + op : Op + The Op to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_tuple_getitem_(self, op: TupleGetItem) -> Expr: + """Visit TupleGetItem. + Users can customized this function to overwrite VisitExpr_(const TupleGetItemNode* op) + on the C++ side. + + Parameters + ---------- + op : TupleGetItem + The TupleGetItem to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_prim_value_(self, op: PrimValue) -> Expr: + """Visit PrimValue. + Users can customized this function to overwrite VisitExpr_(const PrimValueNode* op) + on the C++ side. + + Parameters + ---------- + op : PrimValue + The PrimValue to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_string_imm_(self, op: StringImm) -> Expr: + """Visit StringImm. + Users can customized this function to overwrite VisitExpr_(const StringImmNode* op) + on the C++ side. + + Parameters + ---------- + op : StringImm + The StringImm to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_data_type_imm_(self, op: DataTypeImm) -> Expr: + """Visit DataTypeImm. + Users can customized this function to overwrite VisitExpr_(const DataTypeImmNode* op) + on the C++ side. + + Parameters + ---------- + op : DataTypeImm + The DataTypeImm to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) # type: ignore + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Visit VarBinding. + Users can customized this function to overwrite VisitBinding_(const VarBindingNode* binding) + on the C++ side. + + Parameters + ---------- + binding : VarBinding + The VarBinding to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBinding(self._outer(), binding) # type: ignore + + def visit_match_cast_(self, binding: MatchCast) -> None: + """Visit MatchCast. + Users can customized this function to overwrite VisitBinding_(const MatchCastNode* binding) + on the C++ side. + + Parameters + ---------- + binding : MatchCast + The MatchCast to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBinding(self._outer(), binding) # type: ignore + + def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: + """Visit BindingBlock. + Users can customized this function to overwrite VisitBindingBlock_(const BindingBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : BindingBlock + The BindingBlock to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_dataflow_block_(self, block: DataflowBlock) -> BindingBlock: + """Visit DataflowBlock. + Users can customized this function to overwrite VisitBindingBlock_(const DataflowBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : DataflowBlock + The DataflowBlock to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBindingBlock(self._outer(), block) # type: ignore + + def visit_var_def_(self, var: Var) -> Var: + """Visit the Var definition site. + Users can customized this function to overwrite VisitVarDef_(const VarNode* var) + on the C++ side. + + Parameters + ---------- + var : Var + The Var to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitVarDef(self._outer(), var) # type: ignore + + def visit_dataflow_var_def_(self, var: DataflowVar) -> Var: + """Visit the DataflowVar definition site. + Users can customized this function to overwrite VisitVarDef_(const DataflowVarNode* var) + on the C++ side. + + Parameters + ---------- + var : DataflowVar + The DataflowVar to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitVarDef(self._outer(), var) # type: ignore + + def visit_span(self, span: Span) -> Span: + """Visit Span. + Users can customized this function to overwrite VisitSpan(const Span& span) on the C++ side. + + Parameters + ---------- + span : Span + The Span to be visited. + + Returns + ------- + result : Span + The span after transformation. + """ + raise NotImplementedError + + def visit_expr_post_order(self, expr: Expr) -> Expr: + """Post-order rewrite an Expr and normalize. + + Parameters + ---------- + expr : Expr + The Expr to be rewritten. + + Returns + ------- + result : Expr + The Expr after post-order rewritten. + """ + return _ffi_api.PyExprMutatorVisitExprPostOrder(self._outer(), expr) # type: ignore + + def set_var_remap(self, vid: Id, var: Var) -> None: + """Remap a var to a new var in use-site. + + Parameters + ---------- + vid : Id + The vid of the old var. + var : Var + The new var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorSetVarRemap(self._outer(), vid, var) # type: ignore + + def get_var_remap(self, vid: Id) -> Var: + """Remap a var to a new var in use-site. + + Parameters + ---------- + vid : Id + The vid of the old var + + Returns + ------- + var : Var + The remapped var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorGetVarRemap(self._outer(), vid) # type: ignore + + def visit_with_new_scope(self, expr: Expr) -> Expr: + """Rewrite the expr with a new scope, used in a Function's body and the branches of If. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + var : Var + The expr after visiting. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitWithNewScope(self._outer(), expr) # type: ignore + + def lookup_binding(self, var: Var) -> Optional[Expr]: + """Look up the value bound to a variable. + Note: For function parameters, this function returns NullOpt. + + Parameters + ---------- + var : Var + The var to be looked up. + + Returns + ------- + var : Var + The value bound to the input var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorLookupBinding(self._outer(), var) # type: ignore + + def with_struct_info(self, var: Var, struct_info: StructInfo) -> Var: + """Create a new var with specified shape and type if the original var's shape or type does + not match with the specified ones. + + Parameters + ---------- + var : Var + The var to be updated. + struct_info : StructInfo + The struct info. + + Returns + ------- + var : Var + The var filled with shape and type. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorWithStructInfo(self._outer(), var, struct_info) # type: ignore diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py new file mode 100644 index 000000000000..101b0827d630 --- /dev/null +++ b/python/tvm/relax/op/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax core operators.""" + +# Operators +from .base import * +from .binary import * diff --git a/python/tvm/relax/op/_ffi_api.py b/python/tvm/relax/op/_ffi_api.py new file mode 100644 index 000000000000..8dc6a1b4fbb0 --- /dev/null +++ b/python/tvm/relax/op/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.relax.op""" +import tvm._ffi + +tvm._ffi._init_api("relax.op", __name__) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py new file mode 100644 index 000000000000..d76b155beb83 --- /dev/null +++ b/python/tvm/relax/op/base.py @@ -0,0 +1,358 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# pylint: disable=redefined-builtin +"""The base Relax operators.""" +from typing import Union, List, Tuple, Optional + + +import tvm +from tvm.runtime.object import Object + +from . import _ffi_api +from ..expr import Expr, ShapeExpr, Call, ExternFunc +from ..expr import Tuple as RxTuple +from ..struct_info import StructInfo, TensorStructInfo +from ...ir import PrimExpr +from ..utils import args_converter + + +py_print = print # pylint: disable=invalid-name + + +def null_value() -> Call: + """Create a call node that represents a null value object. + + Returns + ------- + ret: Call + The created call node. + """ + return _ffi_api.null_value() # type: ignore + + +@args_converter.auto +def call_tir( + func: Union[str, Expr], + args: Expr, + out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]], + tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None, +) -> Call: + """ + Call a destination-passing-style function and return the output. + + Parameters + ---------- + func : Union[str, Expr] + The destination-passing-style function, can be ExternFunc or PrimFunc. + + args : Expr + The input arguments. + + out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] + The structure info of the call_tir output. + It should be a single or a list of TensorStructInfo. Each one denotes the + structure info of a returned tensor. + + tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] + ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used + + Returns + ------- + ret: Call + A call node for the call_tir operator. + """ + if isinstance(func, str): + func = ExternFunc(func) + + if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + args = RxTuple((args,)) + + if not isinstance(out_sinfo, list): + out_sinfo = [out_sinfo] + + if isinstance(tir_vars, (list, tuple)): + tir_vars = ShapeExpr(tir_vars) + + return _ffi_api.call_tir(func, args, out_sinfo, tir_vars) # type: ignore + + +@args_converter.auto +def call_builtin_with_ctx( + func: Union[str, Expr], + args: Expr, + *, + sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] = None, +) -> Call: + """Call a builtin function func. + + Parameters + ---------- + func : Expr + The builtin function to be called. + + args : Expr + The input arguments. + + sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] + The struct info arguments to the call node. + + Returns + ------- + ret: Call + The created call node. + """ + if isinstance(func, str): + func = ExternFunc(func) + + if sinfo_args is not None and not isinstance(sinfo_args, (list, tuple)): + sinfo_args = [sinfo_args] + + return _ffi_api.call_builtin_with_ctx( # type: ignore + func, + args, + sinfo_args, # type: ignore + ) + + +@args_converter.auto +def make_closure( + func: Expr, + args: Expr, +) -> Object: + """ + Create a closure with free variables and return the closure. + + Parameters + ---------- + func : Expr + The closure, can be ExternFunc or PrimFunc. + + args : Expr + The input arguments. + + + Returns + ------- + ret: Object + The VMClosure. + """ + + return _ffi_api.make_closure(func, args) # type: ignore + + +@args_converter.auto +def invoke_closure( + closure: Expr, + args: Expr, + sinfo_args: Union[List[StructInfo], StructInfo], +) -> Object: + """ + Invoke a closure. + + Parameters + ---------- + closure : Expr + The VMClosure object. + + args : Expr + The input arguments. + + type_args: Union[List[StructInfo], StructInfo] + The structure info arguments of the CallNode + + Returns + ------- + ret: Object + The result. + """ + + if not isinstance(sinfo_args, (list, tuple)): + sinfo_args = [sinfo_args] + + return _ffi_api.invoke_closure(closure, args, sinfo_args) # type: ignore + + +def render_object(val: tvm.Object) -> str: + """ + Given a TVM Object, renders it in string form. Used for Relax printing and assertions. + + Parameters + ---------- + val: tvm.Object + An object to render + + Returns + ------- + ret: str + A string representing the value, ideally human-readable + """ + if isinstance(val, tvm.runtime.ndarray.NDArray): + return str(val) + # no pretty-printer by default, so if we don't handle this, + # then we can't look inside tuples + if isinstance(val, tvm.runtime.container.ADT): + # the fields array of an ADT cannot be directly accessed in Python + # so we have to get the length and index into the fields separately + fields = ", ".join([render_object(val[i]) for i in range(len(val))]) + # special case: tag = 0 is a tuple + if val.tag == 0: + return f"({fields})" + return f"ADT(tag={val.tag}, fields=[{fields}])" + return str(val) + + +@tvm.register_func("relax.run.print") +def relax_print(format_str: str, *format_args: tvm.Object) -> None: + """ + Takes a list of values to print, formats with the given format string. + If the format string is empty, simply prints. + + Call from TVM script like this: + `relax.print(value1, value2, ..., valueN, format=format_str)` + or + `relax.print(value1, value2, ..., valueN) # format_str defaults to ""` + + Parameters + ---------- + format_str: str + The last argument is a Python-style format string for printing the value + + format_args: List[Object] + The values to print. + """ + val_strs = map(render_object, format_args) + if format_str == "": + py_print(*val_strs) + else: + py_print(format_str.format(*val_strs)) + + +def print(*values: List[Expr], format: str = "") -> Expr: + """Print op to print the values + + Parameters + ---------- + values : List[Expr] + The values to print. + + format_str: str + The format string. + + Returns + ------- + result : Expr + A relax Call, which will print the value during runtime. + """ + return _ffi_api.print(values, format) # type: ignore # pylint: disable=no-member + + +@tvm.register_func("relax.run.assert_op") +def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Object) -> None: + """ + A variadic function. The first value serves as the assertion condition: + If the condition is true, then the operator does nothing. + If the condition is false, then the operator raises an assertion error. + + Arguments after the first value serve as format arguments for the error message; + the last argument must be a format string for the error message (empty by default). + If the format string is the empty string, then the error message will simply include + a comma-separated list of the format arguments. + The condition argument is not included in the format string. + + Parameters + ---------- + condition: tvm.Object + The assertion condition. Must be a boolean scalar. + + format_str: str + The last argument is a Python-style format string for printing the value + + format_args: List[tvm.Object] + Values used for formatting the string. + """ + if not isinstance(format_str, str): + raise ValueError( + f"The format string argument to assert must be a string, given {type(format_str)})" + ) + + # should be guaranteed by the type system + if not isinstance(condition, tvm.runtime.ndarray.NDArray): + raise ValueError(f"The condition must be an NDArray, but given a {type(condition)}.") + + # may happen if the original program had unknown shape or dtype for the tensor's type + dtype = condition.dtype + if dtype != "bool": + raise ValueError(f"The condition must be a bool scalar, but given a {dtype} tensor") + shape = condition.shape + if len(shape) != 0: + raise ValueError(f"The condition must be a scalar, but it has a shape of {shape}") + + val = condition.numpy() + if not val: + error_message = "Assertion Failed" + if format_args or format_str != "": + rendered = map(render_object, format_args) + if format_str != "": + error_message = format_str.format(*rendered) + else: + error_message = ", ".join(rendered) + raise AssertionError(error_message) + + +def assert_op( + condition: Expr, format_args: Optional[Union[Expr, List[Expr]]] = None, format: str = "" +) -> Expr: + """ + Create a call to Relax's assert_op operation (`assert` is reserved in Python, + so the name must be distinct). + + Parameters + ---------- + condition: Expr + The assertion condition. + + format_args: Optional[Union[Expr, List[Expr]]] + Format arguments for the error message if the condition fails. + + format_str: str + The format string for the error message. + + Returns + ------- + result : Expr + A Call to the Relax assert operation. + """ + if format_args is None: + format_args = [] + if isinstance(format_args, Expr): # type: ignore + format_args = [format_args] + return _ffi_api.assert_op(condition, format_args, format) # type: ignore + + +def shape_of(expr: Expr) -> Expr: + """Get shape of a tensor. + + Parameters + ---------- + expr : Expr + The input Expr. + + Returns + ------- + result : Expr + A relax Call, which gets the shape of the input + """ + return _ffi_api.shape_of(expr) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py new file mode 100644 index 000000000000..eee0b6f3366a --- /dev/null +++ b/python/tvm/relax/op/binary.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, invalid-name +"""Relax binary arithmetic and comparison operators.""" +from . import _ffi_api +from ..expr import Expr + +###################### Arithmetic operators ###################### + + +def add(x1: Expr, x2: Expr) -> Expr: + """Addition with numpy-style broadcasting. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + + Returns + ------- + result : Expr + The computed result. + + Examples + -------- + .. code:: python + + bb = relax.BlockBuilder() + a = relax.Var("a", relax.TensorStructInfo(shape=(2, 3), dtype="float32")) + b = relax.Var("b", relax.TensorStructInfo(shape=(2, 1), dtype="float32")) + c = bb.normalize(relax.op.add(a, b)) # c has TensorStructInfo(shape=(2, 3), dtype="float32") + """ + return _ffi_api.add(x1, x2) # type: ignore + + +def multiply(x1: Expr, x2: Expr) -> Expr: + """Multiplication with numpy-style broadcasting. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + + Returns + ------- + result : Expr + The computed result. + """ + return _ffi_api.multiply(x1, x2) # type: ignore diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py new file mode 100644 index 000000000000..5bfb0d87bf00 --- /dev/null +++ b/python/tvm/relax/utils.py @@ -0,0 +1,278 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utility functions for Relax""" +import functools +import inspect +from typing import Any, Callable, List, Optional, TypeVar + +from .. import tir +from ..runtime import String, convert_to_object +from ..tir import PrimExpr +from . import _ffi_api +from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm +from .expr import Tuple as rx_Tuple + + +def metadata_partitioner(rx_txt: str) -> List[str]: + """Extract Relax program and metadata section. + + Parameters + ---------- + rx_txt : str + The input relax text. + + Returns + ------- + output : List[str] + The result list of partitioned text, the first element + is the relax program, and the second is metadata section. + """ + partitions = [] + left_curly = 0 + meta_start = 0 + meta_end = 0 + for i, char in enumerate(rx_txt): + if i < 0: + raise ValueError("The program is invalid.") + if char == "{": + if meta_start == 0: + meta_start = i + left_curly += 1 + elif char == "}": + left_curly -= 1 + if left_curly == 0: + meta_end = i + 1 + break + + if meta_end == 0: + raise ValueError("The metadata section was not found.") + metadata = rx_txt[meta_start:meta_end] + rx_program = rx_txt[meta_end:-1] + + partitions.append(rx_program) + partitions.append(metadata) + + return partitions + + +def convert_to_expr(value: Any) -> Expr: + """Helper function to convert the input to Expr, which follows the rules: + 1. Return the input itself if it's already a `relax.Expr`; + 2. Return `relax.PrimValue` if the input is a `PrimExpr`; + 3. Return `relax.StringImm` if the input is `tvm.String` or `str`; + 4. Return `relax.ShapeExpr` if the input is a tuple/list of `PrimExpr` w/ int dtype; + 5. Return `relax.Tuple` if the input is a tuple/list of `Expr`. + + Notes + ----- + 1. `tvm.tir.StringImm` is not allowed because of ambiguity, + which can be either `relax.StringImm` or `relax.PrimValue`. + 2. We regard empty tuple/list as `relax.Tuple` instead of `relax.ShapeExpr` + """ + if isinstance(value, int): + return PrimValue(tir.IntImm("int64", value)) + + tvm_value = convert_to_object(value) + # Case 1 + if isinstance(tvm_value, Expr): # type: ignore + return tvm_value + # Note`` 1 + if isinstance(tvm_value, tir.StringImm): + raise TypeError( + "Cannot convert `tir.StringImm` to `relax.Expr` because of ambiguity," + "which can be either `relax.StringImm` or `relax.PrimValue` " + ) + # Case 2 + if isinstance(tvm_value, PrimExpr): + return PrimValue(value) + # Case 3 + if isinstance(tvm_value, String): + return StringImm(value) + # Case 4 & 5 + if isinstance(value, (tuple, list)): + # Note 2 + if len(value) == 0: + return rx_Tuple([]) + # Case 4 + opt_prim_value = [convert_to_object(v) for v in value] + if all([isinstance(v, PrimExpr) and v.dtype.startswith("int") for v in opt_prim_value]): + return ShapeExpr(value) + # Case 5 + # `convert_to_expr` ensures that all elements are `Expr` if no exception raises + return rx_Tuple([convert_to_expr(v) for v in value]) + raise TypeError(f"Cannot convert {value} with type {type(value)} to `relax.Expr`") + + +FType = TypeVar("FType", bound=Callable[..., Expr]) + + +class _ArgsConverter: + """A helper class to convert the arguments to Expr.""" + + @staticmethod + def convert(args_to_expr: List[str], args_to_list_expr: List[str]): + """Convert the arguments to Expr. + + Parameters + ---------- + args_to_expr : List[str] + The argument names to be converted to Expr. + + args_to_list_expr : List[str] + The argument names to be converted to List[Expr]. + + Returns + ------- + output : Callable[[FType], FType] + The decorator. + """ + + if any([x in args_to_list_expr for x in args_to_expr]): + raise ValueError(f"`args_to_expr` and `args_to_list_expr` should be disjoint.") + + def _convert(name: str, value: Any) -> Any: + if value is None: + return value + if name in args_to_expr: + try: + return convert_to_expr(value) + except: + raise TypeError( + f"Argument `{name}` is expected to be converted to `Expr`, " + f"but failed with input value: {value}" + ) + elif name in args_to_list_expr: + try: + return [convert_to_expr(x) for x in value] + except: + raise TypeError( + f"Argument `{name}` is expected to be converted to `List[Expr]`, " + f"but failed with input value: {value}" + ) + else: + return value + + def inner(func: FType) -> FType: + sig = inspect.signature(func) + param_names = list(sig.parameters.keys()) + for name in args_to_expr + args_to_list_expr: + if name not in param_names: + raise ValueError(f"Argument `{name}` is not found in function signature.") + + @functools.wraps(func) + def wrapper(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + for param in sig.parameters.values(): + if param.kind == param.VAR_POSITIONAL: + # *args case + values = [_convert(param.name, x) for x in bound.arguments[param.name]] + bound.arguments[param.name] = tuple(values) + elif param.kind == param.VAR_KEYWORD: + # **kwargs case + key_value = { + key: _convert(param.name, value) + for key, value in bound.arguments[param.name].items() + } + bound.arguments[param.name] = key_value + else: + bound.arguments[param.name] = _convert( + param.name, bound.arguments[param.name] + ) + return func(*bound.args, **bound.kwargs) + + return wrapper # type: ignore + + return inner + + @staticmethod + def to_expr(*arg_names: str) -> Callable: + """Convert the arguments to Expr. + + Parameters + ---------- + *arg_names: str + The list of argument names that need to be converted to Expr. + + Returns + ------- + output: Callable + The decorator. + """ + + return _ArgsConverter.convert(args_to_expr=list(arg_names), args_to_list_expr=[]) + + @staticmethod + def to_list_expr(*arg_names: str) -> Callable: + """Convert the arguments to List of Expr. + + Parameters + ---------- + *arg_names: str + The list of argument names that need to be converted to List of Expr. + + Returns + ------- + output: Callable + The decorator. + """ + + return _ArgsConverter.convert(args_to_expr=[], args_to_list_expr=list(arg_names)) + + @staticmethod + def auto(func: FType) -> FType: + """Decorator for automatically convert the arguments to Expr according to type annotation. + Only two patterns are supported: + + 1. The argument is Expr or Optional[Expr]. + + 2. The argument is List[Expr] or Optional[List[Expr]]. + + """ + sig = inspect.signature(func) + args_to_expr = [] + args_to_list_expr = [] + + for param in sig.parameters.values(): + anno = param.annotation + if anno in (Expr, Optional[Expr]): + args_to_expr.append(param.name) + if anno in (List[Expr], Optional[List[Expr]]): + args_to_list_expr.append(param.name) + + return _ArgsConverter.convert(args_to_expr, args_to_list_expr)(func) + + +args_converter = _ArgsConverter() # pylint: disable=invalid-name + + +def copy_with_new_params(func: Function) -> Function: + """Copy the given function. The parameters of the original function would be copied to + satisfy the restriction in the well-formed check: any two functions cannot share the same + parameter variable. + + Parameters + ---------- + func : Function + The relax function to copy. + + Returns + ------- + ret : Function + The copied function. + """ + return _ffi_api.CopyWithNewParams(func) # type: ignore diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 0907ea2ebf85..40fac0f92f6d 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -41,6 +41,7 @@ from .operation import placeholder, compute, scan, extern, var, size_var, const from .operation import thread_axis, reduce_axis from .operation import create_prim_func +from .operation import create_relax_prim_func from .operation import extern_primfunc from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 59bc76f5041e..cfe5e073bae2 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -571,12 +571,64 @@ def create_prim_func( ops: List[_tensor.Tensor], index_dtype_override: Optional[str] = None ) -> tvm.tir.PrimFunc: """Create a TensorIR PrimFunc from tensor expression + Parameters + ---------- + ops : List[Tensor] + The source expression. + Example + ------- + We define a matmul kernel using following code: + .. code-block:: python + import tvm + from tvm import te + from tvm.te import create_prim_func + import tvm.script + A = te.placeholder((128, 128), name="A") + B = te.placeholder((128, 128), name="B") + k = te.reduce_axis((0, 128), "k") + C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C") + func = create_prim_func([A, B, C]) + print(func.script()) + If we want to use TensorIR schedule to do transformations on such kernel, + we need to use `create_prim_func([A, B, C])` to create a schedulable PrimFunc. + The generated function looks like: + .. code-block:: python + @T.prim_func + def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + for i, j, k in T.grip(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + Returns + ------- + func : tir.PrimFunc + The created function. + """ + if not isinstance(ops, (list, tuple, Array)): + ops = [ops] + return _ffi_api.CreatePrimFunc(ops, index_dtype_override) + + +def create_relax_prim_func( + ops: List[_tensor.Tensor], + tir_var_list: List[tvm.tir.Var] = None, + index_dtype_override: Optional[str] = None, +) -> tvm.tir.PrimFunc: + """Create a TensorIR PrimFunc from tensor expression Parameters ---------- ops : List[Tensor] The source expression. + tir_var_list: List[Var] + TIR variables to add as parameters to generated PrimFunc + Example ------- We define a matmul kernel using following code: @@ -621,4 +673,4 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: """ if not isinstance(ops, (list, tuple, Array)): ops = [ops] - return _ffi_api.CreatePrimFunc(ops, index_dtype_override) + return _ffi_api.CreateRelaxPrimFunc(ops, tir_var_list, index_dtype_override) diff --git a/src/ir/function.cc b/src/ir/function.cc index 69752f529a3c..6a7ccc7cf27b 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -46,4 +46,18 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") } }); +TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") + .set_body_typed([](BaseFunc func, String key) -> BaseFunc { + if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + return func; + } + }); + } // namespace tvm diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index d9b139753455..2de06fe5d6f2 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -533,6 +533,155 @@ TVM_REGISTER_GLOBAL("relax.StructInfoIsBaseOf") return IsBaseOf(base, derived); }); +//-------------------------- +// DeriveStructInfo +//-------------------------- + +// NOTE: we are reusing StructInfoBaseChecker here to populate a mapping +// from the expressions in arg(rhs) to var in param. +class CallRetStructInfoDeriver : public StructInfoBaseChecker { + public: + explicit CallRetStructInfoDeriver(arith::Analyzer* ana) : StructInfoBaseChecker(ana) {} + + // No short cut, so we can recursively populate all pairs. + BaseCheckResult VisitStructInfo(const StructInfo& lhs, const StructInfo& other) final { + return StructInfoFunctor::VisitStructInfo(lhs, other); + } + + StructInfo Derive(const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { + // opaque derivation + if (finfo->IsOpaque()) { + if (finfo->derive_func.defined()) { + // derive using custom derivation function. + return finfo->derive_func.value()(call, ctx); + } else { + // directly return the normal value. + return finfo->ret; + } + } + + // Normal function signature derivation. + auto params = finfo->params.value(); + if (params.size() != call->args.size()) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "number of arguments and parameters mismatch:" + << " expected " << params.size() << ", given " << call->args.size()); + } + // Visit each param arg pair, check and populate the var map + for (size_t i = 0; i < params.size(); ++i) { + auto arg_sinfo = GetStructInfo(call->args[i]); + BaseCheckResult res = this->VisitStructInfo(params[i], arg_sinfo); + // Report error if we find L1 level failure + // L2 level is best effort so we don't report. + // The behavior of L2 can be customized later. + if (res == BaseCheckResult::kFailL0 || res == BaseCheckResult::kFailL1) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "Argument " << i << " type mismatch:" + << " expected " << params[i] << ", given " << arg_sinfo); + } + } + // map the ret using the populated var map. + return EraseToWellDefined(finfo->ret, shape_var_map_, var_map_); + } + + protected: + // Whether to populate map in params. + bool populate_mapping_{true}; + // for simplicity, we make these fields public so the user can access them. + Map shape_var_map_; + Map var_map_; + + using StructInfoBaseChecker::ShapeMatchCheck; + + // Match shape values in between param(lhs) and arg(rhs) + BaseCheckResult PrimValueMatchCheck(const PrimExpr& param, const PrimExpr& arg) final { + if (!populate_mapping_) { + return StructInfoBaseChecker::PrimValueMatchCheck(param, arg); + } + + if (auto* ptr = param.as()) { + auto var = GetRef(ptr); + auto it = shape_var_map_.find(var); + // not populated + if (it == shape_var_map_.end()) { + shape_var_map_.Set(var, arg); + return BaseCheckResult::kPass; + } else { + // Best effort prove. + PrimExpr mapped_value = (*it).second; + if (analyzer_->CanProveEqual(mapped_value, arg)) return BaseCheckResult::kPass; + return BaseCheckResult::kFailL2; + } + } else { + // Best effort + // Do not attempt to do prove when param contains a symbolic expr. + // such expression might depends on a later defined var in params created by dyn fusion. + // example: f(a: Tensor[(n+1)], s: Shape[(n,)]), the (n+1) case here. + return StructInfoBaseChecker::PrimValueMatchCheck(param, arg); + } + } + + BaseCheckResult ShapeMatchCheck(const Expr& lhs, const Expr& rhs) final { + if (!populate_mapping_) { + return StructInfoBaseChecker::ShapeMatchCheck(lhs, rhs); + } + + if (auto* ptr = lhs.as()) { + auto var = GetRef(ptr); + auto it = var_map_.find(var); + // not populated + if (it == var_map_.end()) { + var_map_.Set(var, rhs); + return BaseCheckResult::kPass; + } else { + // Best effort prove. + Expr mapped_value = (*it).second; + if (CanProveShapeEqual(mapped_value, rhs, analyzer_)) return BaseCheckResult::kPass; + return BaseCheckResult::kFailL2; + } + } + auto lhs_shape = lhs.as(); + auto rhs_shape = rhs.as(); + ICHECK(lhs_shape) << "lhs must have a shape"; + if (!rhs_shape) return BaseCheckResult::kFailL2; + return ShapeMatchCheck(lhs_shape->values, rhs_shape->values); + } + + BaseCheckResult FuncParamsCheck(const Array& lhs, + const Array& rhs) final { + // Set populate mapping to false + // so we do not pick up symbolic vars in params with function type. + // + // @R.function + // def f(g: R.Func([R.Tensor[(n,)]], R.Tensor[(n+1,)]), + // x: R.Tensor[(m,)]) -> R.Tensor[(m,)]: + // ... + // + // For example, in the above function f, we should avoid + // pick up n in g's signature. + bool populate_mapping = false; + std::swap(populate_mapping_, populate_mapping); + auto ret = StructInfoBaseChecker::FuncParamsCheck(lhs, rhs); + std::swap(populate_mapping_, populate_mapping); + return ret; + } +}; + +StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, + const BlockBuilder& ctx, arith::Analyzer* ana) { + if (ana == nullptr) { + arith::Analyzer inst; + return CallRetStructInfoDeriver(&inst).Derive(finfo, call, ctx); + } else { + return CallRetStructInfoDeriver(ana).Derive(finfo, call, ctx); + } +} + +TVM_REGISTER_GLOBAL("relax.analysis.DeriveCallRetStructInfo") + .set_body_typed([](const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { + return DeriveCallRetStructInfo(finfo, call, ctx); + }); + //-------------------------- // UnifyToLCA //-------------------------- diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc new file mode 100644 index 000000000000..6a2d7ea5c584 --- /dev/null +++ b/src/relax/ir/block_builder.cc @@ -0,0 +1,969 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/block_builder.cc + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// Block builder have three categories of logics that are interdependent with each other. +// +// The logics are somewhat interdependent with each other. +// To help us implement a block builder in two parts: +// +// - BlockBuilderImpl: implements ctx and scope management, with no normalization. +// - BlockBuilderImplWithNormalize: subclasses BlockBuilderImpl and implements normalization. +// +// The final blockbuilder create will be backed by BlockBuilderWithNormalize + +namespace tvm { +namespace relax { + +//--------------------------------------- +// ctx and scope management. +//--------------------------------------- +class BlockBuilderImpl : public BlockBuilderNode { + public: + explicit BlockBuilderImpl(IRModule context_mod) : context_mod_(std::move(context_mod)) {} + + ~BlockBuilderImpl() { + if (!block_stack_.empty()) { + LOG(WARNING) << "BlockBuilder destroyed with remaining blocks!"; + } + } + + //------------------------------- + // Global Context management + //------------------------------- + NameTable* name_table() final { return name_table_.get(); } + + IRModule GetContextIRModule() const final { return context_mod_; } + + GlobalVar AddFunction(const BaseFunc& func, String func_name_hint) final { + LazyInitCtxFuncDedupMap(); + auto it = ctx_func_dedup_map_->find(func); + if (it == ctx_func_dedup_map_->end()) { + context_mod_.CopyOnWrite(); + + String func_name = name_table_->GetUniqueName(func_name_hint); + while (context_mod_->ContainGlobalVar(func_name)) { + func_name = name_table_->GetUniqueName(func_name_hint); + } + GlobalVar gvar = GlobalVar(func_name); + + StructInfo finfo; + if (func->struct_info_.defined()) { + finfo = GetStructInfo(func); + } else if (auto* prim_func = func.as()) { + // NOTE: use a slightly different struct info than checked type + // in PrimFunc so handle can turn into Tensor. + // TODO(relax-team): add fine-grained PrimFunc struct info signature generation. + finfo = FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)); + } else { + finfo = StructInfoFromType(func->checked_type()); + } + UpdateStructInfo(gvar, finfo); + + context_mod_->Add(gvar, func); + + ctx_func_dedup_map_->emplace(func, gvar); + return gvar; + } else { + return it->second; + } + } + + void UpdateFunction(const GlobalVar& gv, BaseFunc function) final { + context_mod_.CopyOnWrite(); + + // invalidate old dedup map + if (ctx_func_dedup_map_ != nullptr) { + auto it = context_mod_->functions.find(gv); + if (it != context_mod_->functions.end()) { + BaseFunc old_func = (*it).second; + auto ptr = ctx_func_dedup_map_->find(old_func); + ICHECK(ptr != ctx_func_dedup_map_->end()); + ctx_func_dedup_map_->erase(ptr); + } + } + + context_mod_->Update(gv, function); + + // add new dedup map item. + if (ctx_func_dedup_map_ != nullptr) { + ctx_func_dedup_map_->emplace(function, gv); + } + } + + void ReportFatal(const Diagnostic& diagnostic) final { + // TODO(relax-team): Print more context information by looking + // into the diagnostic->loc and surrounding IRModule. + // We do not materialzie DiagnosticContext to avoid double referencing to + // the change IRModule in COW. Additionally, we need to be able to + // continue use the builder after an error is thrown to avoid state building up. + // in an interactive environment. + LOG(FATAL) << diagnostic->message; + } + + //------------------------------- + // Scope management + //------------------------------- + Optional LookupBinding(const Var& var) final { + auto it = binding_table_.find(var->vid); + if (it == binding_table_.end()) return NullOpt; + return it->second; + } + + void BeginDataflowBlock() final { block_stack_.emplace_back(BlockFrame{{}, true}); } + + void BeginBindingBlock() final { block_stack_.emplace_back(BlockFrame{{}, false}); } + + void BeginScope(Optional> params) final { + // The current implementation handles the collection of shape var + // defined in parameter struct info annotations. The implementation + // is correct (since we will simply erase all relax Vars in EraseToWellDefined), + // but can be further improved. + // + // TODO(relax-team): Add support for relax Var in struct info annotations. + Map shape_var_map; + for (const Var& var : params.value_or(Array())) { + const Map& var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); + for (const auto& kv : var_map) { + const tir::Var& shape_var = kv.first; + const PrimExpr& shape_expr = kv.second; + auto it = shape_var_map.find(shape_var); + if (it == shape_var_map.end()) { + shape_var_map.Set(shape_var, shape_expr); + } else { + const PrimExpr& old_shape_expr = (*it).second; + CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr)) + << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " + << shape_expr; + } + shape_var_map.Set(shape_var, shape_expr); + } + } + scope_stack_.emplace_back(ScopeFrame({std::move(shape_var_map)})); + } + + void EndScope() final { scope_stack_.pop_back(); } + + BindingBlock EndBlock() final { + BlockFrame* cur_frame = CurrentBlockFrame(); + BindingBlock ret = cur_frame->is_dataflow ? DataflowBlock(cur_frame->bindings) + : BindingBlock(cur_frame->bindings); + block_stack_.pop_back(); + return ret; + } + + bool CurrentBlockIsDataFlow() final { return CurrentBlockFrame()->is_dataflow; } + + Var Emit(Expr expr, String name_hint) final { + return this->Emit(expr, CurrentBlockFrame()->is_dataflow, name_hint); + } + + Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint) final { + value = this->Normalize(value); + + CHECK(StructInfoBaseCheck(GetStructInfo(value), struct_info) != BaseCheckResult::kFailL0) + << "It is impossible to match cast any value into the target struct_info. " + "But got value struct info: " + << GetStructInfo(value) << ", given struct info: " << struct_info; + + // NOTE: do match cast checking later in a pass. + BlockFrame* cur_frame = CurrentBlockFrame(); + Var var = CreateVar(cur_frame->is_dataflow, name_hint); + UpdateStructInfo(var, struct_info); + + MatchCast match_cast(var, value, struct_info); + cur_frame->bindings.push_back(match_cast); + // NOTE match shape do not follow simple binding rule + // as a result should not appear in binding table. + return var; + } + + Var EmitOutput(Expr output, String name_hint) final { + BlockFrame* cur_frame = CurrentBlockFrame(); + + ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; + + return Emit(output, false, name_hint); + } + + void EmitNormalized(Binding binding) final { + BlockFrame* cur_frame = CurrentBlockFrame(); + + if (const auto* var_binding = binding.as()) { + if (!cur_frame->is_dataflow) { + ICHECK(!var_binding->var.as()) + << "Cannot emit dataflow var in non-dataflow block"; + } + // normalized check + ICHECK(var_binding->var->struct_info_.defined()); + ICHECK(var_binding->value->struct_info_.defined()); + cur_frame->bindings.push_back(binding); + binding_table_[var_binding->var->vid] = var_binding->value; + } else if (const auto* match_cast = binding.as()) { + if (!cur_frame->is_dataflow) { + ICHECK(!match_cast->var.as()) + << "Cannot emit dataflow var in non-dataflow block"; + } + // normalized check + ICHECK(match_cast->var->struct_info_.defined()); + ICHECK(match_cast->value->struct_info_.defined()); + // NOTE match shape do not follow simple binding rule + // as a result should not appear in binding table. + cur_frame->bindings.push_back(binding); + } else { + LOG(FATAL) << "Unsupported binding type: " << binding->GetTypeKey(); + } + } + + arith::Analyzer* GetAnalyzer() final { return &analyzer_; } + + protected: + /*! + * \brief A representation of a block frame. + * + * A block frame is a record containing the bindings needed + * to build a binding block, and a boolean to indicate if the + * block being built is a DataflowBlock or not. + */ + struct BlockFrame { + /*! + * \brief List of bindings + */ + Array bindings; + /*! \brief Whether current block is dataflow block. */ + bool is_dataflow; + /*! + * \brief Binding map used by normalizer. + * + * \note The normalizer only caches reuse in the current block scope + * and will not cache bindings from parent scope. + */ + std::unordered_map normalize_binding_map; + }; + /*! + * \brief A representation of a scope frame. + * + * A scope frame records tracks the context of current scope. + */ + struct ScopeFrame { + // NOTE: for simplicity, only tracks symbolic var for now + // the scope is only used for erasure, so less information means + // more conservative analysis. + // Consider impl alternative: merge with block frame if we have more frame kinds. + // + // TODO(relax-team) tracks the var defined also through match-cast. + /*! \brief set of defined symbolic vars, value as themself. */ + Map shape_var_map; + }; + + /*! \brief A stack to store block frames. */ + std::vector block_stack_; + + /*! \brief A stack to store scope frames. */ + std::vector scope_stack_; + + /*! \brief A binding table that maps var to value. */ + std::unordered_map binding_table_; + + /*! \brief A name table to get unique names for IR construction. */ + std::unique_ptr name_table_ = std::make_unique(); + + /*! \brief The IRModule being built by the BlockBuilder. */ + IRModule context_mod_; + + /*! \brief Internal analzyer */ + arith::Analyzer analyzer_; + + /*! + * \return The current frame. + * \note Never hold the value of current frame between Normalize + * or other scope calls this value can change if the block stack get updated, + * then the block frame is no longer valid. + */ + BlockFrame* CurrentBlockFrame() { + ICHECK(!block_stack_.empty()) << "no block is being built"; + return &block_stack_.back(); + } + + /*! + * \return The current scope frame. + * \note only use this value + */ + ScopeFrame* CurrentScopeFrame() { + ICHECK(!scope_stack_.empty()) << "no scope is being opened"; + return &scope_stack_.back(); + } + + /*! + * \brief Emits an Expr, and returns the variable it is bound to. + * \param expr The Expr to be emitted. + * \param is_dataflow Is the bound variable a DataflowVar or not(i.e. Var). + * \param name_hint Name hint for the bound variable. + * \note This Emit function normalizes the \p expr, + * and performs shape/type deductions by calling Normalize. + * \return The new variable that \p expr is bound to. + */ + Var Emit(Expr expr, bool is_dataflow, String name_hint) { + expr = this->Normalize(expr); + + Var var = CreateVar(is_dataflow, name_hint); + + // set the values + UpdateStructInfo(var, Downcast(expr->struct_info_.value())); + + CurrentBlockFrame()->bindings.push_back(VarBinding(var, expr)); + + // update the binding table + binding_table_[var->vid] = expr; + + return var; + } + + /*! + * \brief Create var for bindings + * \param is_dataflow Is the bound variable a DataflowVar or not(i.e. Var). + * \param name_hint Name hint for the bound variable. + * \return The created var. + */ + Var CreateVar(bool is_dataflow, String name_hint) { + if (name_hint.empty()) { + name_hint = is_dataflow ? "lv" : "gv"; + } + Id vid = Id(name_table_->GetUniqueName(name_hint)); + return is_dataflow ? DataflowVar(vid, /*struct_info_annotation=*/NullOpt) + : Var(vid, /*struct_info_annotation=*/NullOpt); + } + + private: + /*! + * \brief A hashmap to store the mapping of Relax functions and TIR PrimFuncs + * in context_mod to their GlobalVar to avoid generating duplicated functions. + */ + std::unique_ptr> + ctx_func_dedup_map_ = nullptr; + + /*! + * \brief lazily initialize function dedeup map. + */ + void LazyInitCtxFuncDedupMap() { + if (ctx_func_dedup_map_ != nullptr) return; + ctx_func_dedup_map_ = std::make_unique< + std::unordered_map>(); + for (const auto& kv : context_mod_->functions) { + const GlobalVar gv = kv.first; + const BaseFunc func = kv.second; + ctx_func_dedup_map_->emplace(func, gv); + } + } + + // Collect all the variables that a parameter var can define. + // The collector is used to making sure that we record the + // shape vars as defined when calling BeginScope(params) + class StructInfoVarCollector : public StructInfoVisitor { + public: + static Map Collect(const StructInfo& struct_info) { + StructInfoVarCollector collector; + collector(struct_info); + return collector.shape_var_map_; + } + + private: + void VisitStructInfo_(const TensorStructInfoNode* op) final { + if (const auto* shape_expr = op->shape.as()) { + for (const PrimExpr& s : shape_expr->values) { + // Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1)) + if (const auto* var = s.as()) { + shape_var_map_.Set(GetRef(var), s); + } + } + } + } + + void VisitStructInfo_(const ShapeStructInfoNode* op) final { + for (const PrimExpr& s : op->values.value_or(Array())) { + // Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1)) + if (const auto* var = s.as()) { + shape_var_map_.Set(GetRef(var), s); + } + } + } + + private: + Map shape_var_map_; + }; +}; + +//--------------------------------------- +// Normalization +//--------------------------------------- +#define RELAX_EXPR_NORMALIZER_LEAF(OP) \ + Expr VisitExpr_(const OP* op) final { return GetRef(op); } + +// TODO(relax-team): Check normalize logic after struct info. + +// Normalizer on struct info: +// +// We take benefit of the following invariants(that are checked in constructor): +// - If an expr appears in StructInfo, then it is already normalized. +// As a result, we do not need to peek into StructInfo in Normalization. +// - Constant, ShapeExpr, already have their StructInfo populated in constructing time. +class Normalizer : public BlockBuilderImpl, private ExprFunctor { + public: + explicit Normalizer(IRModule context_mod) : BlockBuilderImpl(context_mod) {} + + Expr Normalize(const Expr& expr) final { + Expr normalized = this->VisitExpr(expr); + // Invariant: + // After Normalize: an Expr always have + // struct_info (with the exception of Op). + if (!normalized->IsInstance()) { + ICHECK(normalized->struct_info_.defined()) + << "The struct_info_ of an Expr except OpNode after " + "normalization must not be nullptr. However, this Expr does not have struct_info_: " + << normalized; + } + + return normalized; + } + + /*! + * \brief Normalize Argument values to call and other IR sub-fields. + * \param arg The argument. + * \return The normalized value. + * + * \note This function create a new binding for non-leaf expressions except for tuple. + */ + Expr NormalizeArgument(const Expr& arg) final { + // Temp patch to ensure we handle inline PrimFunc case. + // TODO(relax-team) remove such cases from parser and testcases. + if (auto* prim_func = arg.as()) { + return NormalizePrimFunc(GetRef(prim_func)); + } + + if (!block_stack_.empty()) { + // cache lookup + BlockFrame* cur_frame = CurrentBlockFrame(); + auto it = cur_frame->normalize_binding_map.find(arg); + if (it != cur_frame->normalize_binding_map.end()) { + return it->second; + } + } + // skip visit expr's cache, normalize arg + Expr post = ExprFunctor::VisitExpr(arg); + + if (!IsLeafOrTuple(arg)) { + ICHECK(!block_stack_.empty()) << "Cannot normalize non-leaf without a scope"; + Var var = this->Emit(post, ""); + // NOTE: current frame addr can change due to underlying vector + // re-allocation, redo lookup + CurrentBlockFrame()->normalize_binding_map[arg] = var; + return var; + } else { + return post; + } + } + + RELAX_EXPR_NORMALIZER_LEAF(ExternFuncNode); + RELAX_EXPR_NORMALIZER_LEAF(GlobalVarNode); + RELAX_EXPR_NORMALIZER_LEAF(OpNode); + RELAX_EXPR_NORMALIZER_LEAF(ConstantNode); + RELAX_EXPR_NORMALIZER_LEAF(ShapeExprNode); + RELAX_EXPR_NORMALIZER_LEAF(PrimValueNode); + RELAX_EXPR_NORMALIZER_LEAF(StringImmNode); + RELAX_EXPR_NORMALIZER_LEAF(DataTypeImmNode); + + template + Expr VisitVar_(const typename T::ContainerType* var) { + // Parameters and free-vars must be present with struct info + // Other vars must have already been normalized through binding + ICHECK(var->struct_info_.defined()) + << "Var " << var->name_hint() << " does not have struct info."; + return GetRef(var); + } + + Expr VisitExpr_(const VarNode* var) final { return VisitVar_(var); } + + Expr VisitExpr_(const DataflowVarNode* var) final { return VisitVar_(var); } + + // Temp patch to ensure we handle inline PrimFunc case. + // TODO(relax-team) remove such cases from parser and testcases. + Expr NormalizePrimFunc(tir::PrimFunc prim_func) { + if (!prim_func->struct_info_.defined()) { + auto finfo = FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)); + UpdateStructInfo(prim_func, finfo); + } + return prim_func; + } + + Expr VisitExpr(const Expr& expr) final { + // Temp patch to ensure we handle inline PrimFunc case. + // TODO(relax-team) remove such cases from parser and testcases. + if (auto* prim_func = expr.as()) { + return NormalizePrimFunc(GetRef(prim_func)); + } + + // lookup normalize map + if (!block_stack_.empty()) { + BlockFrame* cur_frame = CurrentBlockFrame(); + auto it = cur_frame->normalize_binding_map.find(expr); + if (it != cur_frame->normalize_binding_map.end()) { + return it->second; + } + } + return ExprFunctor::VisitExpr(expr); + } + + Expr VisitExpr_(const TupleNode* op) final { + bool unchanged = true; + Array new_fields; + + for (const Expr& field : op->fields) { + Expr new_field = this->NormalizeArgument(field); + new_fields.push_back(new_field); + unchanged &= new_field.same_as(field); + } + + Tuple tuple = unchanged ? GetRef(op) : Tuple(new_fields, op->span); + // Update tuple fields. + if (!tuple->struct_info_.defined()) { + Array tuple_sinfo; + for (Expr field : tuple->fields) { + tuple_sinfo.push_back(GetStructInfo(field)); + } + UpdateStructInfo(tuple, TupleStructInfo(tuple_sinfo, op->span)); + } + return tuple; + } + + Expr VisitExpr_(const FunctionNode* op) final { + Expr new_body = this->VisitWithNewScope(op->body, op->params); + + if (new_body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(op->params, new_body, op->ret_struct_info, op->attrs); + } + } + + Expr VisitExpr_(const CallNode* op) final { + Expr new_op = this->NormalizeArgument(op->op); + bool unchanged = new_op.same_as(op->op); + + Array new_args; + + for (Expr arg : op->args) { + Expr new_arg = this->NormalizeArgument(arg); + new_args.push_back(new_arg); + unchanged &= new_arg.same_as(arg); + } + + Call call; + if (unchanged) { + call = GetRef(op); + } else { + call = Call(new_op, new_args, op->attrs, op->sinfo_args); + } + + if (!call->struct_info_.defined()) { + auto inferred_sinfo = InferStructInfo(call); + UpdateStructInfo(call, inferred_sinfo); + } + + return call; + } + + Expr VisitExpr_(const SeqExprNode* op) final { + bool unchanged = true; + Array new_blocks; + for (BindingBlock block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + new_blocks.push_back(new_block); + unchanged &= new_block.same_as(block); + } + + this->BeginBindingBlock(); + // the body may not be a leaf expression, so check for that + Expr new_body = this->NormalizeArgument(op->body); + unchanged &= new_body.same_as(op->body); + BindingBlock prologue = this->EndBlock(); + + if (!prologue->bindings.empty()) { + new_blocks.push_back(prologue); + unchanged = false; + } + + // Combine nearby blocks if possible + Array normalized_blocks = NormalizeBlocks(new_blocks); + unchanged &= normalized_blocks.same_as(new_blocks); + + SeqExpr seq_expr; + if (unchanged) { + seq_expr = GetRef(op); + } else { + seq_expr = SeqExpr(normalized_blocks, new_body, op->span); + } + + // only do shape/type inference if the SeqExpr does not have shape/type + if (!seq_expr->struct_info_.defined()) { + UpdateStructInfo(seq_expr, EraseToWellDefinedInScope(GetStructInfo(seq_expr->body))); + } + return seq_expr; + } + + Expr VisitExpr_(const IfNode* op) final { + Expr new_cond = this->NormalizeArgument(op->cond); + Expr new_true = this->VisitWithNewScope(op->true_branch); + Expr new_false = this->VisitWithNewScope(op->false_branch); + + If if_node; + if (new_cond.same_as(op->cond) && new_true.same_as(op->true_branch) && + new_false.same_as(op->false_branch)) { + if_node = GetRef(op); + } else { + if_node = If(new_cond, new_true, new_false, op->span); + } + if (!if_node->struct_info_.defined()) { + auto true_info = EraseToWellDefinedInScope(GetStructInfo(new_true)); + auto false_info = EraseToWellDefinedInScope(GetStructInfo(new_false)); + UpdateStructInfo(if_node, StructInfoLCA(true_info, false_info)); + } + return if_node; + } + + Expr VisitExpr_(const TupleGetItemNode* op) final { + Expr new_tuple = this->NormalizeArgument(op->tuple); + + TupleGetItem node = new_tuple.same_as(op->tuple) ? GetRef(op) + : TupleGetItem(new_tuple, op->index); + + if (!node->struct_info_.defined()) { + auto opt = MatchStructInfo(node->tuple); + ICHECK(opt) << "The struct info of Tuple must be TupleStructInfo."; + UpdateStructInfo(node, opt.value()->fields[node->index]); + } + + return node; + } + + Binding VisitBinding(const Binding& binding) { + if (auto* var_binding = binding.as()) { + return this->VisitVarBinding(GetRef(var_binding)); + } else { + auto* match_cast = binding.as(); + ICHECK(match_cast) << "Unsupported binding type: " << binding->GetTypeKey(); + return this->VisitMatchCast(GetRef(match_cast)); + } + } + + VarBinding VisitVarBinding(VarBinding binding) { + Expr new_value = this->VisitExpr(binding->value); + if (!new_value.same_as(binding->value)) { + binding = VarBinding(binding->var, new_value, binding->span); + } + if (!binding->var->struct_info_.defined()) { + UpdateStructInfo(binding->var, GetStructInfo(new_value)); + } + return binding; + } + + MatchCast VisitMatchCast(MatchCast binding) { + Expr new_value = this->VisitExpr(binding->value); + if (!new_value.same_as(binding->value)) { + binding = MatchCast(binding->var, new_value, binding->struct_info, binding->span); + } + if (!binding->var->struct_info_.defined()) { + UpdateStructInfo(binding->var, binding->struct_info); + } + return binding; + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) { + if (block.as()) { + this->BeginDataflowBlock(); + } else { + this->BeginBindingBlock(); + } + + bool unchanged = true; + for (const Binding& binding : block->bindings) { + Binding new_binding = this->VisitBinding(binding); + unchanged &= new_binding.same_as(binding); + + this->EmitNormalized(new_binding); + } + BindingBlock new_block = this->EndBlock(); + unchanged &= new_block->bindings.size() == block->bindings.size(); + if (unchanged) { + return block; + } + return new_block; + } + + private: + // Helper function to infer the type of a Call. + StructInfo InferStructInfo(const Call& call) { + if (auto* op_ptr = call->op.as()) { + // Case 1: the op field is a primitive op, look up FInferStructInfo attribute + Op op = GetRef(op_ptr); + ICHECK(op_map_infer_struct_info_.count(op)) + << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; + return op_map_infer_struct_info_[op](call, GetRef(this)); + } else { + // derive using function parameters + ICHECK(call->op->struct_info_.defined()); + auto opt = MatchStructInfo(call->op); + ICHECK(opt) << "Call->op must contains a function struct info"; + FuncStructInfo finfo = opt.value(); + return DeriveCallRetStructInfo(finfo, call, GetRef(this), &analyzer_); + } + } + + // erase to well defined within current scope. + StructInfo EraseToWellDefinedInScope(StructInfo info) { + if (scope_stack_.empty()) { + return EraseToWellDefined(info); + } + auto* curr_scope = CurrentScopeFrame(); + auto f_shape_var_map = [curr_scope](tir::Var var) -> Optional { + auto it = curr_scope->shape_var_map.find(var); + if (it != curr_scope->shape_var_map.end()) return (*it).second; + return NullOpt; + }; + return EraseToWellDefined(info, f_shape_var_map); + } + + Expr VisitWithNewScope(const Expr& expr, Optional> params = NullOpt) { + // SeqExpr do not need to prepare for normalization. + if (expr.as()) { + this->BeginScope(params); + Expr ret = this->VisitExpr(expr); + this->EndScope(); + return ret; + } else { + this->BeginScope(params); + + this->BeginBindingBlock(); + Expr post = this->NormalizeArgument(expr); + BindingBlock prologue = this->EndBlock(); + // "New scopes" (function bodies, if/else clauses) must be wrapped in seq exprs. + // Don't wrap if it's already a seq and there are no bindings to add + if (post.as() && prologue->bindings.empty()) { + return post; + } + Array bindings; + if (!prologue->bindings.empty()) { + bindings.push_back(prologue); + } + + SeqExpr seq(bindings, post); + UpdateStructInfo(seq, EraseToWellDefinedInScope(GetStructInfo(seq->body))); + + this->EndScope(); + return seq; + } + } + + Array FlattenBlocks(const Array& blocks) { + // If there is a binding that is a seq expr, split the current block, + // add the nested blocks prior to the seq expr, and bind the seq expr body + // to the var + Array ret; + bool changed = false; + for (const BindingBlock& block : blocks) { + bool is_dataflow = block->IsInstance(); + Array current; + for (const Binding& binding : block->bindings) { + Expr value; + if (const auto* var_binding = binding.as()) { + value = var_binding->value; + } else if (const auto* match_cast = binding.as()) { + value = match_cast->value; + } else { + LOG(FATAL) << "Unknown binding type: " << binding->GetTypeKey(); + } + // if we encounter a nested seq, we have to flatten it: + // 1. Append the binding block we've accumulated so far + // 2. Reset the current block + // 3. Append the inner blocks + // 4. Add a binding of the current var to the seq expr's body to the current block + // then continue + if (auto seq = value.as()) { + changed = true; + ret.push_back(is_dataflow ? DataflowBlock(current) : BindingBlock(current)); + current = {}; + // We do not need to flatten recursively because the normalizer will have normalized + // and thus flattened the inner SeqExprs already + for (const BindingBlock& block : seq->blocks) { + if (is_dataflow && !block->IsInstance()) { + LOG(WARNING) << "Malformed AST: Seq expr nested inside a dataflow block contains a " + "non-dataflow block! " + << seq; + } + ret.push_back(block); + } + + if (const auto* var_binding = binding.as()) { + current.push_back(VarBinding(var_binding->var, seq->body)); + } else if (const auto* match_cast = binding.as()) { + current.push_back(MatchCast(match_cast->var, seq->body, match_cast->struct_info)); + } else { + LOG(FATAL) << "Unknown binding type: " << binding->GetTypeKey(); + } + } else { + current.push_back(binding); + } + } + ret.push_back(is_dataflow ? DataflowBlock(current) : BindingBlock(current)); + } + return changed ? ret : blocks; + } + + Array NormalizeBlocks(const Array& blocks) { + bool changed = false; + Array ret; + auto flattened = FlattenBlocks(blocks); + if (!flattened.same_as(blocks)) { + changed = true; + } + for (const BindingBlock& block : flattened) { + if (block->bindings.empty()) { + // Case 1. Skip empty blocks + changed = true; + } else if (!ret.empty() && ret.back()->type_index() == block->type_index()) { + // Case 2. Merge with previous block if possible + BindingBlock merged; + // NOTE: should check DataflowBlockNode first. + if (const auto* dataflow_block = ret.back().as()) { + auto n = make_object(*dataflow_block); + n->bindings.insert(n->bindings.end(), block->bindings.begin(), block->bindings.end()); + merged = DataflowBlock(n); + } else if (const auto* binding_block = ret.back().as()) { + auto n = make_object(*binding_block); + n->bindings.insert(n->bindings.end(), block->bindings.begin(), block->bindings.end()); + merged = BindingBlock(n); + } else { + LOG(FATAL) << "Unknown block type: " << ret.back()->GetTypeKey(); + } + ret.pop_back(); + ret.push_back(merged); + changed = true; + } else { + // Case 3. Add to the result + ret.push_back(block); + } + } + return changed ? ret : blocks; + } + + /*! \brief Operator struct info inference map. */ + tvm::OpAttrMap op_map_infer_struct_info_ = + Op::GetAttrMap("FInferStructInfo"); +}; + +BlockBuilder BlockBuilder::Create(Optional mod) { + ObjectPtr n = make_object(mod.value_or(IRModule())); + return BlockBuilder(n); +} + +//--------------------------------------- +// User facing function registration. +//--------------------------------------- +TVM_REGISTER_OBJECT_TYPE(BlockBuilderNode); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed([](Optional mod) { + return BlockBuilder::Create(mod); +}); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginDataflowBlock") + .set_body_method(&BlockBuilderNode::BeginDataflowBlock); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginBindingBlock") + .set_body_method(&BlockBuilderNode::BeginBindingBlock); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEndBlock") + .set_body_method(&BlockBuilderNode::EndBlock); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize") + .set_body_method(&BlockBuilderNode::Normalize); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit").set_body_typed([](BlockBuilder builder, Expr expr) { + return builder->Emit(expr); +}); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast") + .set_body_typed([](BlockBuilder builder, Expr value, StructInfo struct_info) { + return builder->EmitMatchCast(value, struct_info); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput") + .set_body_typed([](BlockBuilder builder, const Expr& output) { + return builder->EmitOutput(output); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitNormalized") + .set_body_typed([](BlockBuilder builder, Binding binding) { + return builder->EmitNormalized(binding); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName") + .set_body_typed([](BlockBuilder builder, String name_hint) { + return builder->name_table()->GetUniqueName(name_hint); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderAddFunction") + .set_body_method(&BlockBuilderNode::AddFunction); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderUpdateFunction") + .set_body_method(&BlockBuilderNode::UpdateFunction); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderGetContextIRModule") + .set_body_method(&BlockBuilderNode::GetContextIRModule); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderCurrentBlockIsDataFlow") + .set_body_method(&BlockBuilderNode::CurrentBlockIsDataFlow); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderLookupBinding") + .set_body_method(&BlockBuilderNode::LookupBinding); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginScope") + .set_body_method(&BlockBuilderNode::BeginScope); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEndScope") + .set_body_method(&BlockBuilderNode::EndScope); +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc new file mode 100644 index 000000000000..bfb5896c9988 --- /dev/null +++ b/src/relax/ir/emit_te.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/src/ir/emit_te.cc + */ +#include "./emit_te.h" + +#include +#include + +namespace tvm { +namespace relax { + +// RXPlaceholderOpNode +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "rxplaceholder(" << op->name << ", " << op << ")"; + }); + +TVM_REGISTER_NODE_TYPE(RXPlaceholderOpNode); + +te::Tensor TETensor(Expr value, Map tir_var_map, std::string name) { + auto n = make_object(); + n->name = name; + n->value = value; + + // If the value is a constant, it might come as an argument of EmitTE and thus its shape and + // checked-type might not be properly set. In this case we set the shape and dtype of the returned + // TE tensor. + if (const auto* constant = value.as()) { + n->dtype = DataType(constant->data->dtype); + + int ndim = constant->data->ndim; + ShapeTuple shape_tuple = constant->data.Shape(); + Array shape; + shape.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + shape.push_back(IntImm(DataType::Int(64), shape_tuple[i])); + } + n->shape = std::move(shape); + return te::PlaceholderOp(n).output(0); + } + ICHECK(value->struct_info_.defined()) << "value must be normalized and contain StructInfo"; + auto* tensor_sinfo = GetStructInfoAs(value); + ICHECK(tensor_sinfo) << "Value must be a tensor"; + auto* shape_expr = tensor_sinfo->shape.as(); + CHECK(shape_expr) + << "ValueError: Expression does not have an known symbolic shape, please consider use " + "match_cast " + << "to constrain the shape before passing into te_tensor"; + n->shape = shape_expr->values.Map( + [&tir_var_map](const PrimExpr& e) { return tir::Substitute(e, tir_var_map); }); + n->dtype = tensor_sinfo->dtype; + return te::PlaceholderOp(n).output(0); +} + +TVM_REGISTER_GLOBAL("relax.TETensor").set_body_typed(TETensor); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h new file mode 100644 index 000000000000..46207479c7ef --- /dev/null +++ b/src/relax/ir/emit_te.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/src/ir/emit_te.h + * \brief Tensor expression extension in Relax. + */ +#ifndef TVM_RELAX_IR_EMIT_TE_H_ +#define TVM_RELAX_IR_EMIT_TE_H_ + +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A placeholder op that represents a relax expression. + */ +class RXPlaceholderOpNode : public te::PlaceholderOpNode { + public: + /*! \brief The relax expression. */ + Expr value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("tag", &tag); + v->Visit("attrs", &attrs); + v->Visit("value", &value); + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + } + + static constexpr const char* _type_key = "RXPlaceholderOp"; + TVM_DECLARE_FINAL_OBJECT_INFO(RXPlaceholderOpNode, te::PlaceholderOpNode); +}; + +/*! + * \brief Create a TE tensor from relax expression, with TIR variables in the + * tensor shape substituted by the given mapping. + * \param value The relax expression, which is required to have TensorStructInfo. + * \param tir_var_map The mapping to substitute the TIR variables appeared in the + * shape of the input Expr. + * \param name The name of the created tensor. + */ +te::Tensor TETensor(Expr value, Map tir_var_map, std::string name); + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_IR_EMIT_TE_H_ diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 048de7950f97..4c4b68f3d200 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -542,5 +542,249 @@ BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { PrimExpr ExprMutatorBase::VisitPrimExpr(const PrimExpr& expr) { return expr; } +// ================== +// ExprMutator + +Expr ExprMutator::VisitExpr(const Expr& expr) { + return builder_->Normalize(ExprFunctor::VisitExpr(expr)); +} + +// Visit the use-site of a defined Var +Expr ExprMutator::VisitExpr_(const VarNode* op) { + auto it = var_remap_.find(op->vid); + if (it != var_remap_.end()) { + return it->second; + } + + // default case return self. + return GetRef(op); +} + +// Visit the use-site of a defined DataflowVar +Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) { + auto it = var_remap_.find(op->vid); + if (it != var_remap_.end()) { + return it->second; + } + + // default case return self. + return GetRef(op); +} + +Expr ExprMutator::VisitExpr_(const FunctionNode* op) { + tvm::Array params; + bool all_params_unchanged = true; + for (Var param : op->params) { + Var new_param = this->VisitVarDef(param); + params.push_back(new_param); + all_params_unchanged &= param.same_as(new_param); + } + + Expr body = this->VisitWithNewScope(op->body, params); + + // FuncStructInfo does not depend on Expr + if (all_params_unchanged && body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(params, body, op->ret_struct_info, op->attrs); + } +} + +Expr ExprMutator::VisitExpr_(const IfNode* op) { + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitWithNewScope(op->true_branch); + Expr false_b = this->VisitWithNewScope(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } +} + +Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { + bool all_blocks_unchanged = true; + Array blocks; + for (auto block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + all_blocks_unchanged &= block.same_as(new_block); + } + + builder_->BeginBindingBlock(); + Expr body = this->VisitExpr(op->body); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + blocks.push_back(prologue); + all_blocks_unchanged = false; + } + + if (all_blocks_unchanged && body.same_as(op->body) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + return GetRef(op); + } else { + return SeqExpr(blocks, body); + } +} + +RELAX_VAR_BINDING_DISPATCH_IMPL(ExprMutator); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ConstantNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(TupleNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(VarNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(DataflowVarNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ShapeExprNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ExternFuncNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(GlobalVarNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(FunctionNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(CallNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(SeqExprNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(IfNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(OpNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(TupleGetItemNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(PrimValueNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(StringImmNode); +RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(DataTypeImmNode); + +void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { + Var new_var = this->VisitVarDef(binding->var); + + // fast path: reemit binding if nothing changes + if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + return; + } + + Var temp = WithStructInfo(new_var, GetStructInfo(new_value)); + if (!temp.same_as(new_var)) { + new_var = temp; + this->var_remap_[binding->var->vid] = new_var; + } + + builder_->EmitNormalized(VarBinding(new_var, new_value)); +} + +void ExprMutator::VisitBinding_(const MatchCastNode* binding) { + Var new_var = this->VisitVarDef(binding->var); + Expr new_value = this->VisitExpr(binding->value); + + // re-emit old binding if nothing changes + if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + new_value = builder_->NormalizeArgument(new_value); + builder_->EmitNormalized(MatchCast(new_var, new_value, binding->struct_info, binding->span)); + } +} + +BindingBlock ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); +} + +BindingBlock ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) { + builder_->BeginDataflowBlock(); + for (auto binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); +} + +Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) { + if (auto* sinfo = var->struct_info_.as()) { + StructInfo struct_info = this->VisitExprDepStructInfoField(GetRef(sinfo)); + if (struct_info.same_as(var->struct_info_)) { + return GetRef(var); + } else { + return DataflowVar(var->vid, struct_info, var->span); + } + } else { + return GetRef(var); + } +} + +Var ExprMutator::VisitVarDef_(const VarNode* var) { + if (auto* sinfo = var->struct_info_.as()) { + StructInfo struct_info = this->VisitExprDepStructInfoField(GetRef(sinfo)); + if (struct_info.same_as(var->struct_info_)) { + return GetRef(var); + } else { + return Var(var->vid, struct_info, var->span); + } + } else { + return GetRef(var); + } +} + +void ExprMutator::VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + VisitBinding_(node); + } else if (const auto* node = binding.as()) { + VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } +} + +BindingBlock ExprMutator::VisitBindingBlock(const BindingBlock& block) { + BindingBlock ret; + if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + return ret; +} + +Var ExprMutator::VisitVarDef(const Var& var) { + Var ret; + if (const auto* node = var.as()) { + ret = VisitVarDef_(node); + } else if (const auto* node = var.as()) { + ret = VisitVarDef_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } + return ret; +} + +Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional> params) { + ICHECK(expr->IsInstance()) + << "Normal form requires all new scope is stored as SeqExpr"; + builder_->BeginScope(params); + Expr ret = this->VisitExpr(expr); + builder_->EndScope(); + return ret; +} + +Optional ExprMutator::LookupBinding(const Var& var) { return builder_->LookupBinding(var); } + +Var ExprMutator::WithStructInfo(Var var, StructInfo struct_info) { + ICHECK(struct_info.defined()); + + // TODO(relax-team) add StructInfoEqual check + if (var->struct_info_.defined()) { + // use same-as as a quick path + if (var->struct_info_.same_as(struct_info) || + StructuralEqual()(var->struct_info_, struct_info)) { + return var; + } else { + Var new_var = var.as() ? DataflowVar(var->vid, struct_info, var->span) + : Var(var->vid, struct_info, var->span); + return new_var; + } + } else { + UpdateStructInfo(var, struct_info); + return var; + } +} + } // namespace relax } // namespace tvm diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc new file mode 100644 index 000000000000..7e86235aa61e --- /dev/null +++ b/src/relax/ir/py_expr_functor.cc @@ -0,0 +1,649 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/py_expr_functor.cc + * \brief The backbone of PyExprVisitor/PyExprMutator. + */ +#include + +namespace tvm { +namespace relax { + +/*! + * \brief The abstract interface of ExprVisitor. + */ +class PyExprVisitorNode : public Object, public ExprVisitor { + private: + using TSelf = PyExprVisitorNode; + using FType = tvm::NodeFunctor; + + public: + /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ + PackedFunc f_visit_expr{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ConstantNode* op)` function. */ + PackedFunc f_visit_constant_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleNode* op)` function. */ + PackedFunc f_visit_tuple_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ + PackedFunc f_visit_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataflowVarNode* op)` function. */ + PackedFunc f_visit_dataflow_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ + PackedFunc f_visit_shape_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ + PackedFunc f_visit_extern_func_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ + PackedFunc f_visit_global_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FunctionNode* op)` function. */ + PackedFunc f_visit_function_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ + PackedFunc f_visit_call_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SeqExprNode* op)` function. */ + PackedFunc f_visit_seq_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const IfNode* op)` function. */ + PackedFunc f_visit_if_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const OpNode* op)` function. */ + PackedFunc f_visit_op_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ + PackedFunc f_visit_tuple_getitem_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const PrimValueNode* op)` function. */ + PackedFunc f_visit_prim_value_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ + PackedFunc f_visit_string_imm_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataTypeImmNode* op)` function. */ + PackedFunc f_visit_data_type_imm_{nullptr}; + /*! \brief The packed function to the `VisitBinding(const Binding& binding)` function. */ + PackedFunc f_visit_binding{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` + * function. */ + PackedFunc f_visit_var_binding_{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)` + * function. */ + PackedFunc f_visit_match_cast_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` + * function. */ + PackedFunc f_visit_binding_block{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const BindingBlockNode* block)` + * function. */ + PackedFunc f_visit_binding_block_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const DataflowBlockNode* block)` + * function. */ + PackedFunc f_visit_dataflow_block_{nullptr}; + /*! \brief The packed function to the `VisitVarDef(const Var& var)` function. */ + PackedFunc f_visit_var_def{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const VarNode* var)` function. */ + PackedFunc f_visit_var_def_{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const DataflowVarNode* var)` function. */ + PackedFunc f_visit_dataflow_var_def_{nullptr}; + /*! \brief The packed function to the `VisitSpan(const Span& span)` function. */ + PackedFunc f_visit_span{nullptr}; + + void VisitExpr(const Expr& expr) { + if (f_visit_expr != nullptr) { + f_visit_expr(expr); + } else { + // Need to init the overwrite VTable + static FType vtable = InitVTable(); + vtable(expr, this); + } + } + + void VisitBinding(const Binding& binding) + PY_EXPR_VISITOR_DEFAULT(binding, f_visit_binding, ExprVisitor::VisitBinding(binding)); + + void VisitBinding_(const VarBindingNode* binding) + PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_var_binding_, + ExprVisitor::VisitBinding_(binding)); + void VisitBinding_(const MatchCastNode* binding) + PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_match_cast_, + ExprVisitor::VisitBinding_(binding)); + + void VisitBindingBlock(const BindingBlock& block) + PY_EXPR_VISITOR_DEFAULT(block, f_visit_binding_block, ExprVisitor::VisitBindingBlock(block)); + + void VisitBindingBlock_(const BindingBlockNode* block) + PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_binding_block_, + ExprVisitor::VisitBindingBlock_(block)); + void VisitBindingBlock_(const DataflowBlockNode* block) + PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + ExprVisitor::VisitBindingBlock_(block)); + + void VisitVarDef(const Var& var) + PY_EXPR_VISITOR_DEFAULT(var, f_visit_var_def, ExprVisitor::VisitVarDef(var)); + void VisitVarDef_(const VarNode* var) + PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_var_def_, ExprVisitor::VisitVarDef_(var)); + void VisitVarDef_(const DataflowVarNode* var) + PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + ExprVisitor::VisitVarDef_(var)); + + void VisitSpan(const Span& span) + PY_EXPR_VISITOR_DEFAULT(span, f_visit_span, ExprVisitor::VisitSpan(span)); + + void VisitAttrs(AttrVisitor* v) {} + static constexpr const char* _type_key = "expr_functor.PyExprVisitor"; + TVM_DECLARE_BASE_OBJECT_INFO(PyExprVisitorNode, Object); + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + PY_EXPR_VISITOR_DISPATCH(ConstantNode, f_visit_constant_); + PY_EXPR_VISITOR_DISPATCH(TupleNode, f_visit_tuple_); + PY_EXPR_VISITOR_DISPATCH(VarNode, f_visit_var_); + PY_EXPR_VISITOR_DISPATCH(DataflowVarNode, f_visit_dataflow_var_); + PY_EXPR_VISITOR_DISPATCH(ShapeExprNode, f_visit_shape_expr_); + PY_EXPR_VISITOR_DISPATCH(ExternFuncNode, f_visit_extern_func_); + PY_EXPR_VISITOR_DISPATCH(GlobalVarNode, f_visit_global_var_); + PY_EXPR_VISITOR_DISPATCH(FunctionNode, f_visit_function_); + PY_EXPR_VISITOR_DISPATCH(CallNode, f_visit_call_); + PY_EXPR_VISITOR_DISPATCH(SeqExprNode, f_visit_seq_expr_); + PY_EXPR_VISITOR_DISPATCH(IfNode, f_visit_if_); + PY_EXPR_VISITOR_DISPATCH(OpNode, f_visit_op_); + PY_EXPR_VISITOR_DISPATCH(TupleGetItemNode, f_visit_tuple_getitem_); + PY_EXPR_VISITOR_DISPATCH(PrimValueNode, f_visit_prim_value_); + PY_EXPR_VISITOR_DISPATCH(StringImmNode, f_visit_string_imm_); + PY_EXPR_VISITOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_); + return vtable; + } +}; + +TVM_REGISTER_NODE_TYPE(PyExprVisitorNode); + +/*! + * \brief Managed reference to PyExprVisitorNode. + * \sa PyExprVisitorNode + */ +class PyExprVisitor : public ObjectRef { + public: + /*! + * \brief Create a PyExprVisitor with customized methods on the python-side. + * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. + * \param f_visit_constant_ The packed function of `VisitExpr_(const ConstantNode* op)`. + * \param f_visit_tuple_ The packed function of `VisitExpr_(const TupleNode* op)`. + * \param f_visit_var_ The packed function of `VisitExpr_(const VarNode* op)`. + * \param f_visit_dataflow_var_ The packed function of `VisitExpr_(const DataflowVarNode* op)`. + * \param f_visit_shape_expr_ The packed function of `VisitExpr_(const ShapeExprNode* op)`. + * \param f_visit_extern_func_ The packed function of `VisitExpr_(const ExternFuncNode* op)`. + * \param f_visit_global_var_ The packed function of `VisitExpr_(const GlobalVarNode* op)`. + * \param f_visit_function_ The packed function of `VisitExpr_(const FunctionNode* op)`. + * \param f_visit_call_ The packed function of `VisitExpr_(const CallNode* op)`. + * \param f_visit_seq_expr_ The packed function of `VisitExpr_(const SeqExprNode* op)`. + * \param f_visit_if_ The packed function of `VisitExpr_(const IfNode* op)`. + * \param f_visit_op_ The packed function of `VisitExpr_(const OpNode* op)`. + * \param f_visit_tuple_getitem_ The packed function of `VisitExpr_(const TupleGetItemNode* op)`. + * \param f_visit_prim_value_ The packed function of `VisitExpr_(const PrimValueNode* op)`. + * \param f_visit_string_imm_ The packed function of `VisitExpr_(const StringImmNode* op)`. + * \param f_visit_data_type_imm_ The packed function of `VisitExpr_(const DataTypeImmNode* op)`. + * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. + * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode* + * binding)`. + * \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode* + * binding)`. + * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock& + * block)`. + * \param f_visit_binding_block_ The packed function of `VisitBindingBlock_(const + * BindingBlockNode* block)`. + * \param f_visit_dataflow_block_ The packed function of `VisitBindingBlock_(const + * DataflowBlockNode* block)`. + * \param f_visit_var_def The packed function of `VisitVarDef(const Var& var)`. + * \param f_visit_var_def_ The packed function of `VisitVarDef_(const VarNode* var)`. + * \param f_visit_dataflow_var_def_ The packed function of `VisitVarDef_(const DataflowVarNode* + * var)`. + * \param f_visit_span The packed function of `VisitSpan(const Span& span)`. + * \return The PyVisitor created. + */ + TVM_DLL static PyExprVisitor MakePyExprVisitor( + PackedFunc f_visit_expr, PackedFunc f_visit_constant_, PackedFunc f_visit_tuple_, + PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, PackedFunc f_visit_shape_expr_, + PackedFunc f_visit_extern_func_, PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, + PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, + PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_prim_value_, + PackedFunc f_visit_string_imm_, PackedFunc f_visit_data_type_imm_, PackedFunc f_visit_binding, + PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_, + PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, + PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, + PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_span) { + ObjectPtr n = make_object(); + n->f_visit_expr = f_visit_expr; + n->f_visit_binding = f_visit_binding; + n->f_visit_binding_block = f_visit_binding_block; + n->f_visit_var_def = f_visit_var_def; + n->f_visit_span = f_visit_span; + n->f_visit_constant_ = f_visit_constant_; + n->f_visit_tuple_ = f_visit_tuple_; + n->f_visit_var_ = f_visit_var_; + n->f_visit_dataflow_var_ = f_visit_dataflow_var_; + n->f_visit_shape_expr_ = f_visit_shape_expr_; + n->f_visit_extern_func_ = f_visit_extern_func_; + n->f_visit_global_var_ = f_visit_global_var_; + n->f_visit_function_ = f_visit_function_; + n->f_visit_call_ = f_visit_call_; + n->f_visit_seq_expr_ = f_visit_seq_expr_; + n->f_visit_if_ = f_visit_if_; + n->f_visit_op_ = f_visit_op_; + n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; + n->f_visit_prim_value_ = f_visit_prim_value_; + n->f_visit_string_imm_ = f_visit_string_imm_; + n->f_visit_data_type_imm_ = f_visit_data_type_imm_; + n->f_visit_var_binding_ = f_visit_var_binding_; + n->f_visit_match_cast_ = f_visit_match_cast_; + n->f_visit_binding_block_ = f_visit_binding_block_; + n->f_visit_dataflow_block_ = f_visit_dataflow_block_; + n->f_visit_var_def_ = f_visit_var_def_; + n->f_visit_dataflow_var_def_ = f_visit_dataflow_var_def_; + return PyExprVisitor(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprVisitor, ObjectRef, PyExprVisitorNode); +}; + +/*! + * \brief The abstract interface of ExprMutator. + */ +class PyExprMutatorNode : public Object, public ExprMutator { + private: + using TSelf = PyExprMutatorNode; + using FType = tvm::NodeFunctor; + + public: + /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ + PackedFunc f_visit_expr{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ConstantNode* op)` function. */ + PackedFunc f_visit_constant_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleNode* op)` function. */ + PackedFunc f_visit_tuple_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ + PackedFunc f_visit_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataflowVarNode* op)` function. */ + PackedFunc f_visit_dataflow_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ + PackedFunc f_visit_shape_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ + PackedFunc f_visit_extern_func_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ + PackedFunc f_visit_global_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FunctionNode* op)` function. */ + PackedFunc f_visit_function_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ + PackedFunc f_visit_call_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SeqExprNode* op)` function. */ + PackedFunc f_visit_seq_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const IfNode* op)` function. */ + PackedFunc f_visit_if_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const OpNode* op)` function. */ + PackedFunc f_visit_op_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ + PackedFunc f_visit_tuple_getitem_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const PrimValueNode* op)` function. */ + PackedFunc f_visit_prim_value_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const StringImmNode* op)` function. */ + PackedFunc f_visit_string_imm_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataTypeImmNode* op)` function. */ + PackedFunc f_visit_data_type_imm_{nullptr}; + /*! \brief The packed function to the `VisitBinding(const Binding& binding)` function. */ + PackedFunc f_visit_binding{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` + * function. */ + PackedFunc f_visit_var_binding_{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)` + * function. */ + PackedFunc f_visit_match_cast_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` + * function. */ + PackedFunc f_visit_binding_block{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const BindingBlockNode* block)` + * function. */ + PackedFunc f_visit_binding_block_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const DataflowBlockNode* block)` + * function. */ + PackedFunc f_visit_dataflow_block_{nullptr}; + /*! \brief The packed function to the `VisitVarDef(const Var& var)` function. */ + PackedFunc f_visit_var_def{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const VarNode* var)` function. */ + PackedFunc f_visit_var_def_{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const DataflowVarNode* var)` function. */ + PackedFunc f_visit_dataflow_var_def_{nullptr}; + /*! \brief The packed function to the `VisitSpan(const Span& span)` function. */ + PackedFunc f_visit_span{nullptr}; + + Expr VisitExpr(const Expr& expr) { + if (f_visit_expr != nullptr) { + return builder_->Normalize(f_visit_expr(expr)); + } else { + static FType vtable = InitVTable(); + return builder_->Normalize(vtable(expr, this)); + } + } + + void VisitBinding(const Binding& binding) { + if (f_visit_binding != nullptr) + f_visit_binding(binding); + else + ExprMutator::VisitBinding(binding); + } + + void VisitBinding_(const VarBindingNode* binding) { + if (f_visit_var_binding_ != nullptr) + f_visit_var_binding_(GetRef(binding)); + else + ExprMutator::VisitBinding_(binding); + } + + void VisitBinding_(const MatchCastNode* binding) { + if (f_visit_match_cast_ != nullptr) + f_visit_match_cast_(GetRef(binding)); + else + ExprMutator::VisitBinding_(binding); + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) + PY_EXPR_MUTATOR_DEFAULT(block, f_visit_binding_block, ExprMutator::VisitBindingBlock(block), + BindingBlock); + + BindingBlock VisitBindingBlock_(const BindingBlockNode* block) + PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_binding_block_, + ExprMutator::VisitBindingBlock_(block), BindingBlock); + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) + PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + ExprMutator::VisitBindingBlock_(block), BindingBlock); + + Var VisitVarDef(const Var& var) + PY_EXPR_MUTATOR_DEFAULT(var, f_visit_var_def, ExprMutator::VisitVarDef(var), Var); + Var VisitVarDef_(const VarNode* var) PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_var_def_, + ExprMutator::VisitVarDef_(var), Var); + Var VisitVarDef_(const DataflowVarNode* var) + PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + ExprMutator::VisitVarDef_(var), Var); + + /*! + * \brief Dispatcher for post-order rewrite. + * \param expr The Expr to be rewritten. + * \return The Expr after post-order rewritten. + */ + Expr VisitExprPostOrder(const Expr& expr) { + static FType post_order_vtable = InitPostOrderVTable(); + return post_order_vtable(expr, this); + } + + using ExprMutator::builder_; + using ExprMutator::LookupBinding; + using ExprMutator::var_remap_; + using ExprMutator::VisitWithNewScope; + using ExprMutator::WithStructInfo; + + void VisitAttrs(AttrVisitor* v) { v->Visit("builder_", &builder_); } + static constexpr const char* _type_key = "expr_functor.PyExprMutator"; + TVM_DECLARE_BASE_OBJECT_INFO(PyExprMutatorNode, Object); + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + PY_EXPR_MUTATOR_DISPATCH(ConstantNode, f_visit_constant_); + PY_EXPR_MUTATOR_DISPATCH(TupleNode, f_visit_tuple_); + PY_EXPR_MUTATOR_DISPATCH(VarNode, f_visit_var_); + PY_EXPR_MUTATOR_DISPATCH(DataflowVarNode, f_visit_dataflow_var_); + PY_EXPR_MUTATOR_DISPATCH(ShapeExprNode, f_visit_shape_expr_); + PY_EXPR_MUTATOR_DISPATCH(ExternFuncNode, f_visit_extern_func_); + PY_EXPR_MUTATOR_DISPATCH(GlobalVarNode, f_visit_global_var_); + PY_EXPR_MUTATOR_DISPATCH(FunctionNode, f_visit_function_); + PY_EXPR_MUTATOR_DISPATCH(CallNode, f_visit_call_); + PY_EXPR_MUTATOR_DISPATCH(SeqExprNode, f_visit_seq_expr_); + PY_EXPR_MUTATOR_DISPATCH(IfNode, f_visit_if_); + PY_EXPR_MUTATOR_DISPATCH(OpNode, f_visit_op_); + PY_EXPR_MUTATOR_DISPATCH(TupleGetItemNode, f_visit_tuple_getitem_); + PY_EXPR_MUTATOR_DISPATCH(PrimValueNode, f_visit_prim_value_); + PY_EXPR_MUTATOR_DISPATCH(StringImmNode, f_visit_string_imm_); + PY_EXPR_MUTATOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_); + return vtable; + } + + // initialize the vtable for post order visit. + static FType InitPostOrderVTable() { + FType post_order_vtable; + // Set dispatch + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ConstantNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(TupleNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(VarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataflowVarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ShapeExprNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ExternFuncNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(GlobalVarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(FunctionNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(CallNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(SeqExprNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(IfNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OpNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(TupleGetItemNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(PrimValueNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(StringImmNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataTypeImmNode); + return post_order_vtable; + } +}; + +TVM_REGISTER_NODE_TYPE(PyExprMutatorNode); + +/*! + * \brief Managed reference to PyExprMutatorNode. + * \sa PyExprMutatorNode + */ +class PyExprMutator : public ObjectRef { + public: + /*! + * \brief Create a PyExprMutator with customized methods on the python-side. + * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. + * \param f_visit_constant_ The packed function of `VisitExpr_(const ConstantNode* op)`. + * \param f_visit_tuple_ The packed function of `VisitExpr_(const TupleNode* op)`. + * \param f_visit_var_ The packed function of `VisitExpr_(const VarNode* op)`. + * \param f_visit_dataflow_var_ The packed function of `VisitExpr_(const DataflowVarNode* op)`. + * \param f_visit_shape_expr_ The packed function of `VisitExpr_(const ShapeExprNode* op)`. + * \param f_visit_extern_func_ The packed function of `VisitExpr_(const ExternFuncNode* op)`. + * \param f_visit_global_var_ The packed function of `VisitExpr_(const GlobalVarNode* op)`. + * \param f_visit_function_ The packed function of `VisitExpr_(const FunctionNode* op)`. + * \param f_visit_call_ The packed function of `VisitExpr_(const CallNode* op)`. + * \param f_visit_seq_expr_ The packed function of `VisitExpr_(const SeqExprNode* op)`. + * \param f_visit_if_ The packed function of `VisitExpr_(const IfNode* op)`. + * \param f_visit_op_ The packed function of `VisitExpr_(const OpNode* op)`. + * \param f_visit_tuple_getitem_ The packed function of `VisitExpr_(const TupleGetItemNode* op)`. + * \param f_visit_prim_value_ The packed function of `VisitExpr_(const PrimValueNode* op)`. + * \param f_visit_string_imm_ The packed function of `VisitExpr_(const StringImmNode* op)`. + * \param f_visit_data_type_imm_ The packed function of `VisitExpr_(const DataTypeImmNode* op)`. + * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. + * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode* + * binding)`. + * \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode* + * binding)`. + * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock& + * block)`. + * \param f_visit_binding_block_ The packed function of `VisitBindingBlock_(const + * BindingBlockNode* block)`. + * \param f_visit_dataflow_block_ The packed function of `VisitBindingBlock_(const + * DataflowBlockNode* block)`. + * \param f_visit_var_def The packed function of `VisitVarDef(const Var& var)`. + * \param f_visit_var_def_ The packed function of `VisitVarDef_(const VarNode* var)`. + * \param f_visit_dataflow_var_def_ The packed function of `VisitVarDef_(const DataflowVarNode* + * var)`. + * \param f_visit_span The packed function of `VisitSpan(const Span& span)`. + * \return The PyExprMutator created. + */ + TVM_DLL static PyExprMutator MakePyExprMutator( + BlockBuilder builder_, PackedFunc f_visit_expr, PackedFunc f_visit_constant_, + PackedFunc f_visit_tuple_, PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, + PackedFunc f_visit_shape_expr_, PackedFunc f_visit_extern_func_, + PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_, + PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_, + PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_prim_value_, + PackedFunc f_visit_string_imm_, PackedFunc f_visit_data_type_imm_, PackedFunc f_visit_binding, + PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_, + PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, + PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, + PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_span) { + ObjectPtr n = make_object(); + n->builder_ = builder_; + n->f_visit_expr = f_visit_expr; + n->f_visit_constant_ = f_visit_constant_; + n->f_visit_tuple_ = f_visit_tuple_; + n->f_visit_var_ = f_visit_var_; + n->f_visit_dataflow_var_ = f_visit_dataflow_var_; + n->f_visit_shape_expr_ = f_visit_shape_expr_; + n->f_visit_extern_func_ = f_visit_extern_func_; + n->f_visit_global_var_ = f_visit_global_var_; + n->f_visit_function_ = f_visit_function_; + n->f_visit_call_ = f_visit_call_; + n->f_visit_seq_expr_ = f_visit_seq_expr_; + n->f_visit_if_ = f_visit_if_; + n->f_visit_op_ = f_visit_op_; + n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; + n->f_visit_prim_value_ = f_visit_prim_value_; + n->f_visit_string_imm_ = f_visit_string_imm_; + n->f_visit_data_type_imm_ = f_visit_data_type_imm_; + n->f_visit_binding = f_visit_binding; + n->f_visit_var_binding_ = f_visit_var_binding_; + n->f_visit_match_cast_ = f_visit_match_cast_; + n->f_visit_binding_block = f_visit_binding_block; + n->f_visit_binding_block_ = f_visit_binding_block_; + n->f_visit_dataflow_block_ = f_visit_dataflow_block_; + n->f_visit_var_def = f_visit_var_def; + n->f_visit_var_def_ = f_visit_var_def_; + n->f_visit_dataflow_var_def_ = f_visit_dataflow_var_def_; + n->f_visit_span = f_visit_span; + return PyExprMutator(n); + } + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprMutator, ObjectRef, PyExprMutatorNode); +}; + +TVM_REGISTER_GLOBAL("relax.MakePyExprVisitor").set_body_typed(PyExprVisitor::MakePyExprVisitor); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitExpr") + .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { visitor->VisitExpr(expr); }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBinding") + .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { + visitor->VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBindingBlock") + .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { + visitor->VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitVarDef") + .set_body_typed([](PyExprVisitor visitor, const Var& var) { visitor->VisitVarDef(var); }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitExpr") + .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { + visitor->ExprVisitor::VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding") + .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { + visitor->ExprVisitor::VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock") + .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { + visitor->ExprVisitor::VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitVarDef") + .set_body_typed([](PyExprVisitor visitor, const Var& var) { + visitor->ExprVisitor::VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitSpan") + .set_body_typed([](PyExprVisitor visitor, const Span& span) { + visitor->ExprVisitor::VisitSpan(span); + }); + +TVM_REGISTER_GLOBAL("relax.MakePyExprMutator").set_body_typed(PyExprMutator::MakePyExprMutator); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExpr") + .set_body_typed([](PyExprMutator mutator, const Expr& expr) { + return mutator->VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBinding") + .set_body_typed([](PyExprMutator mutator, const Binding& binding) { + mutator->VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBindingBlock") + .set_body_typed([](PyExprMutator mutator, const BindingBlock& block) { + return mutator->VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitVarDef") + .set_body_typed([](PyExprMutator mutator, const Var& var) { + return mutator->VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitExpr") + .set_body_typed([](PyExprMutator mutator, const Expr& expr) { + return mutator->ExprMutator::VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding") + .set_body_typed([](PyExprMutator mutator, const Binding& binding) { + return mutator->ExprMutator::VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock") + .set_body_typed([](PyExprMutator mutator, const BindingBlock& block) { + return mutator->ExprMutator::VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitVarDef") + .set_body_typed([](PyExprMutator mutator, const Var& var) { + return mutator->ExprMutator::VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExprPostOrder") + .set_body_typed([](PyExprMutator mutator, const Expr& expr) { + return mutator->VisitExprPostOrder(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitWithNewScope") + .set_body_typed([](PyExprMutator mutator, const Expr& expr) { + return mutator->VisitWithNewScope(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorLookupBinding") + .set_body_typed([](PyExprMutator mutator, const Var& var) { + return mutator->LookupBinding(var); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorWithStructInfo") + .set_body_typed([](PyExprMutator mutator, Var var, StructInfo sinfo) { + return mutator->WithStructInfo(var, sinfo); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorSetVarRemap") + .set_body_typed([](PyExprMutator mutator, Id id, Var var) { + return mutator->var_remap_[id] = var; + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorGetVarRemap") + .set_body_typed([](PyExprMutator mutator, Id id) { return mutator->var_remap_[id]; }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc new file mode 100644 index 000000000000..8640ed79adb0 --- /dev/null +++ b/src/relax/op/op.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +// call_tir + +StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "sinfo_args should have exact 1 output struct info."); + } + return call->sinfo_args[0]; +} + +RELAY_REGISTER_OP("relax.call_tir") + .set_num_inputs(3) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", "The input arguments.") + .add_argument("packed_ints", "Expr", + "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " + "args if unused") + .set_attr("FInferStructInfo", InferStructInfoCallTIR); + +Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, + Optional packed_ints) { + for (const TensorStructInfo& sinfo : out_sinfo_list) { + const auto* shape = sinfo->shape.as(); + CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; + } + + StructInfo out_sinfo{nullptr}; + if (out_sinfo_list.size() == 1) { + out_sinfo = out_sinfo_list[0]; + } else { + out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + } + + static const Op& op = Op::Get("relax.call_tir"); + Call call; + if (!packed_ints) { + // don't use additional optional argument + call = Call(op, {func, args}, {}, {out_sinfo}); + } else { + call = Call(op, {func, args, packed_ints.value()}, {}, {out_sinfo}); + } + return call; +} + +TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc new file mode 100644 index 000000000000..260f71e7bfb6 --- /dev/null +++ b/src/relax/op/op_common.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "op_common.h" + +#include + +namespace tvm { +namespace relax { + +Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { + Op op = Downcast(call->op); + int n_input = op->arguments.size(); + if (static_cast(call->args.size()) != n_input) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << op << " op should have " << n_input << " arguments"); + } + Array input_tensor_sinfo; + input_tensor_sinfo.reserve(n_input); + for (int i = 0; i < n_input; ++i) { + const auto* sinfo = GetStructInfoAs(call->args[i]); + if (sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << op << " requires the input " << op->arguments[i]->name + << " to be Tensor. However, the given one is " + << call->args[i]->struct_info_->GetTypeKey()); + } + input_tensor_sinfo.push_back(GetRef(sinfo)); + } + return input_tensor_sinfo; +} + +Optional> InferBinaryBroadcastShape(const Call& call, const BlockBuilder& ctx, + const Array& x1_shape, + const Array& x2_shape) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + int x1_ndim = x1_shape.size(); + int x2_ndim = x2_shape.size(); + int max_ndim = std::max(x1_ndim, x2_ndim); + + std::vector output_shape; + output_shape.reserve(max_ndim); + + int i = 1; + for (; i <= std::min(x1_ndim, x2_ndim); ++i) { + const PrimExpr& dim0 = x1_shape[x1_ndim - i]; + const PrimExpr& dim1 = x2_shape[x2_ndim - i]; + const auto* int_dim0 = dim0.as(); + const auto* int_dim1 = dim1.as(); + if (int_dim0 != nullptr && int_dim0->value == 1) { + output_shape.push_back(dim1); + } else if (int_dim1 != nullptr && int_dim1->value == 1) { + output_shape.push_back(dim0); + } else if (analyzer->CanProveEqual(dim0, dim1)) { + output_shape.push_back(dim0); + } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "In " << call->op << ", the first input shape at dim " << x1_ndim - i + << " is " << dim0 << " and the second input shape at dim " << x2_ndim - i + << " is " << dim1 << ", which are not broadcastable."); + } else { + // Use simple fallback when shape mismatch. + return NullOpt; + } + } + auto& longer_shape = (x1_ndim > x2_ndim) ? x1_shape : x2_shape; + for (; i <= max_ndim; ++i) { + output_shape.push_back(longer_shape[max_ndim - i]); + } + return Array(output_shape.rbegin(), output_shape.rend()); +} + +std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim, + const Array& axes) { + ICHECK_NE(ndim, kUnknownNDim) << "The ndim is required to be known for this function."; + std::vector appeared_dims_set; + std::vector axes_non_neg; + appeared_dims_set.resize(ndim, /*value=*/false); + axes_non_neg.reserve(axes.size()); + for (const Integer& axis : axes) { + int _axis = axis->value; + if (_axis < -ndim || _axis >= ndim) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "In " << call->op << ", the input axis " << _axis + << " is out of range. The input tensor has " << ndim + << " dimensions, so axis should be in range [" << -ndim << ", " << ndim + << ")."); + } else if (_axis < 0) { + _axis = ndim + _axis; + } + + if (appeared_dims_set[_axis]) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "In " << call->op + << ", the input axes is required to be non-repetitive. However, there are " + "multiple given axes referring to axis " + << _axis); + } + appeared_dims_set[_axis] = true; + axes_non_neg.push_back(_axis); + } + return axes_non_neg; +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h new file mode 100644 index 000000000000..8e362bb4d55c --- /dev/null +++ b/src/relax/op/op_common.h @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file op_common.h + * \brief A set of utilities and common functionality + * for Relax ops. + */ +#ifndef TVM_RELAX_OP_OP_COMMON_H_ +#define TVM_RELAX_OP_OP_COMMON_H_ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/************ Op input struct info getter ************/ + +/*! + * \brief Get the tensor struct info of the operator input. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \return The tensor struct info of each input. + * \note This function require every input to be Tensor. The number of call arguments is required + * to match the number of inputs of the op being called. + */ +Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx); + +/*! + * \brief Get the tensor struct info of the unary operator input. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \return The tensor struct info of the unary operator input. + * \throw Throw exception if the number of input is not one, or the struct info of the input is not + * a tensor struct info. + */ +inline TensorStructInfo GetUnaryInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { + return GetInputTensorStructInfo(call, ctx)[0]; +} + +/************ Op registration macro ************/ + +/*! + * \brief Quick helper macro to register the operator to registry + * \param OpRegName The name of operator to register. The name passed in will + * be prepended with a prefix "relax." as the identifier string in the operator registry. + */ +#define RELAX_REGISTER_UNARY_OP(OpRegName) \ + TVM_REGISTER_OP("relax." OpRegName) \ + .set_num_inputs(1) \ + .add_argument("x", "Tensor", "The input tensor.") + +/*! + * \brief Quick helper macro to expose a make-function to construct the operator. + * \param OpName The name of the operator as well as the make-function name, which will + * be prepended with a prefix "relax.op." as the FFI identifier string for the make function, + * \param OpRegName The identifier of the operator in the registry. + */ +#define RELAX_UNARY_OP_INTERFACE(OpName, OpRegName) \ + Expr OpName(Expr x) { \ + static const Op& op = Op::Get("relax." OpRegName); \ + return Call(op, {std::move(x)}, Attrs(), {}); \ + } \ + TVM_REGISTER_GLOBAL("relax.op." OpRegName).set_body_typed(OpName) + +/************ Utilities ************/ + +/*! + * \brief Infer the struct info for unary elementwise ops. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param f_compute_out_dtype The function to compute the output dtype, with + * signature DataType f_compute_out_dtype(const TensorStructInfo& input_sinfo). + * \tparam require_float_dtype whether this op requires the input dtype to be float + * \tparam Ftype the type of f_compute_out_dtype + * \return The inferred struct info. + */ +template +inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx, + FType f_compute_out_dtype) { + TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + if (require_float_dtype && !input_sinfo->IsUnknownDtype() && !input_sinfo->dtype.is_float()) { + ctx->ReportFatal( + Diagnostic::Error(call->span) + << call->op + << " requires the input tensor to have float dtype. However, the given input dtype is " + << input_sinfo->dtype); + } + auto output_sinfo = make_object(*input_sinfo.get()); + output_sinfo->dtype = f_compute_out_dtype(input_sinfo); + return TensorStructInfo(output_sinfo); +} + +/*! + * \brief Infer the struct info for unary arithmetic elementwise ops. It's also + * used in some NN operators. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \tparam require_float_dtype whether this op requires the input dtype to be float + * \return The inferred struct info. + */ +template +StructInfo InferStructInfoUnaryArith(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoUnary( + call, ctx, [](const TensorStructInfo& input_sinfo) { return input_sinfo->dtype; }); +} + +/************ Utilities ************/ + +/*! + * \brief Infer the output datatype for binary arithmetic operators. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param x1_sinfo The struct info of the first operand + * \param x2_sinfo The struct info of the second operand + * \return The inferred output dtype. + * \throw Throw exception if the dtype of two input TensorStructInfo don’t match + */ +inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx, + const TensorStructInfo& x1_sinfo, + const TensorStructInfo& x2_sinfo) { + if (x1_sinfo->IsUnknownDtype() || x2_sinfo->IsUnknownDtype()) { + return DataType::Void(); + } else if (x1_sinfo->dtype != x2_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "Data types " << x1_sinfo->dtype << " and " << x2_sinfo->dtype + << " must be equal for binary operators"); + } + return x1_sinfo->dtype; +} + +/*! + * \brief Infer the output shape for binary broadcast operators. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param x1_shape The shape of the first operand. + * \param x2_shape The shape of the second operand. + * \return The inferred output shape after broadcasting. Or `NullOpt` if the output shape cannot be + * determined due to symbolic broadcast. + */ +Optional> InferBinaryBroadcastShape(const Call& call, const BlockBuilder& ctx, + const Array& x1_shape, + const Array& x2_shape); + +/*! + * \brief Convert all axes to non-negative indices, and meanwhile check if the given array of axes + * are all in range and non-repetitive with regards to the given ndim. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param ndim The ndim constraint, which is required to be known already. + * \param axes The axis indices to be checked + * \return The input axes in non-negative indexing. + * \throw Throw exception if there exists out-of-range axis index or repetitive indices. + */ +std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim, + const Array& axes); + +/*! + * \brief Convert the given axis to non-negative index. Meanwhile check if the axis is in range + * with regards to the given ndim. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param ndim The ndim constraint. + * \param axis The axis index to be checked + * \return The input axis in non-negative indexing. + * \throw Throw exception the given axis is out-of-range. + */ +inline int NormalizeAxis(const Call& call, const BlockBuilder& ctx, int ndim, int axis) { + return NormalizeAxes(call, ctx, ndim, {axis})[0]; +} + +/*! + * \brief Convert an array of integers to int64 dtype. + * \param int_imms The input IntImms to be converted. + * \return The conversion result, where every IntImm has dtype int64 + */ +inline Array ConvertIntImmToInt64(const Array& int_imms) { + return int_imms.Map([](const IntImm& i) { return Downcast(cast(DataType::Int(64), i)); }); +} + +/************ Utilities for NN operators ************/ + +/*! + * \brief Complete the padding to a 4-length array. + * - If the padding length is 1, the same padding is used on all top/left/bottom/right sides + * - If the padding length is 2, top/bottom sides use padding[0] and left/right use padding[1] + * - If the padding length is 4, padding is in the order of (top, left, bottom, right) + * \param padding The given padding to be completed + * \return The completed padding. + * \throws Throws error if the input padding length is neither 1, 2 or 4. + */ +inline Array GetCompletePadding2D(Array padding) { + if (padding.size() == 1) { + return {padding[0], padding[0], padding[0], padding[0]}; + } else if (padding.size() == 2) { + return {padding[0], padding[1], padding[0], padding[1]}; + } else if (padding.size() == 4) { + return padding; + } + LOG(FATAL) << "The input padding length is expected to be either 1, 2 or 4. However, the given " + "padding is " + << padding; + throw; +} + +/*! + * \brief Check if the given tensor layout can be converted to the given target layout. + * If convertible, return the tensor layout and the bijective conversion in tir::Layout and + * tir::BijectiveLayout accordingly. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param tensor_layout The tensor layout to be checked + * \param tgt_layout The target layout to be matched + * \param tensor_name The name of the input tensor + * \return The tensor layout and the bijective conversion in tir::Layout and tir::BijectiveLayout + * accordingly. + */ +inline std::pair CheckTensorLayout(const Call& call, + const BlockBuilder& ctx, + const String& tensor_layout, + const String& tgt_layout, + const String& tensor_name) { + tir::Layout _tensor_layout(tensor_layout, DataType::Int(64)); + tir::BijectiveLayout tensor2tgt(_tensor_layout, tir::Layout(tgt_layout, DataType::Int(64))); + if (!tensor2tgt.defined()) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << call->op << " requires the given " << tensor_name + << " layout to be convertible from " << tgt_layout + << " layout. However, the given layout " << tensor_layout + << " is not convertible."); + } + return {_tensor_layout, tensor2tgt}; +} + +/*! + * \brief Check if the given tensor struct info has expected ndim per the given layout (or the ndim + * is unknown), and try to cast the shape to ShapeExpr. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \param sinfo The input tensor struct info to be checked. + * \param layout The layout that the given tensor is expected to have. + * \return The shape of the input tensor in ShapeExpr, or `NullOpt` if the shape is unknown. + */ +inline Optional CheckNdimPerLayoutAndGetShape(const Call& call, const BlockBuilder& ctx, + const TensorStructInfo& sinfo, + const tir::Layout& layout) { + if (!sinfo->IsUnknownNdim() && sinfo->ndim != static_cast(layout.ndim())) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "In " << call->op << ", layout " << layout << " requires the input to be " + << layout.ndim() << "-dim tensor. However, the given input has ndim " + << sinfo->ndim); + } + if (const auto* shape_expr = sinfo->shape.as()) { + return GetRef(shape_expr); + } + return NullOpt; +} + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_OP_COMMON_H_ diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc new file mode 100644 index 000000000000..dd61091f7aaa --- /dev/null +++ b/src/relax/op/tensor/binary.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file binary.cc + * \brief binary broadcast operators. + */ + +#include "binary.h" + +#include + +namespace tvm { +namespace relax { + +template +StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, + FType f_compute_out_dtype) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo x1_sinfo = input_sinfo[0]; + TensorStructInfo x2_sinfo = input_sinfo[1]; + + // DateType + DataType output_dtype = f_compute_out_dtype(call, ctx, x1_sinfo, x2_sinfo); + + // ndims + int output_ndim; + if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) { + output_ndim = kUnknownNDim; + } else { + output_ndim = std::max(x1_sinfo->ndim, x2_sinfo->ndim); + } + + const auto* x1_shape = x1_sinfo->shape.as(); + const auto* x2_shape = x2_sinfo->shape.as(); + // Shapes and ndims + if (x1_shape && x2_shape) { + // If all inputs have shapes, directly infer shapes + Optional> output_shape = + InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values); + if (!output_shape.defined()) { + return TensorStructInfo(output_dtype, /*ndim=*/output_ndim); + } else { + ICHECK_EQ(static_cast(output_shape.value().size()), output_ndim); + return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); + } + } else if (x1_sinfo->shape.defined() && x1_sinfo->shape.same_as(x2_sinfo->shape)) { + return TensorStructInfo(x1_sinfo->shape.value(), output_dtype); + } else { + return TensorStructInfo(output_dtype, /*ndim=*/output_ndim); + } +} + +StructInfo InferStructInfoBroadcastArith(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoBroadcast(call, ctx, InferBinaryArithOpOutDtype); +} + +StructInfo InferStructInfoBroadcastCMP(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoBroadcast( + call, ctx, + [](const Call& call, const BlockBuilder& ctx, const TensorStructInfo& x1_sinfo, + const TensorStructInfo& x2_sinfo) { return DataType::Bool(); }); +} + +/***************** Arithmetic operators *****************/ + +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(add); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h new file mode 100644 index 000000000000..a7aea576b685 --- /dev/null +++ b/src/relax/op/tensor/binary.h @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file binary.h + * \brief The functions to make Relax binary arithmetic and comparison operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_BINARY_H_ +#define TVM_RELAX_OP_TENSOR_BINARY_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Quick helper macro + * - Expose a make function to construct the node. + * - Register op to the registry. + * \param OpName The name of operator to register. The name passed in will + * 1. be prepended with a prefix "relax.op." as the FFI identifier string for the make function, + * 2. be prepended with a prefix "relax." as the identifier string in the operator registry. + */ +#define RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName) \ + Expr OpName(Expr x1, Expr x2) { \ + static const Op& op = Op::Get("relax." #OpName); \ + return Call(op, {x1, x2}, Attrs(), {}); \ + } \ + TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_REGISTER_OP("relax." #OpName) \ + .set_num_inputs(2) \ + .add_argument("x1", "Tensor", "The first input tensor.") \ + .add_argument("x2", "Tensor", "The second input tensor.") + +#define RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(OpName) \ + RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferStructInfo", InferStructInfoBroadcastArith) + +#define RELAX_REGISTER_CMP_OP_AND_IMPL(OpName) \ + RELAX_REGISTER_BINARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferStructInfo", InferStructInfoBroadcastCMP) + +/***************** Arithmetic operators *****************/ + +/*! \brief Addition with numpy-style broadcasting. */ +Expr add(Expr x1, Expr x2); + +/*! \brief Multiplication with numpy-style broadcasting. */ +Expr multiply(Expr x1, Expr x2); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_BINARY_H_ diff --git a/src/relax/utils.cc b/src/relax/utils.cc new file mode 100644 index 000000000000..5846f8116df2 --- /dev/null +++ b/src/relax/utils.cc @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +namespace tvm { +namespace relax { + +bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) { + const DynTensorTypeNode* tt = ty.as(); + if (!tt) { + return false; + } + bool correct_dtype = tt->dtype.is_bool() || (permit_unknown_dtype && tt->dtype.is_void()); + bool correct_rank = tt->ndim == 0 || (permit_unknown_rank && tt->ndim == -1); + return correct_dtype && correct_rank; +} + +bool IsLeafOrTuple(const Expr& expr) { + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as(); +} + +} // namespace relax +} // namespace tvm diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index bdcdbc023a1c..cca8b9e43322 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -581,5 +581,85 @@ TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body([](TVMArgs args, TVMRetValue* *ret = CreatePrimFunc(arg_list, index_dtype_override); }); +// Relax version impl +PrimFunc GenerateAndCompletePrimFunc(const Array& arg_list, + const Array& root_stmts, CreateFuncInfo* info, + const Optional> tir_var_list) { + Array parameters; + Map buffer_map; + for (const te::Tensor& tensor : arg_list) { + Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle())); + parameters.push_back(arg); + auto it = info->tensor2buffers.find(tensor); + ICHECK(it != info->tensor2buffers.end()); + buffer_map.Set(arg, it->second); + } + + // add additional arguments for tir vars that are left unbound by match buffer + if (tir_var_list) { + for (const Var& v : tir_var_list.value()) { + parameters.push_back(v); + } + } + + PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters), + /*body=*/SeqStmt::Flatten(root_stmts), + /*ret_type=*/VoidType(), + /*buffer_map=*/std::move(buffer_map)), + {{"global_symbol", String("main")}, {"tir.noalias", Bool(true)}}); + + const auto* complete = runtime::Registry::Get("script.Complete"); + ICHECK(complete); + func = (*complete)(std::move(func), info->root_alloc); + return func; +} + +PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, + const Array& constants, + const Optional>& tir_var_list, + std::optional index_dtype_override) { + // Infomations used in CreatePrimFunc and its sub-functions. + CreateFuncInfo info(arg_list); + // Root body stmts. + Array root_stmts; + // Analyzer + arith::Analyzer analyzer; + + // Step 1. Create ordered array of operations and validate they are supported. + Array order = CollectOrderedOps(arg_list); + + // Step 2. Initialize buffer binds map + InitializeBufferBinds(order, &info); + + // Step 3. Rewrite compute stages into blocks. + for (const te::Operation& op : order) { + RewriteStageToBlock(op, &info, &root_stmts, &analyzer); + } + auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info, tir_var_list); + func = tir::BindParams(func, constants); + if (index_dtype_override.has_value()) { + func = IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func)); + } + auto result = LayoutFreePlaceholdersNormalizer().Process(std::move(func)); + return result; +} + +PrimFunc CreatePrimFunc(const Array& arg_list, + const Optional> tir_var_list, + std::optional index_dtype_override) { + return CreatePrimFuncWithConstants(arg_list, {}, tir_var_list, index_dtype_override); +} + +TVM_REGISTER_GLOBAL("te.CreateRelaxPrimFunc").set_body([](TVMArgs args, TVMRetValue* ret) { + Array arg_list = args[0]; + Optional> tir_var_list = args[1]; + std::optional index_dtype_override{std::nullopt}; + // Add conversion to make std::optional compatible with FFI. + if (args[2].type_code() != kTVMNullptr) { + index_dtype_override = args[2].operator DataType(); + } + *ret = CreatePrimFunc(arg_list, tir_var_list, index_dtype_override); +}); + } // namespace tir } // namespace tvm diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index 4246347a16f3..946f024849bf 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -42,6 +42,23 @@ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, const Array& constants, std::optional index_dtype_override = std::nullopt); +// Relax version +// TODO(relax-team) combine with the relay version +/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ +PrimFunc CreatePrimFunc(const Array& arg_list, + const Optional> tir_var_list, + std::optional index_dtype_override); + +/*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the + * constants array is N, the last N tensors in arg_list will be treated as constant tensors. + * Constant tensors will not be part of the parameters of the created PrimFunc, instead constants + * will be embedded in the body as AllocateConstNode. + */ +PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, + const Array& constants, + const Optional>& tir_var_list, + std::optional index_dtype_override = std::nullopt); + } // namespace tir } // namespace tvm diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py index faf8fedcf4bf..03b98f8a565e 100644 --- a/tests/python/relax/test_analysis_struct_info_analysis.py +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -319,6 +319,149 @@ def fn_info_erased(): assert fopaque.is_base_of(fn_info_shape(1)) +def _check_derive(ctx, finfo, args_sinfo, ret): + gv = rx.GlobalVar("test") + rx.expr._update_struct_info(gv, finfo) + args = [] + for i, sinfo in enumerate(args_sinfo): + arg = rx.Var("arg%i" % i, sinfo) + args.append(arg) + call = rx.Call(gv, args) + derived_ret = rx.analysis.derive_call_ret_struct_info(finfo, call, ctx) + tvm.ir.assert_structural_equal(ret, derived_ret) + + +def test_derive_call_ret_struct_info(): + obj0 = rx.ObjectStructInfo() + prim0 = rx.PrimStructInfo("float32") + + n, m = tir.Var("n0", "int64"), tir.Var("m0", "int64") + bb = rx.BlockBuilder() + # derivation cases + with bb.testing_scope(def_vars=[n, m]): + + def func0(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.TensorStructInfo([n, m], "float32") + z = rx.TensorStructInfo([m + c, n], "float32") + return rx.FuncStructInfo([x], z) + + # Tensor => Tensor + _check_derive( + bb, + func0(1), + [rx.TensorStructInfo([10, 11], "float32")], + rx.TensorStructInfo([12, 10], "float32"), + ) + + _check_derive( + bb, + func0(2), + [rx.TensorStructInfo([n, m], "float32")], + rx.TensorStructInfo([m + 2, n], "float32"), + ) + + # passing in information that cannot deduce n, m + # it is still OK as type still matches, return an + # eriased output + _check_derive( + bb, + func0(2), + [rx.TensorStructInfo(ndim=2, dtype="float32")], + rx.TensorStructInfo(ndim=2, dtype="float32"), + ) + + # Error: wrong number of arguments + with pytest.raises(TVMError): + _check_derive( + bb, + func0(2), + [rx.TensorStructInfo(ndim=2, dtype="float32"), obj0], + rx.TensorStructInfo(ndim=2, dtype="float32"), + ) + + # Error:type mismatch + with pytest.raises(TVMError): + _check_derive(bb, func0(2), [obj0], obj0) + + # opaque derivation + fopaque0 = lambda: rx.FuncStructInfo.opaque_func() + fopaque1 = lambda: rx.FuncStructInfo.opaque_func(ret=prim0) + _check_derive(bb, fopaque0(), [obj0, prim0], obj0) + _check_derive(bb, fopaque1(), [obj0, prim0], prim0) + + # recursive tuple derivation + def func_tuple0(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x0 = rx.TensorStructInfo([n, c], "float32") + x1 = rx.TensorStructInfo([n + c, m], "float32") + z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) + return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z) + + _check_derive( + bb, + func_tuple0(2), + [ + rx.TupleStructInfo( + [ + rx.TensorStructInfo([n, 2], "float32"), + rx.TensorStructInfo([n + 2, 10], "float32"), + ] + ) + ], + rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]), + ) + + def func_tuple1(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x0 = rx.TensorStructInfo([n, m], "float32") + x1 = rx.TensorStructInfo([n + c, c], "float32") + z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) + return rx.FuncStructInfo([rx.TupleStructInfo([x0, x1])], z) + + # Still OK, to pass erased tensor into n+2, n is captured by other argument. + _check_derive( + bb, + func_tuple1(4), + [ + rx.TupleStructInfo( + [ + rx.TensorStructInfo([n, 4], "float32"), + rx.TensorStructInfo(ndim=2, dtype="float32"), + ] + ) + ], + rx.TupleStructInfo([rx.TensorStructInfo([4, n], "float32")]), + ) + + # tuple length mismatch is not causes an error + with pytest.raises(TVMError): + _check_derive( + bb, + func_tuple0(4), + [rx.TupleStructInfo([rx.TensorStructInfo([n, 4], "float32")])], + rx.TupleStructInfo([rx.TensorStructInfo([10, n], "float32")]), + ) + + # mixed shape types + def func_shape_mixed(c): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x0 = rx.ShapeStructInfo([n, m]) + f0 = func_tuple0(c) + z = rx.ShapeStructInfo([m + n, c]) + return rx.FuncStructInfo([x0, f0], z) + + _check_derive( + bb, + func_shape_mixed(3), + [ + rx.ShapeStructInfo([10, 20]), + rx.FuncStructInfo.opaque_func(ret=rx.ShapeStructInfo(ndim=2)), + ], + rx.ShapeStructInfo([30, 3]), + ) + + def _check_lca(lhs, rhs, target): tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(lhs, rhs), target) tvm.ir.assert_structural_equal(rx.analysis.struct_info_lca(rhs, lhs), target) diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py new file mode 100644 index 000000000000..36a22f9712ea --- /dev/null +++ b/tests/python/relax/test_blockbuilder.py @@ -0,0 +1,542 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +import tvm.testing + +from tvm import te, tir, topi +from tvm import relax as rx, relay +from tvm.ir.base import assert_structural_equal +from tvm.relax import ExternFunc +from tvm.tir.function import PrimFunc + + +@tvm.register_func("test.blockbuilder.nop") +def nop(): + pass + + +def test_block_builder(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + bb._begin_binding_block() + gv0 = bb.emit(rx.op.add(x, y)) + bb._begin_dataflow_block() + lv0 = bb.emit(rx.op.multiply(gv0, y)) + gv1 = bb.emit_output(rx.op.multiply(lv0, lv0)) + b0 = bb._end_block() + bb._begin_dataflow_block() + lv1 = bb.emit(rx.op.multiply(gv0, y)) + gv2 = bb.emit_output(rx.op.multiply(lv1, lv1)) + b1 = bb._end_block() + gv3 = bb.emit(rx.op.add(x, y)) + b2 = bb._end_block() + + assert isinstance(b0, rx.DataflowBlock) + assert isinstance(b1, rx.DataflowBlock) + assert not isinstance(b2, rx.DataflowBlock) + + +def test_function_single_block(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + lv1 = bb.emit(rx.op.multiply(lv0, y)) + assert lv1.name_hint == "lv1" + gv0 = bb.emit_output(lv1) + assert gv0.name_hint == "gv" + bb.emit_func_output(gv0) + + func = bb.get()["func"] + assert func.params[0] == x + assert func.params[1] == y + assert func.body.body == gv0 + assert_structural_equal(gv0.struct_info, rx.TensorStructInfo([m, n], "float16")) + assert len(func.body.blocks) == 1 + assert len(func.body.blocks[0].bindings) == 3 + + +def test_function_multi_blocks(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + gv0 = bb.emit_output(lv0) + assert gv0.name_hint == "gv" + gv1 = bb.emit(rx.op.add(gv0, gv0)) + assert gv1.name_hint == "gv1" + with bb.dataflow(): + lv1 = bb.emit(rx.op.add(gv1, gv1)) + assert lv1.name_hint == "lv1" + gv2 = bb.emit_output(gv1) + bb.emit_func_output(gv2) + + func = bb.get()["func"] + + assert_structural_equal(gv2.struct_info, rx.TensorStructInfo([m, n], "float16")) + assert func.params[0] == x + assert func.params[1] == y + assert func.body.body == gv2 + assert len(func.body.blocks) == 3 + assert len(func.body.blocks[0].bindings) == 2 + assert len(func.body.blocks[1].bindings) == 1 + assert len(func.body.blocks[2].bindings) == 2 + + +def test_multi_functions(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with bb.function("func1", [x, y]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + gv0 = bb.emit_output(lv0) + bb.emit_func_output(gv0) + + with bb.function("func2", [x, y]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(y, x)) + # TODO(@yuchen): enable block builder to reset local var unique name map + assert lv0.name_hint == "lv1" + gv0 = bb.emit_output(lv0) + bb.emit_func_output(gv0) + + mod = bb.get() + func1 = mod["func1"] + assert func1.params[0] == x + assert func1.params[1] == y + assert len(func1.body.blocks) == 1 + func2 = mod["func2"] + assert func2.params[0] == x + assert func2.params[1] == y + assert len(func2.body.blocks) == 1 + + +def test_binary_shape_type_deduction(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + k = tir.Var("k", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, 1], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + z = rx.Var("z", rx.TensorStructInfo([5], "float16")) + w = rx.Var("w", rx.TensorStructInfo([k], "float16")) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y, z, w]): + with bb.dataflow(): + lv0 = bb.emit(rx.op.add(x, y)) + assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float16")) + + lv1 = bb.emit(rx.op.multiply(x, z)) + assert_structural_equal(lv1.struct_info, rx.TensorStructInfo([m, 5], "float16")) + + lv2 = bb.emit(rx.op.multiply(z, w)) + assert isinstance(lv2.struct_info, rx.TensorStructInfo) + assert lv2.struct_info.ndim == 1 + assert lv2.struct_info.dtype == "float16" + + lv3 = bb.emit(rx.op.multiply(y, w)) + assert isinstance(lv3.struct_info, rx.TensorStructInfo) + assert lv3.struct_info.ndim == 1 + assert lv3.struct_info.dtype == "float16" + + gv0 = bb.emit_output(lv3) + bb.emit_func_output(gv0) + + assert isinstance(gv0.checked_type, rx.DynTensorType) + assert gv0.checked_type.ndim == 1 + assert gv0.checked_type.dtype == "float16" + + +def test_emit_match_cast(): + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + x = rx.Var("tensor_value", rx.TensorStructInfo(dtype="float32", ndim=-1)) + y = rx.Var("shape_value", rx.ShapeStructInfo([16, 8])) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y]): + with bb.dataflow(): + # lv0: Tensor((m, n), "float32") = + # match_cast(x: Tensor(_, "float32"], [m, n)) + lv0 = bb.match_cast(x, rx.TensorStructInfo([m, n], "float32")) + assert isinstance(lv0, rx.DataflowVar) + assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float32")) + + # lv1: Shape = match_cast(shape, rx.ShapeStructInfo([m, n])) + lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n])) + assert lv1.struct_info == rx.ShapeStructInfo([m, n]) + gv0 = bb.emit_output(lv1) + + bb.emit_func_output(gv0) + func = bb.get()["func"] + block = func.body.blocks[0] + b0, b1 = block.bindings[:2] + assert isinstance(b0, rx.MatchCast) + assert isinstance(b1, rx.MatchCast) + + assert b0.value == x + assert b0.struct_info == rx.TensorStructInfo([m, n], "float32") + assert b0.var == lv0 + + assert b1.value == y + assert b1.struct_info == rx.ShapeStructInfo([m, n]) + assert b1.var == lv1 + + +def test_emit_match_cast_binding_in_dataflow_block(): + bb = rx.BlockBuilder() + + x = rx.Var("x", rx.TensorStructInfo(dtype="float32", ndim=-1)) + m = tir.Var("m", dtype="int64") + gv = rx.Var("gv", rx.TensorStructInfo(dtype="float32", ndim=-1)) + match_cast = rx.MatchCast(gv, x, rx.TensorStructInfo((m,), "float32")) + + with bb.function("main", [x]): + with bb.dataflow(): + bb.emit_normalized(match_cast) + bb.emit_output(gv) + bb.emit_func_output(x) + + func = bb.get()["main"] + block = func.body.blocks[0] + b0 = block.bindings[0] + assert isinstance(b0, rx.MatchCast) + + assert b0.value == x + assert isinstance(b0.struct_info, rx.TensorStructInfo) + assert b0.struct_info.shape[0] == m + assert b0.var == gv + + +def test_normalize(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + # Call node + add_call = rx.op.multiply(x, y) + + bb.normalize(add_call) + shape = rx.get_shape_of(add_call) + + assert isinstance(shape, rx.ShapeExpr) + assert shape[0] == m + assert shape[1] == n + + # Tuple node + tuple_1 = rx.Tuple([x, y]) + bb.normalize(tuple_1) + assert isinstance(tuple_1.struct_info, rx.TupleStructInfo) + assert isinstance(tuple_1.struct_info.fields[0], rx.TensorStructInfo) + assert isinstance(tuple_1.struct_info.fields[1], rx.TensorStructInfo) + + # Nested Tuple + tuple_2 = rx.Tuple([x, rx.Tuple([x, y])]) + bb.normalize(tuple_2) + type_anno0 = x.checked_type + type_anno1 = y.checked_type + assert_structural_equal( + tuple_2.checked_type, rx.TupleType([type_anno0, rx.TupleType([type_anno0, type_anno1])]) + ) + assert isinstance(tuple_2.struct_info, rx.TupleStructInfo) + assert isinstance(tuple_2.struct_info.fields[0], rx.TensorStructInfo) + assert isinstance(tuple_2.struct_info.fields[1], rx.TupleStructInfo) + assert isinstance(tuple_2.struct_info.fields[1].fields[0], rx.TensorStructInfo) + assert isinstance(tuple_2.struct_info.fields[1].fields[1], rx.TensorStructInfo) + + +def test_call_te(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) + z = rx.Var("z", rx.TensorStructInfo([n, m], "float32")) + + def te_func(args, args_dict, msg): + A, B = args + C = args_dict["C"] + D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j]) + return E + + with bb.function("rx_func", [x, y, z]): + with bb.dataflow(): + out = bb.emit_output(bb.call_te(te_func, [x, y], {"C": z}, msg="hello")) + bb.emit_func_output(out) + + mod = bb.get() + rx_func = mod["rx_func"] + + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert rx_func.params[2] == z + assert rx_func.body.body == out + assert len(rx_func.body.blocks) == 1 + assert len(rx_func.body.blocks[0].bindings) == 1 + + +def test_call_te_with_unsupported_shape_arg(): + bb = rx.BlockBuilder() + x = rx.Var("x", rx.TensorStructInfo((200,), "float32")) + s = rx.Var("s", rx.ShapeStructInfo((200,))) + + with pytest.raises(AssertionError): + with bb.function("rx_func", [x]): + out = bb.emit(bb.call_te(topi.reshape, x, s)) + bb.emit_func_output(out) + + +def test_emit_te(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) + z = rx.Var("z", rx.TensorStructInfo([n, m], "float32")) + + def te_func(args, args_dict, msg): + A, B = args + C = args_dict["C"] + D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j]) + return E + + with bb.function("rx_func", [x, y, z]): + out = bb.emit_te(te_func, [x, y], {"C": z}, msg="hello") + bb.emit_func_output(out) + + mod = bb.get() + rx_func = mod["rx_func"] + + def get_tir_func(): + A = te.placeholder((n, m), dtype="float32", name="A") + B = te.placeholder((n, m), dtype="float32", name="B") + C = te.placeholder((n, m), dtype="float32", name="C") + out = te_func((A, B), {"C": C}, "") + return tvm.te.create_prim_func([A, B, C, out], index_dtype_override="int64") + + # check TIR structure matches expected + assert_structural_equal(mod["te_func"].body, get_tir_func().body) + + # check Relax function calls TIR function with call_tir call + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert rx_func.params[2] == z + assert rx_func.body.body == out + assert len(rx_func.body.blocks) == 1 + assert len(rx_func.body.blocks[0].bindings) == 1 + + call_node = rx_func.body.blocks[0].bindings[0].value + assert isinstance(call_node, rx.Call) + assert call_node.op == relay.op.get("relax.call_tir") + assert len(call_node.args) == 2 + assert call_node.args[0].name_hint == "te_func" + assert call_node.args[1][0] == x + assert call_node.args[1][1] == y + assert call_node.args[1][2] == z + + +def test_emit_te_multiple(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) + z = rx.Var("z", rx.TensorStructInfo([128, m], "float32")) + + def te_func(A): + B = te.compute((128, 128), lambda i, j: A[i, j] + 1) + return B + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, x) + y1 = bb.emit_te(te_func, y) + z1 = bb.emit_te(te_func, z) + bb.emit_func_output(z1) + + mod = bb.get() + rx_func = mod["rx_func"] + + prim_func = [] + for gv in mod.get_global_vars(): + if isinstance(mod[gv], PrimFunc): + prim_func.append(mod[gv]) + + # only two PrimFuncs were generated since two of them are equal so got deduped + assert len(prim_func) == 2 + assert rx_func.body.blocks[0].bindings[0].value.args[0].name_hint == "te_func" + assert rx_func.body.blocks[0].bindings[1].value.args[0].name_hint == "te_func" + assert rx_func.body.blocks[0].bindings[2].value.args[0].name_hint == "te_func1" + + +def test_emit_te_multiple_output(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + + def te_func(A): + B0, B1 = te.compute((n, m), lambda i, j: (A[i, j] + 1, A[i, j] * 2), name="B") + return (B0, B1) + + with bb.function("rx_func", [x]): + y = bb.emit_te(te_func, x) + z = rx.TupleGetItem(y, 0) + bb.emit_func_output([y, z]) + + rx_func = bb.get()["rx_func"] + + # check call tir output shape is a Tuple of ShapeExpr + assert rx_func.params[0] == x + call_node = rx_func.body.blocks[0].bindings[0].value + assert call_node.op == relay.op.get("relax.call_tir") + assert call_node.args[0].name_hint == "te_func" + assert isinstance(call_node.sinfo_args[0], rx.TupleStructInfo) + assert len(call_node.sinfo_args[0].fields) == 2 + assert isinstance(call_node.sinfo_args[0].fields[0].shape, rx.ShapeExpr) + assert isinstance(call_node.sinfo_args[0].fields[1].shape, rx.ShapeExpr) + + +def test_emit_te_extern(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([m, n], "float32")) + + with bb.function("rx_cblas_matmul", [x, y]): + out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False) + bb.emit_func_output(out) + + mod = bb.get() + rx_func = mod["rx_cblas_matmul"] + + # check Relax function calls TIR function with call_tir call + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert len(rx_func.body.blocks) == 1 + call_node = rx_func.body.blocks[0].bindings[0].value + assert isinstance(call_node, rx.Call) + assert call_node.op == relay.op.get("relax.call_tir") + assert len(call_node.args) == 2 + assert call_node.args[0].name_hint == "matmul" + assert call_node.args[1][0] == x + assert call_node.args[1][1] == y + assert call_node.sinfo_args[0].shape[0] == n + assert call_node.sinfo_args[0].shape[1] == n + + +def test_nested_function_fail(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, x)) + with bb.function("func1", [x, y]): + gv1 = bb.emit(rx.op.add(x, x)) + bb.emit_func_output(gv0) + + +def test_emit_func_output_twice_fail(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv0) + bb.emit_func_output(gv0) + + +def test_func_params_twice_fail(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv0, [x]) + + +def test_no_func_params_fail(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func"): + gv0 = bb.emit(rx.Call(ExternFunc("test.blockbuilder.nop"), [])) + bb.emit_func_output(gv0) + + +def test_block_builder_scope_recovery(): + bb = rx.BlockBuilder() + + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) + y = rx.Var("y", rx.TensorStructInfo([m, n], "float32")) + + with pytest.raises(RuntimeError): + # this line fails + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + + # current should be recovered + assert rx.BlockBuilder.current() is None + + # second attempt to do it correctly. + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv0) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index 4eeaed1e0b50..902c4785610f 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. import numpy as np -import pytest import tvm from tvm import relax as rx from tvm import tir from tvm.script import relax as R +import pytest def _check_equal(x, y, map_free_vars=False): @@ -255,4 +255,4 @@ def test_datatype_imm(): if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py new file mode 100644 index 000000000000..8165107394c9 --- /dev/null +++ b/tests/python/relax/test_expr_functor.py @@ -0,0 +1,746 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relax, tir +from tvm.ir import Op +from tvm.ir.base import assert_structural_equal +from tvm.relax import PyExprMutator, PyExprVisitor +from tvm.relax.expr import ( + BindingBlock, + Call, + Constant, + DataflowBlock, + DataflowVar, + Expr, + ExternFunc, + Function, + GlobalVar, + If, + MatchCast, + SeqExpr, + ShapeExpr, + Tuple, + TupleGetItem, + PrimValue, + StringImm, + DataTypeImm, + Var, + VarBinding, +) +from tvm.script import relax as R +import pytest + +m, n = tir.Var("m", "int64"), tir.Var("n", "int64") +x = relax.Var("x", R.Tensor([n], "float32")) +y = relax.Var("y", R.Tensor([m, n], "float32")) +bb = relax.BlockBuilder() + + +@relax.expr_functor.visitor +class BasicVisitor(PyExprVisitor): + """Default ExprVisitor""" + + +class ASTLog: + """Helper class to log AST""" + + def __init__(self) -> None: + self.log = [] + self.indent = "\t" + self.level = 0 + + def push_scope(self): + self.level += 1 + + def pop_scope(self): + self.level -= 1 + + def add(self, s: str): + self.log.append(self.indent * self.level + s) + + def __str__(self) -> str: + return "\n".join(self.log) + + +@relax.expr_functor.visitor +class ASTPrinter(PyExprVisitor): + """Print relax AST in structured format. The shape of Node is ignored.""" + + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_constant_(self, op: Constant) -> None: + self.log.add("Constant") + + def visit_global_var_(self, op: GlobalVar) -> None: + self.log.add("GlobalVar") + + def visit_tuple_(self, op: Tuple) -> None: + self.log.add("Tuple") + self.log.push_scope() + for field in op.fields: + self.visit_expr(field) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_dataflow_var_(self, op: DataflowVar) -> None: + self.log.add("DataflowVar") + + def visit_function_(self, op: Function) -> None: + self.log.add("Function") + self.log.push_scope() + for param in op.params: + self.visit_var_def(param) + + self.visit_expr(op.body) + self.log.pop_scope() + + def visit_call_(self, op: Call) -> None: + self.log.add("Call") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_if_(self, op: If) -> None: + self.log.add("If") + self.log.push_scope() + self.visit_expr(op.cond) + self.visit_expr(op.true_branch) + self.visit_expr(op.false_branch) + self.log.pop_scope() + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + def visit_tuple_getitem_(self, op: TupleGetItem) -> None: + self.log.add("TupleGetItem") + self.log.push_scope() + self.visit_expr(op.tuple_value) + self.log.pop_scope() + + def visit_prim_value_(self, op: PrimValue) -> None: + self.log.add("PrimValue") + + def visit_string_imm_(self, op: StringImm) -> None: + self.log.add("StringImm") + + def visit_data_type_imm_(self, op: DataTypeImm) -> None: + self.log.add("DataTypeImm") + + def visit_shape_expr_(self, op: ShapeExpr) -> None: + self.log.add("ShapeExpr") + + def visit_extern_func_(self, op: ExternFunc) -> None: + self.log.add("ExternFunc") + + def visit_seq_expr_(self, op: SeqExpr) -> None: + self.log.add("SeqExpr") + self.log.push_scope() + for block in op.blocks: + self.visit_binding_block(block) + self.visit_expr(op.body) + self.log.pop_scope() + + def visit_var_binding_(self, binding: VarBinding) -> None: + self.log.add("VarBinding") + self.log.push_scope() + self.visit_expr(binding.value) + self.visit_var_def(binding.var) + self.log.pop_scope() + + def visit_match_cast_(self, binding: MatchCast) -> None: + self.log.add("MatchCast") + self.log.push_scope() + self.visit_var_def(binding.var) + self.visit_expr(binding.value) + self.log.pop_scope() + + def visit_binding_block_(self, block: BindingBlock) -> None: + self.log.add("BindingBlock") + self.log.push_scope() + for binding in block.bindings: + self.visit_binding(binding) + self.log.pop_scope() + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + self.log.add("DataflowBlock") + self.log.push_scope() + for binding in block.bindings: + self.visit_binding(binding) + self.log.pop_scope() + + def visit_var_def_(self, var: Var) -> None: + self.log.add("VarDef") + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + self.log.add("DataflowVarDef") + + +@relax.expr_functor.mutator +class BasicMutator(PyExprMutator): + """Default ExprMutator""" + + +@relax.expr_functor.mutator +class ASTPostPrinterMutator(PyExprMutator): + """Print relax AST in the post order format.""" + + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_constant_(self, op: Constant) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Constant") + return op + + def visit_global_var_(self, op: GlobalVar) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("GlobalVar") + return op + + def visit_tuple_(self, op: Tuple) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Tuple") + return op + + def visit_var_(self, op: Var) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Var") + return op + + def visit_dataflow_var_(self, op: DataflowVar) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("DataflowVar") + return op + + def visit_function_(self, op: Function) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Function") + return op + + def visit_call_(self, op: Call) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Call") + return op + + def visit_if_(self, op: If) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("If") + return op + + def visit_op_(self, op: Op) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Op") + return op + + def visit_tuple_getitem_(self, op: TupleGetItem) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("TupleGetItem") + return op + + def visit_prim_value_(self, op: PrimValue) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("PrimValue") + return op + + def visit_string_imm_(self, op: StringImm) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("StringImm") + return op + + def visit_data_type_imm_(self, op: DataTypeImm) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("DataTypeImm") + return op + + def visit_shape_expr_(self, op: ShapeExpr) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("ShapeExpr") + return op + + def visit_extern_func_(self, op: ExternFunc) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("ExternFunc") + return op + + def visit_seq_expr_(self, op: SeqExpr) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("SeqExpr") + return op + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Identical with ExprMutator::VisitBinding_(const VarBindingNode* binding) on the C++ side.""" + new_value = self.visit_expr(binding.value) + new_var = self.visit_var_def(binding.var) + + self.log.add("VarBinding") + if binding.var.same_as(new_var) and binding.value.same_as(new_value): + self.builder_.emit_normalized(binding) + return + + temp = self.with_struct_info(new_var, new_value.struct_info) + if not temp.same_as(new_var): + new_var = temp + self.set_var_remap(binding.var.vid, new_var) + + self.builder_.emit_normalized(VarBinding(new_var, new_value)) + + def visit_match_cast_(self, binding: MatchCast) -> None: + """Identical with ExprMutator::VisitBinding_(const MatchCastNode* binding) on the C++ side.""" + new_var = self.visit_var_def(binding.var) + new_value = self.visit_expr(binding.value) + + temp = self.with_struct_info(new_var, binding.struct_info) + if not temp.same_as(new_var): + new_var = temp + self.set_var_remap(binding.var.vid, new_var) + + self.log.add("MatchCast") + self.builder_.emit_normalized(MatchCast(new_var, new_value, binding.struct_info)) + + def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: + """Identical with ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) on the C++ side.""" + self.builder_._begin_binding_block() + for binding in block.bindings: + self.visit_binding(binding) + self.log.add("BindingBlock") + return self.builder_._end_block() + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + """Identical with ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) on the C++ side.""" + self.builder_._begin_dataflow_block() + for binding in block.bindings: + self.visit_binding(binding) + self.log.add("DataflowBlock") + return self.builder_._end_block() + + def visit_var_def_(self, var: Var) -> None: + """Identical with ExprMutator::VisitVarDef_(const VarNode* var) on the C++ side.""" + self.log.add("VarDef") + return var + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + """Identical with ExprMutator::VisitVarDef_(const DataflowVarNode* var) on the C++ side.""" + self.log.add("DataflowVarDef") + return var + + +def basic_check(expr, visitor_str, mutator_str): + def visit(f, expr): + if isinstance(expr, relax.Expr): + return f.visit_expr(expr) + elif isinstance(expr, relax.BindingBlock): + return f.visit_binding_block(expr) + + # check no overloading case + basic_visitor = BasicVisitor() + visit(basic_visitor, expr) + + # check the output log + log_visitor = ASTPrinter() + visit(log_visitor, expr) + assert str(log_visitor.log) == visitor_str + + # check no overloading case + basic_mutator = BasicMutator() + # skip normalize GlobalVar since it requires context IRModule to get the checked_type_ + if isinstance(expr, relax.Expr) and not isinstance(expr, relax.GlobalVar): + expr = bb.normalize(expr) + assert_structural_equal(visit(basic_mutator, expr), expr) + + # check the output log and return value + post_log_mutator = ASTPostPrinterMutator() + if isinstance(expr, relax.Expr) and not isinstance(expr, relax.GlobalVar): + expr = bb.normalize(expr) + assert_structural_equal(visit(post_log_mutator, expr), expr) + assert str(post_log_mutator.log) == mutator_str + + +def test_constant(): + basic_check(relax.const(1.0), "Constant", "Constant") + + +def test_var(): + basic_check(x, "Var", "Var") + + +def test_dataflow_var(): + lv = relax.DataflowVar("lv", R.Tensor([n], "float32")) + basic_check(lv, "DataflowVar", "DataflowVar") + + +def test_tuple(): + t = relax.Tuple([x, y]) + basic_check(t, "\n".join(["Tuple", "\tVar", "\tVar"]), "\n".join(["Var", "Var", "Tuple"])) + + +def test_global_var(): + gv = relax.GlobalVar("gv") + basic_check(gv, "GlobalVar", "GlobalVar") + + +def test_seq_expr(): + bindings = [relax.VarBinding(x, relax.const(1))] + blocks = [relax.BindingBlock(bindings)] + seq_expr = relax.SeqExpr(blocks, x) + basic_check( + seq_expr, + "\n".join( + [ + "SeqExpr", + "\tBindingBlock", + "\t\tVarBinding", + "\t\t\tConstant", + "\t\t\tVarDef", + "\tVar", + ] + ), + "\n".join(["Constant", "VarDef", "VarBinding", "BindingBlock", "Var", "SeqExpr"]), + ) + + +def test_shape_expr(): + x = relax.ShapeExpr([m, n]) + basic_check(x, "ShapeExpr", "ShapeExpr") + + +def test_call(): + call_node = relax.op.add(x, y) + basic_check( + call_node, + "\n".join(["Call", "\tOp", "\tVar", "\tVar"]), + "\n".join(["Op", "Var", "Var", "ShapeExpr", "Call"]), + ) + + +def test_if(): + if_node = relax.If(x, x, x) + basic_check( + if_node, + "\n".join(["If", "\tVar", "\tVar", "\tVar"]), + "\n".join(["Var", "Var", "SeqExpr", "Var", "SeqExpr", "If"]), + ) + + +def test_tuple_getitem(): + tuple_getitem_node = relax.TupleGetItem(relax.Tuple([x, y]), 0) + basic_check( + tuple_getitem_node, + "\n".join(["TupleGetItem", "\tTuple", "\t\tVar", "\t\tVar"]), + "\n".join(["Var", "Var", "Tuple", "TupleGetItem"]), + ) + + +def test_binding_block(): + bb._begin_binding_block() + gv0 = bb.emit(relax.op.add(x, y)) + gv1 = bb.match_cast(y, R.Tensor([m, n], "float32")) + b0 = bb._end_block() + basic_check( + b0, + "\n".join( + [ + "BindingBlock", + "\tVarBinding", + "\t\tCall", + "\t\t\tOp", + "\t\t\tVar", + "\t\t\tVar", + "\t\tVarDef", + "\tMatchCast", + "\t\tVarDef", + "\t\tVar", + ] + ), + "\n".join( + [ + "Op", + "Var", + "Var", + "Call", + "ShapeExpr", + "VarDef", + "VarBinding", + "Var", + "ShapeExpr", + "ShapeExpr", + "VarDef", + "MatchCast", + "BindingBlock", + ] + ), + ) + + +def test_dataflow_block(): + bb._begin_dataflow_block() + lv0 = bb.emit(relax.op.add(x, y)) + gv1 = bb.match_cast(y, R.Tensor([m, n], "float32")) + b0 = bb._end_block() + basic_check( + b0, + "\n".join( + [ + "DataflowBlock", + "\tVarBinding", + "\t\tCall", + "\t\t\tOp", + "\t\t\tVar", + "\t\t\tVar", + "\t\tDataflowVarDef", + "\tMatchCast", + "\t\tDataflowVarDef", + "\t\tVar", + ] + ), + "\n".join( + [ + "Op", + "Var", + "Var", + "Call", + "ShapeExpr", + "DataflowVarDef", + "VarBinding", + "Var", + "ShapeExpr", + "ShapeExpr", + "DataflowVarDef", + "MatchCast", + "DataflowBlock", + ] + ), + ) + + +def test_function(): + bindings = [relax.VarBinding(x, relax.const(1))] + blocks = [relax.BindingBlock(bindings)] + seq_expr = relax.SeqExpr(blocks, x) + func = relax.Function([x], seq_expr, R.Tensor([n], "float32")) + basic_check( + func, + "\n".join( + [ + "Function", + "\tVarDef", + "\tSeqExpr", + "\t\tBindingBlock", + "\t\t\tVarBinding", + "\t\t\t\tConstant", + "\t\t\t\tVarDef", + "\t\tVar", + ] + ), + "\n".join( + [ + "VarDef", + "Constant", + "VarDef", + "VarBinding", + "BindingBlock", + "Var", + "SeqExpr", + "Function", + ] + ), + ) + + +def test_extern_func(): + func = relax.ExternFunc("f") + basic_check(func, "ExternFunc", "ExternFunc") + + +def test_inherit(): + # The internal class is not instantiated. + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + call_node = relax.op.add(x, y) + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "\tOp", "\tVar", "\tVar"]) + + +def test_inherit_with_cls(): + # The decorator converts `InternalVisitor` to a wrapper class. + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + # `InternalVisitor._cls` refers to the original `InternalVisitor` users defined. + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + call_node = relax.op.add(x, y) + iv = InternalVisitor() + iv.visit_expr(call_node) + assert str(iv.log) == "\n".join(["InternalCall", "\tOp", "\tVar", "\tVar"]) + + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "\tOp", "\tVar", "\tVar"]) + + +def test_wrong_inherit(): + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def visit_call_(self, op: Call) -> None: + pass + + with pytest.raises( + TypeError, + match="Inheritance from a decorated object `LeafVisitor` is not allowed. Please inherit from `LeafVisitor._cls`.", + ): + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor): + def visit_call_(self, op: Call) -> None: + pass + + +def test_call_visitor_super(): + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + super().visit_call_(op) # call PyExprVisitor.visit_call_ + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + super().visit_call_(op) # call InternalVisit.visit_call_ + + call_node = relax.op.add(x, y) + iv = InternalVisitor() + iv.visit_expr(call_node) + assert str(iv.log) == "\n".join(["InternalCall", "Op", "Var", "Var"]) + + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"]) + + +def test_call_mutator_super(): + @relax.expr_functor.mutator + class InternalMutator(PyExprMutator): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + return super().visit_call_(op) # call PyExprMutator.visit_call_ + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + return super().visit_var_(op) # call PyExprMutator.visit_var_ + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + return super().visit_op_(op) # call PyExprMutator.visit_op_ + + @relax.expr_functor.mutator + class LeafMutator(InternalMutator._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + return super().visit_call_(op) # call InternalMutator.visit_call_ + + call_node = relax.op.add(x, y) + im = InternalMutator() + im.visit_expr(call_node) + assert str(im.log) == "\n".join(["InternalCall", "Op", "Var", "Var"]) + + lm = LeafMutator() + lm.visit_expr(call_node) + assert str(lm.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"]) + + +if __name__ == "__main__": + tvm.testing.main() From 450d2a7f39010d2380b97e4868a5fd127fa90ed5 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 8 Feb 2023 22:31:47 +0800 Subject: [PATCH 07/81] [Unity] Relax TVMScript Parser. (#13932) This PR adds the TVMScript parser/ir_builder support based on the blockbuilder. Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao Co-authored-by: Tianqi Chen Co-authored-by: Yuchen Jin Co-authored-by: Steven S. Lyubomirsky Co-authored-by: Yong Wu --- include/tvm/ir/expr.h | 1 + include/tvm/script/ir_builder/ir/frame.h | 11 +- include/tvm/script/ir_builder/ir/ir.h | 17 + include/tvm/script/ir_builder/relax/frame.h | 293 +++++ include/tvm/script/ir_builder/relax/ir.h | 137 +++ python/tvm/ir/expr.py | 50 +- python/tvm/script/ir_builder/base.py | 6 +- python/tvm/script/ir_builder/ir/__init__.py | 2 +- python/tvm/script/ir_builder/ir/ir.py | 45 + .../tvm/script/ir_builder/relax/__init__.py | 20 + .../tvm/script/ir_builder/relax/_ffi_api.py | 20 + python/tvm/script/ir_builder/relax/frame.py | 55 + python/tvm/script/ir_builder/relax/ir.py | 407 +++++++ python/tvm/script/parser/__init__.py | 3 +- python/tvm/script/parser/core/diagnostics.py | 2 +- python/tvm/script/parser/core/entry.py | 4 + python/tvm/script/parser/core/evaluator.py | 2 +- python/tvm/script/parser/core/parser.py | 50 +- python/tvm/script/parser/ir/parser.py | 4 + python/tvm/script/parser/relax/__init__.py | 17 +- python/tvm/script/parser/relax/entry.py | 327 +++++ python/tvm/script/parser/relax/parser.py | 276 +++++ python/tvm/script/parser/tir/entry.py | 4 +- python/tvm/script/parser/tir/operation.py | 12 +- python/tvm/script/parser/tir/parser.py | 26 + src/ir/module.cc | 16 +- src/script/ir_builder/ir/frame.cc | 12 +- src/script/ir_builder/ir/ir.cc | 41 +- src/script/ir_builder/ir/utils.h | 49 + src/script/ir_builder/relax/frame.cc | 273 +++++ src/script/ir_builder/relax/ir.cc | 236 ++++ src/script/ir_builder/relax/utils.h | 119 ++ src/script/ir_builder/tir/frame.cc | 15 +- src/script/ir_builder/tir/utils.h | 2 +- .../python/relax/test_tvmscript_ir_builder.py | 153 +++ tests/python/relax/test_tvmscript_parser.py | 1062 +++++++++++++++++ 36 files changed, 3720 insertions(+), 49 deletions(-) create mode 100644 include/tvm/script/ir_builder/relax/frame.h create mode 100644 include/tvm/script/ir_builder/relax/ir.h create mode 100644 python/tvm/script/ir_builder/relax/__init__.py create mode 100644 python/tvm/script/ir_builder/relax/_ffi_api.py create mode 100644 python/tvm/script/ir_builder/relax/frame.py create mode 100644 python/tvm/script/ir_builder/relax/ir.py create mode 100644 python/tvm/script/parser/relax/entry.py create mode 100644 python/tvm/script/parser/relax/parser.py create mode 100644 src/script/ir_builder/ir/utils.h create mode 100644 src/script/ir_builder/relax/frame.cc create mode 100644 src/script/ir_builder/relax/ir.cc create mode 100644 src/script/ir_builder/relax/utils.h create mode 100644 tests/python/relax/test_tvmscript_ir_builder.py create mode 100644 tests/python/relax/test_tvmscript_parser.py diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index d4ba628d36cf..c662067a0486 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -462,6 +462,7 @@ class GlobalVarNode : public RelayExprNode { v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); + v->Visit("struct_info_", &struct_info_); } bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const { diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index 887981ccffc8..dacfc361a6c7 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -38,12 +38,17 @@ namespace ir { */ class IRModuleFrameNode : public IRBuilderFrameNode { public: - Array global_vars; - Array functions; + /*! \brief A map from string names to global variables that ensures global uniqueness. */ + Map global_var_map; + /*! + * \brief A map from GlobalVar to all global functions. + * \note Only defined functions are in the map, while declared functions are not included. + */ + Map functions; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); - v->Visit("global_vars", &global_vars); + v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); } diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index f0e7cc6f5c2f..49bdcf60e6fb 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -37,6 +37,23 @@ namespace ir { */ TVM_DLL IRModuleFrame IRModule(); +/*! + * \brief Declare a Function without given the specific function implementation. + * \note It is usually used in cross-function call. And we can specify the function by `DefFunction` + * \param func_name The function unique name. + * \param func_signature A Function w/o body, which used to specify the function signature + * (i.e. func params and func return type/shape). + * \return The corresponding GlobalVar. + */ +TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature); + +/*! + * \brief Define the function which is declared before. + * \param func_name The function unique name. + * \param func The given function implementation + */ +TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func); + } // namespace ir } // namespace ir_builder } // namespace script diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h new file mode 100644 index 000000000000..0f544d3abcc2 --- /dev/null +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +/*! \brief The base ir_builder frame for the relax dialect. */ +class RelaxFrameNode : public IRBuilderFrameNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); } + + static constexpr const char* _type_key = "script.ir_builder.relax.RelaxFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(RelaxFrameNode, IRBuilderFrameNode); +}; + +class RelaxFrame : public IRBuilderFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame, IRBuilderFrame, RelaxFrameNode); + + protected: + RelaxFrame() = default; +}; + +/*! \brief The base ir_builder frame for frames with SeqExpr + i.e. Functions, If branches + */ +class SeqExprFrameNode : public RelaxFrameNode { + public: + /*! \brief The binding blocks inside the frame. */ + Array binding_blocks; + /*! \brief The frame output expr. `NullOpt` when undefined. */ + Optional output; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("binding_blocks", &binding_blocks); + v->Visit("output", &output); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.SeqExprFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(SeqExprFrameNode, RelaxFrameNode); + + public: + void EnterWithScope() override; + void ExitWithScope() override; +}; + +class SeqExprFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SeqExprFrame, RelaxFrame, SeqExprFrameNode); +}; + +/*! \brief The ir_builder frame for the relax function. */ +class FunctionFrameNode : public SeqExprFrameNode { + public: + /*! + * \brief The function name. + * \note The name will not be specified in constructor, so it is "Optional", + * However, we must specify the name by `R.func_name` before exit this frame. + */ + Optional name; + /*! \brief The function params. */ + Array params; + /*! + * \brief The function return struct info. + * \note Usually the function return type can be deduced by the function body. + * But we can use this field to specify a more "accurate" return type. + * i.e. If the `ret_struct_info` is None, try to use the deduced type from body + * If the `ret_struct_info` is not None, we can still take body.struct_info + * if we ret_struct_info is base of body.struct_info. If not, we will + * take the specified `ret_struct_info`. + */ + Optional ret_struct_info; + + /*! \brief The function attributes. */ + Map attrs; + /*! \brief The block builder to create Relax function. */ + tvm::relax::BlockBuilder block_builder; + + void VisitAttrs(tvm::AttrVisitor* v) { + SeqExprFrameNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("params", ¶ms); + v->Visit("ret_struct_info", &ret_struct_info); + v->Visit("attrs", &attrs); + v->Visit("binding_blocks", &binding_blocks); + v->Visit("output", &output); + // `block_builder` is not visited. + } + + static constexpr const char* _type_key = "script.ir_builder.relax.FunctionFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, SeqExprFrameNode); + + public: + void ExitWithScope() final; +}; + +class FunctionFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionFrame, SeqExprFrame, FunctionFrameNode); +}; + +/*! \brief The ir_builder frame for relax binding blocks. */ +class BlockFrameNode : public RelaxFrameNode { + public: + /*! \brief The flag that indicates whether the block is a dataflow block. */ + bool is_dataflow; + /*! \brief The variables emitted in this block. */ + Array emitted_vars; + /*! + * \brief A boolean indicating if the dataflow block is ended of construction. + * If it is true, any new binding trying to be emitted into this block will cause an error. + * \note Only used for a dataflow block. + */ + bool block_ended; + /*! + * \brief The output vars of the dataflow block. + * \note Only used for a dataflow block. + */ + Array output_vars; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("is_dataflow", &is_dataflow); + v->Visit("emitted_vars", &emitted_vars); + v->Visit("output_vars", &output_vars); + // `block_ended` is not visited. + } + + static constexpr const char* _type_key = "script.ir_builder.relax.BlockFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, RelaxFrameNode); + + public: + void EnterWithScope() final; + void ExitWithScope() final; +}; + +class BlockFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, RelaxFrame, BlockFrameNode); +}; + +/*! + * \brief A frame that represents if statement. + * + * \sa IfFrame + */ +class IfFrameNode : public RelaxFrameNode { + public: + /*! \brief The condition of the if statement. */ + tvm::relax::Expr condition; + /*! \brief The Bindings in the true branch. */ + Optional then_expr; + /*! \brief The Bindings in the false branch. */ + Optional else_expr; + /*! \brief The Binding var. */ + tvm::relax::Var var; + /*! \brief The binding var name. */ + String var_name; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + v->Visit("then_expr", &then_expr); + v->Visit("else_expr", &else_expr); + v->Visit("var", &var); + v->Visit("var_name", &var_name); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.IfFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, RelaxFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to IfFrameNode. + * + * \sa IfFrameNode + */ +class IfFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, RelaxFrame, IfFrameNode); +}; + +/*! + * \brief A frame that represents then. + * + * \sa ThenFrame + */ +class ThenFrameNode : public SeqExprFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.relax.ThenFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, SeqExprFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ThenFrameNode. + * + * \sa ThenFrameNode + */ +class ThenFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, SeqExprFrame, ThenFrameNode); +}; + +/*! + * \brief A frame that represents else. + * + * \sa ElseFrame + */ +class ElseFrameNode : public SeqExprFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.relax.ElseFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, SeqExprFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ElseFrameNode. + * + * \sa ElseFrameNode + */ +class ElseFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, SeqExprFrame, ElseFrameNode); +}; + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h new file mode 100644 index 000000000000..72aab6684ebf --- /dev/null +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +/////////////////////////////// Function //////////////////////////////// + +/*! + * \brief Start a function frame. + * \return The created ir_builder Function frame. + */ +TVM_DLL FunctionFrame Function(); + +/*! + * \brief Add a parameter to the last function frame. + * \param name The name of the parameter. + * \param struct_info The struct_info of the parameter. + * \return The created function parameter var. + */ +TVM_DLL tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_info); + +/*! + * \brief Specify the name of the last function frame. + * \param name The function name. + */ +TVM_DLL void FuncName(const String& name); + +/*! + * \brief Specify the attrs of the last function frame. + * \param attrs The function attrs. + */ +TVM_DLL void FuncAttrs(Map attrs); + +/*! + * \brief Specify the return struct info of the last function frame. + * \param ret_sinfo The return struct info. + */ +TVM_DLL void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo); + +/*! + * \brief Specify the return value of the last function frame. + * \param value The return value. + */ +TVM_DLL void FuncRetValue(const tvm::relax::Expr& value); + +///////////////////////////// BindingBlock ////////////////////////////// + +/*! + * \brief Start a binding block frame. + * \return The created ir_builder Block frame. + */ +TVM_DLL BlockFrame BindingBlock(); + +/*! + * \brief Start a dataflow binding block frame. + * \return The created ir_builder Block frame. + */ +TVM_DLL BlockFrame Dataflow(); + +/*! + * \brief Expose the dataflow block output variables as global ones + * \param vars The output variables of a dataflow block + */ +TVM_DLL void DataflowBlockOutput(const Array& vars); + +////////////////////////////// Bindings //////////////////////////////// + +/*! + * \brief Emit a binding to the last binding block frame. + * \param value The right side value of the bindings to be emitted. + * \param annotate_struct_info The optional struct info annotation for the emitted value. + * \return The left side var of the emitted binding. + */ +TVM_DLL tvm::relax::Var Emit( + const tvm::relax::Expr& value, + const Optional& annotate_struct_info = NullOpt); + +/*! + * \brief Emit a match_cast binding to the last binding block frame. + * \param value The value of the MatchCast to be emitted. + * \param struct_info The struct info of the MatchCast to be emitted. + * \return The left side var of the emitted binding. + */ +TVM_DLL tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, + const tvm::relax::StructInfo& struct_info); + +///////////////////////////// If Then Else ///////////////////////////// + +/*! + * \brief Create an if statement. + * \param condition The condition of if statement. + * \return The result IfFrame. + */ +IfFrame If(tvm::relax::Expr condition); +/*! + * \brief Create a then. + * \return The result ThenFrame. + */ +ThenFrame Then(); +/*! + * \brief Create an else. + * \return The result ElseFrame. + */ +ElseFrame Else(); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index f90468de66c6..721e12e7f8d9 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -93,10 +93,17 @@ def __call__(self, *args): A call taking the variable as a function. """ # pylint: disable=import-outside-toplevel + + # TODO(@relax-team): replace with Relax base class after it's introduced if all(isinstance(x, RelayExpr) for x in args): - from tvm import relay + if all(is_relax_expr(x) for x in args): + from tvm import relax + + return relax.Call(self, args) + else: + from tvm import relay - return relay.Call(self, args) + return relay.Call(self, args) arg_types = [type(x) for x in args] raise RuntimeError( "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types) @@ -185,3 +192,42 @@ def from_min_extent(min_value, extent, span=None): The constructed range. """ return _ffi_api.Range_from_min_extent(min_value, extent, span) + + +# TODO(@relax-team): remove when we have a RelaxExpr base class +def is_relax_expr(expr: RelayExpr) -> bool: + """check if a RelayExpr is a Relax expresssion. + + Parameters + ---------- + expr : RelayExpr + The expression to check. + + Returns + ------- + res : bool + If the expression is Relax expression, return True; otherwise return False. + """ + from tvm import relax # pylint: disable=import-outside-toplevel + + if isinstance( + expr, + ( + relax.Call, + relax.Constant, + relax.Tuple, + relax.TupleGetItem, + relax.If, + relax.Var, + relax.DataflowVar, + relax.ShapeExpr, + relax.SeqExpr, + relax.Function, + relax.ExternFunc, + relax.PrimValue, + relax.StringImm, + relax.DataTypeImm, + ), + ): + return True + return False diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index 7aa33ee49c72..b35bbd0a7df5 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -64,8 +64,10 @@ def __enter__(self) -> "IRBuilderFrame": _ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member return self - def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument - _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member + def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument + if exc_type is None and exc_value is None: + # Do not execute `FrameExit` if the with scope exits because of exceptions + _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member def add_callback(self, callback: Callable[[], None]) -> None: """Add a callback method invoked when exiting the with-scope. diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index ebb9728737ad..946be263a779 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,4 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import ir_module +from .ir import decl_function, def_function, ir_module diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 213180463cb2..796d6f3aad04 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,9 +16,54 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from tvm.ir import BaseFunc, GlobalVar + from . import _ffi_api from .frame import IRModuleFrame def ir_module() -> IRModuleFrame: + """Start a ir_module frame. + Returns + ------- + frame: IRModuleFrame + The constructed frame. + """ return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member + + +def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar: + """Declare a Function without given the specific function implementation. + Parameters + ---------- + func_name : str + The function unique name. + + func_signature: Optional[BaseFunc] + A Function w/o body, which used to specify the function signature + (i.e. func params and func return type/shape). + + Note + ---- + It is usually used in cross-function call. And we can specify the function by `DefFunction` + Returns + ------- + gv : GlobalVar + The corresponding GlobalVar. + """ + + return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member + func_name, func_signature + ) + + +def def_function(func_name: str, func: BaseFunc) -> None: + """Define the function which is declared before. + Parameters + ---------- + func_name : str + The function unique name. + func: BaseFunc + The given function implementation + """ + return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/ir_builder/relax/__init__.py b/python/tvm/script/ir_builder/relax/__init__.py new file mode 100644 index 000000000000..f0905acf34e3 --- /dev/null +++ b/python/tvm/script/ir_builder/relax/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Package tvm.script.ir_builder.relax""" +from . import frame +from .ir import * # pylint: disable=wildcard-import,redefined-builtin diff --git a/python/tvm/script/ir_builder/relax/_ffi_api.py b/python/tvm/script/ir_builder/relax/_ffi_api.py new file mode 100644 index 000000000000..6e2098cf88af --- /dev/null +++ b/python/tvm/script/ir_builder/relax/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.script.ir_builder.relax""" +import tvm._ffi + +tvm._ffi._init_api("script.ir_builder.relax", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/relax/frame.py b/python/tvm/script/ir_builder/relax/frame.py new file mode 100644 index 000000000000..97e181fbe4be --- /dev/null +++ b/python/tvm/script/ir_builder/relax/frame.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""IR Builder Frame for Relax dialect""" +from tvm._ffi import register_object as _register_object + +from ..base import IRBuilderFrame + + +@_register_object("script.ir_builder.relax.RelaxFrame") +class RelaxFrame(IRBuilderFrame): + """The base ir_builder frame for the relax dialect.""" + + +@_register_object("script.ir_builder.relax.SeqExprFrame") +class SeqExprFrame(RelaxFrame): + ... + + +@_register_object("script.ir_builder.relax.FunctionFrame") +class FunctionFrame(SeqExprFrame): + """The ir_builder frame for the relax function.""" + + +@_register_object("script.ir_builder.relax.BlockFrame") +class BlockFrame(RelaxFrame): + """The ir_builder frame for relax binding blocks.""" + + +@_register_object("script.ir_builder.relax.IfFrame") +class IfFrame(RelaxFrame): + ... + + +@_register_object("script.ir_builder.relax.ThenFrame") +class ThenFrame(SeqExprFrame): + ... + + +@_register_object("script.ir_builder.relax.ElseFrame") +class ElseFrame(SeqExprFrame): + ... diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py new file mode 100644 index 000000000000..647ef8f25af7 --- /dev/null +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -0,0 +1,407 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, wrong-import-order, no-member, invalid-name +"""IRBuilder for Relax dialect""" + +import builtins +import functools +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import tvm +from tvm import DataType, relax +from tvm.ir import PrimExpr +from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, const + +############################### Operators ############################### +from tvm.relax.op import ( + add, + assert_op, + call_builtin_with_ctx, + call_tir, + invoke_closure, + make_closure, + multiply, + null_value, + print, + shape_of, +) +from tvm.relax.struct_info import StructInfo +from tvm.relax.utils import args_converter +from tvm.runtime import Object as tvm_Object +from tvm.runtime import ObjectGeneric + +from . import _ffi_api, frame + +##################### Python Native Function Alias ###################### + +py_print = builtins.print +py_tuple = tuple +py_str = str + + +############################### Function ################################ + + +def function() -> frame.FunctionFrame: + """Start a function frame. + Returns + ------- + frame: FunctionFrame + The constructed function frame. + """ + return _ffi_api.Function() # type: ignore[attr-defined] # pylint: disable=no-member + + +def arg(name: py_str, struct_info: StructInfo) -> Var: + """Add a parameter to the last function frame. + Parameters + ---------- + name: str + The name of the parameter. + struct_info: StructInfo + The Struct Info of the parameter + + Returns + ------- + var: Var + The created function parameter var. + """ + + return _ffi_api.Arg(name, struct_info) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_name(name: py_str) -> None: + """Specify the name of the last function frame. + Parameters + ---------- + name: str + The function name. + """ + return _ffi_api.FuncName(name) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_attr(attrs: Dict[py_str, tvm_Object]) -> None: + """Specify the attrs of the last function frame. + Parameters + ---------- + attrs: Dict[str, Object] + The function attrs. + """ + return _ffi_api.FuncAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_ret_struct_info(ret_sinfo: StructInfo) -> None: + """Specify the return struct info of the last function frame. + Parameters + ---------- + ret_type: StructInfo + The function return struct info. + """ + return _ffi_api.FuncRetStructInfo(ret_sinfo) # type: ignore[attr-defined] # pylint: disable=no-member + + +def func_ret_value(value: Expr) -> None: + """Specify the return value of the last function frame. + Parameters + ---------- + value: Expr + The function return value. + """ + return _ffi_api.FuncRetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member + + +############################# BindingBlock ############################## + + +def dataflow() -> frame.BlockFrame: + """Start a dataflow binding block frame. + Returns + ------- + frame: frame.BlockFrame + The created ir_builder Block frame. + """ + return _ffi_api.Dataflow() # type: ignore[attr-defined] # pylint: disable=no-member + + +def output(*vars: Tuple[Var]) -> None: + """Expose the dataflow block output variables as global ones. + Parameters + ---------- + vars: Tuple[Var] + The output variables of a dataflow block. + """ + return _ffi_api.DataflowBlockOutput(vars) # type: ignore[attr-defined] # pylint: disable=no-member + + +################################## Ops ################################# + + +@args_converter.auto +def call_packed( + func: py_str, + *args: Expr, + sinfo_args: Union[StructInfo, List[StructInfo]], + **kwargs: Any, +) -> Call: + """Create a relax Call, which calls a packed function. + Parameters + ---------- + func: str + The name of extern function. + *args : Expr + The arguments. + sinfo_args: Union[StructInfo, List[StructInfo]] + The list of structure info arguments. + kwargs: Expr + The keyword arguments. + + Returns + ------- + call: Call + The created Relax Call + """ + op = ExternFunc(func) + if sinfo_args is None: + raise ValueError("R.call_packed is required to have type_args") + if isinstance(sinfo_args, py_tuple): # type: ignore + sinfo_args = list(sinfo_args) + elif not isinstance(sinfo_args, list): + sinfo_args = [sinfo_args] + for i, sinfo_arg in enumerate(sinfo_args): + if callable(sinfo_arg): + sinfo_arg = sinfo_arg() + # Convert possible StructInfoProxy to StructInfo + if isinstance(sinfo_arg, ObjectGeneric): + sinfo_arg = sinfo_arg.asobject() + sinfo_args[i] = sinfo_arg + + is_default = False + if "attrs_type_key" in kwargs: + attrs_type_key = kwargs["attrs_type_key"] + kwargs.pop("attrs_type_key") + else: + attrs_type_key = "DictAttrs" + is_default = True + attrs = None + if kwargs or not is_default: + attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs) + + return Call(op, args, attrs=attrs, sinfo_args=sinfo_args) + + +def _sinfo_arg_wrapper(func): + """A wrapper to convert StructInfoProxies to StructInfo for builtin operators with sinfo_args""" + + def _convert_tensor_type(args): + if isinstance(args, (list, py_tuple)): # type: ignore + new_args = [_convert_tensor_type(x) for x in args] + return type(args)(new_args) + if isinstance(args, dict): + return {_convert_tensor_type(k): _convert_tensor_type(v) for k, v in args.items()} + if inspect.isfunction(args): + args = args() + if isinstance(args, ObjectGeneric): + args = args.asobject() + return args + + @functools.wraps(func) + def wrapped(*args, **kwargs): + return func(*_convert_tensor_type(args), **_convert_tensor_type(kwargs)) + + return wrapped # type: ignore + + +invoke_closure = _sinfo_arg_wrapper(invoke_closure) # pylint: disable=invalid-name + +call_builtin_with_ctx = _sinfo_arg_wrapper(call_builtin_with_ctx) # pylint: disable=invalid-name + +############################### Bindings ############################### + + +def emit(value: Expr, annotate_struct_info: Optional[StructInfo] = None) -> Var: + """Emit a binding to the last binding block frame. + Parameters + ---------- + value: Expr + The right side value of the bindings to be emitted. + + annotate_struct_info: Optional[StructInfo] + The optional struct info annotation for the emitted value. + + Returns + ------- + var: Var + The left side var of the emitted binding. + """ + return _ffi_api.Emit(value, annotate_struct_info) # type: ignore[attr-defined] # pylint: disable=no-member + + +def emit_match_cast(value: Expr, struct_info: StructInfo) -> Var: + """Emit a match_cast binding to the last binding block frame. + Parameters + ---------- + value: Expr + The value of the MatchCast to be emitted. + struct_info: StructInfo + The struct_info of the MatchCast to be emitted. + + Returns + ------- + var: Var + The left side var of the emitted binding. + """ + return _ffi_api.EmitMatchCast(value, struct_info) # type: ignore + + +############################# If Then Else ############################# + + +def If(condition: Expr) -> frame.IfFrame: # pylint: disable=invalid-name + """Create an if frame. + Parameters + ---------- + condition : Expr + The condition of if statement, executes the true branch if the condition is true, + otherwise jump into the false branch. + Returns + ------- + res : frame.IfFrame + The result IfFrame. + """ + return _ffi_api.If(condition) # type: ignore[attr-defined] # pylint: disable=no-member + + +def Then() -> frame.ThenFrame: # pylint: disable=invalid-name + """Create a then frame. + Returns + ------- + res : frame.ThenFrame + The result ThenFrame. + """ + return _ffi_api.Then() # type: ignore[attr-defined] # pylint: disable=no-member + + +def Else() -> frame.ElseFrame: # pylint: disable=invalid-name + """Create an else frame. + Returns + ------- + res : frame.ElseFrame + The result ElseFrame. + """ + return _ffi_api.Else() # type: ignore[attr-defined] # pylint: disable=no-member + + +############################### R.tuple ################################ + + +def tuple(*fields: Expr) -> Expr: + """Create a tuple expression. + Parameters + ---------- + *fields : Expr + The fields of the tuple. + Returns + ------- + res : Expr + The result tuple. + """ + if len(fields) == 0: + fields = py_tuple() + + return relax.Tuple(fields) # type: ignore[attr-defined] # pylint: disable=no-member + + +############################### PrimValue ############################## + + +def prim_value(value: PrimExpr) -> Expr: + """Create a prim value expression. + Parameters + ---------- + value : PrimExpr + The value of the prim value. + Returns + ------- + res : Expr + The result prim value. + """ + return relax.PrimValue(value) # type: ignore[attr-defined] # pylint: disable=no-member + + +def str(value: py_str) -> Expr: + """Create a string imm expression. + Parameters + ---------- + value : str + The value of the str. + Returns + ------- + res : Expr + The result str. + """ + return relax.StringImm(value) # type: ignore[attr-defined] # pylint: disable=no-member + + +def dtype(value: Union[py_str, DataType]) -> Expr: + """Create a dtype imm expression. + Parameters + ---------- + value : dtype + The value of the dtype. + Returns + ------- + res : Expr + The result dtype. + """ + return relax.DataTypeImm(value) # type: ignore[attr-defined] # pylint: disable=no-member + + +############################### Importer ############################### + +__all__ = [ + "Else", + "If", + "Then", + "TupleGetItem", + "add", + "arg", + "assert_op", + "call_packed", + "call_tir", + "call_builtin_with_ctx", + "const", + "dataflow", + "dtype", + "emit", + "emit_match_cast", + "func_attr", + "func_name", + "func_ret_struct_info", + "func_ret_value", + "function", + "invoke_closure", + "make_closure", + "multiply", + "null_value", + "output", + "prim_value", + "print", + "shape_of", + "str", + "tuple", +] diff --git a/python/tvm/script/parser/__init__.py b/python/tvm/script/parser/__init__.py index 5161a2601c49..678297799e6d 100644 --- a/python/tvm/script/parser/__init__.py +++ b/python/tvm/script/parser/__init__.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the Licens. """The parser""" -from . import _core, ir, tir +from . import _core, ir, tir, relax from ._core import parse from .ir import ir_module from .tir import prim_func +from .relax import function diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py index ad7ae5034780..2767a97f6096 100644 --- a/python/tvm/script/parser/core/diagnostics.py +++ b/python/tvm/script/parser/core/diagnostics.py @@ -220,7 +220,7 @@ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) level : diagnostics.DiagnosticLevel The diagnostic level. """ - lineno = node.lineno or self.source.start_line + lineno = node.lineno or 1 col_offset = node.col_offset or self.source.start_column end_lineno = node.end_lineno or lineno end_col_offset = node.end_col_offset or col_offset diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 9e6c100c954d..3c01b54a9f1a 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -43,6 +43,7 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) if extra_vars is None: import tvm # pylint: disable=import-outside-toplevel from tvm.script.parser import ir # pylint: disable=import-outside-toplevel + from tvm.script.parser import relax # pylint: disable=import-outside-toplevel from tvm.script.parser import tir # pylint: disable=import-outside-toplevel extra_vars = { @@ -51,6 +52,9 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) "ir": ir, "T": tir, "tir": tir, + "relax": relax, + "R": relax, + "tvm": tvm, } source = Source(program) diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 3a72a3c33106..075aedd89146 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -203,7 +203,7 @@ def _visit(self, node: doc.AST) -> Any: else: value = self._eval_expr(node.__class__(**fields)) except Exception as e: # pylint: disable=broad-except,invalid-name - self.parser.report_error(node, str(e)) + self.parser.report_error(node, e) return self._add_intermediate_result(value) def _eval_lambda(self, node: doc.Lambda) -> Any: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 7c699c42aecb..105164ed5ffc 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -60,6 +60,10 @@ def context(): return context() +def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument + pass + + class VarTableFrame: """The variable table frame. A frame of variable table stores the variables created in one block or scope. @@ -259,6 +263,17 @@ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: node = self.diag.source.as_ast() self.visit(node) + def get_dispatch_token(self, node: doc.FunctionDef) -> str: + if not isinstance(node, doc.FunctionDef): + self.report_error(node, "Only can get dispatch token for function.") + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if not hasattr(decorator, "dispatch_token"): + self.report_error(node, "The parser does not understand the decorator") + return decorator.dispatch_token + def with_dispatch_token(self, token: str): """Add a new dispatching token as with statement. @@ -388,6 +403,8 @@ def report_error( # Only take the last line of the error message if isinstance(err, TVMError): msg = list(filter(None, str(err).split("\n")))[-1] + elif isinstance(err, KeyError): + msg = "KeyError: " + str(err) else: msg = str(err) self.diag.error(node, msg) @@ -457,30 +474,33 @@ def visit_tvm_annotation(self, node: doc.expr) -> Any: """ return _dispatch(self, "tvm_annotation")(self, node) - def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name - """The general function definition visiting method. + def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name + """The general function definition visit method. Parameters ---------- node : doc.FunctionDef - The doc AST function definition node. - - Returns - ------- - res : Any - The visiting result. + The doc FunctionDef node. """ - if not node.decorator_list: - self.report_error(node, "Function must be decorated") - # TODO: only the last decorator is parsed - decorator = self.eval_expr(node.decorator_list[-1]) - if not hasattr(decorator, "dispatch_token"): - self.report_error(node, "The parser does not understand the decorator") - token = decorator.dispatch_token + token = self.get_dispatch_token(node) + current_token = self.dispatch_tokens[-1] func = dispatch.get(token=token, type_name="FunctionDef", default=None) if func is None: self.report_error(node, "The parser does not understand the decorator") + pre_func = dispatch.get( + token=current_token, type_name="pre_token_switch", default=_do_nothing + ) + post_func = dispatch.get( + token=current_token, type_name="post_token_switch", default=_do_nothing + ) + pre_func(self, node) _dispatch_wrapper(func)(self, node) + post_func(self, node) + + def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None: + token = self.get_dispatch_token(node) + with self.with_dispatch_token(token): + _dispatch(self, "tvm_declare_function")(self, node) def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name """The general class definition visiting method. diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index e0268412d284..13b3e298590f 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -32,8 +32,12 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: node : doc.ClassDef The doc AST class definition node. """ + with self.var_table.with_frame(): with I.ir_module(): + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit_tvm_declare_function(stmt) with self.with_dispatch_token("ir"): self.visit_body(node.body) diff --git a/python/tvm/script/parser/relax/__init__.py b/python/tvm/script/parser/relax/__init__.py index feb8e683401c..04f3fea21c2b 100644 --- a/python/tvm/script/parser/relax/__init__.py +++ b/python/tvm/script/parser/relax/__init__.py @@ -15,7 +15,18 @@ # specific language governing permissions and limitations # under the License. """Initial impl of relax parser for sugars""" -from tvm.relax import TensorStructInfo, ShapeStructInfo +from ...ir_builder.relax import * # pylint: disable=redefined-builtin +from ...ir_builder.relax import ir as _relax +from . import parser as _parser +from .entry import Callable, Object, Prim, Shape, Tensor, Tuple, function, match_cast -Tensor = TensorStructInfo -Shape = ShapeStructInfo +__all__ = _relax.__all__ + [ + "Callable", + "Object", + "Prim", + "Shape", + "Tensor", + "Tuple", + "function", + "match_cast", +] diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py new file mode 100644 index 000000000000..d93f9a2826bc --- /dev/null +++ b/python/tvm/script/parser/relax/entry.py @@ -0,0 +1,327 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +import inspect +from typing import Any +from typing import Callable as _Callable +from typing import Dict, List, Optional, Set, TypeVar, Union + +from tvm.relax import ( + Expr, + FuncStructInfo, + Function, + ObjectStructInfo, + PrimStructInfo, + ShapeStructInfo, + StructInfo, + TensorStructInfo, + TupleStructInfo, +) +from tvm.runtime import ObjectGeneric +from tvm.tir import PrimExpr + +from .._core import parse, utils + +FType = TypeVar("FType", bound=_Callable) + +############################## R.function ############################## + + +def function(f: FType) -> Union[Function, FType]: + if not inspect.isfunction(f): + raise TypeError(f"Expect a function, but got: {f}") + if utils.is_defined_in_class(inspect.stack(), f): + return f + return parse(f, utils.inspect_function_capture(f)) + + +setattr(function, "dispatch_token", "relax") + + +############################# Struct Info ############################## + + +class StructInfoProxy(ObjectGeneric): + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> StructInfo: + raise NotImplementedError() + + def get_symbolic_vars(self) -> Set[str]: + return {} + + def asobject(self): + return self.as_struct_info(None) + + +############################### R.Tensor ############################### + + +def _eval_shape(expr: Union[str, PrimExpr], dict_globals: Optional[Dict[str, Any]]) -> PrimExpr: + if isinstance(expr, str): + code = compile(expr, "", "eval") + return eval(code, dict_globals or {}) # pylint: disable=eval-used + else: + return expr + + +class TensorProxy(StructInfoProxy): + shape: Optional[List[Union[str, PrimExpr]]] + dtype: str + ndim: int + + def __init__( + self, + shape: Optional[List[Union[PrimExpr, str]]] = None, + dtype: Optional[str] = None, + ndim: int = -1, + ) -> None: + self.shape = shape + self.dtype = dtype + self.ndim = ndim + super().__init__() + + def get_symbolic_vars(self) -> Set[str]: + if self.shape is None: + return {} + else: + return {s for s in self.shape if isinstance(s, str) and s.isidentifier()} + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo: + if self.shape is None: + return TensorStructInfo(None, self.dtype, self.ndim) + else: + if dict_globals is None and any([isinstance(s, str) for s in self.shape]): + raise ValueError( + "String-defined shape expr is only allowed when parsing function parameters " + "and return annotations for TVMScript." + ) + shape = [_eval_shape(s, dict_globals) for s in self.shape] + return TensorStructInfo(shape, self.dtype, self.ndim) + + +def Tensor( + shape: Optional[List[Union[PrimExpr, str]]] = None, + dtype: Optional[str] = None, + ndim: int = -1, +) -> TensorProxy: + # scalar tensor case + if shape is not None and len(shape) == 0: + shape = [] + if isinstance(shape, str) and dtype is None: + dtype = shape + shape = None + + if shape is not None and not isinstance(shape, (tuple, list)): + raise ValueError(f"shape must be a list or tuple, but got: {shape}") + return TensorProxy(shape, dtype, ndim) + + +############################## R.Callable ############################## + + +class CallableProxy(StructInfoProxy): + params: List[StructInfoProxy] + ret: StructInfoProxy + """Function type. + + A function type consists of a list of type parameters to enable + the definition of generic functions, + a set of type constraints which we omit for the time being, + a sequence of argument types, and a return type. + + Parameters + ---------- + params : List[StructInfoProxy] + The argument StructInfoProxy + + ret : StructInfoProxy + The return StructInfoProxy. + + """ + + def __init__( + self, + params: Union[StructInfoProxy, List[StructInfoProxy]], + ret: StructInfoProxy, + ) -> None: + if not isinstance(params, (list, tuple)): + params = [params] + # convert `R.Tensor` to `R.Tensor()` + self.params = [param() if callable(param) else param for param in params] + self.ret = ret() if callable(ret) else ret + + def get_symbolic_vars(self) -> Set[str]: + return set().union(*[p.get_symbolic_vars() for p in self.params]) + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncStructInfo: + params = [param.as_struct_info(dict_globals) for param in self.params] + ret = self.ret.as_struct_info(dict_globals) + return FuncStructInfo(params, ret) + + +def Callable( + params: Union[StructInfoProxy, List[StructInfoProxy]], + ret: StructInfoProxy, +) -> CallableProxy: + return CallableProxy(params, ret) + + +############################### R.Tuple ################################ + + +class TupleProxy(StructInfoProxy): + fields: List[StructInfoProxy] + """The type of tuple values. + + Parameters + ---------- + fields : List[StructInfoProxy] + The fields in the tuple + """ + + def __init__( + self, + *fields: List[StructInfoProxy], + ) -> None: + if len(fields) == 1 and isinstance(fields[0], (tuple, list)): + fields = fields[0] + # convert `R.Tensor` to `R.Tensor()` + self.fields = [field() if callable(field) else field for field in fields] + + def get_symbolic_vars(self) -> Set[str]: + return set().union(*[f.get_symbolic_vars() for f in self.fields]) + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TupleStructInfo: + fields = [field.as_struct_info(dict_globals) for field in self.fields] + return TupleStructInfo(fields) + + +def Tuple(*fields: List[StructInfoProxy]) -> TupleProxy: + return TupleProxy(*fields) + + +############################### R.Shape ################################ + + +class ShapeProxy(StructInfoProxy): + values: Optional[List[PrimExpr]] + ndim: int + """The type of shape values. + + Parameters + ---------- + values : Optional[List[PrimExpr]] + The symbolic shape values if known. + + ndim : Optional[int] + The size of the shape. + """ + + def __init__( + self, + values: Optional[List[PrimExpr]] = None, + ndim: int = -1, + ) -> None: + self.values = values + self.ndim = ndim + + def get_symbolic_vars(self) -> Set[str]: + if self.values is None: + return {} + else: + return {v for v in self.values if isinstance(v, str) and v.isidentifier()} + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: + values = [_eval_shape(v, dict_globals) for v in self.values] if self.values else None + return ShapeStructInfo(values, self.ndim) + + +def Shape(values: Optional[List[PrimExpr]] = None, ndim: int = -1) -> ShapeProxy: + return ShapeProxy(values, ndim) + + +############################### R.Object ################################ + + +class ObjectProxy(StructInfoProxy): + """The proxy fo ObjectStructInfo. + + Parameters + ---------- + values : Optional[List[PrimExpr]] + The symbolic shape values if known. + + ndim : Optional[int] + The size of the shape. + """ + + def __init__(self) -> None: + pass + + def get_symbolic_vars(self) -> Set[str]: + return set() + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: + return ObjectStructInfo() + + +def Object() -> ObjectProxy: + return ObjectProxy() + + +################################ R.Prim ################################ + + +class PrimProxy(StructInfoProxy): + dtype: str + """The type of shape values. + + Parameters + ---------- + dtype : str + The data type. + """ + + def __init__(self, dtype: str) -> None: + self.dtype = dtype + + def get_symbolic_vars(self) -> Set[str]: + return set() + + def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo: + return PrimStructInfo(self.dtype) + + +def Prim(dtype: str) -> PrimProxy: + return PrimProxy(dtype) + + +############################ R.match_cast ############################# +class MatchCastPair: + value: Expr + struct_info: StructInfo + + def __init__(self, value: Expr, struct_info: StructInfo) -> None: + self.value = value + self.struct_info = struct_info + + +def match_cast(value: Expr, struct_info: StructInfo): + if value is None: + raise ValueError("value of match_cast cannot be None") + if struct_info is None: + raise ValueError("struct_info of match_cast cannot be None") + return MatchCastPair(value, struct_info) diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py new file mode 100644 index 000000000000..ef26ddd6e921 --- /dev/null +++ b/python/tvm/script/parser/relax/parser.py @@ -0,0 +1,276 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +import functools +import numbers +from typing import Any, Optional + +from tvm import relax, tir +from tvm.ir import structural_equal +from tvm.relax import StructInfo +from tvm.relax.utils import convert_to_expr +from tvm.script.ir_builder.relax.frame import BlockFrame + +from ...ir_builder import ir as I +from ...ir_builder import relax as R +from ...ir_builder.base import IRBuilder +from .._core import Parser, dispatch, doc +from .entry import MatchCastPair, StructInfoProxy, TupleProxy + + +def bind_assign_value( + self: Parser, + node: doc.expr, + var_name: str, + value: Any, + anno_sinfo: Optional[StructInfo] = None, +) -> Any: + var_table = self.var_table.get() + + if isinstance(value, tir.Var): + if value.name and var_name != value.name: + self.report_error( + node, + "Cannot define TIR variables with different names. The LHS of binding should " + "has the same name provided in RHS.", + ) + if var_name in var_table: + prev_value = var_table[var_name] + if not isinstance(prev_value, tir.Var): + self.report_error( + node, + "Cannot redefine a non-TIR-variable object to a TIR variable. Please " + "define the TIR variable with another name.", + ) + if prev_value.dtype != value.dtype: + self.report_error( + node, + "Expected the same dtype for TIR vars " + f"but got {value.dtype} vs {prev_value.dtype}", + ) + return prev_value + IRBuilder.name(var_name, value) + return value + + if isinstance(value, tuple): + value = convert_to_expr(value) + if isinstance(value, numbers.Number): + value = R.const(value) + + if isinstance(value, relax.Expr): + var = R.emit(value, anno_sinfo) + elif isinstance(value, MatchCastPair): + if anno_sinfo is not None and not structural_equal(anno_sinfo, value.struct_info): + self.report_error( + node, "Cannot specify inconsistent annotation for a match cast pair. " + ) + var = R.emit_match_cast(value.value, value.struct_info) + else: + raise TypeError(f"Unsupported type {type(value)} in assignment") + + IRBuilder.name(var_name, var) + return var + + +def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy: + try: + annotation = self.eval_expr(node) + if annotation is None: + return TupleProxy([]) + if callable(annotation): + annotation = annotation() + if isinstance(annotation, StructInfoProxy): + return annotation + else: + raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.") + except Exception as err: + self.report_error(node, str(err)) + raise err + + +def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> StructInfo: + var_table = self.var_table.get() if eval_str else None + try: + return eval_struct_info_proxy(self, node).as_struct_info(var_table) + except Exception as err: + self.report_error(node, str(err)) + raise err + + +def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> None: + # Collect symbolic vars from parameters + symbolic_vars = set() + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation) + symbolic_vars.update(param_sinfo_proxy.get_symbolic_vars()) + + # Define symbolic vars to the current var_table frame + for var_name in symbolic_vars: + self.var_table.add(var_name, tir.Var(var_name, "int64"), allow_shadowing=False) + + +@dispatch.register(token="relax", type_name="FunctionDef") +def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + with self.var_table.with_frame(): + with self.with_dispatch_token("relax"): + with R.function(): + R.func_name(node.name) + collect_symbolic_var_from_params(self, node) + + if node.returns is not None: + ann_sinfo = eval_struct_info(self, node.returns, eval_str=True) + R.func_ret_struct_info(ann_sinfo) + + self.visit(node.args) + self.visit_body(node.body) + + +@dispatch.register(token="relax", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: + with self.var_table.with_frame(): + collect_symbolic_var_from_params(self, node) + + if node.returns is None: + # Use ObjectStructInfo as unknown return type + # NOTE: Cannot use VoidStructInfo here because the return type can be refined later. + ret_sinfo = relax.ObjectStructInfo() + else: + ret_sinfo = eval_struct_info(self, node.returns, eval_str=True) + params = [] + params_sinfo = [] + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) + params_sinfo.append(param_sinfo) + params.append(relax.Var(arg.arg, param_sinfo)) + + func_signature = relax.Function.create_empty(params, ret_sinfo) + global_var = I.decl_function(node.name, func_signature) + self.var_table.add(node.name, global_var) + + +@dispatch.register(token="relax", type_name="pre_token_switch") +def pre_token_switch(self: Parser, node: doc.Expr) -> None: # pylint: disable=unused-argument + ir_builder = IRBuilder() + ir_builder.__enter__() + + +@dispatch.register(token="relax", type_name="post_token_switch") +def post_token_switch(self: Parser, node: doc.Expr) -> None: + ir_builder = IRBuilder.current() + result = ir_builder.get() + ir_builder.__exit__(None, None, None) + var = R.emit(result) + IRBuilder.name(node.name, var) + self.var_table.add(node.name, var, allow_shadowing=False) + + +@dispatch.register(token="relax", type_name="Expr") +def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: + value = self.eval_expr(node.value) + if value is not None: + self.report_error(node, f"Unsupported Expr stmt type {value}.") + + +@dispatch.register(token="relax", type_name="arguments") +def visit_arguments(self: Parser, node: doc.arguments) -> None: + arg: doc.arg + for arg in node.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) + param = R.arg(arg.arg, param_sinfo) + + self.var_table.add(arg.arg, param) + + +@dispatch.register(token="relax", type_name="tvm_annotation") +def visit_tvm_annotation(self: Parser, node: doc.expr) -> StructInfo: + return eval_struct_info(self, node, eval_str=False) + + +@dispatch.register(token="relax", type_name="With") +def visit_with(self: Parser, node: doc.With) -> None: + # Currently only `with R.dataflow()` is supported + if len(node.items) != 1: + self.report_error(node, "Only one item is allowed.") + item = node.items[0] + if item.optional_vars is not None: + self.report_error( + item.context_expr, + "Relax syntax doesn't allow binding expressions in `with` to variables", + ) + frame = self.eval_expr(item.context_expr) + with self.var_table.with_frame(): + with frame: + self.visit(node.body) + if isinstance(frame, BlockFrame) and frame.is_dataflow: + output_vars = frame.output_vars + for var in output_vars: + self.var_table.add(var.name_hint, var, allow_shadowing=True) + + +@dispatch.register(token="relax", type_name="Assign") +def visit_assign(self: Parser, node: doc.Assign) -> None: + if len(node.targets) != 1: + self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") + lhs = node.targets[0] + rhs = self.eval_expr(node.value) + self.eval_assign( + target=lhs, + source=rhs, + bind_value=bind_assign_value, + allow_shadowing=True, + ) + + +@dispatch.register(token="relax", type_name="AnnAssign") +def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: + lhs = node.target + rhs = self.eval_expr(node.value) + anno_sinfo = self.visit_tvm_annotation(node.annotation) + self.eval_assign( + target=lhs, + source=rhs, + bind_value=functools.partial(bind_assign_value, anno_sinfo=anno_sinfo), + allow_shadowing=True, + ) + + +@dispatch.register(token="relax", type_name="Return") +def visit_return(self: Parser, node: doc.Assign) -> None: + value = self.eval_expr(node.value) + value = convert_to_expr(value) + R.func_ret_value(value) + + +@dispatch.register(token="relax", type_name="If") +def visit_if(self: Parser, node: doc.If) -> None: + if node.orelse is None: + raise ValueError("Else statements are required for relax dialect.") + with R.If(self.eval_expr(node.test)) as if_frame: + with self.var_table.with_frame(): + with R.Then(): + self.visit_body(node.body) + with self.var_table.with_frame(): + with R.Else(): + self.visit_body(node.orelse) + self.var_table.add(if_frame.var_name, if_frame.var, allow_shadowing=True) diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 411a7f8f3c83..649f817411f0 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -83,7 +83,7 @@ def __getitem__(self, keys) -> Buffer: return self(keys) if len(keys) >= 2 and not isinstance(keys[1], str): return self(keys) - return self(*keys) # pylint: disable=no-member # type: ignore + return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member class PtrProxy: @@ -93,7 +93,7 @@ class PtrProxy: def __call__(self, dtype, storage_scope="global"): if callable(dtype): dtype = dtype().dtype - return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore + return ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member @deprecated("T.Ptr[...]", "T.handle(...)") def __getitem__(self, keys): diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py index f0c04f47cdf6..ed8f07a06369 100644 --- a/python/tvm/script/parser/tir/operation.py +++ b/python/tvm/script/parser/tir/operation.py @@ -46,12 +46,12 @@ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name for i in [0, 1]: # Case 1. binop - r(doc.Add, i, tir.Add) - r(doc.Sub, i, tir.Sub) - r(doc.Mult, i, tir.Mul) - r(doc.Div, i, tir.Div) - r(doc.FloorDiv, i, tir.FloorDiv) - r(doc.Mod, i, tir.FloorMod) + r(doc.Add, i, lambda a, b: a + b) + r(doc.Sub, i, lambda a, b: a - b) + r(doc.Mult, i, lambda a, b: a * b) + r(doc.Div, i, lambda a, b: a / b) + r(doc.FloorDiv, i, lambda a, b: a // b) + r(doc.Mod, i, lambda a, b: a % b) r(doc.LShift, i, lambda a, b: a << b) r(doc.RShift, i, lambda a, b: a >> b) r(doc.BitOr, i, lambda a, b: a | b) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index fbef1a969179..f8893ce8cfb1 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -24,6 +24,7 @@ from tvm.ir import PrimType from tvm.tir import Buffer, IterVar, PrimExpr, Var +from ...ir_builder import ir as I from ...ir_builder import tir as T from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame @@ -471,3 +472,28 @@ def visit_return(self: Parser, node: doc.Return) -> None: The doc AST return node. """ self.report_error(node, "Return is not allowed.") + + +@dispatch.register(token="tir", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: + """The function declaration step for tir + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Return + The doc AST return node. + """ + + ret_type = None + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + + # Only ret_type is needed for func_signature. + func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) + global_var = I.decl_function(node.name, func_signature) + self.var_table.add(node.name, global_var) diff --git a/src/ir/module.cc b/src/ir/module.cc index 22c6faf3d69d..4a09bdaaf7c6 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -63,15 +63,25 @@ IRModule::IRModule(tvm::Map functions, } bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { - if (functions.size() != other->functions.size()) return false; if (!equal(this->attrs, other->attrs)) return false; + + if (functions.size() != other->functions.size()) return false; + // Update GlobalVar remap + for (const auto& gv : this->GetGlobalVars()) { + if (!other->ContainGlobalVar(gv->name_hint)) return false; + if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; + } for (const auto& kv : this->functions) { - if (!other->ContainGlobalVar(kv.first->name_hint)) return false; if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; } + if (type_definitions.size() != other->type_definitions.size()) return false; + // Update GlobalTypeVar remap + for (const auto& gtv : this->GetGlobalTypeVars()) { + if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false; + if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false; + } for (const auto& kv : this->type_definitions) { - if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false; if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false; } return true; diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index a81c56922dff..addf12928435 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -26,11 +26,15 @@ namespace ir_builder { namespace ir { void IRModuleFrameNode::ExitWithScope() { - ICHECK_EQ(functions.size(), global_vars.size()); - int n = functions.size(); Map func_map; - for (int i = 0; i < n; ++i) { - func_map.Set(global_vars[i], functions[i]); + CHECK_EQ(functions.size(), global_var_map.size()) + << "All functions must be defined in the IRModule. Got " << global_var_map.size() + << "declared function(s), but only " << functions.size() << "defined function(s)."; + for (const auto& kv : functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined"; + func_map.Set(gv, func); } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index a8cc452e4f0c..da2330b5772b 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -17,9 +17,12 @@ * under the License. */ #include +#include #include #include +#include "./utils.h" + namespace tvm { namespace script { namespace ir_builder { @@ -27,12 +30,48 @@ namespace ir { IRModuleFrame IRModule() { ObjectPtr n = make_object(); - n->global_vars.clear(); + n->global_var_map.clear(); n->functions.clear(); return IRModuleFrame(n); } +GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) { + IRModuleFrame frame = FindModuleFrame("I.DeclFunction"); + CHECK(!frame->global_var_map.count(func_name)) + << "ValueError: function " << func_name << " already exists"; + GlobalVar gv = GlobalVar(func_name); + if (func_signature->struct_info_.defined()) { + gv->struct_info_ = tvm::relax::GetStructInfo(func_signature); + } else if (const auto* prim_func = func_signature.as()) { + gv->struct_info_ = + tvm::relax::FuncStructInfo::OpaqueFunc(tvm::relax::StructInfoFromType(prim_func->ret_type)); + } else { + LOG(FATAL) << "Unsupported function type: " << func_signature->GetTypeKey(); + } + CHECK(frame->functions.find(gv) == frame->functions.end()) + << "ValueError: function " << func_name << " has already been defined."; + frame->global_var_map.Set(func_name, gv); + if (func_signature.defined()) { + frame->functions.Set(gv, func_signature); + } + return gv; +} + +void DefFunction(const String& func_name, const BaseFunc& func) { + IRModuleFrame frame = FindModuleFrame("I.DefFunction"); + auto it = frame->global_var_map.find(func_name); + CHECK(it != frame->global_var_map.end()) + << "ValueError: function " << func_name << " does not exist, please declare it first."; + const GlobalVar& gv = (*it).second; + frame->functions.Set(gv, func); + if (func->checked_type_.defined()) { + gv->checked_type_ = func->checked_type_; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h new file mode 100644 index 000000000000..58d5e53f7032 --- /dev/null +++ b/src/script/ir_builder/ir/utils.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ + +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace ir { + +inline IRModuleFrame FindModuleFrame(const String& method) { + IRBuilder builder = IRBuilder::Current(); + if (Optional frame = builder->FindFrame()) { + const Optional& last_module_frame = builder->GetLastFrame(); + if (last_module_frame.defined() && last_module_frame.value() == frame) { + return frame.value(); + } + } else { + LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method + << "' is called under I.ir_module()"; + } + LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under I.ir_module()"; + throw; +} + +} // namespace ir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc new file mode 100644 index 000000000000..c78b9e73c534 --- /dev/null +++ b/src/script/ir_builder/relax/frame.cc @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +void SeqExprFrameNode::ExitWithScope() { + // At this moment, there should be at most one BlockFrame which hasn't ended. In this case, call + // its `ExitBlockFrame` and check if there is any more unended BlockFrame. + if (Optional block_frame = IRBuilder::Current()->GetLastFrame()) { + block_frame.value()->ExitWithScope(); + ICHECK(!IRBuilder::Current()->GetLastFrame().defined()) + << "ValueError: There is some remaining BlockFrame that is not properly popped out."; + } + RelaxFrameNode::ExitWithScope(); +} + +void SeqExprFrameNode::EnterWithScope() { + RelaxFrameNode::EnterWithScope(); + BindingBlock()->EnterWithScope(); +} + +void FunctionFrameNode::ExitWithScope() { + using ir::IRModuleFrame; + using tvm::relax::Expr; + IRBuilder builder = IRBuilder::Current(); + SeqExprFrameNode::ExitWithScope(); + // Step 1: Create the function. + CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use " + "`return` to return an Expr"; + this->block_builder->BeginScope(params); + Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); + auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); + this->block_builder->EndScope(); + tvm::relax::Function func(/*params=*/params, + /*body=*/body, + /*ret_struct_info=*/ret_struct_info, + /*attrs=*/dict_attrs); + // Step 2: Update IRModule. + if (builder->frames.empty()) { + // Case 0. No outer frame, return function directly + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = func; + } else if (Optional opt_frame = builder->FindFrame()) { + // Case 1. A global function of an IRModule + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // First time visiting the function. + ir::DeclFunction(func_name, func); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); + } else { + LOG(FATAL) << "ValueError: Cannot find where to insert Relax.Function"; + } +} + +void BlockFrameNode::EnterWithScope() { + // Step 1. If the last frame is a block frame. The start of a new block frame marks the end of the + // last block frame. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + if (block_frame.defined()) { + block_frame.value()->ExitWithScope(); + // Block frames cannot appear consecutively. + ICHECK(!IRBuilder::Current()->GetLastFrame()); + } + // Step 2. Deal with the new block frame. + RelaxFrameNode::EnterWithScope(); + Optional func_frame = IRBuilder::Current()->FindFrame(); + CHECK(func_frame.defined()) + << "ValueError: Cannot find FunctionFrame when creating BindingBlocks, Please ensure " + "creating the block under Relax function scope."; + const tvm::relax::BlockBuilder& block_builder = func_frame.value()->block_builder; + if (is_dataflow) { + block_builder->BeginDataflowBlock(); + } else { + block_builder->BeginBindingBlock(); + } +} + +class DataflowBlockRewriter : public tvm::relax::ExprMutator { + public: + static tvm::relax::DataflowBlock Rewrite(const tvm::relax::DataflowBlock& block, + const Array& output_vars) { + DataflowBlockRewriter rewriter(output_vars); + return Downcast(rewriter.VisitBindingBlock(block)); + } + + private: + explicit DataflowBlockRewriter(const Array& output_vars) { + for (const tvm::relax::Var& var : output_vars) { + output_var_set_.insert(var.get()); + } + } + + tvm::relax::Var VisitVarDef_(const tvm::relax::DataflowVarNode* op) final { + auto it = output_var_set_.find(op); + if (it != output_var_set_.end()) { + // Rewrite dataflow vars to global vars + auto n = make_object(*op); + tvm::relax::Var new_var(n); + this->var_remap_[op->vid] = new_var; + return new_var; + } else { + return GetRef(op); + } + } + + private: + std::unordered_set output_var_set_; +}; + +void BlockFrameNode::ExitWithScope() { + // Step 1. Pop the current frame out of the frame stack. + RelaxFrameNode::ExitWithScope(); + + // Step 2. Get the constructed binding block from the block builder. The block should have at + // lease one binding - otherwise, the block is not supposed to be created. + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + tvm::relax::BindingBlock block = block_builder->EndBlock(); + if (block->bindings.empty()) { + return; + } + + // Step 3. Rewrite the dataflow block. + if (is_dataflow) { + // Step 3.1. Rewrite block binding + block = DataflowBlockRewriter::Rewrite(Downcast(block), output_vars); + + // Step 3.2. Collect global vars' reference in bindings + Map new_global_vars; + for (const tvm::relax::Binding& binding : block->bindings) { + if (!binding->var->IsInstance()) { + new_global_vars.Set(binding->var->vid, binding->var); + } + } + + // Step 3.3. Rewrite output vars + Array new_output_vars; + for (const auto& var : output_vars) { + auto it = new_global_vars.find(var->vid); + ICHECK(it != new_global_vars.end()); + new_output_vars.push_back((*it).second); + } + output_vars = std::move(new_output_vars); + } + + // Step 3. Get the last frame from the IRBuilder frame stack. + Optional opt_last_frame = IRBuilder::Current()->GetLastFrame(); + ICHECK(opt_last_frame.defined()); + RelaxFrame last_frame = opt_last_frame.value(); + + // Step 4. Since we popped out any possible block frame when entering the "with" scope of the + // current frame, the last frame cannot be a block frame. + ICHECK(!last_frame->IsInstance()); + + // Step 5. Push the block frame into the corresponding field of the last frame. + if (const auto* seq_frame = last_frame.as()) { + ICHECK(!seq_frame->output.defined()) + << "The function is not expected to have output values when emitting blocks."; + auto frame = GetRef(seq_frame); + frame->binding_blocks.push_back(block); + } else { + LOG(FATAL) << "ValueError: Currently the last frame is supposed to be either a function frame " + "or a block frame. However, the last frame is \"" + << last_frame->GetTypeKey() << "\"."; + } + + // Step 6. Start another binding block when a dataflow block ended. + if (is_dataflow) { + BindingBlock()->EnterWithScope(); + } +} + +void IfFrameNode::EnterWithScope() { + const Array& frames = IRBuilder::Current()->frames; + for (const IRBuilderFrame& frame : frames) { + const auto* block_frame = frame.as(); + if (block_frame && block_frame->is_dataflow) { + LOG(FATAL) << "ValueError: Cannot create an IfFrame inside a dataflow block."; + } + } + RelaxFrameNode::EnterWithScope(); +} + +void IfFrameNode::ExitWithScope() { + RelaxFrameNode::ExitWithScope(); + CHECK(then_expr.defined()) + << "ValueError: The body of then part is expected to be defined before exiting."; + CHECK(then_expr.defined()) + << "ValueError: The body of else part is expected to be defined before exiting."; + auto body = tvm::relax::If(condition, then_expr.value(), else_expr.value()); + var = Emit(body); + IRBuilder::Name(var_name, var); +} + +void ThenFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("R.Then"); + CHECK(!frame->then_expr.defined()) + << "ValueError: Duplicate then branch declaration, previous one is " + << frame->then_expr.value(); + SeqExprFrameNode::EnterWithScope(); +} + +void ThenFrameNode::ExitWithScope() { + SeqExprFrameNode::ExitWithScope(); + String var_name; + output = GetSeqExprForBranch(GetRef(this), &var_name); + IfFrame frame = FindIfFrame("R.Then"); + frame->then_expr = output; + frame->var_name = var_name; +} + +void ElseFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("R.Else"); + CHECK(frame->then_expr.defined()) << "The else branch should follow then branch"; + CHECK(!frame->else_expr.defined()) + << "ValueError: Duplicate else branch declaration, previous one is " + << frame->else_expr.value(); + SeqExprFrameNode::EnterWithScope(); +} + +void ElseFrameNode::ExitWithScope() { + SeqExprFrameNode::ExitWithScope(); + String var_name; + output = GetSeqExprForBranch(GetRef(this), &var_name); + IfFrame frame = FindIfFrame("R.Else"); + frame->else_expr = output; + CHECK(frame->var_name == var_name) + << "This last binding of both branches must have the same variable."; +} + +TVM_REGISTER_NODE_TYPE(FunctionFrameNode); +TVM_REGISTER_NODE_TYPE(SeqExprFrameNode); +TVM_REGISTER_NODE_TYPE(BlockFrameNode); +TVM_REGISTER_NODE_TYPE(IfFrameNode); +TVM_REGISTER_NODE_TYPE(ThenFrameNode); +TVM_REGISTER_NODE_TYPE(ElseFrameNode); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc new file mode 100644 index 000000000000..ece645243c82 --- /dev/null +++ b/src/script/ir_builder/relax/ir.cc @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +///////////////////////////////// Vars ////////////////////////////////// + +using tvm::script::ir_builder::details::Namer; + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using tvm::relax::VarNode; + using tvm::relax::IdNode; + const VarNode* var = node.as(); + IdNode* vid = const_cast(var->vid.get()); + vid->name_hint = name; + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using tvm::relax::DataflowVarNode; + using tvm::relax::IdNode; + const DataflowVarNode* var = node.as(); + IdNode* vid = const_cast(var->vid.get()); + vid->name_hint = name; + }); + +/////////////////////////////// Function //////////////////////////////// + +FunctionFrame Function() { + ObjectPtr n = make_object(); + const IRBuilder& ir_builder = IRBuilder::Current(); + Optional mod = NullOpt; + if (const Optional mod_frame = ir_builder->GetLastFrame()) { + mod = tvm::IRModule(mod_frame.value()->functions); + } + n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/mod); + return FunctionFrame(n); +} + +tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_info) { + FunctionFrame frame = FindFunctionFrame("R.Arg"); + tvm::relax::Var var(name, struct_info); + frame->params.push_back(var); + return var; +} + +void FuncName(const String& name) { + FunctionFrame frame = FindFunctionFrame("R.func_name"); + if (frame->name.defined()) { + LOG(FATAL) << "ValueError: Duplicate function name, previous one is: \"" << frame->name.value() + << "\""; + } + frame->name = name; +} + +void FuncAttrs(Map attrs) { + FunctionFrame frame = FindFunctionFrame("R.func_attr"); + if (!frame->attrs.empty()) { + LOG(FATAL) << "ValueError: Duplicate function attrs, previous one is:\n" << frame->attrs; + } + frame->attrs = attrs; +} + +void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) { + FunctionFrame frame = FindFunctionFrame("R.func_ret_struct_info"); + if (frame->ret_struct_info.defined()) { + LOG(FATAL) << "ValueError: Duplicate function return struct info, previous one is:\n " + << frame->ret_struct_info.value(); + } + frame->ret_struct_info = ret_sinfo; +} + +void FuncRetValue(const tvm::relax::Expr& value) { + // Step 0. Normalize the value. + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + tvm::relax::Expr normalized_value = block_builder->Normalize(value); + + // Step 1. The current Relax TVMScript syntax only allows function return appearing at the end of + // a function body. Therefore if there is any unended block frame when dealing with function + // return, we should end the block frame. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + if (block_frame.defined()) { + block_frame.value()->ExitWithScope(); + ICHECK(!IRBuilder::Current()->FindFrame()) + << "All block frame are supposed to be popped out already"; + } + // Step 2. Add the output value to the function frame. + FunctionFrame frame = FindFunctionFrame("return"); + CHECK(!frame->output.defined()) + << "ValueError: Relax functions don't support multiple return statement. Please make sure " + "the return statement appears at the end of function."; + + frame->output = std::move(normalized_value); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetStructInfo").set_body_typed(FuncRetStructInfo); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue); + +///////////////////////////// BindingBlock ////////////////////////////// + +BlockFrame Dataflow() { + ObjectPtr n = make_object(); + n->is_dataflow = true; + n->block_ended = false; + return BlockFrame(n); +} + +BlockFrame BindingBlock() { + ObjectPtr n = make_object(); + n->is_dataflow = false; + n->block_ended = false; + return BlockFrame(n); +} + +void DataflowBlockOutput(const Array& vars) { + // Step 1. Check that we're in a Dataflow block that is not ended. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + CHECK(block_frame.defined() && block_frame.value()->is_dataflow) + << "ValueError: `R.output` should appear inside a dataflow block. However, the current " + "innermost block is not a dataflow block."; + CHECK(!block_frame.value()->block_ended) + << "ValueError: It is not allowed for a dataflow block to have multiple output operation."; + + // Step 2. Mark the block frame ended of construction, so that any followup binding after this + // mark in the dataflow block will lead to an error. + block_frame.value()->block_ended = true; + + // Step 3. All the output variables must be global variables and must be emitted by this dataflow + // block. + const Array& emitted_vars = block_frame.value()->emitted_vars; + for (const tvm::relax::Var& var : vars) { + CHECK(std::find(emitted_vars.begin(), emitted_vars.end(), var) != emitted_vars.end()) + << "ValueError: An output variable is not emitted by this dataflow block. Please make sure " + "all dataflow block output variables are emitted exactly by this block."; + block_frame.value()->output_vars.push_back(var); + } +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Dataflow").set_body_typed(Dataflow); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.BindingBlock").set_body_typed(BindingBlock); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.DataflowBlockOutput") + .set_body_typed(DataflowBlockOutput); + +/////////////////////////////// Bindings /////////////////////////////// + +tvm::relax::Var Emit(const tvm::relax::Expr& expr, + const Optional& annotate_struct_info) { + using tvm::relax::GetStructInfo; + BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + if (annotate_struct_info.defined()) { + const auto& sinfo = annotate_struct_info.value(); + if (!expr->struct_info_.defined()) { + UpdateStructInfo(expr, sinfo); + } else { + CHECK(StructInfoBaseCheck(sinfo, GetStructInfo(expr)) != tvm::relax::BaseCheckResult::kFailL0) + << "Invalid annotation. Got rhs value struct info: " << GetStructInfo(expr) + << ", given struct info: " << sinfo; + } + } + tvm::relax::Var var = block_builder->Emit(expr); + block_frame->emitted_vars.push_back(var); + return var; +} + +tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, + const tvm::relax::StructInfo& struct_info) { + BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + + tvm::relax::Var var = block_builder->EmitMatchCast(value, struct_info); + block_frame->emitted_vars.push_back(var); + return var; +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast); + +///////////////////////////// If Then Else ///////////////////////////// + +IfFrame If(tvm::relax::Expr condition) { + ObjectPtr n = make_object(); + n->condition = condition; + n->then_expr = NullOpt; + n->else_expr = NullOpt; + return IfFrame(n); +} + +ThenFrame Then() { + ObjectPtr n = make_object(); + return ThenFrame(n); +} + +ElseFrame Else() { + ObjectPtr n = make_object(); + return ElseFrame(n); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.If").set_body_typed(If); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Then").set_body_typed(Then); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Else").set_body_typed(Else); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h new file mode 100644 index 000000000000..ae91d05769bd --- /dev/null +++ b/src/script/ir_builder/relax/utils.h @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +inline FunctionFrame FindFunctionFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->FindFrame()) { + return frame.value(); + } + LOG(FATAL) << "ValueError: Function frame not find. Please ensure '" << method + << "' is called under R.function()"; + throw; +} + +inline IfFrame FindIfFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + return frame.value(); + } else { + LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method + << "' is called under R.if_()"; + } + throw; +} + +inline tvm::relax::BlockBuilder GetBlockBuilder() { + Optional frame = IRBuilder::Current()->FindFrame(); + CHECK(frame.defined()) << "ValueError: Relax Function frame not find. Please ensure " + "assignment is called under R.function()"; + return frame.value()->block_builder; +} + +inline BlockFrame CheckBlockFrameExistAndUnended() { + // We check if the current block is "ended" - if a block is ended, it is not allowed to emit new + // bindings into this block, and we should throw exceptions. + + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + CHECK(block_frame.defined()) << "ValueError: Block frame not find"; + CHECK(!block_frame.value()->block_ended) + << "ValueError: New binding is not allowed after dataflow block output."; + return block_frame.value(); +} + +inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String* var_name) { + // Step 0. Check frame type + std::string method; + if (frame->IsInstance()) { + method = "R.Then"; + } else if (frame->IsInstance()) { + method = "R.Else"; + } else { + ICHECK(false) << "TypeError: Unsupported frame type: " << frame->GetTypeKey(); + } + + // Step 1. Check non-empty block and last binding is non-dataflow + CHECK(!frame->binding_blocks.empty()) + << "Empty body is not allowed for '" << method << "' statements."; + const tvm::relax::BindingBlock& last_block = frame->binding_blocks.back(); + CHECK(!last_block->bindings.empty()) << "Blocks are expected to be non-empty."; + + // Step 2. Collect body from the last binding. + tvm::relax::Expr body; + const tvm::relax::Binding& last_binding = last_block->bindings.back(); + if (const auto* var_binding = last_binding.as()) { + CHECK(!var_binding->var->IsInstance()) + << "A non-dataflow var is expected in the last binding of '" << method << "'."; + body = var_binding->value; + *var_name = var_binding->var->name_hint(); + } else if (const auto* match_cast = last_binding.as()) { + CHECK(!match_cast->var->IsInstance()) + << "A non-dataflow var is expected in the last binding of '" << method << "'."; + body = var_binding->value; + *var_name = match_cast->var->name_hint(); + } else { + ICHECK(false) << "TypeError: Unsupported binding type: " << last_binding->GetTypeKey(); + } + + // Step 3. Re-collect binding blocks to remove the last binding. + Array new_blocks(frame->binding_blocks.begin(), + frame->binding_blocks.end() - 1); + Array last_block_bindings(last_block->bindings.begin(), + last_block->bindings.end() - 1); + new_blocks.push_back(tvm::relax::BindingBlock(last_block_bindings)); + + return tvm::relax::SeqExpr(new_blocks, body); +} + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 1e63201a40dd..dd8d3c2ed3f3 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -41,9 +42,17 @@ void PrimFuncFrameNode::ExitWithScope() { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; } else if (Optional opt_frame = builder->FindFrame()) { - ir::IRModuleFrame frame = opt_frame.value(); - frame->global_vars.push_back(GlobalVar(name.value_or(""))); - frame->functions.push_back(func); + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const ir::IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // Case. First time visiting the function. + ir::DeclFunction(func_name, func); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); } else { LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; } diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 485757063867..e8f125adc053 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -81,7 +81,7 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) { * \return The top frame of BlockFrame. */ inline BlockFrame FindBlockFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + if (Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); } LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py new file mode 100644 index 000000000000..12d8b114b862 --- /dev/null +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import relax, tir +from tvm.script.ir_builder import relax as R +from tvm.script.ir_builder.base import IRBuilder + + +def test_function_simple(): + """ + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + out = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + return out + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + R.func_attr({"Primitive": 1}) + x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) + R.func_ret_struct_info(relax.TensorStructInfo(dtype="float32", ndim=2)) + out = R.emit( + R.call_tir("extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32")) + ) + IRBuilder.name("out", out) + R.func_ret_value(out) + func = ir_builder.get() + # create with BlockBuilder + x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,), attrs={"Primitive": 1}): + out = bb.emit( + relax.call_tir("extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32")) + ) + bb.emit_func_output(out) + mod = bb.get() + + tvm.ir.assert_structural_equal(func, mod["foo"]) + # check names + assert func.params[0].name_hint == "x" + assert func.body.body.name_hint == "out" + + +def test_match_cast(): + """ + @R.function + def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): + m = T.var("int64") + n = T.var("int64") + _ = R.match_cast(x, R.Tensor((m,), "float32")) + y1 = R.match_cast(x, R.Tensor((n,), "float32")) + return (m, n * 2) + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + x = R.arg("x", relax.TensorStructInfo(ndim=-1, dtype="float32")) + y = R.arg("y", relax.TensorStructInfo(ndim=-1, dtype="float32")) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + _ = R.emit_match_cast(x, relax.TensorStructInfo((m,), "float32")) + y1 = R.emit_match_cast(y, relax.TensorStructInfo((n,), "float32")) + IRBuilder.name("y1", y1) + R.func_ret_value(relax.ShapeExpr([m, n * 2])) + func = ir_builder.get() + + # create with BlockBuilder + x = relax.Var("x", relax.TensorStructInfo(dtype="float32", ndim=-1)) + y = relax.Var("y", relax.TensorStructInfo(dtype="float32", ndim=-1)) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + _ = bb.match_cast(x, relax.TensorStructInfo((m,), "float32")) + y1 = bb.match_cast(y, relax.TensorStructInfo((n,), "float32")) + bb.emit_func_output(relax.ShapeExpr([m, n * 2])) + mod = bb.get() + + tvm.ir.assert_structural_equal(func, mod["foo"]) + + +def test_dataflow_block(): + """ + @R.function + def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): + # block 0 + with R.dataflow(): + lv0 = R.call_tir("extern_func", (x,), R.Tensor((128, 128), dtype="float32")) + gv: Tensor((128, 128), "float32") = lv0 + R.output(gv) + return gv + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) + with R.dataflow() as df: + lv0 = R.emit( + R.call_tir( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) + IRBuilder.name("lv0", lv0) + gv = R.emit(lv0) + IRBuilder.name("gv", gv) + R.output(gv) + (gv,) = df.output_vars + R.func_ret_value(gv) + func = ir_builder.get() + + # create with BlockBuilder + x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + with bb.dataflow(): + lv0 = bb.emit( + relax.call_tir( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) + gv = bb.emit_output(lv0) + bb.emit_func_output(gv) + + tvm.ir.assert_structural_equal(func, bb.get()["foo"]) + + +def test_regression_py_print(): + # Test that the py_print directs to python builtin print + from tvm.script.ir_builder.relax.ir import py_print # pylint: disable=import-outside-toplevel + + assert py_print == print + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py new file mode 100644 index 000000000000..34b02fdbb8c3 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser.py @@ -0,0 +1,1062 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import pytest +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax, tir, topi +from tvm.relax import DynTensorType +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]] = None, +): + # TODO(relax-team): enable roundtrip testing when printer is ready + # test = parsed.script(show_meta=True) + # roundtrip_mod = tvm.script.parse(test) + # tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_simple_func(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): + R.func_attr({"Primitive": 1}) + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + return gv0 + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,), attrs={"Primitive": 1}): + out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) + bb.emit_func_output(out) + + _check(foo, bb.get()["foo"]) + + +def test_error_report(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + # error: a = b = c is not allowed. + gv0 = gv1 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + return gv0 + + +def test_mismatch_cast_dims_and_ndim(): + with pytest.raises(Exception): + + @R.function + def f( + x: R.Tensor((2, 3), "float32", ndim=3) + ): # error: ndim and the shape dims are mismatch + return x + + +def test_unexpected_num_kw_args(): + with pytest.raises(Exception): + + @R.function + def f(x: R.Tensor(dtype="float32", ndim=1, foo=2)): # error: unexpected kw args foo + return x + + +def test_unexpected_ndim(): + with pytest.raises(Exception): + + @R.function + # error: dim is expected to be non-negative int or -1 for unknown + def f(x: R.Tensor(dtype="float32", ndim=-2)): + return x + + +def test_unexpected_ndim_type(): + with pytest.raises(Exception): + + @R.function + def f(x: R.Tensor(dtype="float32", ndim="1")): # error: dim is expected to be int + return x + + +def test_unexpected_tir_cast_args(): + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: R.Tensor(("m",), "float32")): + m = T.var("int64") + # tir.cast expects 2 arguments, but got 3 + return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),), dtype="float32")) + + +def test_unexpected_tir_max_args(): + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: R.Tensor(("m", "n"), "float32")): + m = T.var("int64") + # tir.max expects 2 arguments, but got 1 + return relax.call_tir("foo", (x,), R.Tensor((T.max(m),), dtype="float32")) + + +def test_func_type_annotation_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x, y): # error: the parameter type annotation is missing + z = R.add(x, y) + y = z + return y + + +def test_if_mismatch_var_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + z = R.add(w, w) # error: The binding var is expected to `y` + return z + + +def test_unassigned_call_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: R.Tensor): + R.add(x, x) + return x + + +def test_simple_module(): + @I.ir_module + class TestModule: + @T.prim_func + def tir_func( + x: T.Buffer((T.int64(128), T.int64(128)), "float32"), + y: T.Buffer((T.int64(128), T.int64(128)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i, j in T.grid(T.int64(128), T.int64(128)): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + y[vi, vj] = x[vi, vj] + 1.0 + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): + # TODO(Siyuan): Need to change to `TestModule.tir_func` + gv0 = R.call_tir(tir_func, x, R.Tensor((128, 128), dtype="float32")) + return gv0 + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") + bb.emit_func_output(out) + + _check(TestModule, bb.get()) + + +def test_relax_tensor_op(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"): + y = R.add(x, x) + z = R.multiply(x, y) + return z + + x = relax.Var("x", R.Tensor((4, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + y = bb.emit(relax.op.add(x, x)) + z = bb.emit(relax.op.multiply(x, y)) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_symbolic_shape(): + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64", "m") + n = T.var("int64", "n") + gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) + return gv0 + + @R.function + def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) + return gv0 + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): + m = T.var("int64") + n = T.var("int32") # The shape dtype should be int64 + gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) + return gv0 + + def _expected(name: str): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = relax.Var("x", R.Tensor([m, n], "float32")) + bb = relax.BlockBuilder() + with bb.function(name, (x,)): + out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32"))) + bb.emit_func_output(out) + return bb.get()[name] + + _check(foo, _expected("foo")) + _check(bar, _expected("bar")) + + +def test_shadowing(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + y = R.add(x, x) + z = R.multiply(x, y) + y = R.add(x, y) + y = z + y = R.multiply(y, x) + z = y + return z + + x = relax.Var("x", R.Tensor((4, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + y = bb.emit(relax.op.add(x, x)) + z = bb.emit(relax.op.multiply(x, y)) + y = bb.emit(relax.op.add(x, y)) + y = bb.emit(z) + y = bb.emit(relax.op.multiply(y, x)) + z = bb.emit(y) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_match_cast(): + @R.function + def foo(x: R.Tensor("float32"), y: R.Tensor("float32")): + m = T.var("int64") + n = T.var("int64") + x0 = R.match_cast(x, R.Tensor([m], "float32")) + with R.dataflow(): + y0 = R.match_cast(y, R.Tensor([n], "float32")) + gv = y0 + R.output(gv) + return (x0, (m, n * 2)) + + x = relax.Var("x", R.Tensor("float32")) + y = relax.Var("y", R.Tensor("float32")) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + y2 = relax.Var("y", R.Tensor([n], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + x0 = bb.match_cast(x, R.Tensor([m], "float32")) + with bb.dataflow(): + y0 = bb.match_cast(y, R.Tensor([n], "float32")) + bb.emit_output(y0) + bb.emit_func_output(relax.Tuple([x0, relax.ShapeExpr([m, n * 2])])) + + _check(foo, bb.get()["foo"]) + + +def test_tuple_return(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + gv0 = R.call_tir("extern_func_0", x, R.Tensor((4, 4), dtype="float32")) + gv1 = R.call_tir("extern_func_1", x, R.Tensor((4, 4), dtype="float32")) + return (gv0, gv1) + + x = relax.Var("x", R.Tensor((4, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + gv0 = bb.emit(relax.call_tir("extern_func_0", x, R.Tensor((4, 4), dtype="float32"))) + gv1 = bb.emit(relax.call_tir("extern_func_1", x, R.Tensor((4, 4), dtype="float32"))) + bb.emit_func_output(relax.Tuple((gv0, gv1))) + + _check(foo, bb.get()["foo"]) + + +def test_tuple_return_2(): + @R.function + def foo(x: R.Tensor("float32", ndim=2)): + n, m = T.var("int64"), T.var("int64") + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) + return (x0, (n + 1, m, 1)) + + x = relax.Var("x", R.Tensor("float32", ndim=2)) + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + x0 = bb.match_cast(x, R.Tensor((n, m), "float32")) + bb.emit_func_output(relax.Tuple([x0, relax.ShapeExpr([n + 1, m, 1])])) + + _check(foo, bb.get()["foo"]) + + +def test_tuple_binding(): + @R.function + def foo(x: R.Tensor("float32", ndim=2)): + n, m = T.var("int64"), T.var("int64") + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) + t0 = (x, x0) + t1 = (x, (n, m), t0) + return t1 + + x = relax.Var("x", R.Tensor("float32", ndim=2)) + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + x0 = bb.match_cast(x, R.Tensor((n, m), "float32")) + t0 = bb.emit(relax.Tuple([x, x0])) + t1 = bb.emit(relax.Tuple([x, relax.ShapeExpr([n, m]), t0])) + bb.emit_func_output(t1) + + _check(foo, bb.get()["foo"]) + + +def test_tuple_get_item(): + @R.function + def foo(x: R.Tensor, y: R.Tensor): + t1 = R.tuple(x, y) + t2 = (x, y) + a = t1[0] + b = R.TupleGetItem(t2, 1) + c = R.add(a, b) + return c + + x = relax.Var("x", R.Tensor()) + y = relax.Var("y", R.Tensor()) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + t1 = bb.emit(relax.Tuple([x, y])) + t2 = bb.emit(relax.Tuple([x, y])) + a = bb.emit(relax.TupleGetItem(t1, 0)) + b = bb.emit(relax.TupleGetItem(t2, 1)) + c = bb.emit(relax.op.add(a, b)) + bb.emit_func_output(c) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_block(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + lv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + lv1 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + gv = lv1 + R.output(gv) + return gv + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + with bb.dataflow(): + lv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) + lv1 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))) + gv = bb.emit_output(lv1) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_block_advanced(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) + with R.dataflow(): + m = T.var("int64") + n = T.var("int64") + lv0 = R.call_tir("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) + lv1 = R.match_cast(lv0, R.Tensor((m, n), "float32")) + gv2 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + gv2 = R.call_tir("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) + gv3 = R.match_cast(gv2, R.Tensor((m, n), "float32")) + gv3 = R.match_cast(lv0, R.Tensor((m, n), "float32")) + gv4 = gv3 + gv5 = gv2 + R.output(gv5, gv4) + gv6 = R.call_tir("extern_func", gv5, R.Tensor((128, 128), dtype="float32")) + gv7 = R.call_tir("extern_func", gv6, R.Tensor((128, 128), dtype="float32")) + return gv7 + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + with bb.function("foo", (x,)): + gv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) + gv1 = bb.emit(relax.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32"))) + with bb.dataflow(): + lv0 = bb.emit(relax.call_tir("extern_func", gv1, R.Tensor((128, 128), dtype="float32"))) + lv1 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) + gv2 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))) + gv21 = bb.emit( + relax.call_tir("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) + ) + gv3 = bb.match_cast(gv21, R.Tensor((m, n), "float32")) + gv31 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) + gv32 = bb.emit_output(gv31) + gv22 = bb.emit_output(gv21) + gv4 = bb.emit(relax.call_tir("extern_func", gv22, R.Tensor((128, 128), dtype="float32"))) + gv5 = bb.emit(relax.call_tir("extern_func", gv4, R.Tensor((128, 128), dtype="float32"))) + bb.emit_func_output(gv5) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_binding_after_output(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + R.output(gv) + lv = R.call_tir("extern_func", gv, R.Tensor((128, 128), dtype="float32")) + return gv + + +def test_dataflow_output_global_var(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + with R.dataflow(): + gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) + R.output(gv0, gv1) + return gv1 + + +def test_dataflow_multiple_output(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + R.output(gv) + R.output(gv) + return gv + + +def test_dataflow_output_outside_dataflow_block(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + R.output(gv) + return gv + + +def test_dataflow_scope_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: R.Tensor(ndim=2)): + with R.dataflow(): + y = R.add(x, x) + z = R.multiply(y, x) + w = R.add(z, x) + R.output(y, w) + t = R.multiply(y, z) # z is not in the outer scope + return t + + +def test_return_without_binding(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + return x + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + bb.emit_func_output(x) + + _check(foo, bb.get()["foo"]) + + +def test_multiple_return(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + return x + return x + + +def test_function_without_return(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + + +def test_tensor_type_without_args(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + v = R.call_tir("tir_relu", x, R.Tensor((32, 32), dtype="float32")) + return v + + x = relax.Var("x", R.Tensor((32, 32), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + v = bb.emit(relax.call_tir("tir_relu", x, R.Tensor((32, 32), dtype="float32"))) + bb.emit_func_output(v) + + _check(foo, bb.get()["foo"]) + + +def test_direct_return(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + return x + + x = relax.Var("x", R.Tensor((32, 32), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + bb.emit_func_output(x) + + _check(foo, bb.get()["foo"]) + + +def test_call_packed(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + z = R.call_packed("vm.builtin.copy", x, sinfo_args=R.Tensor((32, 32), "float32")) + return z + + x = relax.Var("x", R.Tensor((32, 32), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + z = bb.emit( + relax.Call( + relax.ExternFunc("vm.builtin.copy"), + (x,), + None, + sinfo_args=[R.Tensor((32, 32), "float32")], + ) + ) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_annotation(): + @R.function + def foo( + x: R.Tensor((32, "m"), "float32"), + y: R.Tensor(("m",), "float32"), + r: R.Tensor(dtype="int64"), + ) -> R.Object: + m = T.var("int64", "m") + z: R.Tensor((32, m), "float32") = R.multiply(x, y) + w: R.Tensor = R.multiply(z, z) + q: R.Tensor(ndim=2) = R.add(w, w) + t = R.add(w, z) + sh: R.Shape = R.call_packed("shape_of", x, sinfo_args=R.Shape) + o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, sinfo_args=R.Object) + return o + + def _check_struct_info(binding, expected_sinfo): + tvm.ir.assert_structural_equal(binding.var.struct_info, expected_sinfo) + tvm.ir.assert_structural_equal(binding.value.struct_info, expected_sinfo) + + # Cannot use block builder here because we need to check the annotated type, + # which may be inconsistent with deduced type. + assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo) + m = relax.get_shape_of(foo.params[0])[1] + bindings = foo.body.blocks[0].bindings + + _check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32")) + _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=-1)) + _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=2)) + _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=-1)) + _check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1)) + _check_struct_info(bindings[5], relax.ObjectStructInfo()) + + +def test_annotate_override(): + @R.function + def foo(x: R.Tensor): + y = x + # z will be treated as object type even though it's a tensor + z: R.Object = R.add(x, y) + return z + + assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo) + y_bind, z_bind = foo.body.blocks[0].bindings + assert isinstance(y_bind.var.struct_info, relax.TensorStructInfo) + assert isinstance(z_bind.var.struct_info, relax.ObjectStructInfo) + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def test(x: R.Tensor): + # Error: x is of Tensor StructInfo, which can not annotate to R.Shape. + z: R.Shape = x + return z + + @R.function + def bar(x: R.Tensor): + # x is of Tensor StructInfo, the annotation of `z` is ignored. + z: R.Object = x + return z + + assert isinstance(bar.ret_struct_info, relax.TensorStructInfo) + (z_bind,) = bar.body.blocks[0].bindings + assert isinstance(z_bind.var.struct_info, relax.TensorStructInfo) + + +def test_call_tir_empty_shape(): + @R.function + def foo(x: R.Tensor((), "float32")): + z = R.call_tir("scalar_add", x, R.Tensor((), dtype="float32")) + return z + + (z_bind,) = foo.body.blocks[0].bindings + shape_expr = z_bind.value.sinfo_args[0].shape + + assert isinstance(shape_expr, relax.ShapeExpr) + assert len(shape_expr.values) == 0 + + +def test_call_tir_empty_tuple_arg(): + bb = relax.BlockBuilder() + dummy_param = relax.Var("dummy_param", R.Tensor(())) + with bb.function("foo", [dummy_param]): + output = bb.emit_te(topi.full, shape=(16, 32), dtype="float32", fill_value=1.0) + bb.emit_func_output(output) + + _check(bb.get()) + + +def test_call_tir_with_tir_var(): + @I.ir_module + class Module: + @R.function + def main( + dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2", "float32")) + ) -> R.Tensor(("n * 2",), "float32"): + n = T.var("int64") + y = R.call_tir(copy, (x,), R.Tensor(((n * 2,)), dtype="float32"), tir_vars=(n,)) + return y + + @T.prim_func + def copy(var_x: T.handle, var_y: T.handle, n: T.int64): + X = T.match_buffer(var_x, (n * 2,), dtype="float32") + Y = T.match_buffer(var_y, (n * 2,), dtype="float32") + for i in T.grid(n * 2): + with T.block("block"): + vi = T.axis.remap("S", [i]) + Y[vi] = X[vi] + + _check(Module) + + +def test_local_function(): + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + @R.function + def outer_func( + c1: R.Tensor((2, 3), "float32") + ) -> R.Callable((R.Tensor(None, "float32", ndim=2),), R.Tensor(None, "float32", ndim=2)): + @R.function + def inner_func(x1: R.Tensor((2, 3), "float32")): + s: R.Tensor((2, 3), "float32") = R.add(x1, c1) + return s + + return inner_func + + in_call = outer_func(x) + res = in_call(y) + return res + + main_bindings = main.body.blocks[0].bindings + assert len(main_bindings) == 3 + outer_func = main_bindings[0].value + assert isinstance(outer_func, relax.Function) + + outer_func_bindings = outer_func.body.blocks[0].bindings + assert len(outer_func_bindings) == 1 + inner_func = outer_func_bindings[0].value + assert isinstance(inner_func, relax.Function) + + @I.ir_module + class TestModule: + @R.function + def f(x: R.Tensor((128, 128), "float32"), y: R.Tensor((128, 128), "float32")): + @T.prim_func + def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + + z = relax.call_tir(my_matmul, (x, y), R.Tensor((128, 128), dtype="float32")) + return z + + bindings = TestModule["f"].body.blocks[0].bindings + assert len(bindings) == 2 + tir_func = bindings[0].value + assert isinstance(tir_func, tir.PrimFunc) + + +def test_cross_function_call(): + @I.ir_module + class Mod0: + @R.function + def foo(x: R.Tensor((10, 5), "float32")): + s = R.add(x, x) + return s + + @R.function + def main(x: R.Tensor((10, 5), "float32")): + inner = foo + gv1 = inner(x) + gv2 = foo(x) + return (inner, gv1, gv2) + + @I.ir_module + class Mod1: + @R.function + def main(x: R.Tensor((10, 5), "float32")): + inner = foo + gv1 = inner(x) + gv2 = foo(x) + return (inner, gv1, gv2) + + @R.function + def foo(x: R.Tensor((10, 5), "float32")) -> R.Tensor((10, 5), "float32"): + s = R.add(x, x) + return s + + +def test_if_branch(): + @R.function + def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")) -> R.Tensor((1,), "float32"): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + return y + + cond, x = foo.params + y_bind = foo.body.blocks[0].bindings[0] + y, ite = y_bind.var, y_bind.value + + assert isinstance(y, relax.Var) + assert y.name_hint == "y" + + assert isinstance(ite, relax.If) + assert isinstance(ite.true_branch, relax.SeqExpr) + assert isinstance(ite.false_branch, relax.SeqExpr) + + def check_call(call, op, args): + assert isinstance(call, relax.Call) + if isinstance(op, str): + assert call.op.name == op + else: + assert call.op == op + tvm.ir.assert_structural_equal(call.args, args) + + w_bind = ite.true_branch.blocks[0].bindings[0] + # the seq exprts in the branches are normalized to bind any call + # in the seq expr "body" to a var + y_bind = ite.true_branch.blocks[-1].bindings[-1] + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "relax.add", [x, x]) + check_call(y_bind.value, "relax.multiply", [w_bind.var, w_bind.var]) + + w_bind = ite.false_branch.blocks[0].bindings[0] + y_bind = ite.false_branch.blocks[-1].bindings[-1] + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "relax.multiply", [x, x]) + check_call(y_bind.value, "relax.add", [w_bind.var, w_bind.var]) + + +def test_if_inside_dataflow(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): + with R.dataflow(): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + R.output(y) + return y + + +def test_var_if_scoping_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + return w # error: The w is not defined in the outer scope + + +def test_if_branch_var_scope(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + return w + + +def test_erase_to_well_defined(): + @R.function + def foo(x: R.Tensor): + q = x + m, n = T.var("int64"), T.var("int64") + z = R.match_cast(q, R.Tensor((m, n))) + w = z + return w + + tvm.ir.assert_structural_equal(foo.ret_struct_info, R.Tensor(ndim=2)) + _check(foo) + + +def test_empty_tuple(): + @R.function + def foo(x: R.Tuple()): + y: R.Tuple() = R.tuple() + return y + + x = relax.Var("x", relax.TupleStructInfo([])) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + y = bb.emit(relax.Tuple([])) + bb.emit_func_output(y) + + _check(foo, bb.get()["foo"]) + + +def test_symbolic_shape_computing(): + # Tensor Case 1 + @R.function + def foo(x: R.Tensor(("m + 1",), "float32"), y: R.Tensor(("m", 1), "float32")): + z = R.add(x, y) + return z + + m = tir.Var("m", "int64") + x = relax.Var("x", relax.TensorStructInfo([m + 1], "float32")) + y = relax.Var("y", relax.TensorStructInfo([m, 1], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + z = bb.emit(relax.op.add(x, y)) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + # Tensor Case 2 + @R.function + def bar( + x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32") + ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"): + m = T.var("int64") + z = R.call_tir("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32")) + return z + + m = tir.Var("m", "int64") + x = relax.Var("x", relax.TensorStructInfo([m], "float32")) + y = relax.Var("y", relax.TensorStructInfo([tir.max(m, 20)], "float32")) + bb = relax.BlockBuilder() + with bb.function("bar", (x, y)): + z = bb.emit( + relax.call_tir("test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 1,), dtype="float32")) + ) + bb.emit_func_output(z) + + _check(bar, bb.get()["bar"]) + + # Shape Case + @R.function + def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): + m = T.var("int64") + z = R.call_tir("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) + return z + + m = tir.Var("m", "int64") + x = relax.Var("x", relax.ShapeStructInfo([m])) + y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) + bb = relax.BlockBuilder() + with bb.function("baz", (x, y)): + z = bb.emit(relax.call_tir("test_intrin", (y), R.Tensor((m * 2,), dtype="float32"))) + bb.emit_func_output(z) + + _check(baz, bb.get()["baz"]) + + # Error Case + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor(("m + 1", "m * 2"), "float32")): # name 'm' is not defined + z = R.add(x, x) + return z + + +# TODO(relax-team): enable this when vm ops are ready +@pytest.mark.xfail +def test_vm_ops(): + @R.function + def foo(x: R.Tensor(("m", "n"), dtype="float32")): + m = T.var("int64") + n = T.var("int64") + storage = R.vm.alloc_storage((4 * m * n,), dtype="float32", runtime_device_index=0) + alloc = R.vm.alloc_tensor(storage, (m, n), offset=0, dtype="float32") + tensor = R.builtin.alloc_tensor((m, n), dtype="float32", runtime_device_index=0) + _ = R.vm.call_tir_dyn("te_func", (x, tensor, (m, n))) + gv = tensor + return alloc, gv + + +def test_prim_value(): + @R.function + def foo(): + gv = R.call_packed("test", 1, sinfo_args=R.Tensor((32, 32), "float32")) + return gv + + _check(foo) + + +def test_string_imm(): + @R.function + def foo(): + gv = R.call_packed("test", "hello", sinfo_args=R.Tensor((32, 32), "float32")) + return gv + + _check(foo) + + +def test_datatype_imm(): + @R.function + def foo(): + gv = R.call_packed("test", R.dtype("float32"), sinfo_args=R.Tensor((32, 32), "float32")) + return gv + + _check(foo) + + +def test_function_void_return_type(): + @tvm.script.ir_module + class Foo: + @R.function + def main(x: R.Tensor((3, 3), dtype="float32")): + res = mul(x) + return res + + @R.function + def mul(x: R.Tensor((3, 3), dtype="float32")): + res = R.multiply(x, x) + return res + + _check(Foo) + # Since the return type of function `mul` is not annotated, + # the function `main` regards it as a generic return type. + assert isinstance(Foo["main"].ret_struct_info, relax.ObjectStructInfo) + assert isinstance(Foo["mul"].ret_struct_info, relax.TensorStructInfo) + + @tvm.script.ir_module + class Bar: + @R.function + def main(x1: R.Tensor((3, 3), dtype="float32")): + res1 = mul(x1) + return res1 + + @R.function + def mul(x: R.Tensor((3, 3), dtype="float32")) -> None: + res = R.multiply(x, x) + return res + + # Since the return type of function `mul` is not annotated, + # the function `main` regards it as a generic return type. + _check(Bar) + tvm.ir.assert_structural_equal(Bar["main"].ret_struct_info, relax.TupleStructInfo([])) + tvm.ir.assert_structural_equal(Bar["mul"].ret_struct_info, relax.TupleStructInfo([])) + + +def test_class_normalize(): + @tvm.script.ir_module + class InputModule: + @R.function + def mul_add(x: R.Tensor) -> R.Tensor: + return R.multiply(R.add(x, x), R.add(x, x)) + + # The parser automatically normalizes the input AST to the following ANF form + @tvm.script.ir_module + class OutputModule: + @R.function + def mul_add(x: R.Tensor) -> R.Tensor: + gv = R.add(x, x) + gv1 = R.add(x, x) + return R.multiply(gv, gv1) + + _check(InputModule, OutputModule) + + +if __name__ == "__main__": + test_cross_function_call() + tvm.testing.main() From 409bf916ece3ace288942938811097b436909844 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 10 Feb 2023 07:41:11 -0800 Subject: [PATCH 08/81] [Unity] Relax TVMScript Printer (#13944) This PR introduces Relax as a dialect supported by the TVMScript Printer. Some caveats: - Needs to rebase to mainline before merging. - Some tests are skiped because some operators are not upstreamed to the unity branch yet. Co-authored-by: Tianqi Chen Co-authored-by: Yuchen Jin Co-authored-by: Steven S. Lyubomirsky Co-authored-by: Yong Wu Co-authored-by: Prakalp Srivastava Co-authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Siyuan Feng --- python/tvm/relax/expr.py | 53 +- python/tvm/script/parser/core/entry.py | 1 - src/relax/ir/expr.cc | 45 -- src/relax/ir/struct_info.cc | 43 -- src/script/printer/relax/binding.cc | 87 ++++ src/script/printer/relax/call.cc | 212 ++++++++ src/script/printer/relax/expr.cc | 136 +++++ src/script/printer/relax/function.cc | 78 +++ src/script/printer/relax/region.cc | 100 ++++ src/script/printer/relax/struct_info.cc | 149 ++++++ src/script/printer/relax/tir.cc | 89 ++++ src/script/printer/relax/type.cc | 89 ++++ src/script/printer/relax/utils.h | 101 ++++ .../relax/test_tvmscript_printer_relax.py | 488 ++++++++++++++++++ 14 files changed, 1541 insertions(+), 130 deletions(-) create mode 100644 src/script/printer/relax/binding.cc create mode 100644 src/script/printer/relax/call.cc create mode 100644 src/script/printer/relax/expr.cc create mode 100644 src/script/printer/relax/function.cc create mode 100644 src/script/printer/relax/region.cc create mode 100644 src/script/printer/relax/struct_info.cc create mode 100644 src/script/printer/relax/tir.cc create mode 100644 src/script/printer/relax/type.cc create mode 100644 src/script/printer/relax/utils.h create mode 100644 tests/python/relax/test_tvmscript_printer_relax.py diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 138724ed0693..f1cf815d8ea5 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -22,16 +22,18 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as _np # type: ignore + import tvm import tvm._ffi -import tvm.relax import tvm.ir +import tvm.relax from tvm import DataType from tvm._ffi import base as _base -from tvm.runtime import ndarray as _nd, Object +from tvm.runtime import Object +from tvm.runtime import ndarray as _nd from ..ir import BaseFunc, Node, SourceName, Span -from ..runtime import String +from ..runtime import Scriptable, String from ..tir import PrimExpr from . import _ffi_api @@ -55,7 +57,7 @@ def __init__(self): # NOTE: place base struct info in expr to avoid cyclic dep # from expr to struct info. -class StructInfo(Node): +class StructInfo(Node, Scriptable): """The base class of all StructInfo. StructInfo contains both the static type @@ -110,7 +112,7 @@ def _binary_rhs_helper(rhs: "ExprWithOp") -> "ExprWithOp": raise TypeError(f"type {type(rhs)} not supported") -class ExprWithOp(Expr): +class ExprWithOp(Expr, Scriptable): """Basetype of all relax expressions that defines op overloading.""" def astype(self, dtype: Union[str, DataType]) -> "ExprWithOp": @@ -436,7 +438,7 @@ def __init__( @tvm._ffi.register_object("relax.expr.PrimValue") -class PrimValue(Expr): +class PrimValue(Expr, Scriptable): """The prim expr representing the value.""" value: PrimExpr @@ -448,7 +450,7 @@ def __init__(self, value: Union[PrimExpr, int], span: Span = None) -> None: @tvm._ffi.register_object("relax.expr.StringImm") -class StringImm(Expr): +class StringImm(Expr, Scriptable): """Represent a string literal constant.""" value: str @@ -458,7 +460,7 @@ def __init__(self, value: str, span: Span = None) -> None: @tvm._ffi.register_object("relax.expr.DataTypeImm") -class DataTypeImm(Expr): +class DataTypeImm(Expr, Scriptable): """Represent a data type constant.""" value: DataType @@ -468,11 +470,9 @@ def __init__(self, value: Union[DataType, str], span: Span = None) -> None: @tvm._ffi.register_object("relax.expr.Binding") -class Binding(Node): +class Binding(Node, Scriptable): """The base class of a binding in Relax.""" - ... - @tvm._ffi.register_object("relax.expr.MatchCast") class MatchCast(Binding): @@ -548,7 +548,7 @@ def __init__(self, blocks: List[BindingBlock], body: Expr, span: Span = None) -> @tvm._ffi.register_object("relax.expr.Function") -class Function(BaseFunc): +class Function(BaseFunc, Scriptable): """A Relax function.""" params: List[Var] @@ -588,35 +588,6 @@ def __call__(self, *args): """ return Call(self, args, None, None) - def script(self, show_meta: bool = False) -> str: - """Print relax.Function into TVMScript - - Parameters - ---------- - show_meta : bool - Whether to show meta information - - Returns - ------- - script : str - The TVM Script of the relax.Function - """ - return tvm._ffi.get_global_func("script.AsRelaxScript")(self, show_meta) # type: ignore - - def show(self, style: str = "light") -> None: - """ - A sugar for print highlighted TVM script. - - Parameters - ---------- - style : str, optional - Pygments styles extended by "light" (default) and "dark", by default "light" - """ - from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel - - # Use deferred import to avoid circular import while keeping cprint under tvm/script - cprint(self, style=style) - @tvm._ffi.register_object("relax.expr.ExternFunc") class ExternFunc(BaseFunc): diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 3c01b54a9f1a..d8a11f5b462a 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -54,7 +54,6 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) "tir": tir, "relax": relax, "R": relax, - "tvm": tvm, } source = Source(program) diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 45868a488a36..a0aaea886ddc 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -94,13 +94,6 @@ TVM_REGISTER_GLOBAL("relax.Call") .set_body_typed([](Expr op, Array args, Attrs attrs, Array sinfo_args, Span span) { return Call(op, args, attrs, sinfo_args, span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", " - << node->sinfo_args << ")"; - }); - If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { ObjectPtr n = make_object(); n->cond = std::move(cond); @@ -137,13 +130,6 @@ TVM_REGISTER_GLOBAL("relax.If") return If(cond, true_branch, false_branch, span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", " - << node->false_branch << ")"; - }); - Tuple::Tuple(tvm::Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); @@ -179,12 +165,6 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional o return tuple; } -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Tuple(" << node->fields << ")"; - }); - TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { ObjectPtr n = make_object(); n->tuple = std::move(tuple); @@ -216,12 +196,6 @@ TVM_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int inde return TupleGetItem(tuple, index); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; - }); - TVM_REGISTER_NODE_TYPE(ShapeExprNode); ShapeExpr::ShapeExpr(Array values, Span span) { @@ -245,19 +219,6 @@ TVM_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, return ShapeExpr(values, span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const ShapeExprNode* node = static_cast(ref.get()); - p->stream << "ShapeExpr("; - for (auto it = node->values.begin(); it != node->values.end(); it++) { - if (it != node->values.begin()) { - p->stream << ", "; - } - p->stream << *it; - } - p->stream << ")"; - }); - TVM_REGISTER_NODE_TYPE(VarNode); Var::Var(Id vid, Optional struct_info_annotation, Span span) { @@ -572,12 +533,6 @@ TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol, return ExternFunc(global_symbol, span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const auto* node = static_cast(ref.get()); - p->stream << "ExternFunc(\"" << node->global_symbol << "\")"; - }); - Expr GetShapeOf(const Expr& expr) { // default case, to be normalized. ICHECK(expr->struct_info_.defined()) << "GetShapeOf can only be applied to normalized expr"; diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 9db7cea6725d..4004ad28d560 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -41,11 +41,6 @@ TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) { return ObjectStructInfo(span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - p->stream << "ObjectStructInfo()"; - }); - // Prim PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { ObjectPtr n = make_object(); @@ -60,12 +55,6 @@ TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, Sp return PrimStructInfo(dtype, span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const auto* node = static_cast(ref.get()); - p->stream << "PrimStructInfo(" << node->dtype << ")"; - }); - // Shape ShapeStructInfo::ShapeStructInfo(Array values, Span span) { ObjectPtr n = make_object(); @@ -102,16 +91,6 @@ TVM_REGISTER_GLOBAL("relax.ShapeStructInfo") } }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const auto* node = static_cast(ref.get()); - if (node->values.defined()) { - p->stream << "ShapeStructInfo(" << node->values.value() << ")"; - } else { - p->stream << "ShapeStructInfo(ndim=" << node->ndim << ")"; - } - }); - // Tensor TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) { ObjectPtr n = make_object(); @@ -150,16 +129,6 @@ TVM_REGISTER_GLOBAL("relax.TensorStructInfo") } }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const auto* node = static_cast(ref.get()); - if (node->shape.defined()) { - p->stream << "TensorStructInfo(" << node->shape.value() << ", " << node->dtype << ")"; - } else { - p->stream << "TensorStructInfo(" << node->dtype << ", ndim=" << node->ndim << ")"; - } - }); - // Tuple TupleStructInfo::TupleStructInfo(Array fields, Span span) { ObjectPtr n = make_object(); @@ -175,12 +144,6 @@ TVM_REGISTER_GLOBAL("relax.TupleStructInfo") return TupleStructInfo(fields, span); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const auto* node = static_cast(ref.get()); - p->stream << "TupleStructInfo(" << node->fields << ")"; - }); - // Func FuncStructInfo::FuncStructInfo(Array params, StructInfo ret, Span span) { ObjectPtr n = make_object(); @@ -223,12 +186,6 @@ TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc") } }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - const auto* node = static_cast(ref.get()); - p->stream << "FuncStructInfo(" << node->params << ", " << node->ret << ")"; - }); - // Helper functions void UpdateStructInfo(Expr expr, StructInfo struct_info) { ICHECK(!expr->struct_info_.defined()) diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc new file mode 100644 index 000000000000..8a50fe969850 --- /dev/null +++ b/src/script/printer/relax/binding.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& d, // + const Optional& var, const Optional& ann) { + using relax::SeqExpr; + ExprDoc cond = d->AsDoc(n->cond, n_p->Attr("cond")); + std::vector> branches{ + PrintSeqExpr(Downcast(n->true_branch), n_p->Attr("true_branch"), d, false), + PrintSeqExpr(Downcast(n->false_branch), n_p->Attr("false_branch"), d, false), + }; + if (var.defined()) { + for (Array& stmts : branches) { + ExprDoc ret = Downcast(stmts.back())->expr; + stmts.Set(stmts.size() - 1, AssignDoc(var.value(), ret, ann)); + } + } + return IfDoc(cond, branches[0], branches[1]); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::MatchCast n, ObjectPath n_p, IRDocsifier d) -> Doc { + using relax::StructInfo; + using relax::MatchStructInfo; + Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ExprDoc rhs = Relax(d, "match_cast") + ->Call({d->AsDoc(n->value, n_p->Attr("value")), + d->AsDoc(n->struct_info, n_p->Attr("struct_info_"))}); + ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); + return AssignDoc(lhs, rhs, ann); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::VarBinding n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (const auto if_ = n->value.as()) { + Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); + return PrintIfExpr(GetRef(if_), n_p->Attr("value"), d, lhs, ann); + } else if (n->value->IsInstance()) { + IdDoc lhs = DefineVar(n->var, d->frames.back(), d); + d->cfg->binding_names.push_back(lhs->name); + Doc ret = d->AsDoc(n->value, n_p->Attr("value")); + d->cfg->binding_names.pop_back(); + return ret; + } else { + ExprDoc rhs = d->AsDoc(n->value, n_p->Attr("value")); + Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); + ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); + return AssignDoc(lhs, rhs, ann); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](relax::If n, ObjectPath n_p, IRDocsifier d) -> Doc { + return PrintIfExpr(n, n_p, d, NullOpt, NullOpt); + }); + +TVM_SCRIPT_REPR(relax::MatchCastNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::VarBindingNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::IfNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc new file mode 100644 index 000000000000..2feb2082c510 --- /dev/null +++ b/src/script/printer/relax/call.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +class AttrPrinter : public tvm::AttrVisitor { + public: + explicit AttrPrinter(const ObjectPath& p, const IRDocsifier& d, Array* keys, + Array* values) + : p(p), d(d), keys(keys), values(values) {} + + void Visit(const char* key, double* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Float(*value, p->Attr(key))); + } + + void Visit(const char* key, int64_t* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Int(*value, p->Attr(key))); + } + + void Visit(const char* key, uint64_t* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Int(*value, p->Attr(key))); + } + + void Visit(const char* key, int* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Int(*value, p->Attr(key))); + } + + void Visit(const char* key, bool* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Boolean(*value, p->Attr(key))); + } + + void Visit(const char* key, std::string* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::Str(*value, p->Attr(key))); + } + + void Visit(const char* key, DataType* value) final { + keys->push_back(key); + values->push_back(LiteralDoc::DataType(*value, p->Attr(key))); + } + + void Visit(const char* key, runtime::ObjectRef* value) final { + keys->push_back(key); + values->push_back(d->AsDoc(*value, p->Attr(key))); + } + + void Visit(const char* key, void** value) final { + LOG(FATAL) << "TypeError: void is not allowed in Attrs"; + } + + void Visit(const char* key, runtime::NDArray* value) final { + LOG(FATAL) << "TypeError: NDArray is not allowed in Attrs"; + } + + const ObjectPath& p; + const IRDocsifier& d; + Array* keys; + Array* values; +}; + +ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath& n_p, const IRDocsifier& d) { + // TODO(@junrushao): handle callee better + if (const auto* ext = n.as()) { + return LiteralDoc::Str(ext->global_symbol, n_p); + } else if (const auto* gv = n.as()) { + IdDoc callee(gv->name_hint); + callee->source_paths.push_back(n_p); + return callee; + } else { + return d->AsDoc(n, n_p); + } +} + +Optional PrintCallTIR(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (!n->op.same_as(call_tir_op)) { + return NullOpt; + } + ICHECK(n->args.size() == 2 || n->args.size() == 3); + ICHECK(n->sinfo_args.size() == 1); + Array args; + Array kwargs_keys; + Array kwargs_values; + // Step 1. Print n->args[0], the callee + args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayIndex(0), d)); + // Step 2. Print n->args[1], the input arguments + args.push_back(d->AsDoc(n->args[1], n_p->Attr("args")->ArrayIndex(1))); + // Step 3. Print n->sinfo_args, the output struct info + relax::StructInfo o_sinfo = n->sinfo_args[0]; + ObjectPath o_sinfo_p = n_p->Attr("sinfo_args")->ArrayIndex(0); + kwargs_keys.push_back("out_sinfo"); + if (const auto* o = o_sinfo.as()) { + Array fields; + ObjectPath fields_p = o_sinfo_p->Attr("fields"); + for (int i = 0, l = o->fields.size(); i < l; ++i) { + fields.push_back(d->AsDoc(o->fields[i], fields_p->ArrayIndex(i))); + } + kwargs_values.push_back(ListDoc(fields)); + } else { + kwargs_values.push_back(d->AsDoc(o_sinfo, o_sinfo_p)); + } + // Step 4. Print n->args[2], the tir variables + if (n->args.size() == 3) { + kwargs_keys.push_back("tir_vars"); + kwargs_values.push_back(d->AsDoc(n->args[2], n_p->Attr("args")->ArrayIndex(2))); + } + return Relax(d, "call_tir")->Call(args, kwargs_keys, kwargs_values); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc { + // Special case: call_tir + if (Optional doc = PrintCallTIR(n, n_p, d)) { + return doc.value(); + } + ExprDoc prefix{nullptr}; + Array args; + Array kwargs_keys; + Array kwargs_values; + // Step 1. Print op + if (const auto* op = n->op.as()) { + prefix = Relax(d, "call_packed"); + args.push_back(LiteralDoc::Str(op->global_symbol, n_p->Attr("op"))); + } else if (const auto* op = n->op.as()) { + prefix = IdDoc(op->name_hint); + prefix->source_paths.push_back(n_p->Attr("op")); + } else if (const auto* op = n->op.as()) { + std::string name = op->name; + if (name.rfind("relax.", 0) == 0) { + prefix = Relax(d, name.substr(6)); + } else { + prefix = IdDoc(name); + } + prefix->source_paths.push_back(n_p->Attr("op")); + } else if (n->op->IsInstance()) { + prefix = d->AsDoc(n->op, n_p->Attr("op")); + } else { + LOG(FATAL) << "TypeError: Unsupported op: " << n->op->GetTypeKey(); + } + // Step 2. Print args + if (!n->args.empty()) { + args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayIndex(0), d)); + } + for (int i = 1, l = n->args.size(); i < l; ++i) { + args.push_back(d->AsDoc(n->args[i], n_p->Attr("args")->ArrayIndex(i))); + } + // Step 3. Print attrs + if (n->attrs.defined()) { + if (n->op->IsInstance()) { + kwargs_keys.push_back("attrs_type_key"); + kwargs_values.push_back(LiteralDoc::Str(n->attrs->GetTypeKey(), n_p->Attr("attrs"))); + } + if (const auto* attrs = n->attrs.as()) { + std::vector> sorted; + for (const auto& kv : attrs->dict) { + sorted.push_back(kv); + } + std::sort(sorted.begin(), sorted.end()); + for (const auto& kv : sorted) { + kwargs_keys.push_back(kv.first); + kwargs_values.push_back( + d->AsDoc(kv.second, n_p->Attr("attrs")->Attr(kv.first))); + } + } else { + AttrPrinter printer(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values); + const_cast(n->attrs.get())->VisitAttrs(&printer); + } + } + // Step 4. Print type_args + if (n->sinfo_args.size() > 0) { + ObjectPath sinfo_args_p = n_p->Attr("sinfo_args"); + Array sinfo_args; + for (int i = 0, l = n->sinfo_args.size(); i < l; ++i) { + sinfo_args.push_back( + d->AsDoc(n->sinfo_args[i], sinfo_args_p->ArrayIndex(i))); + } + kwargs_keys.push_back("sinfo_args"); + kwargs_values.push_back(TupleDoc(sinfo_args)); + } + return prefix->Call(args, kwargs_keys, kwargs_values); + }); + +TVM_SCRIPT_REPR(relax::CallNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc new file mode 100644 index 000000000000..a786932fc3d9 --- /dev/null +++ b/src/script/printer/relax/expr.cc @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::PrimValue n, ObjectPath n_p, IRDocsifier d) -> Doc { + // TODO(@junrushao): float numbers + return Relax(d, "prim_value")->Call({d->AsDoc(n->value, n_p->Attr("value"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::StringImm n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "str")->Call({LiteralDoc::Str(n->value, n_p->Attr("value"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::DataTypeImm n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "dtype")->Call({LiteralDoc::DataType(n->value, n_p->Attr("value"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::Tuple n, ObjectPath n_p, IRDocsifier d) -> Doc { + // TODO(@junrushao): revisit tuple printing + if (n->fields.empty()) { + return Relax(d, "tuple")->Call({}); + } + Array fields_doc; + ObjectPath fields_p = n_p->Attr("fields"); + for (int i = 0, l = n->fields.size(); i < l; ++i) { + fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayIndex(i))); + } + return TupleDoc(fields_doc); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::TupleGetItem n, ObjectPath n_p, IRDocsifier d) -> Doc { + ExprDoc idx = LiteralDoc::Int(n->index, n_p->Attr("index")); + return d->AsDoc(n->tuple, n_p->Attr("tuple"))[{idx}]; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ShapeExpr n, ObjectPath n_p, IRDocsifier d) -> Doc { + Array values_doc; + ObjectPath values_p = n_p->Attr("values"); + for (int i = 0, l = n->values.size(); i < l; ++i) { + values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayIndex(i), d)); + } + return TupleDoc(values_doc); + }); + +Optional SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) { + DataType dtype = n.DataType(); + const void* data = n->data; + if (n->ndim != 0 || n->device.device_type != kDLCPU) { + return NullOpt; + } + if (dtype == DataType::Int(32)) { + return LiteralDoc::Int(*reinterpret_cast(data), p); + } else if (dtype == DataType::Int(64)) { + return LiteralDoc::Int(*reinterpret_cast(data), p); + } else if (dtype == DataType::Float(32)) { + return LiteralDoc::Float(*reinterpret_cast(data), p); + } else if (dtype == DataType::Float(64)) { + return LiteralDoc::Float(*reinterpret_cast(data), p); + } else if (dtype == DataType::Bool()) { + return LiteralDoc::Boolean(*reinterpret_cast(data), p); + } else { + return NullOpt; + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::Constant n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (Optional s = SpecialScalar(n->data, n_p->Attr("data"))) { + return Relax(d, "const") + ->Call({ + s.value(), + LiteralDoc::DataType(n->data.DataType(), n_p->Attr("data")->Attr("dtype")), + }); + } + return d->AddMetadata(n); + }); + +Doc PrintRelaxVar(relax::Var n, ObjectPath p, IRDocsifier d) { + if (!d->IsVarDefined(n)) { + ExprDoc ann = d->AsDoc(n->struct_info_, p->Attr("struct_info_")); + Frame f = d->frames.back(); + ExprDoc var = DefineVar(n, f, d); + f->stmts.push_back(AssignDoc(var, NullOpt, ann)); + } + return d->GetVarDoc(n).value(); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("", PrintRelaxVar); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("", PrintRelaxVar); + +TVM_SCRIPT_REPR(relax::PrimValueNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::StringImmNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::DataTypeImmNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::TupleNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::TupleGetItemNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ShapeExprNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::VarNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::DataflowVarNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ConstantNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc new file mode 100644 index 000000000000..fa085fcad403 --- /dev/null +++ b/src/script/printer/relax/function.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_REGISTER_NODE_TYPE(RelaxFrameNode); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](relax::Function n, ObjectPath n_p, IRDocsifier d) -> Doc { + std::unordered_set func_vars; + With f(d); + (*f)->AddDispatchToken(d, "relax"); + (*f)->is_func = true; + (*f)->func_vars = &func_vars; + // Step 1. Print the return type + Optional ret_type = NullOpt; + if (const auto& func_sinfo = relax::MatchStructInfo(n)) { + ret_type = d->AsDoc(func_sinfo.value()->ret, // + n_p->Attr("struct_info_")->Attr("ret")); + } + // Step 2. Print params + Array params; + { + ObjectPath params_p = n_p->Attr("params"); + for (int i = 0, l = n->params.size(); i < l; ++i) { + params.push_back(AssignDoc( + /*lhs=*/DefineVar(n->params[i], *f, d), + /*rhs=*/NullOpt, StructInfoAsAnn(n->params[i], params_p->ArrayIndex(i), d, NullOpt))); + } + } + // Step 3. Clean up func variables + (*f)->func_vars = nullptr; + // Step 4. Print attributes + if (n->attrs.defined() && !n->attrs->dict.empty()) { + (*f)->stmts.push_back( + ExprStmtDoc(Relax(d, "func_attr") // + ->Call({d->AsDoc(n->attrs, n_p->Attr("attrs"))}))); + } + // Step 5. Print body + Array body = + PrintSeqExpr(Downcast(n->body), n_p->Attr("body"), d, /*use_ret=*/true); + (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end()); + return HeaderWrapper(d, FunctionDoc(IdDoc(FindFunctionName(d, n).value_or("main")), params, + {Relax(d, "function")}, ret_type, (*f)->stmts)); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ExternFunc n, ObjectPath n_p, IRDocsifier d) -> Doc { + // TODO(@junrushao): print more information out of extern function. + return ExprStmtDoc(LiteralDoc::Str(n->global_symbol, n_p)); + }); + +TVM_SCRIPT_REPR(relax::FunctionNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ExternFuncNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/region.cc b/src/script/printer/relax/region.cc new file mode 100644 index 000000000000..1ac0b5ba14df --- /dev/null +++ b/src/script/printer/relax/region.cc @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +Array PrintSeqExpr(const relax::SeqExpr& n, const ObjectPath& n_p, const IRDocsifier& d, + bool use_ret) { + With f(d); + const Array& blocks = n->blocks; + ObjectPath blocks_p = n_p->Attr("blocks"); + Array* stmts = &(*f)->stmts; + for (int i = 0, l = blocks.size(); i < l; ++i) { + Doc block = d->AsDoc(blocks[i], blocks_p->ArrayIndex(i)); + if (const auto* stmt_block = block.as()) { + stmts->insert(stmts->end(), stmt_block->stmts.begin(), stmt_block->stmts.end()); + } else if (const auto* stmt = block.as()) { + stmts->push_back(GetRef(stmt)); + } else { + LOG(FATAL) << "TypeError: Unknown type: " << block->GetTypeKey(); + } + } + ExprDoc ret = d->AsDoc(n->body, n_p->Attr("body")); + if (use_ret) { + stmts->push_back(ReturnDoc(ret)); + } else { + stmts->push_back(ExprStmtDoc(ret)); + } + return *stmts; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](relax::SeqExpr n, ObjectPath n_p, IRDocsifier d) -> Doc { + return StmtBlockDoc(PrintSeqExpr(n, n_p, d, false)); + }); + +Array PrintBindingBlock(const relax::BindingBlock& n, const ObjectPath& n_p, + const IRDocsifier& d, Array* non_dataflow_vars) { + const Array& bindings = n->bindings; + ObjectPath bindings_p = n_p->Attr("bindings"); + Array stmts; + for (int i = 0, l = bindings.size(); i < l; ++i) { + const relax::Binding& binding = bindings[i]; + ObjectPath binding_p = bindings_p->ArrayIndex(i); + ICHECK(binding->var.defined()); + Doc binding_doc = d->AsDoc(binding, binding_p); + if (const auto* stmt = binding_doc.as()) { + stmts.push_back(GetRef(stmt)); + } else if (const auto* stmt_block = binding_doc.as()) { + stmts.insert(stmts.end(), stmt_block->stmts.begin(), stmt_block->stmts.end()); + } else { + LOG(FATAL) << "TypeError: Unknown type: " << binding_doc->GetTypeKey(); + } + if (non_dataflow_vars != nullptr && !binding->var->IsInstance()) { + non_dataflow_vars->push_back(d->AsDoc(binding->var, binding_p->Attr("var"))); + } + } + return stmts; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::BindingBlock n, ObjectPath n_p, IRDocsifier d) -> Doc { + return StmtBlockDoc(PrintBindingBlock(n, n_p, d, nullptr)); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::DataflowBlock n, ObjectPath n_p, IRDocsifier d) -> Doc { + Array non_dataflow_vars; + Array stmts = PrintBindingBlock(n, n_p, d, &non_dataflow_vars); + stmts.push_back(ExprStmtDoc(Relax(d, "output")->Call(non_dataflow_vars))); + return ScopeDoc(NullOpt, Relax(d, "dataflow")->Call({}), stmts); + }); + +TVM_SCRIPT_REPR(relax::SeqExprNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::BindingBlockNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::DataflowBlockNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc new file mode 100644 index 000000000000..6f4a66c991d9 --- /dev/null +++ b/src/script/printer/relax/struct_info.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ObjectStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Object"); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::PrimStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Prim")->Call({LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))}); + }); + +ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifier& d) { + ExprDoc expr_doc = d->AsDoc(e, e_p); + // Step 1. Find if `func_vars` are being collected + const RelaxFrameNode* f = nullptr; + for (const Frame& frame : d->frames) { + if (const auto* relax_frame = frame.as()) { + if (relax_frame->func_vars) { + f = relax_frame; + break; + } + } + } + // Step 2. Figure out if the PrimExpr contains at least a func var + bool func_var_mode = false; + if (f != nullptr) { + tir::PostOrderVisit(e, [f, &func_var_mode](const ObjectRef& obj) -> void { + if (const auto* var = obj.as()) { + if (f->func_vars->count(var)) { + func_var_mode = true; + } + } + }); + } + // Step 3. Stringify the PrimExpr if func var exists + if (func_var_mode) { + return LiteralDoc::Str(DocToPythonScript(expr_doc, d->cfg), e_p); + } + return expr_doc; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::ShapeStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (n->values.defined()) { + Array shape = n->values.value(); + ObjectPath shape_p = n_p->Attr("values"); + Array shape_docs; + for (int i = 0, ndim = shape.size(); i < ndim; ++i) { + shape_docs.push_back(PrintShapeVar(shape[i], shape_p->ArrayIndex(i), d)); + } + return Relax(d, "Shape")->Call({ListDoc(shape_docs)}); + } + return Relax(d, "Shape") + ->Call({}, {"ndim"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::TensorStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + Array args; + Array kwargs_keys; + Array kwargs_values; + if (n->shape.defined()) { + args.push_back(d->AsDoc(n->shape.value(), n_p->Attr("shape"))); + } + if (!n->IsUnknownDtype()) { + kwargs_keys.push_back("dtype"); + kwargs_values.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))); + } + if (!n->shape.defined() && !n->IsUnknownNdim()) { + kwargs_keys.push_back("ndim"); + kwargs_values.push_back(LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))); + } + if (args.empty() && kwargs_keys.empty()) { + return Relax(d, "Tensor"); + } + return Relax(d, "Tensor")->Call(args, kwargs_keys, kwargs_values); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::TupleStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (n->fields.empty()) { + return Relax(d, "Tuple"); + } + Array fields_doc; + ObjectPath fields_p = n_p->Attr("fields"); + for (int i = 0, l = n->fields.size(); i < l; ++i) { + fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayIndex(i))); + } + return Relax(d, "Tuple")->Call(fields_doc); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::FuncStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (n->IsOpaque()) { + return Relax(d, "Callable"); + } + // TODO(@junrushao): track symbolic shape relation + Array params_doc; + Array params = n->params.value(); + ObjectPath params_p = n_p->Attr("params"); + for (int i = 0, n_params = params.size(); i < n_params; ++i) { + params_doc.push_back(d->AsDoc(params[i], params_p->ArrayIndex(i))); + } + return Relax(d, "Callable") + ->Call({TupleDoc(params_doc), // + d->AsDoc(n->ret, n_p->Attr("ret"))}); + }); + +TVM_SCRIPT_REPR(relax::ObjectStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::PrimStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ShapeStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::TensorStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::TupleStructInfoNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::FuncStructInfoNode, ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc new file mode 100644 index 000000000000..2c8bb0f1da6c --- /dev/null +++ b/src/script/printer/relax/tir.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) { + ICHECK(n->dtype.is_int() && n->dtype.is_scalar()) << "TypeError: Relax only uses " + "scalar integer TIR variables, but gets: " + << n; + if (!d->IsVarDefined(n)) { + // Find the outmost Relax function frame. If not exist, the outmost Relax frame. + RelaxFrameNode* f = nullptr; + for (const Frame& frame : d->frames) { + if (const auto* relax_frame = frame.as()) { + if (relax_frame->is_func) { + f = const_cast(relax_frame); + break; + } else if (f == nullptr) { + f = const_cast(relax_frame); + } + } + } + // There should be at least one Relax frame + if (f == nullptr) { + LOG(FATAL) << "IndexError: No relax environment is found when printing a TIR var under " + "relax's dispatch token"; + } + // If the Relax function frame is collecting func vars + if (f->func_vars) { + ICHECK(f->is_func); + f->func_vars->insert(n.get()); + } + IdDoc var = d->Define(n, GetRef(f), n->name_hint.empty() ? "v" : n->name_hint); + var->source_paths.push_back(n_p); + f->stmts.push_back(AssignDoc(var, + TIR(d, "Var")->Call({ + LiteralDoc::Str(var->name, n_p->Attr("name_hint")), + LiteralDoc::DataType(n->dtype, n_p->Attr("dtype")), + }), + NullOpt)); + } + if (Optional doc = d->GetVarDoc(n)) { + return doc.value(); + } + LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << n; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", PrintTIRVar); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", PrintTIRVar); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "relax", [](tvm::IntImm n, ObjectPath n_p, IRDocsifier d) -> Doc { // + // TODO(@junrushao): support non-int64 cases + return LiteralDoc::Int(n->value, n_p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "relax", [](tvm::GlobalVar n, ObjectPath n_p, IRDocsifier d) -> Doc { // + IdDoc ret(n->name_hint); + ret->source_paths.push_back(n_p); + return ret; + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/type.cc b/src/script/printer/relax/type.cc new file mode 100644 index 000000000000..d13d90b1d5ed --- /dev/null +++ b/src/script/printer/relax/type.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ShapeType n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Shape") + ->Call({}, {"ndim"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "", [](relax::ObjectType n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Object"); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::DynTensorType n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "Tensor") + ->Call({}, {"ndim", "dtype"}, + {LiteralDoc::Int(n->ndim, n_p->Attr("ndim")), + LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](relax::PackedFuncType n, ObjectPath n_p, IRDocsifier d) -> Doc { + return Relax(d, "PackedFunc"); // TODO(@junrushao): verify if this is correct + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( // + "relax", [](tvm::TupleType n, ObjectPath n_p, IRDocsifier d) -> Doc { + if (n->fields.empty()) { + return Relax(d, "Tuple"); + } + Array fields_doc; + ObjectPath fields_p = n_p->Attr("fields"); + for (int i = 0, l = n->fields.size(); i < l; ++i) { + fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayIndex(i))); + } + return Relax(d, "Tuple")->Call(fields_doc); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "relax", [](tvm::FuncType n, ObjectPath n_p, IRDocsifier d) -> Doc { + Array arg_types_doc; + Array arg_types = n->arg_types; + ObjectPath arg_types_p = n_p->Attr("arg_types"); + for (int i = 0, n_params = arg_types.size(); i < n_params; ++i) { + arg_types_doc.push_back(d->AsDoc(arg_types[i], arg_types_p->ArrayIndex(i))); + } + return Relax(d, "Callable") + ->Call({TupleDoc(arg_types_doc), // + d->AsDoc(n->ret_type, n_p->Attr("ret_type"))}); + }); + +TVM_SCRIPT_REPR(relax::ShapeTypeNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::ObjectTypeNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::DynTensorTypeNode, ReprPrintRelax); +TVM_SCRIPT_REPR(relax::PackedFuncTypeNode, ReprPrintRelax); +TVM_REGISTER_GLOBAL("script.printer.ReprPrintRelax").set_body_typed(ReprPrintRelax); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h new file mode 100644 index 000000000000..7702f7b22dd2 --- /dev/null +++ b/src/script/printer/relax/utils.h @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ +#define TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ + +#include +#include +#include + +#include +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace script { +namespace printer { + +class RelaxFrameNode : public FrameNode { + public: + bool is_func = false; + std::unordered_set* func_vars = nullptr; + + void VisitAttrs(AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("is_global_func", &is_func); + // `func_var_to_define` is not visited + } + + static constexpr const char* _type_key = "script.printer.RelaxFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(RelaxFrameNode, FrameNode); +}; + +class RelaxFrame : public Frame { + public: + explicit RelaxFrame(const IRDocsifier& d) { + ObjectPtr n = make_object(); + n->stmts.clear(); + n->d = d.get(); + n->is_func = false; + n->func_vars = nullptr; + data_ = std::move(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame, Frame, RelaxFrameNode); +}; + +/*! \brief Redirected method for the ReprPrinter */ +inline std::string ReprPrintRelax(const ObjectRef& obj, const PrinterConfig& cfg) { + IRDocsifier d(cfg); + With f(d); + (*f)->AddDispatchToken(d, "relax"); + return Docsify(obj, d, *f, cfg); +} + +inline IdDoc DefineVar(const relax::Var& var, const Frame& frame, const IRDocsifier& d) { + return d->Define(var, frame, var->name_hint().empty() ? "v" : var->name_hint()); +} + +inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& v_p, + const IRDocsifier& d, const Optional& rhs) { + if (!v->struct_info_.defined()) { + return NullOpt; + } + if (const auto* call = rhs.as()) { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (call->op.same_as(call_tir_op)) { + return NullOpt; + } + } + return d->AsDoc(v->struct_info_, v_p->Attr("struct_info_")); +} + +Array PrintSeqExpr(const relax::SeqExpr& n, const ObjectPath& n_p, const IRDocsifier& d, + bool use_ret); + +ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifier& d); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py new file mode 100644 index 000000000000..e2cb8bc5fc32 --- /dev/null +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -0,0 +1,488 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import pytest +from tvm import IRModule, relax, tir +from tvm.script import relax as R + + +def _assert_print(obj, expected): + if not isinstance(obj, str): + obj = obj.script(verbose_expr=True) + obj = obj.strip() + assert obj == expected.strip(), "\n" + obj + + +def test_function(): + @R.function + def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore + return a + + _assert_print( + func, + """ +# from tvm.script import relax as R + +@R.function +def main(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): + return a""", + ) + + +def test_extern_func(): + @R.function + def relax_func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore + return a + + obj = IRModule( + { + "func": relax_func, + "my_ext": relax.ExternFunc("my_ext"), + } + ) + _assert_print( + obj, + """ +# from tvm.script import ir as I +# from tvm.script import relax as R + +@I.ir_module +class Module: + "my_ext" + @R.function + def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): + return a +""", + ) + + +def test_object_struct_info(): + obj = relax.ObjectStructInfo() + _assert_print( + obj, + "R.Object", + ) + + +def test_prim_struct_info(): + obj = relax.PrimStructInfo("float32") + _assert_print(obj, 'R.Prim("float32")') + + +def test_shape_struct_info_0(): + obj = relax.ShapeStructInfo(ndim=-1) + _assert_print(obj, "R.Shape(ndim=-1)") + + +def test_shape_struct_info_1(): + obj = relax.ShapeStructInfo([1, 2, 3]) + _assert_print(obj, "R.Shape([1, 2, 3])") + + +def test_shape_struct_info_2(): + obj = relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]) + _assert_print( + obj, + """ +a = T.Var("a", "int64") +R.Shape([1, a, 3])""", + ) + + +def test_tensor_struct_info(): + obj = relax.TensorStructInfo( + shape=relax.ShapeExpr([1, tir.Var("a", "int64"), 3]), + dtype="float32", + ) + _assert_print( + obj, + """ +a = T.Var("a", "int64") +R.Tensor((1, a, 3), dtype="float32") +""", + ) + + +def test_tuple_struct_info_empty(): + obj = relax.TupleStructInfo([]) + _assert_print(obj, "R.Tuple") + + +def test_tuple_struct_info(): + obj = relax.TupleStructInfo( + [ + relax.PrimStructInfo("float32"), + relax.ObjectStructInfo(), + relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]), + ] + ) + _assert_print( + obj, + """ +a = T.Var("a", "int64") +R.Tuple(R.Prim("float32"), R.Object, R.Shape([1, a, 3])) +""", + ) + + +def test_func_struct_info(): + obj = relax.FuncStructInfo( + params=[ + relax.PrimStructInfo("float32"), + relax.ObjectStructInfo(), + relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]), + ], + ret=relax.TensorStructInfo( + shape=relax.ShapeExpr([1, 2, 3]), + dtype="float32", + ), + ) + _assert_print( + obj, + """ +a = T.Var("a", "int64") +R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3])), R.Tensor((1, 2, 3), dtype="float32")) +""", + ) + + +def test_shape_type(): + obj = relax.ShapeType(ndim=3) + _assert_print(obj, "R.Shape(ndim=3)") + + +def test_object_type(): + obj = relax.ObjectType() + _assert_print(obj, "R.Object") + + +def test_dyn_tensor_type(): + obj = relax.DynTensorType() + _assert_print(obj, 'R.Tensor(ndim=-1, dtype="float32")') + + +def test_packed_func_type(): + obj = relax.PackedFuncType() + _assert_print(obj, "R.PackedFunc") + + +def test_tuple_type(): + obj = relax.TupleType([relax.ShapeType(ndim=3), relax.ObjectType()]) + _assert_print( + obj._relax_script(), # pylint: disable=protected-access + "R.Tuple(R.Shape(ndim=3), R.Object)", + ) + + +def test_func_type(): + obj = relax.FuncType( + arg_types=[ + relax.ObjectType(), + relax.ShapeType(ndim=3), + ], + ret_type=relax.DynTensorType( + ndim=3, + dtype="float32", + ), + ) + _assert_print( + obj._relax_script(), # pylint: disable=protected-access + 'R.Callable((R.Object, R.Shape(ndim=3)), R.Tensor(ndim=3, dtype="float32"))', + ) + + +def test_prim_value(): + obj = relax.PrimValue(1) + _assert_print(obj, "R.prim_value(1)") + + +def test_string_imm(): + obj = relax.StringImm("hello") + _assert_print(obj, 'R.str("hello")') + + +def test_data_type_imm(): + obj = relax.DataTypeImm("float32") + _assert_print(obj, 'R.dtype("float32")') + + +def test_var(): + obj = relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +a""", + ) + + +def test_dataflow_var(): + obj = relax.DataflowVar("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +a""", + ) + + +def test_tuple(): + obj = relax.Tuple( + [ + relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")), + relax.Var("b", relax.TensorStructInfo([1, tir.Var("y", "int64"), 3], "float32")), + relax.Var("c", relax.TensorStructInfo([1, tir.Var("z", "int64"), 3], "float32")), + ] + ) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +y = T.Var("y", "int64") +b: R.Tensor((1, y, 3), dtype="float32") +z = T.Var("z", "int64") +c: R.Tensor((1, z, 3), dtype="float32") +(a, b, c) +""", + ) + + +def test_tuple_get_item(): + obj = relax.TupleGetItem( + relax.Tuple( + [ + relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")), + relax.Var("b", relax.TensorStructInfo([1, tir.Var("y", "int64"), 3], "float32")), + relax.Var("c", relax.TensorStructInfo([1, tir.Var("z", "int64"), 3], "float32")), + ] + ), + 0, + ) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +y = T.Var("y", "int64") +b: R.Tensor((1, y, 3), dtype="float32") +z = T.Var("z", "int64") +c: R.Tensor((1, z, 3), dtype="float32") +(a, b, c)[0] +""", + ) + + +def test_shape_expr(): + obj = relax.ShapeExpr([1, 2, 3]) + _assert_print(obj, "(1, 2, 3)") + + +def test_call(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + obj = relax.call_tir("my_func", args=a, out_sinfo=a.struct_info, tir_vars=[x]) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=(x,)) +""", + ) + + +@pytest.mark.skip(reason="`relax.op.sin` is not upstreamed yet") +def test_seq_expr(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + b = relax.DataflowVar("b", relax.TensorStructInfo([1, x, 3], "float32")) + c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) + + obj = relax.SeqExpr( + blocks=[ + relax.DataflowBlock( + bindings=[ + relax.VarBinding(b, relax.op.sin(a)), + relax.VarBinding(c, relax.op.sin(b)), + ] + ), + ], + body=c, + ) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +with R.dataflow(): + b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) + c: R.Tensor((1, x, 3), dtype="float32") = R.sin(b) + R.output(c) +c +""", + ) + + +@pytest.mark.skip(reason="`relax.op.sin` is not upstreamed yet") +def test_binding_block(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + b = relax.Var("b", relax.TensorStructInfo([1, x, 3], "float32")) + c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) + obj = relax.BindingBlock( + bindings=[ + relax.VarBinding(b, relax.op.sin(a)), + relax.VarBinding(c, relax.op.sin(b)), + ] + ) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) +c: R.Tensor((1, x, 3), dtype="float32") = R.sin(b) +""", + ) + + +@pytest.mark.skip(reason="`relax.op.sin` is not upstreamed yet") +def test_dataflow_block(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + b = relax.DataflowVar("b", relax.TensorStructInfo([1, x, 3], "float32")) + c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) + obj = relax.DataflowBlock( + bindings=[ + relax.VarBinding(b, relax.op.sin(a)), + relax.VarBinding(c, relax.op.sin(b)), + ] + ) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +with R.dataflow(): + b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) + c: R.Tensor((1, x, 3), dtype="float32") = R.sin(b) + R.output(c) +""", + ) + + +def test_match_cast(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3])) + b = relax.Var("b", relax.TensorStructInfo([1, 5, 3])) + obj = relax.MatchCast( + var=b, + value=a, + struct_info=b.struct_info, + ) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +b: R.Tensor((1, 5, 3), dtype="float32") = R.match_cast(a, R.Tensor((1, 5, 3), dtype="float32")) +""", + ) + + +@pytest.mark.skip(reason="`relax.op.sin` is not upstreamed yet") +def test_var_binding(): + x = tir.Var("x", "int64") + a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) + b = relax.Var("b", relax.TensorStructInfo([1, x, 3], "float32")) + obj = relax.VarBinding(b, relax.op.sin(a)) + _assert_print( + obj, + """ +x = T.Var("x", "int64") +a: R.Tensor((1, x, 3), dtype="float32") +b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) +""", + ) + + +def test_if(): + a = relax.Var("a", relax.TensorStructInfo([], "bool")) + b = relax.Var("b", relax.TensorStructInfo([1, 2, 3], "float32")) + c = relax.Var("c", relax.TensorStructInfo([1, 2, 3], "float32")) + obj = relax.If( + a, + relax.SeqExpr([], b), + relax.SeqExpr([], c), + ) + _assert_print( + obj, + """ +a: R.Tensor((), dtype="bool") +if a: + b: R.Tensor((1, 2, 3), dtype="float32") + b +else: + c: R.Tensor((1, 2, 3), dtype="float32") + c +""", + ) + + +if __name__ == "__main__": + test_function() + test_extern_func() + + test_object_struct_info() + test_prim_struct_info() + test_shape_struct_info_0() + test_shape_struct_info_1() + test_shape_struct_info_2() + test_tensor_struct_info() + test_tuple_struct_info_empty() + test_tuple_struct_info() + test_func_struct_info() + + test_shape_type() + test_object_type() + test_dyn_tensor_type() + test_packed_func_type() + test_tuple_type() + test_func_type() + + test_prim_value() + test_string_imm() + test_data_type_imm() + + test_var() + test_dataflow_var() + # + test_tuple() + test_tuple_get_item() + test_shape_expr() + test_call() + + test_seq_expr() + test_binding_block() + test_dataflow_block() + + test_match_cast() + test_var_binding() + test_if() From 509576413356fedfb24ae7daf5d3a6dd2a255dbc Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Fri, 10 Feb 2023 16:19:37 -0800 Subject: [PATCH 09/81] [Unity] Relax VM codegen (#13954) --- python/tvm/relax/testing/runtime_builtin.py | 34 ++ src/relax/backend/vm/codegen_vm.cc | 447 ++++++++++++++++++ src/relax/op/op.cc | 220 ++++++++- src/relax/op/op_common.h | 25 +- src/runtime/relax_vm/builtin.cc | 23 +- tests/python/relax/test_runtime_builtin.py | 153 ++++++ .../relax/test_tvmscript_printer_relax.py | 41 +- tests/python/relax/test_vm_codegen_only.py | 333 +++++++++++++ 8 files changed, 1228 insertions(+), 48 deletions(-) create mode 100644 python/tvm/relax/testing/runtime_builtin.py create mode 100644 src/relax/backend/vm/codegen_vm.cc create mode 100644 tests/python/relax/test_runtime_builtin.py create mode 100644 tests/python/relax/test_vm_codegen_only.py diff --git a/python/tvm/relax/testing/runtime_builtin.py b/python/tvm/relax/testing/runtime_builtin.py new file mode 100644 index 000000000000..1b04364e69fa --- /dev/null +++ b/python/tvm/relax/testing/runtime_builtin.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities for runtime builtin functions.""" +from enum import IntEnum + + +class MatchShapeCode(IntEnum): + """Code passed to match shape builtin""" + + ASSERT_EQUAL_TO_IMM = 0 + STORE_TO_HEAP = 1 + NO_OP = 2 + ASSERT_EQUAL_TO_LOAD = 3 + + +class MakeShapeCode(IntEnum): + """Code passed to match shape builtin""" + + USE_IMM = 0 + LOAD_SHAPE = 1 diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc new file mode 100644 index 000000000000..1782f1107a5b --- /dev/null +++ b/src/relax/backend/vm/codegen_vm.cc @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/vm/codegen_vm.cc + * \brief A codegen to generate VM executable from a Relax IRModule. + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../../target/metadata_module.h" +#include "../../../target/source/codegen_source_base.h" + +namespace tvm { +namespace relax { +namespace relax_vm { + +using tvm::Target; +using namespace relax; +using namespace tvm::runtime; +using namespace tvm::runtime::relax_vm; + +// Helper function to get the function name of the registered packed function implementation of +// relax operator. +FCallPacked GetPackedFuncName(const Call& call) { + static auto op_map = Op::GetAttrMap("FCallPacked"); + if (call->op.as()) { + Op op = Downcast(call->op); + if (op_map.count(op)) { + return op_map[op]; + } + } + return {}; +} + +/*! + * \brief A class to generate VM executable for Relax functions. + */ +class CodeGenVM : public ExprFunctor { + public: + explicit CodeGenVM(relax::ExecBuilder builder, IRModule ctx_mod) + : builder_(builder), ctx_mod_(ctx_mod) {} + + static IRModule Run(relax::ExecBuilder builder, IRModule mod) { + IRModule res_mod = IRModule(Map()); + CodeGenVM codegen(builder, mod); + // Remove relax function and turn into TIR func. + for (auto& p : mod->functions) { + if (auto* func = p.second.as()) { + codegen.Codegen(GetRef(func)); + } else { + res_mod->Add(p.first, p.second); + } + } + return res_mod; + } + + protected: + size_t NewRegister() { return registers_num_++; } + + // Convert Arg value to a register, trigger copy if needed + Instruction::Arg EnsureReg(Instruction::Arg arg) { + if (arg.kind() == Instruction::ArgKind::kRegister) { + return arg; + } else { + RegName dst_reg = NewRegister(); + builder_->EmitCall("vm.builtin.copy", {arg}, dst_reg); + return Instruction::Arg::Register(dst_reg); + } + } + + void Codegen(const Function& func) { + Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(gsymbol.defined()) << "there should be no local functions in Relax VM codegen phase. " + "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; + + Array param_names; + for (Var param : func->params) { + param_names.push_back(param->name_hint()); + } + + builder_->EmitFunction(gsymbol.value(), func->params.size(), param_names); + + for (size_t i = 0; i < func->params.size(); ++i) { + RegName r = NewRegister(); + ICHECK_EQ(r, static_cast(i)); + this->var_arg_map_.insert({func->params[i], Instruction::Arg::Register(r)}); + } + Instruction::Arg ret = ExprFunctor::VisitExpr(func->body); + builder_->EmitRet(EnsureReg(ret)); + builder_->EndFunction(gsymbol.value()); + // reset register number to be 0; + registers_num_ = 0; + var_arg_map_.clear(); + } + + Instruction::Arg VisitExpr_(const SeqExprNode* op) final { + for (auto block : op->blocks) { + for (Binding binding : block->bindings) { + Instruction::Arg value; + if (auto* var_binding = binding.as()) { + value = this->VisitExpr(var_binding->value); + } else if (auto* match_cast = binding.as()) { + value = this->VisitExpr(match_cast->value); + } else { + LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); + } + this->var_arg_map_.insert({binding->var, value}); + } + } + + Instruction::Arg ret_reg = this->VisitExpr(op->body); + return ret_reg; + } + + Instruction::Arg VisitExpr_(const CallNode* call_node) final { + Call call = GetRef(call_node); + + if (call_node->op == null_value_op_) { + return Instruction::Arg::Register(Instruction::kVoidRegister); + } + + // allocate dst register. + RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : NewRegister(); + if (call->op.as()) { + if (call_node->op == call_builtin_with_ctx_op_) { + // TODO(relax-team) migrate most handling of op to + // directly map to call_builtin_with_ctx before codegen and simplify vm codegen. + EmitCallBuiltinWithCtx(call, dst_reg); + } else if (call_node->op == alloc_storage_op_) { + EmitAllocStorage(call, dst_reg); + } else if (call_node->op == alloc_tensor_op_) { + EmitAllocTensor(call, dst_reg); + } else { + // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those + // ops are handled in a pass when lowering them to TIR. + LOG(FATAL) << "CodeGenVM cannot handle this intrinsic now:\n" << call_node->op; + } + } else { + EmitNormalCall(call, dst_reg); + } + return Instruction::Arg::Register(dst_reg); + } + + Instruction::Arg VisitExpr_(const IfNode* op) final { + const If& ife = GetRef(op); + Instruction::Arg cond_value = this->VisitExpr(ife->cond); + + // Reserve a register for cond + RegName cond_reg = NewRegister(); + builder_->EmitCall("vm.builtin.read_if_cond", {cond_value}, cond_reg); + + // obtain the temp exec in progress. + vm::Executable* exec = builder_->exec(); + + // Record the offset of If instruction + size_t if_offset = exec->instr_offset.size(); + + builder_->EmitIf(Instruction::Arg::Register(cond_reg), 3); + size_t num_instr = exec->instr_offset.size(); + Instruction::Arg true_value = this->VisitExpr(ife->true_branch); + // Reserve a register for return + size_t merge_register = NewRegister(); + // Copy the output from true branch to merge register + builder_->EmitCall("vm.builtin.copy", {true_value}, merge_register); + + // Record the offset of Goto instruction + size_t goto_offset = exec->instr_offset.size(); + + builder_->EmitGoto(1); + + // Calculate the false offset of If + size_t false_offset = exec->instr_offset.size() - num_instr + 1; + + Instruction::Arg false_value = this->VisitExpr(ife->false_branch); + // Copy the output data of false branch to merge register + builder_->EmitCall("vm.builtin.copy", {false_value}, merge_register); + + // Update the offsets of the If instruction emitted above + // Jump to the behind of the next goto instruction + exec->SetInstructionData(if_offset, 2, static_cast(false_offset)); + // Update the pc_offset of Goto instruction + // Jump over the false branch + size_t pc_offset = exec->instr_offset.size() - goto_offset; + exec->SetInstructionData(goto_offset, 1, static_cast(pc_offset)); + return Instruction::Arg::Register(merge_register); + } + + Instruction::Arg VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + auto it = this->var_arg_map_.find(var); + ICHECK(it != this->var_arg_map_.end()) << "Var " << var << " is not defined"; + return it->second; + } + + Instruction::Arg VisitExpr_(const ConstantNode* op) final { + return builder_->ConvertConstant(op->data); + } + + Instruction::Arg VisitExpr_(const ShapeExprNode* op) final { + std::vector shape; + for (PrimExpr e : op->values) { + if (auto* int_value = e.as()) { + shape.push_back(int_value->value); + } else { + LOG(FATAL) << "Should only use constant shape after shape lowering: " << op->values; + } + } + return builder_->ConvertConstant(ShapeTuple(shape)); + } + + Instruction::Arg VisitExpr_(const PrimValueNode* op) final { + if (auto* int_imm = op->value.as()) { + return builder_->ConvertConstant(int_imm->value); + } else { + auto* float_imm = op->value.as(); + ICHECK(float_imm) << "PrimValue can only be IntImm/FloatImm for now"; + return builder_->ConvertConstant(float_imm->value); + } + } + + Instruction::Arg VisitExpr_(const StringImmNode* op) final { + return builder_->ConvertConstant(op->value); + } + + Instruction::Arg VisitExpr_(const DataTypeImmNode* op) final { + return builder_->ConvertConstant(op->value); + } + + Instruction::Arg VisitExpr_(const TupleNode* op) final { + Tuple tuple = GetRef(op); + std::vector args; + for (Expr arg : tuple->fields) { + args.push_back(this->VisitExpr(arg)); + } + size_t dst_register = NewRegister(); + builder_->EmitCall("vm.builtin.make_tuple", args, dst_register); + + return Instruction::Arg::Register(dst_register); + } + + Instruction::Arg VisitExpr_(const TupleGetItemNode* op) final { + TupleGetItem expr = GetRef(op); + std::vector args = {this->VisitExpr(expr->tuple)}; + + args.push_back(builder_->ConvertConstant(expr->index)); + + size_t dst_register = NewRegister(); + builder_->EmitCall("vm.builtin.tuple_getitem", args, dst_register); + + return Instruction::Arg::Register(dst_register); + } + + Instruction::Arg VisitExpr_(const GlobalVarNode* op) final { + GlobalVar gvar = GetRef(op); + Optional symbol; + VMFuncInfo::FuncKind kind = VMFuncInfo::FuncKind::kPackedFunc; + + // Run a look up in the env to see if it maps to an extern func. + auto it = ctx_mod_->functions.find(gvar); + if (it != ctx_mod_->functions.end()) { + BaseFunc func = (*it).second; + if (auto* efunc = func.as()) { + symbol = efunc->global_symbol; + kind = VMFuncInfo::FuncKind::kPackedFunc; + } else if (func.as()) { + symbol = gvar->name_hint; + kind = VMFuncInfo::FuncKind::kVMFunc; + } + } + // GlobalVar can be reference to a Relax function or a TIR primfunc + // At this point: all global var must corresponds to the right symbol. + // TODO(relax-team): switch everything to extern before splitting TIR/relax + // so we do not have idle global var here. + if (!symbol.defined()) { + symbol = gvar->name_hint; + kind = VMFuncInfo::FuncKind::kPackedFunc; + } + // declare the function to be safe. + ICHECK(symbol.defined()); + builder_->DeclareFunction(symbol.value(), kind); + return builder_->GetFunction(symbol.value()); + } + + Instruction::Arg VisitExpr_(const ExternFuncNode* op) final { + builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc); + return builder_->GetFunction(op->global_symbol); + } + + void EmitAllocStorage(const Call& call_node, RegName dst_reg) { + ICHECK_EQ(call_node->args.size(), 3); + // Handle args of the call + std::vector args; + args.push_back(Instruction::Arg::Register(Instruction::kVMRegister)); + // buffer size, dtype, device index + for (auto arg : call_node->args) { + args.push_back(this->VisitExpr(arg)); + } + builder_->EmitCall("vm.builtin.alloc_storage", args, dst_reg); + } + + void EmitAllocTensor(const Call& call_node, RegName dst_reg) { + ICHECK_EQ(call_node->args.size(), 4); + std::vector args; + args.reserve(4); + for (Expr arg : call_node->args) { + args.push_back(this->VisitExpr(arg)); + } + builder_->EmitCall("vm.builtin.alloc_tensor", args, dst_reg); + } + + void EmitCallBuiltinWithCtx(const Call& call_node, RegName dst_reg) { + std::vector args; + args.push_back(Instruction::Arg::Register(Instruction::kVMRegister)); + + auto func = this->VisitExpr(call_node->args[0]); + auto tuple_arg = Downcast(call_node->args[1]); + + // Handle args of the call + for (Expr arg : tuple_arg->fields) { + args.push_back(this->VisitExpr(arg)); + } + + builder_->EmitCall(func, args, dst_reg); + } + + void EmitNormalCall(const Call& call_node, RegName dst_reg) { + Instruction::Arg func = VisitExpr(call_node->op); + std::vector args = VisitArray(call_node->args); + builder_->EmitCall(func, args, dst_reg); + } + + // TODO(relax-team) revisit after PrimValue. + // Emit the `call_node` attributes as constants and append these constants to `args` vector. + void AppendAttrsAsConstants(const Call& call_node, std::vector& args) { + auto attrs = call_node->attrs; + if (!attrs.defined()) return; + + LOG(FATAL) << "Support for attributes of Op " << call_node->op + << " has not been implemented yet."; + return; + } + + // Emits call to packed function `name` with arguments copied over from `call_node` args and + // attributes. + void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name, RegName dst_reg) { + std::vector args = VisitArray(call_node->args); + AppendAttrsAsConstants(call_node, args); + builder_->EmitCall(name, args, dst_reg); + } + + std::vector VisitArray(const Array& arr) { + std::vector ret; + for (size_t i = 0; i < arr.size(); ++i) { + ret.push_back(this->VisitExpr(arr[i])); + } + return ret; + } + + /*! \brief Internal ExecBuilder. */ + relax::ExecBuilder builder_; + /*! + * \brief Total number of virtual registers allocated. + * \note The first two registers are reserved for special registers. + */ + size_t registers_num_ = 0; + /*! \brief Map from var to register number. */ + std::unordered_map var_arg_map_; + /*! \brief the context module. */ + IRModule ctx_mod_; + /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */ + const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage"); + const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor"); + const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); + const Op& null_value_op_ = Op::Get("relax.null_value"); +}; + +/*! + * \brief Create the Relax VM executable from all relax.Function in mod. + * and add them to exec_builder. + * \param exec_builder Builder to collect executables. + * \param mod Input module. + * \return Left over IRModule that may contain otehr functions. + */ +IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) { + return CodeGenVM::Run(exec_builder, mod); +} + +TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen); + +/*! + * \brief Link the libaries together. + */ +Module VMLink(ExecBuilder builder, Target target, Optional lib, Array ext_libs, + Map params) { + // TODO(relax-team) Revisit the param and ext_lib options. + ObjectPtr executable = builder->Get(); + if (!lib.defined()) { + lib = codegen::CSourceModuleCreate(";", "", Array{}); + } + std::unordered_map conv_params; + for (const auto& [name, param] : params) { + conv_params[name] = param; + } + Module combined_lib = codegen::CreateMetadataModule( + conv_params, lib.value(), ext_libs, target, + + // TODO(@sunggg): Currently, CRT uses relay-specific executor for uTVM support. + // Before jumping into details, only support cpp runtime for now. + relay::Runtime::Create("cpp"), + relay::Executor::Create("graph"), // TODO(@sunggg): pass arbitrarily executor. CPP runtime + // won't use this anyways. + relay::backend::ExecutorCodegenMetadata()); + executable->Import(combined_lib); + return Module(executable); +} + +TVM_REGISTER_GLOBAL("relax.VMLink").set_body_typed(VMLink); + +} // namespace relax_vm +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 8640ed79adb0..ca66b0a9ef75 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -18,13 +18,46 @@ */ #include #include -#include #include #include +#include "op_common.h" + namespace tvm { namespace relax { +bool EqualConstInt(const PrimExpr& lhs, int64_t value) { + if (const int64_t* pvalue = tir::as_const_int(lhs)) { + return pvalue[0] == value; + } + return false; +} + +bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { + PrimExpr diff = lhs - rhs; + if (const int64_t* pdiff = tir::as_const_int(diff)) { + return pdiff[0] == 0; + } + tvm::arith::Analyzer ana; + diff = ana.Simplify(diff); + if (const int64_t* pdiff = tir::as_const_int(diff)) { + return pdiff[0] == 0; + } + return false; +} + +StructInfo ReturnVoidStructInfo(const Call& call, const BlockBuilder& ctx) { + return TupleStructInfo(Array()); +} + +StructInfo ReturnObjectStructInfo(const Call& call, const BlockBuilder& ctx) { + return ObjectStructInfo(); +} + +StructInfo ReturnShapeStructInfo(const Call& call, const BlockBuilder& ctx) { + return ShapeStructInfo(kUnknownNDim); +} + // call_tir StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { @@ -73,5 +106,190 @@ Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); +// call builtin +StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.size() == 0) { + // by default return void. + return TupleStructInfo(Array()); + } else { + ICHECK_EQ(call->sinfo_args.size(), 1); + return call->sinfo_args[0]; + } +} + +TVM_REGISTER_OP("relax.call_builtin_with_ctx") + .set_num_inputs(4) + .add_argument("func", "Expr", "The builtin packed func.") + .add_argument("args", "Tuple", "The input arguments.") + .set_attr("FInferStructInfo", InferStructInfoCallBuiltinWithCtx); + +Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, Array sinfo_args) { + static const Op& op = Op::Get("relax.call_builtin_with_ctx"); + return Call(op, {func, args}, Attrs(), sinfo_args); +} + +TVM_REGISTER_GLOBAL("relax.op.call_builtin_with_ctx").set_body_typed(MakeCallBuiltinWithCtx); + +TVM_REGISTER_OP("relax.null_value") + .set_num_inputs(0) + .set_attr("FInferStructInfo", ReturnObjectStructInfo); + +Expr MakeCallNullValue() { + static const Op& op = Op::Get("relax.null_value"); + return Call(op, {}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.null_value").set_body_typed(MakeCallNullValue); + +// make_closure + +RELAY_REGISTER_OP("relax.make_closure") + .set_num_inputs(2) + .add_argument("func", "Expr", "The closure.") + .add_argument("args", "Tuple", "The captured variables.") + .set_attr("FInferStructInfo", ReturnObjectStructInfo); + +Expr MakeClosure(Expr func, Tuple args) { + static const Op& op = Op::Get("relax.make_closure"); + return Call(op, {func, args}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.make_closure").set_body_typed(MakeClosure); + +// invoke_closure + +StructInfo InferStructInfoInvokeClosure(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.empty()) { + return ObjectStructInfo(); + } else if (call->sinfo_args.size() == 1) { + return call->sinfo_args[0]; + } else { + return TupleStructInfo(call->sinfo_args); + } +} + +RELAY_REGISTER_OP("relax.invoke_closure") + .set_num_inputs(2) + .add_argument("closure", "Expr", "The VMClosure.") + .add_argument("args", "Tuple", "The captured variables.") + .set_attr("FInferStructInfo", InferStructInfoInvokeClosure); + +Expr InvokeClosure(Expr closure, Tuple args, Array sinfo_args) { + static const Op& op = Op::Get("relax.invoke_closure"); + return Call(op, {closure, args}, {}, sinfo_args); +} + +TVM_REGISTER_GLOBAL("relax.op.invoke_closure").set_body_typed(InvokeClosure); + +// shape_of + +RELAY_REGISTER_OP("relax.shape_of") + .set_num_inputs(1) + .add_argument("input", "Expr", "The input expression") + .set_attr("FInferStructInfo", ReturnShapeStructInfo); + +Expr MakeShapeOf(Expr expr) { + static const Op& op = Op::Get("relax.shape_of"); + return Call(op, {expr}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf); + +// alloc_tensor + +StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& ctx) { + ICHECK(call->args[0].as()) + << "must be ShapeExpr, but got " << call->args[0]->GetTypeKey(); + ICHECK(call->args[1].as()) + << "must be DataTypeImm, but got " << call->args[1]->GetTypeKey(); + DataType out_dtype; + if (const auto* dtype_node = call->args[1].as()) { + const DataTypeImm dtype_imm = GetRef(dtype_node); + out_dtype = dtype_imm->value; + } + return TensorStructInfo(call->args[0], out_dtype); +} + +RELAY_REGISTER_OP("relax.builtin.alloc_tensor") + .set_num_inputs(3) + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.") + .add_argument("runtime_device_index", "int64_t", + "The device index indicating on which device the tensor is to be " + "allocated at runtime. Index -1 is reserved for the host device.") + .set_attr("FInferStructInfo", InferStructInfoAllocateTensor); + +Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t runtime_device_index) { + static const Op& op = Op::Get("relax.builtin.alloc_tensor"); + return Call(op, {shape, DataTypeImm(dtype), PrimValue::Int64(runtime_device_index)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor); + +// vm alloc_storage + +RELAY_REGISTER_OP("relax.vm.alloc_storage") + .set_num_inputs(3) + .add_argument("size", "Expr", "The size of the storage to allocate.") + .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.") + .add_argument("runtime_device_index", "int64_t", + "The device index indicating on which device the tensor is " + "to be allocated at runtime.") + .set_attr("FInferStructInfo", ReturnObjectStructInfo); + +Expr MakeVMAllocStorage(Expr size, int64_t runtime_device_index, DataType dtype) { + static const Op& op = Op::Get("relax.vm.alloc_storage"); + return Call(op, {size, PrimValue::Int64(runtime_device_index), DataTypeImm(dtype)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStorage); + +// vm alloc_tensor + +Expr InferShapeVMAllocTensor(const Call& call, DiagnosticContext diag_ctx) { return call->args[1]; } + +StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& ctx) { + DataType out_dtype; + if (const auto* dtype_node = call->args[3].as()) { + const DataTypeImm dtype_imm = GetRef(dtype_node); + out_dtype = dtype_imm->value; + } + if (const auto* output_shape = call->args[1].as()) { + return TensorStructInfo(GetRef(output_shape), out_dtype); + } + return TensorStructInfo(out_dtype, kUnknownNDim); +} + +RELAY_REGISTER_OP("relax.vm.alloc_tensor") + .set_num_inputs(4) + .add_argument("storage", "Expr", "The storage to allocate the tensor to.") + .add_argument("offset", "int", "Storage offset to allocate the tensor.") + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.") + .set_attr("FInferStructInfo", InferStructInfoVMAllocTensor); + +Expr MakeVMAllocTensor(Expr storage, int offset, Expr shape, DataType dtype) { + static const Op& op = Op::Get("relax.vm.alloc_tensor"); + return Call(op, {storage, PrimValue::Int64(offset), shape, DataTypeImm(dtype)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor); + +// vm call_tir_dyn + +RELAY_REGISTER_OP("relax.vm.call_tir_dyn") + .set_num_inputs(2) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", + "The input arguments (list of tensors and last argument is ShapeExpr)") + .set_attr("FInferStructInfo", ReturnVoidStructInfo); + +Expr MakeCallTIRDyn(Expr func, Tuple args) { + static const Op& op = Op::Get("relax.vm.call_tir_dyn"); + return Call(op, {func, args}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.call_tir_dyn").set_body_typed(MakeCallTIRDyn); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 8e362bb4d55c..c6d335b2a1bd 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -115,7 +115,30 @@ inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx } /*! - * \brief Infer the struct info for unary arithmetic elementwise ops. It's also + * \brief Infer the struct info by returning the struct info of the input argument. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \tparam arg_index The index of the argument to infer the output dtype from. + * \return The inferred struct info. + */ +template +StructInfo ReturnStructInfoFromArg(const Call& call, const BlockBuilder& ctx) { + Op op = Downcast(call->op); + int n_input = op->arguments.size(); + if (static_cast(call->args.size()) != n_input) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << op << " op should have " << n_input << " arguments"); + } + if (arg_index >= n_input) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << op << " op has only " << n_input + << "arguments, but try to get the arg with index " << arg_index); + } + return GetStructInfo(call->args[arg_index]); +} + +/*! + * \brief Infer the struct info for unary arithmetic elementwise ops. It's also * used in some NN operators. * \param call The context Call to the operator. * \param ctx The error reporting context. diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 0ef63c8a4147..15a4f8702b03 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -19,7 +19,6 @@ /*! * \file src/runtime/relax_vm/builtin.cc */ -#include #include #include #include @@ -214,14 +213,13 @@ TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo * \param err_ctx Additional context if error occurs. */ void CheckTupleInfo(ObjectRef arg, int64_t size, Optional err_ctx) { - using Tuple = runtime::ADT; // a function that lazily get context for error reporting - auto* ptr = arg.as(); + auto* ptr = arg.as(); CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Tuple but get " << arg->GetTypeKey(); - CHECK(static_cast(ptr->size) == size) + CHECK(static_cast(ptr->size()) == size) << "ValueError: " << err_ctx.value_or("") << " expect a Tuple with " << size << " elements, " - << " but get a Tuple with " << ptr->size << " elements."; + << " but get a Tuple with " << ptr->size() << " elements."; } TVM_REGISTER_GLOBAL("vm.builtin.check_tuple_info").set_body_typed(CheckTupleInfo); @@ -321,6 +319,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.copy").set_body([](TVMArgs args, TVMRetValue* rv *rv = args[0]; }); +TVM_REGISTER_GLOBAL("vm.builtin.reshape").set_body_typed([](NDArray data, ShapeTuple new_shape) { + return data.CreateView(new_shape, data->dtype); +}); + /*! * \brief Load the scalar value in cond and return the result value. * \param cond The condition @@ -367,8 +369,15 @@ TVM_REGISTER_GLOBAL("vm.builtin.read_if_cond").set_body_typed(ReadIfCond); //------------------------------------- // Data structure API //------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem").set_body_typed([](runtime::ADT arr, int64_t index) { - return arr[index]; +TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem") + .set_body_typed([](runtime::Array arr, int64_t index) { return arr[index]; }); + +TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Array arr; + for (int i = 0; i < args.num_args; ++i) { + arr.push_back(args[i].operator ObjectRef()); + } + *rv = arr; }); } // namespace relax_vm diff --git a/tests/python/relax/test_runtime_builtin.py b/tests/python/relax/test_runtime_builtin.py new file mode 100644 index 000000000000..b4ba54b45554 --- /dev/null +++ b/tests/python/relax/test_runtime_builtin.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import pytest +import numpy as np + +from tvm.ir import assert_structural_equal +from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode + + +def test_make_shape(): + MK = MakeShapeCode + make_shape = tvm.get_global_func("vm.builtin.make_shape") + heap = tvm.nd.array(np.arange(10).astype("int64")) + s = make_shape(heap, 3, MK.USE_IMM, 10, MK.LOAD_SHAPE, 0, MK.LOAD_SHAPE, 2) + + assert s == tvm.runtime.container.ShapeTuple([10, 0, 2]) + + +def test_match_shape(): + MS = MatchShapeCode + match_shape = tvm.get_global_func("vm.builtin.match_shape") + heap = tvm.nd.array(np.zeros(10).astype("int64")) + + assert heap.numpy()[2] == 0 + + s = tvm.runtime.container.ShapeTuple([1, 2, 3]) + x = tvm.nd.array(np.zeros([1, 2, 3])) + + match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP, 2, MS.NO_OP, 0, "") + + assert heap.numpy()[2] == 2 + + match_shape( + x, + heap, + 3, + MS.ASSERT_EQUAL_TO_IMM, + 1, + MS.ASSERT_EQUAL_TO_LOAD, + 2, + MS.ASSERT_EQUAL_TO_IMM, + 3, + "", + ) + + with pytest.raises(RuntimeError): + match_shape(s, heap, 2, MS.ASSERT_EQUAL_TO_IMM, 1, MS.STORE_TO_HEAP, 2, "") + + with pytest.raises(RuntimeError): + match_shape(s, heap, 3, MS.ASSERT_EQUAL_TO_IMM, 2, MS.STORE_TO_HEAP, 2, MS.NO_OP, 0, "") + + +def test_check_shape_info(): + check_shape_info = tvm.get_global_func("vm.builtin.check_shape_info") + s = tvm.runtime.container.ShapeTuple([1, 2, 3]) + + check_shape_info(s, 3, "") + check_shape_info(s, -1, "") + + # wrong ndim + with pytest.raises(ValueError): + check_shape_info(s, 2, "") + + # wrong type + with pytest.raises(TypeError): + check_shape_info([], 2, "") + + +def test_check_tensor_info(): + check_tensor_info = tvm.get_global_func("vm.builtin.check_tensor_info") + x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + + check_tensor_info(x, 2, "int32", "") + check_tensor_info(x, -1, "int32", "") + check_tensor_info(x, 2, "", "") + check_tensor_info(x, -1, "", "") + + # allow not passing in dtype + check_tensor_info(x, 2, "") + check_tensor_info(x, -1, "") + + # ndim mismatch + with pytest.raises(ValueError, match=r".* ndim .*"): + check_tensor_info(x, 3, "int32", "") + + # dtype mismatch + with pytest.raises(ValueError, match=r"myerror.* dtype .*"): + check_tensor_info(x, 2, "float32", "myerror") + + # error with context + with pytest.raises(ValueError, match=r".* myerror .*"): + check_tensor_info(x, 3, "myerror") + + # wrong type + with pytest.raises(TypeError): + check_tensor_info([], 2, "", "") + + +def test_check_tuple_info(): + check_tuple_info = tvm.get_global_func("vm.builtin.check_tuple_info") + x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + t = tvm.runtime.convert([x, x, x]) + + check_tuple_info(t, 3, "") + + # size + with pytest.raises(ValueError, match=r".*elements.*"): + check_tuple_info(t, 2, "") + + # wrong type + with pytest.raises(TypeError): + check_tuple_info(x, 2, "") + + +def test_check_func_info(): + check_func_info = tvm.get_global_func("vm.builtin.check_func_info") + f = tvm.runtime.convert(lambda x: x) + x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + + check_func_info(f, "") + + # wrong type + with pytest.raises(TypeError, match=".*myerror.*"): + check_func_info(x, "myerror") + + +def test_tuple_getitem(): + tuple_getitem = tvm.get_global_func("vm.builtin.tuple_getitem") + x = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + y = tvm.nd.array(np.zeros((2, 3)).astype("int32")) + t = tvm.runtime.convert([x, y]) + + assert tuple_getitem(t, 0) == x + assert tuple_getitem(t, 1) == y + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index e2cb8bc5fc32..58596f968f98 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring +import tvm import pytest from tvm import IRModule, relax, tir from tvm.script import relax as R @@ -447,42 +448,4 @@ def test_if(): if __name__ == "__main__": - test_function() - test_extern_func() - - test_object_struct_info() - test_prim_struct_info() - test_shape_struct_info_0() - test_shape_struct_info_1() - test_shape_struct_info_2() - test_tensor_struct_info() - test_tuple_struct_info_empty() - test_tuple_struct_info() - test_func_struct_info() - - test_shape_type() - test_object_type() - test_dyn_tensor_type() - test_packed_func_type() - test_tuple_type() - test_func_type() - - test_prim_value() - test_string_imm() - test_data_type_imm() - - test_var() - test_dataflow_var() - # - test_tuple() - test_tuple_get_item() - test_shape_expr() - test_call() - - test_seq_expr() - test_binding_block() - test_dataflow_block() - - test_match_cast() - test_var_binding() - test_if() + tvm.testing.main() diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py new file mode 100644 index 000000000000..b5e77091776a --- /dev/null +++ b/tests/python/relax/test_vm_codegen_only.py @@ -0,0 +1,333 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test last-stage of codegen VM. + +Restrictions: all shape lowered, explicit allocation. +""" +import tvm +import pytest +import numpy as np +from tvm import relax, TVMError +from tvm.script import relax as R, tir as T +from tvm.relax.testing.vm import check_saved_func +from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode + +EXEC_MODE = ["bytecode"] + + +def codegen(mod, target, exec_mode="bytecode"): + builder = relax.ExecBuilder() + tir_mod = relax.vm._vmcodegen(builder, mod, exec_mode=exec_mode) + return relax.vm._vmlink(builder, target, tir_mod) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_copy(exec_mode): + @tvm.script.ir_module + class TestVMMove: + @R.function + def foo(x: R.Tensor((3, 4), "float32")): + R.func_attr({"global_symbol": "foo"}) + z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32"))) + return z + + mod = TestVMMove + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + inp = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = check_saved_func(vm, "foo", inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_if_cond_const(exec_mode): + @tvm.script.ir_module + class TestVMIfCondConst: + @R.function + def main(x: R.Tensor(ndim=2, dtype="float32")) -> R.Tensor(ndim=2, dtype="float32"): + R.func_attr({"global_symbol": "main"}) + if relax.const(True, dtype="bool"): + ret = x + else: + ret = x + return ret + + mod = TestVMIfCondConst + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(3, 4)) + res = vm["main"](inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy()) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_exec_serialize_export_library(exec_mode): + @tvm.script.ir_module + class TestVMMove: + @R.function + def foo(x: R.Tensor((3, 4), "float32")): + R.func_attr({"global_symbol": "foo"}) + z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((3, 4), dtype="float32"))) + return z + + mod = TestVMMove + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target) + from tvm.contrib import utils + + temp_dir = utils.tempdir() + path_exec = temp_dir.relpath("exec.so") + ex.mod.export_library(path_exec) + + loaded_exec = relax.vm.Executable(tvm.runtime.load_module(path_exec)) + assert ex.as_text() == loaded_exec.as_text() + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_if_cond(exec_mode): + @tvm.script.ir_module + class TestVMCompileIf: + @R.function + def ife(cond: R.Tensor((), "bool"), x: R.Tensor((3, 4), "float32")) -> R.Tensor: + R.func_attr({"global_symbol": "ife"}) + if cond: + w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) + else: + w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor)) + return w + + mod = TestVMCompileIf + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(3, 4)) + res = vm["ife"](tvm.nd.array(1), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7) + res = vm["ife"](tvm.nd.array(True), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7) + res = vm["ife"](tvm.nd.array(0), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7) + res = vm["ife"](tvm.nd.array(False), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_return_const_tuple(exec_mode): + @tvm.script.ir_module + class ReturnConstTuple: + @R.function + def main(x: R.Tensor(ndim=2, dtype="float32")): + R.func_attr({"global_symbol": "main"}) + y = R.const([1, 2]) + z = (y, R.const([3, 4]), x) + return z + + mod = ReturnConstTuple + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(2, 3)) + res0, res1, res2 = vm["main"](inp) + tvm.testing.assert_allclose(res0.numpy(), np.array([1, 2])) + tvm.testing.assert_allclose(res1.numpy(), np.array([3, 4])) + tvm.testing.assert_allclose(res2.numpy(), inp.numpy()) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_const_as_call_arg(exec_mode): + @tvm.script.ir_module + class TestVMConstAsCallArg: + @R.function + def main(x: R.Tensor(ndim=2, dtype="float32")): + R.func_attr({"global_symbol": "main"}) + a = R.call_packed( + "test.vm.add", + relax.const([1, 2]), + relax.const([3, 4]), + sinfo_args=(R.Tensor(ndim=2, dtype="float32")), + ) + b = R.call_packed( + "test.vm.add", + a, + x, + sinfo_args=(R.Tensor(ndim=2, dtype="float32")), + ) + return b + + mod = TestVMConstAsCallArg + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(1, 2)) + res = vm["main"](inp) + tvm.testing.assert_allclose(res.numpy(), np.array([4, 6]) + inp.numpy()) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_shape_check_builtin(exec_mode): + MS = MatchShapeCode + MK = MakeShapeCode + # slot assignment: + # 0: n, 1: m + sindex = {"n": 0, "m": 1} + + @tvm.script.ir_module + class TestVMShapeCheck: + @R.function + def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3): + R.func_attr({"global_symbol": "main"}) + n = T.Var("n", "int64") + k = T.Var("k", "int64") + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(3)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 2, + MS.STORE_TO_HEAP, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + # construct shape value for return + s = R.call_packed( + "vm.builtin.make_shape", + shape_heap, + 3, + MK.LOAD_SHAPE, + sindex["m"], + MK.LOAD_SHAPE, + sindex["n"], + MK.USE_IMM, + 2, + sinfo_args=[R.Shape(ndim=3)], + ) + return s + + mod = TestVMShapeCheck + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x = tvm.nd.array(np.zeros((1, 2)).astype("float32")) + res = vm["main"](x) + assert res == tvm.runtime.container.ShapeTuple([2, 1, 2]) + + # wrong input type + with pytest.raises(TypeError): + vm["main"]([]) + + # wrong ndim + with pytest.raises(ValueError, match=r".*ndim.*"): + vm["main"](tvm.nd.array(np.zeros(1).astype("float32"))) + + # wrong dtype + with pytest.raises(ValueError, match=r".*dtype.*"): + vm["main"](tvm.nd.array(np.zeros((1, 2)).astype("int32"))) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_prim_value(exec_mode): + @tvm.script.ir_module + class TestVMPrimValue: + @R.function + def main(): + R.func_attr({"global_symbol": "main"}) + ret = R.prim_value(T.int64(1)) + return ret + + mod = TestVMPrimValue + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + assert res == 1 + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_string_imm(exec_mode): + @tvm.script.ir_module + class TestVMStringImm: + @R.function + def main(): + R.func_attr({"global_symbol": "main"}) + ret = R.str("hello") + return ret + + mod = TestVMStringImm + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + assert res == "hello" + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_datatype_imm(exec_mode): + @tvm.script.ir_module + class TestDataTypeImm: + @R.function + def main(): + R.func_attr({"global_symbol": "main"}) + ret = R.dtype("float32") + return ret + + mod = TestDataTypeImm + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + assert res == "float32" + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_builtin_reshape(exec_mode): + @tvm.script.ir_module + class TestVMBuiltinReshape: + @R.function + def main(x: R.Tensor((3, 4), "float32")): + R.func_attr({"global_symbol": "main"}) + y = R.call_packed( + "vm.builtin.reshape", x, (6, 2), sinfo_args=R.Tensor((6, 2), "float32") + ) + return y + + mod = TestVMBuiltinReshape + target = tvm.target.Target("llvm", host="llvm") + ex = codegen(mod, target, exec_mode) + dev = tvm.cpu() + vm = relax.VirtualMachine(ex, dev) + + input_np = np.random.rand(3, 4).astype("float32") + input = tvm.nd.array(input_np, dev) + res = vm["main"](input) + expected = input_np.reshape(6, 2) + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-7, atol=1e-7) + + +if __name__ == "__main__": + tvm.testing.main() From 55c2d1f6657271e3e2449f02967d8f23a008a46d Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Fri, 10 Feb 2023 23:33:46 -0800 Subject: [PATCH 10/81] [Unity] Relax VM shape lowering pass (#13956) This PR introduces Relax `FunctionPass` and `DataflowBlockPass` API, and the `VMShapeLower` pass to lower the shape expression in Relax to TIR functions and VM shape heap builtin functions. Co-Authored-by: Ziheng Jiang Co-Authored-by: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Co-Authored-by: Altan Haan Co-Authored-by: Junru Shao Co-Authored-by: Prakalp Srivastava Co-Authored-by: Ruihang Lai Co-Authored-by: Siyuan Feng Co-Authored-by: Steven S. Co-Authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Co-Authored-by: Tianqi Chen Co-Authored-by: Yong Wu --- include/tvm/relax/backend.h | 44 ++ include/tvm/relax/transform.h | 72 ++ python/tvm/relax/__init__.py | 1 + python/tvm/relax/transform/__init__.py | 20 + python/tvm/relax/transform/_ffi_api.py | 19 + python/tvm/relax/transform/transform.py | 345 +++++++++ src/relax/backend/vm/vm_shape_lower.cc | 725 ++++++++++++++++++ src/relax/ir/transform.cc | 413 ++++++++++ .../test_backend_transform_shape_lower.py | 429 +++++++++++ 9 files changed, 2068 insertions(+) create mode 100644 include/tvm/relax/backend.h create mode 100644 include/tvm/relax/transform.h create mode 100644 python/tvm/relax/transform/__init__.py create mode 100644 python/tvm/relax/transform/_ffi_api.py create mode 100644 python/tvm/relax/transform/transform.py create mode 100644 src/relax/backend/vm/vm_shape_lower.cc create mode 100644 src/relax/ir/transform.cc create mode 100644 tests/python/relax/test_backend_transform_shape_lower.py diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h new file mode 100644 index 000000000000..4ebeacac0ff3 --- /dev/null +++ b/include/tvm/relax/backend.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/backend.h + * \brief Relax backend specific transformation passes. + */ +#ifndef TVM_RELAX_BACKEND_H_ +#define TVM_RELAX_BACKEND_H_ + +#include + +namespace tvm { +namespace relax { +namespace transform { + +/*! + * \brief Lower the shape expression in relax to VM shape heap and TIR functions. + * + * \return The Pass. + */ +TVM_DLL Pass VMShapeLower(); + +} // namespace transform +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_H_ diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h new file mode 100644 index 000000000000..fa288a7f06c2 --- /dev/null +++ b/include/tvm/relax/transform.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform.h + * \brief Relax specific transformation passes. + */ +#ifndef TVM_RELAX_TRANSFORM_H_ +#define TVM_RELAX_TRANSFORM_H_ + +#include +#include + +namespace tvm { +namespace relax { +namespace transform { + +using Pass = tvm::transform::Pass; +using PassInfo = tvm::transform::PassInfo; +using PassContext = tvm::transform::PassContext; +using Function = tvm::relax::Function; +using DataflowBlock = tvm::relax::DataflowBlock; + +/*! + * \brief Create a function pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the function pass. + * \param name The name of the function pass. + * \param required The list of the passes that the function pass is dependent on. + * + * \return The created function pass. + */ +TVM_DLL Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required); + +/*! + * \brief Create a dataflowblock pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the dataflowblock pass. + * \param name The name of the dataflowblock pass. + * \param required The list of the passes that the dataflowblock pass is dependent on. + * + * \return The created dataflowblock pass. + */ +TVM_DLL Pass CreateDataflowBlockPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required); + +} // namespace transform +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_TRANSFORM_H_ diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index ce175354d02c..a6306b788e5a 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -20,6 +20,7 @@ from . import expr from . import ty from . import analysis +from . import transform from . import vm from . import block_builder from . import op diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py new file mode 100644 index 000000000000..eb4d5f710c53 --- /dev/null +++ b/python/tvm/relax/transform/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax transformations. """ + +from .transform import * diff --git a/python/tvm/relax/transform/_ffi_api.py b/python/tvm/relax/transform/_ffi_api.py new file mode 100644 index 000000000000..667aa62c2c95 --- /dev/null +++ b/python/tvm/relax/transform/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.transform""" +import tvm._ffi + +tvm._ffi._init_api("relax.transform", __name__) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py new file mode 100644 index 000000000000..f20f06c52284 --- /dev/null +++ b/python/tvm/relax/transform/transform.py @@ -0,0 +1,345 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Relax transformation passes.""" +import functools +import inspect +import types +from typing import Callable, Union + +import tvm.ir +from . import _ffi_api + + +@tvm._ffi.register_object("relax.FunctionPass") +class FunctionPass(tvm.ir.transform.Pass): + """A pass that works on each tvm.relax.Function in a module. A function + pass class should be created through `function_pass`. + """ + + +@tvm._ffi.register_object("relax.DataflowBlockPass") +class DataflowBlockPass(tvm.ir.transform.Pass): + """A pass that works on each tvm.relax.DataflowBlock in a module.""" + + +def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass: + """Lower the symbolic shape and argument and match-cast structinfo matching. + + Parameters + ---------- + emit_err_ctx: Optional[bool] + Whether emit err context string, can be turned off for testing purposes. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.VMShapeLower(emit_err_ctx) # type: ignore + + +def _wrap_class_function_pass(pass_cls, pass_info): + """Wrap a python class as function pass.""" + + class PyFunctionPass(FunctionPass): + """Internal wrapper class to create a class instance.""" + + def __init__(self, *args, **kwargs): + # initialize handle in case pass_cls creation failed. + self.handle = None + inst = pass_cls(*args, **kwargs) + + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(func, mod, ctx): + return inst.transform_function(func, mod, ctx) + + self.__init_handle_by_constructor__( + _ffi_api.MakeFunctionPass, _pass_func, pass_info # type: ignore + ) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__) + PyFunctionPass.__name__ = pass_cls.__name__ + PyFunctionPass.__doc__ = pass_cls.__doc__ + PyFunctionPass.__module__ = pass_cls.__module__ + return PyFunctionPass + + +def function_pass( + pass_func=None, + opt_level=None, + name=None, + required=None, +) -> Union[Callable, FunctionPass]: + """Decorate a function pass. + + This function returns a callback when pass_func + is provided. Otherwise, it returns the created function pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]] + The transformation function or class. + + opt_level : int + The optimization level of this function pass. + + name : Optional[str] + The name of the function pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the function pass is dependent on. + + Returns + ------- + create_function_pass : Union[Callable, FunctionPass] + + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new FunctionPass will be returned when we decorate a pass function. + A new FunctionPass class will be returned when we decorate a class type. + + Examples + -------- + The following code block decorates a function pass class. + + .. code-block:: python + + @relax.transform.function_pass(opt_level=1) + class TestReplaceFunc: + def __init__(self, new_func): + self.new_func = new_func + + def transform_function(self, func, mod, ctx): + # just for demo purposes + # transform func to new_func + return self.new_func + + @R.function + def f1(x: Tensor[(m, n), "float32"]): + return x + + @tvm.script.ir_module + class InputMod: + @R.function + def f2(x: Tensor[(m, n), "float32"]): + gv0 = relax.add(x, x) + return gv0 + # fpass is now a special pass that replaces every + # function to f1 + fpass = TestReplaceFunc(f1) + # now every function in InputMod is replaced by f1 + updated_mod = fpass(InputMod) + + + The following code creates a function pass by decorating + a user defined transform function. + + .. code-block:: python + + @relax.transform.function_pass(opt_level=2) + def transform(func, mod, ctx): + # my transformations here. + return func + + function_pass = transform + assert isinstance(function_pass, relax.transform.FunctionPass) + assert function_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = function_pass(m) + # Now transform should have been applied to every function in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the function pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + "list/tuple.") + + def create_function_pass(pass_arg): + """Internal function that creates a function pass""" + fname = name if name else pass_arg.__name__ + info = tvm.transform.PassInfo(opt_level, fname, required) + if inspect.isclass(pass_arg): + return _wrap_class_function_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Function pass") + return _ffi_api.MakeFunctionPass(pass_arg, info) # type: ignore + + if pass_func: + return create_function_pass(pass_func) + return create_function_pass + + +def _wrap_class_dataflowblock_pass(pass_cls, pass_info): + """Wrap a python class as dataflowblock pass""" + + class PyDataflowBlockPass(DataflowBlockPass): + """Internal wrapper class to create a class instance.""" + + def __init__(self, *args, **kwargs): + # initialize handle in case pass_cls creation failed. + self.handle = None + inst = pass_cls(*args, **kwargs) + + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(func, mod, ctx): + return inst.transform_dataflowblock(func, mod, ctx) + + self.__init_handle_by_constructor__( + _ffi_api.MakeDataflowBlockPass, _pass_func, pass_info # type: ignore + ) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyDataflowBlockPass.__init__, pass_cls.__init__) + PyDataflowBlockPass.__name__ = pass_cls.__name__ + PyDataflowBlockPass.__doc__ = pass_cls.__doc__ + PyDataflowBlockPass.__module__ = pass_cls.__module__ + return PyDataflowBlockPass + + +def dataflowblock_pass( + pass_func=None, opt_level=None, name=None, required=None +) -> Union[Callable, DataflowBlockPass]: + """Decorate a dataflowblock pass. + + This function returns a callback when pass_func + is provided. Otherwise, it returns the created dataflowblock pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(DataflowBlock, Module, PassContext) -> DataflowBlock]] + The transformation function or class. + + opt_level : int + The optimization level of this dataflowblock pass. + + name : Optional[str] + The name of the dataflowblock pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the dataflowblock pass is dependent on. + + Returns + ------- + create_dataflowblock_pass : Union[Callable, DataflowBlockPass] + + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new DataflowBlockPass will be returned when we decorate a pass function. + A new DataflowBlockPass class will be returned when we decorate a class type. + + Examples + -------- + The following code block decorates a dataflowblock pass class. + + .. code-block:: python + + @relax.transform.dataflowblock_pass(opt_level=1) + class TestReplaceBinding: + # Simple test function to replace the first VarBinding to another. + + def __init__(self): + # create a new VarBinding + m, n = tir.Var("m", "int64"), tir.Var("n", "int64") + lv0 = relax.Var("lv1", relax.TensorStructInfo([m, n], "float32")) + val = relax.const(np.random.rand(24, 56)) + self.new_binding = relax.VarBinding(lv0, val) + + def transform_dataflowblock(self, block, mod, ctx): + # just for demo purposes + # Replace the first binding in the DataflowBlock + new_bindings = [self.new_binding, block.bindings[1]] + new_block = relax.expr.DataflowBlock(new_bindings, block.span) + return new_block + + @tvm.script.ir_module + class InputMod: + @R.function + def f1(x: Tensor[(m, n), "float32"]): + with relax.dataflow(): + lv0 = relax.multiply(x, x) + gv0 = relax.add(x, x) + relax.output(gv0) + return gv0 + # block_pass is now a special pass that replaces every + # first binding to the constant value binding + block_pass = TestReplaceBinding() + # now every first binding in DataflowBlock of InputMod + # is replaced by new_binding + updated_mod = block_pass(InputMod) + + + The following code creates a dataflowblock pass by decorating + a user defined transform function. + + .. code-block:: python + + @relax.transform.dataflowblock_pass(opt_level=2) + def transform(block, mod, ctx): + # my transformations here. + return block + + block_pass = transform + assert isinstance(block_pass, relax.transform.DataflowBlockPass) + assert block_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = block_pass(m) + # Now transform should have been applied to every DataflowBlock in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the dataflowblock pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + "list/tuple.") + + def create_dataflowblock_pass(pass_arg): + """Internal function that creates a dataflowblock pass""" + fname = name if name else pass_arg.__name__ + info = tvm.transform.PassInfo(opt_level, fname, required) + if inspect.isclass(pass_arg): + return _wrap_class_dataflowblock_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for DataflowBlock pass") + return _ffi_api.MakeDataflowBlockPass(pass_arg, info) # type: ignore + + if pass_func: + return create_dataflowblock_pass(pass_func) + return create_dataflowblock_pass diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc new file mode 100644 index 000000000000..090bcf01b5a5 --- /dev/null +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -0,0 +1,725 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/backend/vm/vm_shape_lower.cc + * \brief Lower the function boundary type checks and symbolic shape computations. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief A slot used in PrimExpr lowering. */ +struct PrimExprSlot { + /*! \brief The existing */ + PrimExpr expr; + /*! \brief The slot index */ + int index; + // The following three members are auxiliary data + // to help shape rewriting. + /*! + * \brief List of slots whose PrimExpr uses this PrimExpr. + * \note Users won't be empty only if PrimExpr is a Var and it does not include itself. + */ + std::vector user_slots; + /*! + * \brief Number of outstanding vars that are not defined in this PrimExpr. + * \note This is a helper counter used in analysis to perform computations. + */ + int outstanding_defs = 0; + /*! \brief Whether we have computed the value. */ + bool value_computed = false; +}; + +/*! + * \brief Helper dats structure to collect pairs of match shapes + * in a recursive matching process. + */ +struct MatchShapeTodoItem { + Expr input; + Array pattern; + String err_ctx; +}; + +/*! \brief Slot map used for shape lowering. */ +using PrimExprSlotMap = + std::unordered_map; + +// Collector to collect PrimExprSlotMap +class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { + public: + // collect the PrimExpr slot for a given function + static void Collect(Function func, std::vector>* slot_vec, + PrimExprSlotMap* slot_map) { + PrimExprSlotCollector collector; + collector.slot_vec_ = slot_vec; + collector.slot_map_ = slot_map; + // collect shape declaration in func params + for (auto param : func->params) { + collector.VisitStructInfo(GetStructInfo(param)); + collector.VisitExpr(param); + } + collector.VisitExpr(func->body); + } + + private: + void VisitPrimExpr(const PrimExpr& expr) final { + if (expr->IsInstance()) return; + if (slot_map_->count(expr) == 0) { + auto slot = std::make_unique(); + slot->expr = expr; + slot->index = static_cast(slot_vec_->size()); + slot_map_->emplace(expr, slot.get()); + slot_vec_->emplace_back(std::move(slot)); + } + } + + void VisitBinding_(const MatchCastNode* op) final { + // Visit the match cast struct info so we can define + // the symbolic variables here. + this->VisitStructInfo(op->struct_info); + } + + void VisitExpr_(const FunctionNode* op) final { + // Do not recurse into function node as it is self-contained + } + + void VisitStructInfo_(const FuncStructInfoNode* op) final { + // Do not recurse into function struct info as it is self-contained + } + + void VisitStructInfoExprField(const PrimExpr& expr) final { VisitPrimExpr(expr); } + + void VisitStructInfoExprField(const Expr& expr) final { ExprVisitor::VisitExpr(expr); } + + std::vector>* slot_vec_; + PrimExprSlotMap* slot_map_; +}; + +/*! + * \brief Main logic to transform the shape lowered functions + * + * Consider the following input: + * + * \code + * + * def f(x: R.Tuple(R.Tensor([m, n+1]), R.Tensor([n, 2])) -> R.Tensor: + * return x + * + * \endcode + * + * Overall flow of the algorithm: + * - Preprocess: PrimExprSlot collection, we scan the function and allocate PrimExprSlot + * for each PrimExpr. In the above example, the result mapping from the slot index + * to expr would be {0:m, 1: n+1: 2: n}. Note that "n+1" also get a slot. + * PrimExprSlot also comes with auxiliary fields that track whether its value + * can be readily computed. + * + * Steps at each matching point: + * - Step 0: We call CheckMatchCast, + * which will recursively unpack the StructInfo, and generate static information checks. + * Note that this step only generates functions for checking types and ndim info, but not + * the symbolic shape variables. The symbolic shape-matching results will be returned as + * vector. This is because symbolic shape matching may not be completed + * in a single round. Importantly, CheckMatchCast also deals with tuple unpacking. + * + * - Step 1: We then call RunMatch to generate the statements for matching symbolic shapes. + * In the above example, the first round will store the value of m, n to their corresponding + * slot. RunMatch may return outstanding items. In the above example x.shape[1] == n+1 cannot + * be checked in the first round. RunMatch will populate new vars(this case n, m), these vars + * are added to a ready queue (ready_vars_) + * + * - Step 2: We EmitOutstandingPrimExprCompute to check if ready_vars will trigger new values + * to be computed. We eagerly compute all the outstanding values. The trigger is done through + * a ref counter which decreases when each outstanding def is satisfied. + * This step can also generate additional TIR functions to carry out shape computations. + * + * - Step 3: RunMatch again for given outstanding match todos. This time all invariants + * should be checked. + * + * The above step would populate each slot(which is backed by an element in shape_heap). + * Each time we find a symbolic shape tuple, we call MakeShape for given slot indices + * in the shape_heap. + * + * + * Key functions in the flow: + * - PrimExprSlotCollector: preprocessing and collecting the slots + * - CheckMatchCast: recursively structinfo unpacking, generate checks and match items. + * - RunMatch: generate symbolic shape matches + * - EmitOutstandingPrimExprCompute: tracks the variables to be computed and emit shape computation + * - VisitExpr_(ShapeExprNode*): makes symbolic shape tuple. + * + * The checks and symbolic shape all maps to runtime builtin functions. Please checkout + * runtime/relax_vm/builtin.cc for their definitions. + * + * Shape computation are lowered to host-side TIR functions that load var from slot + * and store computed results into the slot. For a given slot map: {0:m, 1: n+1: 2: n} + * It will create the shape_func below that loads data from H[2](n's slot) run compute + * and store back to H[1](n+1's slot). + * + * \code + * + * @T.prim_func + * def shape_func(H: T.Buffer([3], "int64")): + * H[1] = H[2] + 1 + * + * \endcode + * + * The current implementation will batch all shape computations at each match point. + * For example, all the expressions that depend on n, m will be computed in a single + * shape_func at the function boundary. If there are follow-up match_cast points, + * that defines new variable, then we might we will generate new shape functions + * to compute expressions that depend on these variables. + */ +class VMShapeLowerMutator + : public ExprMutator, + public StructInfoFunctor*)> { + public: + static IRModule Lower(IRModule mod, bool emit_err_ctx) { + VMShapeLowerMutator mutator(mod, emit_err_ctx); + + for (auto& kv : mod->functions) { + if (auto* func = kv.second.as()) { + Function updated_func = mutator.Rewrite(kv.first, GetRef(func)); + mutator.builder_->UpdateFunction(kv.first, updated_func); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + explicit VMShapeLowerMutator(IRModule mod, bool emit_err_ctx) + : ExprMutator(mod), emit_err_ctx_(emit_err_ctx) {} + + using ExprMutator::VisitExpr_; + + // Unit rewrite function per function. + Function Rewrite(GlobalVar gvar, Function func) { + // prepare mapping and heap var + PrimExprSlotCollector::Collect(func, &slot_vec_, &slot_map_); + heap_size_ = IntImm(ShapeDType(), static_cast(slot_vec_.size())); + VarBinding shape_heap_binding = this->AllocShapeHeapBinding(heap_size_); + shape_heap_ = shape_heap_binding->var; + + // prepare slot information + this->PopulateSlotInfo(); + + Array blocks; + + builder_->BeginScope(func->params); + + { + // Check the parameter section. + builder_->BeginBindingBlock(); + this->builder_->EmitNormalized(shape_heap_binding); + std::vector match_todos; + for (size_t i = 0; i < func->params.size(); ++i) { + StructInfo sinfo = GetStructInfo(func->params[i]); + std::ostringstream err_ctx; + err_ctx << "ErrorContext(fn=" << gvar->name_hint << ", loc=param[" << i + << "], param=" << func->params[i]->name_hint() << ", annotation=" << sinfo << ") "; + this->CheckMatchCast(sinfo, func->params[i], true, err_ctx.str(), &match_todos); + } + // insert heap generation logic. + match_todos = this->RunMatch(match_todos, false); + this->EmitOutstandingPrimExprCompute(); + this->RunMatch(match_todos, true); + + BindingBlock pre_block = builder_->EndBlock(); + blocks.push_back(pre_block); + } + + // new body. + auto body_seq = Downcast(this->VisitWithNewScope(func->body, func->params)); + blocks.insert(blocks.end(), body_seq->blocks.begin(), body_seq->blocks.end()); + + { + // Insert the return value check + builder_->BeginBindingBlock(); + std::ostringstream err_ctx; + err_ctx << "ErrorContext(fn=" << gvar->name_hint + << ", loc=return, annotation=" << func->ret_struct_info << ") "; + std::vector match_todos; + // NOTE: the return value's shape computation must already be defined. + this->CheckMatchCast(func->ret_struct_info, body_seq->body, false, err_ctx.str(), + &match_todos); + // NOTE: the return value's shape computation must already be defined. + this->RunMatch(match_todos, true); + BindingBlock post_block = builder_->EndBlock(); + blocks.push_back(post_block); + } + + auto new_body = builder_->Normalize(SeqExpr(blocks, body_seq->body)); + // create a new function + return Function(func->params, new_body, func->ret_struct_info, func->attrs); + } + + //------------------------------------------------------- + // PrimExpr slot handling + //------------------------------------------------------- + static DataType ShapeDType() { return DataType::Int(64); } + + /*! \brief populate additional information in the slot. */ + void PopulateSlotInfo() { + for (auto& kv : slot_map_) { + auto* slot = kv.second; + if (!slot->expr.as()) { + Array dep_vars = tir::UndefinedVars(slot->expr); + for (auto var : dep_vars) { + auto it = slot_map_.find(var); + ICHECK(it != slot_map_.end()) + << "Var " << var << "is not defined in the function but is referenced by " + << slot->expr; + auto* var_slot = it->second; + // populate the use slot. + var_slot->user_slots.push_back(slot); + } + // set outstanding defs. + slot->outstanding_defs += static_cast(dep_vars.size()); + } + } + } + //------------------------------------------------------- + // Helper functions + //------------------------------------------------------- + StringImm GetErrContext(String err_ctx) const { + return emit_err_ctx_ ? StringImm(err_ctx) : StringImm(""); + } + + VarBinding AllocShapeHeapBinding(IntImm heap_size) { + if (heap_size->value > 0) { + TensorStructInfo heap_sinfo(ShapeDType(), 1); + Var var("shape_heap", heap_sinfo); + // set up the builtin func. + Call call(call_builtin_with_ctx_op_, + {builtin_alloc_shape_heap_, Tuple({PrimValue(heap_size)})}, Attrs(), {heap_sinfo}); + UpdateStructInfo(call, heap_sinfo); + return VarBinding(var, call); + } else { + Var var("shape_heap", ObjectStructInfo()); + Call call(null_value_op_, {}); + UpdateStructInfo(call, ObjectStructInfo()); + return VarBinding(var, call); + } + } + + //------------------------------------------------------- + // Expr mutation overloading. + //------------------------------------------------------- + Expr VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "VMShapeLower do not work for local functions, make sure " + << " to run it after LambdaLift"; + return GetRef(op); + } + + Expr VisitExpr_(const ShapeExprNode* op) final { + using runtime::relax_vm::MakeShapeCode; + // Constant shape can be preserved. + bool is_const_shape = std::all_of(op->values.begin(), op->values.end(), [](const PrimExpr& e) { + return e->IsInstance(); + }); + if (is_const_shape) { + return GetRef(op); + } + + Array args = {shape_heap_, PrimValue::Int64(static_cast(op->values.size()))}; + for (PrimExpr expr : op->values) { + if (auto* int_expr = expr.as()) { + args.push_back(PrimValue::Int64(static_cast(MakeShapeCode::kUseImm))); + args.push_back(PrimValue::Int64(int_expr->value)); + } else { + auto it = slot_map_.find(expr); + ICHECK(it != slot_map_.end()); + auto* slot = it->second; + ICHECK(slot->value_computed) << "PrimExpr " << expr << " has not been computed"; + args.push_back(PrimValue::Int64(static_cast(MakeShapeCode::kLoadShape))); + args.push_back(PrimValue::Int64(slot->index)); + } + } + + // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n]) + Call call(builtin_make_shape_, args, Attrs(), + {ShapeStructInfo(static_cast(op->values.size()))}); + return call; + } + + void VisitBinding_(const MatchCastNode* binding) final { + Expr value = ExprMutator::VisitExpr(binding->value); + std::vector match_todos; + std::ostringstream err_ctx; + err_ctx << "ErrorContext(match_cast, struct_info=" << binding->struct_info << ") "; + // always_check=false + this->CheckMatchCast(binding->struct_info, value, false, err_ctx.str(), &match_todos); + + match_todos = this->RunMatch(match_todos, false); + this->EmitOutstandingPrimExprCompute(); + this->RunMatch(match_todos, true); + + // These checks are emitted as extra, in codegen + // match-cast is simply ignored and treated as a normal binding. + builder_->EmitNormalized(GetRef(binding)); + } + + // Do not override shape in struct info fields + // We only override the shape that are already part of the normal function values + // If future passes lift those values out into the values, + // then codegen may not be able to handle symbolic values. + // Place this pass as last pass before codegen. + StructInfo VisitExprDepStructInfoField(const StructInfo& sinfo) final { return sinfo; } + + //------------------------------------------------------- + // Shape computations. + //------------------------------------------------------- + /*! + * \brief Execute the match todo items. + * + * This function can populate vars in the match items when seeing it for the first time. + * These new vars will be added to this->ready_vars_. + * + * If an item contains PrimExpr that are yet to be computed (but may be computable through + * vars defined in this round), it will be returned to the caller. + * + * The caller should call EmitOutstandingPrimExprCompute, then call RunMatch again. + * + * \param match_todos The list of match items to be executed. + * \param require_value_computed Whether we require all expr to be computed. + * \return List of outstanding items that contains value that are yet to be computed. + */ + std::vector RunMatch(const std::vector& match_todos, + bool require_value_computed) { + std::vector outstanding_todos; + + using runtime::relax_vm::MatchShapeCode; + for (const MatchShapeTodoItem& item : match_todos) { + int64_t shape_len = static_cast(item.pattern.size()); + bool all_nop = true; + int num_outstanding_exprs = 0; + + Array args = {item.input, shape_heap_, PrimValue::Int64(shape_len)}; + + for (PrimExpr expr : item.pattern) { + MatchShapeCode code = MatchShapeCode::kNoOp; + int64_t rvalue = 0; + if (auto* int_expr = expr.as()) { + code = MatchShapeCode::kAssertEqualToImm; + rvalue = int_expr->value; + } else { + auto it = slot_map_.find(expr); + ICHECK(it != slot_map_.end()); + auto* slot = it->second; + if (slot->value_computed) { + code = MatchShapeCode::kAssertEqualToLoad; + rvalue = slot->index; + } else { + // the value is not yet computed + ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not computed"; + if (expr.as()) { + // if it is a var, we will populate it in this round. + // otherwise, we skip and mark it as outstanding + code = MatchShapeCode::kStoreToHeap; + rvalue = slot->index; + slot->value_computed = true; + ready_vars_.push_back(slot); + } else { + code = MatchShapeCode::kNoOp; + rvalue = 0; + ++num_outstanding_exprs; + } + } + } + all_nop = all_nop && code == MatchShapeCode::kNoOp; + args.push_back(PrimValue::Int64(static_cast(code))); + args.push_back(PrimValue::Int64(rvalue)); + } + if (num_outstanding_exprs != 0) { + outstanding_todos.push_back(item); + } + args.push_back(GetErrContext(item.err_ctx)); + if (!all_nop) { + Call call(builtin_match_shape_, args, Attrs(), {void_sinfo_}); + builder_->Emit(call, "_"); + } + } + return std::move(outstanding_todos); + } + + /*! + * \brief Compute a list of prim expr that now be computed + * for given ready vars. + */ + std::vector GetReadyPrimExprSlots() { + std::vector to_compute; + for (PrimExprSlot* slot : ready_vars_) { + for (PrimExprSlot* user : slot->user_slots) { + ICHECK_GT(user->outstanding_defs, 0); + user->outstanding_defs -= 1; + if (user->outstanding_defs == 0) { + to_compute.push_back(user); + } + } + } + ready_vars_.clear(); + return to_compute; + } + + /*! + * \brief Check the dependent expressions of ready_vars_, + * + * If there are outstanding PrimExpr that can now be computed + * we generate a PrimFunc that compute the extra shape values + * + * We will then clear the ready_vars. + * + * \return Number of PrimExpr computed. + */ + size_t EmitOutstandingPrimExprCompute() { + std::vector to_compute = GetReadyPrimExprSlots(); + if (to_compute.size() == 0) return 0; + ICHECK_GT(heap_size_->value, 0); + // construct a PrimFunc that compute the shape. + tir::Var heap("heap", DataType::Handle()); + Array buffer_shape{heap_size_}; + tir::Buffer buffer = tir::decl_buffer(buffer_shape, ShapeDType(), "H", "global"); + Map buffer_map; + buffer_map.Set(heap, buffer); + + auto var_map = [&](const tir::Var& var) -> Optional { + auto it = slot_map_.find(var); + ICHECK(it != slot_map_.end()); + return tir::BufferLoad(buffer, {IntImm(ShapeDType(), it->second->index)}); + }; + + Array seq; + for (PrimExprSlot* slot : to_compute) { + ICHECK(!slot->value_computed); + slot->value_computed = true; + PrimExpr value = tir::Substitute(slot->expr, var_map); + seq.push_back(tir::BufferStore(buffer, value, {IntImm(ShapeDType(), slot->index)})); + } + + tir::Stmt body = tir::SeqStmt::Flatten(seq); + Array params{heap}; + Type ret_type = VoidType(); + + // TODO(relax-team): Consider attach the target attribute to + // the shape_func to indicate that this is a host function + // This could require us to attach target to the relax function here. + tir::PrimFunc shape_func(params, body, ret_type, buffer_map); + GlobalVar shape_func_var = builder_->AddFunction(shape_func, "shape_func"); + builder_->Emit(Call(shape_func_var, {shape_heap_}), "_"); + return to_compute.size(); + } + //------------------------------------------------------- + // StructInfo value match logic + // + // CheckMatchCast is the only function needed by + // other code sections + //------------------------------------------------------- + /*! + * \brief Insert runtime check of the match cast condition(value, struct_info). + * + * \param struct_info The struct info to be matched. + * \param value The input value. + * \param always_check Whether we insert runtime check even if we can prove + * that value's struct info already satisfies the condition. + * This option is necessary for argument checking per our calling convention. + * + * \param err_ctx Extra error context to bring more informative error reporting. + * \param match_todos List of match shape todo items collected when recursively + * visit the match cast. + */ + void CheckMatchCast(const StructInfo& struct_info, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) { + return this->VisitStructInfo(struct_info, value, always_check, err_ctx, match_todos); + } + + void VisitStructInfo(const StructInfo& struct_info, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + // short-cut, if the struct info already satisfies the + // constraint during match cast, we can skip matching + if (!always_check && IsBaseOf(struct_info, GetStructInfo(value))) return; + return StructInfoFunctor::VisitStructInfo(struct_info, value, always_check, err_ctx, + match_todos); + } + + void VisitStructInfo_(const ObjectStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + } + + void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + // TODO(relax-team) add PrimValue checks later. + LOG(FATAL) << "MatchCast of PrimValue is not yet supported"; + } + + void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + // emit runtime check of shape + if (always_check || !IsBaseOf(ShapeStructInfo(op->ndim), GetStructInfo(value))) { + // check_shape_info(value, ndim, err_ctx) + Call call(builtin_check_shape_info_, + {value, PrimValue::Int64(op->ndim), GetErrContext(err_ctx)}, Attrs(), + {void_sinfo_}); + builder_->Emit(call, "_"); + } + if (op->values.defined()) { + MatchShapeTodoItem item; + item.input = value; + item.pattern = op->values.value(); + item.err_ctx = err_ctx; + match_todos->push_back(item); + } + } + + void VisitStructInfo_(const TensorStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + // emit runtime check of shape + if (always_check || !IsBaseOf(TensorStructInfo(op->dtype, op->ndim), GetStructInfo(value))) { + // check_tensor_info(value, ndim, dtype, err_ctx) + Call call(builtin_check_tensor_info_, + {value, PrimValue::Int64(op->ndim), DataTypeImm(op->dtype), GetErrContext(err_ctx)}, + Attrs(), {void_sinfo_}); + builder_->Emit(call, "_"); + } + + if (auto* shape_expr = op->shape.as()) { + MatchShapeTodoItem item; + item.input = value; + item.pattern = shape_expr->values; + item.err_ctx = err_ctx; + match_todos->push_back(item); + } else if (op->shape.as()) { + // NOTE: This part of the logic is left empty for future support as it is less common. + // Future implementors: we can emit a binding here and assert here. + LOG(FATAL) << "Cannot handle Tensor shape pattern where a var appears multiple times"; + } else { + ICHECK(!op->shape.defined()) << "Can only handle tensor shape pattern var"; + } + } + + // Internal helper function to make tuple get item. + // This function will try to simplify constant tuples + // the return value **always** have struct info. + Expr MakeTupleGetItem(Expr value, int64_t index) { + if (auto* tuple_expr = value.as()) { + return tuple_expr->fields[index]; + } else if (auto* tuple_sinfo = GetStructInfoAs(value)) { + // value is tuple type, it is OK to run tuple get item. + auto ret = TupleGetItem(value, index); + UpdateStructInfo(ret, tuple_sinfo->fields[index]); + return ret; + } else { + // call runtime tuple get item, and return a object. + Call call(builtin_tuple_getitem_, {value, PrimValue::Int64(index)}, Attrs(), {object_sinfo_}); + UpdateStructInfo(call, ObjectStructInfo()); + return call; + } + } + + void VisitStructInfo_(const TupleStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + auto* value_tinfo = GetStructInfoAs(value); + if (value_tinfo) { + CHECK_EQ(value_tinfo->fields.size(), op->fields.size()) + << "TypeError: " << err_ctx << " during match-cast we find tuple size mismatch"; + } + if (always_check || !value_tinfo) { + // check_tuple_info(value, tuple_size) + Call call(builtin_check_tuple_info_, + {value, PrimValue::Int64(static_cast(op->fields.size())), + GetErrContext(err_ctx)}, + Attrs(), {void_sinfo_}); + builder_->Emit(call, "_"); + } + // recursively visit each sub-field and run matching + for (size_t i = 0; i < op->fields.size(); ++i) { + this->VisitStructInfo(op->fields[i], MakeTupleGetItem(value, i), always_check, err_ctx, + match_todos); + } + } + + void VisitStructInfo_(const FuncStructInfoNode* op, Expr value, bool always_check, + const String& err_ctx, std::vector* match_todos) final { + // we only check function is callable. + if (!always_check && MatchStructInfo(value)) return; + // check_func_info(value, err_ctx) + Call call(builtin_check_func_info_, {value, GetErrContext(err_ctx)}, Attrs(), {void_sinfo_}); + builder_->Emit(call, "_"); + } + + //------------------------------------------------------- + // Private member fields. + //------------------------------------------------------- + /*! \brief whether to emit error context, can be turned off for testing purposes. */ + bool emit_err_ctx_{true}; + /*! \brief heap ptr to store the PrimExpr slots. */ + Var shape_heap_; + /*! \brief heap size. */ + IntImm heap_size_; + /*! \brief index => slot. */ + std::vector> slot_vec_; + /*! \brief Expr => slot. */ + PrimExprSlotMap slot_map_; + /*! + * \brief List of vars that are being defined but + * have not go through outstanding shape compute check. + */ + std::vector ready_vars_; + // call builtin cop + const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); + const Op& null_value_op_ = Op::Get("relax.null_value"); + // common struct info + const StructInfo object_sinfo_ = ObjectStructInfo(); + const StructInfo void_sinfo_ = TupleStructInfo(Array({})); + // check function + const ExternFunc builtin_alloc_shape_heap_{"vm.builtin.alloc_shape_heap"}; + const ExternFunc builtin_match_shape_{"vm.builtin.match_shape"}; + const ExternFunc builtin_make_shape_{"vm.builtin.make_shape"}; + const ExternFunc builtin_check_shape_info_{"vm.builtin.check_shape_info"}; + const ExternFunc builtin_check_tensor_info_{"vm.builtin.check_tensor_info"}; + const ExternFunc builtin_check_tuple_info_{"vm.builtin.check_tuple_info"}; + const ExternFunc builtin_check_func_info_{"vm.builtin.check_func_info"}; + const ExternFunc builtin_tuple_getitem_{"vm.builtin.tuple_getitem"}; +}; + +namespace transform { + +Pass VMShapeLower(bool emit_err_ctx) { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return VMShapeLowerMutator::Lower(mod, emit_err_ctx); }; + return CreateModulePass(pass_func, 0, "VMShapeLower", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed([](bool emit_err_ctx) { + return VMShapeLower(emit_err_ctx); +}); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc new file mode 100644 index 000000000000..1b077d8b887a --- /dev/null +++ b/src/relax/ir/transform.cc @@ -0,0 +1,413 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/ir/transform.cc + * \brief Relax specific transformation passes. + */ +#include +#include +#include +#include +#include +#include +#include +#include +namespace tvm { +namespace relax { +namespace transform { + +TVM_REGISTER_PASS_CONFIG_OPTION("relax.fallback_device_type", IntImm); + +// TODO(@yuchen): will need to dedup with FunctionPass in Relay when we upstream +class FunctionPass; + +/*! + * \brief Function-level passes are used to implement various global + * optimizations for a given Relax IRModule. It fetches one function at a time + * from the function list in the IRModule for optimization. + * + * Note that the scope of passes at this level is a Relax function. Therefore, + * we cannot add or delete a function through these passes as they are not aware + * of the global information. + */ +class FunctionPassNode : public tvm::transform::PassNode { + public: + /* \brief The pass meta data.*/ + PassInfo pass_info; + + /*! \brief The packed pass function sketches the real optimization. For + * instance, we can implement a pass that works on a Relax function as a + * `pass_func` and let it run on a given IRModule. The same `pass_func` will + * then be applied on each function in the IRModule. + */ + runtime::TypedPackedFunc pass_func; + + FunctionPassNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } + + /*! + * \brief Run a function pass on given pass context. + * + * \param mod The IRModule that an optimization pass is applied on. + * \param pass_ctx The context that an optimization pass executes on. + * + * \return Return the updated IRModule. + */ + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; + + /*! + * \brief Get the pass information/meta data. + */ + PassInfo Info() const override { return pass_info; } + + static constexpr const char* _type_key = "relax.FunctionPass"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode); + + private: + /* + * \brief Check if a function should be skipped for optimization. + * + * \param func The target function to be checked. + * + * \return Return true if the function will be skipped, otherwise false. + */ + bool SkipFunction(const Function& func) const; +}; + +class FunctionPass : public Pass { + public: + /*! + * \brief The constructor + * \param pass_func The packed function which implements a pass. + * \param pass_info The pass info. + */ + TVM_DLL FunctionPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info); + + TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); +}; + +FunctionPass::FunctionPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { + auto n = make_object(); + n->pass_func = std::move(pass_func); + n->pass_info = std::move(pass_info); + data_ = std::move(n); +} + +// Perform IRModule -> IRModule optimizations at the Function level. +IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { + DiagnosticContext previous = DiagnosticContext::Default(mod); + + if (pass_ctx->diag_ctx) { + DiagnosticContext tmp = pass_ctx->diag_ctx.value(); + pass_ctx->diag_ctx = previous; + previous = tmp; + } else { + pass_ctx->diag_ctx = previous; + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + + const PassInfo& pass_info = Info(); + + ICHECK(mod.defined()); + + VLOG_CONTEXT << pass_info->name; + VLOG(0) << "Executing function pass with opt level: " << pass_info->opt_level; + VLOG(1) << "Input module:" << std::endl << mod; + + IRModule updated_mod = mod->ShallowCopy(); + + std::vector > updates; + for (const auto& it : updated_mod->functions) { + // only picks up relax::Function + if (auto* n = it.second.as()) { + Function func = GetRef(n); + auto updated_func = SkipFunction(func) ? func : pass_func(func, updated_mod, pass_ctx); + updates.push_back({it.first, updated_func}); + } + } + + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block, this is a bug."; + + pass_ctx->diag_ctx.value().Render(); + pass_ctx->diag_ctx = previous; + + VLOG(1) << "Output module:" << std::endl << updated_mod; + + return updated_mod; +} + +bool FunctionPassNode::SkipFunction(const Function& func) const { + // TODO(@yuchen): will need to revisit in the future + return (func->GetAttr(relay::attr::kCompiler).defined()) || + func->GetAttr(relay::attr::kSkipOptimization, 0) != 0; +} + +Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required) { + PassInfo pass_info = PassInfo(opt_level, name, required); + return FunctionPass(pass_func, pass_info); +} + +TVM_REGISTER_NODE_TYPE(FunctionPassNode); + +TVM_REGISTER_GLOBAL("relax.transform.MakeFunctionPass") + .set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return FunctionPass(pass_func, pass_info); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Function pass: " << info->name << " at the optimization level " + << info->opt_level; + }); + +class DataflowBlockPass; + +/*! + * \brief DataflowBlock-level passes are used to implement various dataflow block + * optimizations for a given Relax IRModule. It fetches one dataflow block at a time + * from the functions in an IRModule, and yields a rewritten DataflowBlock. + * + * Note that the scope of passes at this level is a Relax DataflowBlock. Therefore, + * we cannot modify the global scope Vars and symbolic shape Vars defined inside the dataflow block. + */ +class DataflowBlockPassNode : public tvm::transform::PassNode { + public: + /* \brief The pass meta data.*/ + PassInfo pass_info; + + /*! \brief The packed pass function sketches the real optimization. For + * instance, we can implement a pass that works on a Relax DataflowBlock as a + * `pass_func` and let it run on a given IRModule. The same `pass_func` will + * then be applied on each DataflowBlock in the IRModule. + */ + runtime::TypedPackedFunc pass_func; + + DataflowBlockPassNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } + + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; + + PassInfo Info() const override { return pass_info; } + + static constexpr const char* _type_key = "relax.DataflowBlockPass"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockPassNode, PassNode); +}; + +/*! \brief Helper to apply the passed function to dataflow blocks.*/ +class DataflowBlockMutator : public ExprMutator { + public: + DataflowBlockMutator( + runtime::TypedPackedFunc pass_func, + IRModule mod, PassContext pass_ctx) + : pass_func_(pass_func), mod_(mod), pass_ctx_(pass_ctx) {} + + /*! + * \brief Rewrite the DataflowBlockNode with pass_func_ + * + * This function will check that there are no rewrites of the global scope Vars + * and symbolic shape Vars defined inside the dataflow block. + */ + BindingBlock VisitBindingBlock_(const DataflowBlockNode* n) final { + // collect Global Scope Vars and Symbolic Vars inside the DataflowBlock + Map global_scope_vars; + Map symbolic_vars; + for (const Binding& binding : n->bindings) { + Var var = binding->var; + if (const auto* match_cast = binding.as()) { + auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); + for (const tir::VarNode* var : collected_vars) { + symbolic_vars.Set(var->name_hint, GetRef(var)); + } + } + if (!var.as()) { + global_scope_vars.Set(var->name_hint(), var); + } + } + + // apply pass_func_ to the DataflowBlock + DataflowBlock block = GetRef(n); + DataflowBlock updated_block = pass_func_(block, mod_, pass_ctx_); + + // raise error if there are updates of recorded Global Scope Vars and Symbolic Vars + for (const Binding& binding : updated_block->bindings) { + Var var = binding->var; + if (const auto* match_cast = binding.as()) { + auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); + for (const tir::VarNode* var : collected_vars) { + if (symbolic_vars.count(var->name_hint) > 0) { + tir::Var old_var = symbolic_vars[var->name_hint]; + ICHECK(var == old_var.get()) + << "Error: DataflowBlock Pass should not rewrite any Symbolic Var."; + symbolic_vars.erase(var->name_hint); + } + } + } + if (!var.as() && global_scope_vars.count(var->name_hint()) > 0) { + ICHECK(var.same_as(global_scope_vars[var->name_hint()])) + << "Error: DataflowBlock Pass should not rewrite any GlobalScope Var."; + global_scope_vars.erase(var->name_hint()); + } + } + ICHECK(global_scope_vars.empty() && symbolic_vars.empty()) + << "Error: DataflowBlock Pass should not delete any GlobalScope/Symbolic Var."; + + return std::move(updated_block); + } + + private: + class SymbolicVarCollector : public StructInfoVisitor { + public: + static std::unordered_set Collect(const StructInfo& info) { + SymbolicVarCollector collector; + collector.VisitStructInfo(info); + return std::move(collector.symbolic_vars_); + } + + private: + void VisitStructInfoExprField(const PrimExpr& expr) final { + if (const tir::VarNode* sym_var = expr.as()) { + symbolic_vars_.insert(sym_var); + } + } + + private: + std::unordered_set symbolic_vars_; + }; + + runtime::TypedPackedFunc pass_func_; + IRModule mod_; + PassContext pass_ctx_; +}; + +class DataflowBlockPass : public Pass { + public: + /*! + * \brief The constructor + * \param pass_func The packed function which implements a pass. + * \param pass_info The pass info. + */ + TVM_DLL DataflowBlockPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info); + + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockPass, Pass, DataflowBlockPassNode); +}; + +DataflowBlockPass::DataflowBlockPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { + auto n = make_object(); + n->pass_func = std::move(pass_func); + n->pass_info = std::move(pass_info); + data_ = std::move(n); +} + +// Perform IRModule -> IRModule transformations at the DataflowBlock level. +IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { + DiagnosticContext previous = DiagnosticContext::Default(mod); + + if (pass_ctx->diag_ctx) { + DiagnosticContext tmp = pass_ctx->diag_ctx.value(); + pass_ctx->diag_ctx = previous; + previous = tmp; + } else { + pass_ctx->diag_ctx = previous; + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block, this is a bug."; + + const PassInfo& pass_info = Info(); + + ICHECK(mod.defined()); + + VLOG_CONTEXT << pass_info->name; + VLOG(0) << "Executing DataflowBlock pass with opt level: " << pass_info->opt_level; + VLOG(1) << "Input module:" << std::endl << mod; + + IRModule updated_mod = mod->ShallowCopy(); + + DataflowBlockMutator dataflow_block_mutator(pass_func, updated_mod, pass_ctx); + std::vector > updates; + for (const auto& it : updated_mod->functions) { + // only picks up relax::Function + if (auto* n = it.second.as()) { + Function func = GetRef(n); + Function updated_func = Downcast(dataflow_block_mutator.VisitExpr(func)); + updates.push_back({it.first, updated_func}); + } + } + + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + + pass_ctx->diag_ctx.value().Render(); + pass_ctx->diag_ctx = previous; + + VLOG(1) << "Output module:" << std::endl << updated_mod; + + return updated_mod; +} + +Pass CreateDataflowBlockPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required) { + PassInfo pass_info = PassInfo(opt_level, name, required); + return DataflowBlockPass(pass_func, pass_info); +} + +TVM_REGISTER_NODE_TYPE(DataflowBlockPassNode); + +TVM_REGISTER_GLOBAL("relax.transform.MakeDataflowBlockPass") + .set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return DataflowBlockPass(pass_func, pass_info); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run DataflowBlock pass: " << info->name << " at the optimization level " + << info->opt_level; + }); +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py new file mode 100644 index 000000000000..0bf0f175dd7e --- /dev/null +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -0,0 +1,429 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm.script +import tvm.testing +from tvm import relax +from tvm.ir import assert_structural_equal +from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_const_shape_arg(): + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Shape([1, 2]), y: R.Shape): + return x + + @T.prim_func + def extra_func(H: T.Buffer(T.int64(4), "int64")): + """Extra function, checks if the pass preserves it.""" + H[T.int64(1)] = H[T.int64(0)] + T.int64(1) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Shape([1, 2]), y: R.Shape): + shape_heap = R.null_value() + _ = R.call_packed("vm.builtin.check_shape_info", x, 2, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_shape_info", y, -1, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 2, + MS.ASSERT_EQUAL_TO_IMM, + 1, + MS.ASSERT_EQUAL_TO_IMM, + 2, + "", + sinfo_args=[R.Tuple()], + ) + return x + + @T.prim_func + def extra_func(H: T.Buffer(T.int64(4), "int64")): + H[T.int64(1)] = H[T.int64(0)] + T.int64(1) + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_static_fn_check(): + """Check static shape and function.""" + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])): + return y + + @tvm.script.ir_module + class Expected: + @R.function + def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])): + shape_heap = R.null_value() + _ = R.call_packed("vm.builtin.check_func_info", f, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed("vm.builtin.check_shape_info", y, 2, "", sinfo_args=[R.Tuple()]) + _ = R.call_packed( + "vm.builtin.match_shape", + y, + shape_heap, + 2, + MS.ASSERT_EQUAL_TO_IMM, + 1, + MS.ASSERT_EQUAL_TO_IMM, + 2, + "", + sinfo_args=[R.Tuple()], + ) + return y + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_simple_symbolic_shape(): + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor(["n", 2, "m"], "float32")): + return x + + sindex = { + "n": 0, + "m": 1, + } + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(["n", 2, "m"], "float32")): + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(2)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", x, 3, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 3, + MS.STORE_TO_HEAP, + sindex["n"], + MS.ASSERT_EQUAL_TO_IMM, + 2, + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + return x + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_symbolic_compute(): + MS = MatchShapeCode + MK = MakeShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None) + ) -> R.Shape(ndim=3): + n = T.Var("n", "int64") + k = T.Var("k", "int64") + z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None)) + return (k + 1, m, 2) + + # slot assignment: + # 0: n, 1: m, 2:k, 3: k+1 + sindex = {"n": 0, "m": 1, "k": 2, "k+1": 3} + + @tvm.script.ir_module + class Expected: + @T.prim_func + def shape_func(H: T.Buffer(T.int64(4), "int64")): + # generated compute function + H[T.int64(sindex["k+1"])] = H[T.int64(sindex["k"])] + T.int64(1) + + @R.function + def main( + x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None) + ) -> R.Shape(ndim=3): + n = T.Var("n", "int64") + k = T.Var("k", "int64") + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(4)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", y, 3, R.dtype(""), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 2, + MS.STORE_TO_HEAP, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + _ = R.call_packed( + "vm.builtin.match_shape", + y, + shape_heap, + 3, + MS.STORE_TO_HEAP, + sindex["k"], + MS.ASSERT_EQUAL_TO_LOAD, + sindex["m"], + MS.NO_OP, + 0, + "", + sinfo_args=[R.Tuple()], + ) + _ = shape_func(shape_heap) + # extra assertion on y's shape after shape computation + _ = R.call_packed( + "vm.builtin.match_shape", + y, + shape_heap, + 3, + MS.ASSERT_EQUAL_TO_LOAD, + sindex["k"], + MS.ASSERT_EQUAL_TO_LOAD, + sindex["m"], + MS.ASSERT_EQUAL_TO_LOAD, + sindex["k+1"], + "", + sinfo_args=[R.Tuple()], + ) + z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None)) + # construct shape value for return + s = R.call_packed( + "vm.builtin.make_shape", + shape_heap, + 3, + MK.LOAD_SHAPE, + sindex["k+1"], + MK.LOAD_SHAPE, + sindex["m"], + MK.USE_IMM, + 2, + sinfo_args=[R.Shape(ndim=3)], + ) + return s + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_tuple_handling(): + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tuple( + R.Tensor(["n", "m"], "float32"), R.Tuple(R.Shape, R.Tensor(["n", "k"], "int32")) + ) + ): + return x + + # slot assignment: + sindex = {"n": 0, "m": 1, "k": 2} + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tuple( + R.Tensor(["n", "m"], "float32"), R.Tuple(R.Shape, R.Tensor(["n", "k"], "int32")) + ) + ): + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(3)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + # recursively unpack tuple for static info check + _ = R.call_packed("vm.builtin.check_tuple_info", x, 2, "", sinfo_args=[R.Tuple()]) + t0 = x[0] + _ = R.call_packed( + "vm.builtin.check_tensor_info", + t0, + 2, + R.dtype("float32"), + "", + sinfo_args=[R.Tuple()], + ) + t1 = x[1] + _ = R.call_packed("vm.builtin.check_tuple_info", t1, 2, "", sinfo_args=[R.Tuple()]) + t1x0 = t1[0] + _ = R.call_packed("vm.builtin.check_shape_info", t1x0, -1, "", sinfo_args=[R.Tuple()]) + t1x1 = t1[1] + _ = R.call_packed( + "vm.builtin.check_tensor_info", + t1x1, + 2, + R.dtype("int32"), + "", + sinfo_args=[R.Tuple()], + ) + # match shape checks. + _ = R.call_packed( + "vm.builtin.match_shape", + t0, + shape_heap, + 2, + MS.STORE_TO_HEAP, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + _ = R.call_packed( + "vm.builtin.match_shape", + t1x1, + shape_heap, + 2, + MS.ASSERT_EQUAL_TO_LOAD, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["k"], + "", + sinfo_args=[R.Tuple()], + ) + return x + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +def test_return_match_check(): + """Test when return body is not same as ret_struct_info, runtime match check needed.""" + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["n", "m"], "float32"), y: R.Object + ) -> R.Tuple(R.Tensor(["n", "m"], "float32")): + return y + + # slot assignment: + sindex = { + "n": 0, + "m": 1, + } + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["n", "m"], "float32"), y: R.Object + ) -> R.Tuple(R.Tensor(["n", "m"], "float32")): + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(2)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 2, + MS.STORE_TO_HEAP, + sindex["n"], + MS.STORE_TO_HEAP, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + _ = R.call_packed("vm.builtin.check_tuple_info", y, 1, "", sinfo_args=[R.Tuple()]) + # emit runtime function call since y do not have the right type. + y1 = R.call_packed("vm.builtin.tuple_getitem", y, 0, sinfo_args=[R.Object]) + # run check + _ = R.call_packed( + "vm.builtin.check_tensor_info", + y1, + 2, + R.dtype("float32"), + "", + sinfo_args=[R.Tuple()], + ) + # shape check + _ = R.call_packed( + "vm.builtin.match_shape", + y1, + shape_heap, + 2, + MS.ASSERT_EQUAL_TO_LOAD, + sindex["n"], + MS.ASSERT_EQUAL_TO_LOAD, + sindex["m"], + "", + sinfo_args=[R.Tuple()], + ) + + return y + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + +if __name__ == "__main__": + tvm.testing.main() From 9e47ae6808cfaf14059b2a9a1f67f162f20f3332 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Sat, 11 Feb 2023 09:26:36 -0800 Subject: [PATCH 11/81] [Unity] e2e Relax minimum build flow (#13961) This PR introduces the e2e Relax lowering flow (`relax.vm.build`). Tests for each pass in the flow are added. Co-Authored-by: Altan Haan Co-Authored-by: Andrew Liu Co-Authored-by: Hongyi Jin <3231950289@qq.com> Co-Authored-by: Jiawei Liu Co-Authored-by: Junru Shao Co-Authored-by: Prakalp Srivastava Co-Authored-by: Ruihang Lai Co-Authored-by: Siyuan Feng Co-Authored-by: Steven S. Co-Authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Co-Authored-by: Tianqi Chen Co-Authored-by: Yong Wu Co-Authored-by: Ziheng Jiang --- CMakeLists.txt | 1 + include/tvm/relax/analysis.h | 16 + include/tvm/relax/backend.h | 7 + include/tvm/relax/transform.h | 35 + python/tvm/relax/analysis/analysis.py | 45 + python/tvm/relax/op/__init__.py | 3 + python/tvm/relax/op/builtin/__init__.py | 20 + python/tvm/relax/op/builtin/_ffi_api.py | 19 + python/tvm/relax/op/builtin/builtin.py | 44 + python/tvm/relax/op/manipulate.py | 62 ++ python/tvm/relax/op/memory/__init__.py | 20 + python/tvm/relax/op/memory/_ffi_api.py | 19 + python/tvm/relax/op/memory/memory.py | 108 +++ python/tvm/relax/testing/__init__.py | 20 + python/tvm/relax/testing/nn.py | 194 ++++ python/tvm/relax/transform/transform.py | 53 + python/tvm/relax/vm.py | 4 +- python/tvm/script/ir_builder/relax/ir.py | 6 + src/relax/analysis/tir_op_pattern_kind.cc | 447 +++++++++ src/relax/backend/vm/vm_builtin_lower.cc | 208 ++++ src/relax/op/op.cc | 81 ++ src/relax/op/tensor/manipulate.cc | 163 ++++ src/relax/op/tensor/manipulate.h | 45 + src/relax/transform/attach_global_symbol.cc | 68 ++ src/relax/transform/call_tir_rewrite.cc | 137 +++ .../transform/rewrite_dataflow_reshape.cc | 110 +++ src/relax/transform/to_non_dataflow.cc | 67 ++ tests/python/relax/test_analysis.py | 172 ++++ tests/python/relax/test_transform.py | 141 +++ .../test_transform_attach_global_symbol.py | 88 ++ ...test_transform_rewrite_dataflow_reshape.py | 166 ++++ tests/python/relax/test_vm_build.py | 908 ++++++++++++++++++ 32 files changed, 3476 insertions(+), 1 deletion(-) create mode 100644 python/tvm/relax/op/builtin/__init__.py create mode 100644 python/tvm/relax/op/builtin/_ffi_api.py create mode 100644 python/tvm/relax/op/builtin/builtin.py create mode 100644 python/tvm/relax/op/manipulate.py create mode 100644 python/tvm/relax/op/memory/__init__.py create mode 100644 python/tvm/relax/op/memory/_ffi_api.py create mode 100644 python/tvm/relax/op/memory/memory.py create mode 100644 python/tvm/relax/testing/__init__.py create mode 100644 python/tvm/relax/testing/nn.py create mode 100644 src/relax/analysis/tir_op_pattern_kind.cc create mode 100644 src/relax/backend/vm/vm_builtin_lower.cc create mode 100644 src/relax/op/tensor/manipulate.cc create mode 100644 src/relax/op/tensor/manipulate.h create mode 100644 src/relax/transform/attach_global_symbol.cc create mode 100644 src/relax/transform/call_tir_rewrite.cc create mode 100644 src/relax/transform/rewrite_dataflow_reshape.cc create mode 100644 src/relax/transform/to_non_dataflow.cc create mode 100644 tests/python/relax/test_analysis.py create mode 100644 tests/python/relax/test_transform.py create mode 100644 tests/python/relax/test_transform_attach_global_symbol.py create mode 100644 tests/python/relax/test_transform_rewrite_dataflow_reshape.py create mode 100644 tests/python/relax/test_vm_build.py diff --git a/CMakeLists.txt b/CMakeLists.txt index eecd67be94c1..d0470677e128 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -292,6 +292,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/relax/ir/*.cc src/relax/op/*.cc src/relax/analysis/*.cc + src/relax/transform/*.cc src/relax/backend/vm/*.cc src/relax/utils.cc ) diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index ad2bd19aa41a..24cfe5b9bf11 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -259,6 +259,22 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, */ TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::Analyzer* ana = nullptr); + +/*! + * \brief Check if the given PrimFunc is essentially doing a reshape operation. + * The reshape operation also includes expand_dims, squeeze, flatten, etc. + * \details Here the allowed reshape pattern is: for example, assume the operation is + * `B[l_0, l_1, ..., l_b] = A[r_0, r_1, ..., r_a]`, we check if we can prove that the flattened + * index of l_0, ..., l_b under buffer B equals to the flattened index of r_0, ..., r_a under + * buffer A. + * \param func The function to be examined. + * \return A boolean indicating if the given PrimFunc is doing a reshape. + * \note According to the description above, the returned result can only be false-negative and + * cannot be false-positive, since whenever we cannot prove the equality, we return false. This + * property guarantees the safety of this function. + */ +TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func); + } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h index 4ebeacac0ff3..2fb11f5a6f83 100644 --- a/include/tvm/relax/backend.h +++ b/include/tvm/relax/backend.h @@ -30,6 +30,13 @@ namespace tvm { namespace relax { namespace transform { +/*! + * \brief Perform builtin lowering to map most of the op to VM builtin functions. + * + * \return The Pass. + */ +TVM_DLL Pass VMBuiltinLower(); + /*! * \brief Lower the shape expression in relax to VM shape heap and TIR functions. * diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index fa288a7f06c2..ff98b16d251e 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -65,6 +65,41 @@ TVM_DLL Pass CreateDataflowBlockPass( const runtime::TypedPackedFunc& pass_func, int opt_level, String name, tvm::Array required); +/*! + * \brief Transform all dataflow structure to non-dataflow version. + * + * \return The Pass. + */ +TVM_DLL Pass ToNonDataflow(); + +/*! + * \brief Perform explicit tensor allocation for call_tir. + * + * \return The Pass. + */ +TVM_DLL Pass CallTIRRewrite(); + +/*! + * \brief Convert all reshape-like call_tir whose corresponding binding + * vars are DataflowVars to relax.reshape operator calls. The relax.reshape + * calls will be lowered an external builtin function call in a subsequent + * pass, where the external builtin function does a CreateView operation + * at runtime, instead of doing real data copy. + * Here "reshape-like" includes reshape, expand_dims, flatten, etc. + * + * \return The Pass. + * \note The pass is applied at the first stage of Relax VM build, before + * rewriting call_tir, as this pass requires dataflow information. + */ +TVM_DLL Pass RewriteDataflowReshape(); + +/*! + * \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen. + * + * \return The Pass. + */ +TVM_DLL Pass AttachGlobalSymbol(); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index d81c477145ec..27416c3a7919 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -162,3 +162,48 @@ def struct_info_lca(lhs: StructInfo, rhs: StructInfo) -> StructInfo: The corresponding lca result. """ return _ffi_api.StructInfoLCA(lhs, rhs) # type: ignore + + +def post_order_visit(expr, fvisit): + """Recursively visit the ir in post DFS order node, + apply fvisit. Each node is guaranteed to be visited + only once. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + fvisit : function + The visitor function to be applied. + """ + return _ffi_api.post_order_visit(expr, fvisit) # type: ignore + + +def has_reshape_pattern(func: tir.PrimFunc) -> bool: + """Check if the given PrimFunc is essentially doing a reshape operation. + The reshape operation also includes expand_dims, squeeze, flatten, etc. + + Here the allowed reshape pattern is: for example, assume the operation is + `B[l_0, l_1, ..., l_b] = A[r_0, r_1, ..., r_a]`, we check if we can prove + that the flattened index of l_0, ..., l_b under buffer B equals to the + flattened index of r_0, ..., r_a under buffer A. + + Parameters + ---------- + func : tir.PrimFunc + The function to be examined. + + Returns + ------- + ret : bool + A boolean indicating if the given PrimFunc is doing a reshape. + + Notes + ----- + According to the description above, the returned result can only be + false-negative and cannot be false-positive, since whenever we cannot + prove the equality, we return false. This property guarantees the safety + of this function. + """ + return _ffi_api.has_reshape_pattern(func) # type: ignore diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 101b0827d630..9a131cdf957f 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -20,3 +20,6 @@ # Operators from .base import * from .binary import * +from .manipulate import * +from . import builtin +from . import memory diff --git a/python/tvm/relax/op/builtin/__init__.py b/python/tvm/relax/op/builtin/__init__.py new file mode 100644 index 000000000000..04837724b165 --- /dev/null +++ b/python/tvm/relax/op/builtin/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax builtin operators.""" + +from .builtin import * diff --git a/python/tvm/relax/op/builtin/_ffi_api.py b/python/tvm/relax/op/builtin/_ffi_api.py new file mode 100644 index 000000000000..42fe8cb65234 --- /dev/null +++ b/python/tvm/relax/op/builtin/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.relax.op.builtin""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.builtin", __name__) diff --git a/python/tvm/relax/op/builtin/builtin.py b/python/tvm/relax/op/builtin/builtin.py new file mode 100644 index 000000000000..0afe6a42d09a --- /dev/null +++ b/python/tvm/relax/op/builtin/builtin.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""The builtin Relax operators.""" + +from ...expr import Call, Expr +from ...utils import args_converter +from . import _ffi_api + + +@args_converter.auto +def alloc_tensor(shape: Expr, dtype: str, runtime_device_index: int) -> Call: + """Construct a Call to allocate a tensor with specific shape, dtype, runtime_device_index. + + Parameters + ---------- + shape : Expr + The shape of the tensor to be allocated. + + dtype : str + The datatype of the tensor to be allocated. + + runtime_device_index : int + The device index indicating on which device the tensor is to be allocated at runtime. + Index -1 is reserved for the host device. + + Returns + ------- + result : Call + A relax Call, which gets the allocated tensor. + """ + return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index) # type: ignore diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py new file mode 100644 index 000000000000..fa9c81522596 --- /dev/null +++ b/python/tvm/relax/op/manipulate.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Manipulation operators.""" +from typing import Tuple, Union + +from tvm.ir.expr import PrimExpr + + +from . import _ffi_api +from ..expr import Expr + + +PrimExprLike = Union[int, PrimExpr] + + +def reshape(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: + """Reshape the input array. + + ``-1`` infers the dimension of the output shape by using the remainder of + the input dimensions keeping the size of the new array same as that of the input array. + At most one dimension of shape can be -1. + + .. code-block:: python + + x.shape = (2, 3, 4), shape = (6, 1, -1), result.shape = (6, 1, 4) + x.shape = (2, 3, 4), shape = (3, -1, 8), result.shape = (3, 1, 8) + x.shape = (2, 3, 4), shape = (-1,), result.shape = (24,) + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + shape : Union[Tuple[PrimExprLike], Expr] + The new shape. Should be compatible with the original shape. + + Returns + ------- + result : relax.Expr + The reshaped result. + + Note + ---- + The ``-1`` inference is only performed at compile-time. + That is to say, in any case the dimension length of ``-1`` cannot be inferred in + compile-time, an error will be thrown. + """ + return _ffi_api.reshape(x, shape) # type: ignore diff --git a/python/tvm/relax/op/memory/__init__.py b/python/tvm/relax/op/memory/__init__.py new file mode 100644 index 000000000000..e039590251fc --- /dev/null +++ b/python/tvm/relax/op/memory/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax memory primitives.""" + +from .memory import * diff --git a/python/tvm/relax/op/memory/_ffi_api.py b/python/tvm/relax/op/memory/_ffi_api.py new file mode 100644 index 000000000000..475de481b22e --- /dev/null +++ b/python/tvm/relax/op/memory/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.relax.op.memory""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.memory", __name__) diff --git a/python/tvm/relax/op/memory/memory.py b/python/tvm/relax/op/memory/memory.py new file mode 100644 index 000000000000..b58b987d2a3e --- /dev/null +++ b/python/tvm/relax/op/memory/memory.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""Relax memory primitives.""" + +from . import _ffi_api +from ...expr import Expr, Call +from ...utils import args_converter + + +@args_converter.auto +def alloc_storage(size: Expr, virtual_device_index: int, storage_scope: str, dtype: str) -> Call: + """Construct a Call to allocate a storage with specific size, virtual_device_index, + storage_scope and dtype. + + Parameters + ---------- + size : Expr + The size of the storage to be allocated. + + virtual_device_index : int + The virtual device index indicating on which device the storage is to be allocated. + Index -1 is reserved for the host device. + + storage_scope : str + The storage scope to allocate the storage to. + + dtype : str + The datatype of the storage to be allocated. + + Returns + ------- + result : Call + A relax Call, which gets the allocated storage. + """ + return _ffi_api.alloc_storage(size, virtual_device_index, storage_scope, dtype) # type: ignore + + +@args_converter.auto +def alloc_tensor(storage: Expr, offset: int, shape: Expr, dtype: str) -> Call: + """Construct a Call to allocate a tensor on a certain storage starting from the given offset. + + Parameters + ---------- + storage : Expr + The storage to allocate the tensor to. + + offset : int + The storage offset to allocate the tensor. + + shape : Expr + The shape of the tensor to be allocated. + + dtype : str + The datatype of the tensor to be allocated. + + Returns + ------- + result : Call + A relax Call, which gets the allocated tensor. + """ + return _ffi_api.alloc_tensor(storage, offset, shape, dtype) # type: ignore + + +@args_converter.auto +def kill_storage(storage: Expr) -> Call: + """Construct a Call to kill a storage. + + Parameters + ---------- + storage : Expr + The storage to be killed. + + Returns + ------- + result : Call + A relax Call to kill a storage. + """ + return _ffi_api.kill_storage(storage) # type: ignore + + +@args_converter.auto +def kill_tensor(tensor: Expr) -> Call: + """Construct a Call to kill a tensor. + + Parameters + ---------- + tensor : Expr + The tensor to be killed. + + Returns + ------- + result : Call + A relax Call to kill a tensor. + """ + return _ffi_api.kill_tensor(tensor) # type: ignore diff --git a/python/tvm/relax/testing/__init__.py b/python/tvm/relax/testing/__init__.py new file mode 100644 index 000000000000..ab1dd6f5155e --- /dev/null +++ b/python/tvm/relax/testing/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""The Relax testing namespace containing nn and translator.""" + +from .nn import * diff --git a/python/tvm/relax/testing/nn.py b/python/tvm/relax/testing/nn.py new file mode 100644 index 000000000000..830ddd779fe5 --- /dev/null +++ b/python/tvm/relax/testing/nn.py @@ -0,0 +1,194 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +"""PyTorch-like nn.Module API for constructing workloads.""" + + +from typing import List, Any, Callable, Union +import typing +import numpy as np # type: ignore + +import tvm +from tvm import relax, topi, tir + + +def emit_te(func: Callable, *args: Any, **kwargs: Any) -> relax.Var: + return relax.BlockBuilder.current().emit_te(func, *args, **kwargs) + + +class Placeholder(relax.Var): + """A placeholder variable that can represent model input.""" + + def __init__( + self, shape: Union[List[Any], typing.Tuple[Any, ...]], dtype="float32", name="data" + ): + if not isinstance(shape, (list, tuple)): + raise TypeError("the shape of Placeholder is expected to be a list or a tuple") + super().__init__( + relax.BlockBuilder.current().get_unique_name(name), relax.TensorStructInfo(shape, dtype) + ) + + +class Parameter(relax.Var): + """A special kind of relax Var that represents model parameter(weight).""" + + def __init__( + self, shape: Union[List[Any], typing.Tuple[Any, ...]], dtype="float32", name="param" + ): + if not isinstance(shape, (list, tuple)): + raise TypeError("the shape of Parameter is expected to be a list or a tuple") + super().__init__( + relax.BlockBuilder.current().get_unique_name(name), relax.TensorStructInfo(shape, dtype) + ) + + +class Module: + """Base class for all model modules. + + A neural network or a layer can subclass this class. + + Example + ------- + .. code-block:: python + + # Define a linear layer + class Linear(Module) + def __init__(self, in_features, out_features, bias=True): + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter((in_features, out_features), name="linear_weight") + if bias: + self.bias = Parameter((out_features,), name="linear_bias") + else: + self.bias = None + + # All submodules should implement forward. + # Defines the forward computation performed at every call. + def forward(self, input: relax.Expr) -> relax.Var: + y = emit_te(topi.matmul, input, self.weight) + if self.bias is not None: + y = emit_te(topi.add, y, self.bias) + return y + """ + + def parameters(self) -> List[Parameter]: + """Return the list of parameters in the module.""" + return _unpack_params(self.__dict__) + + def forward(self, input: relax.Expr): + """Define the computation performed at every call.""" + raise NotImplementedError() + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +def _unpack_params(value: object) -> List[relax.Var]: + if isinstance(value, Parameter): + return [value] + if isinstance(value, Module): + return value.parameters() + if isinstance(value, dict): + params = [] + for v in value.values(): + params += _unpack_params(v) + return params + if isinstance(value, (list, tuple)): + params = [] + for v in value: + params += _unpack_params(v) + return params + if value is None or isinstance(value, (int, float, str)): + return [] + raise TypeError("not supported type when unpacking parameters: {}".format(type(value))) + + +def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]: + """Utility function to initialize model's parameters.""" + shape_dict = {v.name_hint: v.struct_info.shape for v in mod["main"].params} + params = [] + for k, v in shape_dict.items(): + if k.startswith("data"): + continue + if isinstance(v, relax.ShapeExpr): + shape = [] + for i in v: + if isinstance(i, tir.IntImm): + shape.append(int(i)) + else: + raise TypeError("cannot initialize for unknown-shape parameters.") + params.append(tvm.nd.array(np.zeros(shape).astype(np.float32))) + else: + raise TypeError("cannot initialize for unknown-shape parameters.") + return params + + +class Sequential(Module): + """A sequential container that concatenates modules in it. + + Example + ------- + .. code-block:: python + + model = nn.Sequential( + nn.Conv2d(1, 20, 5), + nn.ReLU(), + nn.Conv2d(20, 64, 5), + nn.ReLU() + ) + """ + + def __init__(self, *modules: Module): + self.modules = modules + + def forward(self, input: relax.Expr) -> relax.Var: + for module in self.modules: + input = module(input) + return input + + +class ReLU(Module): + """Applies the rectified linear unit activation function on the input.""" + + def forward(self, input: relax.Expr) -> relax.Var: + return emit_te(topi.nn.relu, input) + + +class LogSoftmax(Module): + """Applies log softmax activation function on the input.""" + + def forward(self, input: relax.Expr) -> relax.Var: + return emit_te(topi.nn.log_softmax, input) + + +class Linear(Module): + """Applies a linear transformation to the input data: :math:`y = xA + b`.""" + + def __init__(self, in_features, out_features, bias=True): + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter((in_features, out_features), name="linear_weight") + if bias: + self.bias = Parameter((out_features,), name="linear_bias") + else: + self.bias = None + + def forward(self, input: relax.Expr) -> relax.Var: + y = emit_te(topi.matmul, input, self.weight) + if self.bias is not None: + y = emit_te(topi.add, y, self.bias) + return y diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index f20f06c52284..cab18797c672 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -37,6 +37,49 @@ class DataflowBlockPass(tvm.ir.transform.Pass): """A pass that works on each tvm.relax.DataflowBlock in a module.""" +def ToNonDataflow() -> tvm.ir.transform.Pass: + """Transform all dataflow structure to non-dataflow version. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.ToNonDataflow() # type: ignore + + +def CallTIRRewrite() -> tvm.ir.transform.Pass: + """Perform explicit tensor allocation for call_tir. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.CallTIRRewrite() # type: ignore + + +def RewriteDataflowReshape() -> tvm.ir.transform.Pass: + """Convert all reshape-like call_tir to VM reshape operator call. + The VM reshape operator calls will be further lowered to a CreateView + operation at runtime, instead of doing real data copy. + Here "reshape-like" includes reshape, expand_dims, flatten, etc. + + Returns + ------- + ret : tvm.ir.transform.Pass + """ + return _ffi_api.RewriteDataflowReshape() # type: ignore + + +def VMBuiltinLower() -> tvm.ir.transform.Pass: + """Lowering generic intrinsic to VM intrinsics. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.VMBuiltinLower() # type: ignore + + def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass: """Lower the symbolic shape and argument and match-cast structinfo matching. @@ -52,6 +95,16 @@ def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass: return _ffi_api.VMShapeLower(emit_err_ctx) # type: ignore +def AttachGlobalSymbol() -> tvm.ir.transform.Pass: + """Attach global_symbol to Relax functions and TIR Primfuncs for codegen. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.AttachGlobalSymbol() # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index ba16dfb07985..ff6bf816b62b 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -581,7 +581,9 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): if isinstance(target, str): target = tvm.target.Target(target) - passes = [relax.transform.ToNonDataflow()] + passes = [] + passes.append(relax.transform.RewriteDataflowReshape()) + passes.append(relax.transform.ToNonDataflow()) passes.append(relax.transform.CallTIRRewrite()) passes.append(relax.transform.VMBuiltinLower()) passes.append(relax.transform.VMShapeLower()) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 647ef8f25af7..0692ec5683c0 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -31,13 +31,16 @@ from tvm.relax.op import ( add, assert_op, + builtin, call_builtin_with_ctx, call_tir, invoke_closure, make_closure, + memory, multiply, null_value, print, + reshape, shape_of, ) from tvm.relax.struct_info import StructInfo @@ -381,6 +384,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "add", "arg", "assert_op", + "builtin", "call_packed", "call_tir", "call_builtin_with_ctx", @@ -396,11 +400,13 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "function", "invoke_closure", "make_closure", + "memory", "multiply", "null_value", "output", "prim_value", "print", + "reshape", "shape_of", "str", "tuple", diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc new file mode 100644 index 000000000000..b7ac8faddd23 --- /dev/null +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +using namespace tir; + +class PatternKindAnalyzer : public StmtExprVisitor { + public: + explicit PatternKindAnalyzer(const tir::PrimFunc& func) { + for (const tir::Var& param : func->params) { + Optional param_buf = func->buffer_map.Get(param); + if (param_buf.defined()) { + param_buffers_.insert(param_buf.value()); + } + } + } + + private: + bool IsOutputBlock(const BlockNode* block) { + for (const BufferRegion& write_region : block->writes) { + if (param_buffers_.count(write_region->buffer)) { + return true; + } + } + return false; + } + + void VisitStmt_(const BufferStoreNode* op) final { + // We only support one buffer store in a block (ususally generated by TE compute) + // If we have already seen buffer store in the current block, classify as Opaque. + if (store_.defined()) { + kind_ = relay::kOpaque; + return; + } + store_ = GetRef(op); + StmtVisitor::VisitStmt_(op); + } + + void VisitExpr_(const BufferLoadNode* op) final { + loads_.push_back(GetRef(op)); + ExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BlockNode* op) final { + if (op->name_hint == "root") { + // Skip the root block + StmtVisitor::VisitStmt(op->body); + return; + } + + // Step 1. Clear loads and store + loads_.clear(); + store_ = NullOpt; + // Step 2. Visit block body. + StmtVisitor::VisitStmt(op->body); + BufferStore store = store_.value(); + + // Step 3. Checking load store indices pattern + relay::OpPatternKind index_pair_pattern = relay::kElemWise; + bool has_elem_wise = false; + for (const BufferLoad& load : loads_) { + // Since elemwise is stricter than broadcast and broadcast is stricter than injective, + // while the order amount enums: kElemWise < kBroadcast < kInjective. + // We can simpily use `std::max` to detect these three patterns. + // E.g Here is only one store node but two load nodes, like C[i, j] = A[i, j] + B[i] + // Buffer C and A are elemwise but C and B are broadcast. So the whole block follows + // broadcast pattern. + if (IsElemwisePattern(store, load)) { + index_pair_pattern = std::max(index_pair_pattern, relay::kElemWise); + has_elem_wise = true; + } else if (IsBroadcastPattern(store, load)) { + index_pair_pattern = std::max(index_pair_pattern, relay::kBroadcast); + } else if (IsInjectivePattern(store, load)) { + index_pair_pattern = std::max(index_pair_pattern, relay::kInjective); + } else { + index_pair_pattern = relay::kOpaque; + break; + } + } + // If there is a index pair is kElemWise and others are kBroadcast, we regard it as kElemWise + // e.g. A[i, j] = B[i, j] + C[i] + if (index_pair_pattern == relay::kBroadcast && has_elem_wise) { + index_pair_pattern = relay::kElemWise; + } + // If the block index pattern is not opaque, update kind. + if (index_pair_pattern != relay::kOpaque) { + // This rule for softmax: reduce + injective. + if (IsOutputBlock(op) && kind_ == relay::kCommReduce) { + kind_ = relay::kOutEWiseFusable; + } else { + kind_ = std::max(kind_, index_pair_pattern); + } + return; + } + + // Step 4. Checking if the block contains reduce axis by looking into block iterators. + bool has_reduction = false; + Array reduce_vars; + for (const IterVar& it : op->iter_vars) { + if (it->iter_type == kCommReduce) { + has_reduction = true; + reduce_vars.push_back(it->var); + } + } + + if (has_reduction) { + if (IsFMA(op->body)) { + // FMA is regards as kOutEWiseFusable, e.g. Matmul or Conv. + kind_ = std::max(kind_, relay::kOutEWiseFusable); + return; + } else { + for (size_t i = 0; i < loads_.size(); ++i) { + // If it's not a pure reduce, regards as kOutEWiseFusable. + // This rule works for pooling for now. + if (!IsPureReducePattern(reduce_vars, loads_[i]->indices)) { + kind_ = std::max(kind_, relay::kOutEWiseFusable); + return; + } + } + } + kind_ = std::max(kind_, relay::kCommReduce); + } else { + kind_ = relay::kOpaque; + } + } + + /********** Helper Functions **********/ + + /*! \brief Checking if two arrays contains same elements. */ + static bool IsSameArray(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); ++i) { + if (!lhs[i].same_as(rhs[i])) { + return false; + } + } + return true; + } + + /*! + * \brief Checking the load indices and store indices follows elemwise pattern. + * It's elemwise pattern iff load indices and store indices are the same. + * E.g A[i, j] = B[i, j] + */ + static bool IsElemwisePattern(const BufferStore& store, const BufferLoad& load) { + return IsSameArray(store->indices, load->indices); + } + + /*! + * \brief Checking the load indices and store indices follows broadcast pattern. + * It's broadcast pattern iff all load indices are in the store indices in order + * E.g. A[i, j] = B[i] is broadcast since all load indices(`i`) are in the store indices + * A[i, j] = B[i, k] is not broadcast since `k` are not in the store indices. + * A[i, j] = B[j, i] is not broadcast the load indices are not in the same order as store's + */ + static bool IsBroadcastPattern(const BufferStore& store, const BufferLoad& load) { + size_t ndim_load_buf = load->buffer->shape.size(); + size_t ndim_store_buf = store->buffer->shape.size(); + + for (size_t i = 0, j = 0; i < ndim_load_buf; ++i) { + if (is_const_int(load->buffer->shape[i], 1) && is_const_int(load->indices[i], 0)) { + // Skip unit load dimensions + // E.g. A[i, j] = B[1, j] is still broadcast + continue; + } + + // Try to find the i-th load indice in the store indices. + while (j < ndim_store_buf && !store->indices[j].same_as(load->indices[i])) { + ++j; + } + + // It's not broadcast if we cannot find load indices in the store indices in order. + if (j == ndim_store_buf) { + return false; + } + } + return true; + } + + /*! + * \brief Checking the load indices and store indices follows injective pattern. + * It's injective pattern iff all load indice vars are in the store indices, no matter orders. + * Note that we only support store indices are direct vars so far, which can be enhance later. + * E.g. A[i, j] = B[j, i] is injective. + * A[i, j] = B[i - j] is injective since the load indice vars are only i, j + */ + static bool IsInjectivePattern(const BufferStore& store, const BufferLoad& load) { + std::unordered_set vars; + for (const PrimExpr& store_index : store->indices) { + if (const auto* v = store_index.as()) { + vars.insert(v); + } else { + return false; + } + } + for (const PrimExpr& load_index : load->indices) { + // return false if there are vars used in load indices but not in store indices. + if (tir::UsesVar(load_index, [&vars](const tir::VarNode* var) { return !vars.count(var); })) { + return false; + } + } + return true; + } + + /*! + * \brief Checking the load indices and store indices allow data reuse. + * It allow data reuse iff there is any vars in load indices but they are not in store indices + * E.g. Store = A[i, j] and Load = B[i, j, k] allow data reuse. + * Store = A[i, j] and Load = B[i, j + k] allow data reuse. + */ + static bool IsAllowReusePattern(const BufferStore& store, const BufferLoad& load) { + std::unordered_set vars; + for (const PrimExpr& index : store->indices) { + if (const auto* v = index.as()) { + vars.insert(v); + } else { + return false; + } + } + for (const PrimExpr& index : load->indices) { + PreOrderVisit(index, [&](const ObjectRef& node) { + if (const auto* v = node.as()) { + if (vars.count(v)) { + vars.erase(v); + } + } + return true; + }); + } + return !vars.empty(); + } + + /*! \brief Checking if the stmt is multiply add. E.g. C[i, j] += A[i, k] * B[j, k] */ + static bool IsFMA(const Stmt& body) { + if (const auto* store = body.as()) { + if (const auto* add = store->value.as()) { + if (const auto* l = add->a.as()) { + if (const auto* r = add->b.as()) { + bool incremental = + store->buffer.same_as(l->buffer) && IsSameArray(store->indices, l->indices); + const auto* l_load = r->a.as(); + const auto* r_load = r->b.as(); + if (incremental && l_load && r_load) { + return IsAllowReusePattern(GetRef(store), GetRef(l_load)) && + IsAllowReusePattern(GetRef(store), GetRef(r_load)); + } + } + } + } + } + return false; + } + + /*! + * \brief Checking if it is pure reduce pattern. + * It's pure reduce pattern iff all reduces axis are directly reduce var + * E.g. A[i] = sum(B[i, j]) is pure reduce + * A[i] = sum(B[i, j + k]) is not pure reduce + * pooling is not pure reduce + */ + static bool IsPureReducePattern(Array reduce_loops, Array indices) { + for (const PrimExpr& e : indices) { + int id = -1; + if (UsesVar(e, [&](const tir::VarNode* var) { + for (size_t i = 0; i < reduce_loops.size(); ++i) { + if (reduce_loops[i].get() == var) { + id = i; + return true; + } + } + return false; + })) { + if (!reduce_loops[id].same_as(e)) { + return false; + } + } + } + return true; + } + + private: + /*! + * \brief The BufferStore node in the current block. + * \note We only support one BufferStore node in a block (ususally generated by TE compute) + */ + Optional store_; + /*! \brief The BufferLoad nodes in the current block. */ + Array loads_; + /*! \brief The result of op pattern. */ + relay::OpPatternKind kind_ = relay::kElemWise; + /*! \brief The buffers from function params. I.e. the input and output buffers. */ + std::unordered_set param_buffers_; + + public: + relay::OpPatternKind GetResult() { return kind_; } +}; + +relay::OpPatternKind AnalyzeOpPatternKind(const PrimFunc& func) { + PatternKindAnalyzer analyzer(func); + analyzer(func->body); + return analyzer.GetResult(); +} + +bool HasReshapePattern(const PrimFunc& func) { + class ReshapeDetector : public StmtVisitor { + public: + static bool Detect(const Buffer& src_buffer, const Buffer& dst_buffer, Stmt stmt) { + ReshapeDetector detector(src_buffer, dst_buffer); + detector(stmt); + return detector.is_reshape_; + } + + private: + explicit ReshapeDetector(const Buffer& src_buffer, const Buffer& dst_buffer) + : is_reshape_(false), src_buffer_(src_buffer), dst_buffer_(dst_buffer) {} + + void VisitStmt_(const ForNode* loop) final { + ana_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + // To detect the reshape pattern, we require each For to have + // either another For or a BlockRealize as body. + if (!(loop->body->IsInstance() || loop->body->IsInstance())) { + return; + } + this->VisitStmt(loop->body); + } + + void VisitStmt_(const BlockRealizeNode* block_realize) final { + // Constructing the mapping from block iterators to iterator + // binding values. The mapping will be used in the substitution of + // the flattened buffer access index. + const Block& block = block_realize->block; + const Array& block_iter = block->iter_vars; + const Array& iter_values = block_realize->iter_values; + ICHECK_EQ(block_iter.size(), iter_values.size()); + int n_iter = block_iter.size(); + for (int i = 0; i < n_iter; ++i) { + // To detect the reshape pattern, we require each block iter to be data-parallel. + if (block_iter[i]->iter_type != tir::IterVarType::kDataPar) { + return; + } + var_map_.Set(block_iter[i]->var, iter_values[i]); + } + + // Recurse into the block. + this->VisitStmt(block); + } + + void VisitStmt_(const BlockNode* block) final { + // Step 0. If the block body is a ForNode, recurse into it. + if (block->body->IsInstance()) { + this->VisitStmt(block->body); + return; + } + + // Step 1. Get the load/store pattern of the block body. + // To detect the reshape pattern, we require the block body to be a + // BufferStore, which has a BufferLoad as value. + const auto* buffer_store = block->body.as(); + if (buffer_store == nullptr) { + return; + } + const auto* buffer_load = buffer_store->value.as(); + if (buffer_load == nullptr) { + return; + } + // Further, we require the buffer being stored and being loaded to + // match the parameter of the PrimFunc, namely `dst_buffer_` and `src_buffer_`. + if (!(buffer_store->buffer.same_as(dst_buffer_) && + buffer_load->buffer.same_as(src_buffer_))) { + return; + } + + // Step 3. Calculate the flattened access index according to the load/store pattern. + auto f_calc_flattened_idx = [](const Buffer& buffer, const Array& indices) { + ICHECK_EQ(indices.size(), buffer->shape.size()); + int ndim = indices.size(); + PrimExpr idx = 0; + for (int i = 0; i < ndim; ++i) { + idx = idx * buffer->shape[i] + indices[i]; + } + return idx; + }; + PrimExpr src_idx = f_calc_flattened_idx(src_buffer_, buffer_load->indices); + PrimExpr dst_idx = f_calc_flattened_idx(dst_buffer_, buffer_store->indices); + + // Step 4. Substitute the block iterators in the flattened index + // with loop variables, and check if we can prove their equality. + src_idx = tir::Substitute(std::move(src_idx), var_map_); + dst_idx = tir::Substitute(std::move(dst_idx), var_map_); + if (ana_.CanProveEqual(src_idx, dst_idx)) { + this->is_reshape_ = true; + } + } + + bool is_reshape_; + /*! \brief The mapping from block vars to block binding values. */ + Map var_map_; + const Buffer& src_buffer_; + const Buffer& dst_buffer_; + arith::Analyzer ana_; + }; + + if (func->params.size() < 2) { + return false; + } + Optional src_buffer = func->buffer_map.Get(func->params.front()); + Optional dst_buffer = func->buffer_map.Get(func->params.back()); + if (!(src_buffer.defined() && dst_buffer.defined())) { + return false; + } + + // To detect the reshape pattern, we require each For to have + // either another For or a BlockRealize as body. + ICHECK(func->body->IsInstance()); + return ReshapeDetector::Detect(src_buffer.value(), dst_buffer.value(), func->body); +} + +TVM_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/vm_builtin_lower.cc new file mode 100644 index 000000000000..6613b39626da --- /dev/null +++ b/src/relax/backend/vm/vm_builtin_lower.cc @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/backend/vm/vm_builtin_lower.cc + * \brief Lowers most builtin functions and packed calls. + */ +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +// This pass lowers most ops to VM specific builtins. +// TODO(relax-team): revisit after PrimValue. +class VMBuiltinLowerMutator : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + + // A workaround to remove the CallNodes of killing tensors and storages. + void VisitBinding_(const VarBindingNode* binding) final { + const auto* call = binding->value.as(); + if (call != nullptr && (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_)) { + return; + } + ExprMutator::VisitBinding_(binding); + } + + Expr VisitExpr_(const CallNode* call_node) final { + // post-order mutation + Call call = Downcast(VisitExprPostOrder_(call_node)); + + if (call->op == call_tir_dyn_op_) { + return CallTIRDyn(call); + } else if (call->op == reshape_op_) { + return Reshape(call); + } else if (call->op == make_closure_op_) { + return MakeClosure(call); + } else if (call->op == invoke_closure_op_) { + return InvokeClosure(call); + } else if (call->op == alloc_tensor_op_) { + return MakeAllocTensor(call); + } else if (call->op == mem_alloc_storage_op_) { + return MakeMemAllocStorage(call); + } else if (call->op == mem_alloc_tensor_op_) { + return MakeMemAllocTensor(call); + } else { + return call; + } + } + + Expr ComputeStorageSize(const Expr& shape, const DataType& dtype) const { + // Question: what if the dtype of tensor_type is unknown? + // Symbolic/static shape case + if (auto* shape_expr = shape.as()) { + int64_t elem_bytes = runtime::GetVectorBytes(dtype); + PrimExpr ret = IntImm(DataType::Int(64), elem_bytes); + for (PrimExpr dim : shape_expr->values) { + ret = ret * dim; + } + return ShapeExpr({ret}); + } else { + return Call(builtin_compute_alloc_shape_, {shape, DataTypeImm(dtype)}, Attrs(), + {GetStructInfo(shape)}); + } + } + + Expr MakeAllocTensor(const Call& call) { + ShapeExpr output_shape = Downcast(call->args[0]); + DataTypeImm output_dtype = Downcast(call->args[1]); + DataType dtype = output_dtype->value; + Expr storage_size = ComputeStorageSize(output_shape, dtype); + PrimValue runtime_device_index = Downcast(call->args[2]); + Var storage = builder_->Emit( + Call(vm_alloc_storage_op_, {storage_size, runtime_device_index, output_dtype}, Attrs()), + "storage"); + Expr shape = call->args[0]; + PrimValue offset = PrimValue::Int64(0); + return Call(vm_alloc_tensor_op_, {storage, offset, shape, DataTypeImm(dtype)}, Attrs()); + } + + Expr MakeMemAllocStorage(const Call& call) { + PrimValue runtime_device_index = Downcast(call->args[1]); + DataTypeImm output_dtype = Downcast(call->args[3]); + return Call(vm_alloc_storage_op_, {call->args[0], runtime_device_index, output_dtype}, Attrs()); + } + + Expr MakeMemAllocTensor(const Call& call) { + PrimValue offset = Downcast(call->args[1]); + DataTypeImm dtype = Downcast(call->args[3]); + return Call(vm_alloc_tensor_op_, {call->args[0], offset, call->args[2], dtype}, Attrs()); + } + + Expr CallTIRDyn(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->args[0]->IsInstance()); + ICHECK(call_node->args[1]->IsInstance()); + Array args; + + auto tir_args = Downcast(call_node->args[1]); + args.push_back(call_node->args[0]); + for (Expr arg : tir_args->fields) { + args.push_back(arg); + } + return Call(builtin_call_tir_dyn_, args, Attrs(), {void_sinfo_}); + } + + Expr Reshape(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->struct_info_.defined()); + CHECK(call_node->args[1]->IsInstance()) + << "VMBuiltinLower expects the shape arg of reshape op to be a ShapeExpr"; + return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + } + + Expr MakeClosure(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->args[0]->IsInstance()); + ICHECK(call_node->args[1]->IsInstance()); + + Array args; + auto func = call_node->args[0]; + auto closure_args = Downcast(call_node->args[1]); + + args.push_back(func); + for (Expr arg : closure_args->fields) { + args.push_back(arg); + } + + return Call(builtin_make_closure_, args, Attrs(), {object_sinfo_}); + } + + Expr InvokeClosure(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->args[0]->IsInstance()); + ICHECK(call_node->args[1]->IsInstance()); + + Array args; + + args.push_back(call_node->args[0]); + + // args for the invoke_closure + auto invoke_closure_args = Downcast(call_node->args[1]); + for (Expr arg : invoke_closure_args->fields) { + args.push_back(arg); + } + return Call(call_builtin_with_ctx_op_, {builtin_invoke_closure_, Tuple(args)}, Attrs(), + {object_sinfo_}); + } + + const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); + const StructInfo object_sinfo_ = ObjectStructInfo(); + const StructInfo void_sinfo_ = TupleStructInfo(Array({})); + // object to pattern match. + const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); + const Op& reshape_op_ = Op::Get("relax.reshape"); + const Op& make_closure_op_ = Op::Get("relax.make_closure"); + const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); + const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor"); + const Op& mem_alloc_storage_op_ = Op::Get("relax.memory.alloc_storage"); + const Op& mem_alloc_tensor_op_ = Op::Get("relax.memory.alloc_tensor"); + const Op& mem_kill_storage_op_ = Op::Get("relax.memory.kill_storage"); + const Op& mem_kill_tensor_op_ = Op::Get("relax.memory.kill_tensor"); + // functions to lower to + const Op& vm_alloc_storage_op_ = Op::Get("relax.vm.alloc_storage"); + const Op& vm_alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor"); + // Function to compute allocated shape. + const ExternFunc builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"}; + const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"}; + const ExternFunc builtin_reshape_{"vm.builtin.reshape"}; + const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; + const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; +}; + +Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); } + +namespace transform { + +Pass VMBuiltinLower() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(VMBuiltinLower(f)); }; + return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index ca66b0a9ef75..ba167a45bc68 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -226,6 +226,87 @@ Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t runtime_device_index) { TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor); +// memory planning alloc_storage + +RELAY_REGISTER_OP("relax.memory.alloc_storage") + .set_num_inputs(4) + .add_argument("total_space", "Expr", "The total space of the storage to allocate.") + .add_argument( + "virtual_device_index", "int64_t", + "The virtual device index indicating on which device the storage is to be allocated, " + "Index -1 is reserved for the host device.") + .add_argument("storage_scope", "string", + "The storage scope of the storage to allocate. Default is global.") + .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.") + .set_attr("FInferStructInfo", ReturnObjectStructInfo); + +Expr MakeAllocStorage(Expr size, int64_t virtual_device_index, std::string storage_scope, + DataType dtype) { + static const Op& op = Op::Get("relax.memory.alloc_storage"); + return Call( + op, + {size, PrimValue::Int64(virtual_device_index), StringImm(storage_scope), DataTypeImm(dtype)}, + Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocStorage); + +// memory planning alloc_tensor + +StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& ctx) { + ICHECK(GetStructInfoAs(call->args[2])) + << "must be a Expr of ShapeStructInfo, but got " << call->args[1]->GetTypeKey(); + DataType out_dtype; + if (const auto* dtype_node = call->args[3].as()) { + const DataTypeImm dtype_imm = GetRef(dtype_node); + out_dtype = dtype_imm->value; + } + return TensorStructInfo(call->args[2], out_dtype); +} + +RELAY_REGISTER_OP("relax.memory.alloc_tensor") + .set_num_inputs(4) + .add_argument("storage", "Expr", "The storage to allocate the tensor to.") + .add_argument("offset", "int", "Storage offset to allocate the tensor.") + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.") + .set_attr("FInferStructInfo", InferStructInfoMemAllocTensor); + +Expr MakeMemAllocTensor(Expr storage, int offset, Expr shape, DataType dtype) { + static const Op& op = Op::Get("relax.memory.alloc_tensor"); + return Call(op, {storage, PrimValue::Int64(offset), shape, DataTypeImm(dtype)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.alloc_tensor").set_body_typed(MakeMemAllocTensor); + +// memory planning kill_storage + +RELAY_REGISTER_OP("relax.memory.kill_storage") + .set_num_inputs(1) + .add_argument("storage", "Expr", "The storage to be killed.") + .set_attr("FInferStructInfo", ReturnVoidStructInfo); + +Expr MakeMemKillStorage(Expr storage) { + static const Op& op = Op::Get("relax.memory.kill_storage"); + return Call(op, {storage}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.kill_storage").set_body_typed(MakeMemKillStorage); + +// memory planning kill_tensor + +RELAY_REGISTER_OP("relax.memory.kill_tensor") + .set_num_inputs(1) + .add_argument("tensor", "Expr", "The tensor to be killed.") + .set_attr("FInferStructInfo", ReturnVoidStructInfo); + +Expr MakeMemKillTensor(Expr tensor) { + static const Op& op = Op::Get("relax.memory.kill_tensor"); + return Call(op, {tensor}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.kill_tensor").set_body_typed(MakeMemKillTensor); + // vm alloc_storage RELAY_REGISTER_OP("relax.vm.alloc_storage") diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc new file mode 100644 index 000000000000..2088a8306e7a --- /dev/null +++ b/src/relax/op/tensor/manipulate.cc @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file manipulate.cc + * \brief Manipulation operators. + */ + +#include "manipulate.h" + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +// Helper function for flatten and reshape. +PrimExpr ComputeShapeProduct(const Array& shape_values) { + PrimExpr shape_prod = IntImm(DataType::Int(64), 1); + for (PrimExpr value : shape_values) { + shape_prod *= value; + } + return shape_prod; +} + +/* relax.reshape */ +Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { + if (const auto* e = shape.as()) { + return GetRef(e); + } + + const auto* array = shape.as(); + CHECK(array != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " + "Array of PrimExprs. However, the given new shape is " + << shape; + int dim_to_infer = -1; + PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1); + for (int i = 0; i < static_cast(array->size()); ++i) { + const auto* _len = array->at(i).as(); + CHECK(_len != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " + "Array of PrimExprs. However, the given new shape is " + << shape; + PrimExpr len = GetRef(_len); + CHECK(len->dtype.is_int()) << "Reshape requires the new shape values to be all " + "integers. However, the give new shape is " + << shape; + const auto* int_len = len.as(); + if (int_len != nullptr && int_len->value == -1) { + CHECK_EQ(dim_to_infer, -1) << "Reshape accepts at most one \"-1\" in the new shape. However, " + "there are multiple \"-1\" in the given new shape " + << shape; + dim_to_infer = i; + } else { + CHECK(int_len == nullptr || int_len->value > 0) + << "Reshape requires all values in the new shape to be positive except a single \"-1\". " + "However, the given new shape is " + << shape; + // We expect any symbolic not to signal the intent of -1, and therefore do no check for + // symbolic value here. + new_shape_prod = new_shape_prod * len; + } + } + + Array array_ref = GetRef>(array); + // When there is no dimension to infer, just return the input array as ShapeExpr. + if (dim_to_infer == -1) { + return ShapeExpr(array_ref); + } + + // Otherwise, we require the input tensor to have known shape value for inference. + const auto* data_sinfo = GetStructInfoAs(data); + CHECK(data_sinfo != nullptr) + << "Reshape expects the input data to be a Tensor. However, the given input is " + << data->struct_info_->GetTypeKey(); + CHECK(data_sinfo->shape.defined()) + << "Reshape expects the input tensor to have known shape when there is some dimension length " + "to infer. However, the given input has no shape."; + const auto* shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); + CHECK(shape_sinfo != nullptr && shape_sinfo->values.defined()) + << "Reshape expects the input tensor to have known shape when there is some dimension length " + "to infer. However, the given input shape is " + << data_sinfo->shape << " whose shape value is unknown."; + + arith::Analyzer analyzer; + PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value()); + array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod))); + return ShapeExpr(array_ref); +} + +Expr reshape(Expr x, ObjectRef shape) { + Expr shape_in_expr = ConvertNewShapeToExpr(x, shape); + static const Op& op = Op::Get("relax.reshape"); + return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.reshape").set_body_typed(reshape); + +StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call->span) << "Reshape op should take 2 arguments"); + } + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* new_shape_sinfo = GetStructInfoAs(call->args[1]); + if (data_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "Reshape requires the input data to be Tensor. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (new_shape_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call->span) + << "Reshape requires the input new shape to be Shape. However, the given one is " + << call->args[1]->struct_info_->GetTypeKey()); + } + + Optional> old_shape_values; + if (data_sinfo->shape.defined()) { + const auto* old_shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); + ICHECK_NOTNULL(old_shape_sinfo); + old_shape_values = old_shape_sinfo->values; + } + + if (new_shape_sinfo->values.defined() && old_shape_values.defined()) { + PrimExpr new_shape_prod = ComputeShapeProduct(new_shape_sinfo->values.value()); + PrimExpr old_shape_prod = ComputeShapeProduct(old_shape_values.value()); + if (ctx->GetAnalyzer()->CanProve(old_shape_prod != new_shape_prod)) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "Reshape expects the new shape to be convertible from the old shape. " + "However, the old shape is " + << data_sinfo->shape << ", with product " << old_shape_prod + << ", while the new shape is " << call->args[1] << ", with product " + << new_shape_prod); + } + } + return TensorStructInfo(call->args[1], data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.reshape") + .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("shape", "Shape", "The input new shape.") + .set_attr("FInferStructInfo", InferStructInfoReshape); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h new file mode 100644 index 000000000000..1a3eb0547d7f --- /dev/null +++ b/src/relax/op/tensor/manipulate.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file manipulate.h + * \brief The functions to make Relax tensor manipulation operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_MANIPULATE_H_ +#define TVM_RELAX_OP_TENSOR_MANIPULATE_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Reshape the input array, supporting `-1` inference in the new + * shape when the new shape is given as an Array of PrimExpr. + * \param x The input data to the operator. + * \param shape The new shape. Should be compatible with the original shape. + * It is required to be either an Array of PrimExpr, or a Shape in Relax + * \return The reshaped result. + */ +Expr reshape(Expr x, ObjectRef shape); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_MANIPULATE_H_ diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc new file mode 100644 index 000000000000..be779e97bcf5 --- /dev/null +++ b/src/relax/transform/attach_global_symbol.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/attach_global_symbol.cc + * \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen. + */ + +#include +#include + +namespace tvm { +namespace relax { + +class GlobalSymbolAttacher { + public: + explicit GlobalSymbolAttacher(IRModule mod) : mod_(mod) {} + + IRModule Attach() { + IRModule ret; + for (auto& p : mod_->functions) { + BaseFunc func = p.second; + if (auto* prim_func = func.as()) { + func = WithAttr(GetRef(prim_func), "global_symbol", p.first->name_hint); + } else if (auto* relax_func = func.as()) { + func = WithAttr(GetRef(relax_func), "global_symbol", p.first->name_hint); + } else { + LOG(FATAL) << "Unsupported function type: " << func->GetTypeKey(); + throw; + } + ret->Add(p.first, func); + } + return ret; + } + + private: + IRModule mod_; +}; + +namespace transform { + +Pass AttachGlobalSymbol() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return GlobalSymbolAttacher(mod).Attach(); }; + return CreateModulePass(pass_func, 0, "AttachGlobalSymbol", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.AttachGlobalSymbol").set_body_typed(AttachGlobalSymbol); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc new file mode 100644 index 000000000000..2ea039e0229b --- /dev/null +++ b/src/relax/transform/call_tir_rewrite.cc @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/call_tir_rewrite.cc + * \brief Perform explicit tensor allocation for call_tir. + */ +#include +#include +#include +#include +#include + +#include "../../relay/transforms/pattern_utils.h" + +namespace tvm { +namespace relax { + +// ================== +// CallTIRMutator +// Perform explicit tensor allocation for call_tir. +// Example: +// lv0: Tensor(n, m) = rx.call_tir(func, (x), (n, m), dtype="float32") +// --> +// gv0 = rx.call("relax.builtin.alloc_tensor", [n, m], dtype="float32") +// rx.call_packed(func, x, gv0) + +class CallTIRMutator : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const CallNode* call) override { + // post-order mutation + Expr expr = VisitExprPostOrder_(call); + call = expr.as(); + + static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn"); + + if (call->op == call_tir_op) { + Array outs; + if (const auto& _tensor_sinfo = MatchStructInfo(expr)) { + // single output case + const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value(); + ICHECK(tensor_sinfo->shape.defined()) + << "the TensorStructInfo shape of call_tir has not populated"; + outs.push_back( + builder_->Emit(Call(alloc_tensor_op, // + {Downcast(tensor_sinfo->shape.value()), + DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(0)}, // + Attrs()), + "alloc")); + } else if (const auto& _tuple_sinfo = MatchStructInfo(expr)) { + // multiple output case + const TupleStructInfo& tuple_sinfo = _tuple_sinfo.value(); + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + const auto& field = tuple_sinfo->fields[i]; + + ICHECK(field->IsInstance()) + << "call_tir expects Tuple of TensorStructInfo, but got " << field + << " as an element of TupleStructInfo"; + const auto& field_tensor = Downcast(field); + ICHECK(field_tensor->shape.defined()) + << "call_tir expects all TensorStructInfo has shape, but got " << field_tensor + << " as an element of TupleStructInfo"; + outs.push_back( + builder_->Emit(Call(alloc_tensor_op, + {Downcast(field_tensor->shape.value()), + DataTypeImm(field_tensor->dtype), PrimValue::Int64(0)}, + Attrs()), + "alloc")); + } + } else { + LOG(FATAL) << "TypeError: The struct info of call_tir expects to be TensorStructInfo or " + "TupleStructInfo, but got" + << expr->struct_info_; + } + + Array args; + if (call->args[1].as()) { + args = Downcast(call->args[1])->fields; + args.insert(args.end(), outs.begin(), outs.end()); + + if (call->args.size() == 2) { + builder_->Emit(Call(call->args[0], args), "_"); + } else { + // unpack semantics + args.push_back(call->args[2]); + builder_->Emit(Call(call_tir_dyn_op, {call->args[0], Tuple(args)}), "_"); + } + } else { + args = outs; + args.insert(args.begin(), call->args[1]); + builder_->Emit(Call(call->args[0], args), "_"); + } + + if (outs.size() == 1) { + return outs[0]; + } + return std::move(Tuple(outs)); + } + + return GetRef(call); + } +}; + +Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); } + +namespace transform { + +Pass CallTIRRewrite() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(CallTIRRewrite(f)); }; + return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc new file mode 100644 index 000000000000..aec0911ecc5a --- /dev/null +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/rewrite_dataflow_reshape.cc + * \brief Transform all reshape within dataflow block to a relax.reshape operator + */ +#include +#include +#include + +#include "../op/tensor/manipulate.h" + +namespace tvm { +namespace relax { + +class DataflowReshapeRewriter : public ExprMutator { + public: + explicit DataflowReshapeRewriter(const IRModule& mod) : mod_(mod) {} + + private: + using ExprMutator::VisitExpr_; + + BindingBlock VisitBindingBlock(const BindingBlock& block) final { + // We only rewrite the bindings inside dataflow blocks. + if (const auto* dataflow_block = block.as()) { + return VisitBindingBlock_(dataflow_block); + } else { + return block; + } + } + + void VisitBinding_(const VarBindingNode* binding) final { + // We only rewrite the bindings that are not dataflow output (which means they are not + // externally referenced) + if (!binding->var->IsInstance()) { + this->builder_->EmitNormalized(GetRef(binding)); + } else { + ExprMutator::VisitBinding_(binding); + } + } + + Expr VisitExpr_(const CallNode* call) final { + if (!IsCallingTIRReshape(call)) { + return GetRef(call); + } + + // We bring the calls of reshape PrimFunc back to calls of high-level + // relax.reshape op, which will be lowered to calls of the ExternFunc + // vm.builtin.reshape in the VMBuiltinLower pass. + Array args = Downcast(call->args[1])->fields; + ICHECK_EQ(args.size(), 1); + TensorStructInfo res_sinfo = Downcast(call->struct_info_); + ICHECK(res_sinfo->shape.defined()); + return reshape(args[0], res_sinfo->shape.value()); + } + + bool IsCallingTIRReshape(const CallNode* call) { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (call->op != call_tir_op) { + return false; + } + const auto* gv = call->args[0].as(); + if (gv == nullptr) { + return false; + } + const auto* func = mod_->functions.Get(GetRef(gv)).as(); + ICHECK_NOTNULL(func); + return HasReshapePattern(GetRef(func)); + } + + const IRModule& mod_; +}; + +Expr RewriteDataflowReshape(const Function& f, const IRModule& mod) { + return DataflowReshapeRewriter(mod)(f); +} + +namespace transform { + +Pass RewriteDataflowReshape() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(RewriteDataflowReshape(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "RewriteDataflowReshape", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.RewriteDataflowReshape") + .set_body_typed(RewriteDataflowReshape); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc new file mode 100644 index 000000000000..db2e9d7ee5e7 --- /dev/null +++ b/src/relax/transform/to_non_dataflow.cc @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/to_non_dataflow.cc + * \brief Transform all dataflow structure to non-dataflow version. + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class ToNonDFMutator : public ExprMutator { + public: + Var VisitVarDef(const Var& var) final { + if (var.as()) { + Var new_var = Var(var->vid, GetStructInfo(var), var->span); + this->var_remap_[var->vid] = new_var; + return new_var; + } + return var; + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); + } +}; + +Expr ToNonDataflow(const Expr& e) { return ToNonDFMutator().VisitExpr(e); } + +namespace transform { + +Pass ToNonDataflow() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(ToNonDataflow(f)); }; + return CreateFunctionPass(pass_func, 0, "ToNonDataflow", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.ToNonDataflow").set_body_typed(ToNonDataflow); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py new file mode 100644 index 000000000000..5dd83f2da24c --- /dev/null +++ b/tests/python/relax/test_analysis.py @@ -0,0 +1,172 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Set, Union + +import tvm +import tvm.testing +from tvm import tir +from tvm import relax as rx +from tvm.relax.analysis import has_reshape_pattern +from tvm.script import relax as R, tir as T + + +def test_reshape_pattern_reshape(): + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((1, 2, 3, 4), "float32"), + T_reshape: T.Buffer((8, 3), "float32"), + ): + for i0, i1 in T.grid(8, 3): + with T.block("T_reshape"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads( + rxplaceholder[ + (ax0 * 3 + ax1) // 24, + (ax0 * 3 + ax1) % 24 // 12, + (ax0 * 3 + ax1) % 12 // 4, + (ax0 * 3 + ax1) % 4, + ] + ) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = rxplaceholder[ + (ax0 * 3 + ax1) // 24, + (ax0 * 3 + ax1) % 24 // 12, + (ax0 * 3 + ax1) % 12 // 4, + (ax0 * 3 + ax1) % 4, + ] + + assert has_reshape_pattern(reshape) + + +def test_reshape_pattern_reshape_scheduled(): + @T.prim_func + def reshape_scheduled( + rxplaceholder: T.Buffer((1, 2, 3, 4), "float32"), + T_reshape: T.Buffer((8, 3), "float32"), + ): + for i0_i1_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i0_i1_fused_1 in T.thread_binding(24, thread="threadIdx.x"): + with T.block("T_reshape"): + ax0 = T.axis.spatial(8, (i0_i1_fused_0 * 24 + i0_i1_fused_1) // 3) + ax1 = T.axis.spatial(3, (i0_i1_fused_0 * 24 + i0_i1_fused_1) % 3) + T.reads( + rxplaceholder[ + (ax0 * 3 + ax1) // 24, + (ax0 * 3 + ax1) % 24 // 12, + (ax0 * 3 + ax1) % 12 // 4, + (ax0 * 3 + ax1) % 4, + ] + ) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = rxplaceholder[ + (ax0 * 3 + ax1) // 24, + (ax0 * 3 + ax1) % 24 // 12, + (ax0 * 3 + ax1) % 12 // 4, + (ax0 * 3 + ax1) % 4, + ] + + assert has_reshape_pattern(reshape_scheduled) + + +def test_reshape_pattern_expand_dims(): + @T.prim_func + def expand_dims( + rxplaceholder: T.Buffer((2, 3, 4), "float32"), + expand_dims: T.Buffer((2, 1, 1, 1, 3, 1, 4, 1), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(2, 1, 1, 1, 3, 1, 4, 1): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap( + "SSSSSSSS", [i0, i1, i2, i3, i4, i5, i6, i7] + ) + T.reads(rxplaceholder[i0_1, i4_1, i6_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1] = rxplaceholder[ + i0_1, i4_1, i6_1 + ] + + assert has_reshape_pattern(expand_dims) + + +def test_reshape_pattern_with_raggedness(): + @T.prim_func + def reshape_raggedness( + A: T.Buffer((100, 768), "float32"), + src_indptr: T.Buffer((9,), "int32"), + B: T.Buffer((100, 12, 64), "float32"), + ): + for b in T.serial(8): + with T.block("block0"): + vb = T.axis.spatial(8, b) + for i in T.serial(src_indptr[vb + 1] - src_indptr[vb]): + for h in T.serial(12): + for f in T.serial(64): + with T.block("block1"): + vi, vh, vf = T.axis.remap("SSS", [i, h, f]) + B[src_indptr[vb] + vi, vh, vf] = A[ + src_indptr[vb] + vi, vh * 64 + vf + ] + + assert has_reshape_pattern(reshape_raggedness) + + +def test_reshape_pattern_reject_seqstmt(): + @T.prim_func + def identity_bias(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")): + C = T.alloc_buffer((128, 128), "float32") + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + C[vi0, vi1] = A[vi0, vi1] + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + B[vi0, vi1] = C[vi0, vi1] + T.float32(1) + + @T.prim_func + def identity_identity(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")): + C = T.alloc_buffer((128, 128), "float32") + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + C[vi0, vi1] = A[vi0, vi1] + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + B[vi0, vi1] = C[vi0, vi1] + + assert not has_reshape_pattern(identity_bias) + assert not has_reshape_pattern(identity_identity) + + +def test_reshape_pattern_reject_reduction(): + @T.prim_func + def reduction(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4,), "float32")): + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SR", [i0, i1]) + with T.init(): + B[vi0] = T.float32(0) + B[vi0] = B[vi0] + A[vi0, vi1] + + assert not has_reshape_pattern(reduction) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py new file mode 100644 index 000000000000..624b7877cd11 --- /dev/null +++ b/tests/python/relax/test_transform.py @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +from tvm import relax +from tvm.ir import structural_equal +from tvm.ir.base import assert_structural_equal + +import tvm.script +from tvm.script import tir as T, relax as R + + +def test_to_non_dataflow(): + @tvm.script.ir_module + class TestToNonDataflow: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.var("int64"), T.var("int64") + with R.dataflow(): + lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32")) + R.output(gv0) + return gv0 + + mod = TestToNonDataflow + + old_vars = [] + + def fvisit(e): + if isinstance(e, relax.Var): + nonlocal old_vars + old_vars.append(e) + + relax.analysis.post_order_visit(mod["foo"], fvisit) + x, lv0, gv0 = old_vars + + new_mod = relax.transform.ToNonDataflow()(mod) + + new_vars = [] + + def fvisit(e): + if isinstance(e, relax.Var): + nonlocal new_vars + new_vars.append(e) + + relax.analysis.post_order_visit(new_mod["foo"], fvisit) + + assert x == new_vars[0] + assert lv0 != new_vars[1] + assert isinstance(lv0, relax.DataflowVar) + assert not isinstance(new_vars[1], relax.DataflowVar) + + assert isinstance(gv0, relax.Var) + assert isinstance(new_vars[2], relax.Var) + assert gv0 == new_vars[2] + + +def test_call_tir_rewrite(): + @tvm.script.ir_module + class TestCallTIRRewrite: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.var("int64"), T.var("int64") + gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + return gv0 + + mod = TestCallTIRRewrite + + # before rewrite + v0 = mod["foo"].body.blocks[0].bindings[0].var + s0 = mod["foo"].body.blocks[0].bindings[0].value + assert isinstance(s0, relax.Call) + assert s0.op.name == "relax.call_tir" + + # after rewrite + new_mod = relax.transform.CallTIRRewrite()(mod) + func = new_mod["foo"] + + block = func.body.blocks[0] + assert not isinstance(block, relax.DataflowBlock) + + s1 = block.bindings[0].value + assert isinstance(s1, relax.Call) + assert s1.op.name == "relax.builtin.alloc_tensor" + assert isinstance(s1.args[0], relax.ShapeExpr) + assert structural_equal(s1.args[0], s0.sinfo_args[0].shape) + s2 = block.bindings[1].value + assert s2.op.global_symbol == "test.op.identity" + + +def test_vm_builtin_lower(): + @tvm.script.ir_module + class TestVMBuiltinLower: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: + m, n = T.var("int64"), T.var("int64") + alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0, dtype="float32") + _ = R.call_packed( + "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + ) + gv0 = alloc + return gv0 + + mod = TestVMBuiltinLower + + # after vm builtin lowering + new_mod = relax.transform.VMBuiltinLower()(mod) + func = new_mod["foo"] + + assert isinstance(new_mod, tvm.IRModule) + assert isinstance(func, tvm.relax.expr.Function) + + block = func.body.blocks[0] + s1 = block.bindings[0].value + assert isinstance(s1, relax.Call) + assert s1.op.name == "relax.vm.alloc_storage" + s2 = block.bindings[1].value + assert isinstance(s2, relax.Call) + s3 = block.bindings[2].value + assert isinstance(s3, relax.Call) + assert isinstance(s3.op, relax.ExternFunc) + assert s3.op.global_symbol == "test.op.identity" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_attach_global_symbol.py b/tests/python/relax/test_transform_attach_global_symbol.py new file mode 100644 index 000000000000..edfc646e2108 --- /dev/null +++ b/tests/python/relax/test_transform_attach_global_symbol.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +from tvm import tir, relax +from tvm.ir import assert_structural_equal + +import tvm.script +from tvm.script import tir as T, relax as R + + +@tvm.script.ir_module +class Before: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + m = T.var("int64") + n = T.var("int64") + k = T.var("int64") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")) -> R.Tensor: + m, n, k = T.var("int64"), T.var("int64"), T.var("int64") + gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) + return gv0 + + +def test_basic(): + @tvm.script.ir_module + class Expected: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int64") + n = T.var("int64") + k = T.var("int64") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def main( + x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") + ) -> R.Tensor: + R.func_attr({"global_symbol": "main"}) + m, n, k = T.var("int64"), T.var("int64"), T.var("int64") + gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) + return gv0 + + before = Before + expected = Expected + after = relax.transform.AttachGlobalSymbol()(before) + assert_structural_equal(after, expected) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py new file mode 100644 index 000000000000..2c53d85c5636 --- /dev/null +++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R, tir as T + + +def test_reshape_expand_dims(): + @tvm.script.ir_module + class Module: + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"), + T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + ): + for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + rxplaceholder[ + (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3), + (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3), + ] + ) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[ + (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3), + (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3), + ] + + @T.prim_func + def expand_dims( + rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + expand_dims: T.Buffer( + (T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)), + "float32", + ), + ): + for i0, i1, i2, i3, i4 in T.grid( + T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3) + ): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(rxplaceholder[i0_1, i2_1, i4_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i2_1, i4_1] + + @R.function + def main( + x: R.Tensor((8, 3), dtype="float32") + ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"): + with R.dataflow(): + y = R.call_tir(reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32")) + z = R.call_tir(expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4, 1, 3), "float32")) + R.output(z) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"), + T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + ): + for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + rxplaceholder[ + (v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2) // T.int64(3), + (v_ax1 * T.int64(12) + v_ax2 * T.int64(3) + v_ax2) % T.int64(3), + ] + ) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[ + (v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2) // T.int64(3), + (v_ax1 * T.int64(12) + v_ax2 * T.int64(3) + v_ax2) % T.int64(3), + ] + + @T.prim_func + def expand_dims( + rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + expand_dims: T.Buffer( + (T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)), "float32" + ), + ): + for i0, i1, i2, i3, i4 in T.grid( + T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3) + ): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(rxplaceholder[i0_1, i2_1, i4_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i2_1, i4_1] + + @R.function + def main( + x: R.Tensor((8, 3), dtype="float32") + ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"): + with R.dataflow(): + y: R.Tensor((2, 4, 3), "float32") = R.reshape(x, (2, 4, 3)) + # Note: `z` is the output var of the dataflow block, and is thus + # not expected to be rewritten. + z = R.call_tir( + expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4, 1, 3), dtype="float32") + ) + R.output(z) + return z + + assert relax.analysis.has_reshape_pattern(Module["expand_dims"]) + mod = relax.transform.RewriteDataflowReshape()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_reshape_non_dataflow(): + @tvm.script.ir_module + class Module: + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"), + T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + ): + for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + rxplaceholder[ + (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3), + (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3), + ] + ) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[ + (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3), + (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3), + ] + + @R.function + def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), dtype="float32"): + y = R.call_tir(reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32")) + return y + + assert relax.analysis.has_reshape_pattern(Module["reshape"]) + # The binding var of the call_tir is not a DataflowVar. So the pass does no change. + mod = relax.transform.RewriteDataflowReshape()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py new file mode 100644 index 000000000000..534d2308daa9 --- /dev/null +++ b/tests/python/relax/test_vm_build.py @@ -0,0 +1,908 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from typing import Tuple, Callable + +import sys +import tempfile +import numpy as np +import pytest +import tvm +import tvm.script +import tvm.testing +from tvm import relax, rpc, te, tir, topi +from tvm.contrib import utils +from tvm.relax.testing import nn +from tvm.script import relax as R, tir as T +from tvm.relax.testing.vm import check_saved_func + +EXEC_MODE = ["bytecode"] + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_simple(exec_mode): + @tvm.script.ir_module + class TestVMCompileStage0: + @R.function + def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): + z = R.call_packed( + "test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + ) + return y + + mod = TestVMCompileStage0 + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + inp1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + vm = relax.VirtualMachine(ex, tvm.cpu()) + vm["foo"](inp1, inp2) + tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_match_check(exec_mode): + @tvm.script.ir_module + class TestMatchCheck: + @R.function + def foo(x: R.Tensor(["n", "m"], "int32"), y: R.Object) -> R.Tensor(["m", "n"], dtype=None): + return y + + mod = TestMatchCheck + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x0 = tvm.nd.array(np.zeros((1, 2)).astype("int32")) + y0 = tvm.nd.array(np.zeros((2, 1)).astype("float32")) + y1 = tvm.nd.array(np.zeros((1, 2)).astype("float32")) + y2 = tvm.nd.array(np.zeros((2, 1, 1)).astype("float32")) + + vm["foo"](x0, y0) + + with pytest.raises(RuntimeError, match=".*return.*"): + vm["foo"](x0, y1) + + with pytest.raises(ValueError, match=".*return.*"): + vm["foo"](x0, y2) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_stage2(exec_mode): + @tvm.script.ir_module + class TestVMCompileStage2: + @R.function + def foo(x: R.Tensor(dtype="float32")) -> R.Shape: + n, m = T.var("int64"), T.var("int64") + _ = R.match_cast(x, R.Tensor((n, m), "float32")) + return (n * 2, m * 3) + + mod = TestVMCompileStage2 + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + shape = (32, 16) + arr = tvm.nd.array(np.random.rand(*shape).astype("float32")) + res = vm["foo"](arr) + assert res[0] == shape[0] * 2 + assert res[1] == shape[1] * 3 + + # dtype mismatch + with pytest.raises(ValueError, match=".*dtype.*"): + vm["foo"](tvm.nd.array(np.zeros((1, 2)).astype("int32"))) + + # ndim mismatch + with pytest.raises(ValueError, match=".*match_cast.*ndim.*"): + vm["foo"](tvm.nd.array(np.zeros((1,)).astype("float32"))) + + # type mismach + with pytest.raises(TypeError): + vm["foo"]([]) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_stage3(exec_mode): + @tvm.script.ir_module + class TestVMCompileStage3: + @R.function + def foo(x: R.Tensor((32, 16), "float32")) -> R.Tensor: + with R.dataflow(): + y = R.call_tir("test.vm.identity", (x), R.Tensor((32, 16), dtype="float32")) + R.output(y) + return y + + mod = TestVMCompileStage3 + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + shape = (32, 16) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = vm["foo"](inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_e2e(exec_mode): + @tvm.script.ir_module + class TestVMCompileE2E: + @R.function + def foo(x: R.Tensor(dtype="float32")) -> R.Tensor: + with R.dataflow(): + n, m = T.var("int64"), T.var("int64") + _ = R.match_cast(x, R.Tensor((n, m), "float32")) + y = R.call_tir("test.vm.tile", (x), R.Tensor((n, m * 2), dtype="float32")) + R.output(y) + return y + + mod = TestVMCompileE2E + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + shape = (32, 16) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = check_saved_func(vm, "foo", inp) + tvm.testing.assert_allclose(res.numpy(), np.tile(inp.numpy(), (1, 2)), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_e2e_func_param_with_shape(exec_mode): + @tvm.script.ir_module + class TestVMCompileE2E2: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def func( + x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") + ) -> R.Tensor: + m, k = T.var("int64"), T.var("int64") + gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) + return gv0 + + mod = TestVMCompileE2E2 + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + res = check_saved_func(vm, "func", data, weight) + expected = np.dot(data.numpy(), weight.numpy()) + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_extern(exec_mode): + if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): + print("skip because extern function is not available") + return + bb = relax.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = relax.Var("x", R.Tensor([n, m], "float32")) + y = relax.Var("y", R.Tensor([m, n], "float32")) + + with bb.function("rx_cblas_matmul", [x, y]): + out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False) + bb.emit_func_output(out) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + res = check_saved_func(vm, "rx_cblas_matmul", data, weight) + expected = np.dot(data.numpy(), weight.numpy()) + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_concat(exec_mode): + # concatenate of two vectors of size (n,) and (m,) + bb = relax.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = relax.Var("x", R.Tensor([n], "float32")) + y = relax.Var("y", R.Tensor([m], "float32")) + + def te_func(A, B): + C = te.compute((n + m), lambda i: tvm.tir.if_then_else(i < n, A[i], B[i - n])) + return C + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, x, y) + bb.emit_func_output(x1) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array( + np.random.rand( + 1, + ).astype(np.float32) + ) + inp2 = tvm.nd.array( + np.random.rand( + 2, + ).astype(np.float32) + ) + res = check_saved_func(vm, "rx_func", inp, inp2) + tvm.testing.assert_allclose( + res.numpy(), np.append(inp.numpy(), inp2.numpy()), rtol=1e-7, atol=1e-7 + ) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_dtype_change(exec_mode): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor([n], "float32")) + + # convert a tensor with dtype of float32 to int16 + def te_func(A): + B = te.compute((n,), lambda i: A[i].astype("int16")) + return B + + with bb.function("rx_func", [x]): + y = bb.emit_te(te_func, x) + bb.emit_func_output(y) + + mod = bb.get() + + new_mod = relax.transform.CallTIRRewrite()(mod) + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array( + np.random.rand( + 1, + ).astype(np.float32) + ) + res = check_saved_func(vm, "rx_func", inp) + np.testing.assert_allclose(res.numpy(), inp.numpy().astype("int16")) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_floor_symbolic_shape(exec_mode): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor([n], "float32")) + + def te_func(A): + C = te.compute((tir.floordiv(n, 2),), lambda i: A[i] + 1) + return C + + with bb.function("rx_func", [x]): + x1 = bb.emit_te(te_func, x) + bb.emit_func_output(x1) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + shape = (9,) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = check_saved_func(vm, "rx_func", inp) + + def expected_output(): + output_shape = (shape[0] // 2,) + return inp.numpy()[: output_shape[0]] + 1 + + tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_constant_param_cpu(exec_mode): + x_np = np.random.rand(2, 2).astype("float32") + c_np = np.random.rand(2, 2).astype("float32") + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 2), "float32")) + c = relax.const(c_np, "float32") + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, c) + gv = bb.emit_output(lv0) + bb.emit_func_output(gv) + + mod = bb.get() + exec = relax.vm.build(mod, "llvm", exec_mode=exec_mode) + dev = tvm.cpu() + vm = relax.VirtualMachine(exec, dev) + + add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev)) + tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@tvm.testing.requires_gpu +def test_vm_emit_te_constant_param_gpu(exec_mode): + x_np = np.random.rand(2, 2).astype("float32") + c_np = np.random.rand(2, 2).astype("float32") + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 2), "float32")) + c = relax.const(c_np, "float32") + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, c) + gv = bb.emit_output(lv0) + bb.emit_func_output(gv) + + mod = bb.get() + sch = tvm.tir.Schedule(mod, debug_mask="all") + loops = sch.get_loops(sch.get_block(name="T_add", func_name="add")) + sch.bind(loops[0], "threadIdx.x") + + exec = relax.vm.build(sch.mod, "cuda", exec_mode=exec_mode) + dev = tvm.cuda() + vm = relax.VirtualMachine(exec, dev) + + add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev)) + tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_relax_symbolic_shape(exec_mode): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor([n], "float32")) + y = relax.Var("y", R.Tensor([(n // 2) + 1], "float32")) + + def te_func(A, B): + C = te.compute((n,), lambda i: A[i] + B[i // 2]) + return C + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, x, y) + bb.emit_func_output(x1) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + shape1 = (5,) + shape2 = (3,) + inp = tvm.nd.array(np.random.rand(*shape1).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(*shape2).astype(np.float32)) + res = check_saved_func(vm, "rx_func", inp, inp2) + + def expected_output(): + return inp.numpy() + np.repeat(inp2.numpy(), 2)[:5] + + tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_relax_dyn_tir_shape(exec_mode): + # case where TIR variables are unbound in generated PrimFunc + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + + def te_func(A): + C = te.compute((n + 1), lambda i: A[i]) + return C + + with bb.function("rx_func"): + x = nn.Placeholder((n,), dtype="float32", name="x") + y = nn.Placeholder((n + 1,), dtype="float32", name="y") + + x1 = bb.emit_te(te_func, y) + bb.emit_func_output(x1, params=[x, y]) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + + ex.mod.export_library("exec.so") + exec1 = relax.vm.Executable(tvm.runtime.load_module("exec.so")) + os.remove("exec.so") + assert ex.as_text() == exec1.as_text() + + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(2).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(3).astype(np.float32)) + + res = check_saved_func(vm, "rx_func", inp, inp2) + + tvm.testing.assert_allclose(res.numpy(), inp2.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_tuple(exec_mode): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + + with bb.function("rx_func"): + x = nn.Placeholder((n,), dtype="float32", name="x") + y = nn.Placeholder((n,), dtype="float32", name="y") + tup = relax.Tuple([x, y]) + item = tup[0] + bb.emit_func_output([tup, item], params=[x, y]) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + shape = (5,) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + (res1, res2), res3 = vm["rx_func"](inp, inp2) + + tvm.testing.assert_allclose(res1.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res2.numpy(), inp2.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res3.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_tuplegetitem(exec_mode): + @tvm.script.ir_module + class TestVMTupleGetItem: + @R.function + def tuple_get_item( + x: R.Tensor(ndim=2, dtype="float32"), + y: R.Tensor(ndim=2, dtype="float32"), + ): + t = (x, y) + a = t[0] + b = t[1] + c = R.call_packed("test.vm.add", a, b, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) + return c + + mod = TestVMTupleGetItem + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + y_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + res = check_saved_func(vm, "tuple_get_item", x_inp, y_inp) + tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_lower_memory_alloc_storage_tensor(exec_mode): + @tvm.script.ir_module + class TestMemoryAllocStorageTensor: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")): + storage = R.memory.alloc_storage( + (24,), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + y = R.memory.alloc_tensor(storage, 0, (2, 3), dtype="float32") + _ = copy(x, y) + return y + + @T.prim_func + def copy(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + for i0, i1 in T.grid(2, 3): + with T.block("block"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + B[vi0, vi1] = A[vi0, vi1] + + mod = TestMemoryAllocStorageTensor + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + y = vm["main"](x) + tvm.testing.assert_allclose(y.numpy(), x.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_sub_func_call(exec_mode): + @tvm.script.ir_module + class TestVMSubFunction: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def relax_matmul_tir( + x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") + ) -> R.Tensor((32, 32), dtype="float32"): + with R.dataflow(): + gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + R.output(gv0) + return gv0 + + @R.function + def relax_matmul_packed( + x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") + ) -> R.Object: + gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) + return gv0 + + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Object: + gv0 = relax_matmul_tir(x, w) + gv1 = relax_matmul_packed(gv0, gv0) + return gv1 + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestVMSubFunction, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) + y_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) + res = check_saved_func(vm, "main", x_inp, y_inp) + product = np.dot(x_inp.numpy(), y_inp.numpy()) + expected = product * product + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_recursion(exec_mode): + @tvm.script.ir_module + class TestVMRecursion: + @R.function + def recursion(n: R.Tensor((1,), "float32")) -> R.Tensor: + cond = R.call_packed( + "test.vm.equal_zero", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + ) + if cond: + res = R.const(1.0) + else: + gv0 = R.call_packed( + "test.vm.subtract_one", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + ) + tmp = recursion(gv0) + res = R.call_packed( + "test.vm.add", tmp, tmp, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + ) + return res + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestVMRecursion, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + inp = np.empty(1).astype("float32") + recursion_runs = np.random.randint(1, 10) + inp.fill(recursion_runs) + inp = tvm.nd.array(inp) + res = check_saved_func(vm, "recursion", inp) + tvm.testing.assert_allclose(res.numpy(), np.power(2.0, recursion_runs), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_closure(exec_mode): + @tvm.script.ir_module + class TestClosure: + @R.function + def lifted_func_1(x: R.Tensor((2, 3), "float32"), env: R.Tensor((2, 3), "float32")): + return R.call_packed("test.vm.add", x, env, sinfo_args=(R.Tensor)) + + @R.function + def main( + x: R.Tensor((2, 3), "float32"), + y: R.Tensor((2, 3), "float32"), + ): + clo = R.make_closure(lifted_func_1, (x,)) + res = R.invoke_closure(clo, (y,), sinfo_args=(R.Tensor)) + return res + + mod = TestClosure + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + y_inp = tvm.nd.array(np.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]], dtype="float32")) + res = check_saved_func(vm, "main", x_inp, y_inp) + tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy()) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_time_evaluator(exec_mode): + @tvm.script.ir_module + class TestTimeEvaluator: + @R.function + def main(x: R.Tensor((1,), "float32"), y: R.Tensor((1,), "float32")): + return R.call_packed( + "test.vm.add", x, y, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + ) + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestTimeEvaluator, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x = tvm.nd.array(np.random.rand(1).astype("float32")) + y = tvm.nd.array(np.random.rand(1).astype("float32")) + + # ensure we can use time_evaluator with the stateful API + vm.set_input("main", x, y) + timing_res = vm.time_evaluator("invoke_stateful", tvm.cpu())("main") + # just checking that it has some results at all + assert timing_res.results + + # ensure we can use it with a closure + vm.save_function("main", "saved_main", x, y) + timing_res = vm.time_evaluator("saved_main", tvm.cpu())() + assert timing_res.results + + +@tvm.script.ir_module +class TestVMSetInput: + @T.prim_func + def test_vm_mul(x: T.handle, y: T.handle, z: T.handle): + T.func_attr({"global_symbol": "test_vm_mul"}) + m = T.var("int32") + n = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (m, n)) + C = T.match_buffer(z, (m, n)) + + for i, j in T.grid(m, n): + with T.block("mul"): + vi = T.axis.spatial(m, i) + vj = T.axis.spatial(n, j) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = A[vi, vj] * B[vi, vj] + + # test returning a tuple + @R.function + def test_vm_tuple( + x: R.Tensor((), "int32") + ) -> R.Tuple(R.Tensor((), "int32"), R.Tensor((), "int32")): + return (x, x) + + # nested tuple too + @R.function + def test_vm_nested_tuple( + x: R.Tensor((), "int32") + ) -> R.Tuple( + R.Tuple( + R.Tensor((), "int32"), + R.Tuple( + R.Tensor((), "int32"), + ), + ), + R.Tensor((), "int32"), + ): + return ((x, (x,)), x) + + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + gv0 = R.call_tir("test_vm_mul", (x, w), R.Tensor((32, 32), dtype="float32")) + return gv0 + + +def set_input_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.set_input("main", a, b) + vm.invoke_stateful("main") + res0 = vm.get_outputs("main") + + data_dict = {"x": a, "w": b} + vm.set_input("main", **data_dict) + vm.invoke_stateful("main") + res1 = vm.get_outputs("main") + tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res0.numpy(), res1.numpy(), rtol=1e-7, atol=1e-7) + + # bug! If you don't bind the NDArray to a var, the memory will get corrupted. + # Possibly due to object lifecycles and other FFI issues + a = tvm.nd.array(np.array(2).astype("int32"), device) + vm.set_input("test_vm_tuple", a) + vm.invoke_stateful("test_vm_tuple") + res2 = vm.get_outputs("test_vm_tuple") + # the results are NDArrays wrapped around scalars, + # so we have to get the scalar out of the NDArray + assert tuple(map(lambda a: int(a.numpy()), res2)) == (2, 2) + + b = tvm.nd.array(np.array(1).astype("int32"), device) + vm.set_input("test_vm_nested_tuple", b) + vm.invoke_stateful("test_vm_nested_tuple") + res3 = vm.get_outputs("test_vm_nested_tuple") + assert len(res3) == 2 and len(res3[0]) == 2 and len(res3[0][1]) == 1 + result_cast = ((int(res3[0][0].numpy()), (int(res3[0][1][0].numpy()),)), int(res3[1].numpy())) + assert result_cast == ((1, (1,)), 1) + + +def set_input_attempt_stateless(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # this should fail: once you set inputs, you cannot run statelessly + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.set_input("main", a, b) + # must use invoke stateful! + vm["main"]() + + +def set_input_attempt_invoke(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # this should fail: if the function needs inputs, you can't invoke directly + vm.invoke_stateful("main") + + +def set_input_attempt_get(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # this should fail: you can't get outputs without invoking the function first + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.set_input("main", a, b) + _ = vm.get_outputs("main") + + +def make_vm(mod, exec_mode) -> Tuple[relax.VirtualMachine, tvm.runtime.Device]: + """Returns a local VM for the given mod and the device""" + target = tvm.target.Target("llvm", host="llvm") + exec = relax.vm.build(TestVMSetInput, target, exec_mode=exec_mode) + exec.mod.export_library("exec.so") + exec_loaded = relax.vm.Executable(tvm.runtime.load_module("exec.so")) + os.remove("exec.so") + device = tvm.cpu() + return relax.VirtualMachine(exec_loaded, device), device + + +def run_on_rpc( + mod: tvm.IRModule, + trial_func: Callable[[relax.VirtualMachine, tvm.runtime.Device], None], + exec_mode: str, +): + """ + Sets up a VM over localhost using the given mod and runs the given trial function. + The trial function should take a VM and a device + """ + target = tvm.target.Target("llvm", host="llvm") + exec = relax.vm.build(mod, target, exec_mode=exec_mode) + temp = utils.tempdir() + path = temp.relpath("vm_library.so") + exec.mod.export_library(path) + + # Use local rpc server for testing. + # Server must use popen so it doesn't inherit the current process state. It + # will crash otherwise. + # Adapted from relay/test_vm.py + def check_remote(server): + remote = rpc.connect(server.host, server.port, session_timeout=10) + + # Upload the serialized Executable. + remote.upload(path) + # Get a handle to remote Executable. + rexec = remote.load_module("vm_library.so") + + device = remote.cpu() + # Build a VM out of the executable and context. + vm = relax.vm.VirtualMachine(exec=rexec, device=device) + trial_func(vm, device) + + check_remote(rpc.Server("127.0.0.1")) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_set_input(exec_mode): + set_input_trial(*make_vm(TestVMSetInput, exec_mode)) + + +def save_function_kwargs_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # just checking that we can use kwargs for the args when saving a function + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.save_function("main", "saved_main", x=a, w=b) + res0 = vm["saved_main"]() + tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_save_function_kwargs(exec_mode): + save_function_kwargs_trial(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_save_function_kwargs_rpc(exec_mode): + run_on_rpc(TestVMSetInput, save_function_kwargs_trial, exec_mode) + + +def save_function_time_evaluator_trial( + vm: relax.VirtualMachine, device: tvm.runtime.Device +) -> None: + # just checking that the saved function can be called in the time evaluator + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.save_function("main", "saved_main", a, b) + vm.time_evaluator("saved_main", device)() + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_save_function_time_evaluator(exec_mode): + save_function_time_evaluator_trial(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_save_function_time_evaluator(exec_mode): + run_on_rpc(TestVMSetInput, save_function_time_evaluator_trial, exec_mode) + + +# if you set an input, you should not be able to call statelessly +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_stateless_failure(exec_mode): + set_input_attempt_stateless(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_stateless_failure_rpc(exec_mode): + run_on_rpc(TestVMSetInput, set_input_attempt_stateless, exec_mode) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_invoke_failure(exec_mode): + set_input_attempt_invoke(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_invoke_failure_rpc(exec_mode): + run_on_rpc(TestVMSetInput, set_input_attempt_invoke, exec_mode) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_get_failure(exec_mode): + set_input_attempt_get(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_get_failure_rpc(exec_mode): + run_on_rpc(TestVMSetInput, set_input_attempt_get, exec_mode) + + +if __name__ == "__main__": + tvm.testing.main() From 819c720640cca808b1b44583cce1f99a70bc884d Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 14 Feb 2023 10:42:01 +0800 Subject: [PATCH 12/81] [Unity][TVMScript] Use explicit `R.shape` in TVMScript (#13979) As we've introduced `arg_sinfo` in CallNode, implicit shape constructor is not widely used in TVMScript. This PR removes the implicit shape since it may cause confusion between shape and tuple. --- python/tvm/relax/utils.py | 16 ++------- python/tvm/script/ir_builder/relax/ir.py | 18 ++++++++++ python/tvm/script/parser/relax/entry.py | 22 +++++++++--- src/script/printer/relax/expr.cc | 2 +- src/script/printer/relax/struct_info.cc | 14 +++++++- .../test_backend_transform_shape_lower.py | 2 +- tests/python/relax/test_transform.py | 2 +- tests/python/relax/test_tvmscript_parser.py | 36 +++++++++++++------ .../relax/test_tvmscript_printer_relax.py | 4 +-- tests/python/relax/test_vm_build.py | 6 ++-- tests/python/relax/test_vm_codegen_only.py | 14 ++++---- 11 files changed, 93 insertions(+), 43 deletions(-) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 5bfb0d87bf00..0bb82c79f4f8 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -23,7 +23,7 @@ from ..runtime import String, convert_to_object from ..tir import PrimExpr from . import _ffi_api -from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm +from .expr import Expr, Function, PrimValue, StringImm from .expr import Tuple as rx_Tuple @@ -74,14 +74,12 @@ def convert_to_expr(value: Any) -> Expr: 1. Return the input itself if it's already a `relax.Expr`; 2. Return `relax.PrimValue` if the input is a `PrimExpr`; 3. Return `relax.StringImm` if the input is `tvm.String` or `str`; - 4. Return `relax.ShapeExpr` if the input is a tuple/list of `PrimExpr` w/ int dtype; - 5. Return `relax.Tuple` if the input is a tuple/list of `Expr`. + 4. Return `relax.Tuple` if the input is a tuple/list of `Expr`. Notes ----- 1. `tvm.tir.StringImm` is not allowed because of ambiguity, which can be either `relax.StringImm` or `relax.PrimValue`. - 2. We regard empty tuple/list as `relax.Tuple` instead of `relax.ShapeExpr` """ if isinstance(value, int): return PrimValue(tir.IntImm("int64", value)) @@ -102,16 +100,8 @@ def convert_to_expr(value: Any) -> Expr: # Case 3 if isinstance(tvm_value, String): return StringImm(value) - # Case 4 & 5 + # Case 4 if isinstance(value, (tuple, list)): - # Note 2 - if len(value) == 0: - return rx_Tuple([]) - # Case 4 - opt_prim_value = [convert_to_object(v) for v in value] - if all([isinstance(v, PrimExpr) and v.dtype.startswith("int") for v in opt_prim_value]): - return ShapeExpr(value) - # Case 5 # `convert_to_expr` ensures that all elements are `Expr` if no exception raises return rx_Tuple([convert_to_expr(v) for v in value]) raise TypeError(f"Cannot convert {value} with type {type(value)} to `relax.Expr`") diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 0692ec5683c0..0e6595cb4514 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -329,6 +329,23 @@ def tuple(*fields: Expr) -> Expr: return relax.Tuple(fields) # type: ignore[attr-defined] # pylint: disable=no-member +############################### R.shape ################################ + + +def shape(value: List[PrimExpr]) -> Expr: + """Create a ShapeExpr. + Parameters + ---------- + value : List[PrimExpr] + The fields of the tuple. + Returns + ------- + res : Expr + The result tuple. + """ + return relax.ShapeExpr(value) # pylint: disable=no-member # type: ignore + + ############################### PrimValue ############################## @@ -407,6 +424,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "prim_value", "print", "reshape", + "shape", "shape_of", "str", "tuple", diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index d93f9a2826bc..7e51264cb37c 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -22,6 +22,7 @@ from tvm.relax import ( Expr, + ShapeExpr, FuncStructInfo, Function, ObjectStructInfo, @@ -84,17 +85,22 @@ class TensorProxy(StructInfoProxy): def __init__( self, - shape: Optional[List[Union[PrimExpr, str]]] = None, + shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None, dtype: Optional[str] = None, ndim: int = -1, ) -> None: self.shape = shape + if isinstance(shape, Expr) and not isinstance(shape, ShapeExpr): + raise ValueError( + "Only ShapeExpr is allowed as shape expr, but got: " + f"{shape} with type: {type(shape)}" + ) self.dtype = dtype self.ndim = ndim super().__init__() def get_symbolic_vars(self) -> Set[str]: - if self.shape is None: + if self.shape is None or isinstance(self.shape, Expr): return {} else: return {s for s in self.shape if isinstance(s, str) and s.isidentifier()} @@ -102,6 +108,8 @@ def get_symbolic_vars(self) -> Set[str]: def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo: if self.shape is None: return TensorStructInfo(None, self.dtype, self.ndim) + elif isinstance(self.shape, ShapeExpr): + return TensorStructInfo(self.shape, self.dtype, self.ndim) else: if dict_globals is None and any([isinstance(s, str) for s in self.shape]): raise ValueError( @@ -113,7 +121,7 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Tenso def Tensor( - shape: Optional[List[Union[PrimExpr, str]]] = None, + shape: Optional[Union[List[Union[PrimExpr, str]], ShapeExpr]] = None, dtype: Optional[str] = None, ndim: int = -1, ) -> TensorProxy: @@ -124,8 +132,12 @@ def Tensor( dtype = shape shape = None - if shape is not None and not isinstance(shape, (tuple, list)): - raise ValueError(f"shape must be a list or tuple, but got: {shape}") + if ( + shape is not None + and not isinstance(shape, (tuple, list)) + and not isinstance(shape, ShapeExpr) + ): + raise ValueError(f"shape must be a list/tuple or a ShapeExpr, but got: {shape}") return TensorProxy(shape, dtype, ndim) diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc index a786932fc3d9..66d7d187d0c8 100644 --- a/src/script/printer/relax/expr.cc +++ b/src/script/printer/relax/expr.cc @@ -71,7 +71,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) for (int i = 0, l = n->values.size(); i < l; ++i) { values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayIndex(i), d)); } - return TupleDoc(values_doc); + return Relax(d, "shape")->Call({ListDoc(values_doc)}); }); Optional SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) { diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc index 6f4a66c991d9..c541619ec887 100644 --- a/src/script/printer/relax/struct_info.cc +++ b/src/script/printer/relax/struct_info.cc @@ -89,7 +89,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Array kwargs_keys; Array kwargs_values; if (n->shape.defined()) { - args.push_back(d->AsDoc(n->shape.value(), n_p->Attr("shape"))); + // Need to dig into ShapeExpr to preserve the `R.shape` prefix + if (const auto* shape = n->shape.value().as()) { + auto shape_expr = GetRef(shape); + ObjectPath shape_p = n_p->Attr("shape")->Attr("values"); + Array shape_docs; + for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { + shape_docs.push_back( + PrintShapeVar(shape_expr->values[i], shape_p->ArrayIndex(i), d)); + } + args.push_back(TupleDoc(shape_docs)); + } else { + args.push_back(d->AsDoc(n->shape.value(), n_p->Attr("shape"))); + } } if (!n->IsUnknownDtype()) { kwargs_keys.push_back("dtype"); diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index 0bf0f175dd7e..5cd104dd013f 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -167,7 +167,7 @@ def main( n = T.Var("n", "int64") k = T.Var("k", "int64") z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None)) - return (k + 1, m, 2) + return R.shape([k + 1, m, 2]) # slot assignment: # 0: n, 1: m, 2:k, 3: k+1 diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 624b7877cd11..12dd095c6b5d 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -109,7 +109,7 @@ class TestVMBuiltinLower: @R.function def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: m, n = T.var("int64"), T.var("int64") - alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0, dtype="float32") + alloc = R.builtin.alloc_tensor(R.shape([m, n]), runtime_device_index=0, dtype="float32") _ = R.call_packed( "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) ) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 34b02fdbb8c3..c9a16fbcacb7 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -22,10 +22,9 @@ import tvm.script import tvm.testing from tvm import IRModule, relax, tir, topi -from tvm.relax import DynTensorType -from tvm.script import ir as I -from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script.parser import ir as I +from tvm.script.parser import relax as R +from tvm.script.parser import tir as T def _check( @@ -202,6 +201,23 @@ def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"): _check(foo, bb.get()["foo"]) +def test_relax_base_op(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=0, dtype="float32") + shape = R.shape_of(alloc) + return shape + + x = relax.Var("x", R.Tensor((4, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)), "float32", 0)) + shape = bb.emit(relax.op.shape_of(alloc)) + bb.emit_func_output(shape) + # todo(yongwww): comment this check because 0 was changed to R.prim_value(0) in the printed IR + # _check(foo, bb.get()["foo"]) + + def test_symbolic_shape(): @R.function def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): @@ -274,7 +290,7 @@ def foo(x: R.Tensor("float32"), y: R.Tensor("float32")): y0 = R.match_cast(y, R.Tensor([n], "float32")) gv = y0 R.output(gv) - return (x0, (m, n * 2)) + return (x0, R.shape([m, n * 2])) x = relax.Var("x", R.Tensor("float32")) y = relax.Var("y", R.Tensor("float32")) @@ -314,7 +330,7 @@ def test_tuple_return_2(): def foo(x: R.Tensor("float32", ndim=2)): n, m = T.var("int64"), T.var("int64") x0 = R.match_cast(x, R.Tensor((n, m), "float32")) - return (x0, (n + 1, m, 1)) + return (x0, R.shape([n + 1, m, 1])) x = relax.Var("x", R.Tensor("float32", ndim=2)) n, m = tir.Var("n", "int64"), tir.Var("m", "int64") @@ -332,7 +348,7 @@ def foo(x: R.Tensor("float32", ndim=2)): n, m = T.var("int64"), T.var("int64") x0 = R.match_cast(x, R.Tensor((n, m), "float32")) t0 = (x, x0) - t1 = (x, (n, m), t0) + t1 = (x, R.shape([n, m]), t0) return t1 x = relax.Var("x", R.Tensor("float32", ndim=2)) @@ -965,9 +981,9 @@ def test_vm_ops(): def foo(x: R.Tensor(("m", "n"), dtype="float32")): m = T.var("int64") n = T.var("int64") - storage = R.vm.alloc_storage((4 * m * n,), dtype="float32", runtime_device_index=0) - alloc = R.vm.alloc_tensor(storage, (m, n), offset=0, dtype="float32") - tensor = R.builtin.alloc_tensor((m, n), dtype="float32", runtime_device_index=0) + storage = R.vm.alloc_storage(R.shape([4 * m * n]), dtype="float32", runtime_device_index=0) + alloc = R.vm.alloc_tensor(storage, shape=R.shape([m, n]), offset=0, dtype="float32") + tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", runtime_device_index=0) _ = R.vm.call_tir_dyn("te_func", (x, tensor, (m, n))) gv = tensor return alloc, gv diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 58596f968f98..db90c66422d0 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -292,7 +292,7 @@ def test_tuple_get_item(): def test_shape_expr(): obj = relax.ShapeExpr([1, 2, 3]) - _assert_print(obj, "(1, 2, 3)") + _assert_print(obj, "R.shape([1, 2, 3])") def test_call(): @@ -304,7 +304,7 @@ def test_call(): """ x = T.Var("x", "int64") a: R.Tensor((1, x, 3), dtype="float32") -R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=(x,)) +R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x])) """, ) diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 534d2308daa9..0a881691accc 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -88,7 +88,7 @@ class TestVMCompileStage2: def foo(x: R.Tensor(dtype="float32")) -> R.Shape: n, m = T.var("int64"), T.var("int64") _ = R.match_cast(x, R.Tensor((n, m), "float32")) - return (n * 2, m * 3) + return R.shape([n * 2, m * 3]) mod = TestVMCompileStage2 target = tvm.target.Target("llvm", host="llvm") @@ -511,9 +511,9 @@ class TestMemoryAllocStorageTensor: @R.function def main(x: R.Tensor((2, 3), dtype="float32")): storage = R.memory.alloc_storage( - (24,), virtual_device_index=0, storage_scope="global", dtype="float32" + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" ) - y = R.memory.alloc_tensor(storage, 0, (2, 3), dtype="float32") + y = R.memory.alloc_tensor(storage, 0, R.shape([2, 3]), dtype="float32") _ = copy(x, y) return y diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index b5e77091776a..4b79ecf70fa1 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -18,13 +18,15 @@ Restrictions: all shape lowered, explicit allocation. """ -import tvm -import pytest import numpy as np -from tvm import relax, TVMError -from tvm.script import relax as R, tir as T +import pytest +import tvm +import tvm.testing +from tvm import relax +from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode from tvm.relax.testing.vm import check_saved_func -from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode +from tvm.script import relax as R +from tvm.script import tir as T EXEC_MODE = ["bytecode"] @@ -312,7 +314,7 @@ class TestVMBuiltinReshape: def main(x: R.Tensor((3, 4), "float32")): R.func_attr({"global_symbol": "main"}) y = R.call_packed( - "vm.builtin.reshape", x, (6, 2), sinfo_args=R.Tensor((6, 2), "float32") + "vm.builtin.reshape", x, R.shape([6, 2]), sinfo_args=R.Tensor((6, 2), "float32") ) return y From e48d4d2379a091205bb9bac2204a59807e38fab4 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 14:02:20 -0500 Subject: [PATCH 13/81] [Unity] Relax op: index (#13987) This PR is about the high-level tensor computation operators in Relax. This PR includes the tensor indexing operators. --- include/tvm/relax/attrs/index.h | 62 ++ python/tvm/relax/op/__init__.py | 2 + python/tvm/relax/op/index.py | 90 +++ python/tvm/relax/op/op_attrs.py | 29 + python/tvm/script/ir_builder/relax/ir.py | 4 + src/relax/op/tensor/index.cc | 195 ++++++ src/relax/op/tensor/index.h | 65 ++ tests/python/relax/test_op_index.py | 593 ++++++++++++++++++ .../relax/test_tvmscript_parser_op_index.py | 82 +++ 9 files changed, 1122 insertions(+) create mode 100644 include/tvm/relax/attrs/index.h create mode 100644 python/tvm/relax/op/index.py create mode 100644 python/tvm/relax/op/op_attrs.py create mode 100644 src/relax/op/tensor/index.cc create mode 100644 src/relax/op/tensor/index.h create mode 100644 tests/python/relax/test_op_index.py create mode 100644 tests/python/relax/test_tvmscript_parser_op_index.py diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h new file mode 100644 index 000000000000..c95395a80376 --- /dev/null +++ b/include/tvm/relax/attrs/index.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/index.h + * \brief Attributes for indexing operators. + */ +#ifndef TVM_RELAX_ATTRS_INDEX_H_ +#define TVM_RELAX_ATTRS_INDEX_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in take operator */ +struct TakeAttrs : public tvm::AttrsNode { + Optional axis; + + TVM_DECLARE_ATTRS(TakeAttrs, "relax.attrs.TakeAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis over which to select values."); + } +}; // struct TakeAttrs + +/*! \brief Attributes used in strided_slice operator */ +struct StridedSliceAttrs : public tvm::AttrsNode { + Array axes; + Array begin; + Array end; + Optional> strides; + + TVM_DECLARE_ATTRS(StridedSliceAttrs, "relax.attrs.StridedSliceAttrs") { + TVM_ATTR_FIELD(axes).describe("Axes along which slicing is applied."); + TVM_ATTR_FIELD(begin).describe("The indices to begin with in the slicing, inclusive."); + TVM_ATTR_FIELD(end).describe("The indices indicating end of the slice, exclusive."); + TVM_ATTR_FIELD(strides).describe( + "Specifies the stride values, it can be negative in that case, the input tensor will be " + "reversed in that particular axis. If not specified, it by default is an list of ones of " + "the same length as `axes`."); + } +}; // struct StridedSliceAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_INDEX_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 9a131cdf957f..3393a5dcae67 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -20,6 +20,8 @@ # Operators from .base import * from .binary import * +from .index import * from .manipulate import * +from .op_attrs import * from . import builtin from . import memory diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py new file mode 100644 index 000000000000..2a7afa5ba0f9 --- /dev/null +++ b/python/tvm/relax/op/index.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Indexing operators.""" +from typing import List, Optional, Union + +from tvm.ir.expr import PrimExpr + +from . import _ffi_api +from ..expr import Expr + +PrimExprLike = Union[int, PrimExpr] + + +def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr: + """Take elements from a tensor along an axis. + + Parameters + ---------- + x : relax.Expr + The source tensor. + + indices : relax.Expr + The indices of the values to extract. + It is required to be a one-dimensional tensor which has integer dtype. + + axis : Optional[int] + The axis over which to select values. + If it is none, the input tensor is required to be one-dimensional. + + Returns + ------- + ret : relax.Expr + The taken result. + """ + return _ffi_api.take(x, indices, axis) # type: ignore + + +def strided_slice( + x: Expr, + axes: List[int], + begin: List[PrimExprLike], + end: List[PrimExprLike], + strides: Optional[List[PrimExprLike]] = None, +) -> Expr: + """Strided slice of a tensor. + + Parameters + ---------- + x : relax.Expr + The source tensor to be sliced. + + axes : List[int] + Axes along which slicing is applied. + + begin : List[PrimExprLike] + The indices to begin with in the slicing, inclusive. + + end : List[PrimExprLike] + The indices indicating end of the slice, exclusive. + + strides : Optional[List[PrimExprLike]] + Specifies the stride values, it can be negative in that case, + the input tensor will be reversed in that particular axis. + If not specified, it by default is an list of ones of the same length as `axes`. + + Returns + ------- + ret : relax.Expr + The sliced result. + + Note + ---- + strided_slice require the input `begin`, `end` and `strides` to have the + same length as `axes`. + """ + return _ffi_api.strided_slice(x, axes, begin, end, strides) # type: ignore diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py new file mode 100644 index 000000000000..44cb2cf3a5b4 --- /dev/null +++ b/python/tvm/relax/op/op_attrs.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The attributes node used for Relax operators""" +from tvm.ir import Attrs +import tvm._ffi + + +@tvm._ffi.register_object("relax.attrs.TakeAttrs") +class TakeAttrs(Attrs): + """Attributes used in take operator""" + + +@tvm._ffi.register_object("relax.attrs.StridedSliceAttrs") +class StridedSliceAttrs(Attrs): + """Attributes used in strided_slice operator""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 0e6595cb4514..75a00ea04985 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -42,6 +42,8 @@ print, reshape, shape_of, + strided_slice, + take, ) from tvm.relax.struct_info import StructInfo from tvm.relax.utils import args_converter @@ -427,5 +429,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "shape", "shape_of", "str", + "strided_slice", + "take", "tuple", ] diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc new file mode 100644 index 000000000000..246abef9084b --- /dev/null +++ b/src/relax/op/tensor/index.cc @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file index.cc + * \brief indexing operators. + */ + +#include "index.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.take */ +TVM_REGISTER_NODE_TYPE(TakeAttrs); + +Expr take(Expr x, Expr indices, Optional axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.take"); + return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.take").set_body_typed(take); + +StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo indices_sinfo = input_sinfo[1]; + if (indices_sinfo->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Take op requires the input indices to be 1-dimensional tensor. However, " + "the given indices ndim is " + << indices_sinfo->ndim); + } else if (!indices_sinfo->IsUnknownDtype() && + !(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Take op requires the input indices to have integer dtype. However, the " + "given indices dtype is " + << indices_sinfo->dtype); + } + + const auto* attrs = call->attrs.as(); + if (!attrs->axis.defined() && data_sinfo->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Take op expects the input data to be 1-dimensional tensor when the axis " + "is not specified. However, the given data tensor has ndim " + << data_sinfo->ndim); + } + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + int axis = attrs->axis.defined() + ? NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value) + : 0; + const auto* data_shape = data_sinfo->shape.as(); + const auto* indices_shape = indices_sinfo->shape.as(); + if (data_shape == nullptr || indices_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + + Array output_shape = data_shape->values; + output_shape.Set(axis, indices_shape->values[0]); + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.take") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("x", "Tensor", "The source tensor.") + .add_argument("indices", "Tensor", "The indices of the values to extract.") + .set_attr("FInferStructInfo", InferStructInfoTake); + +/* relax.strided_slice */ +TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); + +Expr strided_slice(Expr x, // + Array axes, // + Array begin, // + Array end, // + Optional> strides) { + int n_axis = axes.size(); + CHECK_EQ(static_cast(begin.size()), n_axis) + << "StridedSlice requires the number of begin indices to equal the number of axes."; + CHECK_EQ(static_cast(end.size()), n_axis) + << "StridedSlice requires the number of end indices to equal the number of axes."; + if (strides.defined()) { + CHECK_EQ(static_cast(strides.value().size()), n_axis) + << "StridedSlice requires the number of strides to equal the number of axes."; + } + + // Todo(relax-team): We are going to support dynamic strided slice, where + // begin/end/stride can be not static at compile time. Therefore, begin/end/stride + // should not be part of StridedSliceAttrs, as we only allow static values to + // reside in attributes. However, using ShapeExpr to represent these + // arrays is not conceptually right, because they are not describing a + // concrete shape. The proper way to support dynamic strided slice is to use + // Tuple of PrimValue to represent begin/end/stride. Since at this moment + // we have no support for PrimValue, we store begin/end/stride as attribute + // fields as a workaround. + // Will switch to Tuple of PrimValue after introducing PrimValue. + auto f_convert_to_int64 = [](const PrimExpr& value) { + if (value->IsInstance()) { + return cast(DataType::Int(64), value); + } + CHECK(value.dtype() == DataType::Int(64)) << "strided_slice expects the input begin/end/stride " + "values to be all int64. However, the given " + << value << " has dtype " << value->dtype; + return value; + }; + + ObjectPtr attrs = make_object(); + attrs->axes = std::move(axes); + attrs->begin = begin.Map(f_convert_to_int64); + attrs->end = end.Map(f_convert_to_int64); + attrs->strides = strides.defined() ? strides.value().Map(f_convert_to_int64) : strides; + + static const Op& op = Op::Get("relax.strided_slice"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice); + +StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (attrs->axes.empty()) { + return data_sinfo; + } + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes); + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + + int n_axis = axes.size(); + Array strides = attrs->strides.defined() + ? attrs->strides.value() + : Array(n_axis, IntImm(DataType::Int(64), 1)); + std::vector int_strides; + int_strides.reserve(n_axis); + // Only do output shape inference when all the begin/end/stride values are integers. + for (int i = 0; i < n_axis; ++i) { + const auto* int_begin = attrs->begin[i].as(); + const auto* int_end = attrs->end[i].as(); + const auto* int_stride = strides[i].as(); + if (!int_begin || !int_end || !int_stride) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + int_strides.push_back(int_stride->value); + } + + Array output_shape = data_shape->values; + for (int i = 0; i < n_axis; ++i) { + PrimExpr len = int_strides[i] < 0 ? ceildiv(attrs->begin[i] - attrs->end[i], -int_strides[i]) + : ceildiv(attrs->end[i] - attrs->begin[i], int_strides[i]); + output_shape.Set(axes[i], len); + } + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.strided_slice") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The source tensor to be sliced.") + .set_attr("FInferStructInfo", InferStructInfoStridedSlice); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/index.h b/src/relax/op/tensor/index.h new file mode 100644 index 000000000000..6944493a0fd6 --- /dev/null +++ b/src/relax/op/tensor/index.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file index.h + * \brief The functions to make Relax tensor indexing operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_INDEX_H_ +#define TVM_RELAX_OP_TENSOR_INDEX_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Take elements from a tensor along an axis. + * \param x The source tensor. + * \param indices The indices of the values to extract. + * It is required to be a one-dimensional tensor which has integer dtype. + * \param axis The axis over which to select values. + * If it is `NullOpt`, the input tensor is required to be one-dimensional. + * \return The taken result. + */ +Expr take(Expr x, Expr indices, Optional axis); + +/*! + * \brief Strided slice of a tensor. + * \param x The source tensor to be sliced. + * \param axes Axes along which slicing is applied. + * \param begin The indices to begin with in the slicing, inclusive. + * \param end The indices indicating end of the slice, exclusive. + * \param strides Specifies the stride values, it can be negative in that case, + * the input tensor will be reversed in that particular axis. + * If it is `NullOpt`, it by default is an list of ones of the same length as `axes`. + * \return The sliced result + */ +Expr strided_slice(Expr x, // + Array axes, // + Array begin, // + Array end, // + Optional> strides); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_INDEX_H_ diff --git a/tests/python/relax/test_op_index.py b/tests/python/relax/test_op_index.py new file mode 100644 index 000000000000..77a04b1a1aab --- /dev/null +++ b/tests/python/relax/test_op_index.py @@ -0,0 +1,593 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + idx = relax.Var("idx", R.Tensor((2,), "float32")) + assert relax.op.take(x, idx, axis=1).op == Op.get("relax.take") + assert relax.op.strided_slice(x, axes=[0], begin=[0], end=[2]).op == Op.get( + "relax.strided_slice" + ) + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_take_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((4, 10), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((4, 10))) + x4 = relax.Var("x", R.Tensor(ndim=2)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((10,), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=1)) + y2 = relax.Var("y", R.Tensor((10,))) + y3 = relax.Var("y", R.Tensor(ndim=1)) + idx0 = relax.Var("idx", R.Tensor((6,), "int64")) + idx1 = relax.Var("idx", R.Tensor("int64", ndim=1)) + idx2 = relax.Var("idx", R.Tensor((6,))) + idx3 = relax.Var("idx", R.Tensor(ndim=1)) + + _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((4, 6), "float32")) + _check_inference( + bb, relax.op.take(x0, idx0, axis=-1), relax.TensorStructInfo((4, 6), "float32") + ) + _check_inference( + bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx0, axis=1), relax.TensorStructInfo((4, 6), dtype="")) + _check_inference(bb, relax.op.take(x4, idx0, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx0, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx1, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x4, idx1, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx1, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo((4, 6), "float32")) + _check_inference( + bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx2, axis=1), relax.TensorStructInfo((4, 6), dtype="")) + _check_inference(bb, relax.op.take(x4, idx2, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx2, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.take(x0, idx3, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx3, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx3, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx3, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x4, idx3, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx3, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((6,), "float32")) + _check_inference(bb, relax.op.take(y1, idx0), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx0), relax.TensorStructInfo((6,), dtype="")) + _check_inference(bb, relax.op.take(y3, idx0), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y0, idx1), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y1, idx1), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx1), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y3, idx1), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y0, idx2), relax.TensorStructInfo((6,), "float32")) + _check_inference(bb, relax.op.take(y1, idx2), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx2), relax.TensorStructInfo((6,), dtype="")) + _check_inference(bb, relax.op.take(y3, idx2), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y0, idx3), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y1, idx3), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx3), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y3, idx3), relax.TensorStructInfo(dtype="", ndim=1)) + + +def test_take_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + i = tir.Var("i", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + y0 = relax.Var("y", R.Tensor((n,), "float32")) + y1 = relax.Var("y", R.Tensor((n,))) + idx0 = relax.Var("idx", R.Tensor((i,), "int64")) + idx1 = relax.Var( + "idx", + R.Tensor( + (i,), + ), + ) + + _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((m, i), "float32")) + _check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo((m, i), dtype="")) + _check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo((m, i), "float32")) + _check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo((m, i), dtype="")) + _check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((i,), "float32")) + _check_inference(bb, relax.op.take(y1, idx0), relax.TensorStructInfo((i,), dtype="")) + _check_inference(bb, relax.op.take(y0, idx1), relax.TensorStructInfo((i,), "float32")) + _check_inference(bb, relax.op.take(y1, idx1), relax.TensorStructInfo((i,), dtype="")) + + +def test_take_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + sx0 = relax.Var("sx", relax.ShapeStructInfo((4, 10))) + sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=2)) + sx2 = relax.Var("sx", relax.ShapeStructInfo()) + sidx0 = relax.Var("sidx", relax.ShapeStructInfo((6,))) + sidx1 = relax.Var("sidx", relax.ShapeStructInfo(ndim=1)) + x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32")) + x3 = relax.Var("x", R.Tensor((4, 10), "float32")) + idx0 = relax.Var("idx", relax.TensorStructInfo(sidx0, "int64")) + idx1 = relax.Var("idx", relax.TensorStructInfo(sidx1, "int64")) + idx2 = relax.Var("idx", R.Tensor((6,), "int64")) + + _check_inference( + bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.take(x3, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x3, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + + +def test_take_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((4, 10), "float16")) + x1 = relax.Var("x", R.Tensor((4, 10), "int16")) + x2 = relax.Var("x", R.Tensor((4, 10), "int32")) + idx0 = relax.Var("idx", R.Tensor((6,), "int32")) + idx1 = relax.Var("idx", R.Tensor((6,), "int8")) + idx2 = relax.Var("idx", R.Tensor((6,), "uint32")) + + _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((4, 6), "float16")) + _check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo((4, 6), "int16")) + _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo((4, 6), "int32")) + _check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo((4, 6), "float16")) + _check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo((4, 6), "int16")) + _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo((4, 6), "int32")) + _check_inference(bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo((4, 6), "float16")) + _check_inference(bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo((4, 6), "int16")) + _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo((4, 6), "int32")) + + +def test_take_infer_struct_info_indices_not_one_dimensional(): + bb = relax.BlockBuilder() + sidx0 = relax.Var("sidx", relax.ShapeStructInfo((6, 6))) + sidx1 = relax.Var("sidx", relax.ShapeStructInfo(())) + sidx2 = relax.Var("sidx", relax.ShapeStructInfo(ndim=2)) + sidx3 = relax.Var("sidx", relax.ShapeStructInfo(ndim=0)) + sidx4 = relax.Var("sidx", relax.ShapeStructInfo()) + x = relax.Var("x", R.Tensor((4, 10), "float32")) + idx0 = relax.Var("idx", R.Tensor((6, 6), "int64")) + idx1 = relax.Var("idx", R.Tensor((), "int64")) + idx2 = relax.Var("idx", R.Tensor("int64", ndim=2)) + idx3 = relax.Var("idx", R.Tensor("int64", ndim=0)) + idx4 = relax.Var("idx", R.Tensor("int64")) + idx5 = relax.Var("idx", relax.TensorStructInfo(sidx0, "int64")) + idx6 = relax.Var("idx", relax.TensorStructInfo(sidx1, "int64")) + idx7 = relax.Var("idx", relax.TensorStructInfo(sidx2, "int64")) + idx8 = relax.Var("idx", relax.TensorStructInfo(sidx3, "int64")) + idx9 = relax.Var("idx", relax.TensorStructInfo(sidx4, "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx1, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx2, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx3, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx4, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx5, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx6, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx7, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx8, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx9, axis=1)) + + +def test_take_infer_struct_info_indices_not_integer_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((4, 10), "float32")) + idx0 = relax.Var("idx", R.Tensor((6, 6), "float32")) + idx1 = relax.Var("idx", R.Tensor((6, 6), "float64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx1, axis=1)) + + +def test_take_infer_struct_info_multi_dimensional_without_axis(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((4, 10), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + idx0 = relax.Var("idx", R.Tensor((6,), "int64")) + idx1 = relax.Var("idx", R.Tensor("int64", ndim=1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x0, idx0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x1, idx0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x2, idx0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x0, idx1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x1, idx1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x2, idx1)) + + +def test_take_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((4, 10), "float32")) + idx = relax.Var("idx", R.Tensor((6,), "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx, axis=-3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx, axis=2)) + + +def test_take_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((4, 10))) + x1 = relax.Var("x", R.Tensor((4, 10), "float32")) + idx0 = relax.Var("idx", relax.ShapeStructInfo((6,))) + idx1 = relax.Var("idx", R.Tensor((6,), "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x0, idx1, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x1, idx0, axis=1)) + + +def test_strided_slice_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((8, 9, 10, 10))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.strided_slice( + x0, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo((4, 9, 10, 3), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice( + x1, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.strided_slice( + x2, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.strided_slice( + x3, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo((4, 9, 10, 3), dtype=""), + ) + _check_inference( + bb, + relax.op.strided_slice( + x4, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo(dtype="", ndim=4), + ) + _check_inference( + bb, + relax.op.strided_slice( + x5, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo(dtype=""), + ) + _check_inference( + bb, + relax.op.strided_slice( + x0, axes=[-1, -3, -4], begin=[8, 0, 1], end=[0, 9, 8], strides=[-3, 1, 2] + ), + relax.TensorStructInfo((4, 9, 10, 3), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[1, 2], begin=[1, 0], end=[8, 9]), + relax.TensorStructInfo((8, 7, 9, 10), "float32"), + ) + + +def test_strided_slice_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[0], begin=[1], end=[3]), + relax.TensorStructInfo((2, n), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[0], begin=[1], end=[8], strides=[3]), + relax.TensorStructInfo((3, n), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[0], begin=[1], end=[3]), + relax.TensorStructInfo((2, n), dtype=""), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[0], begin=[1], end=[8], strides=[3]), + relax.TensorStructInfo((3, n), dtype=""), + ) + + +def test_strided_slice_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((8, 10))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, dtype="")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, dtype="")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, dtype="")) + + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x2, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x3, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x4, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x5, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype=""), + ) + + +def test_strided_slice_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((8, 9), "float16")) + x1 = relax.Var("x", R.Tensor((8, 9), "int32")) + x2 = relax.Var("x", R.Tensor((8, 9), "int64")) + + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo((8, 9), "float16"), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo((8, 9), "int32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x2, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo((8, 9), "int64"), + ) + + +def test_strided_slice_infer_struct_info_symbolic_begin_end_strides(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + x = relax.Var("x", R.Tensor((8, 9), "float32")) + + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[a], end=[8]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[0], end=[a]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[a]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + + +def test_strided_slice_infer_struct_info_no_axis(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((m, n))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor(dtype="float32", ndim=2)) + x2 = relax.Var("x", R.Tensor(dtype="float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[], begin=[], end=[]), + relax.TensorStructInfo((m, n), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[], begin=[], end=[]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x2, axes=[], begin=[], end=[]), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x3, axes=[], begin=[], end=[]), + relax.TensorStructInfo(s0, "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x4, axes=[], begin=[], end=[]), + relax.TensorStructInfo(s1, "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x5, axes=[], begin=[], end=[]), + relax.TensorStructInfo(s2, "float32"), + ) + + +def test_strided_slice_begin_end_strides_int64(): + x = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) + strided_slice = relax.op.strided_slice( + x, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ) + + assert strided_slice.attrs.begin[0].dtype == "int64" + assert strided_slice.attrs.begin[1].dtype == "int64" + assert strided_slice.attrs.begin[2].dtype == "int64" + assert strided_slice.attrs.end[0].dtype == "int64" + assert strided_slice.attrs.end[1].dtype == "int64" + assert strided_slice.attrs.end[2].dtype == "int64" + assert strided_slice.attrs.strides[0].dtype == "int64" + assert strided_slice.attrs.strides[1].dtype == "int64" + assert strided_slice.attrs.strides[2].dtype == "int64" + + +def test_strided_slice_inconsistent_axes_begin_end_strides_length(): + x = relax.Var("x", R.Tensor((8, 9), "float32")) + + with pytest.raises(TVMError): + relax.op.strided_slice(x, axes=[1], begin=[], end=[9]) + with pytest.raises(TVMError): + relax.op.strided_slice(x, axes=[1], begin=[0], end=[]) + with pytest.raises(TVMError): + relax.op.strided_slice(x, axes=[1], begin=[0], end=[9], strides=[]) + + +def test_strided_slice_infer_struct_info_repetitive_axes(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((8, 9), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x, axes=[0, 0], begin=[0, 0], end=[8, 8])) + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x, axes=[0, -2], begin=[0, 0], end=[8, 8])) + + +def test_strided_slice_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((8, 9), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x, axes=[2], begin=[0], end=[8])) + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x, axes=[-3], begin=[0], end=[8])) + + +def test_strided_slice_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((8, 9))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((8, 9), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8])) + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8])) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_index.py b/tests/python/relax/test_tvmscript_parser_op_index.py new file mode 100644 index 000000000000..b271d1a7f3bc --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_index.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_take(): + @R.function + def foo( + x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((3,), "int64") + ) -> R.Tensor((2, 3, 3), "float32"): + gv: R.Tensor((2, 3, 3), "float32") = R.take(x, indices, axis=2) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + indices = relax.Var("indices", R.Tensor((3,), "int64")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, indices]): + gv = bb.emit(relax.op.take(x, indices, axis=2)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_strided_slice(): + @R.function + def foo(x: R.Tensor((8, 9, 10, 10), "float32")) -> R.Tensor((4, 9, 10, 3), "float32"): + gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice( + x, + axes=[0, 1, -1], + begin=[1, 0, 8], + end=[8, 9, 0], + strides=[2, 1, -3], + ) + return gv + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) + with bb.function("foo", [x]): + gv = bb.emit( + relax.op.strided_slice( + x, axes=[0, 1, -1], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() From 886689a5f7991ab9fb15680be7e91d1de630d819 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 14:55:09 -0500 Subject: [PATCH 14/81] [Unity] Relax op: datatype (#13986) --- include/tvm/relax/attrs/datatype.h | 44 ++++++++ python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/datatype.py | 42 +++++++ python/tvm/relax/op/op_attrs.py | 5 + python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/op/tensor/datatype.cc | 60 ++++++++++ src/relax/op/tensor/datatype.h | 45 ++++++++ tests/python/relax/test_op_datatype.py | 105 ++++++++++++++++++ .../test_tvmscript_parser_op_datatype.py | 54 +++++++++ 9 files changed, 358 insertions(+) create mode 100644 include/tvm/relax/attrs/datatype.h create mode 100644 python/tvm/relax/op/datatype.py create mode 100644 src/relax/op/tensor/datatype.cc create mode 100644 src/relax/op/tensor/datatype.h create mode 100644 tests/python/relax/test_op_datatype.py create mode 100644 tests/python/relax/test_tvmscript_parser_op_datatype.py diff --git a/include/tvm/relax/attrs/datatype.h b/include/tvm/relax/attrs/datatype.h new file mode 100644 index 000000000000..79cb345688c9 --- /dev/null +++ b/include/tvm/relax/attrs/datatype.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/datatype.h + * \brief Attributes for datatype operators. + */ +#ifndef TVM_RELAX_ATTRS_DATATYPE_H_ +#define TVM_RELAX_ATTRS_DATATYPE_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in astype operator */ +struct AstypeAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(AstypeAttrs, "relax.attrs.AstypeAttrs") { + TVM_ATTR_FIELD(dtype).describe("Target data type"); + } +}; // struct AstypeAttrs. + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_DATATYPE_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 3393a5dcae67..f3ab9085b87e 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -20,6 +20,7 @@ # Operators from .base import * from .binary import * +from .datatype import * from .index import * from .manipulate import * from .op_attrs import * diff --git a/python/tvm/relax/op/datatype.py b/python/tvm/relax/op/datatype.py new file mode 100644 index 000000000000..5c02776dd7ee --- /dev/null +++ b/python/tvm/relax/op/datatype.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Datatype operators.""" +from typing import Union + +from tvm import DataType + +from . import _ffi_api +from ..expr import Expr + + +def astype(x: Expr, dtype: Union[str, DataType]) -> Expr: + """Cast input tensor to the given data type. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + dtype: Union[str, DataType] + The target data type + + Returns + ------- + result : relax.Expr + The casted result. + """ + return _ffi_api.astype(x, dtype) # type: ignore diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 44cb2cf3a5b4..cb3336394407 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -19,6 +19,11 @@ import tvm._ffi +@tvm._ffi.register_object("relax.attrs.AstypeAttrs") +class AstypeAttrs(Attrs): + """Attributes used in astype operator""" + + @tvm._ffi.register_object("relax.attrs.TakeAttrs") class TakeAttrs(Attrs): """Attributes used in take operator""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 75a00ea04985..aaee0f4e2f89 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -31,6 +31,7 @@ from tvm.relax.op import ( add, assert_op, + astype, builtin, call_builtin_with_ctx, call_tir, @@ -403,6 +404,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "add", "arg", "assert_op", + "astype", "builtin", "call_packed", "call_tir", diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc new file mode 100644 index 000000000000..0c647aa866be --- /dev/null +++ b/src/relax/op/tensor/datatype.cc @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file datatype.cc + * \brief Datatype operators. + */ + +#include "datatype.h" + +#include + +namespace tvm { +namespace relax { + +/* relax.astype */ +TVM_REGISTER_NODE_TYPE(AstypeAttrs); + +Expr astype(Expr x, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.astype"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.astype").set_body_typed(astype); + +StructInfo InferStructInfoAstype(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + ObjectPtr new_sinfo = make_object(*sinfo.get()); + new_sinfo->dtype = attrs->dtype; + return TensorStructInfo(new_sinfo); +} + +TVM_REGISTER_OP("relax.astype") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor") + .set_attr("FInferStructInfo", InferStructInfoAstype); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/datatype.h b/src/relax/op/tensor/datatype.h new file mode 100644 index 000000000000..6afa7a50d462 --- /dev/null +++ b/src/relax/op/tensor/datatype.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file datatype.h + * \brief The functions to make Relax datatype operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_DATATYPE_H_ +#define TVM_RELAX_OP_TENSOR_DATATYPE_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Cast input tensor to the given data type. + * \param x The input data to the operator. + * \param dtype The target data type + * \return The casted result. + */ +Expr astype(Expr x, DataType dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_DATATYPE_H_ diff --git a/tests/python/relax/test_op_datatype.py b/tests/python/relax/test_op_datatype.py new file mode 100644 index 000000000000..56bbe464cf20 --- /dev/null +++ b/tests/python/relax/test_op_datatype.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + assert relax.op.astype(x, "float16").op == Op.get("relax.astype") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_astype_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor(ndim=2)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo((2, 3), "float16")) + _check_inference( + bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo(dtype="float16", ndim=2) + ) + _check_inference(bb, relax.op.astype(x2, "float16"), relax.TensorStructInfo(dtype="float16")) + _check_inference(bb, relax.op.astype(x3, "float16"), relax.TensorStructInfo((2, 3), "float16")) + _check_inference( + bb, relax.op.astype(x4, "float16"), relax.TensorStructInfo(dtype="float16", ndim=2) + ) + _check_inference(bb, relax.op.astype(x5, "float16"), relax.TensorStructInfo(dtype="float16")) + + +def test_astype_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + + _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo((m, n), "float16")) + _check_inference(bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo((m, n), "float16")) + + +def test_astype_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo(s0, "float16")) + _check_inference(bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.astype(x2, "float16"), relax.TensorStructInfo(s2, "float16")) + + +def test_astype_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int32")) + + _check_inference(bb, relax.op.astype(x0, "float32"), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.astype(x1, "int32"), relax.TensorStructInfo((2, 3), "int32")) + _check_inference(bb, relax.op.astype(x2, "int8"), relax.TensorStructInfo((2, 3), "int8")) + + +def test_astype_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.astype(x0, "float16")) + with pytest.raises(TVMError): + bb.normalize(relax.op.astype(x1, "float16")) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_datatype.py b/tests/python/relax/test_tvmscript_parser_op_datatype.py new file mode 100644 index 000000000000..ec71e868d45b --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_datatype.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_astype(): + @R.function + def expected(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float16"): + gv: R.Tensor((2, 3, 4), "float16") = R.astype(x, "float16") + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("main", [x]): + gv = bb.emit(relax.op.astype(x, "float16")) + bb.emit_func_output(gv) + + _check(expected, bb.get()["main"]) + + +if __name__ == "__main__": + tvm.testing.main() From 20ca7c07bac9b389979e2d8915917f0323a2e279 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 14:55:30 -0500 Subject: [PATCH 15/81] [Unity] Relax op: set (#13990) This PR is about the high-level tensor computation operators in Relax. This PR includes the set operators. Co-authored-by: Prakalp Srivastava --- include/tvm/relax/attrs/set.h | 62 ++ python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/op_attrs.py | 5 + python/tvm/relax/op/set.py | 101 ++ python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/op/tensor/set.cc | 103 +++ src/relax/op/tensor/set.h | 40 + tests/python/relax/test_op_set.py | 862 ++++++++++++++++++ .../relax/test_tvmscript_parser_op_set.py | 68 ++ 9 files changed, 1244 insertions(+) create mode 100644 include/tvm/relax/attrs/set.h create mode 100644 python/tvm/relax/op/set.py create mode 100644 src/relax/op/tensor/set.cc create mode 100644 src/relax/op/tensor/set.h create mode 100644 tests/python/relax/test_op_set.py create mode 100644 tests/python/relax/test_tvmscript_parser_op_set.py diff --git a/include/tvm/relax/attrs/set.h b/include/tvm/relax/attrs/set.h new file mode 100644 index 000000000000..3fae7646ff8e --- /dev/null +++ b/include/tvm/relax/attrs/set.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/set.h + * \brief Attributes for set operators. + */ +#ifndef TVM_RELAX_ATTRS_SET_H_ +#define TVM_RELAX_ATTRS_SET_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in unique operator */ +struct UniqueAttrs : public tvm::AttrsNode { + bool sorted; + bool return_index; + bool return_inverse; + bool return_counts; + Optional axis; + + TVM_DECLARE_ATTRS(UniqueAttrs, "relax.attrs.UniqueAttrs") { + TVM_ATTR_FIELD(sorted).describe( + "Whether to sort the unique elements in ascending order before returning as output."); + TVM_ATTR_FIELD(return_index) + .describe( + "Whether to return an additional tensor with indices for where elements in the unique " + "tensor come from the original input."); + TVM_ATTR_FIELD(return_inverse) + .describe( + "Whether to return an additional tensor with indices for where elements in the " + "original input ended up in the returned unique list."); + TVM_ATTR_FIELD(return_counts) + .describe("Whether to return an additional tensor with counts of each unique elements"); + TVM_ATTR_FIELD(axis).describe( + "The dimension to apply unique. If it is NullOpt, the unique values of the flattened input " + "is are returned."); + } +}; // struct UniqueAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_SET_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index f3ab9085b87e..da29c3715dec 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -24,5 +24,6 @@ from .index import * from .manipulate import * from .op_attrs import * +from .set import * from . import builtin from . import memory diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index cb3336394407..47c3b2879878 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -32,3 +32,8 @@ class TakeAttrs(Attrs): @tvm._ffi.register_object("relax.attrs.StridedSliceAttrs") class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" + + +@tvm._ffi.register_object("relax.attrs.UniqueAttrs") +class UniqueAttrs(Attrs): + """Attributes used for the unique operator""" diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py new file mode 100644 index 000000000000..b7ee0f381169 --- /dev/null +++ b/python/tvm/relax/op/set.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-outside-toplevel, redefined-builtin, unused-argument +"""Set operators.""" +from typing import Optional + +import numpy as np # type: ignore +import tvm + +from . import _ffi_api +from ..expr import Expr + + +def unique( + x: Expr, + sorted: bool = True, + return_index: bool = False, + return_inverse: bool = False, + return_counts: bool = False, + axis: Optional[int] = None, +) -> Expr: + """Find the unique elements in a given tensor. + In addition, it optionally returns + - the indices of the input tensor that give the unique values; + - the indices of the unique tensor that reconstruct the input tensor; + - the number of times each unique value comes up in the input tensor. + + Parameters + ---------- + x : relax.Expr + The input tensor. + + sorted : bool + Whether to sort the unique elements in ascending order before + returning as output. + + return_index : bool + Whether to return an additional tensor with indices for where elements in + the unique tensor come from the original input. + + return_inverse : bool + Whether to return an additional tensor with indices for where elements in + the original input ended up in the returned unique list. + + return_counts : bool + Whether to return an additional tensor with counts of each unique elements. + + axis : Optional + The dimension to apply unique. + If not specified, the unique values of the flattened input are returned. + + Returns + ------- + ret : relax.Expr + The created relax call with + """ + + return _ffi_api.unique( # type: ignore + x, sorted, return_index, return_inverse, return_counts, axis + ) + + +@tvm.register_func("relax.run.unique") +def numpy_unique( + x: tvm.nd.array, + sorted: int, + return_index: int, + return_inverse: int, + return_counts: int, + axis: Optional[int], +) -> tvm.nd.array: + """Returns the unique elements of the input tensor. + + Uses numpy.unique to compute unique elements. + """ + import builtins + + # TODO(prakalp): add support for returning a tuple when return_inverse or return_counts is True + if bool(return_index) or bool(return_inverse) or bool(return_counts): + raise NotImplementedError("missing support return_inverse or return_counts set to true") + x_numpy = x.numpy() + # TODO(prakalp): use torch.unique instead of numpy when torch is installed in ci. + output_sorted_numpy, indices = np.unique(x_numpy, return_index=True) + if sorted: + return tvm.nd.array(output_sorted_numpy) + output_numpy = [x_numpy.flatten()[index] for index in builtins.sorted(indices, reverse=True)] + return tvm.nd.array(output_numpy) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index aaee0f4e2f89..537adec6154c 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -45,6 +45,7 @@ shape_of, strided_slice, take, + unique, ) from tvm.relax.struct_info import StructInfo from tvm.relax.utils import args_converter @@ -434,4 +435,5 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "strided_slice", "take", "tuple", + "unique", ] diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc new file mode 100644 index 000000000000..4d5a274e17fa --- /dev/null +++ b/src/relax/op/tensor/set.cc @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file set.cc + * \brief Relax set operators. + */ + +#include "set.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.unique */ +TVM_REGISTER_NODE_TYPE(UniqueAttrs); + +Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool return_counts, + Optional axis) { + ObjectPtr attrs = make_object(); + attrs->sorted = sorted; + attrs->return_index = return_index; + attrs->return_inverse = return_inverse; + attrs->return_counts = return_counts; + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.unique"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.unique").set_body_typed(unique); + +StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) { + // Normalize the axis for sanity check purpose. + NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value); + } + + int n_int_return = static_cast(attrs->return_index) + + static_cast(attrs->return_inverse) + + static_cast(attrs->return_counts); + + std::vector output_sinfo; + output_sinfo.reserve(1 + n_int_return); + + // unique values + if (data_sinfo->ndim == 0) { + output_sinfo.push_back( + TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), data_sinfo->dtype)); + } else if (attrs->axis.defined()) { + output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim)); + } else { + output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1)); + } + + // index, reverse and counts + TensorStructInfo int_return{nullptr}; + if (data_sinfo->ndim == 0) { + int_return = + TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), DataType::Int(64)); + } else { + int_return = TensorStructInfo(DataType::Int(64), /*ndim=*/1); + } + for (int i = 0; i < n_int_return; ++i) { + output_sinfo.push_back(int_return); + } + + if (output_sinfo.size() == 1) { + return output_sinfo[0]; + } else { + return TupleStructInfo(output_sinfo); + } +} + +TVM_REGISTER_OP("relax.unique") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor") + .set_attr("FInferStructInfo", InferStructInfoUnique) + .set_attr("FCallPacked", "relax.run.unique"); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/set.h b/src/relax/op/tensor/set.h new file mode 100644 index 000000000000..83d2619e4d2c --- /dev/null +++ b/src/relax/op/tensor/set.h @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. Sex The NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. Sex The License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file set.h + * \brief The functions to make Relax set operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_SET_H_ +#define TVM_RELAX_OP_TENSOR_SET_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool return_counts, + Optional axis); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_SET_H_ diff --git a/tests/python/relax/test_op_set.py b/tests/python/relax/test_op_set.py new file mode 100644 index 000000000000..755d5e8f870c --- /dev/null +++ b/tests/python/relax/test_op_set.py @@ -0,0 +1,862 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + assert relax.op.unique(x).op == Op.get("relax.unique") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_unique_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=False, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=True, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x0, return_index=True, return_inverse=False, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=False, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=False, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=False, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=-2), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x0, sorted=True, return_index=True, return_inverse=True, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x0, sorted=True, return_index=True, return_inverse=True, return_counts=True, axis=1 + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x1, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.unique( + x1, return_index=False, return_inverse=True, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=True, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=False, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x2, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.unique( + x2, return_index=True, return_inverse=False, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=False, return_counts=False, axis=1), + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="int64", ndim=1)] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=False, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x3, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="", ndim=3), + ) + _check_inference( + bb, + relax.op.unique( + x3, return_index=False, return_inverse=False, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + + +def test_unique_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + x = relax.Var("x", R.Tensor((a, b, c), "float32")) + + _check_inference( + bb, + relax.op.unique( + x, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=False, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + + +def test_unique_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=False, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x1, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.unique( + x1, return_index=False, return_inverse=False, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="int64", ndim=1)] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + + +def test_unique_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32")) + + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float16", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="int8", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="int32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + + +def test_unique_infer_struct_info_input_zero_rank(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(())) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + x0 = relax.Var("x", R.Tensor((), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=0)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((1,), "float32"), + relax.TensorStructInfo((1,), "int64"), + relax.TensorStructInfo((1,), "int64"), + relax.TensorStructInfo((1,), "int64"), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=False, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((1,), "float32"), + relax.TensorStructInfo((1,), "int64"), + relax.TensorStructInfo((1,), "int64"), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x2, return_index=True, return_inverse=False, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [relax.TensorStructInfo((1,), "float32"), relax.TensorStructInfo((1,), "int64")] + ), + ) + _check_inference( + bb, + relax.op.unique( + x3, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo((1,), "float32"), + ) + + +def test_unique_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x0, axis=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x0, axis=-4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x1, axis=0)) + + +def test_unique_infer_struct_info_wrong_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_set.py b/tests/python/relax/test_tvmscript_parser_op_set.py new file mode 100644 index 000000000000..8e01fa6f6215 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_set.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_unique(): + @R.function + def foo( + x: R.Tensor((2, 3, 4), dtype="float32") + ) -> R.Tuple( + R.Tensor(dtype="float32", ndim=3), + R.Tensor(dtype="int64", ndim=1), + R.Tensor(dtype="int64", ndim=1), + ): + gv: R.Tuple( + R.Tensor(dtype="float32", ndim=3), + R.Tensor(dtype="int64", ndim=1), + R.Tensor(dtype="int64", ndim=1), + ) = R.unique( + x, sorted=True, return_index=False, return_inverse=True, return_counts=True, axis=1 + ) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit( + relax.op.unique(x, sorted=True, return_inverse=True, return_counts=True, axis=1) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() From c06d16f0cb1985546b11c50554b557a9e51f79cf Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 14:57:05 -0500 Subject: [PATCH 16/81] [Unity] Relax op: image (#13994) This PR is about the high-level tensor computation operators in Relax. This PR includes the image operators. --- include/tvm/relax/attrs/image.h | 81 ++++++ python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/image/__init__.py | 19 ++ python/tvm/relax/op/image/_ffi_api.py | 20 ++ python/tvm/relax/op/image/image.py | 128 +++++++++ python/tvm/relax/op/op_attrs.py | 5 + python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/op/image/resize.cc | 113 ++++++++ src/relax/op/image/resize.h | 43 +++ tests/python/relax/test_op_image.py | 245 ++++++++++++++++++ .../relax/test_tvmscript_parser_op_image.py | 54 ++++ 11 files changed, 711 insertions(+) create mode 100644 include/tvm/relax/attrs/image.h create mode 100644 python/tvm/relax/op/image/__init__.py create mode 100644 python/tvm/relax/op/image/_ffi_api.py create mode 100644 python/tvm/relax/op/image/image.py create mode 100644 src/relax/op/image/resize.cc create mode 100644 src/relax/op/image/resize.h create mode 100644 tests/python/relax/test_op_image.py create mode 100644 tests/python/relax/test_tvmscript_parser_op_image.py diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h new file mode 100644 index 000000000000..13463aaa4849 --- /dev/null +++ b/include/tvm/relax/attrs/image.h @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/image.h + * \brief Attributes for image operators. + */ +#ifndef TVM_RELAX_ATTRS_IMAGE_H_ +#define TVM_RELAX_ATTRS_IMAGE_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in image resize2d operator */ +struct Resize2DAttrs : public tvm::AttrsNode { + Array roi; + String layout; + String method; + String coordinate_transformation_mode; + String rounding_method; + double cubic_alpha; + int cubic_exclude; + double extrapolation_value; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Resize2DAttrs, "relax.attrs.Resize2DAttrs") { + TVM_ATTR_FIELD(roi).describe( + "Region of Interest for coordinate transformation mode 'tf_crop_and_resize'"); + TVM_ATTR_FIELD(layout).describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method).describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "linear - Bilinear Interpolation" + "cubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(rounding_method) + .describe( + "indicates how to find the \"nearest\" pixel in nearest_neighbor method" + "Available options are round, floor, and ceil."); + TVM_ATTR_FIELD(cubic_alpha).describe("Spline Coefficient for Bicubic Interpolation"); + TVM_ATTR_FIELD(cubic_exclude) + .describe("Flag to exclude exterior of the image during bicubic interpolation"); + TVM_ATTR_FIELD(extrapolation_value) + .describe("Value to return when roi is outside of the image"); + TVM_ATTR_FIELD(out_dtype).describe( + "The dtype of the output tensor. It it is not specified, the output will have the same " + "dtype as input if not specified."); + } +}; // struct Resize2dAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_IMAGE_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index da29c3715dec..38573512691c 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -26,4 +26,5 @@ from .op_attrs import * from .set import * from . import builtin +from . import image from . import memory diff --git a/python/tvm/relax/op/image/__init__.py b/python/tvm/relax/op/image/__init__.py new file mode 100644 index 000000000000..f2552ad6ac51 --- /dev/null +++ b/python/tvm/relax/op/image/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Image operators.""" +from .image import * diff --git a/python/tvm/relax/op/image/_ffi_api.py b/python/tvm/relax/op/image/_ffi_api.py new file mode 100644 index 000000000000..e666203ae7ff --- /dev/null +++ b/python/tvm/relax/op/image/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.image", __name__) diff --git a/python/tvm/relax/op/image/image.py b/python/tvm/relax/op/image/image.py new file mode 100644 index 000000000000..562de5021d53 --- /dev/null +++ b/python/tvm/relax/op/image/image.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Image operators.""" +from typing import Optional, Tuple, Union + +from tvm import DataType +from tvm.ir.expr import PrimExpr + +from . import _ffi_api +from ...expr import Expr, ShapeExpr + + +PrimExprLike = Union[int, PrimExpr] + + +def resize2d( + data: Expr, + size: Union[Expr, PrimExprLike, Tuple[PrimExprLike]], + roi: Optional[Union[float, Tuple[float]]] = None, + layout: str = "NCHW", + method: str = "linear", + coordinate_transformation_mode: str = "half_pixel", + rounding_method: str = "round", + cubic_alpha: float = -0.5, + cubic_exclude: int = 0, + extrapolation_value: float = 0.0, + out_dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + """Image resize2d operator. + + This operator takes data as input and does 2D scaling to the given scale factor. + In the default case, where the data_layout is `NCHW` + with data of shape (n, c, h, w) + out will have a shape (n, c, size[0], size[1]) + + method indicates the algorithm to be used while calculating the out value + and method can be one of ("linear", "nearest_neighbor", "cubic") + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + size: Union[Expr, PrimExprLike, Tuple[PrimExprLike]] + The out size to which the image will be resized. + If specified as a list, it is required to have length either 1 or 2. + If specified as an Expr, it is required to have ndim 2. + + roi: Optional[Union[float, Tuple[float]]] + The region of interest for cropping the input image. Expected to be of + size 4, and format [start_h, start_w, end_h, end_w]. + Only used if coordinate_transformation_mode is tf_crop_and_resize. + + layout : str + Layout of the input. + + method : str + Scale method to used [nearest_neighbor, linear, cubic]. + + coordinate_transformation_mode : str + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. Definitions can be found + in topi/image/resize.py. + [half_pixel, align_corners, asymmetric, pytorch_half_pixel, + tf_half_pixel_for_nn, and tf_crop_and_resize]. + + rounding_method: str + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + cubic_alpha: float + Spline Coefficient for bicubic interpolation + + cubic_exclude: int + Flag to exclude exterior of the image during bicubic interpolation + + extrapolation_value: float + Fill value to use when roi is outside of the image + + out_dtype : Optional[Union[str, DataType]] + The dtype of the output tensor. + It it is not specified, the output will have the same dtype as input if not specified. + + Returns + ------- + result: relax.Expr + The resized result. + """ + if roi is None: + roi = (0.0, 0.0, 0.0, 0.0) # type: ignore + elif isinstance(roi, float): + roi = (roi, roi, roi, roi) # type: ignore + + if isinstance(size, (int, PrimExpr)): + size = (size, size) + if isinstance(size, tuple): + if len(size) == 1: + size = ShapeExpr([size[0], size[0]]) + else: + size = ShapeExpr(size) + + return _ffi_api.resize2d( # type: ignore + data, + size, + roi, + layout, + method, + coordinate_transformation_mode, + rounding_method, + cubic_alpha, + cubic_exclude, + extrapolation_value, + out_dtype, + ) diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 47c3b2879878..fb64443b7e09 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -34,6 +34,11 @@ class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" +@tvm._ffi.register_object("relax.attrs.Resize2DAttrs") +class Resize2DAttrs(Attrs): + """Attributes used in image resize2d operator""" + + @tvm._ffi.register_object("relax.attrs.UniqueAttrs") class UniqueAttrs(Attrs): """Attributes used for the unique operator""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 537adec6154c..22b85f6f402f 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -35,6 +35,7 @@ builtin, call_builtin_with_ctx, call_tir, + image, invoke_closure, make_closure, memory, @@ -420,6 +421,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "func_ret_struct_info", "func_ret_value", "function", + "image", "invoke_closure", "make_closure", "memory", diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc new file mode 100644 index 000000000000..2711b3cc45f5 --- /dev/null +++ b/src/relax/op/image/resize.cc @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file resize.cc + * \brief Image resize operators. + */ + +#include "resize.h" + +#include + +namespace tvm { +namespace relax { + +/* relax.resize2d */ +TVM_REGISTER_NODE_TYPE(Resize2DAttrs); + +Expr resize2d(Expr data, Expr size, Array roi, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, double extrapolation_value, DataType out_dtype) { + ObjectPtr attrs = make_object(); + attrs->roi = std::move(roi); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->coordinate_transformation_mode = std::move(coordinate_transformation_mode); + attrs->rounding_method = std::move(rounding_method); + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; + attrs->extrapolation_value = extrapolation_value; + attrs->out_dtype = out_dtype; + + static const Op& op = Op::Get("relax.image.resize2d"); + return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.image.resize2d").set_body_typed(resize2d); + +StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1 && call->args.size() != 2) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Resize2D expects either one or two arguments, while the given number of arguments is " + << call->args.size()); + } + + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* size_sinfo = GetStructInfoAs(call->args[1]); + const auto* size_value = call->args[1].as(); + if (data_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Resize2D expects the input data to be a Tensor, while the given data is " + << call->args[0]->GetTypeKey()); + } + if (size_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Resize2D expects the given output image size to be a Shape, while the given one is " + << call->args[1]->GetTypeKey()); + } + if (size_sinfo->ndim != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "Resize2D expects the given output image size to " + "be a 2-dim shape, while the given one has ndim " + << size_sinfo->ndim); + } + + const auto* attrs = call->attrs.as(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"data"); + + DataType out_dtype = attrs->out_dtype.is_void() ? data_sinfo->dtype : attrs->out_dtype; + + Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, GetRef(data_sinfo), data_layout); + if (!data_shape.defined() || size_value == nullptr) { + return TensorStructInfo(out_dtype, data_layout.ndim()); + } + + Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + Array out_NCHW_shape(data_NCHW_shape); + out_NCHW_shape.Set(2, size_value->values[0]); + out_NCHW_shape.Set(3, size_value->values[1]); + + Array out_shape = data2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), out_dtype); +} + +TVM_REGISTER_OP("relax.image.resize2d") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("size", "Shape", "The output image shape.") + .set_attr("FInferStructInfo", InferStructInfoResize2D); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h new file mode 100644 index 000000000000..085a1cbc5d5f --- /dev/null +++ b/src/relax/op/image/resize.h @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file resize.h + * \brief The functions to make Relax image resize operator calls. + */ + +#ifndef TVM_RELAX_OP_IMAGE_RESIZE_H_ +#define TVM_RELAX_OP_IMAGE_RESIZE_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief Image resize2d operator. */ +Expr resize2d(Expr data, Expr size, Array roi, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, double extrapolation_value, DataType out_dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_IMAGE_RESIZE_H_ diff --git a/tests/python/relax/test_op_image.py b/tests/python/relax/test_op_image.py new file mode 100644 index 000000000000..b06b51a2a198 --- /dev/null +++ b/tests/python/relax/test_op_image.py @@ -0,0 +1,245 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + assert relax.op.image.resize2d(x, (28, 28)).op == Op.get("relax.image.resize2d") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_resize2d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) + x3 = relax.Var("x", R.Tensor("float32", ndim=4)) + x4 = relax.Var("x", R.Tensor("float32", ndim=5)) + x5 = relax.Var("x", R.Tensor("float32")) + x6 = relax.Var("x", R.Tensor(ndim=4)) + x7 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.image.resize2d(x0, (28, 28)), relax.TensorStructInfo((2, 3, 28, 28), "float32") + ) + _check_inference( + bb, + relax.op.image.resize2d(x0, size=28), + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x0, size=(28, 30)), + relax.TensorStructInfo((2, 3, 28, 30), "float32"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x1, size=28, layout="NHWC"), + relax.TensorStructInfo((2, 28, 28, 3), "float32"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x0, size=28, out_dtype="float16"), + relax.TensorStructInfo((2, 3, 28, 28), "float16"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x2, size=28, layout="NCHW16c"), + relax.TensorStructInfo((2, 4, 28, 28, 16), "float32"), + ) + _check_inference( + bb, relax.op.image.resize2d(x3, size=28), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, + relax.op.image.resize2d(x4, size=28, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, relax.op.image.resize2d(x5, size=28), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.image.resize2d(x6, size=28), relax.TensorStructInfo(dtype="", ndim=4) + ) + _check_inference( + bb, + relax.op.image.resize2d(x6, size=28, out_dtype="float32"), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, relax.op.image.resize2d(x7, size=28), relax.TensorStructInfo(dtype="", ndim=4) + ) + + +def test_resize2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + oh = tir.Var("oh", "int64") + ow = tir.Var("ow", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, 16), "float32")) + + _check_inference( + bb, relax.op.image.resize2d(x0, size=oh), relax.TensorStructInfo((n, c, oh, oh), "float32") + ) + _check_inference( + bb, + relax.op.image.resize2d(x0, size=(oh, ow)), + relax.TensorStructInfo((n, c, oh, ow), "float32"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x1, size=(oh, ow), layout="NCHW16c"), + relax.TensorStructInfo((n, c, oh, ow, 16), "float32"), + ) + + +def test_resize2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.image.resize2d(x0, size=32), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, + relax.op.image.resize2d(x1, size=32, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.image.resize2d(x2, size=32, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + + +def test_resize2d_infer_struct_info_pool_size_var(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + s0 = relax.Var("s", relax.ShapeStructInfo((30, 30))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + + _check_inference( + bb, + relax.op.image.resize2d(x0, s0), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, relax.op.image.resize2d(x0, s1), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + + +def test_resize2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) + _check_inference( + bb, relax.op.image.resize2d(x0, size=28), relax.TensorStructInfo((2, 3, 28, 28), "float16") + ) + _check_inference( + bb, relax.op.image.resize2d(x1, size=28), relax.TensorStructInfo((2, 3, 28, 28), "int8") + ) + _check_inference( + bb, relax.op.image.resize2d(x2, size=28), relax.TensorStructInfo((2, 3, 28, 28), "int64") + ) + + +def test_resize2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x, size=28, layout="OIHW")) + + +def test_resize2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, size=28, layout="NCHW16c")) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x1, size=28, layout="NCHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x2, size=28)) + + +def test_resize2d_wrong_pool_size_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + s0 = relax.ShapeExpr((3,)) + s1 = relax.Var("s", relax.ShapeStructInfo((30, 30, 30))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s4 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + s5 = relax.Var("s", relax.ShapeStructInfo()) + + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, (3, 3, 3))) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s5)) + + +def test_resize2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + s0 = relax.Var("s", R.Tensor((3, 3))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, size=32)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x1, size=32)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x2, s0)) + with pytest.raises(TVMError): + relax.op.image.resize2d(x2, [30, 30]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_image.py b/tests/python/relax/test_tvmscript_parser_op_image.py new file mode 100644 index 000000000000..a90da37812ef --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_image.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_resize2d(): + @R.function + def foo(x: R.Tensor((2, 14, 14, 3), "float32")) -> R.Tensor((2, 28, 28, 3), "float32"): + gv: R.Tensor((2, 28, 28, 3), "float32") = R.image.resize2d(x, size=(28, 28), layout="NHWC") + return gv + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 14, 14, 3), "float32")) + with bb.function("foo", [x]): + gv = bb.emit(relax.op.image.resize2d(x, (28, 28), layout="NHWC")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() From 27dde569cc1017fe5b76c9681c6c886fa57e1bbb Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 14:58:21 -0500 Subject: [PATCH 17/81] [Unity] Relax op: arithmetic, comparison (#13983) This PR is about the high-level tensor computation operators in Relax. This PR includes the unary, binary and ternary arithmetic and comparison operators. Co-authored-by: Siyuan Feng Co-authored-by: Chaofan Lin <1713833595@qq.com> --- python/tvm/relax/op/__init__.py | 13 + python/tvm/relax/op/binary.py | 165 ++++++ python/tvm/relax/op/ternary.py | 43 ++ python/tvm/relax/op/unary.py | 529 ++++++++++++++++++ python/tvm/script/ir_builder/relax/ir.py | 74 +++ src/relax/op/op.cc | 2 +- src/relax/op/op_common.cc | 17 +- src/relax/op/op_common.h | 21 +- src/relax/op/tensor/binary.cc | 12 + src/relax/op/tensor/binary.h | 29 + src/relax/op/tensor/ternary.cc | 108 ++++ src/relax/op/tensor/ternary.h | 45 ++ src/relax/op/tensor/unary.cc | 91 +++ src/relax/op/tensor/unary.h | 144 +++++ tests/python/relax/test_op_binary.py | 209 +++++++ tests/python/relax/test_op_ternary.py | 162 ++++++ tests/python/relax/test_op_unary.py | 203 +++++++ tests/python/relax/test_tvmscript_parser.py | 7 +- .../test_tvmscript_parser_op_arith_cmp.py | 179 ++++++ 19 files changed, 2027 insertions(+), 26 deletions(-) create mode 100644 python/tvm/relax/op/ternary.py create mode 100644 python/tvm/relax/op/unary.py create mode 100644 src/relax/op/tensor/ternary.cc create mode 100644 src/relax/op/tensor/ternary.h create mode 100644 src/relax/op/tensor/unary.cc create mode 100644 src/relax/op/tensor/unary.h create mode 100644 tests/python/relax/test_op_binary.py create mode 100644 tests/python/relax/test_op_ternary.py create mode 100644 tests/python/relax/test_op_unary.py create mode 100644 tests/python/relax/test_tvmscript_parser_op_arith_cmp.py diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 38573512691c..344576fe13b2 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -25,6 +25,19 @@ from .manipulate import * from .op_attrs import * from .set import * +from .ternary import * +from .unary import * from . import builtin from . import image from . import memory + + +def _register_op_make(): + # pylint: disable=import-outside-toplevel + from . import _ffi_api + from .. import expr + + expr._op_ffi_api = _ffi_api # type: ignore + + +_register_op_make() diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py index eee0b6f3366a..4042f9bbc9aa 100644 --- a/python/tvm/relax/op/binary.py +++ b/python/tvm/relax/op/binary.py @@ -49,6 +49,42 @@ def add(x1: Expr, x2: Expr) -> Expr: return _ffi_api.add(x1, x2) # type: ignore +def divide(x1: Expr, x2: Expr) -> Expr: + """Division with numpy-style broadcasting. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.divide(x1, x2) # type: ignore + + +def floor_divide(x1: Expr, x2: Expr) -> Expr: + """Floor division with numpy-style broadcasting. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.floor_divide(x1, x2) # type: ignore + + def multiply(x1: Expr, x2: Expr) -> Expr: """Multiplication with numpy-style broadcasting. @@ -65,3 +101,132 @@ def multiply(x1: Expr, x2: Expr) -> Expr: The computed result. """ return _ffi_api.multiply(x1, x2) # type: ignore + + +def subtract(x1: Expr, x2: Expr) -> Expr: + """Subtraction with numpy-style broadcasting. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.subtract(x1, x2) # type: ignore + + +###################### Comparison operators ###################### + + +def equal(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs == rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.equal(x1, x2) # type: ignore + + +def greater(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs > rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.greater(x1, x2) # type: ignore + + +def greater_equal(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs >= rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.greater_equal(x1, x2) # type: ignore + + +def less(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs < rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.less(x1, x2) # type: ignore + + +def less_equal(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs <= rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.less_equal(x1, x2) # type: ignore + + +def not_equal(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs != rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.not_equal(x1, x2) # type: ignore diff --git a/python/tvm/relax/op/ternary.py b/python/tvm/relax/op/ternary.py new file mode 100644 index 000000000000..7c320cc1ca48 --- /dev/null +++ b/python/tvm/relax/op/ternary.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, invalid-name +"""Relax ternary arithmetic operators.""" +from . import _ffi_api +from ..expr import Expr + + +def ewise_fma(x1: Expr, x2: Expr, x3: Expr) -> Expr: + """Elementwise fused multiply-add operator + Returns elementwise result of :math:`x1 * x2 + x3` + + Parameters + ---------- + x1 : relax.Expr + The left hand operand of the multiplication + + x2 : relax.Expr + The right hand operand of the multiplication + + x3 : relax.Expr + The operand of the addition + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.ewise_fma(x1, x2, x3) # type: ignore diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py new file mode 100644 index 000000000000..866d2a8273d6 --- /dev/null +++ b/python/tvm/relax/op/unary.py @@ -0,0 +1,529 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, invalid-name +"""Relax unary arithmetic operators.""" +from . import _ffi_api +from ..expr import Expr +from ..utils import args_converter + +###################### Arithmetic operators ###################### + + +def abs(x: Expr) -> Expr: + """Compute element-wise absolute value of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.abs(x) # type: ignore + + +def acos(x: Expr) -> Expr: + """Compute element-wise arc cos of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.acos(x) # type: ignore + + +def acosh(x: Expr) -> Expr: + """Compute element-wise arc cosh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.acosh(x) # type: ignore + + +def asin(x: Expr) -> Expr: + """Compute element-wise arc sin of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.asin(x) # type: ignore + + +def asinh(x: Expr) -> Expr: + """Compute element-wise arc sinh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.asinh(x) # type: ignore + + +def atan(x: Expr) -> Expr: + """Compute element-wise arc tan of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.atan(x) # type: ignore + + +def atanh(x: Expr) -> Expr: + """Compute element-wise arc tanh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.atanh(x) # type: ignore + + +def ceil(x: Expr) -> Expr: + """Take ceil of input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.ceil(x) # type: ignore + + +def cos(x: Expr) -> Expr: + """Compute element-wise cos of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.cos(x) # type: ignore + + +def cosh(x: Expr) -> Expr: + """Compute element-wise cosh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.cosh(x) # type: ignore + + +def exp(x: Expr) -> Expr: + """Compute element-wise exp of data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.exp(x) # type: ignore + + +def floor(x: Expr) -> Expr: + """Take floor of input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.floor(x) # type: ignore + + +def log(x: Expr) -> Expr: + """Compute element-wise natural logarithm of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.log(x) # type: ignore + + +def negative(x: Expr) -> Expr: + """Compute element-wise negative of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result + """ + return _ffi_api.negative(x) # type: ignore + + +def round(x: Expr) -> Expr: + """Rounds each element of the input data to nearest integer. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.round(x) # type: ignore + + +def sigmoid(x: Expr) -> Expr: + """Compute element-wise sigmoid of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.sigmoid(x) # type: ignore + + +def sign(x: Expr) -> Expr: + """Returns an indication of the sign of a number for each element of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.sign(x) # type: ignore + + +def sin(x: Expr) -> Expr: + """Compute element-wise sin of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.sin(x) # type: ignore + + +def sinh(x: Expr) -> Expr: + """Compute element-wise sinh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.sinh(x) # type: ignore + + +def square(x: Expr) -> Expr: + """Squares each element of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.square(x) # type: ignore + + +def sqrt(x: Expr) -> Expr: + """Compute element-wise square root of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.sqrt(x) # type: ignore + + +def tan(x: Expr) -> Expr: + """Compute element-wise tan of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.tan(x) # type: ignore + + +def tanh(x: Expr) -> Expr: + """Compute element-wise tanh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.tanh(x) # type: ignore + + +@args_converter.auto +def clip(x: Expr, min: Expr, max: Expr) -> Expr: + """Clips tensor values to a specified min and max. + + Parameters + ---------- + x : relax.Expr + The input data + + min : relax.Expr + The minimum value + + max : relax.Expr + The maximum value + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.clip(x, min, max) # type: ignore + + +###################### Check operators ###################### + + +def isfinite(x: Expr) -> Expr: + """Check if input value is finite. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.isfinite(x) # type: ignore + + +def isinf(x: Expr) -> Expr: + """Check if input value is infinite. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.isinf(x) # type: ignore + + +def isnan(x: Expr) -> Expr: + """Check if input value is Nan. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.isnan(x) # type: ignore diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 22b85f6f402f..a5cb574a06f0 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -29,23 +29,60 @@ ############################### Operators ############################### from tvm.relax.op import ( + abs, + acos, + acosh, + asin, + asinh, + atan, + atanh, add, assert_op, astype, builtin, call_builtin_with_ctx, call_tir, + ceil, + clip, + cos, + cosh, + divide, + equal, + ewise_fma, + exp, + floor, + floor_divide, + greater, + greater_equal, image, invoke_closure, + isfinite, + isinf, + isnan, + less, + less_equal, + log, make_closure, memory, multiply, + negative, + not_equal, null_value, print, reshape, + round, shape_of, + sigmoid, + sign, + sin, + sinh, + square, + sqrt, strided_slice, + subtract, take, + tan, + tanh, unique, ) from tvm.relax.struct_info import StructInfo @@ -403,6 +440,13 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "If", "Then", "TupleGetItem", + "abs", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atanh", "add", "arg", "assert_op", @@ -411,31 +455,61 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "call_packed", "call_tir", "call_builtin_with_ctx", + "ceil", + "clip", + "cos", + "cosh", "const", "dataflow", + "divide", "dtype", "emit", "emit_match_cast", + "equal", + "ewise_fma", + "exp", + "floor", + "floor_divide", "func_attr", "func_name", "func_ret_struct_info", "func_ret_value", "function", + "greater", + "greater_equal", "image", "invoke_closure", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", "make_closure", "memory", "multiply", + "negative", + "not_equal", "null_value", "output", "prim_value", "print", "reshape", + "round", "shape", "shape_of", + "sigmoid", + "sign", + "sin", + "sinh", + "square", + "sqrt", "str", "strided_slice", + "subtract", "take", + "tan", + "tanh", "tuple", "unique", ] diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index ba167a45bc68..f478871e218f 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -62,7 +62,7 @@ StructInfo ReturnShapeStructInfo(const Call& call, const BlockBuilder& ctx) { StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { if (call->sinfo_args.size() != 1) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "sinfo_args should have exact 1 output struct info."); } return call->sinfo_args[0]; diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index 260f71e7bfb6..c82c325d9ba7 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -28,7 +28,7 @@ Array GetInputTensorStructInfo(const Call& call, const BlockBu Op op = Downcast(call->op); int n_input = op->arguments.size(); if (static_cast(call->args.size()) != n_input) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << op << " op should have " << n_input << " arguments"); } Array input_tensor_sinfo; @@ -36,7 +36,7 @@ Array GetInputTensorStructInfo(const Call& call, const BlockBu for (int i = 0; i < n_input; ++i) { const auto* sinfo = GetStructInfoAs(call->args[i]); if (sinfo == nullptr) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << op << " requires the input " << op->arguments[i]->name << " to be Tensor. However, the given one is " << call->args[i]->struct_info_->GetTypeKey()); @@ -70,7 +70,7 @@ Optional> InferBinaryBroadcastShape(const Call& call, const Bloc } else if (analyzer->CanProveEqual(dim0, dim1)) { output_shape.push_back(dim0); } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", the first input shape at dim " << x1_ndim - i << " is " << dim0 << " and the second input shape at dim " << x2_ndim - i << " is " << dim1 << ", which are not broadcastable."); @@ -96,17 +96,16 @@ std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int nd for (const Integer& axis : axes) { int _axis = axis->value; if (_axis < -ndim || _axis >= ndim) { - ctx->ReportFatal(Diagnostic::Error(call->span) - << "In " << call->op << ", the input axis " << _axis - << " is out of range. The input tensor has " << ndim - << " dimensions, so axis should be in range [" << -ndim << ", " << ndim - << ")."); + ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", the input axis " << _axis + << " is out of range. The input tensor has " << ndim + << " dimensions, so axis should be in range [" + << -ndim << ", " << ndim << ")."); } else if (_axis < 0) { _axis = ndim + _axis; } if (appeared_dims_set[_axis]) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", the input axes is required to be non-repetitive. However, there are " "multiple given axes referring to axis " diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index c6d335b2a1bd..29e02946c6d1 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -104,7 +104,7 @@ inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); if (require_float_dtype && !input_sinfo->IsUnknownDtype() && !input_sinfo->dtype.is_float()) { ctx->ReportFatal( - Diagnostic::Error(call->span) + Diagnostic::Error(call) << call->op << " requires the input tensor to have float dtype. However, the given input dtype is " << input_sinfo->dtype); @@ -126,11 +126,11 @@ StructInfo ReturnStructInfoFromArg(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); int n_input = op->arguments.size(); if (static_cast(call->args.size()) != n_input) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << op << " op should have " << n_input << " arguments"); } if (arg_index >= n_input) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << op << " op has only " << n_input << "arguments, but try to get the arg with index " << arg_index); } @@ -151,8 +151,6 @@ StructInfo InferStructInfoUnaryArith(const Call& call, const BlockBuilder& ctx) call, ctx, [](const TensorStructInfo& input_sinfo) { return input_sinfo->dtype; }); } -/************ Utilities ************/ - /*! * \brief Infer the output datatype for binary arithmetic operators. * \param call The context Call to the operator. @@ -168,7 +166,7 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& if (x1_sinfo->IsUnknownDtype() || x2_sinfo->IsUnknownDtype()) { return DataType::Void(); } else if (x1_sinfo->dtype != x2_sinfo->dtype) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "Data types " << x1_sinfo->dtype << " and " << x2_sinfo->dtype << " must be equal for binary operators"); } @@ -269,11 +267,10 @@ inline std::pair CheckTensorLayout(const Call tir::Layout _tensor_layout(tensor_layout, DataType::Int(64)); tir::BijectiveLayout tensor2tgt(_tensor_layout, tir::Layout(tgt_layout, DataType::Int(64))); if (!tensor2tgt.defined()) { - ctx->ReportFatal(Diagnostic::Error(call->span) - << call->op << " requires the given " << tensor_name - << " layout to be convertible from " << tgt_layout - << " layout. However, the given layout " << tensor_layout - << " is not convertible."); + ctx->ReportFatal(Diagnostic::Error(call) << call->op << " requires the given " << tensor_name + << " layout to be convertible from " << tgt_layout + << " layout. However, the given layout " + << tensor_layout << " is not convertible."); } return {_tensor_layout, tensor2tgt}; } @@ -291,7 +288,7 @@ inline Optional CheckNdimPerLayoutAndGetShape(const Call& call, const const TensorStructInfo& sinfo, const tir::Layout& layout) { if (!sinfo->IsUnknownNdim() && sinfo->ndim != static_cast(layout.ndim())) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", layout " << layout << " requires the input to be " << layout.ndim() << "-dim tensor. However, the given input has ndim " << sinfo->ndim); diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index dd61091f7aaa..b7a07c520208 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -81,7 +81,19 @@ StructInfo InferStructInfoBroadcastCMP(const Call& call, const BlockBuilder& ctx /***************** Arithmetic operators *****************/ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(add); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(divide); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract); + +/***************** Comparison operators *****************/ + +RELAX_REGISTER_CMP_OP_AND_IMPL(equal); +RELAX_REGISTER_CMP_OP_AND_IMPL(greater); +RELAX_REGISTER_CMP_OP_AND_IMPL(greater_equal); +RELAX_REGISTER_CMP_OP_AND_IMPL(less); +RELAX_REGISTER_CMP_OP_AND_IMPL(less_equal); +RELAX_REGISTER_CMP_OP_AND_IMPL(not_equal); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index a7aea576b685..b565b159bb48 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -61,9 +61,38 @@ namespace relax { /*! \brief Addition with numpy-style broadcasting. */ Expr add(Expr x1, Expr x2); +/*! \brief Division with numpy-style broadcasting. */ +Expr divide(Expr x1, Expr x2); + +/*! \brief Floor division with numpy-style broadcasting. */ +Expr floor_divide(Expr x1, Expr x2); + /*! \brief Multiplication with numpy-style broadcasting. */ Expr multiply(Expr x1, Expr x2); +/*! \brief Subtraction with numpy-style broadcasting. */ +Expr subtract(Expr x1, Expr x2); + +/***************** Comparison operators *****************/ + +/*! \brief Broadcasted element-wise test for (lhs == rhs). */ +Expr equal(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs > rhs). */ +Expr greater(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs >= rhs). */ +Expr greter_equal(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs < rhs). */ +Expr less(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs <= rhs). */ +Expr less_equal(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs != rhs). */ +Expr not_equal(Expr x1, Expr x2); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc new file mode 100644 index 000000000000..8820c07afd25 --- /dev/null +++ b/src/relax/op/tensor/ternary.cc @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ternary.cc + * \brief ternary operators. + */ + +#include "ternary.h" + +namespace tvm { +namespace relax { + +StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo t1 = input_sinfo[0]; + TensorStructInfo t2 = input_sinfo[1]; + TensorStructInfo t3 = input_sinfo[2]; + + int ndim = kUnknownNDim; + if (!t1->IsUnknownNdim()) { + ndim = t1->ndim; + } + if (!t2->IsUnknownNdim()) { + if (ndim == kUnknownNDim) { + ndim = t2->ndim; + } else if (t2->ndim != ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The 3 arguments of EwiseFMA must have the same number of dimensions"); + } + } + if (!t3->IsUnknownNdim()) { + if (ndim == kUnknownNDim) { + ndim = t3->ndim; + } else if (t3->ndim != ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The 3 arguments of EwiseFMA must have the same number of dimensions"); + } + } + + DataType output_dtype; + if (t1->IsUnknownDtype() || t2->IsUnknownDtype() || t3->IsUnknownDtype()) { + output_dtype = DataType::Void(); + } else if (t1->dtype != t2->dtype || t2->dtype != t3->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Data types " << t1->dtype << ", " << t2->dtype << ", and " << t3->dtype + << " must be equal for EwiseFMA"); + } else { + output_dtype = t1->dtype; + } + + auto* s1 = t1->shape.as(); + auto* s2 = t2->shape.as(); + auto* s3 = t3->shape.as(); + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + if (s1 && s2 && s3) { + Array output_shape; + for (int i = 0; i < ndim; ++i) { + PrimExpr dim1 = s1->values[i]; + PrimExpr dim2 = s2->values[i]; + PrimExpr dim3 = s3->values[i]; + if (analyzer->CanProveEqual(dim1, dim2) && analyzer->CanProveEqual(dim2, dim3)) { + output_shape.push_back(dim1); + } else { + ctx->ReportFatal(Diagnostic::Error(call) + << "The 3 arguments of EwiseFMA must have the same shape"); + } + } + return TensorStructInfo(ShapeExpr(output_shape), output_dtype); + } else if (t1->shape.defined() && t1->shape.same_as(t2->shape) && t1->shape.same_as(t3->shape)) { + return TensorStructInfo(t1->shape.value(), output_dtype); + } + + return TensorStructInfo(output_dtype, ndim); +} + +TVM_REGISTER_OP("relax.ewise_fma") + .set_num_inputs(3) + .add_argument("x1", "Tensor", "The left hand operand of the multiplication") + .add_argument("x2", "Tensor", "The right hand operand of the multiplication") + .add_argument("x3", "Tensor", "The operand of the addition") + .set_attr("FInferStructInfo", InferStructInfoEwiseFMA); + +Expr ewise_fma(Expr x1, Expr x2, Expr x3) { + static const Op& op = Op::Get("relax.ewise_fma"); + return Call(op, {x1, x2, x3}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.ewise_fma").set_body_typed(ewise_fma); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/ternary.h b/src/relax/op/tensor/ternary.h new file mode 100644 index 000000000000..ba22c56d9efd --- /dev/null +++ b/src/relax/op/tensor/ternary.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ternary.h + * \brief The functions to make Relax ternary operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_TERNARY_H_ +#define TVM_RELAX_OP_TENSOR_TERNARY_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Elementwise fused multiply-add operator + * Returns elementwise result of `x1 * x2 + x3` + * \param x1 The left hand operand of the multiplication + * \param x2 The right hand operand of the multiplication + * \param x3 The operand of the addition + * \return The computed result. + */ +Expr ewise_fma(Expr x1, Expr x2, Expr x3); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_TERNARY_H_ diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc new file mode 100644 index 000000000000..f1117c1826c5 --- /dev/null +++ b/src/relax/op/tensor/unary.cc @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unary.cc + * \brief Relax unary arithmetic operators. + */ + +#include "unary.h" + +#include + +namespace tvm { +namespace relax { + +StructInfo InferStructInfoUnaryCheck(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoUnary( + call, ctx, [](const TensorStructInfo& input_sinfo) { return DataType::Bool(); }); +} + +/***************** Arithmetic operators *****************/ + +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(abs, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(acos, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(acosh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(asin, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(asinh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(atan, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(atanh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(ceil, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(cos, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(cosh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(exp, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(floor, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(log, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(negative, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(round, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sigmoid, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sign, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sin, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sinh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(square, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sqrt, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tan, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tanh, /*require_float_dtype=*/true); + +// relax.clip +TVM_REGISTER_OP("relax.clip") + .set_num_inputs(3) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("min", "PrimValue", "The lower-bound of the range to be clipped to") + .add_argument("max", "PrimValue", "The upper-bound of the range to be clipped to") + .set_attr("FInferStructInfo", ReturnStructInfoFromArg<0>); + +Expr clip(Expr x, Expr min, Expr max) { + CHECK(min->IsInstance()) + << "The argument `min` of relax.clip is expected to be a PrimValue, but got" + << min->GetTypeKey(); + CHECK(max->IsInstance()) + << "The argument `max` of relax.clip is expected to be a PrimValue, but got" + << max->GetTypeKey(); + static const Op& op = Op::Get("relax.clip"); + return Call(op, {std::move(x), std::move(min), std::move(max)}); +} + +TVM_REGISTER_GLOBAL("relax.op.clip").set_body_typed(clip); + +/***************** Check operators *****************/ + +RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isfinite); +RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isinf); +RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isnan); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h new file mode 100644 index 000000000000..8f6404c5d9ed --- /dev/null +++ b/src/relax/op/tensor/unary.h @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. Sex The NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. Sex The License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unary.h + * \brief The functions to make Relax unary arithmetic operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_UNARY_H_ +#define TVM_RELAX_OP_TENSOR_UNARY_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Quick helper macro to + * - expose a make-function interface which construct the call node. + * - register op to the registry. + * \param OpName The name of operator to register. + * \param RequireFloatDtype A boolean indicating if the input is required to have float dtype. + * (Only for unary arith operators since all check operators don't require float dtype.) + */ +#define RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName) \ + RELAX_UNARY_OP_INTERFACE(OpName, #OpName); \ + RELAX_REGISTER_UNARY_OP(#OpName) + +#define RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(OpName, RequireFloatDtype) \ + RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferStructInfo", InferStructInfoUnaryArith) + +#define RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(OpName) \ + RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferStructInfo", InferStructInfoUnaryCheck) // require_float_dtype=false for check op + +/***************** Arithmetic operators *****************/ + +/*! + * \brief Compute element-wise absolute value of the input data. + * \param x The input data. + * \return The computed result. + */ +Expr abs(Expr x); + +/*! \brief Compute element-wise arc cos of the input data. */ +Expr acos(Expr x); + +/*! \brief Compute element-wise arc cosh of the input data. */ +Expr acosh(Expr x); + +/*! \brief Compute element-wise arc sin of the input data. */ +Expr asin(Expr x); + +/*! \brief Compute element-wise arc sinh of the input data. */ +Expr asinh(Expr x); + +/*! \brief Compute element-wise arc tan of the input data. */ +Expr atan(Expr x); + +/*! \brief Compute element-wise arc tanh of the input data. */ +Expr atanh(Expr x); + +/*! \brief Take ceil of input data. */ +Expr ceil(Expr x); + +/*! \brief Compute element-wise cos of the input data. */ +Expr cos(Expr x); + +/*! \brief Compute element-wise cosh of the input data. */ +Expr cosh(Expr x); + +/*! \brief Compute element-wise exp of data. */ +Expr exp(Expr x); + +/*! \brief Take floor of input data. */ +Expr floor(Expr x); + +/*! \brief Compute element-wise natural logarithm of data. */ +Expr log(Expr x); + +/*! \brief Compute element-wise negative value of data. */ +Expr negative(Expr x); + +/*! \brief Rounds each element of the input data to nearest integer. */ +Expr round(Expr x); + +/*! \brief Compute element-wise sigmoid of data. */ +Expr sigmoid(Expr x); + +/*! \brief Returns an indication of the sign of a number for each element of the input data. */ +Expr sign(Expr x); + +/*! \brief Compute element-wise sin of data. */ +Expr sin(Expr x); + +/*! \brief Compute element-wise sinh of data. */ +Expr sinh(Expr x); + +/*! \brief Compute element-wise square root of data. */ +Expr sqrt(Expr x); + +/*! \brief Squares each element of the input data. */ +Expr square(Expr x); + +/*! \brief Compute element-wise tan of data. */ +Expr tan(Expr x); + +/*! \brief Compute element-wise tanh of data. */ +Expr tanh(Expr x); + +/*! \brief Clips tensor values to a specified min and max. */ +Expr clip(Expr x, Expr min, Expr max); + +/***************** Check operators *****************/ + +/*! \brief Check if input value is finite. */ +Expr isfinite(Expr x); + +/*! \brief Check if input value is infinite. */ +Expr isinf(Expr x); + +/*! \brief Check if input value is Nan. */ +Expr isnan(Expr x); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_UNARY_H_ diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py new file mode 100644 index 000000000000..a4ae8ce31ac7 --- /dev/null +++ b/tests/python/relax/test_op_binary.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Callable +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + assert relax.op.add(x, y).op == Op.get("relax.add") + assert relax.op.divide(x, y).op == Op.get("relax.divide") + assert relax.op.floor_divide(x, y).op == Op.get("relax.floor_divide") + assert relax.op.multiply(x, y).op == Op.get("relax.multiply") + assert relax.op.subtract(x, y).op == Op.get("relax.subtract") + + assert relax.op.equal(x, y).op == Op.get("relax.equal") + assert relax.op.greater(x, y).op == Op.get("relax.greater") + assert relax.op.greater_equal(x, y).op == Op.get("relax.greater_equal") + assert relax.op.less(x, y).op == Op.get("relax.less") + assert relax.op.less_equal(x, y).op == Op.get("relax.less_equal") + assert relax.op.not_equal(x, y).op == Op.get("relax.not_equal") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +(binary_arith_op,) = tvm.testing.parameters( + (relax.op.add,), + (relax.op.divide,), + (relax.op.floor_divide,), + (relax.op.multiply,), + (relax.op.subtract,), +) + + +def test_binary_arith_infer_struct_info(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor((1, 3), "float32")) + x2 = relax.Var("x", R.Tensor((3, 2, 3), "float32")) + x3 = relax.Var("x", R.Tensor((3, 1, 3), "float32")) + x4 = relax.Var("x", R.Tensor("float32", ndim=2)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor((4, 3, 2, 1), "float32")) + y2 = relax.Var("y", R.Tensor("float32", ndim=2)) + y3 = relax.Var("y", R.Tensor("float32", ndim=-1)) + + _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, binary_arith_op(x1, y0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, binary_arith_op(x1, y1), relax.TensorStructInfo((4, 3, 2, 3), "float32")) + _check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x3, y2), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x4, y0), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y1), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x4, y2), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y3), relax.TensorStructInfo(dtype="float32", ndim=-1)) + _check_inference(bb, binary_arith_op(x5, y0), relax.TensorStructInfo(dtype="", ndim=-1)) + + +(binary_cmp_op,) = tvm.testing.parameters( + (relax.op.equal,), + (relax.op.greater,), + (relax.op.greater_equal,), + (relax.op.less,), + (relax.op.less_equal,), + (relax.op.not_equal,), +) + + +def test_binary_cmp_infer_struct_info(binary_cmp_op: Callable): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor((2, 3), "int32")) + _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool")) + + +def test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + k = tir.Var("k", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((1, n), "float32")) + x2 = relax.Var("x", R.Tensor((k, n, m), "float32")) + x3 = relax.Var("x", R.Tensor((3, 1, n), "float32")) + x4 = relax.Var("x", R.Tensor("float32", ndim=2)) + y0 = relax.Var("y", R.Tensor((m, n), "float32")) + y1 = relax.Var("y", R.Tensor((m, n + 2), "float32")) + y2 = relax.Var("y", R.Tensor((4, k, m, 1), "float32")) + y3 = relax.Var("y", R.Tensor("float32", ndim=2)) + y4 = relax.Var("y", R.Tensor("float32", ndim=-1)) + _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, binary_arith_op(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x1, y0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, binary_arith_op(x1, y2), relax.TensorStructInfo((4, k, m, n), "float32")) + _check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x2, y3), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x3, y3), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x4, y0), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x4, y3), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y4), relax.TensorStructInfo(dtype="float32", ndim=-1)) + + +def test_binary_infer_struct_info_shape_var(binary_arith_op: Callable): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=4)) + s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) + s4 = relax.Var("s4", relax.ShapeStructInfo()) + x = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(s0, "float32")) + y1 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(s2, "float32")) + y3 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) + y4 = relax.Var("y", relax.TensorStructInfo(s4, "float32")) + + _check_inference(bb, binary_arith_op(x, y0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, binary_arith_op(x, y1), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x, y3), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x, y4), relax.TensorStructInfo(dtype="float32")) + + +def test_binary_arith_infer_struct_info_more_input_dtype(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + y0 = relax.Var("y", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + y1 = relax.Var("y", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + y2 = relax.Var("y", R.Tensor((2, 3), "int64")) + + _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, binary_arith_op(x1, y1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo((2, 3), "int64")) + + +def test_binary_infer_struct_info_shape_unequal_const_int(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(binary_arith_op(x0, y0)) + + +def test_binary_arith_infer_struct_info_dtype_mismatch(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "int32")) + with pytest.raises(TVMError): + bb.normalize(binary_arith_op(x, y)) + + +def test_binary_wrong_input_number(binary_arith_op: Callable): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + + with pytest.raises(TypeError): + binary_arith_op(x, x, x) + with pytest.raises(TypeError): + binary_arith_op(x) + with pytest.raises(TypeError): + binary_arith_op(x, x, x, x) + + +def test_binary_infer_struct_info_wrong_input_type(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(binary_arith_op(x0, y)) + with pytest.raises(TVMError): + bb.normalize(binary_arith_op(x1, y)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_ternary.py b/tests/python/relax/test_op_ternary.py new file mode 100644 index 000000000000..5ea7a01da701 --- /dev/null +++ b/tests/python/relax/test_op_ternary.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + z = relax.Var("z", R.Tensor((2, 3), "float32")) + assert relax.op.ewise_fma(x, y, z).op == Op.get("relax.ewise_fma") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_ewise_fma_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3))) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor(dtype="float32", ndim=2)) + z0 = relax.Var("z", R.Tensor((2, 3), "float32")) + z1 = relax.Var("z", R.Tensor("float32")) + + _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.ewise_fma(x0, y1, z0), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.ewise_fma(x0, y1, z1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.ewise_fma(x1, y0, z0), relax.TensorStructInfo((2, 3), dtype="")) + + +def test_ewise_fma_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + y0 = relax.Var("y", R.Tensor((m, n), "float32")) + y1 = relax.Var("y", R.Tensor(dtype="float32", ndim=2)) + z0 = relax.Var("z", R.Tensor((m, n), "float32")) + + _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((m, n), "float32")) + _check_inference( + bb, relax.op.ewise_fma(x0, y1, z0), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + + +def test_ewise_fma_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y = relax.Var("y", relax.TensorStructInfo(s0, "float32")) + z = relax.Var("z", relax.TensorStructInfo(s0, "float32")) + + _check_inference(bb, relax.op.ewise_fma(x0, y, z), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, relax.op.ewise_fma(x1, y, z), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.ewise_fma(x2, y, z), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + + +def test_ewise_fma_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + y0 = relax.Var("y", R.Tensor((2, 3), "float64")) + z0 = relax.Var("z", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + y1 = relax.Var("y", R.Tensor((2, 3), "int8")) + z1 = relax.Var("z", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + y2 = relax.Var("y", R.Tensor((2, 3), "int64")) + z2 = relax.Var("z", R.Tensor((2, 3), "int64")) + + _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.ewise_fma(x1, y1, z1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.ewise_fma(x2, y2, z2), relax.TensorStructInfo((2, 3), "int64")) + + +def test_ewise_fma_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "int32")) + y1 = relax.Var("y", R.Tensor((2, 3), "float32")) + z0 = relax.Var("z", R.Tensor((2, 3), "float32")) + z1 = relax.Var("z", R.Tensor((2, 3), "int8")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y0, z0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y1, z1)) + + +def test_ewise_fma_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor((2, 3, 4), "float32")) + z0 = relax.Var("z", R.Tensor((2, 3), "float32")) + z1 = relax.Var("z", R.Tensor(dtype="float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y1, z0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y0, z1)) + + +def test_ewise_fma_wrong_input_number(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + + with pytest.raises(TypeError): + relax.op.ewise_fma(x) + with pytest.raises(TypeError): + relax.op.ewise_fma(x, x) + with pytest.raises(TypeError): + relax.op.ewise_fma(x, x, x, x) + + +def test_ewise_fma_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", relax.ShapeStructInfo((2, 3))) + y1 = relax.Var("y", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + z = relax.Var("z", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y0, z)) + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y1, z)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_unary.py b/tests/python/relax/test_op_unary.py new file mode 100644 index 000000000000..45336661a1ae --- /dev/null +++ b/tests/python/relax/test_op_unary.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Callable +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + assert relax.op.abs(x).op == Op.get("relax.abs") + assert relax.op.acos(x).op == Op.get("relax.acos") + assert relax.op.acosh(x).op == Op.get("relax.acosh") + assert relax.op.asin(x).op == Op.get("relax.asin") + assert relax.op.asinh(x).op == Op.get("relax.asinh") + assert relax.op.atan(x).op == Op.get("relax.atan") + assert relax.op.atanh(x).op == Op.get("relax.atanh") + assert relax.op.ceil(x).op == Op.get("relax.ceil") + assert relax.op.cos(x).op == Op.get("relax.cos") + assert relax.op.cosh(x).op == Op.get("relax.cosh") + assert relax.op.exp(x).op == Op.get("relax.exp") + assert relax.op.floor(x).op == Op.get("relax.floor") + assert relax.op.isfinite(x).op == Op.get("relax.isfinite") + assert relax.op.isinf(x).op == Op.get("relax.isinf") + assert relax.op.isnan(x).op == Op.get("relax.isnan") + assert relax.op.log(x).op == Op.get("relax.log") + assert relax.op.negative(x).op == Op.get("relax.negative") + assert relax.op.round(x).op == Op.get("relax.round") + assert relax.op.sigmoid(x).op == Op.get("relax.sigmoid") + assert relax.op.sin(x).op == Op.get("relax.sin") + assert relax.op.sinh(x).op == Op.get("relax.sinh") + assert relax.op.square(x).op == Op.get("relax.square") + assert relax.op.sqrt(x).op == Op.get("relax.sqrt") + assert relax.op.tan(x).op == Op.get("relax.tan") + assert relax.op.tanh(x).op == Op.get("relax.tanh") + assert relax.op.clip(x, 0, 6).op == Op.get("relax.clip") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +unary_arith_op, require_float_dtype = tvm.testing.parameters( + (relax.op.abs, False), + (relax.op.acos, True), + (relax.op.acosh, True), + (relax.op.asin, True), + (relax.op.asinh, True), + (relax.op.atan, True), + (relax.op.atanh, True), + (relax.op.ceil, False), + (relax.op.cos, True), + (relax.op.cosh, True), + (relax.op.exp, True), + (relax.op.floor, False), + (relax.op.log, True), + (relax.op.negative, False), + (relax.op.round, False), + (relax.op.sigmoid, True), + (relax.op.sign, False), + (relax.op.sin, True), + (relax.op.sinh, True), + (relax.op.square, False), + (relax.op.sqrt, True), + (relax.op.tan, True), + (relax.op.tanh, True), +) + + +def test_unary_arith_infer_struct_info(unary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, unary_arith_op(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, unary_arith_op(x3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, unary_arith_op(x4), relax.TensorStructInfo(dtype="")) + + +def test_unary_arith_infer_struct_info_shape_symbolic(unary_arith_op: Callable): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((4, n), "float32")) + + _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo((4, n), "float32")) + + +def test_unary_arith_infer_struct_info_shape_var(unary_arith_op: Callable): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo(s1, "float32")) + + +def test_unary_arith_infer_struct_info_more_input_dtype( + unary_arith_op: Callable, require_float_dtype: bool +): + if require_float_dtype: + return + + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + + _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, unary_arith_op(x2), relax.TensorStructInfo((2, 3), "int64")) + + +def test_unary_arith_infer_struct_info_invalid_input_dtype( + unary_arith_op: Callable, require_float_dtype: bool +): + if not require_float_dtype: + return + + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3), "int64")) + + with pytest.raises(TVMError): + bb.normalize(unary_arith_op(x0)) + with pytest.raises(TVMError): + bb.normalize(unary_arith_op(x1)) + + +def test_unary_arith_wrong_input_number(unary_arith_op: Callable): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + + with pytest.raises(TypeError): + unary_arith_op(x, x) + with pytest.raises(TypeError): + unary_arith_op(x, x, x) + + +def test_unary_arith_infer_struct_info_wrong_input_type(unary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(unary_arith_op(x0)) + with pytest.raises(TVMError): + bb.normalize(unary_arith_op(x1)) + + +def test_clip_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.clip(x0, 0, 6), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.clip(x1, 0, 6), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.clip(x2, 0, 6), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.clip(x3, 0, 6), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.clip(x4, 0, 6), relax.TensorStructInfo(dtype="")) + + # Symbolic + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x5 = relax.Var("x", R.Tensor((m, n), "float32")) + x6 = relax.Var("x", R.Tensor((4, n), "float32")) + + _check_inference(bb, relax.op.clip(x5, 0, 6), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.clip(x6, 0, 6), relax.TensorStructInfo((4, n), "float32")) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index c9a16fbcacb7..f6d2e4c20e48 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -31,10 +31,9 @@ def _check( parsed: Union[relax.Function, IRModule], expect: Optional[Union[relax.Function, IRModule]] = None, ): - # TODO(relax-team): enable roundtrip testing when printer is ready - # test = parsed.script(show_meta=True) - # roundtrip_mod = tvm.script.parse(test) - # tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) if expect: tvm.ir.assert_structural_equal(parsed, expect) diff --git a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py new file mode 100644 index 000000000000..ffb8576b27dc --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union, Callable + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +(unary_arith_op,) = tvm.testing.parameters( + (relax.op.abs,), + (relax.op.acos,), + (relax.op.acosh,), + (relax.op.asin,), + (relax.op.asinh,), + (relax.op.atan,), + (relax.op.atanh,), + (relax.op.ceil,), + (relax.op.cos,), + (relax.op.cosh,), + (relax.op.exp,), + (relax.op.floor,), + (relax.op.log,), + (relax.op.negative,), + (relax.op.round,), + (relax.op.sigmoid,), + (relax.op.sign,), + (relax.op.sin,), + (relax.op.sinh,), + (relax.op.square,), + (relax.op.sqrt,), + (relax.op.tan,), + (relax.op.tanh,), +) + + +def test_unary_arith(unary_arith_op: Callable): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = unary_arith_op(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(unary_arith_op(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +(unary_check_op,) = tvm.testing.parameters( + (relax.op.isfinite,), + (relax.op.isinf,), + (relax.op.isnan,), +) + + +def test_unary_check(unary_check_op: Callable): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), "bool") = unary_check_op(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(unary_check_op(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +(binary_arith_op,) = tvm.testing.parameters( + (relax.op.add,), + (relax.op.divide,), + (relax.op.floor_divide,), + (relax.op.multiply,), + (relax.op.subtract,), +) + + +def test_binary_arith(binary_arith_op: Callable): + @R.function + def foo( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32") + ) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = binary_arith_op(x, y) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 1), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y]): + gv = bb.emit(binary_arith_op(x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +(binary_cmp_op,) = tvm.testing.parameters( + (relax.op.equal,), + (relax.op.greater,), + (relax.op.greater_equal,), + (relax.op.less,), + (relax.op.less_equal,), + (relax.op.not_equal,), +) + + +def test_binary_cmp(binary_cmp_op: Callable): + @R.function + def foo( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32") + ) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), "bool") = binary_cmp_op(x, y) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 1), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y]): + gv = bb.emit(binary_cmp_op(x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_relax_ewise_fma(): + @R.function + def foo( + x: R.Tensor((2, 3, 4), dtype="float32"), + y: R.Tensor((2, 3, 4), dtype="float32"), + z: R.Tensor((2, 3, 4), dtype="float32"), + ) -> R.Tensor((2, 3, 4), dtype="float32"): + gv: R.Tensor((2, 3, 4), dtype="float32") = R.ewise_fma(x, y, z) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + y = relax.Var("y", R.Tensor((2, 3, 4), "float32")) + z = relax.Var("z", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y, z]): + gv = bb.emit(relax.op.ewise_fma(x, y, z)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() From d4a7cfccdca4454dc76f97ff6518649dee759054 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 15:01:07 -0500 Subject: [PATCH 18/81] [Unity] Relax op: statistical (#13991) This PR is about the high-level tensor computation operators in Relax. This PR includes the statistical operators. --- include/tvm/relax/attrs/statistical.h | 48 ++++ python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/op_attrs.py | 5 + python/tvm/relax/op/statistical.py | 218 ++++++++++++++++++ python/tvm/script/ir_builder/relax/ir.py | 18 ++ src/relax/op/tensor/statistical.cc | 96 ++++++++ src/relax/op/tensor/statistical.h | 92 ++++++++ tests/python/relax/test_op_statistical.py | 204 ++++++++++++++++ .../test_tvmscript_parser_op_statistical.py | 174 ++++++++++++++ 9 files changed, 856 insertions(+) create mode 100644 include/tvm/relax/attrs/statistical.h create mode 100644 python/tvm/relax/op/statistical.py create mode 100644 src/relax/op/tensor/statistical.cc create mode 100644 src/relax/op/tensor/statistical.h create mode 100644 tests/python/relax/test_op_statistical.py create mode 100644 tests/python/relax/test_tvmscript_parser_op_statistical.py diff --git a/include/tvm/relax/attrs/statistical.h b/include/tvm/relax/attrs/statistical.h new file mode 100644 index 000000000000..bb1ab2195d9a --- /dev/null +++ b/include/tvm/relax/attrs/statistical.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/statistical.h + * \brief Attributes for statistical operators. + */ +#ifndef TVM_RELAX_ATTRS_STATISTICAL_H_ +#define TVM_RELAX_ATTRS_STATISTICAL_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes for statistical operators */ +struct StatisticalAttrs : public tvm::AttrsNode { + Optional> axis; + bool keepdims; + + TVM_DECLARE_ATTRS(StatisticalAttrs, "relax.attrs.StatisticalAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis or axes along which to perform the reduction."); + TVM_ATTR_FIELD(keepdims).describe( + "If this is set to `True`, the reduced axes are left in the result as dimension with size " + "one."); + } +}; // struct StatisticalAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_STATISTICAL_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 344576fe13b2..68152c2056e1 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -24,6 +24,7 @@ from .index import * from .manipulate import * from .op_attrs import * +from .statistical import * from .set import * from .ternary import * from .unary import * diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index fb64443b7e09..1fb8853040fd 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -34,6 +34,11 @@ class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" +@tvm._ffi.register_object("relax.attrs.StatisticalAttrs") +class StatisticalAttrs(Attrs): + """Attributes used in statistical operator""" + + @tvm._ffi.register_object("relax.attrs.Resize2DAttrs") class Resize2DAttrs(Attrs): """Attributes used in image resize2d operator""" diff --git a/python/tvm/relax/op/statistical.py b/python/tvm/relax/op/statistical.py new file mode 100644 index 000000000000..4669c783adda --- /dev/null +++ b/python/tvm/relax/op/statistical.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +"""Statistical operators.""" +from typing import List, Optional, Union + +from . import _ffi_api +from ..expr import Expr + + +def max(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the max of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a max operation is performed. + The default, axis=None, will compute the max of all elements in the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.max(x, axis, keepdims) # type: ignore + + +def mean(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the mean of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a mean operation is performed. + The default, axis=None, will compute the mean of all elements in the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.mean(x, axis, keepdims) # type: ignore + + +def min(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the min of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a min operation is performed. + The default, axis=None, will compute the min of all elements in the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.min(x, axis, keepdims) # type: ignore + + +def prod(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the product of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a product is performed. + The default, axis=None, will compute the product of all elements of the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.prod(x, axis, keepdims) # type: ignore + + +def std(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the standard deviation of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a standard deviation is performed. + The default, axis=None, will compute the std of all elements of the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.std(x, axis, keepdims) # type: ignore + + +def sum(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the sum of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a sum is performed. + The default, axis=None, will sum all of the elements of the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.sum(x, axis, keepdims) # type: ignore + + +def variance(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the variance of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a variance operation is performed. + The default, axis=None, will compute the variance of all elements in the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.variance(x, axis, keepdims) # type: ignore diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index a5cb574a06f0..47779a602452 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -63,15 +63,24 @@ less_equal, log, make_closure, + max, + mean, memory, + min, multiply, negative, not_equal, null_value, print, + prod, reshape, round, shape_of, + std, + strided_slice, + sum, + take, + variance, sigmoid, sign, sin, @@ -486,7 +495,10 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "less_equal", "log", "make_closure", + "max", + "mean", "memory", + "min", "multiply", "negative", "not_equal", @@ -494,10 +506,15 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "output", "prim_value", "print", + "prod", "reshape", "round", "shape", "shape_of", + "std", + "str", + "strided_slice", + "sum", "sigmoid", "sign", "sin", @@ -511,5 +528,6 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "tan", "tanh", "tuple", + "variance", "unique", ] diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc new file mode 100644 index 000000000000..41b99fbe36c1 --- /dev/null +++ b/src/relax/op/tensor/statistical.cc @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file statistical.cc + * \brief Statistical operators. + */ + +#include "statistical.h" + +#include + +namespace tvm { +namespace relax { + +StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + + std::vector axes; + if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) { + axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value()); + } + + int out_ndim; + if (attrs->keepdims) { + out_ndim = data_sinfo->ndim; + } else if (!attrs->axis.defined()) { + out_ndim = 0; + } else if (data_sinfo->IsUnknownNdim()) { + out_ndim = kUnknownNDim; + } else { + out_ndim = data_sinfo->ndim - axes.size(); + ICHECK_GE(out_ndim, 0); + } + + // The inference rule for reduction operator output shapes: + // - axes is None, keepdims is false -> return the zero-rank shape; + // - axes is None, keepdims is true -> return the shape whose ndim is the same as input and every + // value is 1. + // - axes is not None, keepdims is false -> the returned shape does not contain the input axes. + // - axes is not None, keepdims is true -> the returned shape has value 1 at the positions of the + // input axes + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) { + return TensorStructInfo( + ShapeExpr(Array(out_ndim, IntImm(DataType::Int(64), /*value=*/1))), + data_sinfo->dtype); + } else { + return out_ndim == 0 ? TensorStructInfo(ShapeExpr(Array()), data_sinfo->dtype) + : TensorStructInfo(data_sinfo->dtype, out_ndim); + } + } + + Array out_shape; + out_shape.reserve(out_ndim); + for (int i = 0; i < data_sinfo->ndim; ++i) { + if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { + out_shape.push_back(data_shape->values[i]); + } else if (attrs->keepdims) { + out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1)); + } + } + ICHECK_EQ(static_cast(out_shape.size()), out_ndim); + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype); +} + +TVM_REGISTER_NODE_TYPE(StatisticalAttrs); + +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(max); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(mean); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(min); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(prod); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(std); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(sum); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(variance); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h new file mode 100644 index 000000000000..7d322d11293c --- /dev/null +++ b/src/relax/op/tensor/statistical.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file statistical.h + * \brief The functions to make Relax statistical operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_STATISTICAL_H_ +#define TVM_RELAX_OP_TENSOR_STATISTICAL_H_ + +#include + +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Quick helper macro + * - Expose a make function to construct the node. + * - Register op to the registry. + * \param OpName The name of operator to register. The name passed in will + * 1. be prepended with a prefix "relax.op." as the FFI identifier string for the make function, + * 2. be prepended with a prefix "relax." as the identifier string in the operator registry. + */ +#define RELAX_REGISTER_STATISTICAL_OP_INTERFACE(OpName) \ + Expr OpName(Expr x, Optional> axis, bool keepdims) { \ + ObjectPtr attrs = make_object(); \ + attrs->axis = std::move(axis); \ + attrs->keepdims = keepdims; \ + static const Op& op = Op::Get("relax." #OpName); \ + return Call(op, {std::move(x)}, Attrs{attrs}, {}); \ + } \ + TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_REGISTER_OP("relax." #OpName) \ + .set_num_inputs(1) \ + .add_argument("x", "Tensor", "The input data tensor") \ + .set_attr("FInferStructInfo", InferStructInfoStatistical) + +/*! + * \brief Computes the maximum value of tensor elements over given axes. + * \param x The input data tensor + * \param axis Axis or axes along which a max is performed. Being `NullOpt` means to max all the + * elements of the input tensor + * \param keepdims If this is set to True, the axes which are reduced are left in the result as + * dimensions with size one. With this option, the result will broadcast correctly against the + * input tensor. + * \return The result after reduction. + */ +Expr max(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the mean of tensor elements over given axes. */ +Expr mean(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the min of tensor elements over given axes. */ +Expr min(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the product of tensor elements over given axes. */ +Expr prod(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the standard deviation of tensor elements over given axes. */ +Expr std(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the sum of tensor elements over given axes. */ +Expr sum(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the variance of tensor elements over given axes. */ +Expr variance(Expr x, Optional> axis, bool keepdims); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_STATISTICAL_H_ diff --git a/tests/python/relax/test_op_statistical.py b/tests/python/relax/test_op_statistical.py new file mode 100644 index 000000000000..b1bdd8e44d85 --- /dev/null +++ b/tests/python/relax/test_op_statistical.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + assert relax.op.max(x).op == Op.get("relax.max") + assert relax.op.mean(x).op == Op.get("relax.mean") + assert relax.op.min(x).op == Op.get("relax.min") + assert relax.op.prod(x).op == Op.get("relax.prod") + assert relax.op.std(x).op == Op.get("relax.std") + assert relax.op.sum(x).op == Op.get("relax.sum") + assert relax.op.variance(x).op == Op.get("relax.variance") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_statistical_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + + _check_inference(bb, relax.op.sum(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32")) + _check_inference( + bb, + relax.op.sum(x0, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((2, 1, 1, 5), "float32"), + ) + _check_inference(bb, relax.op.sum(x0, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.sum(x0, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "float32"), + ) + _check_inference( + bb, relax.op.mean(x1, axis=[1, 2]), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, + relax.op.mean(x1, axis=[1, 2], keepdims=True), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference(bb, relax.op.mean(x1, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.mean(x1, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "float32"), + ) + _check_inference( + bb, relax.op.variance(x2, axis=[1, 2]), relax.TensorStructInfo(dtype="float32") + ) + _check_inference( + bb, + relax.op.variance(x2, axis=[1, 2], keepdims=True), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference(bb, relax.op.variance(x2, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.variance(x2, axis=None, keepdims=True), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference(bb, relax.op.max(x3, axis=[1, 2]), relax.TensorStructInfo((2, 5), dtype="")) + _check_inference( + bb, + relax.op.max(x3, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((2, 1, 1, 5), dtype=""), + ) + _check_inference(bb, relax.op.max(x3, axis=None), relax.TensorStructInfo((), dtype="")) + _check_inference( + bb, + relax.op.max(x3, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), dtype=""), + ) + _check_inference(bb, relax.op.prod(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32")) + _check_inference( + bb, + relax.op.prod(x0, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((2, 1, 1, 5), "float32"), + ) + _check_inference(bb, relax.op.std(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32")) + _check_inference( + bb, + relax.op.std(x0, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((2, 1, 1, 5), "float32"), + ) + _check_inference(bb, relax.op.sum(x0, axis=[-1, -4]), relax.TensorStructInfo((3, 4), "float32")) + _check_inference(bb, relax.op.sum(x0, axis=[]), relax.TensorStructInfo((2, 3, 4, 5), "float32")) + + +def test_statistical_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) + + _check_inference(bb, relax.op.min(x, axis=[1, 2]), relax.TensorStructInfo((a, d), "float32")) + _check_inference( + bb, + relax.op.min(x, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((a, 1, 1, d), "float32"), + ) + _check_inference(bb, relax.op.min(x, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.min(x, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "float32"), + ) + + +def test_statistical_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, relax.op.max(x0), relax.TensorStructInfo((), dtype="float32")) + _check_inference( + bb, relax.op.max(x0, keepdims=True), relax.TensorStructInfo((1, 1, 1, 1), dtype="float32") + ) + _check_inference( + bb, relax.op.max(x0, axis=[2, 3]), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, + relax.op.max(x0, axis=[2, 3], keepdims=True), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference(bb, relax.op.max(x1), relax.TensorStructInfo((), dtype="float32")) + _check_inference(bb, relax.op.max(x1, keepdims=True), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.max(x1, axis=[2, 3]), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.max(x1, axis=[2, 3], keepdims=True), relax.TensorStructInfo(dtype="float32") + ) + + +def test_statistical_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + + _check_inference(bb, relax.op.sum(x0), relax.TensorStructInfo((), "float16")) + _check_inference(bb, relax.op.sum(x1), relax.TensorStructInfo((), "int8")) + + +def test_statistical_infer_struct_info_axis_out_of_range_repetitive(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x0, axis=[4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x1, axis=[3, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x0, axis=[-1, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x1, axis=[-4, -4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x0, axis=[-5])) + + +def test_statistical_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.variance(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.variance(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_statistical.py b/tests/python/relax/test_tvmscript_parser_op_statistical.py new file mode 100644 index 000000000000..221d2a17a8b8 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_statistical.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_sum(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3), "float32"): + gv: R.Tensor((1, 3), "float32") = R.sum(x, axis=[1, 3]) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.sum(x, axis=[1, 3])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_sum_without_specified_axis(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.sum(x) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.sum(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_sum_keep_dims(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 1, 3, 1), "float32"): + gv: R.Tensor((1, 1, 3, 1), "float32") = R.sum(x, axis=[1, 3], keepdims=True) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.sum(x, axis=[1, 3], keepdims=True)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_mean(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3), "float32"): + gv: R.Tensor((1, 3), "float32") = R.mean(x, axis=[1, 3]) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.mean(x, axis=[1, 3])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_variance(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1,), "float32"): + gv: R.Tensor((1,), "float32") = R.variance(x, axis=[-1, -2, -3]) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.variance(x, axis=[-1, -2, -3])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_max(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 1, 1, 1), "float32"): + gv: R.Tensor((1, 1, 1, 1), "float32") = R.variance(x, axis=[-1, -2, -3], keepdims=True) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.variance(x, axis=[-1, -2, -3], keepdims=True)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_min(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3, 4), "float32"): + gv: R.Tensor((1, 3, 4), "float32") = R.min(x, axis=1) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.min(x, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_prod(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3, 4), "float32"): + gv: R.Tensor((1, 3, 4), "float32") = R.prod(x, axis=1) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.prod(x, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_std(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3, 4), "float32"): + gv: R.Tensor((1, 3, 4), "float32") = R.std(x, axis=1) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.std(x, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() From fcf4f59db8ee3fe276af63364d7005baedcc5f9e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 15:03:22 -0500 Subject: [PATCH 19/81] [Unity] Relax op: neural networks (#13993) This PR is about the high-level tensor computation operators in Relax. This PR includes the neural network operators. --- include/tvm/relax/attrs/nn.h | 190 ++++ python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/nn/__init__.py | 19 + python/tvm/relax/op/nn/_ffi_api.py | 20 + python/tvm/relax/op/nn/nn.py | 524 ++++++++++ python/tvm/relax/op/op_attrs.py | 35 + python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/op/nn/convolution.cc | 146 +++ src/relax/op/nn/convolution.h | 63 ++ src/relax/op/nn/nn.cc | 245 +++++ src/relax/op/nn/nn.h | 81 ++ src/relax/op/nn/pooling.cc | 184 ++++ src/relax/op/nn/pooling.h | 46 + tests/python/relax/test_op_nn.py | 929 ++++++++++++++++++ tests/python/relax/test_op_nn_convolution.py | 429 ++++++++ tests/python/relax/test_op_nn_pooling.py | 429 ++++++++ .../relax/test_tvmscript_parser_op_nn.py | 193 ++++ 17 files changed, 3536 insertions(+) create mode 100644 include/tvm/relax/attrs/nn.h create mode 100644 python/tvm/relax/op/nn/__init__.py create mode 100644 python/tvm/relax/op/nn/_ffi_api.py create mode 100644 python/tvm/relax/op/nn/nn.py create mode 100644 src/relax/op/nn/convolution.cc create mode 100644 src/relax/op/nn/convolution.h create mode 100644 src/relax/op/nn/nn.cc create mode 100644 src/relax/op/nn/nn.h create mode 100644 src/relax/op/nn/pooling.cc create mode 100644 src/relax/op/nn/pooling.h create mode 100644 tests/python/relax/test_op_nn.py create mode 100644 tests/python/relax/test_op_nn_convolution.py create mode 100644 tests/python/relax/test_op_nn_pooling.py create mode 100644 tests/python/relax/test_tvmscript_parser_op_nn.py diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h new file mode 100644 index 000000000000..694a51070683 --- /dev/null +++ b/include/tvm/relax/attrs/nn.h @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/nn.h + * \brief Attributes for neural network operators. + */ +#ifndef TVM_RELAX_ATTRS_NN_H_ +#define TVM_RELAX_ATTRS_NN_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in Conv2d operator */ +struct Conv2DAttrs : public tvm::AttrsNode { + Array strides; + Array padding; + Array dilation; + int groups; + String data_layout; + String kernel_layout; + String out_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv2DAttrs, "relax.attrs.Conv2DAttrs") { + TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation).describe( + "Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).describe( + "Number of groups to split the input into for grouped convolution. The number of input and " + "output channels should be divisible by the number of groups."); + TVM_ATTR_FIELD(data_layout) + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(out_dtype).describe( + "Output data type, set to explicit type under mixed precision setting"); + } +}; // struct Conv2dAttrs + +/*! \brief Attributes used in max_pool2d operator */ +struct MaxPool2DAttrs : public tvm::AttrsNode { + Array pool_size; + Array strides; + Array padding; + Array dilation; + bool ceil_mode; + String layout; + String out_layout; + + TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relax.attrs.MaxPool2DAttrs") { + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(dilation).describe("Specifies the dilation of the convolution."); + TVM_ATTR_FIELD(padding).describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(ceil_mode).describe( + "A boolean indicating if use ceil or floor to compute the output shape. By using ceil, " + "every element in the input tensor will be covered by a sliding window."); + TVM_ATTR_FIELD(layout).describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + } +}; // struct MaxPool2dAttrs + +/*! \brief Attributes for 2d adaptive pool operator */ +struct AdaptivePool2DAttrs : public tvm::AttrsNode { + Optional> output_size; + String layout; + String out_layout; + + TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relax.attrs.AdaptivePool2DAttrs") { + TVM_ATTR_FIELD(output_size).describe("Output height and width."); + TVM_ATTR_FIELD(layout).describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(out_layout) + .describe( + "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + } +}; // struct AdaptivePool2DAttrs + +/*! \brief Attributes used in softmax operators */ +struct SoftmaxAttrs : public tvm::AttrsNode { + int axis; + + TVM_DECLARE_ATTRS(SoftmaxAttrs, "relax.attrs.SoftmaxAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis to sum over when computing softmax."); + } +}; + +/*! \brief Attributes used in batch_norm operator */ +struct BatchNormAttrs : public tvm::AttrsNode { + int axis; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(BatchNormAttrs, "relax.attrs.BatchNormAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis along which the normalization is applied."); + TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).describe( + "Indicating if the beta offset will be added to the normalized tensor."); + TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); + } +}; // struct BatchNormAttrs + +/*! \brief Attributes used in layer_norm operator */ +struct LayerNormAttrs : public tvm::AttrsNode { + Array axes; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(LayerNormAttrs, "relax.attrs.LayerNormAttrs") { + TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization is applied."); + TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).describe( + "Indicating if the beta offset will be added to the normalized tensor."); + TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); + } +}; // struct LayerNormAttrs + +/*! \brief Attributes used in dropout operator */ +struct DropoutAttrs : public tvm::AttrsNode { + double rate; + + TVM_DECLARE_ATTRS(DropoutAttrs, "relax.attrs.DropoutAttrs") { + TVM_ATTR_FIELD(rate).describe( + "Fraction of the input that gets dropped out during training time"); + } +}; // struct DropoutAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_NN_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 68152c2056e1..6c6fffc7c65e 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -31,6 +31,7 @@ from . import builtin from . import image from . import memory +from . import nn def _register_op_make(): diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py new file mode 100644 index 000000000000..af2aa106bca7 --- /dev/null +++ b/python/tvm/relax/op/nn/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Neural network related operators.""" +from .nn import * diff --git a/python/tvm/relax/op/nn/_ffi_api.py b/python/tvm/relax/op/nn/_ffi_api.py new file mode 100644 index 000000000000..1785345ac1b1 --- /dev/null +++ b/python/tvm/relax/op/nn/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.nn", __name__) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py new file mode 100644 index 000000000000..cdf0e9646492 --- /dev/null +++ b/python/tvm/relax/op/nn/nn.py @@ -0,0 +1,524 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Neural Network (NN) operators""" +from typing import List, Optional, Tuple, Union + +from tvm import DataType + +from . import _ffi_api +from ...expr import Expr + + +def conv2d( + data: Expr, + weight: Expr, + strides: Union[int, Tuple[int, int]] = (1, 1), + padding: Union[int, Tuple[int, ...]] = (0, 0), + dilation: Union[int, Tuple[int, int]] = (1, 1), + groups: int = 1, + data_layout: str = "NCHW", + kernel_layout: str = "OIHW", + out_layout: Optional[str] = None, + out_dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + r"""2D convolution. + + This operator takes the weight as the convolution kernel + and convolves it with data to produce an output. + + + In the default case, where the data_layout is `NCHW` + and kernel_layout is `OIHW`, conv2d takes in + a data Tensor with shape `(batch_size, in_channels, height, width)`, + and a weight Tensor with shape `(channels, in_channels, kernel_h, kernel_w)`, + where `kernel_h` and `kernel_w` is the lengths of the `H` and `W` kernel dimensions, + to produce an output Tensor with the following rule: + + .. math:: + + \mbox{out}[b, c, y, x] = \sum_{dy, dx, k} + \mbox{data}[b, k, \mbox{strides}[0] * y + dy, \mbox{strides}[1] * x + dx] * + \mbox{weight}[c, k, dy, dx] + + Padding and dilation are applied to data and weight respectively before the computation. + This operator accepts data layout specification. + Semantically, the operator will convert the layout to the canonical layout + (`NCHW` for data and `OIHW` for weight), perform the computation, + then convert to the out_layout. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + weight : relax.Expr + The weight expressions. + + strides : Union[int, Tuple[int, int]] + The strides of convolution. It is required to have length either 1 or 2. + + padding : Union[int, Tuple[int, ...]] + The padding of convolution on both sides of inputs before convolution. + It is required to have length either 1, 2 or 4. + + dilation : Union[int, Tuple[int, int]] + Specifies the dilation rate to be used for dilated convolution. + It is required to have length either 1 or 2. + + groups : int + Number of groups to split the input into for grouped convolution. + The number of input and output channels should be divisible by the number of groups. + + data_layout : str + Layout of the input. + + kernel_layout : str + Layout of the weight. + + out_layout : Optional[str] + Layout of the output. If not specified, it is the same as data_layout + + out_dtype : Optional[Union[str, DataType]] + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(strides, int): + strides = (strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation) + if isinstance(padding, int): + padding = (padding, padding, padding, padding) + + return _ffi_api.conv2d( # type: ignore + data, + weight, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + out_layout, + out_dtype, + ) + + +def max_pool2d( + data: Expr, + pool_size: Union[int, Tuple[int, int]] = (1, 1), + strides: Union[int, Tuple[int, int]] = (1, 1), + padding: Union[int, Tuple[int, ...]] = (0, 0), + dilation: Union[int, Tuple[int, int]] = (1, 1), + ceil_mode: bool = False, + layout: str = "NCHW", + out_layout: Optional[str] = None, +) -> Expr: + r"""2D maximum pooling operator. + + This operator takes data as input and does 2D max value calculation + with in pool_size sized window by striding defined by stride + + + In the default case, where the data_layout is `NCHW` + a data Tensor with shape `(batch_size, in_channels, height, width)`, + to produce an output Tensor with the following rule: + + with data of shape (b, c, h, w) and pool_size (kh, kw) + + .. math:: + + \mbox{out}(b, c, y, x) = \max_{m=0, \ldots, kh-1} \max_{n=0, \ldots, kw-1} + \mbox{data}(b, c, \mbox{stride}[0] * y + m, \mbox{stride}[1] * x + n) + + Padding is applied to data before the computation. + ceil_mode is used to take ceil or floor while computing out shape. + This operator accepts data layout specification. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + pool_size : Union[int, Tuple[int, int]] + The size of window for pooling. It is required to have length either 1 or 2. + + strides : Union[int, Tuple[int, int]] + The strides of pooling. It is required to have length either 1 or 2. + + padding : Union[int, Tuple[int, ...]] + The padding for pooling. It is required to have length either 1, 2 or 4. + + dilation : Union[int, Tuple[int, int]] + The dilation of pooling. It is required to have length either 1 or 2. + + ceil_mode : bool + A boolean indicating if use ceil or floor to compute the output shape. + By using ceil, every element in the input tensor will be covered by a sliding window. + + layout : str + Layout of the input. + + out_layout : Optional[str] + Layout of the output. If not specified, it is the same as data_layout + + Returns + ------- + result : Expr + The computed result. + """ + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation) + if isinstance(padding, int): + padding = (padding, padding, padding, padding) + + return _ffi_api.max_pool2d( # type: ignore + data, pool_size, strides, padding, dilation, ceil_mode, layout, out_layout + ) + + +def adaptive_avg_pool2d( + data: Expr, + output_size: Optional[Union[int, Tuple[int, int]]] = None, + layout: str = "NCHW", + out_layout: Optional[str] = None, +) -> Expr: + r"""2D adaptive average pooling operator. This operator is experimental. + + This operator takes data as input and does 2D average value calculation + across each window represented by WxH. + + + In the default case, where the data_layout is `NCHW` + a data Tensor with shape `(batch_size, in_channels, height, width)`, + to produce an output Tensor with shape + (batch_size, in_channels, output_height, output_width). + + The pooling kernel and stride sizes are automatically chosen for + desired output sizes. + + For output_size: + If this argument is not provided, input height and width will be used + as output height and width. + + If a single integer is provided for output_size, the output size is + (N x C x output_size x output_size) for any input (NCHW). + + If a tuple of integers (height, width) are provided for output_size, + the output size is (N x C x height x width) for any input (NCHW). + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + output_size : Optional[Union[int, Tuple[int, int]]] + Output height and width. + If not specified, it will be the same as the input height and width. + If specified, it is required to have length either 1 or 2. + + layout : str + Layout of the input. + + out_layout : Optional[str] + Layout of the output. If not specified, it is the same as data_layout + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(output_size, int): + output_size = (output_size, output_size) + return _ffi_api.adaptive_avg_pool2d(data, output_size, layout, out_layout) # type: ignore + + +def relu(data: Expr) -> Expr: + """Rectified linear unit. + + .. math:: + text{ReLU}(x) = max(x, 0) + + Parameters + ---------- + data : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.relu(data) # type: ignore + + +def gelu(data: Expr) -> Expr: + """Gaussian Error Linear Units function + + .. math:: + text{GeLU}(x) = 0.5 * x * (1 + erf(x * 0.5**0.5)) + + where :math:`erf` is the Gauss Error function. + + Parameters + ---------- + data : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.gelu(data) # type: ignore + + +def silu(data: Expr) -> Expr: + """Sigmoid Linear Unit function + + .. math:: + text{SiLU}(x) = x * sigmoid(x) + + Parameters + ---------- + data : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.silu(data) # type: ignore + + +def softmax(data: Expr, axis: int = -1) -> Expr: + r"""Computes softmax. + + .. math:: text{softmax}(x)_i = frac{exp(x_i)}{\sum_j exp(x_j)} + + Parameters + ---------- + data: relax.Expr + The input data to the operator. + + axis: int + The axis to sum over when computing softmax. + If not specified, it is by default the last axis of the input tensor. + Supports negative indexing. + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.softmax(data, axis) # type: ignore + + +def batch_norm( + data: Expr, + gamma: Expr, + beta: Expr, + moving_mean: Expr, + moving_var: Expr, + axis: int, + epsilon: float = 1e-5, + center: bool = True, + scale: bool = True, +) -> Expr: + r""" + Batch normalization layer (Ioffe and Szegedy, 2014). + Normalizes the input at each batch, i.e. applies a transformation + that maintains the mean activation close to 0 and the activation + standard deviation close to 1. + + .. math:: + + data\_mean[i] = mean(data[:,i,:,...]) \\ + data\_var[i] = var(data[:,i,:,...]) + + Then compute the normalized output, which has the same shape as input, as following: + + .. math:: + + out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} + * gamma[i] + beta[i] + + Both *mean* and *var* returns a scalar by treating the input as a vector. + + Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` + have shape *(k,)*. + + Besides the inputs and the outputs, this operator accepts two auxiliary + states, ``moving_mean`` and ``moving_var``, which are *k*-length + vectors. They are global statistics for the whole dataset, which are updated by + + .. code:: python + + moving_mean = moving_mean * momentum + data_mean * (1 - momentum) + moving_var = moving_var * momentum + data_var * (1 - momentum) + + The parameter ``axis`` specifies which axis of the input shape denotes + the 'channel' (separately normalized groups). The default is 1. + Specifying -1 sets the channel axis to be the last item in the input shape. + + .. note:: + + This operator can be optimized away for inference. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + gamma : relax.Expr + The gamma scale factor. + + beta : relax.Expr + The beta offset factor. + + moving_mean : relax.Expr + Running mean of input. + + moving_var : relax.Expr + Running variance of input. + + axis : int + The axis along which the normalization is applied. + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + Indicating if the beta offset will be added to the normalized tensor. + + scale : bool + Indicating if the gamma scale will be multiplied. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.batch_norm( # type: ignore + data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale + ) + + +def layer_norm( + data: Expr, + gamma: Expr, + beta: Expr, + axes: Union[int, List[int]], + epsilon: float = 1e-5, + center: bool = True, + scale: bool = True, +) -> Expr: + r""" + Layer normalization (Lei Ba and et al., 2016). + Applies layer normalization to the n-dimensional input array. + This operator takes an n-dimensional input array and normalizes + the input using the given axis: + + .. math:: + + out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}} + * gamma + beta + + Unlike batch normalization, the mean and var are computed along the channel dimension. + + Assume the input has size k on axis 1, then both gamma and beta have shape (k,). + + .. note:: + + This operator can be optimized away for inference. + + Parameters + ---------- + data : relax.Expr + Input to which layer_norm will be applied. + + gamma : relax.Expr + The gamma scale factor. + + beta : relax.Expr + The beta offset factor. + + axes : Union[int, List[int]] + The axes that along which the normalization is applied. + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + Indicating if the beta offset will be added to the normalized tensor. + + scale : bool + Indicating if the gamma scale will be multiplied. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axes, int): + axes = [axes] + return _ffi_api.layer_norm(data, gamma, beta, axes, epsilon, center, scale) # type: ignore + + +def dropout(data: Expr, rate: float = 0.5) -> Expr: + """Applies the dropout operation to the input tensor. + + During training, each element of the input is set to zero with + probability ``p``. The whole array is scaled by ``1/(1-p)`` + to keep the expected sum of the input unchanged. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + rate : float + The probability for an element to be reset to 0. + + Returns + ------- + result : relax.Expr + The result of dropout, which is a tuple of two tensors. + The first one is the original tensor and the second one is a + mask tensor (1.0 where element not dropped, 0.0 where dropped) + """ + return _ffi_api.dropout(data, rate) # type: ignore diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 1fb8853040fd..68f84b3514a9 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -34,6 +34,41 @@ class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" +@tvm._ffi.register_object("relax.attrs.Conv2DAttrs") +class Conv2DAttrs(Attrs): + """Attributes for nn.conv2d""" + + +@tvm._ffi.register_object("relax.attrs.MaxPool2DAttrs") +class MaxPool2DAttrs(Attrs): + """Attributes for nn.max_pool2d""" + + +@tvm._ffi.register_object("relax.attrs.AdaptivePool2DAttrs") +class AdaptivePool2DAttrs(Attrs): + """Attributes for 2d adaptive pool operator""" + + +@tvm._ffi.register_object("relax.attrs.SoftmaxAttrs") +class SoftmaxAttrs(Attrs): + """Attributes for nn.softmax""" + + +@tvm._ffi.register_object("relax.attrs.BatchNormAttrs") +class BatchNormAttrs(Attrs): + """Attributes used in batch_norm operator""" + + +@tvm._ffi.register_object("relax.attrs.LayerNormAttrs") +class LayerNormAttrs(Attrs): + """Attributes used in layer_norm operator""" + + +@tvm._ffi.register_object("relax.attrs.DropoutAttrs") +class DropoutAttrs(Attrs): + """Attributes for dropout operator""" + + @tvm._ffi.register_object("relax.attrs.StatisticalAttrs") class StatisticalAttrs(Attrs): """Attributes used in statistical operator""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 47779a602452..1f0e31428c63 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -93,6 +93,7 @@ tan, tanh, unique, + nn, ) from tvm.relax.struct_info import StructInfo from tvm.relax.utils import args_converter @@ -530,4 +531,5 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "tuple", "variance", "unique", + "nn", ] diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc new file mode 100644 index 000000000000..a3ddd3e350a0 --- /dev/null +++ b/src/relax/op/nn/convolution.cc @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/op/nn/convolution.cc + * \brief Convolution operators + */ + +#include "convolution.h" + +#include + +namespace tvm { +namespace relax { + +/* relax.nn.conv2d */ +TVM_REGISTER_NODE_TYPE(Conv2DAttrs); + +Expr conv2d(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, String data_layout, String kernel_layout, + Optional out_layout, DataType out_dtype) { + padding = GetCompletePadding2D(std::move(padding)); + if (strides.size() == 1) { + strides.push_back(strides[0]); + } + if (dilation.size() == 1) { + dilation.push_back(dilation[0]); + } + + CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " + "the given number of groups is " + << groups; + CHECK_EQ(strides.size(), 2) + << "The input strides length is expected to be 2. However, the given strides is " << strides; + CHECK_EQ(dilation.size(), 2) + << "The input dilation length is expected to be 2. However, the given dilation is " + << dilation; + return MakeConv(std::move(data), std::move(weight), std::move(strides), + std::move(padding), std::move(dilation), groups, data_layout, + std::move(kernel_layout), out_layout.value_or(data_layout), + out_dtype, /*op_name=*/"relax.nn.conv2d"); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.conv2d").set_body_typed(conv2d); + +StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo weight_sinfo = input_sinfo[1]; + + const auto* attrs = call->attrs.as(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->data_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"data"); + auto [weight_layout, weight2OIHW] = CheckTensorLayout(call, ctx, attrs->kernel_layout, // + /*tgt_layout=*/"OIHW", // + /*tensor_name=*/"kernel"); + auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, attrs->out_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"output"); + + Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + Optional weight_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); + + DataType out_dtype = attrs->out_dtype.is_void() + ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) + : attrs->out_dtype; + if (!data_shape.defined() || !weight_shape.defined()) { + return TensorStructInfo(out_dtype, out_layout.ndim()); + } + + Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + Array weight_OIHW_shape = weight2OIHW.ForwardShape(weight_shape.value()->values); + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr input_channel_data = data_NCHW_shape[1]; + PrimExpr input_channel_kernel = weight_OIHW_shape[1]; + if (analyzer->CanProve(input_channel_data != input_channel_kernel * attrs->groups)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "The channel size of the data should equal to the product of input channel size of the " + "weight and the number of groups. However, the data channel size is " + << input_channel_data << " while the weight input channel size and number of groups are " + << input_channel_kernel << " and " << attrs->groups); + } else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel * attrs->groups)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check + } + if (analyzer->CanProve(floormod(weight_OIHW_shape[0], attrs->groups) != 0)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Conv2d expects the number of output channels to be divisible by the " + "number of groups. However, the number of output channels is " + << weight_OIHW_shape[0] << " while the number of groups is " << attrs->groups); + } else if (!analyzer->CanProveEqual(floormod(weight_OIHW_shape[0], attrs->groups), 0)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check + } + + PrimExpr input_h = data_NCHW_shape[2]; + PrimExpr input_w = data_NCHW_shape[3]; + PrimExpr kernel_h = weight_OIHW_shape[2]; + PrimExpr kernel_w = weight_OIHW_shape[3]; + PrimExpr padding_h = attrs->padding[0] + attrs->padding[2]; + PrimExpr padding_w = attrs->padding[1] + attrs->padding[3]; + + std::vector out_NCHW_shape; + out_NCHW_shape.resize(4); + out_NCHW_shape[0] = data_NCHW_shape[0]; + out_NCHW_shape[1] = weight_OIHW_shape[0]; + + PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w - 1) - 1; + out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[0]) + 1); + out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1); + + Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), out_dtype); +} + +TVM_REGISTER_OP("relax.nn.conv2d") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoConv2d); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h new file mode 100644 index 000000000000..a65617b48d90 --- /dev/null +++ b/src/relax/op/nn/convolution.h @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file convolution.h + * \brief The functions to make Relax neural network convolution operator calls. + */ + +#ifndef TVM_RELAX_OP_NN_CONVOLUTION_H_ +#define TVM_RELAX_OP_NN_CONVOLUTION_H_ + +#include + +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +template +inline Expr MakeConv(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, String data_layout, String kernel_layout, + String out_layout, DataType out_dtype, std::string op_name) { + auto attrs = make_object(); + attrs->strides = ConvertIntImmToInt64(strides); + attrs->padding = ConvertIntImmToInt64(padding); + attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->groups = groups; + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + const Op& op = Op::Get(op_name); + return Call(op, {data, weight}, Attrs(attrs), {}); +} + +/*! \brief 2D convolution */ +Expr conv2d(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, String data_layout, String kernel_layout, + Optional out_layout, DataType out_dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_NN_CONVOLUTION_H_ diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc new file mode 100644 index 000000000000..66ae10fe6ccd --- /dev/null +++ b/src/relax/op/nn/nn.cc @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "nn.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.nn.relu */ +RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(relu, "nn.relu", /*require_float_dtype=*/false); + +/* relax.nn.gelu */ +RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu, "nn.gelu", /*require_float_dtype=*/true); + +/* relax.nn.silu */ +RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(silu, "nn.silu", /*require_float_dtype=*/true); + +/* relax.nn.softmax */ +TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); + +Expr softmax(Expr data, int axis) { + auto attrs = make_object(); + attrs->axis = axis; + static const Op& op = Op::Get("relax.nn.softmax"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.softmax").set_body_typed(softmax); + +StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + if (data_sinfo->IsUnknownNdim()) { + return data_sinfo; + } + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + ctx->ReportFatal(Diagnostic::Error(call) << "Softmax requires the input tensor to have float " + "dtype. However, the given input dtype is " + << data_sinfo->dtype); + } + const auto* attrs = call->attrs.as(); + NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis); + + return data_sinfo; +} + +TVM_REGISTER_OP("relax.nn.softmax") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoSoftmax); + +bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, + const Array& input_sinfo, Array axes) { + Op op = Downcast(call->op); + int n_input = op->arguments.size(); + + TensorStructInfo data_sinfo = input_sinfo[0]; + + std::vector axes_non_neg; + if (!data_sinfo->IsUnknownNdim()) { + axes_non_neg = NormalizeAxes(call, ctx, data_sinfo->ndim, axes); + } + int n_axis = axes.size(); + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + ctx->ReportFatal( + Diagnostic::Error(call) + << op << " requires the input data to have float dtype. However, the given data dtype is " + << data_sinfo->dtype); + } + for (int i = 1; i < n_input; ++i) { + if (input_sinfo[i]->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << op + << " requires all the input tensors to have the same dtype. However, the " + << op->arguments[i]->name << " has dtype " << input_sinfo[i]->dtype + << " which is other than the input data's dtype " << data_sinfo->dtype); + } else if (input_sinfo[i]->ndim != n_axis) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " requires the input " << op->arguments[i]->name + << " to have as many dimensions as the length of input axes. However, the " + "given one has ndim " + << input_sinfo[i]->ndim << ", which is other than the length of axes " + << n_axis); + } + } + + std::vector> axis_lengths; + axis_lengths.reserve(n_input); + if (const auto* data_shape = data_sinfo->shape.as()) { + std::vector lengths; + lengths.reserve(n_axis); + for (int d = 0; d < n_axis; ++d) { + lengths.push_back(data_shape->values[axes_non_neg[d]]); + } + axis_lengths.push_back(lengths); + } + for (int i = 1; i < n_input; ++i) { + if (const auto* shape = input_sinfo[i]->shape.as()) { + axis_lengths.push_back(shape->values); + } + } + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + for (int i = 1; i < static_cast(axis_lengths.size()); ++i) { + for (int d = 0; d < n_axis; ++d) { + if (analyzer->CanProve(axis_lengths[0][d] != axis_lengths[i][d])) { + ctx->ReportFatal(Diagnostic::Error(call) + << op + << " requires the input gamma, beta, etc., to have size same as the " + "lengths of the data on the given axes. However, there exists " + << axis_lengths[0] << " and " << axis_lengths[i] << " that are unequal."); + } else if (!analyzer->CanProveEqual(axis_lengths[0][d], axis_lengths[i][d])) { + return true; + } + } + } + return false; +} + +/* relax.nn.batch_norm */ +TVM_REGISTER_NODE_TYPE(BatchNormAttrs); + +Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // + int axis, double epsilon, bool center, bool scale) { + ObjectPtr attrs = make_object(); + attrs->axis = axis; + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + + static const Op& op = Op::Get("relax.nn.batch_norm"); + return Call(op, + {std::move(data), std::move(gamma), std::move(beta), std::move(moving_mean), + std::move(moving_var)}, + Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm); + +StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as(); + bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, {attrs->axis}); + + DataType dtype = input_sinfo[0]->dtype; + if (unknown_shape) { + return TupleStructInfo({TensorStructInfo(dtype, input_sinfo[0]->ndim), + TensorStructInfo(dtype, /*ndim=*/1), + TensorStructInfo(dtype, /*ndim=*/1)}); + } else { + return TupleStructInfo({input_sinfo[0], input_sinfo[3], input_sinfo[4]}); + } +} + +TVM_REGISTER_OP("relax.nn.batch_norm") + .set_attrs_type() + .set_num_inputs(5) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .add_argument("moving_mean", "Tensor", "Running mean of input.") + .add_argument("moving_var", "Tensor", "Running variance of input.") + .set_attr("FInferStructInfo", InferStructInfoBatchNorm); + +/* relax.nn.layer_norm */ +TVM_REGISTER_NODE_TYPE(LayerNormAttrs); + +Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, + bool scale) { + ObjectPtr attrs = make_object(); + attrs->axes = std::move(axes); + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + + static const Op& op = Op::Get("relax.nn.layer_norm"); + return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.layer_norm").set_body_typed(layer_norm); + +StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as(); + bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes); + + return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim) + : input_sinfo[0]; +} + +TVM_REGISTER_OP("relax.nn.layer_norm") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_attr("FInferStructInfo", InferStructInfoLayerNorm); + +/* relax.nn.dropout */ +TVM_REGISTER_NODE_TYPE(DropoutAttrs); + +Expr dropout(Expr data, double rate) { + ObjectPtr attrs = make_object(); + attrs->rate = rate; + + static const Op& op = Op::Get("relax.nn.dropout"); + return Call(op, {std::move(data)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.dropout").set_body_typed(dropout); + +StructInfo InferStructInfoDropout(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + return TupleStructInfo({data_sinfo, data_sinfo}); +} + +TVM_REGISTER_OP("relax.nn.dropout") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input to which dropout will be applied.") + .set_attr("FInferStructInfo", InferStructInfoDropout); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h new file mode 100644 index 000000000000..df2b978fc296 --- /dev/null +++ b/src/relax/op/nn/nn.h @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file nn.h + * \brief The functions to make Relax neural network operator calls. + */ + +#ifndef TVM_RELAX_OP_NN_NN_H_ +#define TVM_RELAX_OP_NN_NN_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Quick helper macro to + * - expose a make-function interface which construct the call node. + * - register op to the registry. + * \param OpName The name of operator to register. + * \param OpRegName The identifier of the operator in the registry. + * \param RequireFloatDtype A boolean indicating if the input is required to have float dtype. + */ +#define RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(OpName, OpRegName, RequireFloatDtype) \ + RELAX_REGISTER_UNARY_OP(OpRegName).set_attr( \ + "FInferStructInfo", InferStructInfoUnaryArith); \ + RELAX_UNARY_OP_INTERFACE(OpName, OpRegName); + +/*! \brief Rectified linear unit. */ +Expr relu(Expr data); + +/*! \brief Gaussian Error Linear Units function. */ +Expr gelu(Expr data); + +/*! \brief Sigmoid Linear Unit function. */ +Expr silu(Expr data); + +/*! \brief Softmax function. */ +Expr softmax(Expr data, int axis); + +/*! \brief Compute batch normalization. */ +Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // + int axis, double epsilon, bool center, bool scale); + +/*! \brief Compute layer normalization. */ +Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, + bool scale); + +/*! + * \brief Applies the dropout operation to the input tensor. + * \param data The input data to the operator. + * \param rate The probability for an element to be reset to 0. + * \return A Tuple of two tensors. + * The first one is the original tensor and the second one is a + * mask tensor (1.0 where element not dropped, 0.0 where dropped) + */ +Expr dropout(Expr data, double rate); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_NN_NN_H_ diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc new file mode 100644 index 000000000000..a4c1e6b17dd3 --- /dev/null +++ b/src/relax/op/nn/pooling.cc @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "pooling.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.nn.max_pool2d */ +TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); + +Expr max_pool2d(Expr data, Array pool_size, Array strides, Array padding, + Array dilation, bool ceil_mode, String layout, + Optional out_layout) { + padding = GetCompletePadding2D(std::move(padding)); + if (pool_size.size() == 1) { + pool_size.push_back(pool_size[0]); + } + if (strides.size() == 1) { + strides.push_back(strides[0]); + } + if (dilation.size() == 1) { + dilation.push_back(dilation[0]); + } + + CHECK_EQ(pool_size.size(), 2) + << "The input pool_size length is expected to be 2. However, the given pool_size is " + << pool_size; + CHECK_EQ(strides.size(), 2) + << "The input strides length is expected to be 2. However, the given strides is " << strides; + CHECK_EQ(dilation.size(), 2) + << "The input dilation length is expected to be 2. However, the given dilation is " + << dilation; + + auto attrs = make_object(); + attrs->pool_size = std::move(pool_size); + attrs->strides = ConvertIntImmToInt64(strides); + attrs->padding = ConvertIntImmToInt64(padding); + attrs->dilation = ConvertIntImmToInt64(dilation); + attrs->ceil_mode = ceil_mode; + attrs->layout = layout; + attrs->out_layout = out_layout.value_or(layout); + static const Op& op = Op::Get("relax.nn.max_pool2d"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(max_pool2d); + +StructInfo InferStructInfoMaxPool2D(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"data"); + auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, attrs->out_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"output"); + + Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + if (!data_shape.defined()) { + return TensorStructInfo(data_sinfo->dtype, out_layout.ndim()); + } + + Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + + PrimExpr input_h = data_NCHW_shape[2]; + PrimExpr input_w = data_NCHW_shape[3]; + PrimExpr kernel_h = attrs->pool_size[0]; + PrimExpr kernel_w = attrs->pool_size[1]; + PrimExpr padding_h = attrs->padding[0] + attrs->padding[2]; + PrimExpr padding_w = attrs->padding[1] + attrs->padding[3]; + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + std::vector out_NCHW_shape; + out_NCHW_shape.resize(4); + out_NCHW_shape[0] = data_NCHW_shape[0]; + out_NCHW_shape[1] = data_NCHW_shape[1]; + + PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w - 1) - 1; + if (attrs->ceil_mode) { + numerator_h += attrs->strides[0] - 1; + numerator_w += attrs->strides[1] - 1; + } + out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, attrs->strides[0]) + 1); + out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[1]) + 1); + + Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.nn.max_pool2d") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoMaxPool2D); + +/* relax.nn.adaptive_avg_pool2d */ +TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs); + +Expr adaptive_avg_pool2d(Expr data, Optional> output_size, String layout, + Optional out_layout) { + ObjectPtr attrs = make_object(); + attrs->layout = layout; + attrs->out_layout = out_layout.value_or(layout); + if (output_size.defined()) { + Array _output_size = output_size.value(); + if (_output_size.size() == 1) { + _output_size.push_back(_output_size[0]); + } + CHECK_EQ(_output_size.size(), 2) + << "The output_size length is expected to be 2. However, the given output_size is " + << _output_size; + attrs->output_size = std::move(_output_size); + } + + static const Op& op = Op::Get("relax.nn.adaptive_avg_pool2d"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool2d").set_body_typed(adaptive_avg_pool2d); + +StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"data"); + auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, attrs->out_layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"output"); + + Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + if (!data_shape.defined()) { + if (data_sinfo->shape.defined() && attrs->out_layout == attrs->layout && + !attrs->output_size.defined()) { + return data_sinfo; + } else { + return TensorStructInfo(data_sinfo->dtype, out_layout.ndim()); + } + } + + Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + Array out_NCHW_shape(data_NCHW_shape); + if (attrs->output_size.defined()) { + out_NCHW_shape.Set(2, attrs->output_size.value()[0]); + out_NCHW_shape.Set(3, attrs->output_size.value()[1]); + } + + Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attr("FInferStructInfo", InferStructInfoAdaptiveAvgPool2D); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h new file mode 100644 index 000000000000..3c1792d21f6b --- /dev/null +++ b/src/relax/op/nn/pooling.h @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file pooling.h + * \brief The functions to make Relax neural network pooling operator calls. + */ + +#ifndef TVM_RELAX_OP_NN_POOLING_H_ +#define TVM_RELAX_OP_NN_POOLING_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief 2D maximum pooling operator. */ +Expr max_pool2d(Expr data, Array pool_size, Array strides, Array padding, + Array dilation, bool ceil_mode, String layout, Optional out_layout); + +/*! \brief 2D adaptive average pooling operator. */ +Expr adaptive_avg_pool2d(Expr data, Optional> output_size, String layout, + Optional out_layout); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_NN_POOLING_H_ diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py new file mode 100644 index 000000000000..d047448309ab --- /dev/null +++ b/tests/python/relax/test_op_nn.py @@ -0,0 +1,929 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + assert relax.op.nn.relu(x).op == Op.get("relax.nn.relu") + assert relax.op.nn.gelu(x).op == Op.get("relax.nn.gelu") + assert relax.op.nn.silu(x).op == Op.get("relax.nn.silu") + assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax") + assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout") + + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + gamma = relax.Var("gamma", R.Tensor((3,), "float32")) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var = relax.Var("moving_var", R.Tensor((3,), "float32")) + assert relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1).op == Op.get( + "relax.nn.batch_norm" + ) + assert relax.op.nn.layer_norm(x, gamma, beta, axes=1).op == Op.get("relax.nn.layer_norm") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_linear_unit_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.nn.silu(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.nn.gelu(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.nn.relu(x3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.nn.gelu(x4), relax.TensorStructInfo(dtype="")) + + +def test_linear_unit_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((4, n), "float32")) + + _check_inference(bb, relax.op.nn.silu(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((4, n), "float32")) + + +def test_linear_unit_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, relax.op.nn.gelu(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo(s1, "float32")) + + +def test_linear_unit_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + + _check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.nn.relu(x2), relax.TensorStructInfo((2, 3), "int64")) + + +def test_linear_unit_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3), "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.gelu(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.silu(x1)) + + +def test_linear_unit_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.gelu(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.silu(x1)) + + +def test_softmax_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.nn.softmax(x1, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference(bb, relax.op.nn.softmax(x2, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.nn.softmax(x3, axis=-1), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.nn.softmax(x4, axis=-2), relax.TensorStructInfo(dtype="")) + + +def test_softmax_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((4, n), "float32")) + + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.nn.softmax(x1, axis=0), relax.TensorStructInfo((4, n), "float32")) + + +def test_softmax_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo(s1, "float32")) + + +def test_softmax_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "float64")) + + _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo((2, 3), "float64")) + + +def test_softmax_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3), "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x1)) + + +def test_softmax_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x, axis=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x, axis=-4)) + + +def test_softmax_wrong_with_multiple_axes(): + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + with pytest.raises(TVMError): + relax.op.nn.softmax(x, axis=[1, 2]) + with pytest.raises(TVMError): + relax.op.nn.softmax(x, axis=[-1, -2, -3]) + + +def test_softmax_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.softmax(x1)) + + +def test_batch_norm_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor(ndim=4)) + x4 = relax.Var("x", R.Tensor()) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor("float32", ndim=1)) + gamma2 = relax.Var("gamma", R.Tensor(ndim=1)) + beta0 = relax.Var("beta", R.Tensor((3,), "float32")) + beta1 = relax.Var("beta", R.Tensor((3,))) + moving_mean0 = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_mean1 = relax.Var("moving_mean", R.Tensor((3,))) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor("float32", ndim=1)) + moving_var2 = relax.Var("moving_var", R.Tensor(ndim=1)) + + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo((3,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var0, axis=-3), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo((3,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x1, gamma0, beta0, moving_mean0, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo((3,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma1, beta0, moving_mean0, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo((3,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var1, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x1, gamma1, beta0, moving_mean0, moving_var1, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x2, gamma1, beta0, moving_mean0, moving_var1, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo((3,), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x3, gamma2, beta1, moving_mean1, moving_var2, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(ndim=4, dtype=""), + relax.TensorStructInfo((3,), dtype=""), + relax.TensorStructInfo(dtype="", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x4, gamma2, beta1, moving_mean1, moving_var2, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo((3,), dtype=""), + relax.TensorStructInfo(dtype="", ndim=1), + ] + ), + ) + + +def test_batch_norm_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c0 = tir.Var("c", "int64") + c1 = tir.Var("c", "int64") + h = tir.Var("h", "int64") + w = tir.Var("w", "int64") + x0 = relax.Var("x", R.Tensor((n, c0, h, w), "float32")) + x1 = relax.Var("x", R.Tensor((n, c1, h, w), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + gamma0 = relax.Var("gamma", R.Tensor((c0,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((c1,), "float32")) + gamma2 = relax.Var("gamma", R.Tensor("float32", ndim=1)) + beta = relax.Var("beta", R.Tensor((c0,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((c0,), "float32")) + moving_var0 = relax.Var("moving_var", R.Tensor((c0,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor((c1,), "float32")) + moving_var2 = relax.Var("moving_var", R.Tensor("float32", ndim=1)) + + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((n, c0, h, w), "float32"), + relax.TensorStructInfo((c0,), "float32"), + relax.TensorStructInfo((c0,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x1, gamma0, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x2, gamma0, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo((c0,), "float32"), + relax.TensorStructInfo((c0,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var1, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma2, beta, moving_mean, moving_var0, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((n, c0, h, w), "float32"), + relax.TensorStructInfo((c0,), "float32"), + relax.TensorStructInfo((c0,), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var2, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((n, c0, h, w), "float32"), + relax.TensorStructInfo((c0,), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=1), + ] + ), + ) + + +def test_batch_norm_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s1", relax.ShapeStructInfo()) + s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=1)) + s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32")) + beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32")) + moving_mean = relax.Var("moving_mean", relax.TensorStructInfo(s2, "float32")) + moving_var = relax.Var("moving_var", relax.TensorStructInfo(s3, "float32")) + + _check_inference( + bb, + relax.op.nn.batch_norm(x0, gamma, beta, moving_mean, moving_var, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(s0, "float32"), + relax.TensorStructInfo(s2, "float32"), + relax.TensorStructInfo(s3, "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.batch_norm(x1, gamma, beta, moving_mean, moving_var, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(s1, "float32"), + relax.TensorStructInfo(s2, "float32"), + relax.TensorStructInfo(s3, "float32"), + ] + ), + ) + + +def test_batch_norm_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) + gamma = relax.Var("gamma", R.Tensor((3,), "float16")) + beta = relax.Var("beta", R.Tensor((3,), "float16")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float16")) + moving_var = relax.Var("moving_var", R.Tensor((3,), "float16")) + + _check_inference( + bb, + relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 28, 28), "float16"), + relax.TensorStructInfo((3,), "float16"), + relax.TensorStructInfo((3,), "float16"), + ] + ), + ) + + +def test_batch_norm_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "int8")) + beta0 = relax.Var("beta", R.Tensor((3,), "int8")) + moving_mean0 = relax.Var("moving_mean", R.Tensor((3,), "int8")) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int32")) + gamma1 = relax.Var("gamma", R.Tensor((3,), "int32")) + beta1 = relax.Var("beta", R.Tensor((3,), "int32")) + moving_mean1 = relax.Var("moving_mean", R.Tensor((3,), "int32")) + moving_var1 = relax.Var("moving_var", R.Tensor((3,), "int32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x1, gamma1, beta1, moving_mean1, moving_var1, axis=1)) + + +def test_batch_norm_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + gamma = relax.Var("gamma", R.Tensor((3,), "float32")) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var = relax.Var("moving_var", R.Tensor((3,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=-5)) + + +def test_batch_norm_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((3,))) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor((3,), "float16")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x1, gamma0, beta, moving_mean, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var1, axis=1)) + + +def test_batch_norm_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((3, 1), "float32")) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor((1, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x, gamma1, beta, moving_mean, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x, gamma0, beta, moving_mean, moving_var1, axis=1)) + + +def test_batch_norm_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + c = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, c, 28, 28), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "float32")) + gamma2 = relax.Var("gamma", R.Tensor((c + 2,), "float32")) + beta0 = relax.Var("beta", R.Tensor((3,), "float32")) + beta1 = relax.Var("beta", R.Tensor((c,), "float32")) + moving_mean0 = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_mean1 = relax.Var("moving_mean", R.Tensor((c,), "float32")) + moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32")) + moving_var1 = relax.Var("moving_var", R.Tensor((4,), "float32")) + moving_var2 = relax.Var("moving_var", R.Tensor((c,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma1, beta0, moving_mean0, moving_var0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var1, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x1, gamma2, beta1, moving_mean1, moving_var2, axis=1)) + + +def test_batch_norm_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) + gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((3,), "float32"))) + beta = relax.Var("beta", R.Tensor((3,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32")) + moving_var = relax.Var("moving_var", R.Tensor((3,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x1, gamma0, beta, moving_mean, moving_var, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, moving_var, axis=1)) + + +def test_layer_norm_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", R.Tensor("float32", ndim=2)) + gamma2 = relax.Var("gamma", R.Tensor((4, 5))) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((4, 5))) + + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, 3]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x1, gamma0, beta0, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x2, gamma0, beta0, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma1, beta0, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x3, gamma2, beta1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype=""), + ) + + +def test_layer_norm_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c0 = tir.Var("c", "int64") + c1 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((n, a, b, c0), "float32")) + x1 = relax.Var("x", R.Tensor((n, a, b, c1), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + gamma0 = relax.Var("gamma", R.Tensor((b, c0), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((b, c1), "float32")) + beta = relax.Var("beta", R.Tensor((b, c0), "float32")) + + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma0, beta, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c0), "float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x1, gamma0, beta, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma1, beta, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x2, gamma0, beta, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x2, gamma1, beta, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_layer_norm_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s1", relax.ShapeStructInfo()) + s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=2)) + s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=2)) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32")) + beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32")) + + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma, beta, axes=[2, 3]), + relax.TensorStructInfo(s0, "float32"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x1, gamma, beta, axes=[2, 3]), + relax.TensorStructInfo(s1, "float32"), + ) + + +def test_layer_norm_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float16")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float64")) + gamma1 = relax.Var("gamma", R.Tensor((4, 5), "float64")) + beta1 = relax.Var("beta", R.Tensor((4, 5), "float64")) + + _check_inference( + bb, + relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float16"), + ) + _check_inference( + bb, + relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float64"), + ) + + +def test_layer_norm_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "int8")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int32")) + gamma1 = relax.Var("gamma", R.Tensor((4, 5), "int32")) + beta1 = relax.Var("beta", R.Tensor((4, 5), "int32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1])) + + +def test_layer_norm_infer_struct_info_axis_out_of_range_and_repetitive(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4, 5), "float32")) + beta = relax.Var("beta", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma, beta, axes=[3, 4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma, beta, axes=[3, -1])) + + +def test_layer_norm_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4, 5), "int8")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((4, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma1, beta0, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma0, beta1, axes=[-2, -1])) + + +def test_layer_norm_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma1, beta0, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x, gamma0, beta1, axes=[-2, -1])) + + +def test_layer_norm_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + c0 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, c0), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 6), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4, c0), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((4, c0 - 2), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1])) + + +def test_layer_norm_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((4, 5), "float32"))) + beta = relax.Var("beta", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x1, gamma0, beta, axes=[-2, -1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.layer_norm(x0, gamma1, beta, axes=[-2, -1])) + + +def test_dropout_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.nn.dropout(x0), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x2), + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x3), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), dtype=""), relax.TensorStructInfo((2, 3), dtype="")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x4), + relax.TupleStructInfo([relax.TensorStructInfo(dtype=""), relax.TensorStructInfo(dtype="")]), + ) + + +def test_dropout_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor((m, n), "float32")) + + _check_inference( + bb, + relax.op.nn.dropout(x), + relax.TupleStructInfo( + [relax.TensorStructInfo((m, n), "float32"), relax.TensorStructInfo((m, n), "float32")] + ), + ) + + +def test_dropout_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, + relax.op.nn.dropout(x0), + relax.TupleStructInfo( + [relax.TensorStructInfo(s0, "float32"), relax.TensorStructInfo(s0, "float32")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x1), + relax.TupleStructInfo( + [relax.TensorStructInfo(s1, "float32"), relax.TensorStructInfo(s1, "float32")] + ), + ) + + +def test_dropout_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + + _check_inference( + bb, + relax.op.nn.dropout(x0), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), "float64"), relax.TensorStructInfo((2, 3), "float64")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x1), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), "int8"), relax.TensorStructInfo((2, 3), "int8")] + ), + ) + _check_inference( + bb, + relax.op.nn.dropout(x2), + relax.TupleStructInfo( + [relax.TensorStructInfo((2, 3), "int64"), relax.TensorStructInfo((2, 3), "int64")] + ), + ) + + +def test_dropout_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.dropout(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.dropout(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py new file mode 100644 index 000000000000..6533d434206b --- /dev/null +++ b/tests/python/relax/test_op_nn_convolution.py @@ -0,0 +1,429 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + assert relax.op.nn.conv2d(x, w).op == Op.get("relax.nn.conv2d") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_conv2d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 28, 28, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor()) + x5 = relax.Var("x", R.Tensor((2, 4, 28, 28, 16), "float32")) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32")) + w2 = relax.Var("w", R.Tensor("float32", ndim=4)) + w3 = relax.Var("w", R.Tensor("float32")) + w4 = relax.Var("w", R.Tensor((48, 4, 3, 3, 16), "float32")) + + _check_inference( + bb, relax.op.nn.conv2d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, out_dtype="float16"), + relax.TensorStructInfo((2, 4, 26, 26), "float16"), + ) + _check_inference( + bb, relax.op.nn.conv2d(x0, w0, padding=1), relax.TensorStructInfo((2, 4, 28, 28), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, padding=[1, 2]), + relax.TensorStructInfo((2, 4, 28, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, padding=[1, 2, 3, 4]), + relax.TensorStructInfo((2, 4, 30, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, strides=2), + relax.TensorStructInfo((2, 4, 13, 13), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, strides=(2, 3)), + relax.TensorStructInfo((2, 4, 13, 9), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, dilation=2), + relax.TensorStructInfo((2, 4, 24, 24), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, dilation=(2, 1)), + relax.TensorStructInfo((2, 4, 24, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x1, w0, data_layout="NHWC"), + relax.TensorStructInfo((2, 26, 26, 4), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, out_layout="NHWC"), + relax.TensorStructInfo((2, 26, 26, 4), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w1, kernel_layout="IOHW"), + relax.TensorStructInfo((2, 4, 26, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d( + x5, w4, data_layout="NCHW16c", kernel_layout="OIHW16i", out_layout="NHWC16c" + ), + relax.TensorStructInfo((2, 26, 26, 3, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.conv2d(x2, w0), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.conv2d(x3, w0), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.conv2d(x0, w2), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.conv2d(x0, w3), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.nn.conv2d(x4, w0), relax.TensorStructInfo(dtype="", ndim=4)) + + +def test_conv2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + ki = tir.Var("ki", "int64") + ko = tir.Var("ko", "int64") + kh = tir.Var("kh", "int64") + kw = tir.Var("kw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + w0 = relax.Var("w", R.Tensor((ko, ki, kh, kw), "float32")) + w1 = relax.Var("w", R.Tensor((ko, c, kh, kw), "float32")) + w2 = relax.Var("w", R.Tensor((ko, c, kh, kw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0), + relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w1), + relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d( + x1, w2, data_layout="NCHW16c", kernel_layout="OIHW16i", out_layout="NCHW" + ), + relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, strides=(2, 2), padding=(1, 1), dilation=(2, 2)), + relax.TensorStructInfo( + (n, ko, tvm.tir.floordiv(ih + 3, 2) + 1 - kh, tvm.tir.floordiv(iw + 3, 2) + 1 - kw), + "float32", + ), + ) + + +def test_conv2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + w = relax.Var("w", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.nn.conv2d(x0, w), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference( + bb, + relax.op.nn.conv2d(x1, w, data_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w, out_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x2, w), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_conv2d_infer_struct_info_groups(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 8, 28, 28, 16), "float32")) + w0 = relax.Var("w", R.Tensor((48, 16, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((48, 2, 3, 3, 8), "float32")) + + _check_inference( + bb, relax.op.nn.conv2d(x0, w0, groups=8), relax.TensorStructInfo((2, 48, 26, 26), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w1, kernel_layout="OIHW8i", groups=8), + relax.TensorStructInfo((2, 48, 26, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x1, w0, data_layout="NCHW16c", groups=8), + relax.TensorStructInfo((2, 3, 26, 26, 16), "float32"), + ) + + +def test_conv2d_infer_struct_info_symbolic_groups(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x = relax.Var("x", R.Tensor((n, ic * 4, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((oc * 4, ic, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((oc, ic, 3, 3), "float32")) + + _check_inference( + bb, + relax.op.nn.conv2d(x, w0, groups=4), + relax.TensorStructInfo((n, oc * 4, 26, 26), "float32"), + ) + _check_inference( + bb, relax.op.nn.conv2d(x, w1, groups=4), relax.TensorStructInfo((n, oc, 26, 26), "float32") + ) + + +def test_conv2d_infer_struct_info_input_channel_group_incompatible(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((48, 20, 3, 3), "float32")) + x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32")) + w1 = relax.Var("w", R.Tensor((oc, ic - 1, 3, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w0, groups=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w1, groups=6)) + + +def test_conv2d_infer_struct_info_output_channel_group_incompatible(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x0 = relax.Var("x", R.Tensor((2, 120, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((128, 20, 3, 3), "float32")) + x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32")) + w1 = relax.Var("w", R.Tensor((oc * 6 + 4, ic * 6, 3, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w0, groups=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w1, groups=6)) + + +def test_conv2d_non_positive_group(): + x = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((48, 16, 3, 3), "float32")) + + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, groups=0) + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, groups=-2) + + +def test_conv2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float64")) + w1 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float64")) + x2 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + w2 = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8")) + x3 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int32")) + w3 = relax.Var("w", R.Tensor((4, 3, 3, 3), "int32")) + + _check_inference( + bb, relax.op.nn.conv2d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26), "float16") + ) + _check_inference( + bb, relax.op.nn.conv2d(x1, w1), relax.TensorStructInfo((2, 4, 26, 26), "float64") + ) + _check_inference(bb, relax.op.nn.conv2d(x2, w2), relax.TensorStructInfo((2, 4, 26, 26), "int8")) + _check_inference( + bb, relax.op.nn.conv2d(x3, w3), relax.TensorStructInfo((2, 4, 26, 26), "int32") + ) + + +def test_conv2d_infer_struct_info_mixed_precision(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + w1 = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 28, 28))) + w2 = relax.Var("w", R.Tensor((4, 3, 3, 3))) + + _check_inference( + bb, + relax.op.nn.conv2d(x0, w0, out_dtype="float32"), + relax.TensorStructInfo((2, 4, 26, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x1, w1, out_dtype="int32"), + relax.TensorStructInfo((2, 4, 26, 26), "int32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x2, w2, out_dtype="float32"), + relax.TensorStructInfo((2, 4, 26, 26), "float32"), + ) + + +def test_conv2d_unequal_input_channel(): + bb = relax.BlockBuilder() + ic = tir.Var("ic", "int64") + x0 = relax.Var("x", R.Tensor([2, 3, 28, 28], "float32")) + w0 = relax.Var("w", R.Tensor([3, 4, 3, 3], "float32")) + x1 = relax.Var("x", R.Tensor([2, ic, 28, 28], "float32")) + w1 = relax.Var("w", R.Tensor([4, ic + 2, 3, 3], "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w1)) + + +def test_conv2d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + conv2d = relax.op.nn.conv2d(x, w, strides=(1, 1), padding=(1, 1), dilation=(1, 1)) + + assert conv2d.attrs.strides[0].dtype == "int64" + assert conv2d.attrs.strides[1].dtype == "int64" + assert conv2d.attrs.padding[0].dtype == "int64" + assert conv2d.attrs.padding[1].dtype == "int64" + assert conv2d.attrs.padding[2].dtype == "int64" + assert conv2d.attrs.padding[3].dtype == "int64" + assert conv2d.attrs.dilation[0].dtype == "int64" + assert conv2d.attrs.dilation[1].dtype == "int64" + + +def test_conv2d_wrong_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, strides=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, dilation=(1, 2, 3)) + + +def test_conv2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x, w, data_layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x, w, kernel_layout="NHWC")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x, w, out_layout="OHWI")) + + +def test_conv2d_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x, w)) + + +def test_conv2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((4, 3, 6, 3, 3), "float32")) + w2 = relax.Var("w", R.Tensor("float32", ndim=6)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w1, data_layout="NCHW16c")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x2, w0)) + + +def test_conv2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) + w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((4, 3, 3, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w0)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py new file mode 100644 index 000000000000..0eec5de21c98 --- /dev/null +++ b/tests/python/relax/test_op_nn_pooling.py @@ -0,0 +1,429 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + assert relax.op.nn.max_pool2d(x).op == Op.get("relax.nn.max_pool2d") + assert relax.op.nn.adaptive_avg_pool2d(x).op == Op.get("relax.nn.adaptive_avg_pool2d") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_max_pool2d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) + + _check_inference( + bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, pool_size=3), + relax.TensorStructInfo((2, 3, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, pool_size=(5, 3)), + relax.TensorStructInfo((2, 3, 28, 30), "float32"), + ) + _check_inference( + bb, relax.op.nn.max_pool2d(x0, padding=1), relax.TensorStructInfo((2, 3, 34, 34), "float32") + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, padding=[1, 2]), + relax.TensorStructInfo((2, 3, 34, 36), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, strides=2), + relax.TensorStructInfo((2, 3, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, dilation=2), + relax.TensorStructInfo((2, 3, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x1, layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x0, out_layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), + relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.max_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.max_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.nn.max_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4)) + _check_inference(bb, relax.op.nn.max_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4)) + + +def test_max_pool2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool2d( + x0, pool_size=(3, 3), strides=(3, 3), padding=(2, 2), dilation=(2, 2) + ), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(ih - 1, 3) + 1, + tvm.tir.floordiv(iw - 1, 3) + 1, + ), + "float32", + ), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x1, layout="NCHW16c", out_layout="NHWC"), + relax.TensorStructInfo((n, ih, iw, c * 16), "float32"), + ) + + +def test_max_pool2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x1, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x2), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_max_pool2d_infer_struct_info_ceil_mode(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool2d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x, pool_size=(5, 3), strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 15, 16), "float32"), + ) + + +def test_max_pool2d_infer_struct_info_ceil_mode_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool2d( + x, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), dilation=(2, 2), ceil_mode=True + ), + relax.TensorStructInfo((n, c, tvm.tir.floordiv(ih, 2), tvm.tir.floordiv(iw, 2)), "float32"), + ) + + +def test_max_pool2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) + _check_inference( + bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float16") + ) + _check_inference(bb, relax.op.nn.max_pool2d(x1), relax.TensorStructInfo((2, 3, 32, 32), "int8")) + _check_inference( + bb, relax.op.nn.max_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), "int64") + ) + + +def test_conv2d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + max_pool2d = relax.op.nn.max_pool2d(x, (3, 3), strides=(1, 1), padding=(1, 1), dilation=(1, 1)) + + assert max_pool2d.attrs.strides[0].dtype == "int64" + assert max_pool2d.attrs.strides[1].dtype == "int64" + assert max_pool2d.attrs.padding[0].dtype == "int64" + assert max_pool2d.attrs.padding[1].dtype == "int64" + assert max_pool2d.attrs.padding[2].dtype == "int64" + assert max_pool2d.attrs.padding[3].dtype == "int64" + assert max_pool2d.attrs.dilation[0].dtype == "int64" + assert max_pool2d.attrs.dilation[1].dtype == "int64" + + +def test_max_pool2d_wrong_pool_size_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + relax.op.nn.max_pool2d(x, pool_size=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.max_pool2d(x, strides=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.max_pool2d(x, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.max_pool2d(x, dilation=(1, 2, 3)) + + +def test_max_pool2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x, layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x, out_layout="OHWI")) + + +def test_max_pool2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x1)) + + +def test_max_pool2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool2d(x1)) + + +def test_adaptive_avg_pool2d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) + + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=30), + relax.TensorStructInfo((2, 3, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=(28, 30)), + relax.TensorStructInfo((2, 3, 28, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x1, layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, out_layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), + relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4) + ) + + +def test_adaptive_avg_pool2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((n, c, ih, iw), "float32") + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=256), + relax.TensorStructInfo((n, c, 256, 256), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=(256, 128)), + relax.TensorStructInfo((n, c, 256, 128), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x1, layout="NCHW16c", out_layout="NHWC"), + relax.TensorStructInfo((n, ih, iw, c * 16), "float32"), + ) + + +def test_adaptive_avg_pool2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=32), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x1, layout="NCHW16c"), + relax.TensorStructInfo(s1, "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, out_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x2, out_layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + + +def test_adaptive_avg_pool2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float16") + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x1), relax.TensorStructInfo((2, 3, 32, 32), "int8") + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), "int64") + ) + + +def test_adaptive_avg_pool2d_wrong_output_size_ndim(): + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + with pytest.raises(TVMError): + relax.op.nn.adaptive_avg_pool2d(x, (32, 32, 32)) + + +def test_adaptive_avg_pool2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x, layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x, out_layout="OHWI")) + + +def test_adaptive_avg_pool2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x1)) + + +def test_adaptive_avg_pool2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool2d(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py new file mode 100644 index 000000000000..4e52bccb8637 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_nn.py @@ -0,0 +1,193 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_conv2d(): + @R.function + def foo( + x: R.Tensor((2, 3, 228, 228), "float32"), w: R.Tensor((16, 3, 5, 5), "float32") + ) -> R.Tensor((2, 16, 224, 224), "float16"): + gv: R.Tensor((2, 16, 224, 224), "float16") = R.nn.conv2d(x, w, out_dtype="float16") + return gv + + x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float32")) + w = relax.Var("w", R.Tensor([16, 3, 5, 5], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, w]): + gv = bb.emit(relax.op.nn.conv2d(x, w, out_dtype="float16")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_max_pool2d(): + @R.function + def foo( + x: R.Tensor((1, 1, 32, 32), dtype="float32") + ) -> R.Tensor((1, 1, 30, 30), dtype="float32"): + gv: R.Tensor((1, 1, 30, 30), dtype="float32") = R.nn.max_pool2d(x, pool_size=(3,)) + return gv + + x = relax.Var("x", R.Tensor([1, 1, 32, 32], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.max_pool2d(x, pool_size=(3,))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_adaptive_avg_pool2d(): + @R.function + def foo(x: R.Tensor((2, 64, 8, 9), "float32")) -> R.Tensor((2, 64, 7, 7), "float32"): + gv: R.Tensor((2, 64, 7, 7), "float32") = R.nn.adaptive_avg_pool2d(x, output_size=(7, 7)) + return gv + + x = relax.Var("x", R.Tensor((2, 64, 8, 9), dtype="float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.adaptive_avg_pool2d(x, output_size=(7, 7))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_gelu(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.gelu(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.gelu(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_softmax(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.softmax(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.softmax(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_batch_norm(): + @R.function + def foo( + x: R.Tensor((2, 4, 3, 3), dtype="float32"), + gamma: R.Tensor((4,), dtype="float32"), + beta: R.Tensor((4,), dtype="float32"), + moving_mean: R.Tensor((4,), dtype="float32"), + moving_var: R.Tensor((4,), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 4, 3, 3), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ): + gv: R.Tuple( + R.Tensor((2, 4, 3, 3), dtype="float32"), + R.Tensor((4,), dtype="float32"), + R.Tensor((4,), dtype="float32"), + ) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1) + return gv + + x = relax.Var("x", R.Tensor((2, 4, 3, 3), "float32")) + gamma = relax.Var("gamma", R.Tensor((4,), "float32")) + beta = relax.Var("beta", R.Tensor((4,), "float32")) + moving_mean = relax.Var("moving_mean", R.Tensor((4,), "float32")) + moving_var = relax.Var("moving_var", R.Tensor((4,), "float32")) + + bb = relax.BlockBuilder() + with bb.function("foo", [x, gamma, beta, moving_mean, moving_var]): + gv = bb.emit(relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_layer_norm(): + @R.function + def foo( + x: R.Tensor((2, 3, 4, 5), "float32"), + gamma: R.Tensor((4, 5), "float32"), + beta: R.Tensor((4, 5), "float32"), + ) -> R.Tensor((2, 3, 4, 5), "float32"): + gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.layer_norm(x, gamma, beta, axes=[-2, -1]) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4, 5), "float32")) + beta = relax.Var("beta", R.Tensor((4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, gamma, beta]): + gv = bb.emit(relax.op.nn.layer_norm(x, gamma, beta, axes=[-2, -1])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_dropout(): + @R.function + def foo( + x: R.Tensor((2, 3), "float32") + ) -> R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3), "float32")): + gv: R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3), "float32")) = R.nn.dropout( + x, rate=0.5 + ) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.dropout(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() From 35f17cfab93290b84def84ee7e5c3c00647cbe34 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 15:05:09 -0500 Subject: [PATCH 20/81] [Unity] Relax op: creation (#13984) This PR is about the high-level tensor computation operators in Relax. This PR includes the tensor creation operators. --- include/tvm/relax/attrs/create.h | 54 ++ python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/create.py | 209 ++++++ python/tvm/relax/op/op_attrs.py | 10 + python/tvm/script/ir_builder/relax/ir.py | 16 + src/relax/op/tensor/create.cc | 264 ++++++++ src/relax/op/tensor/create.h | 90 +++ tests/python/relax/test_op_create.py | 638 ++++++++++++++++++ .../relax/test_tvmscript_parser_op_create.py | 162 +++++ 9 files changed, 1444 insertions(+) create mode 100644 include/tvm/relax/attrs/create.h create mode 100644 python/tvm/relax/op/create.py create mode 100644 src/relax/op/tensor/create.cc create mode 100644 src/relax/op/tensor/create.h create mode 100644 tests/python/relax/test_op_create.py create mode 100644 tests/python/relax/test_tvmscript_parser_op_create.py diff --git a/include/tvm/relax/attrs/create.h b/include/tvm/relax/attrs/create.h new file mode 100644 index 000000000000..6af176a42c9d --- /dev/null +++ b/include/tvm/relax/attrs/create.h @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/create.h + * \brief Attributes for tensor creation operators. + */ +#ifndef TVM_RELAX_ATTRS_CREATE_H_ +#define TVM_RELAX_ATTRS_CREATE_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operators */ +struct InitAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(InitAttrs, "relax.attrs.InitAttrs") { + TVM_ATTR_FIELD(dtype).describe("The data type of the created tensor."); + } +}; // struct InitAttrs + +/*! \brief Attributes used in tril and triu operator */ +struct TriluAttrs : public tvm::AttrsNode { + int k; + + TVM_DECLARE_ATTRS(TriluAttrs, "relax.attrs.TriluAttrs") { + TVM_ATTR_FIELD(k).describe( + "The number of diagonals above or below the main diagonal to exclude or include."); + } +}; // struct TriluAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_CREATE_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 6c6fffc7c65e..97d08c0946a0 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -20,6 +20,7 @@ # Operators from .base import * from .binary import * +from .create import * from .datatype import * from .index import * from .manipulate import * diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py new file mode 100644 index 000000000000..a6643a8633e4 --- /dev/null +++ b/python/tvm/relax/op/create.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Creation operators.""" +from typing import Optional, Tuple, Union + +from tvm import DataType +from tvm.ir.expr import PrimExpr + +from . import _ffi_api +from ..expr import Expr, ShapeExpr + +PrimExprLike = Union[int, PrimExpr] + + +def full( + shape: Union[Tuple[PrimExprLike], Expr], + fill_value: Expr, + dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + """Fill array with scalar value. + + Parameters + ---------- + shape : Union[Tuple[PrimExprLike], Expr] + The shape of the created tensor. + + fill_value : relax.Expr + The value to fill. Must be a scalar tensor. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of fill_value. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.full(shape, fill_value, dtype) # type: ignore + + +def full_like(x: Expr, fill_value: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr: + """Construct a tensor such that + - its shape is the same as the input data tensor's shape, + - its value is filled with the input scalar fill value. + + Parameters + ---------- + x : relax.Expr + The input tensor, which provides the shape, and dtype + when the `dtype` field is not specified. + + fill_value : relax.Expr + The value to fill. Must be a scalar tensor. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of the input tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.full_like(x, fill_value, dtype) # type: ignore + + +def ones(shape: Union[Tuple[PrimExprLike], Expr], dtype: Union[str, DataType]) -> Expr: + """Construct a tensor of all ones, with the input shape and dtype. + + Parameters + ---------- + shape : Union[Tuple[PrimExprLike], Expr] + The shape of the created tensor. + + dtype : Union[str, DataType] + The data type of the created tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + if isinstance(shape, (tuple, list)): + shape = ShapeExpr(shape) + return _ffi_api.ones(shape, dtype) # type: ignore + + +def ones_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr: + """Construct a tensor with all ones, with shape of the input tensor shape. + + Parameters + ---------- + x : relax.Expr + The input tensor, which provides the shape, and dtype + when the `dtype` field is not specified. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of the input tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.ones_like(x, dtype) # type: ignore + + +def zeros(shape: Union[Tuple[PrimExprLike], Expr], dtype: Union[str, DataType]) -> Expr: + """Construct a tensor of all zeros, with the input shape and dtype. + + Parameters + ---------- + shape : Union[Tuple[PrimExprLike], Expr] + The shape of the created tensor. + + dtype : Union[str, DataType] + The data type of the created tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + if isinstance(shape, (tuple, list)): + shape = ShapeExpr(shape) + return _ffi_api.zeros(shape, dtype) # type: ignore + + +def zeros_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr: + """Construct a tensor with all zeros, with shape of the input tensor shape. + + Parameters + ---------- + x : relax.Expr + The input tensor, which provides the shape, and dtype + when the `dtype` field is not specified. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of the input tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.zeros_like(x, dtype) # type: ignore + + +def tril(x: Expr, k: int = 0) -> Expr: + """Return the lower triangular part of a matrix or a batch of matrices. + + Parameters + ---------- + x : relax.Expr + The tensor that tril will be applied to. + It is required to have at least two dimensions. + + k : int + The index indicating the diagonal above which to zero elements. + If k = 0, the diagonal is the main diagonal. + If k < 0, the diagonal is below the main diagonal. + If k > 0, the diagonal is above the main diagonal. + + Returns + ------- + ret : relax.Expr + The result tensor. + """ + return _ffi_api.tril(x, k) # type: ignore + + +def triu(x: Expr, k: int = 0) -> Expr: + """Return the upper triangular part of a matrix or a batch of matrices. + + Parameters + ---------- + x : relax.Expr + The tensor that triu will be applied to. + It is required to have at least two dimensions. + + k : int + The index indicating the diagonal below which to zero elements. + If k = 0, the diagonal is the main diagonal. + If k < 0, the diagonal is below the main diagonal. + If k > 0, the diagonal is above the main diagonal. + + Returns + ------- + ret : relax.Expr + The result tensor. + """ + return _ffi_api.triu(x, k) # type: ignore diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 68f84b3514a9..ac6714d940d3 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -19,6 +19,16 @@ import tvm._ffi +@tvm._ffi.register_object("relax.attrs.InitAttrs") +class InitAttrs(Attrs): + """Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operator""" + + +@tvm._ffi.register_object("relax.attrs.TriluAttrs") +class TriluAttrs(Attrs): + """Attributes used in tril and triu operator""" + + @tvm._ffi.register_object("relax.attrs.AstypeAttrs") class AstypeAttrs(Attrs): """Attributes used in astype operator""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 1f0e31428c63..118790372a35 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -52,6 +52,8 @@ exp, floor, floor_divide, + full, + full_like, greater, greater_equal, image, @@ -71,6 +73,8 @@ negative, not_equal, null_value, + ones, + ones_like, print, prod, reshape, @@ -92,7 +96,11 @@ take, tan, tanh, + tril, + triu, unique, + zeros, + zeros_like, nn, ) from tvm.relax.struct_info import StructInfo @@ -480,6 +488,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "exp", "floor", "floor_divide", + "full", + "full_like", "func_attr", "func_name", "func_ret_struct_info", @@ -504,6 +514,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "negative", "not_equal", "null_value", + "ones", + "ones_like", "output", "prim_value", "print", @@ -528,8 +540,12 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "take", "tan", "tanh", + "tril", + "triu", "tuple", "variance", "unique", + "zeros", + "zeros_like", "nn", ] diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc new file mode 100644 index 000000000000..e8374d198109 --- /dev/null +++ b/src/relax/op/tensor/create.cc @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file create.cc + * \brief Creation operators. + */ + +#include "create.h" + +#include + +namespace tvm { +namespace relax { + +/* Initialization operators */ +TVM_REGISTER_NODE_TYPE(InitAttrs); + +/* relax.full */ +Expr full(ObjectRef shape, Expr fill_value, DataType dtype) { + Expr shape_in_expr{nullptr}; + if (const auto* expr = shape.as()) { + shape_in_expr = GetRef(expr); + } else if (const auto* _array = shape.as()) { + shape_in_expr = ShapeExpr(GetRef>(_array)); + } else { + LOG(FATAL) << "Full only expects the input shape to be either an Expr or an Array of PrimExpr. " + "However, the given one is " + << shape->GetTypeKey(); + } + + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.full"); + return Call(op, {std::move(shape_in_expr), std::move(fill_value)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.full").set_body_typed(full); + +StructInfo InferStructInfoFull(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "Full op should have 2 arguments"); + } + const auto* shape_sinfo = GetStructInfoAs(call->args[0]); + const auto* fill_value_sinfo = GetStructInfoAs(call->args[1]); + if (shape_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Full requires the input shape to be a Shape. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (fill_value_sinfo == nullptr || fill_value_sinfo->ndim != 0) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Full requires the input fill value to be zero rank Tensor. However, the given one is " + << call->args[1]->struct_info_); + } + + const auto* attrs = call->attrs.as(); + DataType out_dtype = attrs->dtype.is_void() ? fill_value_sinfo->dtype : attrs->dtype; + return TensorStructInfo(/*shape=*/call->args[0], out_dtype); +} + +TVM_REGISTER_OP("relax.full") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("shape", "Shape", "The shape of the created tensor.") + .add_argument("fill_value", "Tensor", "The scalar tensor, denoting the value to fill.") + .set_attr("FInferStructInfo", InferStructInfoFull); + +/* relax.full_like */ +Expr full_like(Expr x, Expr fill_value, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.full_like"); + return Call(op, {std::move(x), std::move(fill_value)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.full_like").set_body_typed(full_like); + +StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo fill_value_sinfo = input_sinfo[1]; + if (fill_value_sinfo->ndim != 0) { + ctx->ReportFatal(Diagnostic::Error(call) << "FullLike requires the input fill value to be zero " + "rank Tensor. However, the given one has ndim" + << fill_value_sinfo->ndim); + } + + const auto* attrs = call->attrs.as(); + if (attrs->dtype.is_void()) { + return data_sinfo; + } else { + auto output_sinfo = make_object(*data_sinfo.get()); + output_sinfo->dtype = attrs->dtype; + return TensorStructInfo(output_sinfo); + } +} + +TVM_REGISTER_OP("relax.full_like") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("fill_value", "Tensor", "The scalar value to fill.") + .set_attr("FInferStructInfo", InferStructInfoFullLike); + +// Structure info inference for ones and zeros +StructInfo InferStructInfoOnesZeros(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) << "Ones/Zeros should have 1 argument"); + } + + const auto* shape_sinfo = GetStructInfoAs(call->args[0]); + if (shape_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Ones/Zeros requires the input shape to be a Shape. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + const auto* attrs = call->attrs.as(); + return TensorStructInfo(/*shape=*/call->args[0], attrs->dtype); +} + +// Structure info inference for ones_like and zeros_like +StructInfo InferStructInfoOnesLikeZerosLike(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (attrs->dtype.is_void()) { + return data_sinfo; + } else { + auto output_sinfo = make_object(*data_sinfo.get()); + output_sinfo->dtype = attrs->dtype; + return TensorStructInfo(output_sinfo); + } +} + +/* relax.ones & relax.ones_like */ +Expr ones(Expr shape, DataType dtype) { + CHECK(!dtype.is_void()) << "Ones op expects the input dtype not to be void"; + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.ones"); + return Call(op, {std::move(shape)}, Attrs(attrs), {}); +} + +Expr ones_like(Expr x, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.ones_like"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.ones").set_body_typed(ones); +TVM_REGISTER_GLOBAL("relax.op.ones_like").set_body_typed(ones_like); + +TVM_REGISTER_OP("relax.ones") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("shape", "Shape", "The shape of the created tensor.") + .set_attr("FInferStructInfo", InferStructInfoOnesZeros); + +TVM_REGISTER_OP("relax.ones_like") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike); + +/* relax.zeros & relax.zeros_like */ +Expr zeros(Expr shape, DataType dtype) { + CHECK(!dtype.is_void()) << "Zeros op expects the input dtype not to be void"; + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.zeros"); + return Call(op, {std::move(shape)}, Attrs(attrs), {}); +} + +Expr zeros_like(Expr x, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.zeros_like"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.zeros").set_body_typed(zeros); +TVM_REGISTER_GLOBAL("relax.op.zeros_like").set_body_typed(zeros_like); + +TVM_REGISTER_OP("relax.zeros") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("shape", "Shape", "The shape of the created tensor.") + .set_attr("FInferStructInfo", InferStructInfoOnesZeros); + +TVM_REGISTER_OP("relax.zeros_like") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike); + +/* relax.tril & relax.triu */ +TVM_REGISTER_NODE_TYPE(TriluAttrs); + +Expr tril(Expr x, int k) { + ObjectPtr attrs = make_object(); + attrs->k = k; + + static const Op& op = Op::Get("relax.tril"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +Expr triu(Expr x, int k) { + ObjectPtr attrs = make_object(); + attrs->k = k; + + static const Op& op = Op::Get("relax.triu"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.tril").set_body_typed(tril); +TVM_REGISTER_GLOBAL("relax.op.triu").set_body_typed(triu); + +StructInfo InferStructInfoTrilTriu(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim < 2) { + ctx->ReportFatal(Diagnostic::Error(call) << call->op + << " requires the input tensor to have at least two " + "dimensions. However, the given input has " + << data_sinfo->ndim << " dimension(s)."); + } + return data_sinfo; +} + +TVM_REGISTER_OP("relax.tril") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoTrilTriu); + +TVM_REGISTER_OP("relax.triu") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoTrilTriu); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h new file mode 100644 index 000000000000..c1ade470b4e8 --- /dev/null +++ b/src/relax/op/tensor/create.h @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file create.h + * \brief The functions to make Relax tensor-creation operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_CREATE_H_ +#define TVM_RELAX_OP_TENSOR_CREATE_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Fill array with scalar value. + * \param shape The shape of the created tensor. + * \param fill_value The value to fill. Must be a scalar tensor. + * \param dtype The data type of the created tensor. + * If dtype is not given, it will by default use the dtype of fill_value. + * \return The result tensor. + */ +Expr full(ObjectRef shape, Expr fill_value, DataType dtype); + +/*! + * \brief Construct a tensor such that + * - its shape is the same as the input data tensor's shape, + * - its value is filled with the input scalar fill value. + * \param x The input tensor, which provides the shape, and dtype + * when the input dtype is void. + * \param fill_value The value to fill. Must be a scalar tensor. + * \param dtype The data type of the created tensor. If it is + * void, the input tensor's dtype will be used. + * \return The result tensor. + */ +Expr full_like(Expr x, Expr fill_value, DataType dtype); + +/*! + * \brief Construct a tensor of all ones, with the input shape and dtype. + * \param shape The shape of the created tensor. + * \param dtype The data type of the created tensor. + * \return The result tensor. + */ +Expr ones(Expr shape, DataType dtype); + +/*! + * \brief Construct a tensor with all ones, with shape of the input tensor shape. + * \param x The input tensor, which provides the shape, and dtype + * when the input dtype is void. + * \param dtype The data type of the created tensor. If it is + * void, the input tensor's dtype will be used. + * \return The result tensor. + */ +Expr ones_like(Expr x, DataType dtype); + +/*! \brief Construct a tensor of all zeros, with the input shape and dtype. */ +Expr zeros(Expr shape, DataType dtype); + +/*! \brief Construct a tensor with all zeros, with shape of the input tensor shape. */ +Expr zeros_like(Expr x, DataType dtype); + +/*! \brief Return the lower triangular part of a matrix or a batch of matrices. */ +Expr tril(Expr x, int k); + +/*! \brief Return the upper triangular part of a matrix or a batch of matrices. */ +Expr triu(Expr x, int k); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_CREATE_H_ diff --git a/tests/python/relax/test_op_create.py b/tests/python/relax/test_op_create.py new file mode 100644 index 000000000000..6dd0a0d15ead --- /dev/null +++ b/tests/python/relax/test_op_create.py @@ -0,0 +1,638 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + fill_value = relax.Var("fill_value", R.Tensor((), "float32")) + assert relax.op.full((2, 3), fill_value).op == Op.get("relax.full") + assert relax.op.full_like(x, fill_value).op == Op.get("relax.full_like") + assert relax.op.ones((2, 3), "float32").op == Op.get("relax.ones") + assert relax.op.ones_like(x).op == Op.get("relax.ones_like") + assert relax.op.zeros((2, 3), "float32").op == Op.get("relax.zeros") + assert relax.op.zeros_like(x).op == Op.get("relax.zeros_like") + assert relax.op.tril(x).op == Op.get("relax.tril") + assert relax.op.triu(x).op == Op.get("relax.triu") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_full_infer_struct_info(): + bb = relax.BlockBuilder() + v0 = relax.Var("v", R.Tensor((), "float32")) + v1 = relax.Var("v", R.Tensor("float32", ndim=0)) + v2 = relax.Var("v", R.Tensor(())) + v3 = relax.Var("v", R.Tensor(ndim=0)) + s0 = relax.ShapeExpr((2, 3)) + s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + + _check_inference( + bb, relax.op.full((2, 3), v0, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full((2, 3), v0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.full(s0, v0, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full(s0, v0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full(s1, v0, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v0), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full(s2, v0, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference(bb, relax.op.full(s2, v0), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full(s3, v0, "float16"), relax.TensorStructInfo(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v0), relax.TensorStructInfo(s3, "float32")) + _check_inference( + bb, relax.op.full((2, 3), v1, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full((2, 3), v1), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.full(s0, v1, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full(s0, v1), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full(s1, v1, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full(s2, v1, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference(bb, relax.op.full(s2, v1), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full(s3, v1, "float16"), relax.TensorStructInfo(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v1), relax.TensorStructInfo(s3, "float32")) + _check_inference( + bb, relax.op.full((2, 3), v2, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full((2, 3), v2), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference( + bb, relax.op.full(s0, v2, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full(s0, v2), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full(s1, v2, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v2), relax.TensorStructInfo(s1, dtype="")) + _check_inference(bb, relax.op.full(s2, v2, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference(bb, relax.op.full(s2, v2), relax.TensorStructInfo(s2, dtype="")) + _check_inference(bb, relax.op.full(s3, v2, "float16"), relax.TensorStructInfo(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v2), relax.TensorStructInfo(s3, dtype="")) + _check_inference( + bb, relax.op.full((2, 3), v3, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full((2, 3), v3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference( + bb, relax.op.full(s0, v3, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full(s0, v3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full(s1, v3, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference( + bb, + relax.op.full( + s1, + v3, + ), + relax.TensorStructInfo(s1, dtype=""), + ) + _check_inference(bb, relax.op.full(s2, v3, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference( + bb, + relax.op.full( + s2, + v3, + ), + relax.TensorStructInfo(s2, dtype=""), + ) + _check_inference(bb, relax.op.full(s3, v3, "float16"), relax.TensorStructInfo(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v3), relax.TensorStructInfo(s3, dtype="")) + + +def test_full_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + v = relax.Var("v", R.Tensor((), "float32")) + s0 = relax.ShapeExpr((a, 3)) + s1 = relax.Var("s", relax.ShapeStructInfo((a, 3))) + + _check_inference( + bb, relax.op.full((a, 3), v, "float16"), relax.TensorStructInfo((a, 3), "float16") + ) + _check_inference(bb, relax.op.full((a, 3), v), relax.TensorStructInfo((a, 3), "float32")) + _check_inference(bb, relax.op.full(s0, v, "float16"), relax.TensorStructInfo((a, 3), "float16")) + _check_inference(bb, relax.op.full(s0, v), relax.TensorStructInfo((a, 3), "float32")) + _check_inference(bb, relax.op.full(s1, v, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v), relax.TensorStructInfo(s1, "float32")) + + +def test_full_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(())) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + v0 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) + v1 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, relax.op.full((2, 3), v0, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.full((2, 3), v1, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + + +def test_full_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + v0 = relax.Var("v", R.Tensor((), "float16")) + v1 = relax.Var("v", R.Tensor((), "int8")) + v2 = relax.Var("v", R.Tensor((), "int32")) + + _check_inference( + bb, relax.op.full((2, 3), v0, "float32"), relax.TensorStructInfo((2, 3), "float32") + ) + _check_inference(bb, relax.op.full((2, 3), v0), relax.TensorStructInfo((2, 3), "float16")) + _check_inference( + bb, relax.op.full((2, 3), v1, "int32"), relax.TensorStructInfo((2, 3), "int32") + ) + _check_inference(bb, relax.op.full((2, 3), v1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.full((2, 3), v2, "int8"), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.full((2, 3), v2), relax.TensorStructInfo((2, 3), "int32")) + + +def test_full_infer_struct_info_fill_value_not_scalar_tensor(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((1,))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + v0 = relax.Var("v", R.Tensor((1,), "float32")) + v1 = relax.Var("v", R.Tensor("float32", ndim=1)) + v2 = relax.Var("v", R.Tensor("float32")) + v3 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) + v4 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) + v5 = relax.Var("v", relax.TensorStructInfo(s2, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v5)) + + +def test_full_shape_not_tuple(): + m = tir.Var("m", "int64") + v = relax.Var("v", R.Tensor((), "float32")) + + with pytest.raises(TVMError): + relax.op.full(4, v) + with pytest.raises(TVMError): + relax.op.full(m, v) + + +def test_full_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + v0 = relax.Var("v", R.Tensor((), "float32")) + v1 = relax.Var("v", relax.ShapeStructInfo(())) + v2 = relax.Var("v", relax.FuncStructInfo([], R.Tensor((), "float32"))) + s = relax.Var("s", R.Tensor((2, 3))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.full(s, v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v2)) + + +def test_full_like_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor(ndim=2)) + x5 = relax.Var("x", R.Tensor()) + v0 = relax.Var("v", R.Tensor((), "float16")) + v1 = relax.Var("v", R.Tensor("float16", ndim=0)) + v2 = relax.Var("v", R.Tensor(())) + v3 = relax.Var("v", R.Tensor(ndim=0)) + + _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x0, v2), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x0, v3), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.full_like(x1, v0), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.full_like(x1, v1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.full_like(x1, v2), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.full_like(x1, v3), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.full_like(x2, v0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.full_like(x2, v1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.full_like(x2, v2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.full_like(x2, v3), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.full_like(x3, v0), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x3, v1), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x3, v2), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x3, v3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x4, v0), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x4, v1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x4, v2), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x4, v3), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x5, v0), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.full_like(x5, v1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.full_like(x5, v2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.full_like(x5, v3), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.full_like(x0, v0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.full_like(x0, v2, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.full_like(x3, v0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.full_like(x3, v2, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + + +def test_full_like_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + v = relax.Var("v", R.Tensor((), "float16")) + + _check_inference(bb, relax.op.full_like(x0, v), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.full_like(x1, v), relax.TensorStructInfo((m, n), dtype="")) + + +def test_full_like_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", R.Tensor((2, 3), "float32")) + sv0 = relax.Var("sv", relax.ShapeStructInfo(())) + sv1 = relax.Var("sv", relax.ShapeStructInfo(ndim=0)) + v0 = relax.Var("v", relax.TensorStructInfo(sv0, "float16")) + v1 = relax.Var("v", relax.TensorStructInfo(sv1, "float16")) + v2 = relax.Var("v", R.Tensor((), "float16")) + + _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.full_like(x0, v2), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.full_like(x1, v0), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full_like(x1, v1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full_like(x1, v2), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full_like(x2, v0), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full_like(x2, v1), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full_like(x2, v2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full_like(x3, v0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x3, v1), relax.TensorStructInfo((2, 3), "float32")) + + +def test_full_like_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + v0 = relax.Var("v", R.Tensor((), "int32")) + v1 = relax.Var("v", R.Tensor((), "float64")) + + _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.full_like(x1, v0), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.full_like(x1, v1), relax.TensorStructInfo((2, 3), "int8")) + + +def test_full_like_infer_struct_info_fill_value_not_scalar_tensor(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + s0 = relax.Var("s", relax.ShapeStructInfo((1,))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + v0 = relax.Var("v", R.Tensor((1,), "float32")) + v1 = relax.Var("v", R.Tensor("float32", ndim=1)) + v2 = relax.Var("v", R.Tensor("float32")) + v3 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) + v4 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) + v5 = relax.Var("v", relax.TensorStructInfo(s2, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v5)) + + +def test_full_like_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((), "float32"))) + x2 = relax.Var("x", R.Tensor((2, 3))) + v0 = relax.Var("v", R.Tensor(())) + v1 = relax.Var("v", relax.ShapeStructInfo(())) + + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x0, v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x1, v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x2, v1)) + + +def test_ones_zeros_infer_struct_info(): + bb = relax.BlockBuilder() + s0 = relax.ShapeExpr((2, 3)) + s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + + _check_inference( + bb, relax.op.ones((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32") + ) + _check_inference(bb, relax.op.ones(s0, "float32"), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.ones(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.ones(s2, "float32"), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.ones(s3, "float32"), relax.TensorStructInfo(s3, "float32")) + _check_inference( + bb, relax.op.zeros((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32") + ) + _check_inference(bb, relax.op.zeros(s0, "float32"), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.zeros(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.zeros(s2, "float32"), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.zeros(s3, "float32"), relax.TensorStructInfo(s3, "float32")) + + +def test_ones_zeros_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + s0 = relax.ShapeExpr((m, n)) + s1 = relax.Var("s", relax.ShapeStructInfo((m, n))) + + _check_inference( + bb, relax.op.ones((m, n), "float32"), relax.TensorStructInfo((m, n), "float32") + ) + _check_inference(bb, relax.op.ones(s0, "float32"), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.ones(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + _check_inference( + bb, relax.op.zeros((m, n), "float32"), relax.TensorStructInfo((m, n), "float32") + ) + _check_inference(bb, relax.op.zeros(s0, "float32"), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.zeros(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + + +def test_ones_zeros_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + s0 = relax.ShapeExpr((2, 3)) + s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + + _check_inference(bb, relax.op.ones(s0, "float16"), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.ones(s1, "int8"), relax.TensorStructInfo(s1, "int8")) + _check_inference(bb, relax.op.zeros(s2, "int32"), relax.TensorStructInfo(s2, "int32")) + _check_inference(bb, relax.op.zeros(s3, "float64"), relax.TensorStructInfo(s3, "float64")) + + +def test_ones_zeros_shape_not_tuple(): + m = tir.Var("m", "int64") + + with pytest.raises(TVMError): + relax.op.ones(10, "float32") + with pytest.raises(TVMError): + relax.op.zeros(m, "float32") + + +def test_ones_zeros_wrong_dtype(): + with pytest.raises(TypeError): + relax.op.ones((2, 3)) + with pytest.raises(TVMError): + relax.op.ones((2, 3), "") + with pytest.raises(TypeError): + relax.op.zeros((2, 3)) + with pytest.raises(TVMError): + relax.op.zeros((2, 3), "") + + +def test_ones_zeros_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", R.Tensor((2, 3))) + s1 = relax.Var("s", relax.FuncStructInfo([], R.Tensor((2, 3)))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ones(s0, "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.zeros(s1, "float32")) + + +def test_ones_like_zeros_like_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor(ndim=2)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.ones_like(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.zeros_like(x3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.ones_like(x4), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.zeros_like(x5), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.ones_like(x0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.zeros_like(x3, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + + +def test_ones_like_zeros_like_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + + _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo((m, n), dtype="")) + + +def test_ones_like_zeros_like_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.zeros_like(x2), relax.TensorStructInfo(s2, "float32")) + + +def test_ones_like_zeros_like_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + + _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo((2, 3), "int8")) + + +def test_ones_like_zeros_like_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ones_like(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.zeros_like(x1)) + + +def test_tril_triu_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.tril(x0, k=1), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference(bb, relax.op.triu(x0, k=0), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.triu(x3), relax.TensorStructInfo((2, 3, 4), dtype="")) + _check_inference(bb, relax.op.tril(x4), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.triu(x5), relax.TensorStructInfo(dtype="")) + + +def test_tril_triu_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((a, b, c), "float32")) + x1 = relax.Var("x", R.Tensor((a, b, c))) + + _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((a, b, c), "float32")) + _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo((a, b, c), dtype="")) + + +def test_tril_triu_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo(s2, "float32")) + + +def test_tril_triu_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32")) + + _check_inference(bb, relax.op.triu(x0), relax.TensorStructInfo((2, 3, 4), "float16")) + _check_inference(bb, relax.op.tril(x1), relax.TensorStructInfo((2, 3, 4), "int8")) + _check_inference(bb, relax.op.triu(x2), relax.TensorStructInfo((2, 3, 4), "int32")) + + +def test_tril_triu_infer_struct_info_less_than_two_ndim(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2,))) + s1 = relax.Var("s", relax.ShapeStructInfo(())) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + x0 = relax.Var("x", R.Tensor((2,), "float32")) + x1 = relax.Var("x", R.Tensor((), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=1)) + x3 = relax.Var("x", R.Tensor("float32", ndim=0)) + x4 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x6 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x7 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x5)) + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x7)) + + +def test_tril_triu_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_create.py b/tests/python/relax/test_tvmscript_parser_op_create.py new file mode 100644 index 000000000000..6cbc0ebf906a --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_create.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_full(): + @R.function + def foo(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.full((2, 3), v, dtype="float32") + return gv + + bb = relax.BlockBuilder() + v = relax.Var("v", R.Tensor((), "int32")) + with bb.function("foo", [v]): + gv = bb.emit(relax.op.full((2, 3), v, "float32")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_full_like(): + @R.function + def foo( + x: R.Tensor((2, 3), "float16"), v: R.Tensor((), "float32") + ) -> R.Tensor((2, 3), "float16"): + gv: R.Tensor((2, 3), "float16") = R.full_like(x, v) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float16")) + v = relax.Var("y", R.Tensor((), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, v]): + gv = bb.emit(relax.op.full_like(x, v)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_ones(): + @R.function + def foo(dumb_param: R.Tensor()) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.ones((2, 3), "float32") + return gv + + bb = relax.BlockBuilder() + dumb_param = relax.Var("dumb_param", R.Tensor()) + with bb.function("foo", [dumb_param]): + gv = bb.emit(relax.op.ones((2, 3), "float32")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_ones_like(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.ones_like(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.ones_like(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_zeros(): + @R.function + def foo(dumb_param: R.Tensor()) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.zeros((2, 3), "float32") + return gv + + bb = relax.BlockBuilder() + dumb_param = relax.Var("dumb_param", R.Tensor()) + with bb.function("foo", [dumb_param]): + gv = bb.emit(relax.op.zeros((2, 3), "float32")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_zeros_like(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.zeros_like(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.zeros_like(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_tril(): + @R.function + def foo(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.tril(x, k=2) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.tril(x, k=2)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_triu(): + @R.function + def foo(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.triu(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.triu(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() From b95a20a46eb315fdf31557a36994f55cc709378a Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 15:07:33 -0500 Subject: [PATCH 21/81] [Unity] Relax op: linear algebra (#13988) This PR is about the high-level tensor computation operators in Relax. This PR includes the linear algebra operators. Co-authored-by: Siyuan Fneg --- include/tvm/relax/attrs/linear_algebra.h | 44 ++++ python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/linear_algebra.py | 90 +++++++ python/tvm/relax/op/op_attrs.py | 5 + python/tvm/script/ir_builder/relax/ir.py | 4 + src/relax/op/tensor/linear_algebra.cc | 123 +++++++++ src/relax/op/tensor/linear_algebra.h | 49 ++++ tests/python/relax/test_op_linear_algebra.py | 244 ++++++++++++++++++ ...test_tvmscript_parser_op_linear_algebra.py | 80 ++++++ 9 files changed, 640 insertions(+) create mode 100644 include/tvm/relax/attrs/linear_algebra.h create mode 100644 python/tvm/relax/op/linear_algebra.py create mode 100644 src/relax/op/tensor/linear_algebra.cc create mode 100644 src/relax/op/tensor/linear_algebra.h create mode 100644 tests/python/relax/test_op_linear_algebra.py create mode 100644 tests/python/relax/test_tvmscript_parser_op_linear_algebra.py diff --git a/include/tvm/relax/attrs/linear_algebra.h b/include/tvm/relax/attrs/linear_algebra.h new file mode 100644 index 000000000000..4b0e04298c9e --- /dev/null +++ b/include/tvm/relax/attrs/linear_algebra.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/linear_algebra.h + * \brief Attributes for linear algebra operators. + */ +#ifndef TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_ +#define TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes for matmul operator */ +struct MatmulAttrs : public tvm::AttrsNode { + DataType out_dtype; + + TVM_DECLARE_ATTRS(MatmulAttrs, "relax.attrs.MatmulAttrs") { + TVM_ATTR_FIELD(out_dtype).describe("The data type of the output tensor"); + } +}; // struct MatmulAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 97d08c0946a0..4b2f990eaa27 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -23,6 +23,7 @@ from .create import * from .datatype import * from .index import * +from .linear_algebra import * from .manipulate import * from .op_attrs import * from .statistical import * diff --git a/python/tvm/relax/op/linear_algebra.py b/python/tvm/relax/op/linear_algebra.py new file mode 100644 index 000000000000..940861a97227 --- /dev/null +++ b/python/tvm/relax/op/linear_algebra.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Relax linear algebra operators""" +from typing import Optional, Union + +from tvm import DataType + +from ..expr import Expr +from . import _ffi_api +from .manipulate import permute_dims + + +def matmul(x1: Expr, x2: Expr, out_dtype: Optional[Union[str, DataType]] = None) -> Expr: + """General matrix multiplication of two tensors, with broadcasting on batched dimensions. + + The semantics and output shape deduction rule is specified as + https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + + x2 : relax.Expr + The second input tensor. + + out_dtype: Optional[Union[str, DataType]] + The data type of the matmul result. + When it is not specified, the output dtype will be the the same as input dtype. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.matmul(x1, x2, out_dtype) # type: ignore + + +def linear( + data: Expr, + weight: Expr, + bias: Optional[Expr] = None, + out_dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + """Applies a linear transformation to the incoming data: y = xA^T + b + + Parameters + ---------- + data : relax.Expr + The input data. + + weight : relax.Expr + The weight tensor. + + bias : Optional[Expr] + The bias tensor. + + out_dtype: Optional[Union[str, DataType]] + The data type of the matmul result. + When it is not specified, the output dtype will be the the same as input dtype. + + Notes + ----- + Relax does not regard the Linear Op as a primitive Op, + while combine the transpose, matmul and add op to implement it. + + Returns + ------- + result : relax.Expr + The computed result. + """ + + # Since weight can be 1D or 2D, we use `axes=None` to support both cases. + x = matmul(data, permute_dims(weight, axes=None), out_dtype=out_dtype) + return x + bias if bias is not None else x diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index ac6714d940d3..3a7ed427f9bf 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -44,6 +44,11 @@ class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" +@tvm._ffi.register_object("relax.attrs.MatmulAttrs") +class MatmulAttrs(Attrs): + """Attributes for matmul operator""" + + @tvm._ffi.register_object("relax.attrs.Conv2DAttrs") class Conv2DAttrs(Attrs): """Attributes for nn.conv2d""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 118790372a35..9f5fe03decfb 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -63,8 +63,10 @@ isnan, less, less_equal, + linear, log, make_closure, + matmul, max, mean, memory, @@ -504,8 +506,10 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "isnan", "less", "less_equal", + "linear", "log", "make_closure", + "matmul", "max", "mean", "memory", diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc new file mode 100644 index 000000000000..50b53d0c8e66 --- /dev/null +++ b/src/relax/op/tensor/linear_algebra.cc @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file linear_algebra.cc + * \brief Linear algebra operators. + */ + +#include "linear_algebra.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.matmul */ +TVM_REGISTER_NODE_TYPE(MatmulAttrs); + +Expr matmul(Expr x1, Expr x2, DataType out_dtype) { + ObjectPtr attrs = make_object(); + attrs->out_dtype = out_dtype; + + static const Op& op = Op::Get("relax.matmul"); + return Call(op, {std::move(x1), std::move(x2)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.matmul").set_body_typed(matmul); + +StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo x1_sinfo = input_sinfo[0]; + TensorStructInfo x2_sinfo = input_sinfo[1]; + + const auto* attrs = call->attrs.as(); + DataType out_dtype = attrs->out_dtype.is_void() + ? InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo) + : attrs->out_dtype; + + if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) { + return TensorStructInfo(out_dtype, kUnknownNDim); + } + int x1_ndim = x1_sinfo->ndim; + int x2_ndim = x2_sinfo->ndim; + if (x1_ndim == 0 || x2_ndim == 0) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Matmul requires both inputs to have at least 1 dimension. However, " + << (x1_ndim == 0 ? "x1" : "x2") << " is a 0-rank tensor."); + } + + int x1_prepended = 0; + int x2_appended = 0; + if (x1_ndim == 1) { + x1_ndim = 2; + x1_prepended = 1; + } + if (x2_ndim == 1) { + x2_ndim = 2; + x2_appended = 1; + } + int output_ndim = std::max(x1_ndim, x2_ndim) - x1_prepended - x2_appended; + + const auto* x1_shape = x1_sinfo->shape.as(); + const auto* x2_shape = x2_sinfo->shape.as(); + if (x1_shape == nullptr || x2_shape == nullptr) { + return TensorStructInfo(out_dtype, output_ndim); + } + + Array x1_shape_prefix{x1_shape->values.begin(), + x1_shape->values.end() - 2 + x1_prepended}; + Array x2_shape_prefix{x2_shape->values.begin(), + x2_shape->values.end() - 2 + x2_appended}; + Optional> output_shape_prefix = + InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix); + if (!output_shape_prefix.defined()) { + return TensorStructInfo(out_dtype, output_ndim); + } + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr x1_reduction_length = x1_shape->values[x1_sinfo->ndim - 1]; + PrimExpr x2_reduction_length = x2_shape->values[x2_ndim - 2]; + if (analyzer->CanProve(x1_reduction_length != x2_reduction_length)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Matmul requires the reduction length of x1 and x2 to be equal. However, " + "the reduction lengths of x1 and x2 are " + << x1_reduction_length << " and " << x2_reduction_length << " respectively."); + } + + Array output_shape = output_shape_prefix.value(); + if (!x1_prepended) { + output_shape.push_back(x1_shape->values[x1_ndim - 2]); + } + if (!x2_appended) { + output_shape.push_back(x2_shape->values[x2_ndim - 1]); + } + ICHECK_EQ(static_cast(output_shape.size()), output_ndim); + return TensorStructInfo(ShapeExpr(output_shape), out_dtype); +} + +TVM_REGISTER_OP("relax.matmul") + .set_num_inputs(2) + .add_argument("x1", "Tensor", "The first input tensor.") + .add_argument("x2", "Tensor", "The second input tensor.") + .set_attr("FInferStructInfo", InferStructInfoMatmul); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/linear_algebra.h b/src/relax/op/tensor/linear_algebra.h new file mode 100644 index 000000000000..af614c1f30d5 --- /dev/null +++ b/src/relax/op/tensor/linear_algebra.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file linear_algebra.h + * \brief The functions to make Relax linear algebra operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_ +#define TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief General matrix multiplication of two tensors. + * The semantics and output shape deduction rule is specified as + * https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html. + * \param x1 The first input tensor. + * \param x2 The second input tensor. + * \param out_dtype The data type of the matmul result. + * When it is not specified, the output dtype will be the the same as input dtype. + * \return The computed result. + */ +Expr matmul(Expr x1, Expr x2, DataType out_dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_ diff --git a/tests/python/relax/test_op_linear_algebra.py b/tests/python/relax/test_op_linear_algebra.py new file mode 100644 index 000000000000..5eb19cf2b420 --- /dev/null +++ b/tests/python/relax/test_op_linear_algebra.py @@ -0,0 +1,244 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((3, 4), "float32")) + assert relax.op.matmul(x, y).op == Op.get("relax.matmul") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_matmul_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((4,), "float32")) + x2 = relax.Var("x", R.Tensor((2, 3, 5, 4), "float32")) + x3 = relax.Var("x", R.Tensor((2, 1, 4, 5), "float32")) + x4 = relax.Var("x", R.Tensor((2, 1, 4, 5))) + x5 = relax.Var("x", R.Tensor("float32")) + x6 = relax.Var("x", R.Tensor((2, 1, 4, 5), "float16")) + y0 = relax.Var("y", R.Tensor((4, 5), "float32")) + y1 = relax.Var("y", R.Tensor((4,), "float32")) + y2 = relax.Var("y", R.Tensor((2, 3, 4, 5), "float32")) + y3 = relax.Var("y", R.Tensor((6, 1, 3, 5, 7), "float32")) + y4 = relax.Var("y", R.Tensor("float32", ndim=5)) + y5 = relax.Var("y", R.Tensor()) + + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((3, 5), "float32")) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((), "float32")) + _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo((2, 3, 5), "float32")) + _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo((2, 3, 5), "float32")) + _check_inference( + bb, relax.op.matmul(x3, y3), relax.TensorStructInfo((6, 2, 3, 4, 7), "float32") + ) + _check_inference(bb, relax.op.matmul(x4, y3), relax.TensorStructInfo((6, 2, 3, 4, 7), "")) + _check_inference(bb, relax.op.matmul(x3, y4), relax.TensorStructInfo(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.matmul(x5, y3), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.matmul(x3, y5), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, + relax.op.matmul(x3, y3, out_dtype="float16"), + relax.TensorStructInfo((6, 2, 3, 4, 7), "float16"), + ) + _check_inference( + bb, + relax.op.matmul(x6, y3, out_dtype="float16"), + relax.TensorStructInfo((6, 2, 3, 4, 7), "float16"), + ) + + +def test_matmul_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + k0 = tir.Var("k0", "int64") + k1 = tir.Var("k1", "int64") + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + b1 = tir.Var("b", "int64") + c = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((m, k0), "float32")) + x1 = relax.Var("x", R.Tensor((k0,), "float32")) + x2 = relax.Var("x", R.Tensor((a, b, m, k0), "float32")) + x3 = relax.Var("x", R.Tensor((b, 1, m, k0), "float32")) + x4 = relax.Var("x", R.Tensor((b, 1, m, k1), "float32")) + y0 = relax.Var("y", R.Tensor((k0, n), "float32")) + y1 = relax.Var("y", R.Tensor((k0,), "float32")) + y2 = relax.Var("y", R.Tensor((a, b, k0, n), "float32")) + y3 = relax.Var("y", R.Tensor((a, 1, c, k0, n), "float32")) + y4 = relax.Var("y", R.Tensor((a, b1, c, k0, n), "float32")) + + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((), "float32")) + _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo((a, b, n), "float32")) + _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo((a, b, m), "float32")) + _check_inference( + bb, relax.op.matmul(x3, y3), relax.TensorStructInfo((a, b, c, m, n), "float32") + ) + _check_inference( + bb, relax.op.matmul(x4, y3), relax.TensorStructInfo((a, b, c, m, n), "float32") + ) + _check_inference(bb, relax.op.matmul(x3, y4), relax.TensorStructInfo(dtype="float32", ndim=5)) + + +def test_matmul_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) + s3 = relax.Var("s4", relax.ShapeStructInfo(ndim=1)) + s5 = relax.Var("s5", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s5, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + y1 = relax.Var("y", relax.TensorStructInfo(s2, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) + + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.matmul(x1, y0), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.matmul(x2, y0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.matmul(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo(dtype="float32", ndim=0)) + _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo(dtype="float32", ndim=0)) + + +def test_matmul_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float16")) + y0 = relax.Var("y", R.Tensor((4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((3, 4), "int8")) + y1 = relax.Var("y", R.Tensor((4, 5), "int8")) + x2 = relax.Var("x", R.Tensor((3, 4), "int64")) + y2 = relax.Var("y", R.Tensor((4, 5), "int64")) + + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((3, 5), "float16")) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((3, 5), "int8")) + _check_inference(bb, relax.op.matmul(x2, y2), relax.TensorStructInfo((3, 5), "int64")) + + +def test_matmul_infer_struct_info_mixed_precision(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float16")) + y0 = relax.Var("y", R.Tensor((4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((3, 4), "int8")) + y1 = relax.Var("y", R.Tensor((4, 5), "int8")) + x2 = relax.Var("x", R.Tensor((3, 4))) + y2 = relax.Var("y", R.Tensor((4, 5))) + + _check_inference( + bb, + relax.op.matmul(x0, y0, out_dtype="float32"), + relax.TensorStructInfo((3, 5), "float32"), + ) + _check_inference( + bb, relax.op.matmul(x1, y1, out_dtype="int32"), relax.TensorStructInfo((3, 5), "int32") + ) + _check_inference( + bb, + relax.op.matmul(x2, y2, out_dtype="float32"), + relax.TensorStructInfo((3, 5), "float32"), + ) + + +def test_matmul_infer_struct_info_zero_rank_input(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((), "float32")) + y0 = relax.Var("y", R.Tensor((4, 5), "float32")) + y1 = relax.Var("y", R.Tensor((), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x0, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x1, y0)) + + +def test_matmul_infer_struct_info_not_broadcastable(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + y = relax.Var("y", R.Tensor((2, 8, 3, 5, 6), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x, y)) + + +def test_matmul_infer_struct_info_unequal_reduction_length(): + bb = relax.BlockBuilder() + k = tir.Var("k", "int64") + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((3, k), "float32")) + y0 = relax.Var("y", R.Tensor((6, 5), "float32")) + y1 = relax.Var("y", R.Tensor((k + 1, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x0, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x1, y1)) + + +def test_linear(): + # Since linear is only a sugar for transpose + matmul + add, + # we only have brief tests here. + bb = relax.BlockBuilder() + x1 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x2 = relax.Var("x", R.Tensor("float32")) + w1 = relax.Var("w", R.Tensor((5, 4), "float32")) + w2 = relax.Var("w", R.Tensor((4,), "float32")) + w3 = relax.Var("w", R.Tensor("float32")) + b1 = relax.Var("b", R.Tensor((5,), "float32")) + b2 = relax.Var("b", R.Tensor((), "float32")) + + # Need a scope to normalize non-leaf nodes + with bb.function("func", [x1]): + _check_inference( + bb, relax.op.linear(x1, w1, b1), relax.TensorStructInfo((2, 3, 5), "float32") + ) + _check_inference( + bb, relax.op.linear(x1, w1, b2), relax.TensorStructInfo((2, 3, 5), "float32") + ) + with pytest.raises(TVMError): + bb.normalize(relax.op.linear(x1, w2, b1)) # error on Add with shape (2, 3, 5) and (4,) + _check_inference(bb, relax.op.linear(x1, w2, b2), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.linear(x1, w3, b1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x1, w3, b2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w1, b1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w1, b2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w2, b1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w2, b2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w3, b1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w3, b2), relax.TensorStructInfo(dtype="float32")) + + # Fake output + gv = bb.emit_func_output(relax.Tuple([])) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_linear_algebra.py b/tests/python/relax/test_tvmscript_parser_op_linear_algebra.py new file mode 100644 index 000000000000..1ed7fa9b917c --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_linear_algebra.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_matmul(): + @R.function + def foo( + x: R.Tensor((2, 3, 4, 5), "float32"), y: R.Tensor((6, 2, 3, 5, 7), "float32") + ) -> R.Tensor((6, 2, 3, 4, 7), "float32"): + gv: R.Tensor((6, 2, 3, 4, 7), "float32") = R.matmul(x, y) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + y = relax.Var("y", R.Tensor((6, 2, 3, 5, 7), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y]): + gv = bb.emit(relax.op.matmul(x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_linear(): + @R.function + def foo( + x: R.Tensor((2, 3, 4, 5), "float32"), + w: R.Tensor((3, 5), "float32"), + bias: R.Tensor((3,), "float32"), + ): + gv = R.linear(x, w, bias) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + w = relax.Var("y", R.Tensor((3, 5), "float32")) + bias = relax.Var("bias", R.Tensor((3,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, w, bias]): + w_T = bb.emit(relax.op.permute_dims(w, axes=None)) + matmul = bb.emit(relax.op.matmul(x, w_T)) + out = matmul + bias + bb.emit_func_output(out) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() From 4577c986b8169ad718be10a1328e6312003d521b Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 15:09:10 -0500 Subject: [PATCH 22/81] [Unity] Relax op: search (#13992) This PR is about the high-level tensor computation operators in Relax. This PR includes the search operators. --- python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/search.py | 50 ++++ python/tvm/script/ir_builder/relax/ir.py | 4 +- src/relax/op/tensor/search.cc | 99 +++++++ src/relax/op/tensor/search.h | 41 +++ tests/python/relax/test_op_search.py | 278 ++++++++++++++++++ .../relax/test_tvmscript_parser_op_search.py | 60 ++++ 7 files changed, 532 insertions(+), 1 deletion(-) create mode 100644 python/tvm/relax/op/search.py create mode 100644 src/relax/op/tensor/search.cc create mode 100644 src/relax/op/tensor/search.h create mode 100644 tests/python/relax/test_op_search.py create mode 100644 tests/python/relax/test_tvmscript_parser_op_search.py diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 4b2f990eaa27..39a645ffea54 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -27,6 +27,7 @@ from .manipulate import * from .op_attrs import * from .statistical import * +from .search import * from .set import * from .ternary import * from .unary import * diff --git a/python/tvm/relax/op/search.py b/python/tvm/relax/op/search.py new file mode 100644 index 000000000000..8252b0e1d851 --- /dev/null +++ b/python/tvm/relax/op/search.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Search operators.""" +from . import _ffi_api +from ..expr import Expr + + +def where(condition: Expr, x1: Expr, x2: Expr) -> Expr: + """Selecting elements from either the input tensors depending on the value of the + condition. + + For a given position, return the corresponding value in `x1` if `condition` is True, + and return the corresponding value in `x2` otherwise. + + Parameters + ---------- + condition : relax.Expr + When True, yield `x1`; otherwise, yield `x2`. + Must be broadcasting compatible with `x1` and `x2`. + Must have boolean dtype. + + x1 : relax.Expr + The first input tensor. + Must be broadcasting compatible with `condition` and `x2`. + + x2 : relax.Expr + The second input tensor. + Must be broadcasting compatible with `condition` and `x1`. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.where(condition, x1, x2) # type: ignore diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 9f5fe03decfb..b779bdac9c13 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -101,6 +101,7 @@ tril, triu, unique, + where, zeros, zeros_like, nn, @@ -547,8 +548,9 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "tril", "triu", "tuple", - "variance", "unique", + "variance", + "where", "zeros", "zeros_like", "nn", diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc new file mode 100644 index 000000000000..5191017ea17f --- /dev/null +++ b/src/relax/op/tensor/search.cc @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file search.cc + * \brief Searching operators. + */ + +#include "search.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.where */ +Expr where(Expr condition, Expr x1, Expr x2) { + static const Op& op = Op::Get("relax.where"); + return Call(op, {std::move(condition), std::move(x1), std::move(x2)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.where").set_body_typed(where); + +StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo cond_sinfo = input_sinfo[0]; + TensorStructInfo x1_sinfo = input_sinfo[1]; + TensorStructInfo x2_sinfo = input_sinfo[2]; + + if (!cond_sinfo->dtype.is_bool()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Where requires the input condition tensor to have boolean dtype. However, " + "the given condition dtype is " + << cond_sinfo->dtype); + } + DataType output_dtype = InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo); + + int output_ndim; + if (cond_sinfo->IsUnknownNdim() || x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) { + output_ndim = kUnknownNDim; + } else { + output_ndim = std::max(cond_sinfo->ndim, std::max(x1_sinfo->ndim, x2_sinfo->ndim)); + } + + const auto* cond_shape = cond_sinfo->shape.as(); + const auto* x1_shape = x1_sinfo->shape.as(); + const auto* x2_shape = x2_sinfo->shape.as(); + if (cond_shape && x1_shape && x2_shape) { + // Step 1. Compute the broadcasted shape of x1's and x2's + Optional> broadcasted_shape = + InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values); + if (!broadcasted_shape.defined()) { + return TensorStructInfo(output_dtype, output_ndim); + } + // Step 2. Compute the broadcasted shape of cond's and the previous broadcasted shape. + broadcasted_shape = + InferBinaryBroadcastShape(call, ctx, cond_shape->values, broadcasted_shape.value()); + if (!broadcasted_shape.defined()) { + return TensorStructInfo(output_dtype, output_ndim); + } + ICHECK_EQ(static_cast(broadcasted_shape.value().size()), output_ndim); + return TensorStructInfo(ShapeExpr(broadcasted_shape.value()), output_dtype); + } else if (cond_sinfo->shape.defined() && // + x1_sinfo->shape.defined() && // + x2_sinfo->shape.defined() && // + cond_sinfo->shape.same_as(x1_sinfo->shape) && // + cond_sinfo->shape.same_as(x2_sinfo->shape)) { + return TensorStructInfo(cond_sinfo->shape.value(), output_dtype); + } else { + return TensorStructInfo(output_dtype, output_ndim); + } +} + +TVM_REGISTER_OP("relax.where") + .set_num_inputs(3) + .add_argument("condition", "Tensor", "When True, yield `x1`; otherwise, yield `x2`.") + .add_argument("x1", "Tensor", "The first input tensor.") + .add_argument("x2", "Tensor", "The second input tensor.") + .set_attr("FInferStructInfo", InferStructInfoWhere); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/search.h b/src/relax/op/tensor/search.h new file mode 100644 index 000000000000..aeae4a7157b3 --- /dev/null +++ b/src/relax/op/tensor/search.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file search.h + * \brief The functions to make Relax searching operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_SEARCH_H_ +#define TVM_RELAX_OP_TENSOR_SEARCH_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Selecting elements from either the input tensors depending on the value of the + * condition. + */ +Expr where(Expr condition, Expr x1, Expr x2); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_SEARCH_H_ diff --git a/tests/python/relax/test_op_search.py b/tests/python/relax/test_op_search.py new file mode 100644 index 000000000000..a2f271671ba6 --- /dev/null +++ b/tests/python/relax/test_op_search.py @@ -0,0 +1,278 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + cond = relax.Var("cond", R.Tensor((2, 3), "bool")) + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("x", R.Tensor((2, 3), "float32")) + assert relax.op.where(cond, x, y).op == Op.get("relax.where") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_where_infer_struct_info(): + bb = relax.BlockBuilder() + cond0 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool")) + cond1 = relax.Var("cond", R.Tensor("bool", ndim=5)) + cond2 = relax.Var("cond", R.Tensor("bool")) + x0 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((5, 1, 3, 2))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((4, 3, 1), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=3)) + y2 = relax.Var("y", R.Tensor("float32")) + y3 = relax.Var("y", R.Tensor((4, 3, 1))) + y4 = relax.Var("y", R.Tensor(ndim=3)) + y5 = relax.Var("y", R.Tensor()) + + _check_inference( + bb, relax.op.where(cond0, x0, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), "float32") + ) + _check_inference( + bb, relax.op.where(cond0, x1, y0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond0, x2, y0), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.where(cond0, x3, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), dtype="") + ) + _check_inference(bb, relax.op.where(cond0, x4, y0), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y0), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.where(cond0, x1, y1), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond0, x2, y1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x3, y1), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x4, y1), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x2, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x3, y2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x4, y2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x5, y2), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.where(cond0, x3, y3), relax.TensorStructInfo((6, 5, 4, 3, 2), dtype="") + ) + _check_inference(bb, relax.op.where(cond0, x4, y3), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y3), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x4, y4), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y4), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x5, y5), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.where(cond1, x0, y0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond1, x2, y0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond2, x0, y0), relax.TensorStructInfo(dtype="float32")) + + +def test_where_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d0 = tir.Var("d", "int64") + d1 = tir.Var("d", "int64") + e = tir.Var("e", "int64") + cond = relax.Var("cond", R.Tensor((a, b, 1, d0, 1), "bool")) + x0 = relax.Var("x", R.Tensor((b, 1, d0, e), "float32")) + x1 = relax.Var("x", R.Tensor((b, 1, d1, e), "float32")) + x2 = relax.Var("x", R.Tensor((b, 1, d0, e))) + y0 = relax.Var("y", R.Tensor((c, d0, 1), "float32")) + y1 = relax.Var("y", R.Tensor((c, d0, 1))) + + _check_inference( + bb, relax.op.where(cond, x0, y0), relax.TensorStructInfo((a, b, c, d0, e), "float32") + ) + _check_inference( + bb, relax.op.where(cond, x1, y0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.where(cond, x2, y0), relax.TensorStructInfo((a, b, c, d0, e), dtype="") + ) + _check_inference( + bb, relax.op.where(cond, x0, y1), relax.TensorStructInfo((a, b, c, d0, e), dtype="") + ) + _check_inference(bb, relax.op.where(cond, x1, y1), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference( + bb, relax.op.where(cond, x2, y1), relax.TensorStructInfo((a, b, c, d0, e), dtype="") + ) + + +def test_where_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + scond0 = relax.Var("scond", relax.ShapeStructInfo((6, 5, 1, 3, 1))) + scond1 = relax.Var("scond", relax.ShapeStructInfo(ndim=5)) + scond2 = relax.Var("scond", relax.ShapeStructInfo()) + sx0 = relax.Var("sx", relax.ShapeStructInfo((5, 1, 3, 2))) + sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=4)) + sx2 = relax.Var("sx", relax.ShapeStructInfo()) + sy0 = relax.Var("sy", relax.ShapeStructInfo((4, 3, 1))) + sy1 = relax.Var("sy", relax.ShapeStructInfo(ndim=3)) + sy2 = relax.Var("sy", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeStructInfo((6, 5, 4, 3, 2))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + cond0 = relax.Var("cond", relax.TensorStructInfo(scond0, "bool")) + cond1 = relax.Var("cond", relax.TensorStructInfo(scond1, "bool")) + cond2 = relax.Var("cond", relax.TensorStructInfo(scond2, "bool")) + cond3 = relax.Var("cond", relax.TensorStructInfo(s0, "bool")) + cond4 = relax.Var("cond", relax.TensorStructInfo(s1, "bool")) + cond5 = relax.Var("cond", relax.TensorStructInfo(s2, "bool")) + x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(sy0, "float32")) + y1 = relax.Var("y", relax.TensorStructInfo(sy1, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(sy2, "float32")) + y3 = relax.Var("y", relax.TensorStructInfo(s0, "float32")) + y4 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + y5 = relax.Var("y", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.where(cond0, x0, y0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.where(cond0, x0, y1), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond0, x0, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.where(cond0, x1, y1), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond0, x1, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x2, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.where(cond1, x1, y1), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond1, x1, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond1, x2, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond2, x2, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond3, x3, y3), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, relax.op.where(cond3, x3, y4), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.where(cond3, x4, y3), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.where(cond4, x3, y3), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond4, x4, y4), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.where(cond4, x4, y5), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond4, x5, y4), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond5, x4, y4), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond5, x5, y5), relax.TensorStructInfo(s2, "float32")) + + +def test_where_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + cond = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool")) + x0 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float16")) + y0 = relax.Var("y", R.Tensor((4, 3, 1), "float16")) + x1 = relax.Var("x", R.Tensor((5, 1, 3, 2), "int8")) + y1 = relax.Var("y", R.Tensor((4, 3, 1), "int8")) + x2 = relax.Var("x", R.Tensor((5, 1, 3, 2), "int32")) + y2 = relax.Var("y", R.Tensor((4, 3, 1), "int32")) + + _check_inference( + bb, relax.op.where(cond, x0, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), "float16") + ) + _check_inference( + bb, relax.op.where(cond, x1, y1), relax.TensorStructInfo((6, 5, 4, 3, 2), "int8") + ) + _check_inference( + bb, relax.op.where(cond, x2, y2), relax.TensorStructInfo((6, 5, 4, 3, 2), "int32") + ) + + +def test_where_infer_struct_info_cond_not_boolean(): + bb = relax.BlockBuilder() + cond0 = relax.Var("cond", R.Tensor((2, 3), "float32")) + cond1 = relax.Var("cond", R.Tensor((2, 3))) + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond0, x, y)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x, y)) + + +def test_where_infer_struct_info_shape_unequal_const_int(): + bb = relax.BlockBuilder() + cond0 = relax.Var("cond", R.Tensor((6, 5, 1, 4, 1), "bool")) + cond1 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool")) + x0 = relax.Var("x", R.Tensor((5, 1, 4, 2), "float32")) + x1 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float32")) + y0 = relax.Var("y", R.Tensor((4, 4, 1), "float32")) + y1 = relax.Var("y", R.Tensor((4, 3, 1), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond0, x1, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x0, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x1, y0)) + + +def test_where_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + cond = relax.Var("cond", R.Tensor((2, 3), "bool")) + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + y1 = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond, x0, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond, x1, y1)) + + +def test_where_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + cond0 = relax.Var("cond", relax.ShapeStructInfo((2, 3))) + cond1 = relax.Var("cond", R.Tensor((2, 3), "bool")) + x0 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x1 = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", relax.TupleStructInfo([R.Tensor((2, 3), "float32")])) + y1 = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond0, x1, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x0, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x1, y0)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_search.py b/tests/python/relax/test_tvmscript_parser_op_search.py new file mode 100644 index 000000000000..a8eaa814aa2e --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_search.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_where(): + @R.function + def foo( + condition: R.Tensor((2, 1), "bool"), + x: R.Tensor((2, 3), "float32"), + y: R.Tensor((1, 3), "float32"), + ) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.where(condition, x, y) + return gv + + bb = relax.BlockBuilder() + condition = relax.Var("condition", R.Tensor((2, 1), "bool")) + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((1, 3), "float32")) + with bb.function("foo", [condition, x, y]): + gv = bb.emit(relax.op.where(condition, x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() From 24704357dc90d0784b6dde5df67a2a6f935120a9 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 15:10:58 -0500 Subject: [PATCH 23/81] [Unity] Relax op: manipulation (#13989) This PR is about the high-level tensor computation operators in Relax. This PR includes the tensor manipulation operators. Co-authored-by: Prakalp Srivastava --- include/tvm/relax/attrs/manipulate.h | 108 + python/tvm/relax/op/manipulate.py | 207 +- python/tvm/relax/op/op_attrs.py | 30 + python/tvm/script/ir_builder/relax/ir.py | 20 +- src/relax/op/tensor/manipulate.cc | 688 ++++- src/relax/op/tensor/manipulate.h | 73 + tests/python/relax/test_op_manipulate.py | 2373 +++++++++++++++++ .../test_tvmscript_parser_op_manipulate.py | 314 +++ 8 files changed, 3803 insertions(+), 10 deletions(-) create mode 100644 include/tvm/relax/attrs/manipulate.h create mode 100644 tests/python/relax/test_op_manipulate.py create mode 100644 tests/python/relax/test_tvmscript_parser_op_manipulate.py diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h new file mode 100644 index 000000000000..bd6ae17bcf1c --- /dev/null +++ b/include/tvm/relax/attrs/manipulate.h @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/manipulate.h + * \brief Attributes for tensor manipulation operators. + */ +#ifndef TVM_RELAX_ATTRS_MANIPULATE_H_ +#define TVM_RELAX_ATTRS_MANIPULATE_H_ + +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in concat operators */ +struct ConcatAttrs : public tvm::AttrsNode { + Optional axis; + + TVM_DECLARE_ATTRS(ConcatAttrs, "relax.attrs.ConcatAttrs") { + TVM_ATTR_FIELD(axis).describe( + "The axis at which the input arrays are concatenated." + "Should lie in range `[-ndim, ndim)`."); + } +}; // struct ConcatAttrs + +/*! \brief Attributes used in expand_dims operators */ +struct ExpandDimsAttrs : public tvm::AttrsNode { + Array axis; + + TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relax.attrs.ExpandDimsAttrs") { + TVM_ATTR_FIELD(axis).describe( + "The axes at which the input array are expanded. " + "All values are required to lie in range `[-data.ndim - 1, data.ndim]`, " + "with the convention of negative indexing."); + } +}; // struct ExpandDimsAttrs + +/*! \brief Attributes used in layout_transform operator */ +struct LayoutTransformAttrs : public tvm::AttrsNode { + tir::IndexMap index_map; + // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This + // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future. + Optional pad_value; + + TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") { + TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply."); + TVM_ATTR_FIELD(pad_value).describe( + "The specific value to be used to pad if the layout transform would result in implicit " + "padding. If not specified, the compiler is free to choose any value."); + } +}; // struct LayoutTransformAttrs + +/*! \brief Attributes used in permute_dims operator */ +struct PermuteDimsAttrs : public tvm::AttrsNode { + Optional> axes; + + TVM_DECLARE_ATTRS(PermuteDimsAttrs, "relax.attrs.PermuteDimsAttrs") { + TVM_ATTR_FIELD(axes).describe("The target axes order, reverse order if not specified."); + } +}; // struct PermuteDimsAttrs + +/*! \brief Attributes used in split operator */ +struct SplitAttrs : public tvm::AttrsNode { + ObjectRef indices_or_sections; + int axis; + + TVM_DECLARE_ATTRS(SplitAttrs, "relax.attrs.SplitAttrs") { + TVM_ATTR_FIELD(indices_or_sections) + .describe("The input array of indices or the number of split sections."); + TVM_ATTR_FIELD(axis).describe("The axis to be splitted"); + } +}; // struct SplitAttrs + +/*! \brief Attributes used in squeeze operators */ +struct SqueezeAttrs : public tvm::AttrsNode { + Optional> axis; + + TVM_DECLARE_ATTRS(SqueezeAttrs, "relax.attrs.SqueezeAttrs") { + TVM_ATTR_FIELD(axis).describe( + "The axis to squeeze in the input tensor." + "If `axis = None`, all axis of dimension 1 get squeezed;" + "Else, the dimension in axes get squeezed." + "It is an error if an axis does not has dimension 1."); + } +}; // struct SqueezeAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_MANIPULATE_H_ diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index fa9c81522596..a46c62e1f12b 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -15,18 +15,161 @@ # specific language governing permissions and limitations # under the License. """Manipulation operators.""" -from typing import Tuple, Union +from typing import List, Optional, Tuple, Union, Callable from tvm.ir.expr import PrimExpr - +from tvm.tir import IntImm, FloatImm, IndexMap from . import _ffi_api -from ..expr import Expr +from ..expr import Expr, PrimValue, ShapeExpr, Tuple as RxTuple PrimExprLike = Union[int, PrimExpr] +def broadcast_to(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: + """Broadcasts a tensor to a specified shape. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + shape : Union[Tuple[PrimExprLike], Expr] + The target shape. + + Returns + ------- + result : relax.Expr + The broadcasted tensor. + """ + if isinstance(shape, (tuple, list)): + shape = ShapeExpr(shape) + return _ffi_api.broadcast_to(x, shape) # type: ignore + + +def concat(tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: + """Concatenate the input tensors along the given axis. + + Parameters + ---------- + tensors : Union[relax.Expr, List[relax.Expr]] + An Expr in Tuple type, containing the tensors to be concatenated, + or a list of Tensors. + + axis : Optional[int] + The axis along which the tensors are concatenated. + If `axis` is `None`, the input tensor is required to be flattened before concatenation. + + Returns + ------- + result: relax.Expr + The concatenated tensor. + """ + if isinstance(tensors, (list, tuple)): + tensors = RxTuple(tensors) + return _ffi_api.concat(tensors, axis) # type: ignore + + +def expand_dims(x: Expr, axis: Union[int, List[int]]) -> Expr: + """Insert new axes at the positions given by `axis`. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + axis : Union[int, List[int]] + The axes at which the input array are expanded. + All values are required to lie in range `[-data.ndim - 1, data.ndim]`, with the convention + of negative indexing. + + Returns + ------- + result : relax.Expr + The transformed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.expand_dims(x, axis) # type: ignore + + +def flatten(x: Expr) -> Expr: + """Flatten all the tensor dimensions into one. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + Returns + ------- + result : relax.Expr + The flattened result. + """ + return _ffi_api.flatten(x) # type: ignore + + +def layout_transform( + x: Expr, + index_map: Union[Callable, IndexMap], + pad_value: Optional[Union[int, float, PrimValue]] = None, +): + """Modifies the layout of a tensor. + + Parameters + ---------- + x : relax.Expr + The input tensor to the operator. + + index_map : Union[Callable, IndexMap] + The transformation to apply. + + pad_value : Optional[Union[int, float, PrimValue]] + The value used for padding if the transformation results in implicit padding. + If not specified, any value can be used. + + Returns + ------- + result : relax.Expr + The transformed tensor. + """ + if callable(index_map): + index_map = IndexMap.from_func(index_map) + x_dtype = x.checked_type.dtype + + # Explicitly convert python int/float pad_value to the x's type. If the default behavior + # is applied, it would be converted to int32/float32, which may not match the x's type. + if pad_value is None: + pass + elif not isinstance(pad_value, PrimValue): + if "int" in x_dtype and isinstance(pad_value, int): + pad_value = IntImm(x_dtype, pad_value) + elif "float" in x_dtype and (isinstance(pad_value, (int, float))): + pad_value = FloatImm(x_dtype, float(pad_value)) + pad_value = PrimValue(pad_value) + return _ffi_api.layout_transform(x, index_map, pad_value) # type: ignore + + +def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr: + """Permutes the dimensions of an array. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + axes : Optional[List[int]] + The target axes order, reverse order if not specified. + + Returns + ------- + result : relax.Expr + The transposed result. + """ + return _ffi_api.permute_dims(x, axes) # type: ignore + + def reshape(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: """Reshape the input array. @@ -60,3 +203,61 @@ def reshape(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: compile-time, an error will be thrown. """ return _ffi_api.reshape(x, shape) # type: ignore + + +def split( + x: Expr, + indices_or_sections: Union[int, List[PrimExprLike]], + axis: int = 0, +) -> Expr: + """Split input tensor along axis by sections or indices. + + If indices_or_sections is an integer, the input will be divided equally + along given axis (if possible). Last section will be smaller if the tensor + size along the given dimension is not divisible by the integer. + + If indices_or_sections is a tuple of mixture of int or PrimExpr, + the entries indicate the indices where along axis the array is split. + + Parameters + ---------- + x : relax.Expr + The tensor to be split. + + indices_or_sections : Union[int, List[PrimExprLike]] + Indices or sections to split into. Accepts an int or a list. + + axis : int + The axis over which to split. + + Returns + ------- + ret : relax.Expr + The computed result. + """ + if isinstance(indices_or_sections, int): + indices_or_sections = IntImm("int64", indices_or_sections) + return _ffi_api.split(x, indices_or_sections, axis) # type: ignore + + +def squeeze(x: Expr, axis: Optional[Union[int, List[int]]] = None) -> Expr: + """Squeeze axes in the array. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + axis : Optional[Union[int, List[int]] + The set of axes to remove. + If axis = None, remove all axis of dimensions 1. + If any specified axis has dimension that does not equal 1, it is an error. + + Returns + ------- + result : relax.Expr + The squeezed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.squeeze(x, axis) # type: ignore diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 3a7ed427f9bf..efad5d98f01a 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -89,6 +89,36 @@ class StatisticalAttrs(Attrs): """Attributes used in statistical operator""" +@tvm._ffi.register_object("relax.attrs.ConcatAttrs") +class ConcatAttrs(Attrs): + """Attributes for concat operator""" + + +@tvm._ffi.register_object("relax.attrs.ExpandDimsAttrs") +class ExpandDimsAttrs(Attrs): + """Attributes for expand_dims operator""" + + +@tvm._ffi.register_object("relax.attrs.PermuteDimsAttrs") +class PermuteDimsAttrs(Attrs): + """Attributes for permute_dims operator""" + + +@tvm._ffi.register_object("relax.attrs.SplitAttrs") +class SplitAttrs(Attrs): + """Attributes used in split operator""" + + +@tvm._ffi.register_object("relax.attrs.SqueezeAttrs") +class SqueezeAttrs(Attrs): + """Attributes for squeeze operator""" + + +@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs") +class LayoutTransformAttrs(Attrs): + """Attributes used in layout_transform operator""" + + @tvm._ffi.register_object("relax.attrs.Resize2DAttrs") class Resize2DAttrs(Attrs): """Attributes used in image resize2d operator""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index b779bdac9c13..7298b8c6e54f 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -39,17 +39,21 @@ add, assert_op, astype, + broadcast_to, builtin, call_builtin_with_ctx, call_tir, ceil, clip, + concat, cos, cosh, divide, equal, ewise_fma, exp, + expand_dims, + flatten, floor, floor_divide, full, @@ -61,6 +65,7 @@ isfinite, isinf, isnan, + layout_transform, less, less_equal, linear, @@ -75,6 +80,7 @@ negative, not_equal, null_value, + permute_dims, ones, ones_like, print, @@ -91,11 +97,11 @@ sign, sin, sinh, + split, square, + squeeze, sqrt, - strided_slice, subtract, - take, tan, tanh, tril, @@ -472,12 +478,14 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "arg", "assert_op", "astype", + "broadcast_to", "builtin", "call_packed", "call_tir", "call_builtin_with_ctx", "ceil", "clip", + "concat", "cos", "cosh", "const", @@ -489,6 +497,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "equal", "ewise_fma", "exp", + "expand_dims", + "flatten", "floor", "floor_divide", "full", @@ -505,6 +515,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "isfinite", "isinf", "isnan", + "layout_transform", "less", "less_equal", "linear", @@ -522,6 +533,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "ones", "ones_like", "output", + "permute_dims", "prim_value", "print", "prod", @@ -537,7 +549,9 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "sign", "sin", "sinh", + "split", "square", + "squeeze", "sqrt", "str", "strided_slice", @@ -553,5 +567,5 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "where", "zeros", "zeros_like", - "nn", + "nn", ] diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 2088a8306e7a..8ce2a541da53 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -32,6 +32,310 @@ namespace tvm { namespace relax { +/* relax.broadcast_to */ +Expr broadcast_to(Expr x, Expr shape) { + static const Op& op = Op::Get("relax.broadcast_to"); + return Call(op, {std::move(x), std::move(shape)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.broadcast_to").set_body_typed(broadcast_to); + +StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "broadcast_to should take 2 arguments."); + } + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* tgt_shape_sinfo = GetStructInfoAs(call->args[1]); + if (data_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "broadcast_to requires the input data to be Tensor. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (tgt_shape_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "broadcast_to requires the input new shape to be Shape. However, the given one is " + << call->args[1]->struct_info_->GetTypeKey()); + } + + if (!data_sinfo->IsUnknownNdim() && !tgt_shape_sinfo->IsUnknownNdim() && + tgt_shape_sinfo->ndim < data_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "broadcast_to expects the input shape to have the number of ndim at least " + "as the input tensor's. However, the given tensor has ndim " + << data_sinfo->ndim << " while the target shape has ndim " + << tgt_shape_sinfo->ndim); + } + + // Trust the input target shape when there is no possibility to do any compile-time check. + if (!data_sinfo->shape.defined()) { + return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype); + } + ShapeStructInfo shape_sinfo = Downcast(data_sinfo->shape.value()->struct_info_); + if (!shape_sinfo->values.defined() || !tgt_shape_sinfo->values.defined()) { + return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype); + } + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + Array old_shape_value = shape_sinfo->values.value(); + Array tgt_shape_value = tgt_shape_sinfo->values.value(); + int old_ndim = old_shape_value.size(); + int tgt_ndim = tgt_shape_value.size(); + for (int i = 0; i < old_ndim; ++i) { + PrimExpr old_len = old_shape_value[old_ndim - i - 1]; + PrimExpr tgt_len = tgt_shape_value[tgt_ndim - i - 1]; + const auto* old_len_int = old_len.as(); + if (old_len_int != nullptr && old_len_int->value == 1) { + continue; + } else if (analyzer->CanProve(old_len != tgt_len)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "broadcast_to expects the input tensor shape is broadcastable to the target shape. " + "The target shape at dim " + << tgt_ndim - i - 1 << " is " << tgt_len << " while the input tensor shape at dim " + << old_ndim - i - 1 << " is " << old_len << ", which are not equal."); + } + // Todo(relax-team): revisit here for better check on if the tensor length + // is consistent with the length in the given shape. + } + return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.broadcast_to") + .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("shape", "Shape", "The target shape.") + .set_attr("FInferStructInfo", InferStructInfoBroadcastTo); + +/* relax.concat */ +TVM_REGISTER_NODE_TYPE(ConcatAttrs); + +Expr concat(Expr tensors, Optional axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.concat"); + return Call(op, {std::move(tensors)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.concat").set_body_typed(concat); + +Array GetTensorSInfoFromTuple(const Call& call, const BlockBuilder& ctx, + const Expr& expr) { + const auto* tuple_sinfo = GetStructInfoAs(expr); + if (tuple_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << call->op + << " expects the input to be a Tuple of Tensors. However, the given input is " + << expr->struct_info_->GetTypeKey()); + } + + Array tensor_sinfo; + tensor_sinfo.reserve(tuple_sinfo->fields.size()); + for (StructInfo field_sinfo : tuple_sinfo->fields) { + const auto* field_tensor_sinfo = field_sinfo.as(); + if (field_tensor_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << call->op << " expects the input to be a Tuple of Tensors. However, the given input is " + << expr->struct_info_); + } + tensor_sinfo.push_back(GetRef(field_tensor_sinfo)); + } + return tensor_sinfo; +} + +Optional> CheckConcatOutputShape(const Call& call, const BlockBuilder& ctx, + const std::vector>& shape_values, + int axis) { + bool shape_unknown = false; + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr concat_sum = IntImm(DataType::Int(64), 0); + for (int d = 0; d < static_cast(shape_values[0].size()); ++d) { + // For the specified axis, we compute the sum of shape value over each tensor. + if (d == axis) { + for (Array shape_value : shape_values) { + concat_sum += shape_value[d]; + } + continue; + } + + // For other axes, we check the equality of all tensors' shape values, to ensure safety. + for (int i = 1; i < static_cast(shape_values.size()); ++i) { + if (analyzer->CanProve(shape_values[i][d] != shape_values[0][d])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects the input tensors to have the same shape on every " + "dimension except the one indicated by the input axis. However, the " + "input contains tensors whose shapes on dimension " + << d << " is " << shape_values[0][d] << " and " << shape_values[i][d]); + } else if (!analyzer->CanProveEqual(shape_values[i][d], shape_values[0][d])) { + shape_unknown = true; + } + } + } + + if (shape_unknown) { + return NullOpt; + } + Array output_shape = shape_values[0]; + output_shape.Set(axis, concat_sum); + return output_shape; +} + +StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) << "Concat op should have 1 argument"); + } + Array tensor_sinfo = GetTensorSInfoFromTuple(call, ctx, call->args[0]); + if (tensor_sinfo.empty()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat op expects at least one tensor in the input Tuple. However, the " + "given input Tuple is empty."); + } + + const auto* attrs = call->attrs.as(); + int output_ndim = attrs->axis.defined() ? kUnknownNDim : 1; + DataType output_dtype = DataType::Void(); + bool shape_unknown = false; + bool is_void_dtype = false; + std::vector> shape_values; + shape_values.reserve(tensor_sinfo.size()); + + for (TensorStructInfo sinfo : tensor_sinfo) { + // Update the output dtype. + if (sinfo->dtype.is_void()) { + is_void_dtype = true; + } else if (output_dtype.is_void()) { + output_dtype = sinfo->dtype; + } else if (sinfo->dtype != output_dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects all input tensors to have the same dtype. However, the " + "input contains tensors with dtype " + << output_dtype << " and " << sinfo->dtype); + } + + // Update the output ndim. + // Todo(relax-team): revisit here for better check on if the input tensor has + // ndim 1 when the input axis is undefined. + if (output_ndim == kUnknownNDim) { + output_ndim = sinfo->ndim; + } else if (sinfo->ndim != kUnknownNDim && sinfo->ndim != output_ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects all input tensors to have same ndim. However, the " + "input contains tensors with ndim " + << output_ndim << " and " << sinfo->ndim); + } + + // Update the shape values for best effort check. + const auto* shape_expr = sinfo->shape.as(); + if (shape_expr != nullptr) { + shape_values.push_back(shape_expr->values); + continue; + } + shape_unknown = true; + + if (!sinfo->shape.defined()) { + continue; + } + // Keep the shape value for equality check. + ShapeStructInfo shape_sinfo = Downcast(sinfo->shape.value()->struct_info_); + if (shape_sinfo->values.defined()) { + shape_values.push_back(shape_sinfo->values.value()); + } + } + + if (is_void_dtype) { + output_dtype = DataType::Void(); + } + if (output_ndim == kUnknownNDim) { + return tensor_sinfo.size() == 1 ? tensor_sinfo[0] : TensorStructInfo(output_dtype, output_ndim); + } + + int axis = + attrs->axis.defined() ? NormalizeAxis(call, ctx, output_ndim, attrs->axis.value()->value) : 0; + // If there is only one input tensor, no action is needed. + if (tensor_sinfo.size() == 1) { + return tensor_sinfo[0]; + } + if (shape_values.empty()) { + return TensorStructInfo(output_dtype, output_ndim); + } + + // As long as the there is known shape value, we will do the best effort check to ensure safety. + Optional> output_shape = CheckConcatOutputShape(call, ctx, shape_values, axis); + + if (shape_unknown || !output_shape.defined()) { + return TensorStructInfo(output_dtype, output_ndim); + } else { + return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); + } +} + +TVM_REGISTER_OP("relax.concat") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") + .set_attr("FInferStructInfo", InferStructInfoConcat); + +/* relax.expand_dims */ +TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); + +Expr expand_dims(Expr x, Array axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.expand_dims"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.expand_dims").set_body_typed(expand_dims); + +StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (attrs->axis.empty()) { + return data_sinfo; + } + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + int n_new_dim = attrs->axis.size(); + int output_ndim = data_sinfo->ndim + n_new_dim; + std::vector axes = NormalizeAxes(call, ctx, output_ndim, attrs->axis); + + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, output_ndim); + } + + std::vector output_shape; + output_shape.resize(output_ndim, PrimExpr()); + for (int i = 0; i < n_new_dim; ++i) { + output_shape[axes[i]] = IntImm(DataType::Int(64), 1); + } + + int i_data_shape = 0; + for (int i = 0; i < output_ndim; ++i) { + if (output_shape[i].defined()) { + continue; + } + ICHECK_LT(i_data_shape, data_sinfo->ndim); + output_shape[i] = data_shape->values[i_data_shape]; + ++i_data_shape; + } + ICHECK_EQ(i_data_shape, data_sinfo->ndim); + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.expand_dims") + .set_num_inputs(1) + .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoExpandDims); + // Helper function for flatten and reshape. PrimExpr ComputeShapeProduct(const Array& shape_values) { PrimExpr shape_prod = IntImm(DataType::Int(64), 1); @@ -41,6 +345,172 @@ PrimExpr ComputeShapeProduct(const Array& shape_values) { return shape_prod; } +/* relax.flatten */ +Expr flatten(Expr x) { + static const Op& op = Op::Get("relax.flatten"); + return Call(op, {std::move(x)}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.flatten").set_body_typed(flatten); + +StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1); + } else if (data_sinfo->ndim == 0) { + return TensorStructInfo(ShapeExpr({1}), data_sinfo->dtype); + } else if (data_sinfo->ndim == 1) { + return data_sinfo; + } + + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1); + } + PrimExpr shape_prod = ComputeShapeProduct(data_shape->values); + return TensorStructInfo(ShapeExpr({std::move(shape_prod)}), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.flatten") + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoFlatten); + +/* relax.layout_transform */ +TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); + +Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value) { + ObjectPtr attrs = make_object(); + attrs->index_map = std::move(index_map); + attrs->pad_value = std::move(pad_value); + + static const Op& op = Op::Get("relax.layout_transform"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.layout_transform").set_body_typed(layout_transform); + +StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + tir::IndexMap index_map = attrs->index_map; + Optional optional_pad_value = attrs->pad_value; + + // Check pad_value has same dtype as input. + if (optional_pad_value.defined()) { + PrimExpr padded_value = optional_pad_value.value()->value; + if (padded_value->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "layout_transform pad_value dtype (" << padded_value->dtype + << ") and input dtype (" << data_sinfo->dtype << ") must be the same"); + } + } + + if (data_sinfo->IsUnknownNdim()) { + // Todo(relax-team): revisit here for better check on if the input tensor has desired ndim. + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + } + + // If rank is known, check that it is compatible with the index_map, i.e., #dims match. + if (index_map->initial_indices.size() != static_cast(data_sinfo->ndim)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "number of dimensions in input must match the number of source dimensions " + "in index map, but got " + << data_sinfo->ndim << " != " << index_map->initial_indices.size()); + } + + if (!data_sinfo->shape.defined()) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + } + + ShapeStructInfo shape_sinfo = Downcast(data_sinfo->shape.value()->struct_info_); + if (!shape_sinfo->values.defined()) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + } + + Array output_shape = index_map->MapShape(shape_sinfo->values.value()); + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.layout_transform") + .set_num_inputs(1) + .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoLayoutTransform); + +/* relax.permute_dims */ +TVM_REGISTER_NODE_TYPE(PermuteDimsAttrs); + +Expr permute_dims(Expr x, Optional> axes) { + ObjectPtr attrs = make_object(); + attrs->axes = std::move(axes); + + static const Op& op = Op::Get("relax.permute_dims"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.permute_dims").set_body_typed(permute_dims); + +bool IsIdentityPermutation(const std::vector& permutation) { + for (int i = 0; i < static_cast(permutation.size()); ++i) { + if (permutation[i] != i) { + return false; + } + } + return true; +} + +StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as(); + + // Todo(relax-team): revisit here for better check on if the input tensor has + // ndim same as the number of input axes. + if (!attrs->axes.defined() && data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + if (attrs->axes.defined()) { + int n_axis = attrs->axes.value().size(); + if (!data_sinfo->IsUnknownNdim() && n_axis != data_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "PermuteDims expects the number of input axes to equal the ndim of the " + "input tensor. However, the tensor ndim is " + << data_sinfo->ndim << " while the given number of axes is " << n_axis); + } + } + + std::vector axes; + if (attrs->axes.defined()) { + axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes.value()); + } else { + // Construct the reverse permutation via std::iota + axes.resize(data_sinfo->ndim); + std::iota(axes.rbegin(), axes.rend(), 0); + } + if (IsIdentityPermutation(axes)) { + return data_sinfo; + } + + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + std::vector new_shape; + new_shape.reserve(data_sinfo->ndim); + for (int i = 0; i < data_sinfo->ndim; ++i) { + new_shape.push_back(data_shape->values[axes[i]]); + } + return TensorStructInfo(ShapeExpr(new_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.permute_dims") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoPermuteDims); + /* relax.reshape */ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { if (const auto* e = shape.as()) { @@ -115,18 +585,18 @@ TVM_REGISTER_GLOBAL("relax.op.reshape").set_body_typed(reshape); StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { - ctx->ReportFatal(Diagnostic::Error(call->span) << "Reshape op should take 2 arguments"); + ctx->ReportFatal(Diagnostic::Error(call) << "Reshape op should take 2 arguments"); } const auto* data_sinfo = GetStructInfoAs(call->args[0]); const auto* new_shape_sinfo = GetStructInfoAs(call->args[1]); if (data_sinfo == nullptr) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "Reshape requires the input data to be Tensor. However, the given one is " << call->args[0]->struct_info_->GetTypeKey()); } if (new_shape_sinfo == nullptr) { ctx->ReportFatal( - Diagnostic::Error(call->span) + Diagnostic::Error(call) << "Reshape requires the input new shape to be Shape. However, the given one is " << call->args[1]->struct_info_->GetTypeKey()); } @@ -142,7 +612,7 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { PrimExpr new_shape_prod = ComputeShapeProduct(new_shape_sinfo->values.value()); PrimExpr old_shape_prod = ComputeShapeProduct(old_shape_values.value()); if (ctx->GetAnalyzer()->CanProve(old_shape_prod != new_shape_prod)) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "Reshape expects the new shape to be convertible from the old shape. " "However, the old shape is " << data_sinfo->shape << ", with product " << old_shape_prod @@ -159,5 +629,215 @@ TVM_REGISTER_OP("relax.reshape") .add_argument("shape", "Shape", "The input new shape.") .set_attr("FInferStructInfo", InferStructInfoReshape); +/* relax.split */ +TVM_REGISTER_NODE_TYPE(SplitAttrs); + +Expr split(Expr x, ObjectRef indices_or_sections, int axis) { + ObjectPtr attrs = make_object(); + if (const auto* indices = indices_or_sections.as()) { + for (int i = 0; i < static_cast(indices->size()); ++i) { + const auto* idx = indices->at(i).as(); + CHECK(idx != nullptr) << "Split op only accepts an array of integers as the indices. " + "However, the given indices " + << indices_or_sections << " contains some non-integer."; + } + indices_or_sections = ConvertIntImmToInt64(GetRef>(indices)); + } else if (const auto* n_section = indices_or_sections.as()) { + CHECK_GT(n_section->value, 0) << "Split op expects the input number of sections to be a " + "positive integer. However, the given number of sections is " + << n_section->value; + indices_or_sections = IntImm(DataType::Int(64), n_section->value); + } else { + LOG(FATAL) << "Split op expects the input indices_or_sections to be either an Array of " + "PrimExpr or an integer. However, the given one is " + << indices_or_sections->GetTypeKey(); + } + attrs->indices_or_sections = indices_or_sections; + attrs->axis = axis; + + static const Op& op = Op::Get("relax.split"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.split").set_body_typed(split); + +StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + const auto* data_shape = data_sinfo->shape.as(); + int axis = + data_sinfo->IsUnknownNdim() ? -1 : NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis); + + if (const auto* p_indices = attrs->indices_or_sections.as()) { + // When there is not index, return the input tensor's struct info. + if (p_indices->size() == 0) { + return TupleStructInfo({data_sinfo}); + } + // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. + if (data_shape == nullptr) { + return TupleStructInfo(Array( + p_indices->size() + 1, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim))); + } + + ICHECK_NE(axis, -1); + const auto* axis_length = data_shape->values[axis].as(); + // Fall back to unknown shape when the input tensor shape at the given axis is symbolic. + if (axis_length == nullptr) { + return TupleStructInfo(Array( + p_indices->size() + 1, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim))); + } + + // Only do output shape inference when all the indices and the total length are integers. + Array indices = GetRef>(p_indices); + IntImm zero(DataType::Int(64), /*value=*/0); + indices.insert(indices.begin(), zero); + indices.insert(indices.end(), Downcast(data_shape->values[axis])); + + std::vector output_sinfo; + output_sinfo.reserve(indices.size() - 1); + for (int i = 0; i + 1 < static_cast(indices.size()); ++i) { + PrimExpr l = tvm::max(zero, indices[i]); + PrimExpr r = tvm::min(data_shape->values[axis], indices[i + 1]); + + Array shape = data_shape->values; + shape.Set(axis, tvm::max(zero, r - l)); + output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype)); + } + return TupleStructInfo(output_sinfo); + } else if (const auto* p_n_section = attrs->indices_or_sections.as()) { + ICHECK_GT(p_n_section->value, 0); + int n_section = p_n_section->value; + // When the number of section is one, return the input tensor's struct info. + if (n_section == 1) { + return TupleStructInfo({data_sinfo}); + } + // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. + if (data_shape == nullptr) { + return TupleStructInfo( + Array(n_section, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim))); + } + ICHECK_NE(axis, -1); + PrimExpr split_len = ceildiv(data_shape->values[axis], n_section); + + // Construct struct info for tensors except the last one. + Array shape = data_shape->values; + shape.Set(axis, split_len); + std::vector output_sinfo(n_section - 1, + TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype)); + + // Construct struct info for the last tensor. + shape.Set(axis, data_shape->values[axis] - split_len * (n_section - 1)); + output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype)); + return TupleStructInfo(output_sinfo); + } + ICHECK(false) << "Cannot reach here."; + throw; +} + +TVM_REGISTER_OP("relax.split") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoSplit); + +/* relax.squeeze */ +TVM_REGISTER_NODE_TYPE(SqueezeAttrs); + +Expr squeeze(Expr x, Optional> axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.squeeze"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.squeeze").set_body_typed(squeeze); + +StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (attrs->axis.defined() && attrs->axis.value().empty()) { + return data_sinfo; + } + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + Optional> shape_value; + if (data_sinfo->shape.defined()) { + shape_value = Downcast(data_sinfo->shape.value()->struct_info_)->values; + } + + std::vector axis_removal_mask; + axis_removal_mask.resize(data_sinfo->ndim, /*value=*/false); + + if (attrs->axis.defined()) { + std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value()); + + if (!shape_value.defined()) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim - axes.size()); + } + for (int i = 0; i < static_cast(axes.size()); ++i) { + // Todo(relax-team): revisit here for better check on if the axis being squeezed has length 1. + // When `axis` is given, the dim lengths at the axes must be integer 1 when it is not symbolic + const auto* int_len = shape_value.value()[axes[i]].as(); + if (int_len != nullptr && int_len->value != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Squeeze expects the input tensor shape values at the given axis " + "positions to be all 1. However, the tensor shape at axis " + << axes[i] << " is " << shape_value.value()[axes[i]] + << " which is not 1. If it is symbolic, please use MatchCast to cast it " + "to 1 before doing Squeeze."); + } + axis_removal_mask[axes[i]] = true; + } + } else { + // When `axis` is not defined, squeeze all unit-length dimensions. + // Note: This is a less well-defined path in Array API standard's squeeze + // (https://data-apis.org/array-api/latest/API_specification/generated/array_api.squeeze.html). + // Consider discourage usage later. + if (!shape_value.defined()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + for (int i = 0; i < data_sinfo->ndim; ++i) { + // Whenever a dimension length is symbolic, fall back to unknown ndim. + const auto* int_len = shape_value.value()[i].as(); + if (int_len == nullptr) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + if (int_len->value == 1) { + axis_removal_mask[i] = true; + } + } + } + + std::vector output_shape; + output_shape.reserve(data_sinfo->ndim - axis_removal_mask.size()); + for (int i = 0; i < data_sinfo->ndim; ++i) { + if (!axis_removal_mask[i]) { + output_shape.push_back(shape_value.value()[i]); + } + } + + if (data_sinfo->shape.value()->IsInstance()) { + if (static_cast(output_shape.size()) == data_sinfo->ndim) { + return data_sinfo; + } else if (attrs->axis.defined()) { + return TensorStructInfo(data_sinfo->dtype, output_shape.size()); + } else { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + } else { + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); + } +} + +TVM_REGISTER_OP("relax.squeeze") + .set_num_inputs(1) + .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoSqueeze); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 1a3eb0547d7f..6a2b23ecbdbb 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -24,11 +24,59 @@ #ifndef TVM_RELAX_OP_TENSOR_MANIPULATE_H_ #define TVM_RELAX_OP_TENSOR_MANIPULATE_H_ +#include + #include "../op_common.h" namespace tvm { namespace relax { +/*! \brief Broadcasts a tensor to a specified shape. */ +Expr broadcast_to(Expr x, Expr shape); + +/*! + * \brief Concatenate the input tensors along the given axis. + * \param tensors An Expr in Tuple type, containing the tensors to be concatenated, + * or a list of tensors + * \param axis The axis along which the tensors are concatenated. + * If it is `NullOpt`, the input tensor is required to be flattened before concatenation. + * \return The concatenated tensor. + */ +Expr concat(Expr tensors, Optional axis); + +/*! + * \brief Insert new axes at the positions given by `axis`. + * \param x The input data to the operator. + * \param axis The axes at which the input array are expanded. + * \return The transformed result. + */ +Expr expand_dims(Expr x, Array axis); + +/*! + * \brief Flatten all the tensor dimensions into one. + * \param x The input data to the operator. + * \return The flattened result. + */ +Expr flatten(Expr x); + +/*! + * \brief Transform layout of a tensor. + * \param x The input data to the operator. + * \param index_map The transformation to apply. + * \param pad_value The value used for padding if the transformation results in implicit padding. If + * not specified, any value can be used. + * \return The transformed result. + */ +Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value); + +/*! + * \brief Permutes the dimensions of an array. + * \param x The input data to the operator. + * \param axes The target axes order, reverse order if not specified. + * \return The transposed result. + */ +Expr permute_dims(Expr x, Optional> axes); + /*! * \brief Reshape the input array, supporting `-1` inference in the new * shape when the new shape is given as an Array of PrimExpr. @@ -39,6 +87,31 @@ namespace relax { */ Expr reshape(Expr x, ObjectRef shape); +/*! + * \brief Split input tensor along axis by sections or indices. + * - If indices_or_sections is an integer, the input will be divided equally + * along given axis (if possible). Last section will be smaller if the tensor + * size along the given dimension is not divisible by the integer. + * - If indices_or_sections is a tuple of mixture of int or PrimExpr, + * the entries indicate the indices where along axis the array is split. + * \param x The tensor to be split. + * \param indices_or_sections Indices or sections to split into. + * It is required to be an Array of PrimExpr or an integer. + * \param axis The axis over which to split. + * \return The computed result. + */ +Expr split(Expr x, ObjectRef indices_or_sections, int axis); + +/*! + * \brief Squeeze axes in the array. + * \param x The input data to the operator. + * \param axis The set of axes to remove. + * If it is `NullOpt`, remove all axis of dimensions 1. + * If any specified axis has dimension that does not equal 1, it is an error. + * \return The squeezed result. + */ +Expr squeeze(Expr x, Optional> axis); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py new file mode 100644 index 000000000000..92d4bb26760a --- /dev/null +++ b/tests/python/relax/test_op_manipulate.py @@ -0,0 +1,2373 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + assert relax.op.broadcast_to(x, (3, 3, 4, 5)).op == Op.get("relax.broadcast_to") + assert relax.op.concat([x]).op == Op.get("relax.concat") + assert relax.op.expand_dims(x, axis=[]).op == Op.get("relax.expand_dims") + assert relax.op.flatten(x).op == Op.get("relax.flatten") + assert relax.op.permute_dims(x).op == Op.get("relax.permute_dims") + assert relax.op.reshape(x, (4, 5, 3)).op == Op.get("relax.reshape") + assert relax.op.split(x, indices_or_sections=1).op == Op.get("relax.split") + assert relax.op.squeeze(x).op == Op.get("relax.squeeze") + assert relax.op.layout_transform(x, index_map=lambda a, b, c: (b, c, a)).op == Op.get( + "relax.layout_transform" + ) + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_reshape_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + s0 = relax.Var("s", R.Shape((3, 8, 5))) + s1 = relax.Var("s", R.Shape(ndim=3)) + s2 = relax.Var("s", R.Shape()) + + _check_inference( + bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference(bb, relax.op.reshape(x0, (-1,)), relax.TensorStructInfo((120,), "float32")) + _check_inference( + bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x3, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") + ) + _check_inference( + bb, relax.op.reshape(x3, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") + ) + _check_inference( + bb, relax.op.reshape(x4, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") + ) + _check_inference( + bb, relax.op.reshape(x5, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") + ) + _check_inference(bb, relax.op.reshape(x0, s0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.reshape(x1, s0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.reshape(x2, s0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.reshape(x3, s0), relax.TensorStructInfo(s0, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s0), relax.TensorStructInfo(s0, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s0), relax.TensorStructInfo(s0, dtype="")) + _check_inference(bb, relax.op.reshape(x0, s1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.reshape(x1, s1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.reshape(x2, s1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.reshape(x3, s1), relax.TensorStructInfo(s1, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s1), relax.TensorStructInfo(s1, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s1), relax.TensorStructInfo(s1, dtype="")) + _check_inference(bb, relax.op.reshape(x0, s2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.reshape(x1, s2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.reshape(x2, s2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.reshape(x3, s2), relax.TensorStructInfo(s2, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s2), relax.TensorStructInfo(s2, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s2), relax.TensorStructInfo(s2, dtype="")) + + +def test_reshape_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) + s0 = relax.Var("s", R.Shape((c, a, d, b))) + s1 = relax.Var("s", R.Shape()) + + _check_inference( + bb, relax.op.reshape(x, (c, a, d, b)), relax.TensorStructInfo((c, a, d, b), "float32") + ) + _check_inference( + bb, + relax.op.reshape(x, (d, c, b, -1)), + relax.TensorStructInfo((d, c, b, tir.floordiv(a * b * c * d, d * c * b)), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (1, -1, 1)), + relax.TensorStructInfo((1, a * b * c * d, 1), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (2, -1, a)), + relax.TensorStructInfo((2, tir.floordiv(a * b * c * d, a * 2), a), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (c, -1, d, b)), + relax.TensorStructInfo((c, tir.floordiv(a * b * c * d, c * d * b), d, b), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (c, a * d, b)), + relax.TensorStructInfo((c, a * d, b), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (c, a * b * d, -1)), + relax.TensorStructInfo( + (c, a * b * d, tir.floordiv(a * b * c * d, c * (a * b * d))), "float32" + ), + ) + _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.reshape(x, s1), relax.TensorStructInfo(s1, "float32")) + + +def test_reshape_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4, 5))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + ns0 = relax.Var("ns", relax.ShapeStructInfo((3, 8, 5))) + ns1 = relax.Var("ns", relax.ShapeStructInfo()) + + _check_inference( + bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference(bb, relax.op.reshape(x0, ns0), relax.TensorStructInfo(ns0, "float32")) + _check_inference(bb, relax.op.reshape(x0, ns1), relax.TensorStructInfo(ns1, "float32")) + _check_inference( + bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference(bb, relax.op.reshape(x1, ns0), relax.TensorStructInfo(ns0, "float32")) + _check_inference(bb, relax.op.reshape(x1, ns1), relax.TensorStructInfo(ns1, "float32")) + _check_inference( + bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference(bb, relax.op.reshape(x2, ns0), relax.TensorStructInfo(ns0, "float32")) + _check_inference(bb, relax.op.reshape(x2, ns1), relax.TensorStructInfo(ns1, "float32")) + + +def test_reshape_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + + _check_inference(bb, relax.op.reshape(x0, (120,)), relax.TensorStructInfo((120,), "float16")) + _check_inference(bb, relax.op.reshape(x1, (120,)), relax.TensorStructInfo((120,), "int8")) + + +def test_reshape_infer_struct_info_unequal_shape_prod(): + bb = relax.BlockBuilder() + s = relax.Var("s", relax.ShapeStructInfo((2, 3, 4, 5))) + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) + ns = relax.Var("ns", relax.ShapeStructInfo((4, 4, 1, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, (4, 4, 1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, (4, 4, 1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, (4, 4, -1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, (4, 4, -1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, ns)) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, ns)) + + +def test_reshape_infer_struct_info_inference_not_deducible(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor("float32", ndim=4)) + x1 = relax.Var("x", R.Tensor("float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, (2, 3, -1))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, (2, 3, -1))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x2, (2, 3, -1))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x3, (2, 3, -1))) + + +def test_reshape_new_shape_not_tuple(): + m = tir.Var("m", "int64") + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + + with pytest.raises(TVMError): + relax.op.reshape(x, 120) + with pytest.raises(TVMError): + relax.op.reshape(x, m) + + +def test_reshape_infer_struct_info_new_shape_not_integer(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2.0, 3, 4, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2, 3, -1.0))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2, 3, 4.0, -1))) + + +def test_reshape_infer_struct_info_multiple_dim_inference(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2, -1, -1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (-1, -1, -1, -1))) + + +def test_reshape_infer_struct_info_non_positive_new_shape(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2, 0, 4, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (-2, -3, -4, -5))) + + +def test_reshape_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + ns = relax.Var("ns", relax.TensorStructInfo((120,), "float32")) + pv = relax.Var("pv", relax.PrimStructInfo("int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, (2, 3, 4, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, (2, 3, 4, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x2, ns)) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x2, [pv])) + + +def test_permute_dims_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((1, 2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((1,), "float32")) + x7 = relax.Var("x", R.Tensor((), "float32")) + + _check_inference( + bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x0, axes=None), relax.TensorStructInfo((4, 3, 2, 1), "float32") + ) + _check_inference( + bb, + relax.op.permute_dims(x0, [-2, -3, 3, -4]), + relax.TensorStructInfo((3, 2, 4, 1), "float32"), + ) + _check_inference( + bb, relax.op.permute_dims(x1, [2, 3, 1, 0]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x1, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x2, axes=None), relax.TensorStructInfo(dtype="float32") + ) + _check_inference( + bb, relax.op.permute_dims(x3, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), dtype="") + ) + _check_inference( + bb, relax.op.permute_dims(x3, axes=None), relax.TensorStructInfo((4, 3, 2, 1), dtype="") + ) + _check_inference( + bb, + relax.op.permute_dims(x3, [-2, -3, 3, -4]), + relax.TensorStructInfo((3, 2, 4, 1), dtype=""), + ) + _check_inference( + bb, relax.op.permute_dims(x4, [2, 3, 1, 0]), relax.TensorStructInfo(dtype="", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x4, axes=None), relax.TensorStructInfo(dtype="", ndim=4) + ) + _check_inference(bb, relax.op.permute_dims(x5, axes=None), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.permute_dims(x6, axes=None), relax.TensorStructInfo((1,), "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x7, axes=None), relax.TensorStructInfo((), "float32") + ) + + +def test_permute_dims_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) + + _check_inference( + bb, relax.op.permute_dims(x, [2, 3, 1, 0]), relax.TensorStructInfo((c, d, b, a), "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x, axes=None), relax.TensorStructInfo((d, c, b, a), "float32") + ) + _check_inference( + bb, + relax.op.permute_dims(x, [-2, -3, 3, -4]), + relax.TensorStructInfo((c, b, d, a), "float32"), + ) + + +def test_permute_dims_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((1, 2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.permute_dims(x0, [0, 1, 2, 3]), relax.TensorStructInfo(s0, "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x0, [-4, -3, -2, -1]), relax.TensorStructInfo(s0, "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x0, [2, 3, 0, 1]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x0, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x1, [0, 1, 2, 3]), relax.TensorStructInfo(s1, "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x1, [2, 3, 0, 1]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x1, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x2, axes=None), relax.TensorStructInfo(dtype="float32") + ) + + +def test_permute_dims_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((1, 2, 3, 4), "int8")) + x2 = relax.Var("x", R.Tensor((1, 2, 3, 4), "int32")) + + _check_inference( + bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "float16") + ) + _check_inference( + bb, relax.op.permute_dims(x1, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "int8") + ) + _check_inference( + bb, relax.op.permute_dims(x2, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "int32") + ) + + +def test_permute_dims_infer_struct_info_unknown_ndim_with_axes(): + bb = relax.BlockBuilder() + s = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor("float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [2, 3, 1, 0])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [2, 3, 1, 0])) + + +def test_permute_dims_infer_struct_info_wrong_number_axes(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((1, 2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [1, 2, 4, 0, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [1, 2, 4, 0, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x2, [0, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x2, [1, 2, 4, 0, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x3, [0, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x3, [1, 2, 4, 0, 3])) + + +def test_permute_dims_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, 3, 4, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, -5, 1, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, 3, 4, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, -5, 1, 3])) + + +def test_permute_dims_infer_struct_info_repetitive_axes(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, 2, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, 2, -2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, 2, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, 2, -2, 1])) + + +def test_permute_dims_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((1, 2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((1, 2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1)) + + +def test_expand_dims_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "float32") + ) + _check_inference( + bb, + relax.op.expand_dims(x0, [-1, 1, -6, 3, 5]), + relax.TensorStructInfo((2, 1, 1, 1, 3, 1, 4, 1), "float32"), + ) + _check_inference(bb, relax.op.expand_dims(x0, []), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference( + bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.expand_dims(x1, []), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference(bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.expand_dims(x2, []), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.expand_dims(x3, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), dtype="") + ) + _check_inference( + bb, + relax.op.expand_dims(x3, [-1, 1, -6, 3, 5]), + relax.TensorStructInfo((2, 1, 1, 1, 3, 1, 4, 1), dtype=""), + ) + _check_inference(bb, relax.op.expand_dims(x3, []), relax.TensorStructInfo((2, 3, 4), dtype="")) + _check_inference(bb, relax.op.expand_dims(x4, [1, 3]), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.expand_dims(x4, []), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.expand_dims(x5, [1, 3]), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.expand_dims(x5, []), relax.TensorStructInfo(dtype="")) + + +def test_expand_dims_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x = relax.Var("x", R.Tensor((a, 4, b), "float32")) + + _check_inference( + bb, relax.op.expand_dims(x, [1, 3]), relax.TensorStructInfo((a, 1, 4, 1, b), "float32") + ) + _check_inference( + bb, + relax.op.expand_dims(x, [-1, 1, -6, 3, 5]), + relax.TensorStructInfo((a, 1, 1, 1, 4, 1, b, 1), "float32"), + ) + _check_inference(bb, relax.op.expand_dims(x, []), relax.TensorStructInfo((a, 4, b), "float32")) + + +def test_expand_dims_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.expand_dims(x0, []), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.expand_dims(x1, []), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.expand_dims(x2, []), relax.TensorStructInfo(s2, "float32")) + + +def test_expand_dims_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32")) + + _check_inference( + bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "float16") + ) + _check_inference( + bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "int8") + ) + _check_inference( + bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "int32") + ) + + +def test_expand_dims_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", relax.TensorStructInfo(s0)) + x3 = relax.Var("x", relax.TensorStructInfo(s1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, [1, 5])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, [-6, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, [1, 5])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, [-6, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x2, [1, 5])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x2, [-6, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x3, [1, 5])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x3, [-6, 1])) + + +def test_expand_dims_infer_struct_info_repetitive_axes(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", relax.TensorStructInfo(s0)) + x3 = relax.Var("x", relax.TensorStructInfo(s1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, [1, -4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, [1, -4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x2, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x2, [1, -4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x3, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x3, [1, -4])) + + +def test_expand_dims_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, axis=[])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, axis=[])) + + +def test_layout_transform_infer_struct_info(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) + + transpose_transform = lambda a, b, c: (a, c, b) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=transpose_transform), + relax.TensorStructInfo((10, 30, 20), "float32"), + ) + + tiling_transform = lambda a, b, c: (a, b // 2, c, b % 2) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=tiling_transform), + relax.TensorStructInfo((10, 10, 30, 2), "float32"), + ) + + implicit_padding_transform = lambda a, b, c: (a, c, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=implicit_padding_transform, pad_value=2), + relax.TensorStructInfo((10, 30, 7, 3), "float32"), + ) + + flatten_transform = lambda a, b, c: (a * 600 + b * 30 + c) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=flatten_transform), + relax.TensorStructInfo((6000,), "float32"), + ) + + +def test_layout_transform_infer_struct_info_mismatch_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "int32")) + + transpose_transform = lambda a, b, c: (a, c, b) + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x, index_map=transpose_transform, pad_value=2.2)) + + +def test_layout_transform_infer_struct_info_unknown_shape(): + bb = relax.BlockBuilder() + tiling_transform = lambda a, b: (a, b // 2, b % 2) + + x_unknown_shape = relax.Var("x", R.Tensor("float32", ndim=2)) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_shape, index_map=tiling_transform), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + x_unknown_rank_dtype = relax.Var("x", R.Tensor()) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_rank_dtype, index_map=tiling_transform), + relax.TensorStructInfo(dtype="", ndim=3), + ) + + +def test_layout_transform_infer_struct_info_symbolic_shape(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((a, b), "float32")) + + tiling_transform = lambda a, b: (a, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x0, index_map=tiling_transform), + relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), + ) + + +def test_layout_transform_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + + s = relax.Var("s", relax.ShapeStructInfo((30, 20))) + x = relax.Var("x", relax.TensorStructInfo(s, "float32")) + tiling_padding_transform = lambda a, b: (a, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=tiling_padding_transform), + relax.TensorStructInfo((30, 7, 3), "float32"), + ) + + s_unknown_shape = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x_unknown_shape = relax.Var("x", relax.TensorStructInfo(s_unknown_shape, "float32")) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_shape, index_map=tiling_padding_transform), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + s_unknown_rank = relax.Var("s", relax.ShapeStructInfo()) + x_unknown_rank = relax.Var("x", relax.TensorStructInfo(s_unknown_rank, "float32")) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_rank, index_map=tiling_padding_transform), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s_symbolic_shape = relax.Var("s", relax.ShapeStructInfo((a, b))) + x_symbolic_shape = relax.Var("x", relax.TensorStructInfo(s_symbolic_shape, "float32")) + _check_inference( + bb, + relax.op.layout_transform(x_symbolic_shape, index_map=tiling_padding_transform), + relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), + ) + + +def test_layout_transform_infer_struct_info_invalid_index_map(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x, index_map=lambda a, b: (b, a))) + + +def test_squeeze_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=6)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4))) + x4 = relax.Var("x", R.Tensor(ndim=6)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.squeeze(x0, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), "float32") + ) + _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference( + bb, relax.op.squeeze(x1, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x2, [1, 4]), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.squeeze(x3, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), dtype="") + ) + _check_inference(bb, relax.op.squeeze(x3), relax.TensorStructInfo((2, 3, 4), dtype="")) + _check_inference(bb, relax.op.squeeze(x4, [1, 4]), relax.TensorStructInfo(dtype="", ndim=4)) + _check_inference(bb, relax.op.squeeze(x4), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.squeeze(x5, [1, 4]), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.squeeze(x5), relax.TensorStructInfo(dtype="")) + + +def test_squeeze_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((a, 1, b), "float32")) + x1 = relax.Var("x", R.Tensor((a, 1, b))) + + _check_inference(bb, relax.op.squeeze(x0, [1]), relax.TensorStructInfo((a, b), "float32")) + _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x1, [1]), relax.TensorStructInfo((a, b), dtype="")) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(dtype="")) + + +def test_squeeze_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s2 = relax.Var("s", relax.ShapeStructInfo((a, 1, b))) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + s4 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) + + _check_inference( + bb, relax.op.squeeze(x0, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.squeeze(x0, []), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x1, []), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(s1, dtype="float32")) + _check_inference(bb, relax.op.squeeze(x2, [1]), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.squeeze(x2, []), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.squeeze(x3, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.squeeze(x3, []), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.squeeze(x3), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x4, [1, 4]), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x4, []), relax.TensorStructInfo(s4, "float32")) + _check_inference(bb, relax.op.squeeze(x4), relax.TensorStructInfo(dtype="float32")) + + +def test_squeeze_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "int8")) + x2 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "int32")) + + _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo((2, 3, 4), "float16")) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo((2, 3, 4), "int8")) + _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo((2, 3, 4), "int32")) + + +def test_squeeze_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=6)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [6])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [-7])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1, [6])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1, [-7])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [6])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [-7])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x3, [6])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x3, [-7])) + + +def test_squeeze_infer_struct_info_repetitive_axes(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=6)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [3, -3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1, [3, -3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [3, -3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x3, [3, -3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x3, [1, 1])) + + +def test_squeeze_infer_struct_info_axis_length_not_one(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo((a, 3, 4))) + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((a, 3, 4), "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [0])) + _check_inference(bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo((3, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [0])) + _check_inference(bb, relax.op.squeeze(x3, [0]), relax.TensorStructInfo(dtype="float32", ndim=2)) + + +def test_squeeze_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1)) + + +def test_flatten_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor((3,), "float32")) + x2 = relax.Var("x", R.Tensor((), "float32")) + x3 = relax.Var("x", R.Tensor("float32", ndim=3)) + x4 = relax.Var("x", R.Tensor("float32", ndim=1)) + x5 = relax.Var("x", R.Tensor("float32", ndim=0)) + x6 = relax.Var("x", R.Tensor("float32")) + x7 = relax.Var("x", R.Tensor((3, 4, 5))) + x8 = relax.Var("x", R.Tensor((3,))) + x9 = relax.Var("x", R.Tensor(())) + x10 = relax.Var("x", R.Tensor(ndim=3)) + x11 = relax.Var("x", R.Tensor(ndim=1)) + x12 = relax.Var("x", R.Tensor(ndim=0)) + x13 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((60,), "float32")) + _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((3,), "float32")) + _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((1,), "float32")) + _check_inference(bb, relax.op.flatten(x3), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x4), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x5), relax.TensorStructInfo((1,), "float32")) + _check_inference(bb, relax.op.flatten(x6), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x7), relax.TensorStructInfo((60,), dtype="")) + _check_inference(bb, relax.op.flatten(x8), relax.TensorStructInfo((3,), dtype="")) + _check_inference(bb, relax.op.flatten(x9), relax.TensorStructInfo((1,), dtype="")) + _check_inference(bb, relax.op.flatten(x10), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.flatten(x11), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.flatten(x12), relax.TensorStructInfo((1,), dtype="")) + _check_inference(bb, relax.op.flatten(x13), relax.TensorStructInfo(dtype="", ndim=1)) + + +def test_flatten_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((a, b), "float32")) + x1 = relax.Var("x", R.Tensor((a, b))) + + _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((a * b,), "float32")) + _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((a * b,), dtype="")) + + +def test_flatten_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((3, 4, 5))) + s1 = relax.Var("s", relax.ShapeStructInfo((3,))) + s2 = relax.Var("s", relax.ShapeStructInfo(())) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s4 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s5 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + s6 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s5, "float32")) + x6 = relax.Var("x", relax.TensorStructInfo(s6, "float32")) + + _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((1,), "float32")) + _check_inference(bb, relax.op.flatten(x3), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x4), relax.TensorStructInfo(s4, "float32")) + _check_inference(bb, relax.op.flatten(x5), relax.TensorStructInfo((1,), "float32")) + _check_inference(bb, relax.op.flatten(x6), relax.TensorStructInfo(dtype="float32", ndim=1)) + + +def test_flatten_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((3, 4, 5), "int8")) + x2 = relax.Var("x", R.Tensor((3, 4, 5), "int32")) + + _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((60,), "float16")) + _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((60,), "int8")) + _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((60,), "int32")) + + +def test_flatten_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((3, 4, 5), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.flatten(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.flatten(x1)) + + +def test_flatten_wrong_input_number(): + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + y = relax.Var("y", R.Tensor((2, 3, 4), "float32")) + + with pytest.raises(TypeError): + relax.op.flatten(x, y) + + +def test_concat_infer_struct_info_with_axis(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((2, 4, 4), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=3)) + y2 = relax.Var("y", R.Tensor("float32")) + y3 = relax.Var("y", R.Tensor((2, 4, 4))) + y4 = relax.Var("y", R.Tensor(ndim=3)) + y5 = relax.Var("y", R.Tensor()) + z0 = relax.Var("z", R.Tensor((2, 5, 4), "float32")) + z1 = relax.Var("z", R.Tensor("float32", ndim=3)) + z2 = relax.Var("z", R.Tensor("float32")) + z3 = relax.Var("z", R.Tensor((2, 5, 4))) + z4 = relax.Var("z", R.Tensor(ndim=3)) + z5 = relax.Var("z", R.Tensor()) + + _check_inference( + bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorStructInfo((2, 12, 4), "float32") + ) + _check_inference( + bb, relax.op.concat([x0, y0, z0], axis=-2), relax.TensorStructInfo((2, 12, 4), "float32") + ) + _check_inference( + bb, relax.op.concat([x1, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y0, z0], axis=1), relax.TensorStructInfo((2, 12, 4), dtype="") + ) + _check_inference( + bb, relax.op.concat([x3, y0, z0], axis=-2), relax.TensorStructInfo((2, 12, 4), dtype="") + ) + _check_inference( + bb, relax.op.concat([x4, y0, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y0, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x1, y1, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y1, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y1, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y1, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y2, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y2, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y5, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x1, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y2, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y1, z1], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y2, z2], axis=1), relax.TensorStructInfo(dtype="float32", ndim=-1) + ) + _check_inference( + bb, relax.op.concat([x3, y2, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x4, y4, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y5, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=-1) + ) + _check_inference( + bb, relax.op.concat([x3, y3, z3], axis=1), relax.TensorStructInfo((2, 12, 4), dtype="") + ) + _check_inference( + bb, relax.op.concat([x3, y3, z3], axis=-2), relax.TensorStructInfo((2, 12, 4), dtype="") + ) + _check_inference( + bb, relax.op.concat([x4, y3, z3], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y5, z3], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x4, y4, z4], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y5, z4], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference(bb, relax.op.concat([x5, y5, z5], axis=1), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0, z0]), axis=1), + relax.TensorStructInfo((2, 12, 4), "float32"), + ) + + +def test_concat_infer_struct_info_with_axis_shape_symbolic(): + bb = relax.BlockBuilder() + a0 = tir.Var("a0", "int64") + a1 = tir.Var("a1", "int64") + b0 = tir.Var("b0", "int64") + b1 = tir.Var("b1", "int64") + b2 = tir.Var("b2", "int64") + c = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((a0, b0, c), "float32")) + x1 = relax.Var("x", R.Tensor((a1, b0, c), "float32")) + y = relax.Var("y", R.Tensor((a0, b1, c), "float32")) + z = relax.Var("z", R.Tensor((a0, b2, c), "float32")) + + _check_inference( + bb, + relax.op.concat([x0, y, z], axis=1), + relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), + ) + _check_inference( + bb, + relax.op.concat([x0, y, z], axis=-2), + relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), + ) + _check_inference( + bb, relax.op.concat([x1, y, z], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y, z]), axis=1), + relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), + ) + + +def test_concat_infer_struct_info_with_axis_shape_var(): + bb = relax.BlockBuilder() + a0 = tir.Var("a0", "int64") + a1 = tir.Var("a1", "int64") + b0 = tir.Var("b0", "int64") + b1 = tir.Var("b1", "int64") + b2 = tir.Var("b2", "int64") + c = tir.Var("c", "int64") + sx0 = relax.Var("sx", relax.ShapeStructInfo((2, 3, 4))) + sx1 = relax.Var("sx", relax.ShapeStructInfo((a0, b0, c))) + sx2 = relax.Var("sx", relax.ShapeStructInfo((a1, b0, c))) + sx3 = relax.Var("sx", relax.ShapeStructInfo(ndim=3)) + sx4 = relax.Var("sx", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(sx3, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(sx4, "float32")) + y0 = relax.Var("y", R.Tensor((2, 4, 4), "float32")) + y1 = relax.Var("y", R.Tensor((a0, b1, c), "float32")) + z0 = relax.Var("z", R.Tensor((2, 5, 4), "float32")) + z1 = relax.Var("z", R.Tensor((a0, b2, c), "float32")) + + _check_inference( + bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x1, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x4, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0, z0]), axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + +def test_concat_infer_struct_info_without_axis(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3,), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=1)) + x2 = relax.Var("x", R.Tensor((3,))) + x3 = relax.Var("x", R.Tensor(ndim=1)) + y0 = relax.Var("y", R.Tensor((4,), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=1)) + z0 = relax.Var("z", R.Tensor((5,), "float32")) + z1 = relax.Var("z", R.Tensor("float32", ndim=1)) + + _check_inference( + bb, relax.op.concat([x0, y0, z0], axis=None), relax.TensorStructInfo((12,), "float32") + ) + _check_inference( + bb, + relax.op.concat([x1, y0, z0], axis=None), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, relax.op.concat([x2, y0, z0], axis=None), relax.TensorStructInfo((12,), dtype="") + ) + _check_inference( + bb, relax.op.concat([x3, y0, z0], axis=None), relax.TensorStructInfo(dtype="", ndim=1) + ) + _check_inference( + bb, + relax.op.concat([x1, y1, z0], axis=None), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, relax.op.concat([x2, y1, z0], axis=None), relax.TensorStructInfo(dtype="", ndim=1) + ) + _check_inference( + bb, + relax.op.concat([x1, y1, z1], axis=None), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0, z0]), axis=None), + relax.TensorStructInfo((12,), "float32"), + ) + + +def test_concat_infer_struct_info_without_axis_shape_symbolic(): + bb = relax.BlockBuilder() + a0 = tir.Var("a0", "int64") + a1 = tir.Var("a1", "int64") + x0 = relax.Var("x", R.Tensor((a0,), "float32")) + x1 = relax.Var("x", R.Tensor((a0,), "")) + y0 = relax.Var("y", R.Tensor((a1,), "float32")) + y1 = relax.Var("y", R.Tensor((a1,), "")) + + _check_inference( + bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo((a0 + a1,), "float32") + ) + _check_inference( + bb, relax.op.concat([x0, y1], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") + ) + _check_inference( + bb, relax.op.concat([x1, y0], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") + ) + _check_inference( + bb, relax.op.concat([x1, y1], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0]), axis=None), + relax.TensorStructInfo((a0 + a1,), "float32"), + ) + + +def test_concat_infer_struct_info_without_axis_shape_var(): + bb = relax.BlockBuilder() + sx0 = relax.Var("sx", relax.ShapeStructInfo((3,))) + sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=1)) + sy0 = relax.Var("sy", relax.ShapeStructInfo((4,))) + x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(sy0, "float32")) + + _check_inference( + bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) + ) + _check_inference( + bb, relax.op.concat([x1, y0], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0]), axis=None), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + + +def test_concat_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3,), "float16")) + y0 = relax.Var("y", R.Tensor((4,), "float16")) + x1 = relax.Var("x", R.Tensor((3,), "int8")) + y1 = relax.Var("y", R.Tensor((4,), "int8")) + x2 = relax.Var("x", R.Tensor((3,), "int32")) + y2 = relax.Var("y", R.Tensor((4,), "int32")) + + _check_inference( + bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo((7,), "float16") + ) + _check_inference(bb, relax.op.concat([x1, y1], axis=None), relax.TensorStructInfo((7,), "int8")) + _check_inference( + bb, relax.op.concat([x2, y2], axis=None), relax.TensorStructInfo((7,), "int32") + ) + + +def test_concat_infer_struct_info_tuple_var(): + bb = relax.BlockBuilder() + a = tir.Var("a0", "int64") + b0 = tir.Var("b0", "int64") + b1 = tir.Var("b1", "int64") + t0 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo((a, b0), "float32"), relax.TensorStructInfo((a, b1), "float32")] + ), + ) + t1 = relax.Var( + "t", + relax.TupleStructInfo( + [ + relax.TensorStructInfo((a, b0), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=2), + ] + ), + ) + t2 = relax.Var( + "t", + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32", ndim=2), + ] + ), + ) + t3 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] + ), + ) + t4 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo((a, b0), "float32"), relax.TensorStructInfo((a, b1))] + ), + ) + t5 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo((a, b0), dtype=""), relax.TensorStructInfo((a, b1), dtype="")] + ), + ) + t6 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="", ndim=2), relax.TensorStructInfo(dtype="")] + ), + ) + t7 = relax.Var( + "t", + relax.TupleStructInfo([relax.TensorStructInfo(dtype=""), relax.TensorStructInfo(dtype="")]), + ) + + _check_inference( + bb, relax.op.concat(t0, axis=1), relax.TensorStructInfo((a, b0 + b1), "float32") + ) + _check_inference( + bb, relax.op.concat(t1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.concat(t2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.concat(t3, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.concat(t4, axis=1), relax.TensorStructInfo((a, b0 + b1), "float32") + ) + _check_inference( + bb, relax.op.concat(t5, axis=1), relax.TensorStructInfo((a, b0 + b1), dtype="") + ) + _check_inference(bb, relax.op.concat(t6, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.concat(t7, axis=1), relax.TensorStructInfo(dtype="")) + + +def test_concat_infer_struct_info_single_input_tensor(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((3, a))) + s1 = relax.Var("s", relax.ShapeStructInfo((a,))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s4 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((3, a), "float32")) + x1 = relax.Var("x", R.Tensor((a,), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) + x3 = relax.Var("x", R.Tensor("float32", ndim=1)) + x4 = relax.Var("x", R.Tensor("float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x6 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x7 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x8 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + x9 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) + + _check_inference(bb, relax.op.concat([x0], axis=1), relax.TensorStructInfo((3, a), "float32")) + _check_inference(bb, relax.op.concat([x1], axis=0), relax.TensorStructInfo((a,), "float32")) + _check_inference(bb, relax.op.concat([x1], axis=None), relax.TensorStructInfo((a,), "float32")) + _check_inference( + bb, relax.op.concat([x2], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3], axis=0), relax.TensorStructInfo(dtype="float32", ndim=1) + ) + _check_inference( + bb, relax.op.concat([x3], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) + ) + _check_inference(bb, relax.op.concat([x4], axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.concat([x5], axis=1), relax.TensorStructInfo(s0, dtype="float32")) + _check_inference(bb, relax.op.concat([x6], axis=0), relax.TensorStructInfo(s1, dtype="float32")) + _check_inference( + bb, relax.op.concat([x6], axis=None), relax.TensorStructInfo(s1, dtype="float32") + ) + _check_inference(bb, relax.op.concat([x7], axis=1), relax.TensorStructInfo(s2, dtype="float32")) + _check_inference(bb, relax.op.concat([x8], axis=0), relax.TensorStructInfo(s3, dtype="float32")) + _check_inference( + bb, relax.op.concat([x8], axis=None), relax.TensorStructInfo(s3, dtype="float32") + ) + _check_inference(bb, relax.op.concat([x9], axis=1), relax.TensorStructInfo(s4, dtype="float32")) + + +def test_concat_infer_struct_info_zero_rank_input_tensor(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(())) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + x0 = relax.Var("x", R.Tensor((), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=0)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x0], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x1], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x2], axis=None)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x3], axis=None)) + + +def test_concat_infer_struct_info_no_input_tensor(): + bb = relax.BlockBuilder() + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([], axis=None)) + + +def test_concat_infer_struct_info_without_axis_but_tensor_not_one_dimensional(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x0], axis=None)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x1], axis=None)) + _check_inference(bb, relax.op.concat([x2], axis=None), relax.TensorStructInfo(dtype="float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x3], axis=None)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x4], axis=None)) + _check_inference(bb, relax.op.concat([x5], axis=None), relax.TensorStructInfo(s2, "float32")) + + +def test_concat_infer_struct_info_inconsistent_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3,))) + y = relax.Var("y", R.Tensor((4,), "float32")) + z = relax.Var("z", R.Tensor((5,), "int8")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y, z], axis=0)) + + +def test_concat_infer_struct_info_inconsistent_ndim(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((4, 5))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x = relax.Var("x", R.Tensor((3,), "float32")) + y0 = relax.Var("y", R.Tensor((4, 5), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=2)) + y2 = relax.Var("y", relax.TensorStructInfo(s0, "float32")) + y3 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + z = relax.Var("z", R.Tensor((5,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y0, z], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y1, z], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y2, z], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y3, z], axis=0)) + + +def test_concat_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((3,))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + x0 = relax.Var("x", R.Tensor((3,), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=1)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x0], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x1], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x2], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x3], axis=1)) + + +def test_concat_infer_struct_info_unequal_shape(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo((3, a + 2))) + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((3, a + 2), "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + y0 = relax.Var("y", R.Tensor((3, 3), "float32")) + y1 = relax.Var("y", R.Tensor((3, a), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x0, y0])) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x2, y0])) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x1, y1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x3, y1])) + + +def test_concat_infer_struct_info_input_not_tuple(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3,), "float32")) + s = relax.Var("s", relax.ShapeStructInfo((3,))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat(x)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat(s)) + + +def test_concat_infer_struct_info_input_tuple_field_not_tensor(): + bb = relax.BlockBuilder() + s = relax.Var("s", relax.ShapeStructInfo((3,))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([s])) + + +def test_split_infer_struct_info_by_indices(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 10, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.split(x0, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 3, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, [3, 7], axis=-2), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 3, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x2, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x3, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), dtype=""), + relax.TensorStructInfo((2, 4, 4), dtype=""), + relax.TensorStructInfo((2, 3, 4), dtype=""), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x4, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x5, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo(dtype=""), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, [-2, 2, 6, 4, 8, 12, 9], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 0, 4), "float32"), + relax.TensorStructInfo((2, 2, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 0, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 2, 4), "float32"), + relax.TensorStructInfo((2, 0, 4), "float32"), + relax.TensorStructInfo((2, 1, 4), "float32"), + ] + ), + ) + + +def test_split_infer_struct_info_by_indices_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x = relax.Var("x", R.Tensor((a, b), "float32")) + + _check_inference( + bb, + relax.op.split(x, [10, 20], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorStructInfo(dtype="float32", ndim=2), + ] + ), + ) + + +def test_split_infer_struct_info_by_indices_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 10, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, + relax.op.split(x0, [3], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, [3], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x2, [3], axis=1), + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] + ), + ) + + +def test_split_infer_struct_info_by_n_section(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 10, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.split(x0, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 2, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, 2, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 5, 4), "float32"), + relax.TensorStructInfo((2, 5, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, 3, axis=-2), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 2, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x2, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x3, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), dtype=""), + relax.TensorStructInfo((2, 4, 4), dtype=""), + relax.TensorStructInfo((2, 2, 4), dtype=""), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x4, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x5, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo(dtype=""), + ] + ), + ) + + +def test_split_infer_struct_info_by_n_section_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x = relax.Var("x", R.Tensor((a, b), "float32")) + + _check_inference( + bb, + relax.op.split(x, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"), + relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"), + relax.TensorStructInfo((a, b - tir.ceildiv(b, 3) * 2), "float32"), + ] + ), + ) + + +def test_split_infer_struct_info_by_n_section_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 10, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, + relax.op.split(x0, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x2, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + ] + ), + ) + + +def test_split_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 10, 4), "int8")) + + _check_inference( + bb, + relax.op.split(x0, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), "float16"), + relax.TensorStructInfo((2, 4, 4), "float16"), + relax.TensorStructInfo((2, 3, 4), "float16"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), "int8"), + relax.TensorStructInfo((2, 4, 4), "int8"), + relax.TensorStructInfo((2, 3, 4), "int8"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), "float16"), + relax.TensorStructInfo((2, 4, 4), "float16"), + relax.TensorStructInfo((2, 2, 4), "float16"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), "int8"), + relax.TensorStructInfo((2, 4, 4), "int8"), + relax.TensorStructInfo((2, 2, 4), "int8"), + ] + ), + ) + + +def test_split_infer_struct_info_single_output(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((a, b))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((a, b), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, + relax.op.split(x0, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo((a, b), "float32")]), + ) + _check_inference( + bb, + relax.op.split(x1, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32", ndim=2)]), + ) + _check_inference( + bb, + relax.op.split(x2, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32")]), + ) + _check_inference( + bb, + relax.op.split(x3, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s0, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x4, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s1, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x5, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s2, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x0, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo((a, b), "float32")]), + ) + _check_inference( + bb, + relax.op.split(x1, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32", ndim=2)]), + ) + _check_inference( + bb, + relax.op.split(x2, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32")]), + ) + _check_inference( + bb, + relax.op.split(x3, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s0, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x4, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s1, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x5, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s2, "float32")]), + ) + + +def test_split_indices_or_sections_int64(): + x = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + split0 = relax.op.split(x, [3, 6], axis=1) + split1 = relax.op.split(x, 4, axis=1) + + assert split0.attrs.indices_or_sections[0].dtype == "int64" + assert split0.attrs.indices_or_sections[1].dtype == "int64" + assert split1.attrs.indices_or_sections.dtype == "int64" + + +def test_split_infer_struct_info_non_integer_indices(): + bb = relax.BlockBuilder() + a = tir.Var("c", "int64") + b = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((3, 4), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x, [a, b], axis=1)) + + +def test_split_invalid_n_section(): + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor((3, 4), "float32")) + + with pytest.raises(TVMError): + relax.op.split(x, 0, axis=1) + with pytest.raises(TVMError): + relax.op.split(x, -1, axis=1) + with pytest.raises(TVMError): + relax.op.split(x, n, axis=1) + + +def test_split_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, [], axis=2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, [], axis=-3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x1, 1, axis=2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x1, 1, axis=-3)) + + +def test_split_infer_invalid_struct_info_indices(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + v = relax.Var("v", relax.PrimStructInfo("int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, [v], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, v, axis=1)) + + +def test_split_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, 1, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x1, 1, axis=1)) + + +def test_broadcast_to_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 1, 3))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x3, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") + ) + _check_inference( + bb, relax.op.broadcast_to(x4, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") + ) + _check_inference( + bb, relax.op.broadcast_to(x5, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") + ) + + +def test_broadcast_to_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x0 = relax.Var("x", R.Tensor((b, 1, 1, d), "float32")) + x1 = relax.Var("x", R.Tensor((b, 1, 1, d))) + + _check_inference( + bb, + relax.op.broadcast_to(x0, (a, b, 1, c, d)), + relax.TensorStructInfo((a, b, 1, c, d), "float32"), + ) + _check_inference( + bb, + relax.op.broadcast_to(x1, (a, b, 1, c, d)), + relax.TensorStructInfo((a, b, 1, c, d), dtype=""), + ) + + +def test_broadcast_to_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + + +def test_broadcast_to_infer_struct_info_tgt_shape_var(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((b, 1, 1, d))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((b, 1, 1, d), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + stgt0 = relax.Var("stgt", relax.ShapeStructInfo((a, b, 1, c, d))) + stgt1 = relax.Var("stgt", relax.ShapeStructInfo(ndim=5)) + stgt2 = relax.Var("stgt", relax.ShapeStructInfo()) + + _check_inference(bb, relax.op.broadcast_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x2, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x3, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x4, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x5, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x2, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x3, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x4, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x5, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x2, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x3, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x4, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x5, stgt2), relax.TensorStructInfo(stgt2, "float32")) + + +def test_broadcast_to_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 1, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 1, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 1, 3), "int32")) + + _check_inference( + bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float16") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "int8") + ) + _check_inference( + bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "int32") + ) + + +def test_broadcast_to_infer_struct_info_tgt_ndim_less_than_old_ndim(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x0 = relax.Var("x", R.Tensor((2, 1), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + stgt0 = relax.Var("stgt", relax.ShapeStructInfo((2,))) + stgt1 = relax.Var("stgt", relax.ShapeStructInfo(ndim=1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, (2,))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, stgt0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, stgt1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, (2,))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, stgt0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, stgt1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x2, (2,))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x2, stgt0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x2, stgt1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x3, (2,))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x3, stgt0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x3, stgt1)) + + +def test_broadcast_to_infer_struct_info_not_broadcastable_static(): + bb = relax.BlockBuilder() + s = relax.Var("s", relax.ShapeStructInfo((2, 1, 3))) + x0 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) + stgt = relax.Var("stgt", relax.ShapeStructInfo((2, 1, 6))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, (2, 1, 6))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, stgt)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, (2, 1, 6))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, stgt)) + + +def test_broadcast_to_infer_struct_info_not_broadcastable_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s = relax.Var("s", relax.ShapeStructInfo((2, a))) + x0 = relax.Var("x", R.Tensor((2, a), "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) + stgt0 = relax.Var("stgt", relax.ShapeStructInfo((2, b))) + stgt1 = relax.Var("stgt", relax.ShapeStructInfo((2, 1))) + stgt2 = relax.Var("stgt", relax.ShapeStructInfo((b, a))) + + _check_inference( + bb, relax.op.broadcast_to(x0, (2, b)), relax.TensorStructInfo((2, b), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x0, (2, 1)), relax.TensorStructInfo((2, 1), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x0, (b, a)), relax.TensorStructInfo((b, a), "float32") + ) + _check_inference(bb, relax.op.broadcast_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference( + bb, relax.op.broadcast_to(x1, (2, b)), relax.TensorStructInfo((2, b), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (2, 1)), relax.TensorStructInfo((2, 1), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (b, a)), relax.TensorStructInfo((b, a), "float32") + ) + _check_inference(bb, relax.op.broadcast_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32")) + + +def test_broadcast_to_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 1, 3))) + x1 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) + stgt = relax.Var("stgt", relax.TensorStructInfo((4, 2, 5, 3), dtype="")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, (4, 2, 5, 3))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, stgt)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py new file mode 100644 index 000000000000..27f089ee67c1 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py @@ -0,0 +1,314 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_broadcast_to(): + @R.function + def foo(x: R.Tensor((2, 1, 3), "float32")) -> R.Tensor((4, 2, 5, 3), "float32"): + gv: R.Tensor((4, 2, 5, 3), "float32") = R.broadcast_to(x, (4, 2, 5, 3)) + return gv + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 1, 3), "float32")) + with bb.function("foo", [x]): + gv = bb.emit(relax.op.broadcast_to(x, (4, 2, 5, 3))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_concat(): + @R.function + def foo( + x1: R.Tensor((1, 2, 3), "float32"), + x2: R.Tensor((1, 3, 3), "float32"), + x3: R.Tensor((1, 4, 3), "float32"), + ) -> R.Tensor((1, 9, 3), "float32"): + gv: R.Tensor((1, 9, 3), "float32") = R.concat((x1, x2, x3), axis=1) + return gv + + x1 = relax.Var("x1", R.Tensor((1, 2, 3), "float32")) + x2 = relax.Var("x2", R.Tensor((1, 3, 3), "float32")) + x3 = relax.Var("x3", R.Tensor((1, 4, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x1, x2, x3]): + gv = bb.emit(relax.op.concat((x1, x2, x3), axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_concat_without_specified_axis(): + @R.function + def foo( + x1: R.Tensor((2,), "float32"), x2: R.Tensor((3,), "float32"), x3: R.Tensor((4,), "float32") + ) -> R.Tensor((9,), "float32"): + gv: R.Tensor((9,), "float32") = R.concat((x1, x2, x3), axis=None) + return gv + + x1 = relax.Var("x1", R.Tensor((2,), "float32")) + x2 = relax.Var("x2", R.Tensor((3,), "float32")) + x3 = relax.Var("x3", R.Tensor((4,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x1, x2, x3]): + gv = bb.emit(relax.op.concat((x1, x2, x3), axis=None)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_expand_dims(): + @R.function + def foo(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), "float32"): + gv: R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), "float32") = R.expand_dims(x, axis=[-1, 1, -6, 3, 5]) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.expand_dims(x, axis=[-1, 1, -6, 3, 5])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_flatten(): + @R.function + def foo(x: R.Tensor((3, 4, 5), "float32")) -> R.Tensor((60,), "float32"): + gv: R.Tensor((60,), "float32") = R.flatten(x) + return gv + + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.flatten(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_layout_transform(): + transformation = lambda n, c, h, w: (n, h, w, c) + + @R.function + def foo(x: R.Tensor((2, 3, 4, 5), "float32")): + gv: R.Tensor((2, 4, 5, 3), "float32") = R.layout_transform(x, index_map=transformation) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.layout_transform(x, index_map=transformation)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_layout_transform_with_padding(): + transformation = lambda n, c, h, w: (n, c // 3, h, w, c % 3) + + @R.function + def foo(x: R.Tensor((10, 20, 2, 2), "float32")): + gv: R.Tensor((10, 7, 2, 2, 3), "float32") = R.layout_transform( + x, index_map=transformation, pad_value=2 + ) + return gv + + x = relax.Var("x", R.Tensor((10, 20, 2, 2), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.layout_transform(x, index_map=transformation, pad_value=2)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_permute_dims(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((2, 4, 3, 1), "float32"): + gv: R.Tensor((2, 4, 3, 1), "float32") = R.permute_dims(x, axes=[1, -1, 2, -4]) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.permute_dims(x, axes=[1, -1, 2, -4])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_permute_dims_none_arg(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((4, 3, 2, 1), "float32"): + gv: R.Tensor((4, 3, 2, 1), "float32") = R.permute_dims(x) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.permute_dims(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_reshape(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 3), "float32"): + gv: R.Tensor((8, 3), "float32") = R.reshape(x, (8, 3)) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.reshape(x, shape=(8, 3))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_reshape_infer_dim(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 1, 3), "float32"): + gv: R.Tensor((8, 1, 3), "float32") = R.reshape(x, (8, -1, 3)) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.reshape(x, shape=(8, -1, 3))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_split_by_indices(): + @R.function + def foo( + x: R.Tensor((2, 10, 4), dtype="float32") + ) -> R.Tuple( + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 4, 4), dtype="float32"), + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 4, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 1, 4), dtype="float32"), + ): + gv: R.Tuple( + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 4, 4), dtype="float32"), + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 4, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 1, 4), dtype="float32"), + ) = R.split(x, indices_or_sections=[-2, 2, 6, 4, 8, 12, 9], axis=1) + return gv + + x = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.split(x, indices_or_sections=[-2, 2, 6, 4, 8, 12, 9], axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_split_by_n_section(): + @R.function + def foo( + x: R.Tensor((2, 10, 4), dtype="float32") + ) -> R.Tuple( + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + ): + gv: R.Tuple( + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + ) = R.split(x, indices_or_sections=5, axis=1) + return gv + + x = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.split(x, indices_or_sections=5, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_squeeze(): + @R.function + def foo(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.squeeze(x) + return gv + + x = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.squeeze(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_squeeze_with_indices(): + @R.function + def foo(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), "float32"): + gv: R.Tensor((2, 3, 1, 4), "float32") = R.squeeze(x, axis=[3, -5]) + return gv + + x = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.squeeze(x, axis=[3, -5])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() From fe81ddaf60b5027e59f012bb869821016955f151 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 14 Feb 2023 16:58:14 -0500 Subject: [PATCH 24/81] [Unity] NestedMsg Support utility (#13995) This PR introduce NestedMsg to robustly handle nested-tuple analysis. Relax support nested tuple structures in the IR. Nested tuple structure is important to support advanced groupings in cases such as gradient calculation and other scenarios. The possible presence of nested tuple does mean that we need to to robustly handle analysis that contains nested tuple structures in a dataflow graph. This PR introduces a NestedMsg class that corresponds to a possibly nested message tuple for a given leaf message class T. We also introduces various helper functions to compose and decompose messages. Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Yixin Dong Co-authored-by: Ruihang Lai --- include/tvm/relax/nested_msg.h | 536 +++++++++++++++++++++++++++++++++ tests/cpp/nested_msg_test.cc | 318 +++++++++++++++++++ 2 files changed, 854 insertions(+) create mode 100644 include/tvm/relax/nested_msg.h create mode 100644 tests/cpp/nested_msg_test.cc diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h new file mode 100644 index 000000000000..93fc9a36c5dc --- /dev/null +++ b/include/tvm/relax/nested_msg.h @@ -0,0 +1,536 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/nested_msg.h + * \brief Helper container to store nested message for robust tuple-aware analysis. + * + * Please see NestedMsg for description of usage. + * + * \sa NestedMsg + */ +#ifndef TVM_RELAX_NESTED_MSG_H_ +#define TVM_RELAX_NESTED_MSG_H_ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Container that stores possibly nested message with leaf message type T. + * + * NestedMsg is a helper structure to store intermediate + * message state in pass analysis so we can robustly handle message + * passing with the presence of nested tuple types. + * + * Under the hood, NestedMsg[T] = Union[T, NullOpt, Array[NestedMsg[T]]]. + * Each nested message corresponds to the same nesting structure as + * the nested tuple types when we encounter them in analysis. + * + * Relax support nested tuple structures in the IR. Nested tuple structure + * is important to support advanced groupings in cases such as gradient calculation + * and other scenarios. + * + * The possible presence of nested tuple does mean that we need to + * to robustly handle analysis that contains nested tuple structures + * in a dataflow graph. + * + * \code + * + * v1 = relu(v0) + * v2 = exp(v0) + * t = ((v0, v1), (v2,), v0) + * t1 = t[0] + * v3 = concat(t1) + * v4 = t[2] + * v5 = add(v4, v3) + * + * \endcode + * + * Consider the above code sequence that contains a mixture of tuple + * nesting and normal operations. A common message-passing-based analysis + * will track messages attached to each intermediate variable. + * + * Because the intermediate value can contain nested-tuples, we need to have + * abilities to nest messages according to tuple structure and propagate them + * along the way. In python, this simply corresponds to using a tuple to hold + * nested messages. This class provides a helper wrapper in C++ to present such + * possibly nested message for a given leaf message. + * + * This design pattern is necessary to handle tuple values regardless of + * the normal form design of the IR to enable different messages for each + * tuple component without enforcing all tuple elements to have the same message. + * + * Please consider the following patterns in our pass: + * + * On a forward propagation message passing analysis: + * - Create map [leafnode=>NestedMsg], scan forward + * - input_msg = [MapToNestedMsg(x, lookup_map) for x in call->args] + * - output_msg = ForwardProp[call->op](input_msg, call) + * - map[binding->var] = output_msg + * - Use MapToNestedMsg to remap the remaining body. + * + * On a backward propagation message passing analysis: + * - Create map [leafnode=>NestedMsg], scan backward + * - output_msg = lookup map(binding->var) + * - handle case when output_msg is null + * - input_msg = BackProp[call->op](out_msg, call) + * - for arg, msg in zip(call->args, input_msg), + * DecomposeNestedMessage(arg, msg, lambda node, m: update_map(node, m)) + * - update_map(node, m) => CombineNestedMessage(map[node], m) + * + * Here leafnode is a node that you would like to propagate messages to + * such as constant, var and should not include tuple. + * + * We also recommend writing unit-test cases that involve nested tuple composition + * and decomposition. + * + * \sa MapToNestedMsg, DecomposeNestedMsg, CombineNestedMsg, ForEachLeaf, Equal + * + * \note If you want to write robust message passing-based analysis for + * programs that can contain nested tuples, you likely need to + * use this class or logic of a similar kind. + */ +template +class NestedMsg : public ObjectRef { + public: + // default constructors. + NestedMsg() = default; + NestedMsg(const NestedMsg&) = default; + NestedMsg(NestedMsg&&) = default; + NestedMsg& operator=(const NestedMsg&) = default; + NestedMsg& operator=(NestedMsg&&) = default; + /*! + * \brief Construct from an ObjectPtr + * whose type already satisfies the constraint + * \param ptr + */ + explicit NestedMsg(ObjectPtr ptr) : ObjectRef(ptr) {} + /*! \brief Nullopt handling */ + NestedMsg(runtime::NullOptType) {} // NOLINT(*) + // nullptr handling. + // disallow implicit conversion as 0 can be implicitly converted to nullptr_t + explicit NestedMsg(std::nullptr_t) {} + NestedMsg& operator=(std::nullptr_t) { + data_ = nullptr; + return *this; + } + // normal value handling. + NestedMsg(T other) // NOLINT(*) + : ObjectRef(std::move(other)) {} + NestedMsg& operator=(T other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + // Array> handling + NestedMsg(Array, void> other) // NOLINT(*) + : ObjectRef(std::move(other)) {} + NestedMsg& operator=(Array, void> other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + + // initializer list handling + NestedMsg(std::initializer_list> other) // NOLINT(*) + : NestedMsg(Array, void>(other)) {} + NestedMsg& operator=(std::initializer_list> other) { + return operator=(Array, void>(other)); + } + + // delete the int constructor + // since NestedMsg(0) is ambiguous + // 0 can be implicitly casted to nullptr_t + explicit NestedMsg(int val) = delete; + NestedMsg& operator=(int val) = delete; + // operator overloadings + bool operator==(std::nullptr_t) const { return data_ == nullptr; } + bool operator!=(std::nullptr_t) const { return data_ != nullptr; } + + /*! \return Whether the nested message is not-null leaf value */ + bool IsLeaf() const { return data_ != nullptr && data_->IsInstance(); } + + /*! \return Whether the nested message is null */ + bool IsNull() const { return data_ == nullptr; } + + /*! \return Whether the nested message is nested */ + bool IsNested() const { return data_ != nullptr && data_->IsInstance(); } + + /*! + * \return The underlying leaf value. + * \note This function checks if the msg is leaf. + */ + T LeafValue() const { + ICHECK(IsLeaf()); + return T(data_); + } + + /*! + * \return a corresponding nested array. + * \note This checks if the underlying data type is array. + */ + Array, void> NestedArray() const { + ICHECK(IsNested()); + return Array, void>(data_); + } + + using ContainerType = Object; + using LeafContainerType = typename T::ContainerType; + + static_assert(std::is_base_of::value, "NestedMsg is only defined for ObjectRef."); + + static constexpr bool _type_is_nullable = true; +}; + +/*! + * \brief Apply fvisit for each leaf elements in the nested message. + * \param fvisit The visit callback. + * \param msg The input nested message. + * \tparam T the content type of nested msg + * \tparam FType the visitor type with signature void fvisit(T) + */ +template +void ForEachLeaf(const NestedMsg& msg, FType fvisit) { + if (msg == nullptr) return; + if (msg.IsLeaf()) { + fvisit(msg.LeafValue()); + } else { + for (NestedMsg x : msg.NestedArray()) { + ForEachLeaf(x, fvisit); + } + } +} + +/*! + * \brief Recursively compare two nested messages. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param fequal The equal functor with signature bool fequal(T, T) + * \tparam T the content type of nested msg + * \tparam FType the equal comparator type + */ +template +bool Equal(const NestedMsg& lhs, const NestedMsg& rhs, FType fequal) { + if (lhs.IsNull()) return rhs.IsNull(); + if (rhs.IsNull()) return lhs.IsNull(); + if (lhs.IsLeaf()) { + return rhs.IsLeaf() && fequal(lhs.LeafValue(), rhs.LeafValue()); + } else { + if (!rhs.IsNested()) return false; + Array> arr_lhs = lhs.NestedArray(); + Array> arr_rhs = rhs.NestedArray(); + if (arr_lhs.size() != arr_rhs.size()) return false; + for (size_t i = 0; i < arr_lhs.size(); ++i) { + if (!Equal(arr_lhs[i], arr_rhs[i], fequal)) return false; + } + return true; + } +} + +/*! + * \brief Map expr with possible nested-tuple to nested message. + * + * This function will unpack recursive tuples and run fmapleaf for each leaf, + * then recursively combines the results together into a NestedMsg. + * + * The nesting structure will corresponds to the tuple structure. + * + * \param expr The input expression. + * \param fmapleaf The mapping function for each leaf with signature `NestedMsg fmap(Expr)` + * \tparam T the content type of nested msg + * \tparam FType The mapping function type + */ +template +NestedMsg MapToNestedMsg(Expr expr, FType fmapleaf) { + if (auto* tuple = expr.as()) { + Array> res; + res.reserve(tuple->fields.size()); + for (Expr x : tuple->fields) { + res.push_back(MapToNestedMsg(x, fmapleaf)); + } + return res; + } else { + return fmapleaf(expr); + } +} + +/*! + * \brief Map structinfo with possible nested-sinfo to nested message. + * + * This function will unpack recursive sinfo and run fmapleaf for each leaf, + * then recursively combines the results together into a NestedMsg. + * + * The nesting structure will corresponds to the tuple structure. + * + * \param sinfo The input struct info. + * \param fmapleaf The mapping function for each leaf with signature `NestedMsg fmap(StructInfo)` + * \tparam T the content type of nested msg + * \tparam FType The mapping function type + */ +template +NestedMsg MapToNestedMsg(StructInfo sinfo, FType fmapleaf) { + if (auto* tuple = sinfo.as()) { + Array> res; + res.reserve(tuple->fields.size()); + for (StructInfo x : tuple->fields) { + res.push_back(MapToNestedMsg(x, fmapleaf)); + } + return res; + } else { + return fmapleaf(sinfo); + } +} + +/*! + * \brief Map expr with possible nested-tuple to nested message. + * + * This function will unpack recursive expr by its struct info and + * run fmapleaf for each leaf, then recursively combines the results + * together into a NestedMsg. + * + * The nesting structure will corresponds to the struct info of expr. + * + * \param expr The input expression which should have struct info. + * \param fmapleaf The mapping function for each leaf with signature `NestedMsg fmapleaf(Expr)` + * \tparam T the content type of nested msg + * \tparam FType The mapping function type + */ +template +NestedMsg MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) { + auto sinfo = GetStructInfo(expr); + if (auto* tuple = sinfo.as()) { + Array> res; + res.reserve(tuple->fields.size()); + for (size_t i = 0; i < tuple->fields.size(); ++i) { + Expr field; + if (const auto* expr_tuple = expr.as()) { + field = expr_tuple->fields[i]; + } else { + field = TupleGetItem(expr, i); + UpdateStructInfo(field, tuple->fields[i]); + } + res.push_back(MapToNestedMsgBySInfo(field, fmapleaf)); + } + return res; + } else { + return fmapleaf(expr); + } +} + +/*! + * \brief Map nested message back to the expr. + * + * This function will decompose the nested message and + * run fmapleaf for each leaf message and get the leaf expr, + * then recursively combines the results as tuple expr. + * + * \param msg The input nested message. + * \param fmapleaf The mapping function for each leaf with signature `Expr fmapleaf(Optional)`. + * \tparam T the content type of nested msg. + * \tparam FType The mapping function type. + */ +template +Expr NestedMsgToExpr(NestedMsg msg, FType fmapleaf) { + if (msg.IsNull()) { + return fmapleaf(NullOpt); + } else if (msg.IsLeaf()) { + return fmapleaf(msg.LeafValue()); + } else { + ICHECK(msg.IsNested()); + Array> arr = msg.NestedArray(); + Array subexpr; + subexpr.reserve(arr.size()); + for (size_t i = 0; i < arr.size(); ++i) { + subexpr.push_back(NestedMsgToExpr(arr[i], fmapleaf)); + } + Optional simplified_tuple; + bool simplified_flag = false; + if (subexpr.size() >= 1) { + simplified_flag = true; + for (size_t i = 0; i < subexpr.size() && simplified_flag; ++i) { + auto* node = subexpr[i].as(); + if (node == nullptr || node->index != static_cast(i)) { + simplified_flag = false; + } else { + if (simplified_tuple.defined()) { + simplified_flag &= (simplified_tuple == node->tuple); + } else { + simplified_tuple = node->tuple; + ICHECK(simplified_tuple.defined()); + } + } + } + } + return simplified_flag ? simplified_tuple.value() : Tuple(subexpr); + } +} + +/*! + * \brief Recursively combine two nested message into one. + * + * This function requires the two messages to be compatible with each other. + * The combination rule is as follows: + * - combine(null, msg) => msg + * - combine(leaf1, leaf2) => fcombine(leaf1, leaf2) + * - combine(array1, array2) => [combine(x, y) for x, y in zip(array1, array2)] + * - This function will throw an error if array have different size + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param fcombine with signature T fcombine(T lhs, T rhs) + * \tparam T the content type of nested msg + * \tparam FType combine function type. + */ +template +NestedMsg CombineNestedMsg(NestedMsg lhs, NestedMsg rhs, FType fcombine) { + if (lhs.IsNull()) return rhs; + if (rhs.IsNull()) return lhs; + + if (lhs.IsLeaf()) { + ICHECK(rhs.IsLeaf()) << "Cannot combine leaf with nested"; + return NestedMsg(fcombine(lhs.LeafValue(), rhs.LeafValue())); + } else { + ICHECK(lhs.IsNested()); + ICHECK(rhs.IsNested()) << "Cannot combine leaf with nested"; + Array> arr_lhs = lhs.NestedArray(); + Array> arr_rhs = rhs.NestedArray(); + ICHECK_EQ(arr_lhs.size(), arr_rhs.size()) + << "Cannot combine two nested array with different sizes"; + Array> res; + res.reserve(arr_lhs.size()); + for (size_t i = 0; i < arr_lhs.size(); ++i) { + res.push_back(CombineNestedMsg(arr_lhs[i], arr_rhs[i], fcombine)); + } + return NestedMsg(res); + } +} + +/*! + * \brief Recursively map a nested message to another one, with leaf mapped by the input fmapleaf. + * \param msg The nested message to be mapped. + * \param fmapleaf The leaf map function, with signature NestedMsg fmapleaf(T msg) + * \tparam T The content type of nested message. + * \tparam FType The leaf map function type. + * \return The new nested message. + */ +template +NestedMsg MapNestedMsg(NestedMsg msg, FType fmapleaf) { + if (msg.IsNull()) { + return msg; + } else if (msg.IsLeaf()) { + return fmapleaf(msg.LeafValue()); + } else { + ICHECK(msg.IsNested()); + Array> arr = msg.NestedArray(); + Array> res; + res.reserve(arr.size()); + for (int i = 0; i < static_cast(arr.size()); ++i) { + res.push_back(MapNestedMsg(arr[i], fmapleaf)); + } + return NestedMsg(res); + } +} + +/*! + * \brief Recursively decompose the tuple structure in expr and msg along with it. + * + * This function will call fvisitleaf for each leaf expression in expr. + * This function will throw an error if the nesting structure in msg does not + * match the tuple nesting structure in expr. + * + * \param expr The input expression to be decomposed. + * \param msg The input nested message. + * \param fvisitleaf with signature fvisitleaf(Expr expr, NestedMsg msg) + * \tparam T the content type of nested msg + * \tparam FType The visit function type. + */ +template +void DecomposeNestedMsg(Expr expr, NestedMsg msg, FType fvisitleaf) { + if (auto* tuple = expr.as()) { + ICHECK(msg.IsNested()) << "Expected nested to match tuple"; + Array> arr = msg.NestedArray(); + ICHECK_EQ(arr.size(), tuple->fields.size()) << "Expected nested array size to match tuple size"; + for (size_t i = 0; i < arr.size(); ++i) { + DecomposeNestedMsg(tuple->fields[i], arr[i], fvisitleaf); + } + } else { + fvisitleaf(expr, msg); + } +} + +/*! + * \brief Recursively transform the tuple structure in expr and msgs along with it. + * + * This function will call ftransleaf for each leaf expression in expr. + * This function will throw an error if the nesting structure in msg does not + * match the tuple nesting structure in expr. + * + * \param expr The input expression to be transform.  + * \param msgs The input messages to guide the transformation. + * \param ftransleaf with signature ftransleaf(Expr, Array>)->Expr + * \tparam T the content type of nested msg + * \tparam N the number of messages + * \tparam FType The visit function type. + */ +template +Expr TransformTupleLeaf(Expr expr, std::array, N> msgs, FType ftransleaf) { + StructInfo sinfo = GetStructInfo(expr); + if (const auto* tuple = sinfo.as()) { + std::array>, N> msg_arrays; + for (size_t i = 0; i < N; ++i) { + ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; + msg_arrays[i] = msgs[i].NestedArray(); + } + bool same = true; + Array fields; + fields.reserve(tuple->fields.size()); + for (size_t i = 0; i < tuple->fields.size(); ++i) { + Expr field; + if (const auto* expr_tuple = expr.as()) { + field = expr_tuple->fields[i]; + } else { + field = TupleGetItem(expr, i); + UpdateStructInfo(field, tuple->fields[i]); + } + std::array, N> sub_msgs; + for (size_t j = 0; j < N; ++j) { + sub_msgs[j] = msg_arrays[j][i]; + } + fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf)); + same &= (fields.back().same_as(field)); + } + return same ? expr : Tuple(fields); + } else { + for (const auto& msg : msgs) { + ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple"; + } + return ftransleaf(expr, msgs); + } +} + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_NESTED_MSG_H_ diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc new file mode 100644 index 000000000000..48af552007fd --- /dev/null +++ b/tests/cpp/nested_msg_test.cc @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace tvm; +using namespace tvm::runtime; +using namespace tvm::relax; + +TEST(NestedMsg, Basic) { + // start with no annotation + relax::Var x("x", NullOpt), y("y", NullOpt); + + // constructor from array, T and nullopt. + NestedMsg msg({x, NullOpt, x}); + + EXPECT_TRUE(msg.IsNested()); + EXPECT_FALSE(msg.IsLeaf()); + EXPECT_TRUE(msg != nullptr); + + EXPECT_ANY_THROW(msg.LeafValue()); + + auto arr = msg.NestedArray(); + EXPECT_TRUE(arr[0].same_as(x)); + EXPECT_TRUE(arr[1] == nullptr); + EXPECT_TRUE(arr[1].IsNull()); + + EXPECT_TRUE(arr[2].LeafValue().same_as(x)); + + auto a0 = arr[0]; + EXPECT_TRUE(a0.IsLeaf()); + + // assignment + // assign null + a0 = NullOpt; + EXPECT_TRUE(a0 == nullptr); + + // assign array + a0 = {x, {x, NullOpt, y}}; + EXPECT_TRUE(a0.IsNested()); + auto t0 = a0.NestedArray()[1]; + EXPECT_TRUE(t0.IsNested()); + EXPECT_TRUE(t0.NestedArray()[2].same_as(y)); + + // assign leaf + a0 = x; + + EXPECT_TRUE(a0.IsLeaf()); + EXPECT_TRUE(a0.same_as(x)); +} + +TEST(NestedMsg, ForEachLeaf) { + relax::Var x("x", NullOpt), y("y", NullOpt); + NestedMsg msg = {x, {x, y}, NullOpt, {x, {x, y}}}; + + int x_count = 0, y_count = 0; + + ForEachLeaf(msg, [&](const Expr& v) { + if (v.same_as(x)) ++x_count; + if (v.same_as(y)) ++y_count; + }); + EXPECT_EQ(x_count, 4); + EXPECT_EQ(y_count, 2); +} + +TEST(NestedMsg, Equal) { + relax::Var x("x", NullOpt), y("y", NullOpt); + relax::Var z("z", NullOpt); + + auto fequal = [](Expr lhs, Expr rhs) { return lhs.same_as(rhs); }; + + using M = NestedMsg; + + EXPECT_TRUE(Equal(M(NullOpt), M(NullOpt), fequal)); + + EXPECT_TRUE(Equal(M(x), M(x), fequal)); + + EXPECT_TRUE(Equal(M({x, y}), M({x, y}), fequal)); + + EXPECT_TRUE(Equal(M({x, NullOpt}), M({x, NullOpt}), fequal)); + + EXPECT_TRUE(Equal(M({x, {NullOpt, y}}), M({x, {NullOpt, y}}), fequal)); + + EXPECT_TRUE(Equal(M({x, {NullOpt, y}, {x, z}}), M({x, {NullOpt, y}, {x, z}}), fequal)); + + // type mismatch + EXPECT_FALSE(Equal(M({x, {NullOpt, y}, x}), M({x, {NullOpt, y}, {x, z}}), fequal)); + + EXPECT_FALSE(Equal(M({x, {NullOpt, y}, {x, NullOpt}}), M({x, {NullOpt, y}, {x, z}}), fequal)); + + EXPECT_FALSE(Equal(M({x, {NullOpt, y}}), M({x, {NullOpt, y}, {x, z}}), fequal)); + + EXPECT_FALSE(Equal(M(x), M(NullOpt), fequal)); + + EXPECT_FALSE(Equal(M(NullOpt), M(x), fequal)); + + EXPECT_FALSE(Equal(M(x), M(Array({x})), fequal)); + + EXPECT_FALSE(Equal(M(Array({x})), M(x), fequal)); +} + +TEST(NestedMsg, MapAndDecompose) { + relax::Var x("x", PrimStructInfo(runtime::DataType::Int(16))); + relax::Var y("y", PrimStructInfo(runtime::DataType::Int(32))); + relax::Var z("z", PrimStructInfo(runtime::DataType::Int(64))); + + BlockBuilder bb = BlockBuilder::Create(NullOpt); + relax::Expr t0 = bb->Normalize(Tuple({x, y})); + relax::Expr t1 = bb->Normalize(Tuple({t0, x, z, t0})); + + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + + auto output = MapToNestedMsg(t1, [&](Expr value) { + if (value.same_as(x)) return c0; + if (value.same_as(y)) return c1; + return c2; + }); + + NestedMsg expected = {{c0, c1}, c0, c2, {c0, c1}}; + + EXPECT_TRUE(Equal(output, expected, + [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); + + auto output2 = + MapToNestedMsg(GetStructInfo(t1), [&](StructInfo sinfo) -> NestedMsg { + const auto* prim_sinfo = sinfo.as(); + if (prim_sinfo == nullptr) return NullOpt; + int bits = prim_sinfo->dtype.bits(); + if (bits == 16) return c0; + if (bits == 32) return c1; + if (bits == 64) return c2; + return NullOpt; + }); + + EXPECT_TRUE(Equal(output2, expected, + [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); + + int x_count = 0, y_count = 0, z_count = 0; + + DecomposeNestedMsg(t1, expected, [&](Expr value, NestedMsg msg) { + if (value.same_as(x)) { + EXPECT_TRUE(msg.same_as(c0)); + ++x_count; + } else if (value.same_as(y)) { + EXPECT_TRUE(msg.same_as(c1)); + ++y_count; + } else { + EXPECT_TRUE(msg.same_as(c2)); + ++z_count; + } + }); + EXPECT_EQ(x_count, 3); + EXPECT_EQ(y_count, 2); + EXPECT_EQ(z_count, 1); +} + +TEST(NestedMsg, MapToNestedMsgBySInfo) { + auto sf0 = TensorStructInfo(DataType::Float(32), /*ndim=*/0); + auto sf1 = TupleStructInfo({sf0, sf0}); + auto sf2 = TupleStructInfo({sf0, sf0}); + auto x = relax::Var("x", TupleStructInfo({sf1, sf2, sf0})); + + auto msg = MapToNestedMsgBySInfo(x, [](Expr value) { return value; }); + + EXPECT_TRUE(msg.IsNested()); + auto arr = msg.NestedArray(); + + EXPECT_TRUE(arr[1].IsNested()); + auto arr1 = arr[1].NestedArray(); + + EXPECT_TRUE(arr1[0].IsLeaf()); + EXPECT_TRUE(StructuralEqual()(arr1[0].LeafValue(), TupleGetItem(TupleGetItem(x, 1), 0))); + + EXPECT_TRUE(arr[2].IsLeaf()); + EXPECT_TRUE(StructuralEqual()(arr[2].LeafValue(), TupleGetItem(x, 2))); +} + +TEST(NestedMsg, NestedMsgToExpr) { + auto sf0 = TensorStructInfo(DataType::Float(32), /*ndim=*/0); + auto sf1 = TupleStructInfo({sf0, sf0}); + + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + + relax::Var x("x", sf0), y("y", sf0), z("z", sf0); + + NestedMsg msg = {c0, {c0, c1}, {c0, {c1, c2}}}; + auto expr = NestedMsgToExpr(msg, [&](Optional leaf) { + ICHECK(leaf.defined()); + int value = leaf.value().IntValue(); + switch (value) { + case 0: + return x; + case 1: + return y; + default: + return z; + } + }); + + Expr expected = Tuple({x, Tuple({x, y}), Tuple({x, Tuple({y, z})})}); + EXPECT_TRUE(StructuralEqual()(expr, expected)); + + // test simplified + relax::Var t("t", sf1); + NestedMsg msg1 = {TupleGetItem(t, 0), TupleGetItem(t, 1)}; + auto expr1 = NestedMsgToExpr(msg1, [](Optional leaf) { return leaf.value(); }); + EXPECT_TRUE(StructuralEqual()(expr1, t)); +} + +TEST(NestedMsg, CombineNestedMsg) { + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + + NestedMsg lhs = {c0, {c0, c1}, NullOpt, {c0, {c1, c2}}}; + NestedMsg rhs = {c1, {c2, NullOpt}, NullOpt, {c1, {c2, c2}}}; + NestedMsg expected = {c1, {c2, c1}, NullOpt, {c1, {c2, c2}}}; + + auto output = CombineNestedMsg(lhs, rhs, [](Integer x, Integer y) { + if (x->value > y->value) return x; + return y; + }); + + EXPECT_TRUE(Equal(output, expected, + [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); +} + +TEST(NestedMsg, MapNestedMsg) { + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + auto c3 = Integer(3); + + NestedMsg msg = {c0, {c0, c1}, NullOpt, {c0, {c2, c1}}}; + NestedMsg expected = {c3, {c3, NullOpt}, NullOpt, {c3, {c2, NullOpt}}}; + + auto output = MapNestedMsg(msg, [](Integer x) { + if (x->value == 0) { + return NestedMsg(Integer(3)); + } else if (x->value == 1) { + return NestedMsg(); + } else { + return NestedMsg(x); + } + }); + + EXPECT_TRUE(Equal(output, expected, + [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); +} + +TEST(NestedMsg, TransformTupleLeaf) { + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + using NInt = NestedMsg; + + NInt msg1 = {c0, {c0, c1}, c2, {c0, {c1, c2}}}; + NInt msg2 = {c1, {c2, c0}, c2, {c1, {c2, c0}}}; + + PrimStructInfo s = PrimStructInfo(runtime::DataType::Int(32)); + relax::Var x("x", s), y("y", s), z("z", s); + BlockBuilder bb = BlockBuilder::Create(NullOpt); + Expr expr = bb->Normalize(Tuple({x, Tuple({x, x}), x, Tuple({x, Tuple({x, x})})})); + + auto ftransleaf = [&](Expr value, std::array msgs) -> Expr { + int lhs = Downcast(msgs[0].LeafValue())->value; + int rhs = Downcast(msgs[1].LeafValue())->value; + if (lhs > rhs) + return z; + else if (lhs == rhs) + return value; + else + return y; + }; + + Expr expected = Tuple({y, Tuple({y, z}), x, Tuple({y, Tuple({y, z})})}); + + EXPECT_TRUE(StructuralEqual()( + TransformTupleLeaf(expr, std::array({msg1, msg2}), ftransleaf), expected)); + + EXPECT_TRUE( + expr.same_as(TransformTupleLeaf(expr, std::array({msg1, msg1}), ftransleaf))); +} From 75b905796e387303682a3fbbfd78579047b419be Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 15 Feb 2023 21:35:14 +0800 Subject: [PATCH 25/81] [Unity][Pass] Operator Fusion Passes (#14001) [Unity][Pass] Operator fusion passes This PR introduces three passes for operator fusion: 1. AnnotateTIROpPattern: analysis the operator kind from PrimFunc. 2. FuseOps: fuse operators for Relax functions, which adds a new fused relax primitive function. 3. FuseTIR: fuse corresponding TIR PrimFuncs for the fused relax. --- include/tvm/relax/analysis.h | 11 + include/tvm/tir/buffer.h | 14 +- python/tvm/relax/transform/transform.py | 43 + .../transform/annotate_tir_op_pattern.cc | 55 ++ src/relax/transform/fuse_ops.cc | 909 ++++++++++++++++++ src/relax/transform/fuse_tir.cc | 728 ++++++++++++++ .../test_transform_annotate_tir_op_pattern.py | 360 +++++++ tests/python/relax/test_transform_fuse_ops.py | 759 +++++++++++++++ tests/python/relax/test_transform_fuse_tir.py | 563 +++++++++++ tests/python/relax/test_tvmscript_parser.py | 1 - 10 files changed, 3441 insertions(+), 2 deletions(-) create mode 100644 src/relax/transform/annotate_tir_op_pattern.cc create mode 100644 src/relax/transform/fuse_ops.cc create mode 100644 src/relax/transform/fuse_tir.cc create mode 100644 tests/python/relax/test_transform_annotate_tir_op_pattern.py create mode 100644 tests/python/relax/test_transform_fuse_ops.py create mode 100644 tests/python/relax/test_transform_fuse_tir.py diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 24cfe5b9bf11..a55fe6797d45 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -260,6 +260,17 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::Analyzer* ana = nullptr); +/*! + * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps. + * + * \param func The PrimFunc to be analyzed. + * \return The Op Pattern Kind. + * + * \note This analysis applies on TIR function but is primarily used by relax passes. + * As a result we place it under the relax namespace. + */ +TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func); + /*! * \brief Check if the given PrimFunc is essentially doing a reshape operation. * The reshape operation also includes expand_dims, squeeze, flatten, etc. diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index d7a2aec0b972..e3a853e4c7ea 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -34,6 +34,18 @@ namespace tvm { namespace tir { +#ifndef TVM_INDEX_DEFAULT_I64 +#define TVM_INDEX_DEFAULT_I64 1 +#endif +/*! \brief if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32 */ +inline DataType DefaultIndexType() { +#if TVM_INDEX_DEFAULT_I64 + return DataType::Int(64); +#else + return DataType::Int(32); +#endif +} + // forward declare Stmt class Stmt; @@ -135,7 +147,7 @@ class BufferNode : public Object { /*! \return preferred index type for this buffer node */ DataType DefaultIndexType() const { - return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32); + return shape.size() != 0 ? shape[0].dtype() : tvm::tir::DefaultIndexType(); } /*! \brief Determine the offset in the buffer of the given index. diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index cab18797c672..0f973db290f8 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -105,6 +105,49 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass: return _ffi_api.AttachGlobalSymbol() # type: ignore +def AnnotateTIROpPattern() -> tvm.ir.transform.Pass: + """Annotate Op Pattern Kind for TIR functions + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.AnnotateTIROpPattern() # type: ignore + + +def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass: + """This pass groups bindings in a dataflow block of Relax functions and generate a new grouped + Relax function for each group, according to the fusion algorithm described in the pass + implementation. By grouping bindings into new Relax functions, we substitute the bindings in + the function being manipulated into function calls to the new grouped function. + + A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. + + Parameters + ---------- + fuse_opt_level : int + The level of fuse optimization. -1 indicates that the level will be + inferred from pass context. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for operator fusion. + """ + return _ffi_api.FuseOps(fuse_opt_level) # type: ignore + + +def FuseTIR() -> tvm.ir.transform.Pass: + """Fuse primitive relax function into a larger TIR function if possible + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for tir fusion. + """ + return _ffi_api.FuseTIR() # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/src/relax/transform/annotate_tir_op_pattern.cc b/src/relax/transform/annotate_tir_op_pattern.cc new file mode 100644 index 000000000000..b1c1ed29aff3 --- /dev/null +++ b/src/relax/transform/annotate_tir_op_pattern.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/annotate_tir_op_pattern.cc + * \brief Annotate Op Pattern for TIR functions. It is a pass works on TIR PrimFuncs, + * but they are needed for relax fusion. So we put them in the relax namespace. + */ +#include +#include +#include + +namespace tvm { +namespace relax { + +tir::PrimFunc AnnotateOpPattern(tir::PrimFunc f) { + if (f->HasNonzeroAttr("op_pattern")) { + return f; + } else { + relay::OpPatternKind kind = AnalyzeOpPatternKind(f); + return WithAttr(std::move(f), "op_pattern", Integer(static_cast(kind))); + } +} + +namespace transform { + +Pass AnnotateTIROpPattern() { + auto pass_func = [=](tir::PrimFunc f, IRModule m, PassContext ctx) { + return AnnotateOpPattern(std::move(f)); + }; + return tir::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern").set_body_typed(AnnotateTIROpPattern); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc new file mode 100644 index 000000000000..f3559b72da3f --- /dev/null +++ b/src/relax/transform/fuse_ops.cc @@ -0,0 +1,909 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/fuse_ops.cc + * \brief This file contains a pass which groups bindings in a dataflow block of Relax + * functions and generate a new grouped Relax function for each group, according to the fusion + * algorithm described below. By grouping bindings into new Relax functions, we substitute the + * bindings in the function being manipulated into function calls to the new grouped function. + * + * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. + */ + +#include +#include +#include +#include +#include + +#include + +#include "../../relay/analysis/graph_partitioner.h" +#include "../../support/arena.h" + +namespace tvm { +namespace relax { + +/* + Note on Fusing algorithm: + + The main challenge of general fusor is to handle possible diamond shape branches, + in the following graph, conv2d can be fused to elemwise add. + + conv2d + / | \ + / | \ + op op op + \ | / + \ | / + elemwise add + | + + However, at the point of conv2d we do not necessarily know that all the future paths + will merge at the elemwise add. The fusion algorithm applies post-dominator analysis. + + The immediate post-dominator of a node defined by the closest node where all the future path goes + into. In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm + is as follows: + + - Construct a DAG of dataflow graph for dominator analysis + - Construct a post-dominator tree which gives immediate post dominator of each node. + - Run fusion algorithm with the given post-dominator information. + + Note that, because we run analysis on a DAG, we use a single pass post-dominator + tree construction algorithm via LCA, which is simpler than the full version that handles cycles. + + The fusion algorithm traverses from each node and checks if it can be fused to its + immediate post dominator. It has to check the following things: + + - CheckPath: check all the path between a node and its immediate post-dominator + satisfies the fuse condition. + - Note that these intermediate node can already be fused with another nodes, the algorithm + will still run correctly. + - CommitFuse: mark all the nodes between source and post-dominator as the same group. + - We use an Union-Find data structure to manage the groups. +*/ + +using relay::GraphPartitioner; +using relay::IndexedForwardGraph; +using relay::OpPatternKind; +using support::LinkNode; + +constexpr uint32_t kMaxFusedOps = 256; + +TVM_REGISTER_PASS_CONFIG_OPTION("relax.FuseOps.max_depth", Integer); + +class GraphCreator : public ExprVisitor { + public: + /*! + * \brief Create a IndexedForwardGraph according to the input module. The graph will be used for + * graph partition and operator fusion. + * \param mod The module which the creation accords to + * \param arena The allocator of all the internal node objects + * \return The created IndexedForwardGraph + */ + static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) { + // Since cross-function call is not supported yet, FuseOps only serves the entry function, whose + // name is "main". + auto relax_func = Downcast(mod->Lookup("main")); + GraphCreator creator(mod, arena); + creator(relax_func); + + // The algorithm of the graph creator ensures that each created node will be added to the + // post-dfs order and will be set its op pattern. Thus we check whether all these containers + // have the same size. + size_t n_nodes = creator.graph_.node_map.size(); + ICHECK_EQ(n_nodes, creator.graph_.post_dfs_order.size()); + ICHECK_EQ(n_nodes, creator.initialized_nodes_.size()); + + return creator.graph_; + } + + private: + explicit GraphCreator(IRModule mod, support::Arena* arena) + : mod_(std::move(mod)), arena_(arena) {} + + void VisitExpr_(const FunctionNode* func) final { + for (const Var& param : func->params) { + IndexedForwardGraph::Node* param_node = CreateNode(param.get()); + // The parameter is passed in from the outside, and thus it's marked as an external reference, + // and it's pattern is `kOpaque`. + MarkAsExternRef(param_node); + SetNodePattern(param_node, OpPatternKind::kOpaque); + AddToPostDFSOrder(param_node, param.get()); + } + ExprVisitor::VisitExpr_(func); + } + + void VisitBindingBlock(const BindingBlock& block) final { + if (const auto* df_block = block.as()) { + VisitBindingBlock_(df_block); + } + // We skip ordinary binding blocks since they might be impure (with side effect or control flow) + } + + // TODO(tvm-team): how to deal with MatchCast binding here + + void VisitBinding_(const VarBindingNode* binding) final { + IndexedForwardGraph::Node* node = CreateNode(binding->var.get()); + + // If the variable is not a dataflow variable, it must be the output variable of this dataflow + // block + if (!binding->var->IsInstance()) { + this->MarkAsExternRef(node); + } + if (const auto* call = binding->value.as()) { + // Case 1. The expression is a CallNode + VisitCall(call, node); + } else if (const auto* tuple_get_item = binding->value.as()) { + // Case 2. The expression is a TupleGetItemNode + VisitTupleGetItem(tuple_get_item, node); + } else { + VisitUnsupportedNode(binding->value, node); + // Case 3. The type of the expression is not fusion-supported. + // In this case, we skip adding edges, adding an empty node into graph. + } + AddToPostDFSOrder(node, binding->var.get()); + } + + /********** Non-Leaf Expression Nodes **********/ + + void VisitCall(const CallNode* call, IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + OpPatternKind pattern = OpPatternKind::kOpaque; + Array args = call->args; + + // - If the op being called is a TIR PrimFunc, we get the function op pattern directly from the + // function attribute and visit the arguments one by one. + // - Otherwise, the pattern of the current binding variable node is set to `kOpaque`, and we + // recurse into the call expression. + const auto* op = call->op.as(); + if (op == call_tir_op_.get()) { + const GlobalVar& global_var = Downcast(call->args[0]); + tir::PrimFunc func = Downcast(mod_->Lookup(global_var)); + + // Override args for call_tir + args = Downcast(call->args[1])->fields; + + // TODO(tvm-team): handle the shape argument (args[3]) + Optional opt_pattern = func->GetAttr("op_pattern"); + if (opt_pattern.defined()) { + pattern = static_cast(Downcast(opt_pattern)->value); + } else { + pattern = OpPatternKind::kOpaque; + } + } + // The pattern of the current binding variable node is set to the pattern of this operator. + SetNodePattern(binding_var_node, pattern); + // Visit all call args + for (const Expr& arg : args) { + ICHECK(IsLeaf(arg)); + VisitLeaf(arg, binding_var_node, pattern); + } + } + + void VisitTupleGetItem(const TupleGetItemNode* tuple_item, + IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + + SetNodePattern(binding_var_node, OpPatternKind::kInjective); + VisitLeaf(tuple_item->tuple, binding_var_node, OpPatternKind::kInjective); + } + + void VisitUnsupportedNode(const Expr& expr, IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + SetNodePattern(binding_var_node, OpPatternKind::kOpaque); + + auto visit_leaves = [this, &binding_var_node](const Expr& e) { + if (e->IsInstance() || e->IsInstance()) { + VisitLeaf(e, binding_var_node, OpPatternKind::kOpaque); + } + }; + PostOrderVisit(expr, visit_leaves); + } + + /********** Leaf Expression Nodes **********/ + + void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node* binding_var_node, + const OpPatternKind& pattern) { + ICHECK_NOTNULL(binding_var_node); + + // Recursive visit if it's Tuple + if (const auto* tuple = leaf_expr.as()) { + for (const Expr& expr : tuple->fields) { + VisitLeaf(expr, binding_var_node, pattern); + } + return; + } + + auto it = graph_.node_map.find(leaf_expr.get()); + IndexedForwardGraph::Node* leaf_node = nullptr; + if (it != graph_.node_map.end()) { + leaf_node = it->second; + } else if (leaf_expr->IsInstance()) { + leaf_node = CreateNode(leaf_expr.get()); + // Since we never fuse constants, the pattern of the constant is set to `kOpaque`. + SetNodePattern(leaf_node, OpPatternKind::kOpaque); + AddToPostDFSOrder(leaf_node, leaf_expr.get()); + } else { + LOG(FATAL) << "The leaf Expr is supposed to be defined before, but got: " << leaf_expr + << " used before definition."; + } + AddEdge(leaf_node, binding_var_node, pattern); + } + + /********** Helper Functions **********/ + + /*! + * \brief Check whether the expression is a leaf expression + * \param expr The expression to be checked + * \return Whether the expression is a leaf expression + * \note In order to avoid too much refactor, this method is a simple copy-paste of the is-leaf + * check in "block_builder.cc". And it should be refactored in the future. + * \sa src/relax/ir/block_builder.cc + */ + static bool IsLeaf(const Expr& expr) { + // NOTE: Tuples are treated as leaf nodes for ergonomics + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as() || expr.as() || + expr.as(); + } + + /*! + * \brief Create a graph node corresponding to the input key + * \param key The object which is used to create the graph node + * \return The created graph node + * \note The node corresponding to each key is supposed to be created for only once + */ + IndexedForwardGraph::Node* CreateNode(const Object* key) { + ICHECK(graph_.node_map.find(key) == graph_.node_map.end()) + << "The node corresponding to the input key is not supposed to be created before"; + auto* node = arena_->make(); + graph_.node_map[key] = node; + return node; + } + + /*! + * \brief Append the input node to the post-dfs order of the graph + * \param node The node to be appended + * \param key The key corresponding to the node + * \note Each node is supposed to be appended to the post-dfs order for only once + */ + void AddToPostDFSOrder(IndexedForwardGraph::Node* node, const Object* key) { + auto it = graph_.node_map.find(key); + ICHECK(it != graph_.node_map.end() && it->second == node) + << "The node must have been created before adding to the post-dfs order"; + + // We only set the reference of the node when adding it to the post-dfs order. Thus, if the + // reference of a node is already set, it must have been appended to the post-dfs order. + ICHECK(node->ref == nullptr) + << "The node is not supposed to be added into the post-dfs order before"; + + node->ref = key; + node->index = graph_.post_dfs_order.size(); + graph_.post_dfs_order.push_back(node); + } + + /*! + * \brief Add an edge from the input start to the input end in the graph, with specific pattern + * \param start The start of the edge + * \param end The end of the edge + * \param pattern The pattern of this edge + */ + void AddEdge(IndexedForwardGraph::Node* start, IndexedForwardGraph::Node* end, + OpPatternKind pattern) { + auto* link = arena_->make>(); + link->value.node = end; + link->value.pattern = pattern; + start->outputs.Push(link); + } + + /*! + * \brief Mark a given node as "external reference", which means the node cannot be fused as an + * intermediate node + * \param node The graph node to be marked + */ + void MarkAsExternRef(IndexedForwardGraph::Node* node) { node->extern_ref = true; } + + /*! + * \brief Set the pattern of the input node + * \param node The graph node to be set + * \param pattern The pattern of the node + */ + void SetNodePattern(IndexedForwardGraph::Node* node, OpPatternKind pattern) { + ICHECK(initialized_nodes_.find(node) == initialized_nodes_.end()) + << "The input node is supposed to be set pattern for only once"; + initialized_nodes_.insert(node); + node->pattern = pattern; + } + + private: + /*! \brief The IRModule from which the indexed forward graph is created */ + IRModule mod_; + /*! \brief The allocator of all the internal node objects */ + support::Arena* arena_; + /*! \brief The created indexed forward graph */ + IndexedForwardGraph graph_; + /*! \brief The graph nodes whose patterns are set */ + std::unordered_set initialized_nodes_; +}; + +/*! + * \brief The ExprMutator used to create a new grouped function + * \details The workflow of this ExprMutator is: + * - The bindings in the function will be added by OperatorFusor via `AppendBinding(...)`. + * - When adding a new binding through `AppendBinding(...)`, we check whether the variables and + * constants used by the binding are defined by some previous added binding. And for the undefined + * variables and constants, we add them to the argument list and created new variables as the + * corresponding parameters. + * - When `CreateFunction()` is called, we go through each binding and update the binding with the + * new parameters. After that we wrap all bindings with a DataflowBlock and a Function. + */ +class FunctionCreator : public ExprMutator { + public: + explicit FunctionCreator(bool lift_constant) : lift_constant_(lift_constant) {} + /*! + * \brief Append a new binding to this function and possibly create new parameters for the + * function accordingly + * \param binding The binding to be appended + * \note Allowed bindings are: + * - VarBinding with value being a call node calling `relax.call_tir`. + * - VarBinding with value being a tuple-get-item node. + * // TODO(tvm-team): handle match shape + */ + void AppendBinding(const Binding& binding) { + ICHECK(!function_.defined()) + << "The `function_` is supposed to be uncreated when adding bindings"; + + if (const auto* var_binding = binding.as()) { + if (const auto* call = var_binding->value.as()) { + if (call->op == Op::Get("relax.call_tir")) { + // Update the name of the function. + name_hint_ = name_hint_ + "_" + Downcast(call->args[0])->name_hint; + + const Tuple& args = Downcast(call->args[1]); + for (const Expr& arg : args->fields) { + CheckDefAndUpdateParam(arg); + } + // TODO(tvm-team): handle shape expr + } else { + if (call->op->IsInstance()) { + name_hint_ = name_hint_ + "_" + Downcast(call->op)->name; + } else if (call->op->IsInstance()) { + std::string gvar_name = Downcast(call->op)->name_hint; + if (auto pos = gvar_name.find("fused_"); pos == 0) { + name_hint_ = name_hint_ + "_" + gvar_name.substr(std::string("fused_").size()); + } else { + name_hint_ = name_hint_ + "_" + gvar_name; + } + } + + for (const Expr& arg : call->args) { + CheckDefAndUpdateParam(arg); + } + } + } else { + const auto* tuple_item = var_binding->value.as(); + ICHECK(tuple_item != nullptr); + CheckDefAndUpdateParam(tuple_item->tuple); + } + + // Mark the binding variable as defined. + defined_vars_.insert(var_binding->var.get()); + // Set var as output true if the binding is not a dataflow variable + if (!var_binding->var->IsInstance()) { + AppendOutput(var_binding->var); + } + } else { + // TODO(tvm-team): handle match_cast + } + bindings_.push_back(binding); + } + + /*! \brief Set a var defined in the group as output. */ + size_t AppendOutput(const Var& var) { + ICHECK(defined_vars_.count(var.get())); + auto output_idx = GetOutputIndex(var); + if (output_idx) { + return *output_idx; + } + output_vars_.push_back(var.get()); + return output_vars_.size() - 1; + } + + /*! + * \brief Create the grouped function according according to the collected bindings and parameters + * \param composite_name The name to identify the pattern this function is created from, if any. + * It will become the value of the kComposite attribute of the created function. + * \note The created function won't be returned immediately. It's stored in the `function_` field. + */ + void CreateFunction(Map group_attrs) { + // Step 1. Start constructing a new dataflow block. + builder_->BeginDataflowBlock(); + + // Step 2. Visit each binding and collect outputs one by one. + Array outputs(output_vars_.size(), Expr()); + for (const Binding& binding : bindings_) { + if (auto output_idx = GetOutputIndex(binding->var)) { + // Case 1. It is an output binding + // We only allow VarBinding as output. + const auto* var_binding = binding.as(); + ICHECK_NOTNULL(var_binding); + Var output_var = builder_->EmitOutput(VisitExpr(var_binding->value)); + var_remap_[var_binding->var->vid] = output_var; + outputs.Set(*output_idx, output_var); + } else { + // Case 2. It is an internel binding, add it to the binding list. + VisitBinding(binding); + } + } + + // Step 3. Finish constructing the new block. + BindingBlock new_block = builder_->EndBlock(); + ICHECK(!outputs.empty()) << "At least one output is required."; + Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs); + body = builder_->Normalize(body); + body = builder_->Normalize(SeqExpr({new_block}, body)); + group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1)); + function_ = Function(/*params=*/params_, // + /*body=*/body, // + /*ret_struct_info=*/NullOpt, // + /*attrs=*/DictAttrs(group_attrs)); + } + + /*! \brief The original bindings of the function */ + Array bindings_; + /*! \brief The parameters of the function */ + Array params_; + /*! \brief The arguments to call the function on the caller side */ + Array arguments_; + /*! \brief The name for the fused function */ + String name_hint_ = "fused"; + /*! \brief The constructed Relax function */ + Function function_{nullptr}; + + private: + std::optional GetOutputIndex(Var v) { + auto it = std::find(output_vars_.begin(), output_vars_.end(), v.get()); + if (it != output_vars_.end()) { + return std::distance(output_vars_.begin(), it); + } + return std::nullopt; + } + + /*! + * \brief Check whether the input expression is defined within this function. If not, create a new + * parameter for the expression. + * \param expr The expression to be checked + */ + void CheckDefAndUpdateParam(const Expr& expr) { + // If the expression has already served as an argument, no need to create another one for it. + if (std::find(arguments_.begin(), arguments_.end(), expr) != arguments_.end()) { + return; + } + + // If the expression is not a variable or is a undefined variable, it should be populated as a + // parameter of the relax function. + const auto* var = expr.as(); + if ((var == nullptr || defined_vars_.count(var) == 0) && + (lift_constant_ || !expr->IsInstance())) { + String name{nullptr}; + if (var != nullptr) { + name = var->name_hint(); + } else { + name = String("param_" + std::to_string(n_param_for_const_++)); + } + + Var param(std::move(name), GetStructInfo(expr)); + arguments_.push_back(expr); + params_.push_back(param); + } + } + + Expr VisitExpr(const Expr& expr) final { + // If the expression serves as an argument, return its correspondng parameter. + auto it = std::find(arguments_.begin(), arguments_.end(), expr); + if (it != arguments_.end()) { + return params_[it - arguments_.begin()]; + } + // Otherwise, recurse into this expression. + return ExprMutator::VisitExpr(expr); + } + + private: + /*! \brief The variables defined in this function */ + std::unordered_set defined_vars_; + /*! \brief The number of parameters reserved for constants */ + int n_param_for_const_ = 0; + /*! \brief The output vars */ + std::vector output_vars_; + /*! \brief Whether or not to lift bound constants to parameters */ + bool lift_constant_; +}; + +/*! + * \brief The ExprMutator used to fuse the operators in Relax functions + * \details Given the partition results on the indexed-forward graph, for each group whose size is + * larger than one, we create a new grouped function for it, containing all bindings in that group. + * And we substitute the bindings in a group with a single function call to the newly created + * grouped function. The workflow of this ExprMutator is: for each dataflow block, + * - we go through the bindings one by one. For each binding, if it is in a group whose size is + * larger than one, we add the binding to the function of the group it is in and update the + * parameters and arguments of that function; + * - then we finalize all the grouped functions by updating their bindings using BlockBuilder; + * - lastly, we go through the bindings again and substitute the bindings in a group with a single + * call to the corresponding grouped function. + * + * After transforming a Relax function, we update the function in the IRModule. Besides, we add all + * newly created grouped function to the IRModule. + */ +class OperatorFusor : public ExprMutator { + public: + using Group = GraphPartitioner::Group; + using GroupMap = std::unordered_map; + + OperatorFusor(IRModule mod, const GroupMap& obj2group, bool lift_constants = true) + : ExprMutator(mod), + mod_(std::move(mod)), + obj2group_(obj2group), + lift_constants_(lift_constants) {} + + /*! + * \brief Construct a new operator fusor. Given the indexed-forward graph and the graph partition + * result on that graph, the constructor creates a mapping from each leaf AST object + * (e.g. parameters, variables, constants) to the group of the node corresponding to the object + * in the graph. + * \param mod The IRModule to be transformed + * \param graph The indexed-forward graph of the input IRModule + * \param groups The grouped result of the group partition on the input indexed-forward graph. + */ + OperatorFusor(IRModule mod, const IndexedForwardGraph& graph, const std::vector& groups, + bool lift_constant = true) + : OperatorFusor(mod, CreateGroupMap(graph, groups), lift_constant) {} + + /*! + * \brief The main transformation on the IRModule + * \return The new IRModule after transformation + */ + IRModule Transform() { + for (const auto& [gv, func] : mod_->functions) { + // Only visit Relax function without attr kPrimitive. + if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { + auto updated_func = Downcast(VisitExpr(func)); + builder_->UpdateFunction(gv, updated_func); + } + } + return builder_->GetContextIRModule(); + } + + private: + static GroupMap CreateGroupMap(const IndexedForwardGraph& graph, + const std::vector& groups) { + GroupMap obj2group; + for (int nid = 0; nid < static_cast(graph.post_dfs_order.size()); ++nid) { + Group* group_root = groups[nid]->FindRoot(); + ICHECK(group_root != nullptr); + ICHECK(graph.post_dfs_order[nid]->ref != nullptr); + obj2group[graph.post_dfs_order[nid]->ref] = group_root; + } + return obj2group; + } + + bool IsTupleOutput(Function f) { + auto sinfo = GetStructInfo(f).as(); + ICHECK(sinfo); + return sinfo->ret->IsInstance(); + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) final { + if (const auto* df_block = block.as()) { + return VisitBindingBlock_(df_block); + } + // We skip ordinary binding blocks since they might be impure (with side effect or control flow) + return block; + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { + group2func_.clear(); + + // Step 1. Collect the bindings for each grouped function. + CollectFuncBindings(block->bindings); + + // Step 2. Collect all group's boundary (i.e. the output vars for each group) + CollectFuncBoundary(block->bindings); + + // Step 3. Create the grouped function for each group. + for (auto& [g, creator] : group2func_) { + creator.CreateFunction(g->attrs); + } + + // Step 4. Start generating the new binding block. + // - For groups with single binding, we directly recurse into the binding and emit the new one. + // - For groups with multiple bindings, we emit the call to the grouped function only when + // visiting the last binding of the group, because only by doing this we don't break the + // dependencies among the bindings of different groups. And therefore, we will skip all but the + // last binding of the group. + builder_->BeginDataflowBlock(); + + // For each group, record which variables need to be remapped to the output of TupleGetItem. + // Only relevant when the output of the grouped function is a tuple. + std::unordered_map> pending_tuple_get; + + // A grouped function which returns a tuple requires attaching TupleGetItem to each element and + // remapping variables in earlier bindings approriately. Thus, a binding whose value depends on + // some elements of a tuple from other group's function must be emitted after a call to the + // tuple-producing function is emitted and remapping is done. + // To guarantee this, we process bindings in the order of the topological sort of the group + // dependency relations. + for (const auto& binding : TopoSortByGroupDep(block->bindings)) { + // Case 1. If the binding is the only binding in its group, recurse into it and emit the + // transformed binding as usual. + Group* group = GetGroupFromBinding(binding); + if (group->num_nodes == 1 && group->attrs.empty()) { + VisitBinding(binding); + continue; + } + + const auto& it_creator = group2func_.find(group); + ICHECK(it_creator != group2func_.end()); + const FunctionCreator& func_info = it_creator->second; + + // If this binding belongs to a group whose output is a tuple, the original bound variable + // needs to be remapped to the output of TupleGetItem after the corresponding tuple is + // emitted. + if (IsTupleOutput(func_info.function_) && tuple_get_indices_.count(binding->var.get())) { + pending_tuple_get[group].push_back(binding->var); + } + + // Case 2. If the binding is not the last binding of the group, we skip it. + if (!func_info.bindings_.back().same_as(binding)) { + continue; + } + + // Case 3. The binding is the last binding of the group. + const auto* var_binding = binding.as(); + ICHECK(var_binding != nullptr) << "The last binding of a group whose size is larger than 1 " + "is supposed to be a variable binding"; + + // Step a. Add the grouped function to the IRModule + GlobalVar gv = builder_->AddFunction(func_info.function_, func_info.name_hint_); + + // Step b. Create the call to the deduplicated function, and then emit the call. + // - If this binding is an output binding, emit an output variable. + // - Otherwise, emit a dataflow variable. + Var new_var; + Call call_to_emit = Call(gv, UpdateArgs(func_info.arguments_)); + + if (var_binding->var->IsInstance()) { + new_var = builder_->Emit(call_to_emit); + } else { + new_var = builder_->EmitOutput(call_to_emit); + } + + // Step c. Update the mapping used for the remapping of the binding variables. + if (IsTupleOutput(func_info.function_)) { + // If the output is a tuple, attach TupleGetItem to all tuple elements, and + // remap variables approriately. + // The variables that need to be remapped and the corresponding tuple indices are + // available in pending_tuple_get and tuple_get_indices_ respectively. + for (const auto& var : pending_tuple_get[group]) { + auto tuple_get = TupleGetItem(new_var, tuple_get_indices_[var.get()]); + var_remap_[var->vid] = builder_->Emit(tuple_get); + } + } else { + var_remap_[var_binding->var->vid] = new_var; + } + } + // Step 5. Finish the binding block generation. + return builder_->EndBlock(); + } + + /*! + * \brief Collect the bindings for each grouped function and update the information of the grouped + * function + * \param bindings The bindings to be collected + * \note The function update is done by `AppendBinding(...)` + */ + void CollectFuncBindings(const Array& bindings) { + for (const Binding& binding : bindings) { + // If the binding is the only binding in its group, there is no need to create a new function. + Group* group = GetGroupFromBinding(binding); + if (group->num_nodes == 1 && group->attrs.empty()) { + continue; + } + // Add the binding to the grouped function it's in, and update the function information + // accordingly. + if (!group2func_.count(group)) { + group2func_.emplace(group, lift_constants_); + } + group2func_.find(group)->second.AppendBinding(binding); + } + } + + void CollectFuncBoundary(const Array& bindings) { + for (const Binding& binding : bindings) { + // Step 1. Get current binding's group + Group* cur_group = GetGroupFromBinding(binding); + + // Step 2. Collect all used vars in the binding value and update bondary. + // - If the var's group is same as the binding's, the var is defined in the same group + // - If the var's group is different with the binding's, the var must be the output from + // another group. Mark it to be the group output. + auto update_boundary = [this, binding, &cur_group](const Expr& e) { + if (e->IsInstance()) { + const Var& used_var = Downcast(e); + Group* producer_group = GetGroupFromVar(used_var); + // Only check those group defined before. + // Skip the vars from input or groups with single binding. + if (producer_group != cur_group) { + ICHECK(!group_deps_[producer_group].count(cur_group)) + << "A cyclic dependency detected between the groups " << binding->var->name_hint() + << " and " << used_var->name_hint() << " are in."; + group_deps_[cur_group].insert(producer_group); + } + + if (auto producer = group2func_.find(producer_group); + producer_group != cur_group && producer != group2func_.end()) { + auto output_index = producer->second.AppendOutput(used_var); + tuple_get_indices_[used_var.get()] = output_index; + } + } + }; + + if (const auto* var_binding = binding.as()) { + PostOrderVisit(var_binding->value, update_boundary); + } else { + const auto* match_cast = binding.as(); + ICHECK_NOTNULL(match_cast); + PostOrderVisit(match_cast->value, update_boundary); + } + } + } + + /*! + * \brief Get the group which the input binding is in + * \param binding The binding to be queried + * \return The pointer to the group which the input binding is in + */ + Group* GetGroupFromBinding(const Binding& binding) { + Var var = binding->var; + return GetGroupFromVar(var); + } + + /*! + * \brief Get the group which the input var is in + * \param Var The var to be queried + * \return The pointer to the group which the input var is in + */ + Group* GetGroupFromVar(const Var& var) { + const auto& it_group = obj2group_.find(var.get()); + ICHECK(it_group != obj2group_.end()); + Group* group = it_group->second; + return group->FindRoot(); + } + + /*! + * \brief Update the pre-stored arguments according to the variable remapping of the fusor, by + * recursing into each argument + * \param args The arguments to be updated + * \return The updated arguments + */ + Array UpdateArgs(const Array& args) { + Array new_args; + new_args.reserve(args.size()); + for (const Expr& arg : args) { + new_args.push_back(VisitExpr(arg)); + } + return new_args; + } + + private: + // Topologically sort bindings according to the group dependency relations. + Array TopoSortByGroupDep(const Array& bindings) { + std::unordered_map> bindings_per_group; + // The order to visit groups should respect the original order of bindings as much as possible. + std::vector group_order; + for (const auto& binding : bindings) { + auto g = GetGroupFromBinding(binding); + group_order.push_back(g); // Duplication does not matter since each group is visited once. + bindings_per_group[g].push_back(binding); + } + + std::unordered_set visited; + + std::function)> dfs_visit; + dfs_visit = [this, &visited, &dfs_visit](Group* g, auto leaf_fun) { + if (!visited.count(g)) { + visited.insert(g); + for (auto dep : group_deps_[g]) { + dfs_visit(dep, leaf_fun); + } + leaf_fun(g); + } + }; + + Array sorted; + + for (auto g : group_order) { + dfs_visit(g, [&sorted, &bindings_per_group](Group* leaf) { + for (const auto& binding : bindings_per_group[leaf]) { + sorted.push_back(binding); + } + }); + } + + return sorted; + } + + /*! \brief The IRModule. */ + IRModule mod_; + /*! \brief Internal arena. */ + support::Arena arena_; + /*! \brief The group assignment map. */ + GroupMap obj2group_; + /*! \brief Internal function information map. */ + std::unordered_map group2func_; + /*! \brief Record the index for TupleGetItem if the variable needs to be remapped to an output + * tuple element after fusion. */ + std::unordered_map tuple_get_indices_; + /*! \brief A map from a group to its dependent groups, used to detect cyclic dependencies. */ + std::unordered_map> group_deps_; + /*! \brief Whether or not to lift bound constants to parameters of the grouped function. */ + bool lift_constants_{true}; +}; + +IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) { + support::Arena arena; + + // Step 1. Create the indexed-forward graph according to the input IRModule. + IndexedForwardGraph graph = GraphCreator::Create(mod, &arena); + + // Step 2. Partition the graph by applying the fusion algorithm. + std::vector groups = + GraphPartitioner(&arena, opt_level, max_fuse_depth).Partition(graph); + + // Step 3. Transform the IRModule by fusing the operators in accordance with the graph partition + // results. + return OperatorFusor(mod, graph, groups, /*lift_constants*/ true).Transform(); +} + +namespace transform { + +Pass FuseOps(int fuse_opt_level) { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { + int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; + auto max_fuse_depth = pc->GetConfig("relax.FuseOps.max_depth", Integer(kMaxFusedOps)); + return relax::FuseOps(m, opt_level, max_fuse_depth.value().IntValue()); + }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseOps", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc new file mode 100644 index 000000000000..fa5c296d278e --- /dev/null +++ b/src/relax/transform/fuse_tir.cc @@ -0,0 +1,728 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include "../../relay/analysis/graph_partitioner.h" +#include "../../support/arena.h" +#include "../../tir/ir/functor_common.h" + +namespace tvm { +namespace tir { + +// TODO(Siyuan): move it to somewhere under tir folder +/*! + * \brief Substitute a given source buffer with a given target buffer in statements or expressions. + */ +class FuseTIRBufferSubstitor : private StmtExprMutator { + public: + static Stmt Substitute(const Map& buffer_map, Stmt stmt) { + return FuseTIRBufferSubstitor(buffer_map)(std::move(stmt)); + } + + private: + explicit FuseTIRBufferSubstitor(const Map& buffer_map) { + for (const auto& kv : buffer_map) { + const Buffer& src = kv.first; + const Buffer& tgt = kv.second; + buffer_var_map_[src->data.get()] = tgt; + } + } + + PrimExpr VisitExpr_(const VarNode* _op) final { + auto it = buffer_var_map_.find(_op); + if (it != buffer_var_map_.end()) { + return it->second->data; + } else { + return GetRef(_op); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* _op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); + auto it = buffer_var_map_.find(load->buffer->data.get()); + if (it != buffer_var_map_.end()) { + auto n = make_object(*load.get()); + n->buffer = it->second; + return BufferLoad(n); + } else { + return std::move(load); + } + } + + Stmt VisitStmt_(const BufferStoreNode* _op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); + auto it = buffer_var_map_.find(store->buffer->data.get()); + if (it != buffer_var_map_.end()) { + auto n = CopyOnWrite(store.get()); + n->buffer = it->second; + return BufferStore(n); + } else { + return std::move(store); + } + } + + PrimExpr VisitExpr_(const LoadNode* _op) final { + Load load = Downcast(StmtExprMutator::VisitExpr_(_op)); + auto it = buffer_var_map_.find(load->buffer_var.get()); + if (it != buffer_var_map_.end()) { + auto n = make_object(*load.get()); + n->buffer_var = it->second->data; + return Load(n); + } else { + return std::move(load); + } + } + + Stmt VisitStmt_(const StoreNode* _op) final { + Store store = Downcast(StmtExprMutator::VisitStmt_(_op)); + auto it = buffer_var_map_.find(store->buffer_var.get()); + if (it != buffer_var_map_.end()) { + auto n = CopyOnWrite(store.get()); + n->buffer_var = it->second->data; + return Store(n); + } else { + return std::move(store); + } + } + + Stmt VisitStmt_(const BlockNode* _op) final { + Block block = Downcast(StmtMutator::VisitStmt_(_op)); + + // Define the mutation functions. + auto f_mutate_match_buffers = [this](const MatchBufferRegion& match_buffer) { + const Buffer& src_buffer = match_buffer->source->buffer; + auto it = buffer_var_map_.find(src_buffer->data.get()); + if (it != buffer_var_map_.end()) { + return MatchBufferRegion(match_buffer->buffer, + BufferRegion(it->second, match_buffer->source->region)); + } else { + return match_buffer; + } + }; + + auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) { + auto it = buffer_var_map_.find(buffer_region->buffer->data.get()); + return it == buffer_var_map_.end() ? buffer_region + : BufferRegion(it->second, buffer_region->region); + }; + + // Step 1. Mutate `match_buffers`. + Array match_buffers = + MutateArray(block->match_buffers, f_mutate_match_buffers); + // Step 2. Mutate the read/write region. + Array reads = MutateArray(block->reads, f_mutate_read_write_region); + Array writes = MutateArray(block->writes, f_mutate_read_write_region); + + reads = UnionAccessRegion(reads); + writes = UnionAccessRegion(writes); + + if (reads.same_as(block->reads) && // + writes.same_as(block->writes) && // + match_buffers.same_as(block->match_buffers)) { + return std::move(block); + } else { + auto n = CopyOnWrite(block.get()); + n->reads = std::move(reads); + n->writes = std::move(writes); + n->match_buffers = std::move(match_buffers); + return Block(n); + } + } + + private: + /*! \brief Mapping from src buffer.data to tgt buffer. */ + std::unordered_map buffer_var_map_; + /*! \brief The structural equality checker */ + StructuralEqual structural_equal_; + + Array UnionAccessRegion(const Array& regions) const { + // For now we only allow Buffer access the same elements. + // e.g. `[A[vi, vj], A[vi, vj]]` is a legal pattern but need to union to `A[vi, vj]` + // However, `A[vi, vj], A[vi, vj + 1]` is not allow for now. + // Note: the order of return region should remain the same as the first occurance of the region + Array ret; + std::unordered_map buffer_region_set; + + for (const BufferRegion& region : regions) { + auto it = buffer_region_set.find(region->buffer.get()); + if (it == buffer_region_set.end()) { + ret.push_back(region); + buffer_region_set[region->buffer.get()] = region->region; + } else { + ICHECK(structural_equal_(region->region, it->second)); + } + } + + if (ret.size() == regions.size()) { + return regions; + } else { + return ret; + } + } +}; + +/*! \brief A mutator which detect block name duplication and deduplicate the names. */ +class BlockNameDeduplicator : public tir::StmtMutator { + private: + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(tir::StmtMutator::VisitStmt_(op)); + + String name = GetUniqueName(block->name_hint); + + if (name == block->name_hint) { + return std::move(block); + } else { + ObjectPtr n = CopyOnWrite(block.get()); + n->name_hint = std::move(name); + return Stmt(n); + } + } + + String GetUniqueName(const String& prefix) { + String unique_prefix = prefix; + auto it = name_count_.find(prefix); + while (name_count_.count(unique_prefix)) { + unique_prefix = prefix + "_" + std::to_string(++it->second); + } + name_count_[unique_prefix] = 0; + return unique_prefix; + } + + // TODO(relax-team): It should detects the number suffix and do renaming properly + // e.g. GetUniqueName("name1") should return "name2" instead of "name10". + /*! \brief The count map to make block name unique. */ + std::unordered_map name_count_; +}; + +} // namespace tir + +namespace relax { + +class FusedTIRConstructor : public ExprVisitor { + public: + /*! + * \brief Construct a fused TIR PrimFunc from a relax sub-function + * \param mod The IRModule + * \param gv The global var of relax subfunction to be fused into one PrimFunc + * \return The fused TIR PrimFunc + */ + static tir::PrimFunc GetFusedTIR(const IRModule& mod, const GlobalVar& gv) { + FusedTIRConstructor visitor(mod, gv->name_hint); + BaseFunc f = mod->Lookup(gv); + CHECK(f->IsInstance()) + << "Expected relax functions, but got: " << f->GetTypeKey(); + CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive)) + << "Expected a function with attr `kPrimitive`"; + visitor(Downcast(f)); + return visitor.fused_tir_; + } + + private: + explicit FusedTIRConstructor(const IRModule& mod, const String& func_name) + : mod_(mod), func_name_(func_name) {} + + void VisitExpr_(const FunctionNode* func) final { + // Step 1. Create buffers for function params + for (const Var& relax_param : func->params) { + auto ret = CreateParamsAndBuffers(GetStructInfo(relax_param), // + relax_param->name_hint()); + const Array& params = ret.first; + const Array& buffers = ret.second; + ICHECK_EQ(params.size(), buffers.size()); + for (size_t i = 0; i < params.size(); ++i) { + func_info_.buffer_map.Set(params[i], buffers[i]); + func_info_.params.push_back(params[i]); + } + func_info_.expr2buffers.Set(relax_param, buffers); + } + + // Step 2. Visit Function body and create intermediate buffers + ExprVisitor::VisitExpr_(func); + + // Step 3. Create and remap buffers for function output + ICHECK(func->body->IsInstance()) + << "Function body is expected to be a SeqExpr, but got: " << func->body->GetTypeKey(); + Expr body = Downcast(func->body)->body; + auto it = func_info_.expr2buffers.find(body); + ICHECK(it != func_info_.expr2buffers.end()) + << "Fail to detect output buffers for function body"; + const Array& buffers = (*it).second; + for (size_t i = 0; i < buffers.size(); ++i) { + tir::Var param = tir::Var("p_output" + std::to_string(i), PrimType(DataType::Handle())); + func_info_.buffer_map.Set(param, buffers[i]); + func_info_.params.push_back(param); + func_info_.output_buffers.insert(buffers[i].get()); + } + + // Step 4. Create PrimFunc + fused_tir_ = ConstructFunc(); + } + + void VisitBinding_(const VarBindingNode* binding) final { + // Update expr2buffers by visiting values. + this->VisitExpr(binding->value); + auto it = func_info_.expr2buffers.find(binding->value); + if (it != func_info_.expr2buffers.end()) { + // assign binding var to the buffers of the value + func_info_.expr2buffers.Set(binding->var, (*it).second); + } else { + LOG(FATAL) << "Unsupported binding value: " << binding->value; + } + } + + void VisitBinding_(const MatchCastNode* match_cast) final { + LOG(FATAL) << "MatchCast is unsupported in primitive functions"; + } + + void VisitExpr_(const CallNode* call) final { + ExprVisitor::VisitExpr_(call); + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + ICHECK(call->op == call_tir_op_) + << "Only call_tir is supported in primitive function, but got: " << GetRef(call); + + // Step 1. Get Global var and PrimFunc + GlobalVar gv = Downcast(call->args[0]); + Optional prim_func_ = GetPrimFunc(gv); + ICHECK(prim_func_.defined()) << "Cannot find the prim_func of the call_tir in the module: " + << gv; + // Step 2. Renew all vars/buffer definitions and blocks to avoid duplication + tir::PrimFunc prim_func = tir::RenewDefs(prim_func_.value()); + + // Step 3. Check functions are all schedulable funcs. i.e. the body of func is root block + // TODO(Siyuan): support un-schedulable functions. + ICHECK(prim_func->body->IsInstance()) + << "Only schedulable functions (whose body is the root block) can be fused"; + const tir::BlockRealize& root_realize = Downcast(prim_func->body); + const tir::Block& root_block = root_realize->block; + + // Step 4. Add all the original alloc_buffers and body to the fused function. + func_info_.alloc_buffers.insert(func_info_.alloc_buffers.end(), + root_block->alloc_buffers.begin(), + root_block->alloc_buffers.end()); + func_info_.bodies.push_back(root_block->body); + + // Step 5. Map input arguments to buffer + MapInputBuffer(prim_func, call->args[1]); + size_t num_output_buffers = GetCallTIROutputSize(call); + AllocateIntermediateBuffer(GetRef(call), prim_func, num_output_buffers); + // Update fused func name + func_info_.global_name += "_" + gv->name_hint; + } + + void VisitExpr_(const TupleGetItemNode* tuple_get_item) final { + ExprVisitor::VisitExpr_(tuple_get_item); + auto it = func_info_.expr2buffers.find(tuple_get_item->tuple); + if (it != func_info_.expr2buffers.end()) { + int begin_buf_idx = 0; + int end_buf_idx = 0; + const TupleType& tuple_type = Downcast(tuple_get_item->tuple->checked_type()); + for (int i = 0; i < tuple_get_item->index; ++i) { + begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]); + } + end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]); + func_info_.expr2buffers.Set( + GetRef(tuple_get_item), + {(*it).second.begin() + begin_buf_idx, (*it).second.begin() + end_buf_idx}); + } + } + + void VisitExpr_(const TupleNode* tuple) final { + ExprVisitor::VisitExpr_(tuple); + Array buffers; + for (const Expr& expr : tuple->fields) { + auto it = func_info_.expr2buffers.find(expr); + if (it != func_info_.expr2buffers.end()) { + buffers.insert(buffers.end(), (*it).second.begin(), (*it).second.end()); + } + } + if (!buffers.empty()) { + func_info_.expr2buffers.Set(GetRef(tuple), buffers); + } + } + + void VisitExpr_(const ConstantNode* op) final { + LOG(FATAL) << "Relax.Constant is not supported in primitive functions."; + } + + /********** Helper Functions **********/ + + /*! + * \brief Pattern match op to a TIR function and look it up. + * \return The TIR function, or NullOpt if patter match fails. + */ + Optional GetPrimFunc(const GlobalVar& global_var) { + // NOTE: as check works for nullptr(returns null) + Optional base_func = mod_->functions.Get(global_var); + if (auto* pfunc = base_func.as()) { + return GetRef(pfunc); + } else { + return NullOpt; + } + } + + /*! + * \brief Get the number of outputs for a call_tir node. + * \return The number of outputs. + */ + static size_t GetCallTIROutputSize(const CallNode* call) { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + ICHECK(call->op.same_as(call_tir_op_)); + ICHECK_EQ(call->sinfo_args.size(), 1); + if (const auto* tuple_sinfo = call->sinfo_args[0].as()) { + return tuple_sinfo->fields.size(); + } else { + return 1; + } + } + + /*! \brief Map old TIR func param buffer to new buffer, and then update `buffer_subst_map` */ + void MapArgsToBuffer(const Array args, const Array& buffers) { + size_t buffer_idx = 0; + for (const Expr& arg : args) { + if (const auto* v = arg.as()) { + auto it = func_info_.expr2buffers.find(GetRef(v)); + // Substitute the buffer with the already allocated one if it is an intermediate var + if (it != func_info_.expr2buffers.end()) { + for (const tir::Buffer& target_buffer : (*it).second) { + ICHECK_LT(buffer_idx, buffers.size()); + const tir::Buffer& buffer = buffers[buffer_idx]; + // TODO(relax-team): Add support for symbolic shape fusion + for (const PrimExpr& shape_expr : buffer->shape) { + ICHECK(shape_expr.as()) << "Only support constant shape fusion for now"; + } + func_info_.buffer_subst_map.Set(buffer, target_buffer); + buffer_idx++; + } + } + } + } + // Make sure every buffers are maped. + ICHECK_EQ(buffer_idx, buffers.size()); + } + + /*! + * \brief Update buffer mapping `func_info_.buffer_subst_map` for input args + * \param func The old TIR PrimFunc + * \param output_size The number of output params. All output params are at the end of param list. + */ + void MapInputBuffer(const tir::PrimFunc& func, const relax::Expr& args) { + Array arg_list; + Array buffer_list; + if (const auto* arg_tuple = args.as()) { + arg_list = arg_tuple->fields; + } else { + arg_list = {args}; + } + + ICHECK_GE(func->params.size(), arg_list.size()); + for (size_t i = 0; i < arg_list.size(); ++i) { + const tir::Var& param = func->params[i]; + const tir::Buffer& buffer = func->buffer_map.at(param); + buffer_list.push_back(buffer); + } + + MapArgsToBuffer(arg_list, buffer_list); + } + + /*! + * \brief Allocate buffer(s) and update `func_info.expr2buffers` if the PrimFunc output(s) are + * intermediate results. + * \param expr The relax Expr, which can be binding vars or binding values. + * \param func The old TIR PrimFunc + * \param output_size The number of output params. All output params are at the end of param list. + */ + void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func, size_t output_size) { + size_t n = func->params.size(); + ICHECK_GE(n, output_size); + // Allocate intermediate buffer + Array alloc_buffers; + for (size_t i = 0; i < output_size; ++i) { + const tir::Var& param = func->params[n - output_size + i]; + const tir::Buffer& buffer = func->buffer_map.at(param); + func_info_.alloc_buffers.push_back(buffer); + alloc_buffers.push_back(buffer); + } + // Update expr2buffers + func_info_.expr2buffers.Set(expr, alloc_buffers); + } + + /*! + * \brief Create an TIR func params and buffers with specified relax type and shape + * \param struct_info The struct info + * \param name_hint The name hint for params and buffers + * \param index The index used for unique name_hint if type is Tuple. + * -1 means no need to add postfix since the relax param is not a Tuple. + * \return The created TIR func params and buffers + */ + static std::pair, Array> CreateParamsAndBuffers( + StructInfo struct_info, const String& name_hint, int index = -1) { + Array params; + Array buffers; + if (const auto* tensor = struct_info.as()) { + // Case 1. the relax param is a DynTensor, we directly create a tir var and buffer + const auto* shape_expr = tensor->shape.as(); + ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with symbolic shape."; + + String name = index == -1 ? name_hint : name_hint + "_" + std::to_string(index); + DataType dtype = tensor->dtype; + tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name); + // Differentiate buffer name and param name by adding prefix `v_` to param + // Every symbol should be unique in TVMScript, and Buffer is used more than param + // So we decide to make sure buffer names have better readability. + tir::Var param = tir::Var("p_" + name, PrimType(DataType::Handle())); + params.push_back(std::move(param)); + buffers.push_back(std::move(buffer)); + } else if (const auto* tuple = struct_info.as()) { + // Case 2. the relax param is a Tuple, we recursively visit each field until it's a DynTensor + // Enable postfix + if (index == -1) index = 0; + for (size_t i = 0; i < tuple->fields.size(); ++i) { + auto ret = CreateParamsAndBuffers(tuple->fields[i], name_hint, index); + const Array& ret_params = ret.first; + const Array& ret_buffers = ret.second; + ICHECK_EQ(ret_params.size(), ret_buffers.size()); + // Adding tuple field results to the end of params and buffers. + params.insert(params.end(), ret_params.begin(), ret_params.end()); + buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end()); + index += ret_params.size(); + } + } else { + ICHECK(false) << "shapes are expected to be ShapeExprNode or TupleNode"; + } + return std::make_pair(params, buffers); + } + + /*! + * \brief Construct fused TIR func with collected FuseFuncInfo + * \return The fused TIR + */ + tir::PrimFunc ConstructFunc() { + Map attr_map; + attr_map.Set("tir.noalias", tir::const_true()); + ICHECK(func_info_.global_name != "fused"); + // Remove output buffers from func_info_.alloc_buffers + Array alloc_buffers; + for (const tir::Buffer& buf : func_info_.alloc_buffers) { + if (func_info_.output_buffers.count(buf.get()) == 0) { + alloc_buffers.push_back(buf); + } + } + tir::Stmt body = tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies)); + body = tir::FuseTIRBufferSubstitor::Substitute(func_info_.buffer_subst_map, body); + body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt, alloc_buffers); + body = tir::BlockRealize({}, Bool(true), Downcast(body)); + tir::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map, + DictAttrs(attr_map)); + return func; + } + + /*! \brief Get DynTensor numbers from recursive Tuples. */ + static size_t GetTotalTensorSize(const Type& type) { + if (type.as()) { + return 1; + } else if (const auto* tuple_type = type.as()) { + size_t num = 0; + for (const Type& type : tuple_type->fields) { + num += GetTotalTensorSize(type); + } + return num; + } else { + LOG(FATAL) << "DynTensorType and TupleType are expect, but got: " << type; + return 0; + } + } + + /********** Function Info **********/ + + /*! \brief auxiliary information for FuseTIR */ + struct FuseFuncInfo { + /*! \brief The arguments for calling prim_func */ + Array arguments; + /*! + * \brief The map from each dataflow var (intermediate var) to the corresponding buffers + * allocated in the fused func + */ + Map> expr2buffers; + /*! \brief The buffers to allocate in the fused func*/ + Array alloc_buffers; + /*! \brief The bodies of the original funcs, which is also the body of the fused func. */ + Array bodies; + /*! \brief The params of the fused function*/ + Array params; + /*! + * \brief The map from buffer in original functions to corresponding buffer in the fused + * function + */ + Map buffer_subst_map; + /*! \brief The `buffer_map` in the fused function*/ + Map buffer_map; + /*! \brief The output buffers in the function buffer_map*/ + std::unordered_set output_buffers; + /*! \brief The name of the fused function */ + std::string global_name = "fused"; + }; + + /*! \brief The IRModule */ + const IRModule& mod_; + /*! \brief The name hint for the input func. */ + String func_name_; + /*! \brief The helper info to fuse TIR prim_func */ + FuseFuncInfo func_info_; + /*! \brief The tir function after fusion*/ + tir::PrimFunc fused_tir_; +}; + +/*! + * \brief The helper class to fuse TIR functions and build a new module which calls the fused TIR. + */ +class TIRFuseMutator : public ExprMutator { + public: + static IRModule Transform(const IRModule& mod) { + // Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder. + TIRFuseMutator mutator(mod); + // Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_` + for (const auto& kv : mod->functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + // Only fuse primitive relax functions + if (func->IsInstance() && func->HasNonzeroAttr(attr::kPrimitive)) { + tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv); + mutator.fused_tir_funcs_.Set(gv, fused_tir); + } + } + + // Step 2. Update all non-primitive relax functions and add it, with the dependent function, + // into the new IRModule + for (const auto& kv : mod->functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { + relax::Function update_func = Downcast(mutator.VisitExpr(func)); + mutator.builder_->AddFunction(update_func, gv->name_hint); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + explicit TIRFuseMutator(const IRModule& mod) : mod_(mod) {} + + using ExprMutator::VisitExpr_; + + // Get shape from call tir + static Expr GetCallTIRShape(StructInfo sinfo) { + if (auto* tuple = sinfo.as()) { + Array fields = tuple->fields.Map([&](StructInfo x) { return GetCallTIRShape(x); }); + return Tuple(fields); + } else { + auto* tensor = sinfo.as(); + ICHECK(tensor) << "FuseTIR can only take tensor or tuple type"; + auto* shape_expr = tensor->shape.as(); + ICHECK(shape_expr) << "FuseTIR requires all intermediate values have shape"; + return GetRef(shape_expr); + } + } + + Expr VisitExpr_(const CallNode* op) final { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + Call call = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(op))); + + if (call->op->IsInstance()) { + // Case 1. It is a relax cross function call + GlobalVar old_gv = Downcast(call->op); + auto it = fused_tir_funcs_.find(old_gv); + if (it != fused_tir_funcs_.end()) { + const tir::PrimFunc& fused_tir = (*it).second; + // Case 1.1. It calls a primitive relax function, update the call into a call_tir + GlobalVar fused_tir_gv = this->builder_->AddFunction(fused_tir, old_gv->name_hint); + // Step a. Flatten all args since call_tir does not support Tuple value. + Array arg_list; + for (const Expr& arg : call->args) { + Array flattened = FlattenArg(arg); + arg_list.insert(arg_list.end(), flattened.begin(), flattened.end()); + } + // Step b. Create call_tir + Array call_args = {fused_tir_gv, Tuple(arg_list)}; + return Call(call_tir_op_, call_args, call->attrs, {GetStructInfo(call)}); + } else { + // Case 1.2. The callee function is not primitive, nothing to do. + return call; + } + } else if (call->op == call_tir_op_) { + // Case 2. It is a call_tir, re-emit the PrimFunc. + GlobalVar gv = Downcast(call->args[0]); + tir::PrimFunc func = Downcast(mod_->Lookup(gv)); + GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint); + return Call(call->op, {new_gv, call->args[1]}, call->attrs, call->sinfo_args, call->span); + } else { + // Case 3. CallNode in other types. Leave it as it is. + return call; + } + } + + /********** Helper Functions **********/ + + /*! \brief Flatten the call args if it's Tuple by emitting `TupleGetItem`. */ + Array FlattenArg(const Expr& arg) { + if (const auto* tuple_sinfo = GetStructInfoAs(arg)) { + Array arg_list; + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + Expr new_arg = builder_->Emit(TupleGetItem(arg, i)); + Array flattened = FlattenArg(new_arg); + arg_list.insert(arg_list.end(), flattened.begin(), flattened.end()); + } + return arg_list; + } else { + return {arg}; + } + } + + private: + /*! \brief The IRModule */ + const IRModule& mod_; + /*! \brief The map from global var of primitive relax function to generated prim func. */ + Map fused_tir_funcs_; +}; + +IRModule FuseTIR(IRModule mod) { + mod = TIRFuseMutator::Transform(mod); + return mod; +} + +namespace transform { + +Pass FuseTIR() { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { return relax::FuseTIR(m); }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseTIR", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_annotate_tir_op_pattern.py b/tests/python/relax/test_transform_annotate_tir_op_pattern.py new file mode 100644 index 000000000000..73c65378693a --- /dev/null +++ b/tests/python/relax/test_transform_annotate_tir_op_pattern.py @@ -0,0 +1,360 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import enum + +import tvm +import tvm.script +import tvm.testing +from tvm import relax +from tvm.script import tir as T + + +class OpPatternKind(enum.IntEnum): + kElemWise = 0 + kBroadcast = 1 + kInjective = 2 + kCommReduce = 3 + kOutEWiseFusable = 4 + kTuple = 7 + kOpaque = 8 + + +def test_annotate_opkind_outewisefusable(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["tir_matmul"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_annotate_opkind_outewisefusable_int_var_signature(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle, m: T.int64, n: T.int64, k: T.int64): + T.func_attr({"global_symbol": "tir_matmul"}) + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["tir_matmul"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_annotate_opkind_reduce(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def sum(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16,)) + + for i, j in T.grid(16, 16): + with T.block("matmul"): + vi, vj = T.axis.remap("SR", [i, j]) + with T.init(): + B[vi] = 0.0 + B[vi] += A[vi, vj] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["sum"].attrs["op_pattern"] == OpPatternKind.kCommReduce + + +def test_annotate_opkind_ewise(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def elemwise(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16, 16)) + + for i, j in T.grid(16, 16): + with T.block("matmul"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["elemwise"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_broadcast(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def broadcast(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16, 16, 16, 16)) + + for i0, j0, i1, j1 in T.grid(16, 16, 16, 16): + with T.block("matmul"): + vi0, vj0, vi1, vj1 = T.axis.remap("SSSS", [i0, j0, i1, j1]) + B[vi0, vj0, vi1, vj1] = A[vj0, vj1] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["broadcast"].attrs["op_pattern"] == OpPatternKind.kBroadcast + + +def test_annotate_opkind_injective(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def injective(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (4, 4, 4, 4)) + B = T.match_buffer(y, (16, 16)) + + for i, j in T.grid(16, 16): + with T.block("matmul"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi // 4, vj // 4, vi % 4, vj % 4] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["injective"].attrs["op_pattern"] == OpPatternKind.kInjective + + +def test_annotate_opkind_bias_add(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_bias_add( + A: T.Buffer((1, 1000), "float32"), + B: T.Buffer((1000,), "float32"), + C: T.Buffer((1, 1000), "float32"), + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "tir_bias_add", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1 in T.grid(1, 1000): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[ax0, ax1], B[ax1]) + T.writes(C[ax0, ax1]) + C[ax0, ax1] = A[ax0, ax1] + B[ax1] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["tir_bias_add"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_add_broadcast_with_unit_shape(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def add_with_unit_dim_len_broadcast( + A: T.Buffer((1, 64, 112, 112), "float32"), + B: T.Buffer((64, 1, 1), "float32"), + C: T.Buffer((1, 64, 112, 112), "float32"), + ) -> None: + T.func_attr({"global_symbol": "add5", "tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(1, 64, 112, 112): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[ax0, ax1, ax2, ax3], B[ax1, 0, 0]) + T.writes(C[ax0, ax1, ax2, ax3]) + C[ax0, ax1, ax2, ax3] = A[ax0, ax1, ax2, ax3] + B[ax1, 0, 0] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["add_with_unit_dim_len_broadcast"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_add_zero_dim_element_wise(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def add_zero_dim( + A: T.Buffer((128,), "float32"), + B: T.Buffer((), "float32"), + C: T.Buffer((128,), "float32"), + ) -> None: + T.func_attr({"global_symbol": "add8", "tir.noalias": True}) + for i0 in T.serial(128): + with T.block("T_add"): + ax0 = T.axis.spatial(128, i0) + T.reads(A[ax0], B[()]) + T.writes(C[ax0]) + C[ax0] = A[ax0] + B[()] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["add_zero_dim"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_pooling(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def max_pool2d( + rxplaceholder_1: T.Buffer((1, 64, 112, 112), "float32"), + tensor_1: T.Buffer((1, 64, 56, 56), "float32"), + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "max_pool2d", "T.noalias": True}) + # body + # with T.block("root") + pad_temp_1 = T.alloc_buffer([1, 64, 114, 114], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 64, 114, 114): + with T.block("pad_temp"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1]) + T.writes(pad_temp_1[ax0, ax1, ax2, ax3]) + pad_temp_1[ax0, ax1, ax2, ax3] = T.if_then_else( + 1 <= ax2 and ax2 < 113 and 1 <= ax3 and ax3 < 113, + rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1], + T.float32(-3.4028234663852886e38), + dtype="float32", + ) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 64, 56, 56, 3, 3): + with T.block("tensor"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads( + tensor_1[ax0, ax1, ax2, ax3], + pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1], + ) + T.writes(tensor_1[ax0, ax1, ax2, ax3]) + with T.init(): + tensor_1[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e38) + tensor_1[ax0, ax1, ax2, ax3] = T.max( + tensor_1[ax0, ax1, ax2, ax3], + pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1], + ) + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["max_pool2d"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_annotate_opkind_softmax(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def softmax( + rxplaceholder_1: T.Buffer((16, 16), "float32"), + T_softmax_norm_1: T.Buffer((16, 16), "float32"), + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "softmax", "T.noalias": True}) + # body + # with T.block("root") + T_softmax_maxelem_1 = T.alloc_buffer([16], dtype="float32") + T_softmax_exp_1 = T.alloc_buffer([16, 16], dtype="float32") + T_softmax_expsum_1 = T.alloc_buffer([16], dtype="float32") + for i0_7, i1_3 in T.grid(16, 16): + with T.block("T_softmax_maxelem"): + i0_8, k = T.axis.remap("SR", [i0_7, i1_3]) + T.reads(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k]) + T.writes(T_softmax_maxelem_1[i0_8]) + with T.init(): + T_softmax_maxelem_1[i0_8] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem_1[i0_8] = T.max( + T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k] + ) + for i0_9, i1_4 in T.grid(16, 16): + with T.block("T_softmax_exp"): + i0_10, i1_5 = T.axis.remap("SS", [i0_9, i1_4]) + T.reads(rxplaceholder_1[i0_10, i1_5], T_softmax_maxelem_1[i0_10]) + T.writes(T_softmax_exp_1[i0_10, i1_5]) + T_softmax_exp_1[i0_10, i1_5] = T.exp( + rxplaceholder_1[i0_10, i1_5] - T_softmax_maxelem_1[i0_10], dtype="float32" + ) + for i0_11, i1_6 in T.grid(16, 16): + with T.block("T_softmax_expsum"): + i0_12, k = T.axis.remap("SR", [i0_11, i1_6]) + T.reads(T_softmax_expsum_1[i0_12], T_softmax_exp_1[i0_12, k]) + T.writes(T_softmax_expsum_1[i0_12]) + with T.init(): + T_softmax_expsum_1[i0_12] = T.float32(0) + T_softmax_expsum_1[i0_12] = ( + T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k] + ) + for i0_13, i1_7 in T.grid(16, 16): + with T.block("T_softmax_norm"): + i0_14, i1_8 = T.axis.remap("SS", [i0_13, i1_7]) + T.reads(T_softmax_exp_1[i0_14, i1_8], T_softmax_expsum_1[i0_14]) + T.writes(T_softmax_norm_1[i0_14, i1_8]) + T.block_attr({"axis": 1}) + T_softmax_norm_1[i0_14, i1_8] = ( + T_softmax_exp_1[i0_14, i1_8] / T_softmax_expsum_1[i0_14] + ) + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["softmax"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_multiple_bufer_stores_fallback(): + @tvm.script.ir_module + class CumsumModule: + @T.prim_func + def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer(160, "float32")): + rxplaceholder = T.match_buffer( + var_rxplaceholder, [10, 16], dtype="float32", offset_factor=1 + ) + with T.block("cumsum_generic"): + T.reads(rxplaceholder[0:10, 0:16]) + T.writes(out_buf[0:160]) + for fused in T.parallel(1): + out_buf[fused * 160] = rxplaceholder[fused * 160 // 16, fused * 160 % 16] + for v_k in T.serial(159): + out_buf[fused * 160 + (v_k + 1)] = ( + out_buf[fused * 160 + (v_k + 1 - 1)] + + rxplaceholder[ + (fused * 160 + (v_k + 1)) // 16, + (fused * 160 + (v_k + 1)) % 16, + ] + ) + + mod = CumsumModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["cumsum"].attrs["op_pattern"] == OpPatternKind.kOpaque + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py new file mode 100644 index 000000000000..1a228bb268fa --- /dev/null +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -0,0 +1,759 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax, topi +from tvm.script import relax as R + + +def _check(mod_actual, mod_expected): + mod_actual = relax.transform.AnnotateTIROpPattern()(mod_actual) + mod_actual = relax.transform.FuseOps()(mod_actual) + mod_expected = relax.transform.AnnotateTIROpPattern()(mod_expected) + tvm.ir.assert_structural_equal(mod_actual, mod_expected) + + +def test_fuse_simple(): + """Simple testcase.""" + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call(fused_add_exp_squeeze, [x, relax.const(1, "float32")]) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_conv2d_fuse(): + """Test fusion case of conv2d""" + + def before(dtype): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype)) + w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype)) + w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype)) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit_te(topi.nn.conv2d, lv0, w1, strides=1, padding=1, dilation=1) + # this is the next dominator. + lv2 = bb.emit_te(topi.add, relax.const(1, dtype), lv1) + lv3 = bb.emit_te(topi.add, lv1, lv2) + # second path + lv4 = bb.emit_te(topi.nn.conv2d, lv3, w2, strides=1, padding=0, dilation=1) + lv5 = bb.emit_te(topi.nn.conv2d, lv3, w3, strides=1, padding=1, dilation=1) + gv = bb.emit_output(bb.call_te(topi.add, lv4, lv5)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dtype): + bb = relax.BlockBuilder() + + # Grouped function 1 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype)) + p0 = relax.Var("p0", R.Tensor((), dtype)) + with bb.function("fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d", + ) + lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1") + gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Grouped function 2 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype)) + y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype)) + with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=0, + dilation=1, + primfunc_name_hint="conv2d1", + ) + gv = bb.emit_output(bb.call_te(topi.add, lv0, y, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2") + fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2") + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype)) + w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype)) + w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype)) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1, relax.const(1, dtype)])) + lv2 = bb.emit_te( + topi.nn.conv2d, + lv1, + w3, + strides=1, + padding=1, + dilation=1, + ) + gv = bb.emit_output(relax.Call(fused_conv2d1_add2, [lv1, w2, lv2])) + bb.emit_func_output(gv) + + return bb.get() + + _check(before("float32"), expected("float32")) + _check(before("float16"), expected("float16")) + _check(before("int8"), expected("int8")) + + +def test_concatenate(): + """Test fusion case involving concat op and Tuple node""" + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + x, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + lv1 = bb.emit_te(topi.nn.upsampling, lv0, scale_h=2.0, scale_w=2.0) + lv2 = bb.emit_te(topi.concatenate, (lv1, x), axis=1) + gv = bb.emit_output(bb.call_te(topi.add, lv2, relax.const(1, "float32"))) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + w = relax.Var("w", R.Tensor((1, 16, 32, 32), "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_upsampling_concatenate_add", [w, x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.upsampling, w, scale_h=2.0, scale_w=2.0) + lv1 = bb.emit_te(topi.concatenate, (lv0, x), axis=1) + gv = bb.emit_output(bb.call_te(topi.add, lv1, p0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_upsampling_concatenate_add = bb.get().get_global_var( + "fused_upsampling_concatenate_add" + ) + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + x, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + gv = bb.emit_output( + relax.Call( + fused_upsampling_concatenate_add, (lv0, x, relax.const(1, "float32")) + ) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_tuple_root(): + """Test fusion case where Tuple node is the root in its group""" + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + x, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + lv1 = bb.emit_te(topi.nn.upsampling, lv0, scale_h=2.0, scale_w=2.0) + gv = bb.emit_output((lv1, x)) + bb.emit_func_output(gv) + + return bb.get() + + # The fusion is supposed to make no change. + _check(before(), before()) + + +def test_fuse_tuple_get_elemwise(): + def before(dim: int): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, dim), "float32")) + w = relax.Var("w", R.Tensor((3 * dim, dim), "float32")) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.dense, x, w) + lv1 = bb.emit_te(topi.split, lv0, indices_or_sections=3, axis=1) + lv2 = bb.emit(relax.TupleGetItem(lv1, 0)) + lv3 = bb.emit_te(topi.sigmoid, lv2) + lv4 = bb.emit(relax.TupleGetItem(lv1, 1)) + lv5 = bb.emit_te(topi.tanh, lv4) + lv6 = bb.emit(relax.TupleGetItem(lv1, 2)) + lv7 = bb.emit_te(topi.exp, lv6) + lv8 = bb.emit_te(topi.multiply, lv5, lv7) + gv = bb.emit_output(bb.call_te(topi.add, lv3, lv8)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dim: int): + bb = relax.BlockBuilder() + + # Grouped function + dense = relax.Var("dense", R.Tensor((1, 3 * dim), "float32")) + with bb.function( + "fused_split_sigmoid_tanh_exp_multiply_add", [dense], attrs={"Primitive": 1} + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.split, dense, indices_or_sections=3, axis=1) + lv1 = bb.emit(relax.TupleGetItem(lv0, 0)) + lv2 = bb.emit_te(topi.sigmoid, lv1) + lv3 = bb.emit(relax.TupleGetItem(lv0, 1)) + lv4 = bb.emit_te(topi.tanh, lv3) + lv5 = bb.emit(relax.TupleGetItem(lv0, 2)) + lv6 = bb.emit_te(topi.exp, lv5) + lv7 = bb.emit_te(topi.multiply, lv4, lv6) + gv = bb.emit_output(bb.call_te(topi.add, lv2, lv7)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_split_sigmoid_tanh_exp_multiply_add = bb.get().get_global_var( + "fused_split_sigmoid_tanh_exp_multiply_add" + ) + + # Main function + x = relax.Var("x", R.Tensor((1, dim), "float32")) + w = relax.Var("w", R.Tensor((3 * dim, dim), "float32")) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.dense, x, w) + gv = bb.emit_output(relax.Call(fused_split_sigmoid_tanh_exp_multiply_add, (lv0,))) + bb.emit_func_output(gv) + + return bb.get() + + dim = 10 + _check(before(dim), expected(dim)) + + +def test_tuple_get_root(): + def before(dim: int): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((1, 3 * dim), "float32")) + w = relax.Var("w", R.Tensor((dim, dim), "float32")) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1) + lv1 = bb.emit(relax.TupleGetItem(lv0, 0)) + gv = bb.emit_output(bb.call_te(topi.nn.dense, lv1, w)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dim: int): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((1, 3 * dim), "float32")) + with bb.function("fused_split", [x], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1) + gv = bb.emit_output(relax.TupleGetItem(lv0, 0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_split = bb.get().get_global_var("fused_split") + + # Main function + x = relax.Var("x", R.Tensor((1, 3 * dim), "float32")) + w = relax.Var("w", R.Tensor((dim, dim), "float32")) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit(relax.Call(fused_split, (x,))) + gv = bb.emit_output(bb.call_te(topi.nn.dense, lv0, w)) + bb.emit_func_output(gv) + + return bb.get() + + dim = 10 + _check(before(dim), expected(dim)) + + +def test_tuple_intermediate(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.squeeze, x) + lv1 = bb.emit_te(topi.add, lv0, relax.const(1, "float32")) + lv2 = bb.emit_te(topi.squeeze, lv0) + lv3 = bb.emit_te(topi.add, lv2, relax.const(1, "float32")) + lv4 = bb.emit_te(topi.add, lv3, relax.const(1, "float32")) + lv5 = bb.emit_te(topi.add, lv0, relax.const(1, "float32")) + lv6 = bb.emit_te(topi.concatenate, (lv1, lv4, lv5), axis=1) + lv7 = bb.emit_te(topi.squeeze, lv6) + gv = bb.emit_output(bb.call_te(topi.add, lv7, relax.const(1, "float32"))) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + p1 = relax.Var("p1", R.Tensor((), "float32")) + p2 = relax.Var("p2", R.Tensor((), "float32")) + p3 = relax.Var("p3", R.Tensor((), "float32")) + p4 = relax.Var("p4", R.Tensor((), "float32")) + with bb.function( + "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1", + [x, p0, p1, p2, p3, p4], + attrs={"Primitive": 1}, + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.squeeze, x) + lv1 = bb.emit_te(topi.add, lv0, p0) + lv2 = bb.emit_te(topi.squeeze, lv0) + lv3 = bb.emit_te(topi.add, lv2, p1) + lv4 = bb.emit_te(topi.add, lv3, p2) + lv5 = bb.emit_te(topi.add, lv0, p3) + lv6 = bb.emit_te(topi.concatenate, (lv1, lv4, lv5), axis=1) + lv7 = bb.emit_te(topi.squeeze, lv6) + gv = bb.emit_output(bb.call_te(topi.add, lv7, p4)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_func = bb.get().get_global_var( + "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1" + ) + + # Main func + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call( + fused_func, + ( + x, + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + ), + ) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_tuple_consecutive(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv1 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv2 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv3 = bb.emit_te(topi.concatenate, (lv0, lv1, lv2), axis=1) + lv4 = bb.emit_te(topi.add, lv3, relax.const(1, "float32")) + lv5 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv6 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv7 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv8 = bb.emit_te(topi.concatenate, (lv5, lv6, lv7), axis=1) + lv9 = bb.emit_te(topi.add, lv8, relax.const(1, "float32")) + lv10 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv11 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv12 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv13 = bb.emit_te(topi.concatenate, (lv10, lv11, lv12), axis=1) + lv14 = bb.emit_te(topi.add, lv13, relax.const(1, "float32")) + lv15 = bb.emit_te(topi.concatenate, (lv4, lv9, lv14), axis=1) + lv16 = bb.emit_te( + topi.nn.pool2d, + lv15, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + lv17 = bb.emit_te(topi.add, lv16, relax.const(1, "float32")) + lv18 = bb.emit_te(topi.add, lv17, relax.const(1, "float32")) + gv = bb.emit_output((lv17, lv18)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function 1 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + p1 = relax.Var("p1", R.Tensor((), "float32")) + p2 = relax.Var("p2", R.Tensor((), "float32")) + p3 = relax.Var("p3", R.Tensor((), "float32")) + p4 = relax.Var("p4", R.Tensor((), "float32")) + p5 = relax.Var("p5", R.Tensor((), "float32")) + p6 = relax.Var("p6", R.Tensor((), "float32")) + p7 = relax.Var("p7", R.Tensor((), "float32")) + p8 = relax.Var("p8", R.Tensor((), "float32")) + p9 = relax.Var("p9", R.Tensor((), "float32")) + p10 = relax.Var("p10", R.Tensor((), "float32")) + p11 = relax.Var("p11", R.Tensor((), "float32")) + with bb.function( + "fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1", + [x, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11], + attrs={"Primitive": 1}, + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.add, x, p1) + lv2 = bb.emit_te(topi.add, x, p2) + lv3 = bb.emit_te(topi.concatenate, (lv0, lv1, lv2), axis=1) + lv4 = bb.emit_te(topi.add, lv3, p3) + lv5 = bb.emit_te(topi.add, x, p4) + lv6 = bb.emit_te(topi.add, x, p5) + lv7 = bb.emit_te(topi.add, x, p6) + lv8 = bb.emit_te(topi.concatenate, (lv5, lv6, lv7), axis=1) + lv9 = bb.emit_te(topi.add, lv8, p7) + lv10 = bb.emit_te(topi.add, x, p8) + lv11 = bb.emit_te(topi.add, x, p9) + lv12 = bb.emit_te(topi.add, x, p10) + lv13 = bb.emit_te(topi.concatenate, (lv10, lv11, lv12), axis=1) + lv14 = bb.emit_te(topi.add, lv13, p11) + gv = bb.emit_output(bb.call_te(topi.concatenate, (lv4, lv9, lv14), axis=1)) + bb.emit_func_output(gv) + + # Grouped function 2 + concat = relax.Var("concat", R.Tensor((1, 144, 64, 64), "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_pool2d_add2", [concat, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + concat, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + gv = bb.emit_output(bb.call_te(topi.add, lv0, p0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_func1 = mod.get_global_var( + "fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1" + ) + fused_func2 = mod.get_global_var("fused_pool2d_add2") + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit( + relax.Call( + fused_func1, + ( + x, + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + ), + ) + ) + lv1 = bb.emit(relax.Call(fused_func2, (lv0, relax.const(1, "float32")))) + lv2 = bb.emit_te(topi.add, lv1, relax.const(1, "float32")) + gv = bb.emit_output((lv1, lv2)) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_inception_like(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + w0 = relax.Var("w0", R.Tensor((16, 16, 3, 3), "float32")) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), "float32")) + w2 = relax.Var("w2", R.Tensor((16, 32, 3, 3), "float32")) + w3 = relax.Var("w3", R.Tensor((16, 32, 3, 3), "float32")) + with bb.function("main", [x, w0, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.conv2d, x, w0, strides=1, padding=1, dilation=1) + lv1 = bb.emit_te(topi.nn.relu, lv0) + lv2 = bb.emit_te(topi.nn.conv2d, x, w1, strides=1, padding=1, dilation=1) + lv3 = bb.emit_te(topi.nn.relu, lv2) + lv4 = bb.emit_te(topi.concatenate, (lv1, lv3), axis=1) + lv5 = bb.emit_te(topi.nn.conv2d, lv4, w2, strides=1, padding=1, dilation=1) + lv6 = bb.emit_te(topi.nn.relu, lv5) + lv7 = bb.emit_te(topi.nn.conv2d, lv4, w3, strides=1, padding=1, dilation=1) + lv8 = bb.emit_te(topi.nn.relu, lv7) + gv = bb.emit_output(bb.call_te(topi.concatenate, (lv6, lv8), axis=1)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function 1 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + w = relax.Var("w", R.Tensor((16, 16, 3, 3), "float32")) + with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d", + ) + gv = bb.emit_output(bb.call_te(topi.nn.relu, lv0)) + bb.emit_func_output(gv) + + # Grouped function 2 + x = relax.Var("x", R.Tensor((1, 32, 64, 64), "float32")) + w = relax.Var("w", R.Tensor((16, 32, 3, 3), "float32")) + with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d1", + ) + gv = bb.emit_output(bb.call_te(topi.nn.relu, lv0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_conv2d_relu1 = mod.get_global_var("fused_conv2d_relu") + fused_conv2d_relu2 = mod.get_global_var("fused_conv2d1_relu") + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) + w0 = relax.Var("w0", R.Tensor((16, 16, 3, 3), "float32")) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), "float32")) + w2 = relax.Var("w2", R.Tensor((16, 32, 3, 3), "float32")) + w3 = relax.Var("w3", R.Tensor((16, 32, 3, 3), "float32")) + with bb.function("main", [x, w0, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit(relax.Call(fused_conv2d_relu1, (x, w0))) + lv1 = bb.emit(relax.Call(fused_conv2d_relu1, (x, w1))) + lv2 = bb.emit_te(topi.concatenate, (lv0, lv1), axis=1) + lv3 = bb.emit(relax.Call(fused_conv2d_relu2, (lv2, w2))) + lv4 = bb.emit(relax.Call(fused_conv2d_relu2, (lv2, w3))) + gv = bb.emit_output(bb.call_te(topi.concatenate, (lv3, lv4), axis=1)) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_fuse_parallel_injective(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((10, 20), "int32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "int32")) + lv1 = bb.emit_te(topi.squeeze, lv0) + lv2 = bb.emit_te(topi.transpose, lv0, axes=[1, 0]) + lv3 = bb.emit_te(topi.transpose, lv2, axes=[1, 0]) + gv = bb.emit_output(bb.call_te(topi.left_shift, lv1, lv3)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((10, 20), "int32")) + p0 = relax.Var("p0", R.Tensor((), "int32")) + with bb.function( + "fused_add_squeeze_transpose_transpose1_left_shift", [x, p0], attrs={"Primitive": 1} + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.squeeze, lv0) + lv2 = bb.emit_te(topi.transpose, lv0, axes=[1, 0]) + lv3 = bb.emit_te(topi.transpose, lv2, axes=[1, 0], primfunc_name_hint="transpose1") + gv = bb.emit_output(bb.call_te(topi.left_shift, lv1, lv3)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_func = bb.get().get_global_var("fused_add_squeeze_transpose_transpose1_left_shift") + + # Main function + x = relax.Var("x", R.Tensor((10, 20), "int32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_func, (x, relax.const(1, "int32")))) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_softmax(): + """Test if softmax can be fused with following ops.""" + + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((16, 16), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.softmax, x) + gv = bb.emit_output(bb.call_te(topi.cast, lv0, dtype="float16")) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", R.Tensor((16, 16), "float32")) + with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.softmax, x) + gv = bb.emit_output(bb.call_te(topi.cast, lv0, dtype="float16")) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_func = bb.get().get_global_var("fused_softmax_cast") + + # Main function + x = relax.Var("x", R.Tensor((16, 16), "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_func, (x,))) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py new file mode 100644 index 000000000000..91edab2bbb98 --- /dev/null +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -0,0 +1,563 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax, topi +from tvm.script import relax as R + + +def _check(mod_before, mod_expected): + mod = relax.transform.FuseTIR()(mod_before) + tvm.ir.assert_structural_equal(mod, mod_expected) + + +def test_simple(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + def fused_add_exp_squeeze(x, p0): + add = topi.add(x, p0) + exp = topi.exp(add) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(fused_add_exp_squeeze, x, p0)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_conv2d_fuse(): + def before(dtype): + bb = relax.BlockBuilder() + + # Grouped function 1 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype)) + p0 = relax.Var("p0", R.Tensor((), dtype)) + with bb.function("fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d", + ) + lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1") + gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Grouped function 2 + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype)) + y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype)) + with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=0, + dilation=1, + primfunc_name_hint="conv2d1", + ) + gv = bb.emit_output(bb.call_te(topi.add, lv0, y, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2") + fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2") + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype)) + w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype)) + w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype)) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1, relax.const(1, dtype)])) + lv2 = bb.emit_te( + topi.nn.conv2d, + lv1, + w3, + strides=1, + padding=1, + dilation=1, + ) + gv = bb.emit_output(relax.Call(fused_conv2d1_add2, [lv1, w2, lv2])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dtype): + def fused_conv2d_add1_add2(x, w, p): + conv = topi.nn.conv2d(x, w, strides=1, padding=1, dilation=1) + add = topi.add(p, conv) + return topi.add(conv, add) + + def fused_conv2d1_add2(x, w, p): + conv = topi.nn.conv2d(x, w, strides=1, padding=0, dilation=1) + return topi.add(conv, p) + + bb = relax.BlockBuilder() + + # Main function + x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) + w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype)) + w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype)) + w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype)) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit_te(fused_conv2d_add1_add2, lv0, w1, relax.const(1, dtype)) + lv2 = bb.emit_te( + topi.nn.conv2d, + lv1, + w3, + strides=1, + padding=1, + dilation=1, + ) + gv = bb.emit_output(bb.call_te(fused_conv2d1_add2, lv1, w2, lv2)) + bb.emit_func_output(gv) + + return bb.get() + + _check(before("float32"), expected("float32")) + + +def test_two_subfunction(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor([10, 20], "float32")) + with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}): + with bb.dataflow(): + lv1 = bb.emit_te(topi.exp, x1) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_squeeze") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit(relax.Call(func_gv, [x])) + lv2 = bb.emit(relax.Call(func_gv, [lv])) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_squeeze(x): + exp = topi.exp(x) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_exp_squeeze, x) + lv2 = bb.emit_te(fused_exp_squeeze, lv) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_same_primfunc(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor([10, 20], "float32")) + with bb.function("fused_exp_exp_squeeze", [x1], attrs={"Primitive": True}): + with bb.dataflow(): + lv1 = bb.emit_te(topi.exp, x1) + lv2 = bb.emit_te(topi.exp, lv1) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv2)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_exp_squeeze") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit(relax.Call(func_gv, [x])) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_exp_squeeze(x): + exp = topi.exp(x) + exp = topi.exp(exp) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_exp_exp_squeeze, x) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_tuple_as_param(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])) + with bb.function("fused_exp_add", [x], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + lv2 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.add, lv2, lv1)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_add") + x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(func_gv, [x])) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_add(x1, x2): + exp = topi.exp(x1) + return topi.add(exp, x2) + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + gv = bb.emit_output(bb.call_te(fused_exp_add, lv0, lv1)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_nested_tuple_as_param(): + tuple_struct_info = R.Tuple( + [R.Tensor([10], "float32"), R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")])] + ) + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", tuple_struct_info) + with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv0_exp = bb.emit_te(topi.exp, lv0) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + lv1_0 = bb.emit(relax.TupleGetItem(lv1, 0)) + lv1_1 = bb.emit(relax.TupleGetItem(lv1, 1)) + lv2 = bb.emit_te(topi.add, lv1_0, lv1_1) + gv = bb.emit_output(bb.call_te(topi.add, lv0_exp, lv2)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_add_add") + x = relax.Var("x", tuple_struct_info) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(func_gv, [x])) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_add_add(x1, x2, x3): + exp = topi.exp(x1) + add = topi.add(x2, x3) + return topi.add(exp, add) + + bb = relax.BlockBuilder() + x = relax.Var("x", tuple_struct_info) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + lv2 = bb.emit(relax.TupleGetItem(lv1, 0)) + lv3 = bb.emit(relax.TupleGetItem(lv1, 1)) + gv = bb.emit_output(bb.call_te(fused_exp_add_add, lv0, lv2, lv3)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_call_tir_in_main(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor([10, 20], "float32")) + with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}): + with bb.dataflow(): + lv = bb.emit_te(topi.exp, x1) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_squeeze") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit(relax.Call(func_gv, [x])) + lv1 = bb.emit_te(topi.add, lv0, relax.const(1, "float32")) + gv = bb.emit_output(lv1) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_squeeze(x): + exp = topi.exp(x) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_exp_squeeze, x) + lv2 = bb.emit_te(topi.add, lv, relax.const(1, "float32")) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_const_in_argument(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor([10, 20], "float32")) + x2 = relax.Var("x2", R.Tensor([], "float32")) + with bb.function("fused_add_exp_squeeze", [x1, x2], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x1, x2) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_add_exp_squeeze") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit(relax.Call(func_gv, [x, relax.const(1, "float32")])) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_add_exp_squeeze(x, y): + add = topi.add(x, y) + exp = topi.exp(add) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_add_exp_squeeze, x, relax.const(1, "float32")) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_tuple_output(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + + with bb.function("fused_add_exp", [x, p0], attrs={"Primitive": True}): + with bb.dataflow(): + gv0 = bb.emit_output(bb.call_te(topi.add, x, p0)) + gv1 = bb.emit_output(bb.call_te(topi.exp, gv0)) + bb.emit_func_output(relax.Tuple([gv0, gv1])) + fused_add_exp = bb.get().get_global_var("fused_add_exp") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_add_exp, [x, p0])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + def fused_add_exp(x, p0): + add = topi.add(x, p0) + exp = topi.exp(add) + return add, exp + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(fused_add_exp, x, p0)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_immediate_tuple(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + y = relax.Var("y", R.Tensor([10, 20], "float32")) + + with bb.function("fused_add", [x, y], attrs={"Primitive": True}): + with bb.dataflow(): + lv_tuple = bb.emit(relax.Tuple([x, relax.Tuple([x, y])])) + lv_x = bb.emit(relax.TupleGetItem(lv_tuple, 0)) + lv0 = bb.emit(relax.TupleGetItem(lv_tuple, 1)) + lv_y = bb.emit(relax.TupleGetItem(lv0, 1)) + gv = bb.emit_output(bb.call_te(topi.add, lv_x, lv_y)) + bb.emit_func_output(gv) + fused_add = bb.get().get_global_var("fused_add") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + y = relax.Var("y", R.Tensor([10, 20], "float32")) + with bb.function("main", [x, y]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_add, [x, y])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + y = relax.Var("y", R.Tensor([10, 20], "float32")) + with bb.function("main", [x, y]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(topi.add, x, y, primfunc_name_hint="fused_add")) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_return_partial_result(): + def te_argmax_idx_val(val): + from tvm import te + + def f_combine(x, y): + lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) + return lhs, rhs + + def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType): + return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1) + + argmax = te.comm_reducer(f_combine, f_identity, name="argmax") + m, n = val.shape + k = te.reduce_axis((0, n), "k") + max_idx, max_val = te.compute( + (m,), lambda i: argmax((k.var, val[i, k]), axis=k), name="argmax" + ) + return max_idx, max_val + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + offset = relax.Var("offset", R.Tensor([10], "int32")) + with bb.function("fused_argmax_add", [x, offset], attrs={"Primitive": True}): + with bb.dataflow(): + lv = bb.emit_te(te_argmax_idx_val, x) + idx = bb.emit(relax.TupleGetItem(lv, 0)) + gv = bb.emit_output(bb.call_te(topi.add, idx, offset)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_argmax_add") + x = relax.Var("x", R.Tensor([10, 20], "float32")) + offset = relax.Var("x", R.Tensor([10], "int32")) + with bb.function("main", [x, offset]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(func_gv, [x, offset])) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_argmax_add(x, offset): + idx, value = te_argmax_idx_val(x) + idx = topi.add(idx, offset) + return idx + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + offset = relax.Var("offset", R.Tensor([10], "int32")) + with bb.function("main", [x, offset]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(fused_argmax_add, x, offset)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index f6d2e4c20e48..6e9e14d3dc47 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1073,5 +1073,4 @@ def mul_add(x: R.Tensor) -> R.Tensor: if __name__ == "__main__": - test_cross_function_call() tvm.testing.main() From 6475d9884a123035ce349e139c3e4b274ba5e1e7 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Thu, 16 Feb 2023 12:09:47 -0800 Subject: [PATCH 26/81] [Unity][Pass] LambdaLift pass (#14012) --- include/tvm/relax/analysis.h | 57 ++++ python/tvm/relax/transform/transform.py | 10 + src/relax/analysis/analysis.cc | 173 ++++++++++ src/relax/transform/lambda_lift.cc | 266 +++++++++++++++ src/relax/utils.cc | 45 +++ .../relax/test_transform_lambda_lift.py | 304 ++++++++++++++++++ 6 files changed, 855 insertions(+) create mode 100644 src/relax/analysis/analysis.cc create mode 100644 src/relax/transform/lambda_lift.cc create mode 100644 tests/python/relax/test_transform_lambda_lift.py diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index a55fe6797d45..ff576d4ebb6a 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -260,6 +260,63 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::Analyzer* ana = nullptr); +//----------------------------------- +// General IR analysis +//----------------------------------- +/*! + * \brief Get all bound variables from expression expr. + * + * Bound variables are all variables that are declared in the expr. + * They only have meaning inside that expr, and can only be used in it. + * + * \param expr the expression. + * + * \return List of bound vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array BoundVars(const Expr& expr); + +/*! + * \brief Get free type parameters from expression expr. + * + * Free variables are variables that are not bound by a + * varbinding or a function parameter in the context. + * + * \param expr the expression. + * + * \return List of free vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array FreeVars(const Expr& expr); + +/*! + * \brief Get all variables from expression expr. + * + * \param expr the expression. + * + * \return List of all vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array AllVars(const Expr& expr); + +/*! + * \brief Get all global variables used in calls in expression expr. + * + * \param expr the expression. + * + * \return List of all global variables called in expr. + */ +TVM_DLL tvm::Array CalledGlobalVars(const Expr& expr); + +/*! + * \brief Get all global variables from expression expr. + * + * AllVars is a superset of BoundVars and FreeVars. + * The union of BoundVars and FreeVars is Allvars. + * + * \param expr the expression. + * + * \return List of all global variables, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array AllGlobalVars(const Expr& expr); + /*! * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps. * diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 0f973db290f8..1a525431dd48 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -47,6 +47,16 @@ def ToNonDataflow() -> tvm.ir.transform.Pass: return _ffi_api.ToNonDataflow() # type: ignore +def LambdaLift(): + """A pass that lifts local functions into global. + + Returns + ------- + ret : tvm.ir.transform.Pass + """ + return _ffi_api.LambdaLift() + + def CallTIRRewrite() -> tvm.ir.transform.Pass: """Perform explicit tensor allocation for call_tir. diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc new file mode 100644 index 000000000000..33197308fa1b --- /dev/null +++ b/src/relax/analysis/analysis.cc @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file analysis.cc + * + * \brief Analysis functions for Relax. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { + +template +struct InsertionSet { + std::unordered_set set; + std::vector data; + void Insert(const T& t) { + if (set.count(t) == 0) { + set.insert(t); + data.push_back(t); + } + } +}; + +class VarVisitor : protected ExprVisitor { + public: + Array Free(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : vars_.data) { + if (bound_vars_.set.count(v) == 0) { + ret.push_back(v); + } + } + return ret; + } + + Array Collect() { + Array ret; + for (const auto& v : bound_vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array Bound(const Expr& expr) { + this->VisitExpr(expr); + return Collect(); + } + + Array All(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array AllGlobalVars(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : global_vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array CalledGlobalVars(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : called_global_vars_.data) { + ret.push_back(v); + } + return ret; + } + + void MarkBounded(const Var& v) { + bound_vars_.Insert(v); + vars_.Insert(v); + } + + void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef(var)); } + + void VisitExpr_(const FunctionNode* op) final { + for (const auto& param : op->params) { + MarkBounded(param); + } + VisitExpr(op->body); + } + + void VisitExpr_(const GlobalVarNode* op) final { global_vars_.Insert(GetRef(op)); } + + void VisitExpr_(const CallNode* call_node) final { + VisitSpan(call_node->span); + VisitExpr(call_node->op); + + for (StructInfo sinfo_arg : call_node->sinfo_args) { + VisitExprDepStructInfoField(sinfo_arg); + } + + for (Expr arg : call_node->args) { + VisitExpr(arg); + } + + if (const GlobalVarNode* global_var_node = call_node->op.as()) { + called_global_vars_.Insert(GetRef(global_var_node)); + } + } + + void VisitBinding_(const VarBindingNode* binding) final { + MarkBounded(binding->var); + VisitExpr(binding->value); + VisitVarDef(binding->var); + } + + void VisitBinding_(const MatchCastNode* binding) final { + MarkBounded(binding->var); + ExprVisitor::VisitBinding_(binding); + } + + private: + InsertionSet vars_; + InsertionSet bound_vars_; + InsertionSet global_vars_; + InsertionSet called_global_vars_; +}; + +tvm::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } + +tvm::Array BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } + +tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } + +tvm::Array AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); } + +tvm::Array CalledGlobalVars(const Expr& expr) { + return VarVisitor().CalledGlobalVars(expr); +} + +TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars); + +TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars); + +TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars); + +TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars); + +TVM_REGISTER_GLOBAL("relax.analysis.called_global_vars").set_body_typed(CalledGlobalVars); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc new file mode 100644 index 000000000000..f08499036b1c --- /dev/null +++ b/src/relax/transform/lambda_lift.cc @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/lambda_lift.cc + * \brief Lift local functions into global functions. + */ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/* The goal of this class is to lift out any nested functions into top-level + * functions. + * + * We will lift a function out into a global which takes the set of the free + * vars and then return the new created function. + */ +class LambdaLifter : public ExprMutator { + public: + explicit LambdaLifter(const IRModule& module) : ExprMutator(module) { mod_ = module; } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + if (auto const* var = call_node->op.as()) { + bool has_closure = HasClosure(GetRef(var)); + auto val = builder_->LookupBinding(GetRef(var)); + // Call "relax.invoke_closure" to invoke closure + if (has_closure && val.as()) { + Var clo_arg = GetRef(var); + if (this->var_remap_.find(var->vid) != this->var_remap_.end()) { + clo_arg = this->var_remap_.at(var->vid); + } + return Call(invoke_closure_op_, {clo_arg, Tuple(call_node->args)}, {}, + {GetStructInfo(GetRef(call_node))}); + } + } + if (auto global_var_node = call_node->op.as()) { + String rec_name = global_var_node->name_hint; + auto global_var = GetRef(global_var_node); + auto it = lambda_map_.find(global_var); + if (it != lambda_map_.end()) { + // flatten nested call, e.g. call(y)(x) -> call(x, y)) + Array new_args; + for (const auto arg : call->args) { + new_args.push_back(arg); + } + if (const auto* nest_call = it->second.as()) { + for (const auto arg : nest_call->args) { + new_args.push_back(arg); + } + return Call(nest_call->op, new_args, call_node->attrs, call_node->sinfo_args); + } + return Call(it->second, call->args, call_node->attrs, call_node->sinfo_args); + } + } + return std::move(call); + } + + Expr VisitExpr_(const FunctionNode* func_node) final { + auto func = GetRef(func_node); + + // TODO(@yongwww): consider appending inner func name into the lifted func name + String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++); + auto global = GlobalVar(lift_func_name); + Array captured_vars = FreeVars(func); + recur_vars_ = CalledGlobalVars(func); + auto all_global_vars = AllGlobalVars(func); + + Array typed_captured_vars; + Map rebinding_map; + for (auto free_var : captured_vars) { + Var var = Var(free_var->name_hint(), GetStructInfo(free_var), free_var->span); + typed_captured_vars.push_back(var); + rebinding_map.Set(free_var, var); + } + + // recursive call + if (!recur_vars_.empty()) { + if (!captured_vars.empty()) { + Array fvs; + for (auto fv : captured_vars) { + fvs.push_back(fv); + } + lambda_map_.emplace(recur_vars_.back(), Call(global, fvs)); + } else { + if (recur_vars_.size() > 0) { + lambda_map_.emplace(recur_vars_.back(), global); + } + } + } + + tvm::Array params; + bool all_params_unchanged = true; + for (Var param : func_node->params) { + Var new_param = this->VisitVarDef(param); + params.push_back(new_param); + all_params_unchanged &= param.same_as(new_param); + } + + Expr body = this->VisitWithNewScope(func_node->body); + Expr visited_func; + + if (all_params_unchanged && body.same_as(func_node->body)) { + visited_func = GetRef(func_node); + } else if (const auto& body_sinfo = MatchStructInfo(body)) { + visited_func = Function(params, body, body_sinfo.value(), func_node->attrs); + } else { + visited_func = Function(params, body, func_node->ret_struct_info, func_node->attrs); + } + auto new_func = Downcast(visited_func); + + Function lifted_func; + bool is_closure = IsClosure(captured_vars); + if (!is_closure) { + lifted_func = Function( + /*params=*/new_func->params, + /*body=*/new_func->body, + /*ret_struct_info=*/new_func->ret_struct_info, + /*attrs=*/new_func->attrs, + /*span=*/new_func->span); + } else { + // Flatten the Closure + std::vector closure_params; + closure_params.reserve(func->params.size() + typed_captured_vars.size()); + for (size_t i = 0; i < func->params.size(); ++i) { + closure_params.emplace_back(func->params[i]); + } + for (size_t i = 0; i < typed_captured_vars.size(); ++i) { + closure_params.emplace_back(typed_captured_vars[i]); + } + + lifted_func = Function(/*params=*/closure_params, + /*body=*/Bind(new_func->body, rebinding_map), + /*ret_struct_info=*/new_func->ret_struct_info, + /*attrs=*/new_func->attrs, + /*span=*/func->span); + + Array param_types; + for (Var param : closure_params) { + CHECK(param->checked_type_.defined()) + << "relax.Function requires params to contain checked_type_"; + param_types.push_back(param->checked_type_); + } + } + + ICHECK(lifted_func.defined()); + + // Add the lifted function to the module. + UpdateStructInfo(global, GetStructInfo(lifted_func)); + builder_->UpdateFunction(global, lifted_func); + + if (!is_closure) { + return std::move(global); + } else { + // If we need to allocate a closure, + // we pass the variables in its environment here. + Array fvs; + for (auto fv : captured_vars) { + fvs.push_back(fv); + } + // Call make_closure intrinsic + return Call(make_closure_op_, {global, Tuple(fvs)}, {}, {}); + } + } + + bool HasClosure(const Var& var) { + auto val = builder_->LookupBinding(var); + if (const auto* value = val.as()) { + IRModule ctx_mod = builder_->GetContextIRModule(); + ICHECK(ctx_mod->functions.size() > 0); + BaseFunc func = ctx_mod->Lookup(GetRef(value)); + if (const auto* func_node = func.as()) { + if (const auto* call_node = func_node->body.as()) { + if (call_node->op == make_closure_op_) { + return true; + } + } else if (const auto* seq_expr_node = func_node->body.as()) { + // the return var points to a make_closure intrinsic + if (const auto* var = seq_expr_node->body.as()) { + return HasClosure(GetRef(var)); + } + } + } + } else if (const auto* func_node = val.as()) { + if (const auto* call_node = func_node->body.as()) { + if (call_node->op == make_closure_op_) { + return true; + } + } + } else if (const auto* call_node = val.as()) { + // recursive call + auto op = call_node->op; + if (make_closure_op_ == op) { + return true; + } + if (const auto* lv = op.as()) { + return HasClosure(GetRef(lv)); + } + } + return false; + } + + bool IsClosure(const Array& captured_vars) { return captured_vars.size() > 0; } + + IRModule Lift() { + auto glob_funcs = mod_->functions; + for (auto pair : glob_funcs) { + if (auto* n = pair.second.as()) { + auto func = GetRef(n); + func = Function(func->params, VisitExpr(func->body), func->ret_struct_info, func->attrs); + builder_->UpdateFunction(pair.first, func); + } + } + return builder_->GetContextIRModule(); + } + + private: + std::unordered_map lambda_map_; + Array recur_vars_; + IRModule mod_; + size_t lift_func_num_ = 0; + /*! \brief Cache ops that would be used later to reduce lookup overhead. */ + const Op& make_closure_op_ = Op::Get("relax.make_closure"); + const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); +}; + +namespace transform { + +Pass LambdaLift() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relax::LambdaLifter(m).Lift(); }; + return CreateModulePass(pass_func, 1, "LambdaLift", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.LambdaLift").set_body_typed(LambdaLift); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 5846f8116df2..24414f250cbc 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -22,6 +22,51 @@ namespace tvm { namespace relax { +/*! \brief Helper to implement bind params.*/ +class ExprBinder : public ExprMutator { + public: + explicit ExprBinder(const tvm::Map& args_map) : args_map_(args_map) {} + + Expr VisitExpr_(const VarNode* op) final { + auto id = GetRef(op); + auto it = args_map_.find(id); + if (it != args_map_.end()) { + return (*it).second; + } else { + return ExprMutator::VisitExpr_(op); + } + } + + private: + const tvm::Map& args_map_; +}; + +/*! + * \brief Bind params on expr + * \param expr The expr where to bind params + * \param args_map The map from param var to the expr it binds to + * \return The result expr after bind params + */ +Expr Bind(const Expr& expr, const tvm::Map& args_map) { + if (const FunctionNode* func = expr.as()) { + Expr new_body = ExprBinder(args_map).VisitExpr(func->body); + Array new_params; + for (size_t i = 0; i < func->params.size(); ++i) { + if (!args_map.count(func->params[i])) { + new_params.push_back(func->params[i]); + } + } + if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { + return expr; + } + // The checked_type_ of the new function is deduced from the function body + // TODO(@relax-team): Should infer the shape from the body as well + return Function(new_params, new_body, NullOpt, func->attrs); + } else { + return ExprBinder(args_map).VisitExpr(expr); + } +} + bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) { const DynTensorTypeNode* tt = ty.as(); if (!tt) { diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py new file mode 100644 index 000000000000..fbdb1fbdcea9 --- /dev/null +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -0,0 +1,304 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +import tvm.testing +from tvm import relax +import tvm.script +from tvm.script import relax as R, tir as T +from tvm.relax import transform +from tvm.ir.base import assert_structural_equal + + +def _check_equal(x, y): + tvm.ir.assert_structural_equal(x, y) + tvm.ir.assert_structural_equal(y, x) + + xhash = tvm.ir.structural_hash(x, map_free_vars=True) + yhash = tvm.ir.structural_hash(y, map_free_vars=True) + assert xhash == yhash + + +def _check_save_roundtrip(x): + y = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, y) + + +def test_basic(): + # the target IRModule + @tvm.script.ir_module + class Expected: + @R.function + def lifted_func_0( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + @R.function + def main( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + inner = lifted_func_0 + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + @tvm.script.ir_module + class Before: + @R.function + def main( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + @R.function + def inner( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + before = Before + expected = Expected + # Perform Lambda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 2 + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_closure(): + # the expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + outer_func = lifted_func_0 + in_call = outer_func(x) + res = R.invoke_closure(in_call, (y,), sinfo_args=(R.Tensor((2, 3), dtype="float32"))) + return res + + @R.function + def lifted_func_1(x1: R.Tensor((2, 3), "float32"), c1: R.Tensor((2, 3), "float32")): + r_1: R.Tensor((2, 3), "float32") = R.add(x1, c1) + return r_1 + + @R.function + def lifted_func_0(y: R.Tensor((2, 3), "float32")) -> R.Object: + inner_func = R.make_closure(lifted_func_1, (y,)) + return inner_func + + # IRModule to perform Lambda Lifting + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + @R.function + def outer_func(c1: R.Tensor((2, 3), "float32")): + @R.function + def inner_func(x1: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + s: R.Tensor((2, 3), "float32") = R.add(x1, c1) + return s + + return inner_func + + in_call = outer_func(x) + res = in_call(y) + return res + + before = Before + after = transform.LambdaLift()(before) + expected = Expected + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +@pytest.mark.skip(reason="Need fix after parser switch over") +def test_recursive(): + # the expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def lifted_func_0( + i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + cond: R.Tensor((), "bool") = R.call_packed( + "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), dtype="bool")) + ) + c: R.Tensor((), "int32") = R.const(1, dtype="int32") + if cond: + new_i: R.Tensor((), "int32") = R.add(i, c) + new_s: R.Tensor((2, 3), "float32") = R.add(s, x) + r = lifted_func_0(new_i, new_s, x) + else: + r = s + return r + + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor: + while_loop = R.make_closure(lifted_func_0, (x,)) + gv = R.invoke_closure( + while_loop, + (relax.const(0), x), + sinfo_args=(R.Tensor(ndim=2, dtype="float32")), + ) + return gv + + # the IRModule to apply lambda lifting + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor: + @R.function + def while_loop( + i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + cond: R.Tensor((), "bool") = R.call_packed( + "test.vm.less", i, R.const(10), sinfo_args=(R.Tensor((), dtype="bool")) + ) + c: R.Tensor((), "int32") = R.const(1, dtype="int32") + if cond: + new_i: R.Tensor((), "int32") = R.add(i, c) + new_s: R.Tensor((2, 3), "float32") = R.add(s, x) + r: R.Tensor((2, 3), "float32") = while_loop(new_i, new_s) + else: + r: R.Tensor((2, 3), "float32") = s + return r + + gv: R.Tensor((2, 3), "float32") = while_loop(relax.const(0), x) + return gv + + before = Before + expected = Expected + # Perform Lamda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 2 + + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +@pytest.mark.skip(reason="Need fix after parser switch over") +def test_multi_func(): + # expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def glob_func_1( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + inner = lifted_func_1 + gv1 = inner(x1, y1) + return gv1 + + @R.function + def glob_func_2( + x11: R.Tensor((10, 5), "float32"), y11: R.Tensor((10, 5), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + inner1 = lifted_func_0 + gv11 = inner1(x11, y11) + return gv11 + + @R.function + def lifted_func_0( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + @R.function + def lifted_func_1( + x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + s1: R.Tensor((10, 5), "float32") = R.add(x21, y21) + return s1 + + # the IRModule to apply lambda lifting + @tvm.script.ir_module + class Before: + @R.function + def glob_func_1( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + @R.function + def inner( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + @R.function + def glob_func_2( + x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + @R.function + def inner( + x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") + ) -> R.Tensor((10, 5), "float32"): + s: R.Tensor((10, 5), "float32") = R.add(x2, y2) + return s + + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + before = Before + expected = Expected + # Perform Lamda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 4 + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_no_local_func(): + @tvm.script.ir_module + class Before: + @T.prim_func + def sub( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], + ) -> None: + for i, j in T.grid(16, 16): + with T.block("sub"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] - B[vi, vj] + + @R.function + def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor(dtype="float32", ndim=2)): + s = R.call_tir(sub, (c0, x), R.Tensor((16, 16), dtype="float32")) + return s + + before = Before + # Perform lambda lifting + after = transform.LambdaLift()(before) + # No local functions are lifted + assert_structural_equal(after, before, map_free_vars=True) + _check_save_roundtrip(after) + + +if __name__ == "__main__": + tvm.testing.main() From af63d19d7de5823d6f9aa35ebeb18b3848d74d48 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 16 Feb 2023 21:12:02 -0500 Subject: [PATCH 27/81] [Unity][VM] Supporting "compiled" exec mode. (#14015) [VM] Supporting "compiled" exec mode. This PR adds support of "compiled" mode to the VM. The compiled mode translate the relax function into TIR function and drive it through the TIR function. It is different from the micro AOT codegen, which generate TIR code that targets the micro C runtime environment and useful for resource limited settings with smaller set of features. Both leverages the low-level TIR build that is also shared with TensorIR. The current implementation targets full TVM (VM) runtime, that comes with PackedFunc, object, tuple, closure and all kinds of rich structure support. This also mean that we can leverage the full runtime support to handle things like allocation, dynamic shape, easy plugins and python interaction, which are not available in more limited runtime. The user directly use the same API to load the generated code regardless of compiled mode or bytecode. And just need to change one line ```python ex = relax.vm.build(mod, target, exec_mode="compiled") ``` The simplicity is thanks to the TVM runtime archiecture that allows us to compose things together in objects. The only difference is how the PackedFunc of high-level driving is being provided. In the case of bytecode it is normal interpretation and in the case of compiled mode it is TIR. It is a complete implementation Unit-testcases are added. All codegen build tests are updated to include two exec_modes and have passed locally. Co-authored-by: Junru Shao --- include/tvm/tir/builtin.h | 44 ++ python/tvm/script/ir_builder/tir/ir.py | 8 + python/tvm/tir/op.py | 68 +++ src/relax/backend/vm/codegen_vm_tir.cc | 511 +++++++++++++++++++++ src/runtime/library_module.cc | 5 +- src/target/llvm/codegen_cpu.cc | 6 +- src/tir/op/builtin.cc | 12 + src/tir/op/runtime.cc | 41 ++ src/tir/transforms/lower_tvm_builtin.cc | 169 +++---- tests/python/relax/test_vm_build.py | 2 +- tests/python/relax/test_vm_codegen_only.py | 2 +- tests/python/relax/test_vm_codegen_tir.py | 224 +++++++++ 12 files changed, 1006 insertions(+), 86 deletions(-) create mode 100644 src/relax/backend/vm/codegen_vm_tir.cc create mode 100644 src/tir/op/runtime.cc create mode 100644 tests/python/relax/test_vm_codegen_tir.py diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 708abde2cd31..35022e0e75f4 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -762,6 +762,50 @@ TVM_DLL const Op& start_profile_intrinsic(); */ TVM_DLL const Op& end_profile_intrinsic(); +/*! + * \brief Get a item from any list and return it. + * + * Any anylist_getitem(Handle anylist, + * int index) + * return anylist[index]; + * } + * + * \note This intrinsic is only applicable when appearing + * in call_packed and anylist_setitem_call_packed. + */ +TVM_DLL const Op& anylist_getitem(); + +/*! + * \brief Reset and clear a item in any list. + * + * void anylist_resetitem(Handle anylist, + * int index) + * anylist[index] = nullptr; + * } + * + * \note This intrinsic is only applicable when appearing + * in call_packed and anylist_setitem_call_packed. + */ +TVM_DLL const Op& anylist_resetitem(); + +/*! + * \brief Set an item into any list by running packed function call. + * + * void anylist_setitem_call_packed(Handle anylist, + * int index, + * name, *args) + * + * anylist[index] = call_packed(name, *args) + * } + * \note This intrinsic can be used in combination with anylist_getitem. + */ +TVM_DLL const Op& anylist_setitem_call_packed(); + +/*! + * \brief Same as anylist_setitem_call_packed but use C calling convention. + */ +TVM_DLL const Op& anylist_setitem_call_cpacked(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 5f4e9d4f2cf0..601963565fff 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1713,6 +1713,10 @@ def wrapped(*args, **kwargs): TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace) start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic) end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic) +anylist_getitem = _op_wrapper(_tir_op.anylist_getitem) +anylist_resetitem = _op_wrapper(_tir_op.anylist_resetitem) +anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed) +anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked) def _dtype_forward(func): @@ -1988,6 +1992,10 @@ def wrapped(*args, **kwargs): "start_profile_intrinsic", "end_profile_intrinsic", "meta_var", + "anylist_getitem", + "anylist_resetitem", + "anylist_setitem_call_packed", + "anylist_setitem_call_cpacked", "llvm_lookup_intrinsic_id", "type_annotation", "broadcast", diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 0a9c4fdfaa52..14decca77e51 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -2931,6 +2931,74 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr): return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr) +def anylist_getitem(list_handle, index): + """Returns an item from any list. + list_handle: Var + The handle to anylist + index : int + The index + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.anylist_getitem", list_handle, index) + + +def anylist_resetitem(list_handle, index): + """Reset an item from any list. + list_handle: Var + The handle to anylist + index : int + The index + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("int", "tir.anylist_resetitem", list_handle, index) + + +def anylist_setitem_call_packed(list_handle, index, func_name, *args): + """Set anylist item by result of packed call. + list_handle: Var + The handle to anylist + index : int + The index + func_name: str + The name of the function to be called. + args: + Extra arguments + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "int", "tir.anylist_setitem_call_packed", list_handle, index, func_name, *args + ) + + +def anylist_setitem_call_cpacked(list_handle, index, func_name, *args): + """Set anylist item by result of packed call. + list_handle: Var + The handle to anylist + index : int + The index + func_name: str + The name of the function to be called. + args: + Extra arguments + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "int", "tir.anylist_setitem_call_cpacked", list_handle, index, func_name, *args + ) + + # pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc new file mode 100644 index 000000000000..2f63a50d370f --- /dev/null +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -0,0 +1,511 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/vm/codegen_tir.cc + * \brief A codegen to generate VMTIR function(that can be compiled) from executable. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace relax_vm { + +using vm::VMFuncInfo; + +/*! + * \brief A class to generate VMTIR for Relax functions. + * + * \note Skip CallPacked with special attrs for now, as they can be + * further simplified with PrimValue. + */ +class CodeGenVMTIR : public ExprFunctor(const Expr&)> { + public: + explicit CodeGenVMTIR(relax::ExecBuilder builder, IRModule ctx_mod) + : builder_(builder), ctx_mod_(ctx_mod) {} + + static IRModule Run(relax::ExecBuilder builder, IRModule mod) { + // create a new copy + IRModule res_mod = mod; + res_mod.CopyOnWrite(); + + CodeGenVMTIR codegen(builder, mod); + // Remove relax function and turn into TIR func. + for (auto& p : mod->functions) { + if (auto* func = p.second.as()) { + auto tir_func = codegen.Codegen(GetRef(func)); + auto gsymbol = tir_func->GetAttr(tvm::attr::kGlobalSymbol); + res_mod->Add(GlobalVar(gsymbol.value()), tir_func); + res_mod->Remove(p.first); + } + } + return res_mod; + } + + private: + int64_t NewRegister() { return registers_num_++; } + + static IntImm ConstInt64(int64_t value) { return IntImm(DataType::Int(64), value); } + + static IntImm ConstInt32(int64_t value) { return IntImm(DataType::Int(32), value); } + + PrimExpr RegListGet(int64_t slot) const { + // use 128 bits to represent any + return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(), + {reg_anylist_handle_, ConstInt32(slot)}); + } + + PrimExpr ConstListGet(int64_t slot) const { + // use 128 bits to represent any + return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(), + {const_anylist_handle_, ConstInt32(slot)}); + } + + PrimExpr FuncListGet(int64_t slot) const { + // use 128 bits to represent any + return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(), + {func_anylist_handle_, ConstInt32(slot)}); + } + + void EmitStmt(tir::Stmt stmt) { + ICHECK(!stmt_stack_.empty()); + stmt_stack_.back().emplace_back(stmt); + } + + void EmitCallPacked(String name, const Array& args, int64_t dst_anylist_slot = -1) { + Array all_args; + // negative index indicate return value can be discarded, emit call_packed + if (dst_anylist_slot >= 0) { + all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; + } + all_args.push_back(tir::StringImm(name)); + for (PrimExpr arg : args) { + all_args.push_back(arg); + } + if (dst_anylist_slot >= 0) { + this->EmitStmt(tir::Evaluate( + tir::Call(DataType::Int(32), tir::builtin::anylist_setitem_call_packed(), all_args))); + } else { + this->EmitStmt( + tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), all_args))); + } + } + + void EmitCallCPacked(const tir::PrimFunc& prim_func, const Array& args, + int64_t dst_anylist_slot = -1) { + Optional gsymbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(gsymbol.defined()) << "All functions must have global symbol at this phase"; + Array all_args; + // negative index indicate return value can be discarded, emit call_packed + if (dst_anylist_slot >= 0) { + all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; + } + all_args.push_back(tir::StringImm(gsymbol.value())); + for (PrimExpr arg : args) { + all_args.push_back(arg); + } + // push an empty handle to be compatible with current cpacked convention + // TODO(tqchen): revisit C Packed convention + all_args.push_back(tir::make_zero(DataType::Handle())); + if (dst_anylist_slot >= 0) { + this->EmitStmt(tir::Evaluate( + tir::Call(DataType::Int(32), tir::builtin::anylist_setitem_call_cpacked(), all_args))); + } else { + this->EmitStmt( + tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_call_cpacked(), all_args))); + } + } + + tir::PrimFunc Codegen(const Function& func) { + Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(gsymbol.defined()) << "there should be no local functions in Relax VM codegen phase. " + "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; + // initialize the state + stmt_stack_ = {}; + registers_num_ = 0; + var_map_.clear(); + ctx_ptr_ = tir::Var("ctx_ptr", DataType::Handle()); + reg_anylist_handle_ = tir::Var("r", DataType::Handle()); + func_anylist_handle_ = tir::Var("f", DataType::Handle()); + const_anylist_handle_ = tir::Var("c", DataType::Handle()); + + Array param_names; + for (Var param : func->params) { + param_names.push_back(param->name_hint()); + } + // declare this function. + builder_->DeclareFunction(gsymbol.value(), vm::VMFuncInfo::FuncKind::kVMTIRFunc); + + for (size_t i = 0; i < func->params.size(); ++i) { + int64_t r = NewRegister(); + ICHECK_EQ(static_cast(r), i); + this->var_map_.insert({func->params[i], RegListGet(r)}); + } + size_t ret_reg = NewRegister(); + + tir::Stmt body = WithNewScope([&]() { + Optional ret = ExprFunctor::VisitExpr(func->body); + if (ret.defined()) { + this->EmitCallPacked("vm.builtin.copy", {ret.value()}, ret_reg); + } + }); + + // Mark the function entry internally. + builder_->EmitFunction(gsymbol.value(), param_names.size(), param_names, + VMFuncInfo::FuncKind::kVMTIRFunc, registers_num_); + builder_->EndFunction(gsymbol.value()); + + Type ret_type = VoidType(); + Array tir_params = {ctx_ptr_, reg_anylist_handle_, const_anylist_handle_, + func_anylist_handle_}; + String tir_func_name = "__vmtir__" + gsymbol.value(); + tir::PrimFunc tir_func(tir_params, body, ret_type, {}); + tir_func = WithAttr(tir_func, "global_symbol", tir_func_name); + registers_num_ = 0; + var_map_.clear(); + stmt_stack_.clear(); + return tir_func; + } + + Optional VisitExpr_(const SeqExprNode* op) final { + for (auto block : op->blocks) { + for (Binding binding : block->bindings) { + Optional value; + if (auto* var_binding = binding.as()) { + value = this->VisitExpr(var_binding->value); + } else if (auto* match_cast = binding.as()) { + value = this->VisitExpr(match_cast->value); + } else { + LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); + } + this->var_map_.insert({binding->var, value}); + } + } + return this->VisitExpr(op->body); + } + + Optional VisitExpr_(const CallNode* call_node) final { + Call call = GetRef(call_node); + + if (call_node->op == null_value_op_) { + return tir::Call(DataType::Handle(), tir::builtin::reinterpret(), + {IntImm(DataType::Int(64), 0)}); + } + int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister(); + if (call->op.as()) { + if (call_node->op == call_builtin_with_ctx_op_) { + EmitCallBuiltinWithCtx(call, dst_reg); + } else if (call_node->op == alloc_storage_op_) { + EmitAllocStorage(call, dst_reg); + } else if (call_node->op == alloc_tensor_op_) { + EmitAllocTensor(call, dst_reg); + } else { + // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those + // ops are handled in a pass when lowering them to TIR. + LOG(FATAL) << "CodeGenVMTIR cannot handle this intrinsic now:\n" << call_node->op; + } + } else { + EmitNormalCall(call, dst_reg); + } + if (dst_reg >= 0) { + return RegListGet(dst_reg); + } else { + return NullOpt; + } + } + + Optional VisitExpr_(const IfNode* op) final { + // Reserve a register for return + size_t merge_register = NewRegister(); + PrimExpr cond_value = this->VisitExpr(op->cond).value(); + + // turn ndarray cond value into scalar. + cond_value = tir::Cast(DataType::Bool(), + tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), + {tir::StringImm("vm.builtin.read_if_cond"), cond_value})); + + tir::Stmt true_branch = WithNewScope([&]() { + PrimExpr true_value = this->VisitExpr(op->true_branch).value(); + this->EmitCallPacked("vm.builtin.copy", {true_value}, merge_register); + }); + tir::Stmt false_branch = WithNewScope([&]() { + PrimExpr false_value = this->VisitExpr(op->false_branch).value(); + this->EmitCallPacked("vm.builtin.copy", {false_value}, merge_register); + }); + this->EmitStmt(tir::IfThenElse(cond_value, true_branch, false_branch)); + return RegListGet(merge_register); + } + + Optional VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + auto it = this->var_map_.find(var); + ICHECK(it != this->var_map_.end()) << "Var " << var << " is not defined"; + return it->second; + } + + Optional VisitExpr_(const ConstantNode* op) final { + return ConstListGet(builder_->ConvertConstant(op->data).value()); + } + + Optional VisitExpr_(const ShapeExprNode* op) final { + std::vector shape; + for (PrimExpr e : op->values) { + if (auto* int_value = e.as()) { + shape.push_back(int_value->value); + } else { + LOG(FATAL) << "Should only use constant shape after shape lowering: " << op->values; + } + } + return ConstListGet(builder_->ConvertConstant(ShapeTuple(shape)).value()); + } + + Optional VisitExpr_(const PrimValueNode* op) final { return op->value; } + + Optional VisitExpr_(const StringImmNode* op) final { + return ConstListGet(builder_->ConvertConstant(op->value).value()); + } + + Optional VisitExpr_(const DataTypeImmNode* op) final { + return ConstListGet(builder_->ConvertConstant(op->value).value()); + } + + Optional VisitExpr_(const TupleNode* op) final { + Tuple tuple = GetRef(op); + Array args; + for (auto arg : tuple->fields) { + args.push_back(this->VisitExpr(arg).value()); + } + int32_t dst_register = NewRegister(); + this->EmitCallPacked("vm.builtin.make_tuple", args, dst_register); + return RegListGet(dst_register); + } + + Optional VisitExpr_(const TupleGetItemNode* op) final { + TupleGetItem expr = GetRef(op); + Array args = {this->VisitExpr(expr->tuple).value()}; + + args.push_back(ConstInt64(expr->index)); + + int64_t dst_register = NewRegister(); + this->EmitCallPacked("vm.builtin.tuple_getitem", args, dst_register); + return RegListGet(dst_register); + } + + // Lookup the function and see if it matches + Optional LookupFunction(const Expr& expr, VMFuncInfo::FuncKind* kind) { + if (auto* ext_func = expr.as()) { + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return ext_func->global_symbol; + } else if (auto* gvar_ptr = expr.as()) { + GlobalVar gvar = GetRef(gvar_ptr); + // Run a look up in the env to see if it maps to an extern func. + auto it = ctx_mod_->functions.find(gvar); + if (it != ctx_mod_->functions.end()) { + BaseFunc func = (*it).second; + if (auto* efunc = func.as()) { + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return efunc->global_symbol; + } else if (func.as()) { + *kind = VMFuncInfo::FuncKind::kVMTIRFunc; + return gvar->name_hint; + } else if (func.as()) { + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return gvar->name_hint; + } else { + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return gvar->name_hint; + } + } + LOG(WARNING) << "Undefined global var " << gvar->name_hint; + // undefined global var, consider eliminate later. + *kind = VMFuncInfo::FuncKind::kPackedFunc; + return gvar->name_hint; + } else { + return NullOpt; + } + } + // Lookup PrimFunc in the same module + // We can do direct PrimFunc call in such cases + Optional LookupPrimFunc(const String& name) { + if (!ctx_mod_->ContainGlobalVar(name)) return NullOpt; + + GlobalVar gvar = ctx_mod_->GetGlobalVar(name); + auto it = ctx_mod_->functions.find(gvar); + if (it != ctx_mod_->functions.end()) { + BaseFunc func = (*it).second; + if (auto* prim_func = func.as()) { + return GetRef(prim_func); + } + } + return NullOpt; + } + + Optional VisitExpr_(const GlobalVarNode* op) final { + VMFuncInfo::FuncKind kind; + auto symbol = LookupFunction(GetRef(op), &kind); + ICHECK(symbol.defined()); + builder_->DeclareFunction(symbol.value(), kind); + return FuncListGet(builder_->GetFunction(symbol.value()).value()); + } + + Optional VisitExpr_(const ExternFuncNode* op) final { + builder_->DeclareFunction(op->global_symbol, VMFuncInfo::FuncKind::kPackedFunc); + return FuncListGet(builder_->GetFunction(op->global_symbol).value()); + } + + void EmitAllocStorage(const Call& call_node, int64_t dst_reg) { + // Handle args of the call + Array args; + args.push_back(ctx_ptr_); + for (Expr arg : call_node->args) { + args.push_back(this->VisitExpr(arg).value()); + } + this->EmitCallPacked("vm.builtin.alloc_storage", args, dst_reg); + } + + void EmitAllocTensor(const Call& call_node, int64_t dst_reg) { + ICHECK_EQ(call_node->args.size(), 4); + Array args; + args.reserve(4); + for (Expr arg : call_node->args) { + args.push_back(this->VisitExpr(arg).value()); + } + this->EmitCallPacked("vm.builtin.alloc_tensor", args, dst_reg); + } + + void EmitCallBuiltinWithCtx(const Call& call_node, int64_t dst_reg) { + Array args; + // if context is required, pass as first argument. + args.push_back(ctx_ptr_); + auto* func = call_node->args[0].as(); + ICHECK(func) << "CallBuiltin comes with extern func"; + + auto tuple_arg = Downcast(call_node->args[1]); + + // Handle args of the call + for (Expr arg : tuple_arg->fields) { + args.push_back(this->VisitExpr(arg).value()); + } + + this->EmitCallPacked(func->global_symbol, args, dst_reg); + } + + void EmitNormalCall(const Call& call_node, int64_t dst_reg) { + Array args = VisitArray(call_node->args); + // A function can be a closure that comes from parent + // Do call closure to be safe. + VMFuncInfo::FuncKind kind; + auto symbol = LookupFunction(call_node->op, &kind); + + if (symbol.defined() && kind == VMFuncInfo::FuncKind::kPackedFunc) { + // primfunc in the same module. + // use cpacked to directly invoke without named based lookup + if (Optional prim_func = LookupPrimFunc(symbol.value())) { + this->EmitCallCPacked(prim_func.value(), args, dst_reg); + } else { + this->EmitCallPacked(symbol.value(), args, dst_reg); + } + } else { + // Default path, leverage function table and invoke as closure + Array all_args; + all_args.push_back(ctx_ptr_); + all_args.push_back(this->VisitExpr(call_node->op).value()); + for (auto arg : args) { + all_args.push_back(arg); + } + this->EmitCallPacked("vm.builtin.invoke_closure", all_args, dst_reg); + } + } + + template + tir::Stmt WithNewScope(const FLambda& callback) { + stmt_stack_.push_back({}); + callback(); + tir::Stmt stmt = tir::SeqStmt::Flatten(stmt_stack_.back()); + stmt_stack_.pop_back(); + return stmt; + } + + Array VisitArray(const Array& arr) { + Array ret; + for (size_t i = 0; i < arr.size(); ++i) { + ret.push_back(this->VisitExpr(arr[i]).value()); + } + return ret; + } + /*! \brief Internal ExecBuilder. */ + relax::ExecBuilder builder_; + /*! \brief List to ctx_ptr */ + tir::Var ctx_ptr_; + /*! \brief List to store temp object registers */ + tir::Var reg_anylist_handle_; + /*! \brief List to store closures */ + tir::Var func_anylist_handle_; + /*! \brief List to store constants */ + tir::Var const_anylist_handle_; + /*! + * \brief Total number of virtual registers allocated. + * \note The first two registers are reserved for special registers. + */ + int64_t registers_num_ = 0; + /*! \brief Stack to build up statements */ + std::vector> stmt_stack_; + /*! \brief Map from var to Expr. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> var_map_; + /*! \brief the context module. */ + IRModule ctx_mod_; + /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */ + const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage"); + const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor"); + const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); + const Op& null_value_op_ = Op::Get("relax.null_value"); +}; + +/*! + * \brief Create the Relax VM executable from all relax.Function in mod. + * and add them to exec_builder. Create extra TIR functions. + * + * \param exec_builder Builder to collect executables. + * \param mod Input module. + * \return Extra TIR module created. + */ +IRModule VMTIRCodeGen(ExecBuilder exec_builder, IRModule mod) { + return CodeGenVMTIR::Run(exec_builder, mod); +} + +TVM_REGISTER_GLOBAL("relax.VMTIRCodeGen").set_body_typed(VMTIRCodeGen); + +} // namespace relax_vm +} // namespace relax +} // namespace tvm diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index d6c2f791deb9..17dfbec0d054 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -77,7 +77,10 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& int ret_type_code = kTVMNullptr; int ret = (*faddr)(const_cast(args.values), const_cast(args.type_codes), args.num_args, &ret_value, &ret_type_code, nullptr); - ICHECK_EQ(ret, 0) << TVMGetLastError(); + // NOTE: important to keep the original error message. + if (ret != 0) { + LOG(FATAL) << TVMGetLastError(); + } if (ret_type_code != kTVMNullptr) { *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); } diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 21d2c6ebe0a5..10aa2688a846 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -905,8 +905,10 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage, func_name, module_.get()); } - - nargs -= 1; + // NOTE: This is a bugfix to a previous coupled convention(in lower_tvm_builtin) + // The begin, end should correspond to the right location in cpacked excluding resource handle. + // TODO(tqchen): upstream the fix. + // nargs -= 1; call_args.insert(call_args.end(), { builder_->CreateBitCast(arg_value, t_void_p_), arg_tcode.addr, diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 680202751f12..f9d522804260 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -318,6 +318,18 @@ TIR_DEFINE_BUILTIN_FUNC(start_profile_intrinsic) TIR_DEFINE_BUILTIN_FUNC(end_profile_intrinsic) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_BUILTIN_FUNC(anylist_getitem) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kReadState)); + +TIR_DEFINE_BUILTIN_FUNC(anylist_resetitem) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TGlobalSymbol", "TVMBackendAnyListResetItem"); + +TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_packed) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_cpacked) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace builtin } // namespace tir } // namespace tvm diff --git a/src/tir/op/runtime.cc b/src/tir/op/runtime.cc new file mode 100644 index 000000000000..9ee6c67ec96b --- /dev/null +++ b/src/tir/op/runtime.cc @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/op/runtime.cc + * \brief TIR ops for runtime functions. + */ +#include +#include + +namespace tvm { +namespace tir { + +TVM_REGISTER_OP("tir.TVMBackendAnyListSetPackedArg") + .set_num_inputs(5) + .set_attr("TGlobalSymbol", "TVMBackendAnyListSetPackedArg") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tir.TVMBackendAnyListMoveFromPackedReturn") + .set_num_inputs(3) + .set_attr("TGlobalSymbol", "TVMBackendAnyListMoveFromPackedReturn") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 082a54f9c73d..b0a87a3056b4 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -302,13 +302,21 @@ class BuiltinLower : public StmtExprMutator { return Stmt(n); } } + PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_call_packed())) { - return MakeCallPacked(op, /* use_string_lookup */ true); + return MakeCallPackedGeneric(op, 0, builtin::tvm_call_packed_lowered(), + /* use_string_lookup */ true); } else if (op->op.same_as(builtin::tvm_call_cpacked())) { - return MakeCallPacked(op, /* use_string_lookup */ false); + return MakeCallPackedGeneric(op, 0, builtin::tvm_call_cpacked_lowered(), + /* use_string_lookup */ false); } else if (op->op.same_as(builtin::tvm_call_trace_packed())) { - return MakeCallTracePacked(op); + return MakeCallPackedGeneric(op, 0, builtin::tvm_call_trace_packed_lowered(), + /* use_string_lookup */ true); + } else if (op->op.same_as(builtin::anylist_setitem_call_packed())) { + return MakeAnyListSetItemCallPacked(op, builtin::tvm_call_packed_lowered(), true); + } else if (op->op.same_as(builtin::anylist_setitem_call_cpacked())) { + return MakeAnyListSetItemCallPacked(op, builtin::tvm_call_cpacked_lowered(), false); } else if (op->op.same_as(builtin::tvm_stack_make_shape())) { return MakeShape(op); } else if (op->op.same_as(builtin::tvm_stack_make_array())) { @@ -418,8 +426,68 @@ class BuiltinLower : public StmtExprMutator { cast(DataType::Int(32), device_type_))); return TVMStructGet(DataType::Handle(), scope.stack_array, idx, builtin::kArrAddr); } - // call packed. - PrimExpr MakeCallPacked(const CallNode* op, bool use_string_lookup) { + + void SetPackedArg(PrimExpr arg, const Var& value_stack, const Buffer& tcode_stack, + size_t stack_offset, std::vector* prep_seq) { + auto* call_pattern = arg.as(); + if (call_pattern && call_pattern->op.same_as(builtin::anylist_getitem())) { + // call runtime function to set anylist + prep_seq->emplace_back( + Evaluate(Call(DataType::Int(32), Op::Get("tir.TVMBackendAnyListSetPackedArg"), + {call_pattern->args[0], call_pattern->args[1], value_stack, + tcode_stack->data, ConstInt32(stack_offset)}))); + } else { + DataType api_type = APIType(arg.dtype()); + if (arg.dtype() != api_type) { + arg = Cast(api_type, arg); + } + prep_seq->emplace_back( + TVMStructSet(value_stack, stack_offset, builtin::kTVMValueContent, arg)); + int arg_tcode = api_type.code(); + if (api_type.is_handle() && arg.as()) { + arg_tcode = kTVMStr; + } else if (IsArrayHandle(arg)) { + arg_tcode = kTVMDLTensorHandle; + } + // opaque handle need to set the kind properly + if (arg_tcode == kTVMOpaqueHandle) { + prep_seq->emplace_back(IfThenElse( + Call(DataType::Bool(), builtin::isnullptr(), {arg}), + BufferStore(tcode_stack, ConstInt32(kTVMNullptr), {ConstInt32(stack_offset)}), + BufferStore(tcode_stack, ConstInt32(arg_tcode), {ConstInt32(stack_offset)}))); + } else { + prep_seq->emplace_back( + BufferStore(tcode_stack, ConstInt32(arg_tcode), {ConstInt32(stack_offset)})); + } + } + } + + PrimExpr MakeAnyListSetItemCallPacked(const CallNode* op, const Op& lowered_op, + bool use_string_lookup) { + PrimExpr list_handle = op->args[0]; + PrimExpr list_index = op->args[1]; + + Call call = MakeCallPackedGeneric(op, 2, lowered_op, use_string_lookup); + PrimExpr value_stack = call->args[1]; + PrimExpr tcode_stack = call->args[2]; + // The stack offset of return value stack_end + PrimExpr ret_offset = call->args[4]; + auto& prep_seq = prep_seq_stack_.back(); + prep_seq.emplace_back(Evaluate(call)); + return Call(DataType::Int(32), Op::Get("tir.TVMBackendAnyListMoveFromPackedReturn"), + {list_handle, list_index, value_stack, tcode_stack, ret_offset}); + } + /*! + * \brief Generic tool to make low-level + * packed_call(other_args..., func_name, packed_arg0, packed_arg1...) + * + * \param op The call + * \param name_offset The beginning of function name and call packed section. + * \param lowered_packed_op The target lowered op. + * \param use_string_lookup Whether to lookup function by string. + */ + Call MakeCallPackedGeneric(const CallNode* op, size_t name_offset, const Op& lowered_packed_op, + bool use_string_lookup) { auto& scope = alloca_scope_.back(); auto& prep_seq = prep_seq_stack_.back(); @@ -427,34 +495,24 @@ class BuiltinLower : public StmtExprMutator { size_t restore_array_stack = scope.run_sizes.array_stack; size_t arg_stack_begin = scope.run_sizes.arg_stack; - size_t arg_count = op->args.size(); + size_t args_begin = name_offset + 1; + size_t args_end = op->args.size(); // cpacked expects a resource_handle parameter if (!use_string_lookup) { - arg_count--; + --args_end; } + size_t num_args = args_end - args_begin; - scope.run_sizes.arg_stack += arg_count; + // The extra one slot is for return value. + scope.run_sizes.arg_stack += num_args + 1; // Specially handle the buffer packed intrinsic PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - for (size_t i = 1; i < arg_count; ++i) { - PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1); - PrimExpr arg = op->args[i]; - DataType t = arg.dtype(); - DataType api_type = APIType(t); - if (t != api_type) { - arg = Cast(api_type, arg); - } - prep_seq.emplace_back(TVMStructSet(scope.stack_value, - static_cast(arg_stack_begin + i - 1), - builtin::kTVMValueContent, arg)); - int arg_tcode = api_type.code(); - if (api_type.is_handle() && arg.as()) { - arg_tcode = kTVMStr; - } - if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle; - prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index})); + + for (size_t i = 0; i < num_args; ++i) { + this->SetPackedArg(op->args[args_begin + i], scope.stack_value, scope.stack_tcode, + arg_stack_begin + i, &prep_seq); } // Verify stack size matches earlier value. if (is_precheck_) { @@ -465,13 +523,12 @@ class BuiltinLower : public StmtExprMutator { scope.run_sizes.shape_stack = restore_shape_stack; scope.run_sizes.array_stack = restore_array_stack; scope.run_sizes.arg_stack = arg_stack_begin; - Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + op->args.size() - 1)}; - + Array packed_args = {op->args[name_offset], scope.stack_value, + scope.stack_tcode->data, ConstInt32(arg_stack_begin), + ConstInt32(arg_stack_begin + num_args)}; // cpacked call resource_handle if (!use_string_lookup) { - PrimExpr last_arg = op->args[arg_count]; + PrimExpr last_arg = op->args[args_end]; const VarNode* var_node = last_arg.as(); if (var_node != nullptr) { tir::Var resource_handle = GetRef(var_node); @@ -480,57 +537,7 @@ class BuiltinLower : public StmtExprMutator { packed_args.push_back(last_arg); } } - - auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered() - : builtin::tvm_call_cpacked_lowered(); - return Call(op->dtype, builtin_call, packed_args); - } - - PrimExpr MakeCallTracePacked(const CallNode* op) { - ICHECK(!alloca_scope_.empty()); - auto& scope = alloca_scope_.back(); - auto& prep_seq = prep_seq_stack_.back(); - - int64_t restore_shape_stack = scope.run_sizes.shape_stack; - size_t restore_array_stack = scope.run_sizes.array_stack; - size_t arg_stack_begin = scope.run_sizes.arg_stack; - scope.run_sizes.arg_stack += op->args.size(); - size_t args_size = op->args.size(); - ICHECK_GT(args_size, 0); - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - for (size_t i = 1; i < op->args.size(); ++i) { - PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1); - PrimExpr arg = op->args[i]; - DataType t = arg.dtype(); - DataType api_type = APIType(t); - if (t != api_type) { - arg = Cast(api_type, arg); - } - prep_seq.emplace_back(TVMStructSet(scope.stack_value, - static_cast(arg_stack_begin + i - 1), - builtin::kTVMValueContent, arg)); - int arg_tcode = api_type.code(); - ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; - prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index})); - } - // Verify stack size matches earlier value. - if (is_precheck_) { - scope.UpdateMax(); - } else { - scope.AssertMaxIsValid(); - } - scope.run_sizes.shape_stack = restore_shape_stack; - scope.run_sizes.array_stack = restore_array_stack; - // Update the top of the stack, so we can use more than one - // packed function's arguments with the one stack. - scope.run_sizes.arg_stack = arg_stack_begin + args_size - 1; - Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + op->args.size() - 1), - // Pass traced value. - op->args[args_size - 1]}; - return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args); + return Call(op->dtype, lowered_packed_op, packed_args); } Stmt MakeNdMemAllocWithScope(const LetStmtNode* let, const CallNode* call) { diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 0a881691accc..d57efd8b9992 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -30,7 +30,7 @@ from tvm.script import relax as R, tir as T from tvm.relax.testing.vm import check_saved_func -EXEC_MODE = ["bytecode"] +EXEC_MODE = ["bytecode", "compiled"] @pytest.mark.parametrize("exec_mode", EXEC_MODE) diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index 4b79ecf70fa1..600d2456174e 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -28,7 +28,7 @@ from tvm.script import relax as R from tvm.script import tir as T -EXEC_MODE = ["bytecode"] +EXEC_MODE = ["bytecode", "compiled"] def codegen(mod, target, exec_mode="bytecode"): diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py new file mode 100644 index 000000000000..6f3bced38581 --- /dev/null +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test the TIR codegen path of VM compiled mode. + +Restrictions: all shape lowered, explicit allocation. +""" +import tvm +import tvm.testing +from tvm import relax +from tvm.ir import assert_structural_equal +from tvm.script import relax as R +from tvm.script import tir as T + + +def get_tir_mod(mod): + builder = relax.ExecBuilder() + return relax.vm._vmcodegen(builder, mod, exec_mode="compiled") + + +def test_add(): + @tvm.script.ir_module + class Before: + @R.function + def foo(x: R.Tensor): + R.func_attr({"global_symbol": "foo"}) + z = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + T.func_attr({"global_symbol": "__vmtir__foo"}) + T.anylist_setitem_call_packed( + r, + T.int32(2), + "test.vm.add", + T.anylist_getitem(r, T.int32(0)), + T.anylist_getitem(r, T.int32(0)), + ) + T.anylist_setitem_call_packed( + r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, T.int32(2)) + ) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +def test_tir_call(): + @tvm.script.ir_module + class Before: + @T.prim_func + def shape_func(H: T.Buffer(T.int64(4), "int64")): + T.func_attr({"global_symbol": "shape_func"}) + # generated compute function + H[T.int64(0)] = H[T.int64(0)] + T.int64(1) + + @R.function + def foo(x: R.Tensor): + R.func_attr({"global_symbol": "foo"}) + _ = shape_func(x) + return x + + @tvm.script.ir_module + class Expected: + @T.prim_func + def shape_func(H: T.Buffer(T.int64(4), "int64")): + T.func_attr({"global_symbol": "shape_func"}) + # generated compute function + H[T.int64(0)] = H[T.int64(0)] + T.int64(1) + + @T.prim_func + def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + T.func_attr({"global_symbol": "__vmtir__foo"}) + T.call_cpacked( + "shape_func", T.anylist_getitem(r, T.int32(0)), T.reinterpret("handle", T.uint64(0)) + ) + T.anylist_setitem_call_packed( + r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, T.int32(0)) + ) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +def test_if_cond(): + @tvm.script.ir_module + class Before: + @R.function + def ife(cond: R.Tensor((), "bool"), x: R.Tensor) -> R.Tensor: + R.func_attr({"global_symbol": "ife"}) + if cond: + w = R.call_packed("test.vm.add", x, x, sinfo_args=(R.Tensor)) + else: + w = R.call_packed("test.vm.mul", x, x, sinfo_args=(R.Tensor)) + return w + + @tvm.script.ir_module + class Expected: + @T.prim_func + def __vmtir__ife(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + T.func_attr({"global_symbol": "__vmtir__ife"}) + if T.cast( + T.tvm_call_packed("vm.builtin.read_if_cond", T.anylist_getitem(r, T.int32(0))), + "bool", + ): + T.anylist_setitem_call_packed( + r, + T.int32(4), + "test.vm.add", + T.anylist_getitem(r, T.int32(1)), + T.anylist_getitem(r, T.int32(1)), + ) + T.anylist_setitem_call_packed( + r, T.int32(3), "vm.builtin.copy", T.anylist_getitem(r, T.int32(4)) + ) + else: + T.anylist_setitem_call_packed( + r, + T.int32(5), + "test.vm.mul", + T.anylist_getitem(r, T.int32(1)), + T.anylist_getitem(r, T.int32(1)), + ) + T.anylist_setitem_call_packed( + r, T.int32(3), "vm.builtin.copy", T.anylist_getitem(r, T.int32(5)) + ) + T.anylist_setitem_call_packed( + r, T.int32(2), "vm.builtin.copy", T.anylist_getitem(r, T.int32(3)) + ) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +def test_const(): + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor): + R.func_attr({"global_symbol": "main"}) + y = R.const([1, 2]) + z = (y, R.const([3, 4]), x) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + # function attr dict + T.func_attr({"global_symbol": "__vmtir__main"}) + # body + T.anylist_setitem_call_packed( + r, + T.int32(2), + "vm.builtin.make_tuple", + T.anylist_getitem(c, T.int32(0)), + T.anylist_getitem(c, T.int32(1)), + T.anylist_getitem(r, T.int32(0)), + ) + T.anylist_setitem_call_packed( + r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, T.int32(2)) + ) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +def test_const_call(): + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor): + R.func_attr({"global_symbol": "main"}) + y = R.const([1, 2]) + z = R.call_packed("test.vm.add", x, y, sinfo_args=(R.Tensor)) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def __vmtir__main(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): + # function attr dict + T.func_attr({"global_symbol": "__vmtir__main"}) + # body + T.anylist_setitem_call_packed( + r, + 2, + "test.vm.add", + T.anylist_getitem(r, 0), + T.anylist_getitem(c, 0), + ) + T.anylist_setitem_call_packed(r, 1, "vm.builtin.copy", T.anylist_getitem(r, 2)) + + before = Before + expected = Expected + after = get_tir_mod(before) + assert_structural_equal(expected, after) + + +if __name__ == "__main__": + tvm.testing.main() From b06d77929ad66cb99b622b77b1629c669b99c9b0 Mon Sep 17 00:00:00 2001 From: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Thu, 16 Feb 2023 19:35:03 -0800 Subject: [PATCH 28/81] [Unity][Pass] BindParams pass, FoldConstant pass (#14016) This PR introduces FoldConstant/BindParam passes. Co-authored-by: Yong Wu Co-Authored-by: Hongyi Jin <3231950289@qq.com> Co-Authored-by: Siyuan Feng --- include/tvm/ir/function.h | 133 ++++++--- include/tvm/relax/transform.h | 15 + python/tvm/relax/transform/transform.py | 62 +++- src/relax/transform/bind_params.cc | 113 +++++++ src/relax/transform/fold_constant.cc | 230 ++++++++++++++ .../relax/test_transform_bind_params.py | 75 +++++ .../relax/test_transform_fold_constant.py | 280 ++++++++++++++++++ 7 files changed, 861 insertions(+), 47 deletions(-) create mode 100644 src/relax/transform/bind_params.cc create mode 100644 src/relax/transform/fold_constant.cc create mode 100644 tests/python/relax/test_transform_bind_params.py create mode 100644 tests/python/relax/test_transform_fold_constant.py diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 1493544e7324..381ea6b8d6d3 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -65,6 +65,68 @@ enum class CallingConv : int { kDeviceKernelLaunch = 2, }; +/*! + * \brief Supported linkage types. + */ +enum class LinkageType : int { + /*! + * \brief Internal linkage. + */ + kInternal = 0, + /*! + * \brief External linkage. + - Function with external linkage should have a global symbol attached to it. + */ + kExternal = 1 +}; + +/*! + * \brief Generic attribute names that can be attached to any function. + * + * \sa tvm::tir::attr, tvm::relay::attr + */ +namespace attr { +/*! + * \brief Indicates the special calling convention. + * + * Type: Integer + * + * \sa tvm::CallingConv + */ +constexpr const char* kCallingConv = "calling_conv"; + +/*! + * \brief Compilation target of the function. + * + * Type: Target + * + * \sa tvm::Target + */ +constexpr const char* kTarget = "target"; + +/*! + * \brief Global linker symbol of the function in generated code. + * + * This option forces the code generator to name the + * function with the given. + * + * For example, we could set a global_symbol of a function + * early to make sure that we can always refer to it by + * the symbol name in the generated DLL. + * + * We should not set the attribute for local functions, + * so that the compiler can freely rename them. + * + * A unique global symbol will be automatically assigned + * to each function in the module before the target code + * generation phase. + * + * Type: String + */ +constexpr const char* kGlobalSymbol = "global_symbol"; + +} // namespace attr + /*! * \brief Base node of all functions. * @@ -130,6 +192,31 @@ class BaseFuncNode : public RelayExprNode { * \endcode */ bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); } + /*! + * \brief Get the type of the linkage. + * + * Currently, we only consider external/internal linkage. + * This can be extended in the future when necessary. + * + * \return Linkage type. + * + * \code + * + * void Example(const BaseFunc& f) { + * if (f->GetLinkageType() == tvm::LinkageType::kExternal) { + * // Do not remove a function with external linkage + * } + * } + * + * \endcode + */ + + LinkageType GetLinkageType() const { + if (GetAttr(attr::kGlobalSymbol)) + return LinkageType::kExternal; + else + return LinkageType::kInternal; + } static constexpr const char* _type_key = "BaseFunc"; static constexpr const uint32_t _type_child_slots = 2; @@ -145,51 +232,5 @@ class BaseFunc : public RelayExpr { TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); }; -/*! - * \brief Generic attribute names that can be attached to any function. - * - * \sa tvm::tir::attr, tvm::relay::attr - */ -namespace attr { -/*! - * \brief Indicates the special calling convention. - * - * Type: Integer - * - * \sa tvm::CallingConv - */ -constexpr const char* kCallingConv = "calling_conv"; - -/*! - * \brief Compilation target of the function. - * - * Type: Target - * - * \sa tvm::Target - */ -constexpr const char* kTarget = "target"; - -/*! - * \brief Global linker symbol of the function in generated code. - * - * This option forces the code generator to name the - * function with the given. - * - * For example, we could set a global_symbol of a function - * early to make sure that we can always refer to it by - * the symbol name in the generated DLL. - * - * We should not set the attribute for local functions, - * so that the compiler can freely rename them. - * - * A unique global symbol will be automatically assigned - * to each function in the module before the target code - * generation phase. - * - * Type: String - */ -constexpr const char* kGlobalSymbol = "global_symbol"; - -} // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index ff98b16d251e..dab062588a82 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -99,7 +99,22 @@ TVM_DLL Pass RewriteDataflowReshape(); * \return The Pass. */ TVM_DLL Pass AttachGlobalSymbol(); +/*! + * \brief Bind params of function of the module to constant tensors. + * + * \param func_name The name of the function to bind parameters. + * \param params The parameters to bind. + * + * \return The Pass. + */ +TVM_DLL Pass BindParams(String func_name, Map params); +/*! + * \brief Fold constant expressions. + * + * \return The Pass. + */ +TVM_DLL Pass FoldConstant(); } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 1a525431dd48..745a26a4dac4 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -19,7 +19,8 @@ import functools import inspect import types -from typing import Callable, Union +from typing import Callable, Dict, Union, Optional, List +import numpy as np # type: ignore import tvm.ir from . import _ffi_api @@ -115,6 +116,65 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass: return _ffi_api.AttachGlobalSymbol() # type: ignore +def BindParams( + func_name: str, + params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]], +) -> tvm.ir.transform.Pass: + """Bind params of function of the module to constant tensors. + + Parameters + ---------- + + func_name: str + The function name to be bound + + params : Dict[str, Union[tvm.runtime.NDArray, np.ndarray]] + The map from param name to constant tensors. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + tvm_params = {} + for k, v in params.items(): + if isinstance(v, np.ndarray): + v = tvm.nd.array(v) + assert isinstance( + v, tvm.runtime.NDArray + ), f"param values are expected to be TVM.NDArray or numpy.ndarray, but got {type(v)}" + tvm_params[k] = v + + return _ffi_api.BindParams(func_name, tvm_params) # type: ignore + + +def RemoveUnusedFunctions(entry_functions: Optional[List[str]] = None) -> tvm.ir.transform.Pass: + """Remove unused relax/prim functions without external linkage in a IRModule. + + Parameters + ---------- + entry_functions: Optional[List[str]] + The set of entry functions to start from. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass to remove unused functions. + """ + if entry_functions is None: + entry_functions = ["main"] + return _ffi_api.RemoveUnusedFunctions(entry_functions) # type: ignore + + +def FoldConstant() -> tvm.ir.transform.Pass: + """Fold constant expressions. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.FoldConstant() # type: ignore + + def AnnotateTIROpPattern() -> tvm.ir.transform.Pass: """Annotate Op Pattern Kind for TIR functions diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc new file mode 100644 index 000000000000..1de8d94461cf --- /dev/null +++ b/src/relax/transform/bind_params.cc @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Bind params to function by using name + * \param func Relax function + * \param params params dict + * \return Function + */ +inline Function BindParamsByName(Function func, const Map& params) { + std::unordered_map name_dict; + std::unordered_set repeat_var; + for (auto arg : func->params) { + const auto& name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(name_dict[name]); + } else { + name_dict[name] = arg; + } + } + + std::unordered_map bind_dict; + for (auto& kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + auto arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "ValueError: Multiple args in the function have name " << kv.first; + } + bind_dict[arg] = Constant(kv.second); + } + Expr bound_expr = Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + ICHECK(ret.defined()) << "The returning type is expected to be a Relax Function." + << "\n"; + return ret; +} + +/*! + * \brief Bind params to a specific function in a module + * \param m The module + * \param func_name The name of the specific function + * \param param The param dict + * \return The module after binding params. + */ +IRModule BindParam(IRModule m, String func_name, Map param) { + IRModuleNode* new_module = m.CopyOnWrite(); + Map functions = m->functions; + for (const auto& func_pr : functions) { + if (const auto* relax_f = func_pr.second.as()) { + if (relax_f->GetLinkageType() == LinkageType::kExternal) { + // Use global_symbol if it's external linkage + Optional gsymbol = relax_f->GetAttr(tvm::attr::kGlobalSymbol); + if (gsymbol.defined() && gsymbol.value() == func_name) { + Function f_after_bind = BindParamsByName(GetRef(relax_f), param); + new_module->Update(func_pr.first, f_after_bind); + } + } else { + // Use global var's name_hint if it's internal linkage + if (func_pr.first->name_hint == func_name) { + Function f_after_bind = BindParamsByName(GetRef(relax_f), param); + new_module->Update(func_pr.first, f_after_bind); + } + } + } + } + return GetRef(new_module); +} + +namespace transform { + +Pass BindParams(String func_name, Map params) { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); }; + return CreateModulePass(pass_func, 0, "BindParams", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.BindParams").set_body_typed(BindParams); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc new file mode 100644 index 000000000000..aa55ee7f7e3d --- /dev/null +++ b/src/relax/transform/fold_constant.cc @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class ConstantFolder : public ExprMutator { + public: + explicit ConstantFolder(IRModule ctx_module) : ctx_module_(ctx_module) {} + + private: + /*! + * \brief Pattern match the shape inside the given struct info to a + * constant shape and get runtime shape tuple from it. + * \param struct_info The given struct info whose shape inside is to be casted. + * \return The runtime shape tuple, or nullopt if it is not a constant shape. + * \note Only TensorStructInfo is supported at this moment. Return NullOpt + * if the input struct info is not TensorStructInfo. + */ + static Optional MatchConstShape(const StructInfo& struct_info) { + // Only support single output for call_tir at this moment. + const auto* tensor_sinfo = struct_info.as(); + if (tensor_sinfo == nullptr) { + return NullOpt; + } + + const auto* shape = tensor_sinfo->shape.as(); + ICHECK(shape != nullptr) << "struct info given by call_tir should have ShapeExpr shape"; + + std::vector shape_values; + for (const auto v : shape->values) { + auto* ptr = v.as(); + if (!ptr) return NullOpt; + shape_values.push_back(ptr->value); + } + return runtime::ShapeTuple(shape_values.begin(), shape_values.end()); + } + + /*! + * \brief Pattern match op to constant array arguments. + * \return The constant array arguments, or nullopt if match fails. + */ + static Optional> MatchConstArrayArgs(const Array& args) { + Array res; + for (auto arg : args) { + auto* ptr = arg.as(); + if (!ptr) return NullOpt; + res.push_back(ptr->data); + } + return res; + } + + /*! + * \brief Pattern match op to a TIR function and look it up. + * \return The TIR function, or nullopt if pattern match fails. + */ + Optional MatchPrimFunc(const Expr& op) { + if (auto* ptr = op.as()) { + // NOTE: as check works for nullptr(returns null) + Optional base_func = ctx_module_->functions.Get(GetRef(ptr)); + if (auto* pfunc = base_func.as()) { + return GetRef(pfunc); + } + } + return NullOpt; + } + + /*! + * \brief Get a cached build version of func + * \return The cached func, nullopt if func cannot be built. + */ + Optional GetCachedBuild(tir::PrimFunc func) { + // TODO(tvm-team): consider another way of bulk extract and build PrimFunc once + // would be helpful for future cases where PrimFunc recursively call into each other + Target eval_cpu_target{"llvm"}; + + auto it = func_build_cache_.find(func); + if (it != func_build_cache_.end()) { + return it->second; + } + Optional build_func = NullOpt; + + try { + // Not all the primfunc can be directly built via llvm, for example, if a function is + // already scheduled to only work on GPU, we will need to skip this in the const folder for + // now + // TODO(Hongyi): further check and narrow the scope of foldable function + runtime::Module rt_module = + build(LowerPrimFunc(func, "tir_function"), eval_cpu_target, eval_cpu_target); + build_func = rt_module.GetFunction("tir_function"); + } catch (const tvm::Error& err) { + // build failure may happen in which case we skip + DLOG(WARNING) << "Build failure for function " << func << ", Error message: " << err.what(); + } + func_build_cache_[func] = build_func; + return build_func; + } + + // Try constant evaluate the function call + // if failed return NullOpt + Optional ConstEvaluateCallTIR(tir::PrimFunc tir_func, Array arr_args, + runtime::ShapeTuple shape, DataType ret_type) { + // obtain function from the cache. + Optional func = GetCachedBuild(tir_func); + if (!func) return NullOpt; + + // here the vector size has an additional + 1 because we need to put ret_tensor at the end + std::vector values(arr_args.size() + 1); + std::vector type_codes(arr_args.size() + 1); + + DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0}; + runtime::NDArray ret_tensor = runtime::NDArray::Empty(shape, ret_type, cpu_dev); + + // avoid set rvalue ref which get de-allocated later, store args in a vector + // where temp_args[i] are lvalue ref that is stable + std::vector temp_args(arr_args.begin(), arr_args.end()); + + size_t arg_offset = 0; + for (; arg_offset < arr_args.size(); ++arg_offset) { + runtime::TVMArgsSetter(values.data(), type_codes.data())(arg_offset, temp_args[arg_offset]); + } + // set return value + runtime::TVMArgsSetter(values.data(), type_codes.data())(arg_offset++, ret_tensor); + + TVMRetValue ret; + // invoke + func.value().CallPacked(TVMArgs(values.data(), type_codes.data(), values.size()), &ret); + return Constant(ret_tensor); + } + + Expr VisitCallTIR(Call call) { + // call_tir needs to have at least three arguments + ICHECK_GE(call->args.size(), 2); + Optional func = MatchPrimFunc(call->args[0]); + ICHECK(call->args[1].as()) << "call_tir.args[1] must be Tuple"; + Optional> arr_args = + MatchConstArrayArgs(call->args[1].as()->fields); + ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg"; + Optional shape = MatchConstShape(call->sinfo_args[0]); + bool output_not_tuple = call->sinfo_args.size() == 1; + // Pattern 0: call constant function, const argument with const shape. + if (func && arr_args && shape && output_not_tuple) { + DynTensorType ret_type = Downcast(call->checked_type()); + // value_or will return value if it is not null, otherwise return or + return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), ret_type->dtype) + .value_or(call); + } + // TODO(hongyi): support const-fold tuple outputs + return std::move(call); + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call) final { + // post-order mutation + Call post_call = Downcast(VisitExprPostOrder_(call)); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + + if (call->op.same_as(call_tir_op)) { + return VisitCallTIR(post_call); + } + return std::move(post_call); + } + + Expr VisitExpr_(const DataflowVarNode* op) final { + Optional opt = LookupBinding(GetRef(op)); + // `as` check checks if opt is not null and is instance of constant + if (opt.as()) { + return opt.value(); + } + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const VarNode* op) final { + Optional opt = LookupBinding(GetRef(op)); + // `as` check checks if opt is not null and is instance of constant + if (opt.as()) { + return opt.value(); + } + return ExprMutator::VisitExpr_(op); + } + + // the context module to lookup functions + IRModule ctx_module_; + // cache for function build, via structural equality + std::unordered_map, StructuralHash, StructuralEqual> + func_build_cache_; +}; + +namespace transform { + +Pass FoldConstant() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + ConstantFolder folder(m); + return Downcast(folder(f)); + }; + return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FoldConstant").set_body_typed(FoldConstant); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py new file mode 100644 index 000000000000..b96fb89e6c0a --- /dev/null +++ b/tests/python/relax/test_transform_bind_params.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import tvm +import tvm.script +import tvm.testing +from tvm import relax +from tvm.script import relax as R +from tvm.script import tir as T + +use_np_array = tvm.testing.parameter(False, True) + + +def test_bind_params(use_np_array): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16, 16)) + C = T.match_buffer(z, (16, 16)) + for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): + with T.block("matmul"): + vi = T.axis.S(16, i0 * 4 + i1) + vj = T.axis.S(16, j) + vk = T.axis.R(16, k0 * 4 + k1) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((16, 16), dtype="float32")) + return gv0 + + x_np = np.random.rand(16, 16).astype(np.float32) + w_np = np.random.rand(16, 16).astype(np.float32) + x_tvm = tvm.nd.array(x_np) + w_tvm = tvm.nd.array(w_np) + params_dict = {"w": w_np if use_np_array else w_tvm} + mod = relax.transform.BindParams("main", params_dict)(InputModule) + assert len(mod["main"].params) == 1 + + target = tvm.target.Target("llvm") + ex_after = relax.vm.build(mod, target) + vm_after = relax.VirtualMachine(ex_after, tvm.cpu()) + res_after = vm_after["main"](x_tvm) + + ex_before = relax.vm.build(InputModule, target) + vm_before = relax.VirtualMachine(ex_before, tvm.cpu()) + res_before = vm_before["main"](x_tvm, w_tvm) + + tvm.testing.assert_allclose(res_before.numpy(), res_after.numpy()) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py new file mode 100644 index 000000000000..32ee3e700080 --- /dev/null +++ b/tests/python/relax/test_transform_fold_constant.py @@ -0,0 +1,280 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import relax +import numpy as np + +import tvm.script +from tvm.script import tir as T, relax as R + + +def gen_mod(mod, name, binding): + """Select relax function with name, rename to main and and bind constant. + + Parameters + ---------- + mod: IRModule + The input module + + name: str + The name of relax function to preserve and rename to main + + binding: Dict[str, array] + The const parameter bindings + """ + funcs = {} + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + + for k, v in mod.functions.items(): + if isinstance(v, tvm.relax.Function): + if k.name_hint == name: + # rename to main + gv = tvm.ir.GlobalVar("main") + funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_struct_info).with_attr( + "global_symbol", "main" + ) + else: + funcs[k] = v + mod = tvm.IRModule(funcs) + return relax.transform.BindParams("main", binding)(mod) + + +def test_one_fold_addone(): + # put before after in a single module + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) -> None: + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @R.function + def before(c0: R.Tensor((16, 16), "float32")): + lv0 = relax.call_tir(addone, (c0,), R.Tensor((16, 16), dtype="float32")) + return lv0 + + @R.function + def expected(c1: R.Tensor((16, 16), "float32")): + lv0 = c1 + return c1 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_one_fold_transpose(): + # put before after in a single module + @tvm.script.ir_module + class Module: + @T.prim_func + def func(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]) -> None: + for i, j in T.grid(3, 2): + with T.block("transpose"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vj, vi] + + @R.function + def before(c0: R.Tensor((2, 3), "float32")): + lv0 = relax.call_tir(func, (c0,), R.Tensor((3, 2), dtype="float32")) + return lv0 + + @R.function + def expected(c1: R.Tensor((3, 2), "float32")): + lv0 = c1 + return c1 + + c0_np = np.arange(2 * 3).astype("float32").reshape(2, 3) + c1_np = c0_np.T + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_two_hop_addone(): + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer[(2, 2), "float32"], B: T.Buffer[(2, 2), "float32"]) -> None: + for i, j in T.grid(2, 2): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @R.function + def before(c0: R.Tensor((2, 2), "float32")): + lv0 = relax.call_tir(addone, (c0,), R.Tensor((2, 2), dtype="float32")) + lv1 = relax.call_tir(addone, (lv0,), R.Tensor((2, 2), dtype="float32")) + return lv1 + + @R.function + def expected(c1: R.Tensor((2, 2), "float32"), c2: R.Tensor((2, 2), "float32")): + lv0 = c1 + lv1 = c2 + return c2 + + c0_np = np.arange((2 * 2)).astype("float32").reshape(2, 2) + c1_np = c0_np + 1 + c2_np = c1_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np, "c2": c2_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_dataflow_fold(): + @tvm.script.ir_module + class Module: + @T.prim_func + def identity(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) -> None: + for i, j in T.grid(16, 16): + with T.block("identity"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + + @R.function + def before(c0: R.Tensor((16, 16), "float32")): + with R.dataflow(): + gv0 = relax.call_tir(identity, (c0,), R.Tensor((16, 16), dtype="float32")) + R.output(gv0) + return gv0 + + @R.function + def expected(c1: R.Tensor((16, 16), "float32")): + with R.dataflow(): + gv0 = c1 + R.output(gv0) + return c1 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_fold_mixed_case(): + @tvm.script.ir_module + class Module: + # TIR function can handle different cases. + @T.prim_func + def addone(a: T.handle, b: T.handle) -> None: + n = T.var("int32") + m = T.var("int32") + A = T.match_buffer(a, (n, m)) + B = T.match_buffer(b, (n, m)) + for i, j in T.grid(n, m): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @T.prim_func + def sub( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], + ) -> None: + for i, j in T.grid(16, 16): + with T.block("sub"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] - B[vi, vj] + + @R.function + def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2)): + n, m = T.var("int64"), T.var("int64") + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) + # this line cannot be folded because n is unknown + lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16), dtype="float32")) + # this line can be folded + lv1 = relax.call_tir(addone, (c0,), R.Tensor((16, 16), dtype="float32")) + # this line can be folded because all inputs are const + lv2 = relax.call_tir(sub, (c0, lv1), R.Tensor((16, 16), dtype="float32")) + # this line can not be folded because x's shape is unknown + lv3 = relax.call_tir(sub, (lv2, x), R.Tensor((16, 16), dtype="float32")) + return lv3 + + @R.function + def expected( + c0: R.Tensor((16, 16), "float32"), + c1: R.Tensor((16, 16), "float32"), + c2: R.Tensor((16, 16), "float32"), + x: R.Tensor("float32", ndim=2), + ) -> R.Tensor: + n, m = T.var("int64"), T.var("int64") + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) + # this line cannot be folded because n is unknown + lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16), dtype="float32")) + # this line can be folded + lv1 = c1 + # this line can be folded because all inputs are const + lv2 = c2 + # this line can not be folded because x's shape is unknown + lv3 = relax.call_tir(sub, (c2, x), R.Tensor((16, 16), dtype="float32")) + return lv3 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + 1 + c2_np = c0_np - c1_np + + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c0": c0_np, "c1": c1_np, "c2": c2_np}) + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_int32_fold(): + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16), "int32"]) -> None: + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + @R.function + def before(c0: R.Tensor((16, 16), "int32")): + lv0 = relax.call_tir(addone, (c0,), R.Tensor((16, 16), dtype="int32")) + return lv0 + + @R.function + def expected(c1: R.Tensor((16, 16), "int32")): + lv0 = c1 + return c1 + + c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +if __name__ == "__main__": + tvm.testing.main() From 2aed16966ffa978cc3b14d8e3eff9fe3a9ff28ba Mon Sep 17 00:00:00 2001 From: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Fri, 17 Feb 2023 00:20:44 -0800 Subject: [PATCH 29/81] [Unity][Pass][TuningAPI] Introduce TuningAPI and MetaSchedule pass (#14014) Add TuningAPI and MetaSchedule tuning pass --- CMakeLists.txt | 1 + include/tvm/ir/transform.h | 54 +- include/tvm/relax/transform.h | 22 +- include/tvm/relax/tuning_api.h | 396 +++++++++ include/tvm/relay/transform.h | 2 +- include/tvm/tir/transform.h | 2 +- python/tvm/ir/transform.py | 95 ++- python/tvm/meta_schedule/__init__.py | 1 + python/tvm/meta_schedule/relax_integration.py | 352 ++++++++ python/tvm/meta_schedule/tir_integration.py | 89 ++ python/tvm/meta_schedule/tune_context.py | 3 +- python/tvm/relax/transform/transform.py | 69 +- .../relax/transform/tuning_api/__init__.py | 22 + .../relax/transform/tuning_api/_ffi_api.py | 19 + .../relax/transform/tuning_api/database.py | 273 ++++++ .../transform/tuning_api/default_functions.py | 306 +++++++ .../relax/transform/tuning_api/primitives.py | 419 ++++++++++ python/tvm/tir/transform/function_pass.py | 3 +- src/ir/transform.cc | 84 +- src/relax/backend/task_extraction.cc | 114 +++ src/relax/ir/transform.cc | 8 +- src/relax/transform/meta_schedule.cc | 171 ++++ src/relax/transform/tuning_api/database.cc | 350 ++++++++ src/relax/transform/tuning_api/primitives.cc | 273 ++++++ src/relay/ir/transform.cc | 4 +- src/relay/transforms/type_infer.cc | 2 +- src/tir/ir/transform.cc | 4 +- .../test_transform_meta_schedule_tuning.py | 115 +++ tests/python/relax/test_tuning_api.py | 781 ++++++++++++++++++ 29 files changed, 3987 insertions(+), 47 deletions(-) create mode 100644 include/tvm/relax/tuning_api.h create mode 100644 python/tvm/meta_schedule/relax_integration.py create mode 100644 python/tvm/relax/transform/tuning_api/__init__.py create mode 100644 python/tvm/relax/transform/tuning_api/_ffi_api.py create mode 100644 python/tvm/relax/transform/tuning_api/database.py create mode 100644 python/tvm/relax/transform/tuning_api/default_functions.py create mode 100644 python/tvm/relax/transform/tuning_api/primitives.py create mode 100644 src/relax/backend/task_extraction.cc create mode 100644 src/relax/transform/meta_schedule.cc create mode 100644 src/relax/transform/tuning_api/database.cc create mode 100644 src/relax/transform/tuning_api/primitives.cc create mode 100644 tests/python/relax/test_transform_meta_schedule_tuning.py create mode 100644 tests/python/relax/test_tuning_api.py diff --git a/CMakeLists.txt b/CMakeLists.txt index d0470677e128..18be118832ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -294,6 +294,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/relax/analysis/*.cc src/relax/transform/*.cc src/relax/backend/vm/*.cc + src/relax/backend/task_extraction.cc src/relax/utils.cc ) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 473e6291685d..ff54a6b5eacd 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -32,18 +32,18 @@ * - Reducing the effort required to implement new passes for compiler * developers, etc. * - * Similar to LLVM's pass manager, we designed the Relay pass manager to work + * Similar to LLVM's pass manager, we designed the Relay/Relax pass manager to work * different granularity, i.e. module level, function level, and even sequential * passe that contains a host of passes. * * However, we also extend the functionality of the traditional pass manager * with the consideration of requirements/convention from deep learning - * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass + * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay/Relax pass * manager performs the IRModule -> IRModule transformation. All * different types of passes, including the sequential-level pass object, are * essentially pass objects. This design, therefore, effectively provides users * a consistent and convenient interface, i.e. Pass, to play with. It offers a - * means to ease the development and testing of Relay passes. For example, with + * means to ease the development and testing of Relay/Relax passes. For example, with * the pass manager, external users will be able to have custom passes correctly * scheduled without having to modify a single handcrafted pass order. * @@ -90,7 +90,16 @@ class PassContextNode : public Object { /*! \brief A list of pass instrument implementations. */ Array instruments; - + // TODO(@sunggg): Fix dependency issue in the header file and correct the types + // e.g., relax::trace, relax::database in tvm/relax/tuning_api.h + /*! \brief Trace stack for relax pass infra. */ + mutable Array trace_stack; + /*! \brief List of passes to be traced. If not defined, make every pass traceable. */ + Optional> make_traceable; + /*! \brief Number of evaluations conducted in the pass pipeline. */ + mutable int num_evals{0}; + /*! \brief Database for tuning API. */ + Optional tuning_api_database; PassContextNode() = default; /*! @@ -130,7 +139,27 @@ class PassContextNode : public Object { v->Visit("instruments", &instruments); v->Visit("config", &config); v->Visit("diag_ctx", &diag_ctx); + v->Visit("trace_stack", &trace_stack); + v->Visit("make_traceable", &make_traceable); + v->Visit("num_evals", &num_evals); + v->Visit("tuning_api_daatabase", &tuning_api_database); + } + + Array GetTraceStack() { return trace_stack; } + void PushTrace(ObjectRef new_trace) { trace_stack.push_back(new_trace); } + void PopTrace() { + ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check."; + trace_stack.pop_back(); } + int GetTraceStackSize() { return trace_stack.size(); } + ObjectRef GetCurrentTrace() { + ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check."; + return trace_stack.back(); + } + void SetNumEvals(int _num_evals) { num_evals = _num_evals; } + void IncNumEvals(int _num_evals) { num_evals += _num_evals; } + + Optional GetTuningAPIDatabase() { return tuning_api_database; } static constexpr const char* _type_key = "transform.PassContext"; static constexpr bool _type_has_method_sequal_reduce = false; @@ -287,6 +316,9 @@ class PassInfoNode : public Object { /*! \brief The name of an optimization/analysis pass. */ String name; + /*! \brief Boolean that tells whether this pass will be traced or not. */ + bool traceable; + /*! \brief The passes that are required to perform the current pass. */ Array required; @@ -296,6 +328,7 @@ class PassInfoNode : public Object { v->Visit("opt_level", &opt_level); v->Visit("name", &name); v->Visit("required", &required); + v->Visit("traceable", &traceable); } static constexpr const char* _type_key = "transform.PassInfo"; @@ -314,8 +347,9 @@ class PassInfo : public ObjectRef { * \param opt_level The optimization level * \param name Name of the pass. * \param required The passes that are required to perform the current pass. + * \param traceable Boolean that tells whether the pass is traceable. */ - TVM_DLL PassInfo(int opt_level, String name, Array required); + TVM_DLL PassInfo(int opt_level, String name, Array required, bool traceable); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); }; @@ -323,7 +357,7 @@ class PassInfo : public ObjectRef { /*! * \brief PassNode is the base type of differnt types of optimization passes. * It is designed as a pure class and implemented by different pass subclasses - * at different granularity of Relay nodes. + * at different granularity of Relay/Relax nodes. */ class PassNode : public Object { public: @@ -396,7 +430,7 @@ class Pass : public ObjectRef { }; /*! - * \brief The SequentialNode contains a set of passes that transform Relay + * \brief The SequentialNode contains a set of passes that transform Relay/Relax * programs from one AST to another semantically equivalent one. * * One example of this level of pass is that the pass manager needs to correctly @@ -489,9 +523,9 @@ class Sequential : public Pass { * * \return The created module pass. */ -TVM_DLL Pass -CreateModulePass(const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, Array required); +TVM_DLL Pass CreateModulePass( + const runtime::TypedPackedFunc& pass_func, int opt_level, + String name, Array required, bool traceable = false); /*! * \brief A special trace pass that prints the header and IR to LOG(INFO). diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index dab062588a82..e9f63ee9dbc9 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -44,12 +44,13 @@ using DataflowBlock = tvm::relax::DataflowBlock; * \param opt_level The optimization level of the function pass. * \param name The name of the function pass. * \param required The list of the passes that the function pass is dependent on. + * \param traceable Boolean variable whether the dataflowblock pass is traceable. * * \return The created function pass. */ TVM_DLL Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required); + int opt_level, String name, tvm::Array required, bool traceable = false); /*! * \brief Create a dataflowblock pass. @@ -58,12 +59,13 @@ TVM_DLL Pass CreateFunctionPass( * \param opt_level The optimization level of the dataflowblock pass. * \param name The name of the dataflowblock pass. * \param required The list of the passes that the dataflowblock pass is dependent on. + * \param traceable Boolean variable whether the dataflowblock pass is traceable. * * \return The created dataflowblock pass. */ TVM_DLL Pass CreateDataflowBlockPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required); + int opt_level, String name, tvm::Array required, bool traceable = false); /*! * \brief Transform all dataflow structure to non-dataflow version. @@ -93,6 +95,22 @@ TVM_DLL Pass CallTIRRewrite(); */ TVM_DLL Pass RewriteDataflowReshape(); +/*! + * \brief Bind params of function of the module to constant tensors. + * + * \param func_name The name of the function to bind parameters. + * \param params The parameters to bind. + * + * \return The Pass. + */ +TVM_DLL Pass BindParams(String func_name, Map params); + +/*! + * \brief Fold constant expressions. + * + * \return The Pass. + */ +TVM_DLL Pass FoldConstant(); /*! * \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen. * diff --git a/include/tvm/relax/tuning_api.h b/include/tvm/relax/tuning_api.h new file mode 100644 index 000000000000..b6224a6d6d9e --- /dev/null +++ b/include/tvm/relax/tuning_api.h @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/tuning_api.h + * \brief Relax Tuning Pass APIs. + */ +#ifndef TVM_RELAX_TUNING_API_H_ +#define TVM_RELAX_TUNING_API_H_ +#include +#include +#include + +#include +namespace tvm { +namespace relax { + +/*! \brief Helper function to unpack arguments in the array as parameters for the given packed + * function. */ +TVM_ALWAYS_INLINE TVMRetValue CallPackedWithArgsInArray(const runtime::PackedFunc f, + const Array& args) { + size_t num_args = args.size(); + std::vector values(num_args); + std::vector codes(num_args); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + const ObjectRef* ptr = args.template as()->begin(); + for (size_t i = 0; i < num_args; ++i) { + setter(i, *(ptr + i)); + } + + TVMRetValue rv; + f.CallPacked(TVMArgs(values.data(), codes.data(), num_args), &rv); + return rv; +} + +/*! \brief Choice manages a set of keys for transformation and constraint functions. */ +class ChoiceNode : public runtime::Object { + public: + /*! \brief ffi key for transformation function. */ + String transform_func_key; + /*! \brief ffi key for constraint function. */ + String constr_func_key; + Array transform_func_args; + Array constr_func_args; + + /*! \brief The default destructor. */ + virtual ~ChoiceNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("transform_func_key", &transform_func_key); + v->Visit("transform_func_args", &transform_func_args); + v->Visit("constr_func_key", &constr_func_key); + v->Visit("constr_func_args", &constr_func_args); + } + + /*! \brief Getter for constr_func. */ + const runtime::PackedFunc GetConstrFunc() { + const auto* constr_func = tvm::runtime::Registry::Get(constr_func_key); + ICHECK(constr_func != nullptr) << "constr_func_key is not registered: " << constr_func_key; + return *constr_func; + } + + /*! \brief Getter for transform_func. */ + const runtime::PackedFunc GetTransformFunc() { + auto* transform_func = tvm::runtime::Registry::Get(transform_func_key); + ICHECK(transform_func != nullptr) + << "transform_func_key is not registered: " << transform_func_key; + return *transform_func; + } + + /*! \brief Perform constr_func. */ + bool CheckConstr(const IRModule& mod) { + Array args(constr_func_args); + args.insert(args.begin(), mod); + return CallPackedWithArgsInArray(GetConstrFunc(), args); + } + + /*! \brief Perform transform_func. */ + IRModule ApplyTransformFunc(IRModule mod) { + // Apply transformation when constraint is satisfied. + if (CheckConstr(mod)) { + Array args(transform_func_args); + args.insert(args.begin(), GetRef(mod.CopyOnWrite())); + return CallPackedWithArgsInArray(GetTransformFunc(), args); + } + return mod; + } + + /*! + * \brief Serialize Choice as a JSON-style object + * \return The JSON-style object + */ + ObjectRef AsJSON() const; + + static constexpr const char* _type_key = "relax.tuning_api.Choice"; + TVM_DECLARE_BASE_OBJECT_INFO(ChoiceNode, Object); +}; + +/*! \brief Managed reference to ChoiceNode */ +class Choice : public runtime::ObjectRef { + public: + TVM_DLL explicit Choice(String transform_func_key, Array transform_func_args, + String constr_func_key, Array constr_func_args); + /*! \brief Deserialize JSON-style object into Choice */ + TVM_DLL static Choice FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Choice, ObjectRef, ChoiceNode); +}; + +/*! \brief Knob manages a set of valid choices for an optimization. */ +class KnobNode : public runtime::Object { + public: + /*! \brief Name of the knob. */ + String name; + /*! \brief Decision space. */ + Map choices; + + /*! \brief The default destructor. */ + virtual ~KnobNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("choices", &choices); + } + + /*! \brief Check if a decision is valid. */ + bool IsValidDecision(String decision) { return choices.count(decision) > 0; } + + /*! \brief Apply decision if the constraint is satisfied. + Otherwise, return the original IRModule. + */ + IRModule Apply(IRModule mod, String decision) { + ICHECK(IsValidDecision(decision)) << "Invalid choice for this knob: " << decision; + return choices[decision]->ApplyTransformFunc(mod); + } + + /*! + * \brief Serialize Knob as a JSON-style object + * \return The JSON-style object + */ + ObjectRef AsJSON() const; + + static constexpr const char* _type_key = "relax.tuning_api.Knob"; + TVM_DECLARE_BASE_OBJECT_INFO(KnobNode, Object); +}; + +/*! \brief Managed reference to KnobNode */ +class Knob : public runtime::ObjectRef { + public: + TVM_DLL explicit Knob(String name, Map choices); + /*! \brief Deserialize JSON-style object into Knob */ + TVM_DLL static Knob FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Knob, ObjectRef, KnobNode); +}; + +/*! \brief Trace manages history of optimization decisions. */ +class TraceNode : public runtime::Object { + public: + /*! \brief Input IRModule. */ + IRModule in_mod; + /*! \brief Output IRModule. */ + mutable IRModule out_mod; + // TODO(sunggg): can we move knobs and decisions into private? + /*! \brief Knobs that are applied so far. */ + Array knobs; + /*! \brief Decisions made for the knobs. */ + Array decisions; + /*! \brief Performance of out_mod. */ + mutable double perf = -1; + /*! \brief Length of the decision history. */ + mutable int size = 0; + /*! \brief The default destructor. */ + virtual ~TraceNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("in_mod", &in_mod); + v->Visit("out_mod", &out_mod); + v->Visit("knobs", &knobs); + v->Visit("decisions", &decisions); + v->Visit("perf", &perf); + v->Visit("size", &size); + } + + /*! \brief Verify current decision history. */ + bool Verify() const { + if (knobs.size() != decisions.size()) return false; + int n = knobs.size(); + for (int i = 0; i < n; i++) { + if (!knobs[i]->IsValidDecision(decisions[i])) return false; + } + return true; + } + + /*! \brief Add a knob and its decision to the current trace. */ + IRModule Add(Knob knob, String decision) { + out_mod = knob->Apply(out_mod, decision); + knobs.push_back(knob); + decisions.push_back(decision); + // perf number should be initialized after new decision is applied. + perf = -1; + // increment history size. + size++; + return out_mod; + } + + /*! + * \brief Serialize Trace as a JSON-style object + * \param include_in_mod Boolean config to include input IRModule in the output. + * \return The JSON-style object + */ + ObjectRef AsJSON(bool include_in_mod = true) const; + + /*! \brief Set the performance. */ + void SetPerf(double _perf) { perf = _perf; } + /*! \brief Set output module. */ + void SetOutMod(IRModule mod_) { out_mod = mod_; } + + static constexpr const char* _type_key = "relax.tuning_api.Trace"; + TVM_DECLARE_BASE_OBJECT_INFO(TraceNode, Object); +}; + +/*! \brief Managed reference to TraceNode */ +class Trace : public runtime::ObjectRef { + public: + /*! \brief Default constructor. Creating an empty trace. */ + Trace(); + /*! + * \brief Constructor. Creating a trace from existing knobs and their decisions + * \param in_mod Input IRModule + * \param knobs The knobs used + * \param decisions The decisions made in sampling + */ + TVM_DLL explicit Trace(IRModule in_mod, Array knobs, Array decisions); + /*! \brief Deserialize JSON-style object into Trace */ + TVM_DLL static Trace FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, ObjectRef, TraceNode); +}; + +/*! \brief The class of tuning records. */ +class TuningRecordNode : public runtime::Object { + public: + /*! \brief The trace tuned. */ + Trace trace; + /*! \brief The measurement record in seconds. */ + Optional> run_secs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("trace", &trace); + v->Visit("run_secs", &run_secs); + } + + static constexpr const char* _type_key = "relax.tuning_api.TuningRecord"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object); + + /*! + * \brief Export the tuning record to a JSON string. + * \param include_irmod Boolean config to include IRModules in the output. + * \return JSON object + */ + ObjectRef AsJSON(bool include_irmod = false) const; +}; + +/*! + * \brief The managed reference of TuningRecordNode. + * \sa TuningRecordNode + */ +class TuningRecord : public runtime::ObjectRef { + public: + /*! + \brief Constructor of a tuning record. + \param trace The trace of the tuning record. + \param run_secs The running time of the tuning record. + */ + TVM_DLL explicit TuningRecord(Trace trace, Optional> run_secs); + /*! + * \brief Create a tuning record from a json object. + * \param json_obj The json object. + * \return The tuning record created. + */ + TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode); +}; + +/*! \brief The equality check for Workload */ +struct WorkloadEqual { + bool operator()(const meta_schedule::Workload& a, const meta_schedule::Workload& b) const { + return a->shash == b->shash && tvm::StructuralEqual()(a->mod, b->mod); + } +}; + +/* \brief The abstract interface of database. */ +class DatabaseNode : public runtime::Object { + public: + /*! \brief Default destructor */ + virtual ~DatabaseNode() = default; + /*! + * \brief Check if the database has the given workload. + * \param mod The IRModule to be searched for. + * \return Whether the database has the given workload. + */ + virtual bool HasWorkload(const IRModule& mod) = 0; + /*! + * \brief Check if the database has a measurement record for the given workload and target pair. + * \param workload The workload to be searched for. + * \param target The target to be searched for. + * \return Whether the database has the measurement record for given workload and target pair. + */ + virtual bool HasMeasurementRecord(const meta_schedule::Workload& workload, + const Target& target) = 0; + /*! + * \brief Check if the database has a tuning record for the given workload and target pair. + * \param workload The workload to be searched for. + * \param target The target to be searched for. + * \return Whether the database has the tuning record for the given workload and target pair. + */ + virtual bool HasTuningRecord(const meta_schedule::Workload& workload, const Target& target) = 0; + /*! + * \brief Look up or add workload to the database if missing. + * \param mod The IRModule to be searched for or added. + * \return The workload corresponding to the given IRModule. + */ + virtual meta_schedule::Workload CommitWorkload(const IRModule& mod) = 0; + /*! + * \brief Add a measurement record for a given pair of target and workload to the database. + * \param workload Workload to be searched for. + * \param target Target to be searched for. + * \param record Measurement record to be added. + */ + virtual void CommitMeasurementRecord(const meta_schedule::Workload& workload, + const Target& target, const Array& record) = 0; + /*! + * \brief Add a tuning record for a given pair of target and workload to the database. + * \param workload Workload to be searched for. + * \param target Target to be searched for. + * \param record Tuning record to be added. + */ + virtual void CommitTuningRecord(const meta_schedule::Workload& workload, const Target& target, + const TuningRecord& record) = 0; + /*! + * \brief Get the top K tuning records of given workload and target from the database. + * \param workload The workload to be searched for. + * \param target Target to be searched for. + * \param top_k The number of top records to be returned. + * \return An array of top K tuning records for the given workload. + */ + virtual Array GetTopK(const meta_schedule::Workload& workload, const Target& target, + int top_k) = 0; + /*! + * \brief Get the measurement record of given workload and target from the database. + * \param workload The workload to be searched for. + * \param target Target to be searched for. + * \return Measurement. + */ + virtual Array GetMeasurementRecord(const meta_schedule::Workload& workload, + const Target target) = 0; + + static constexpr const char* _type_key = "relax.tuning_api.Database"; + TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object); +}; + +/*! + * \brief Managed reference to DatabaseNode. + * \sa DatabaseNode + */ +class Database : public runtime::ObjectRef { + public: + /*! + * \brief Create a default database that uses JSON file for tuning records. + * \param path_workload The path to the workload table. + * \param path_tuning_record The path to the tuning record table. + * \param path_measurement_record The path to the measurement_record table. + * \param allow_missing Whether to create new file when the given path is not found. + */ + TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record, + String path_measurement_record, bool allow_missing); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode); +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_TUNING_API_H_ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 43a0f89d95c1..256f1a64dd87 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -60,7 +60,7 @@ using Sequential = tvm::transform::Sequential; */ TVM_DLL Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required); + int opt_level, String name, tvm::Array required, bool traceable = false); /*! \brief Remove let-bound expressions which do not effect the program result. * diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 85b381a52950..fee5db087589 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -56,7 +56,7 @@ using tvm::transform::Sequential; */ TVM_DLL Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required); + int opt_level, String name, tvm::Array required, bool traceable = false); /*! * \brief Inject prefetch instructions into stmt. diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 17995bfa7850..21f5d41d862a 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -45,8 +45,10 @@ class PassInfo(tvm.runtime.Object): The list of passes that are required by a certain pass. """ - def __init__(self, opt_level, name, required=None): - self.__init_handle_by_constructor__(_ffi_transform_api.PassInfo, opt_level, name, required) + def __init__(self, opt_level, name, required=None, traceable=False): + self.__init_handle_by_constructor__( + _ffi_transform_api.PassInfo, opt_level, name, required, traceable + ) @tvm._ffi.register_object("transform.PassContext") @@ -70,6 +72,20 @@ class PassContext(tvm.runtime.Object): config : Optional[Dict[str, Object]] Additional configurations for specific passes. + + trace: Optional[relax.tuning.Trace] + Initial trace for trace mode. + + trace_stack: Optional[List[relax.tuning_api.Trace]] + Initial trace stack for trace mode. + + make_traceable: Optional[List[str]] + List of passes to make traceable. + + num_evals: int + initial number of evaluations conducted in the pipeline. + + tuning_api_database: Optional[relax.tuning_api.JSONDatabase] """ def __init__( @@ -79,6 +95,11 @@ def __init__( disabled_pass=None, instruments=None, config=None, + trace=None, + trace_stack=None, + make_traceable=None, + num_evals=0, + tuning_api_database=None, ): required = list(required_pass) if required_pass else [] if not isinstance(required, (list, tuple)): @@ -92,9 +113,25 @@ def __init__( if not isinstance(instruments, (list, tuple)): raise TypeError("instruments is expected to be the type of " + "list/tuple/set.") + # Convert to Map + # TODO(sunggg): Replace this to Set equivalent if exists + make_traceable = {name: True for name in make_traceable} if make_traceable else None + + if not trace_stack: + trace_stack = [trace] if trace else [] + config = config if config else None self.__init_handle_by_constructor__( - _ffi_transform_api.PassContext, opt_level, required, disabled, instruments, config + _ffi_transform_api.PassContext, + opt_level, + required, + disabled, + instruments, + config, + trace_stack, + make_traceable, + num_evals, + tuning_api_database, ) def __enter__(self): @@ -131,6 +168,47 @@ def list_configs(): """ return _ffi_transform_api.ListConfigs() + def push_trace(self, trace): + """Push a trace into the stack.""" + return _ffi_transform_api.PushTrace(self, trace) + + def pop_trace(self, return_current=True): + """Pop a topmost trace from the stack. + Returns + ------- + Trace : Optional[relax.tuning.Trace] + """ + if return_current: + cur_trace = self.get_current_trace() + _ffi_transform_api.PopTrace(self) + return cur_trace + + return _ffi_transform_api.PopTrace(self) + + def get_trace_stack(self): + """Get the current trace stack.""" + return _ffi_transform_api.GetTraceStack(self) + + def get_trace_stack_size(self): + """Get the size of current stack.""" + return _ffi_transform_api.GetTraceStackSize(self) + + def get_current_trace(self): + """Get the trace on the top of the stack.""" + return _ffi_transform_api.GetCurrentTrace(self) + + def set_num_evals(self, num: int): + """Set the number of evaluations conducted in the pipeline.""" + return _ffi_transform_api.SetNumEvals(self, num) + + def inc_num_evals(self, num: int): + """Increment the number of evaluations conducted in the pipeline.""" + return _ffi_transform_api.IncNumEvals(self, num) + + def get_tuning_api_database(self): + """Get tuning api database.""" + return _ffi_transform_api.GetTuningAPIDatabase(self) + @tvm._ffi.register_object("transform.Pass") class Pass(tvm.runtime.Object): @@ -199,7 +277,7 @@ class Sequential(Pass): The list of passes that the sequential pass is dependent on. """ - def __init__(self, passes=None, opt_level=0, name="sequential", required=None): + def __init__(self, passes=None, opt_level=0, name="sequential", required=None, traceable=False): passes = passes if passes else [] if not isinstance(passes, (list, tuple)): raise TypeError("passes must be a list of Pass objects.") @@ -209,7 +287,7 @@ def __init__(self, passes=None, opt_level=0, name="sequential", required=None): raise TypeError("Required is expected to be the type of list/tuple.") self.__init_handle_by_constructor__( - _ffi_transform_api.Sequential, passes, opt_level, name, required + _ffi_transform_api.Sequential, passes, opt_level, name, required, traceable ) @@ -245,7 +323,7 @@ def __getattr__(self, name): return PyModulePass -def module_pass(pass_func=None, opt_level=None, name=None, required=None): +def module_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False): """Decorate a module pass. This function returns a callback when pass_func is provided. @@ -270,6 +348,9 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None): required : Optional[List[str]] The list of passes that the module pass is dependent on. + traceable: Boolean + Boolean variable whether the module pass is traceable + Returns ------- create_module_pass : Union[Callable, ModulePass] @@ -337,7 +418,7 @@ def transform(mod, ctx): def create_module_pass(pass_arg): """Internal function that creates a module pass""" fname = name if name else pass_arg.__name__ - info = PassInfo(opt_level, fname, required) + info = PassInfo(opt_level, fname, required, traceable) if inspect.isclass(pass_arg): return _wrap_class_module_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 30a4fc6d9467..21a11ff9e84d 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -25,6 +25,7 @@ mutator, postproc, relay_integration, + relax_integration, runner, schedule, schedule_rule, diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py new file mode 100644 index 000000000000..a82d8996858b --- /dev/null +++ b/python/tvm/meta_schedule/relax_integration.py @@ -0,0 +1,352 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta schedule integration with high-level IR""" +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +# isort: off +from typing_extensions import Literal + +# isort: on + +from tvm._ffi import get_global_func, register_func +from tvm.ir import IRModule +from tvm.ir.transform import PassContext +from tvm.runtime import NDArray +from tvm.target import Target +from tvm.tir.expr import IntImm + +from .builder import Builder +from .cost_model import CostModel +from .database import Database +from .extracted_task import ExtractedTask +from .logging import get_loggers_from_work_dir +from .measure_callback import MeasureCallback +from .runner import Runner +from .search_strategy import SearchStrategy +from .space_generator import SpaceGenerator +from .task_scheduler import TaskScheduler +from .tune import tune_tasks +from .tune_context import TuneContext +from .utils import fork_seed + +if TYPE_CHECKING: + from tvm import relax + +_extract_task_func = get_global_func( # pylint: disable=invalid-name + "relax.backend.MetaScheduleExtractTask", + allow_missing=False, +) + + +def extract_tasks( + mod: Union[IRModule, "relax.Function"], + target: Target, + params: Optional[Dict[str, NDArray]] = None, +) -> List[ExtractedTask]: + """Extract tuning tasks from a relax program. + + Parameters + ---------- + mod : Union[IRModule, relax.Function] + The module or function to tune + target : tvm.target.Target + The compilation target + + Returns + ------- + tasks: List[ExtractedTask] + The tasks extracted from this module + """ + # pylint: disable=import-outside-toplevel + from tvm.relax.expr import Function as RelaxFunc + from tvm.relax.transform import BindParams + + # pylint: enable=import-outside-toplevel + if isinstance(mod, RelaxFunc): + mod = IRModule({"main": mod}) + if not isinstance(target, Target): + target = Target(target) + if params: + mod = BindParams("main", params)(mod) + return list(_extract_task_func(mod, target)) + + +def extracted_tasks_to_tune_contexts( + extracted_tasks: List[ExtractedTask], + work_dir: str, + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + num_threads: Union[Literal["physical", "logical"], int] = "physical", + seed: Optional[int] = None, +) -> Tuple[List[TuneContext], List[float]]: + """Convert ExtractedTask to TuneContext. + + Parameters + ---------- + tasks : List[ExtractedTask] + The tasks to be converted + work_dir : str + The working directory to store logs and databases + space : SpaceGenerator.SpaceGeneratorType + The space generator to use. + strategy : SearchStrategy.SearchStrategyType + The search strategy to use. + num_threads : Union[Literal["physical", "logical"], int] + The number of threads to use in multi-threaded search algorithm. + seed : Optional[int] + The random seed to use. + + Returns + ------- + tasks : List[TuneContext] + The converted tasks + task_weights : List[float] + The weights of the tasks + """ + tasks: List[TuneContext] = [] + task_weights: List[float] = [] + for task, logger, rand_state in zip( + extracted_tasks, + get_loggers_from_work_dir(work_dir, [t.task_name for t in extracted_tasks]), + fork_seed(seed, n=len(extracted_tasks)), + ): + tasks.append( + TuneContext( + mod=task.dispatched[0], + target=task.target, + space_generator=space, + search_strategy=strategy, + task_name=task.task_name, + logger=logger, + rand_state=rand_state, + num_threads=num_threads, + ).clone() + ) + task_weights.append(task.weight) + return tasks, task_weights + + +def tune_relax( + mod: Union[IRModule, "relax.Function"], + params: Dict[str, NDArray], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + max_trials_per_task: Optional[int] = None, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + seed: Optional[int] = None, +) -> Database: + """Tune a Relax program. + + Parameters + ---------- + mod : Union[IRModule, relax.Function] + The module or function to tune + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + target : Union[Target, str] + The compilation target + work_dir : str + The working directory to store the tuning records + max_trials_global : int + The maximum number of trials to run + max_trials_per_task : Optional[int] + The maximum number of trials to run for each task + num_trials_per_iter : int + The number of trials to run per iteration + builder : BuilderType + The builder to use + runner : RunnerType + The runner to use + database : DatabaseType + The database to use + cost_model : CostModelType + The cost model to use + measure_callbacks : CallbackListType + The measure callbacks to use + task_scheduler : TaskSchedulerType + The task scheduler to use + space : SpaceGeneratorType + The space generator to use + strategy : SearchStrategyType + The search strategy to use + seed : Optional[int] + The random seed + + Returns + ------- + database : Database + The database that contains the tuning records + """ + tasks, task_weights = extracted_tasks_to_tune_contexts( + extracted_tasks=extract_tasks(mod, target, params), + work_dir=work_dir, + space=space, + strategy=strategy, + seed=seed, + ) + return tune_tasks( + tasks=tasks, + task_weights=task_weights, + work_dir=work_dir, + max_trials_global=max_trials_global, + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + ) + + +@register_func("tvm.meta_schedule.tune_relax") +def _tune_relax( + mod: Union[IRModule, "relax.Function"], + params: Dict[str, NDArray], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + max_trials_per_task: Optional[int] = None, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + seed: Optional[int] = None, +) -> Database: + """Interface with tuning api to tune a Relax program. + + Parameters + ---------- + mod : Union[IRModule, relax.Function] + The module or function to tune + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + target : Union[Target, str] + The compilation target + work_dir : str + The working directory to store the tuning records + max_trials_global : int + The maximum number of trials to run + max_trials_per_task : Optional[int] + The maximum number of trials to run for each task + num_trials_per_iter : int + The number of trials to run per iteration + builder : BuilderType + The builder to use + runner : RunnerType + The runner to use + database : DatabaseType + The database to use + cost_model : CostModelType + The cost model to use + measure_callbacks : CallbackListType + The measure callbacks to use + task_scheduler : TaskSchedulerType + The task scheduler to use + space : SpaceGeneratorType + The space generator to use + strategy : SearchStrategyType + The search strategy to use + seed : Optional[int] + The random seed + + Returns + ------- + ret_mod : IRModule + IRModule + """ + if isinstance(max_trials_global, IntImm): + max_trials_global = int(max_trials_global) + + tune_relax( + mod, + params, + target, + work_dir, + max_trials_global, + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + space=space, + strategy=strategy, + seed=seed, + ) + # Return original IRModule + # This pass only makes optimization decision + return mod + + +def compile_relax( + database: Database, + mod: IRModule, + target: Union[Target, str], + params: Optional[Dict[str, NDArray]], +) -> "relax.vm.Executable": + """Compile a relax program with a MetaSchedule database. + + Parameters + ---------- + database : Database + The database to use + mod : IRModule + The Relax program to be compiled + target : tvm.target.Target + The compilation target + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + + Returns + ------- + lib : relax.vm.Executable + The built runtime module or vm Executable for the given relax workload. + """ + # pylint: disable=import-outside-toplevel + from tvm.relax.transform import BindParams, MetaScheduleApplyDatabase + from tvm.relax.vm import build as relax_build + + # pylint: enable=import-outside-toplevel + if not isinstance(target, Target): + target = Target(target) + if params: + mod = BindParams("main", params)(mod) + + with target, database, PassContext(opt_level=3): + relax_mod = MetaScheduleApplyDatabase()(mod) + relax_ex = relax_build(relax_mod, target=target) + return relax_ex diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py index f3d505c28b0e..d5f5ee86e0b8 100644 --- a/python/tvm/meta_schedule/tir_integration.py +++ b/python/tvm/meta_schedule/tir_integration.py @@ -22,7 +22,9 @@ # isort: on from tvm import ir, tir +from tvm._ffi import register_func from tvm.target import Target +from tvm.tir.expr import IntImm from .builder import Builder from .cost_model import CostModel @@ -128,6 +130,93 @@ def tune_tir( ) +@register_func("tvm.meta_schedule.tune_tir") +def _tune_tir( + mod: Union[ir.IRModule, tir.PrimFunc], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "round-robin", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + task_name: str = "main", + num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical", + seed: Optional[int] = None, +) -> Database: + """Interface with tuning api to tune a TIR program. + + Parameters + ---------- + mod : Union[ir.IRModule, tir.PrimFunc] + The TIR function to tune. + target : Union[str, Target] + The target to tune for. + work_dir : str + The working directory. + max_trials_global : int + The maximum number of trials to run globally. + num_trials_per_iter : int + The number of trials to run per iteration + builder : Builder.BuilderType + The builder. + runner : Runner.RunnerType + The runner. + database : Database.DatabaseType + The database. + cost_model : CostModel.CostModelType + The cost model. + measure_callbacks : MeasureCallback.CallbackListType + The measure callbacks. + task_scheduler : TaskScheduler.TaskSchedulerType + The task scheduler. + space : SpaceGenerator.SpaceGeneratorType + The space generator. + strategy : SearchStrategy.SearchStrategyType + The search strategy. + task_name : str + The name of the task. + num_tuning_cores : Union[Literal["physical", "logical"], int] + The number of CPU cores to use during tuning. + seed : Optional[int] + The seed for the random number generator. + + Returns + ------- + ret_mod : IRModule + IRModule + """ + if isinstance(max_trials_global, IntImm): + max_trials_global = int(max_trials_global) + tune_tir( + mod, + target, + work_dir, + max_trials_global, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + space=space, + strategy=strategy, + task_name=task_name, + num_tuning_cores=num_tuning_cores, + seed=seed, + ) + # Return original IRModule + # This pass only makes optimization decision + return mod + + def compile_tir( database: Database, mod: Union[ir.IRModule, tir.PrimFunc], diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 38a46ebe757e..8c4f4ce864ab 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -24,7 +24,7 @@ # isort: on from tvm import IRModule -from tvm._ffi import register_object +from tvm._ffi import register_object, register_func from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule @@ -41,6 +41,7 @@ from .space_generator import SpaceGenerator +@register_func("tvm.meta_schedule.normalize_mod") def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: """Normalize the input to an IRModule""" if isinstance(mod, PrimFunc): diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 745a26a4dac4..c0ac180ff165 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -21,8 +21,8 @@ import types from typing import Callable, Dict, Union, Optional, List import numpy as np # type: ignore - import tvm.ir +from tvm.runtime import NDArray from . import _ffi_api @@ -218,6 +218,60 @@ def FuseTIR() -> tvm.ir.transform.Pass: return _ffi_api.FuseTIR() # type: ignore +def MetaScheduleApplyDatabase( + work_dir: Optional[str] = None, +) -> tvm.ir.transform.Pass: + """Apply the best schedule from tuning database. + work_dir : Optional[str] + work directory to deduce default database if database is not provided + (it will be ignored when an user passes database) + Returns + ------- + ret : tvm.transform.Pass + The registered pass + """ + return _ffi_api.MetaScheduleApplyDatabase(work_dir) # type: ignore + + +def MetaScheduleTuneTIR( + work_dir: str, + max_trials_global: int, +) -> tvm.ir.transform.Pass: + """Tune TIR with MetaSchedule. + Parameters + ---------- + work_dir: str + work directory + max_trials_gloabl: int + maximum number of total trials allowed for tuning + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.MetaScheduleTuneTIR(work_dir, max_trials_global) # type: ignore + + +def MetaScheduleTuneIRMod( + params: Dict[str, NDArray], + work_dir: str, + max_trials_global: int, +) -> tvm.ir.transform.Pass: + """Tune Relax IRModule with MetaSchedule. + Parameters + ---------- + params: Dict[str, NDArray] + model params + work_dir: str + work directory + max_trials_gloabl: int + maximum number of total trials allowed for tuning + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.MetaScheduleTuneIRMod(params, work_dir, max_trials_global) # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" @@ -255,6 +309,7 @@ def function_pass( opt_level=None, name=None, required=None, + traceable=False, ) -> Union[Callable, FunctionPass]: """Decorate a function pass. @@ -277,6 +332,9 @@ def function_pass( required : Optional[List[str]] The list of passes that the function pass is dependent on. + traceable: Boolean + Boolean variable whether the function pass is traceable + Returns ------- create_function_pass : Union[Callable, FunctionPass] @@ -350,7 +408,7 @@ def transform(func, mod, ctx): def create_function_pass(pass_arg): """Internal function that creates a function pass""" fname = name if name else pass_arg.__name__ - info = tvm.transform.PassInfo(opt_level, fname, required) + info = tvm.transform.PassInfo(opt_level, fname, required, traceable) if inspect.isclass(pass_arg): return _wrap_class_function_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): @@ -395,7 +453,7 @@ def __getattr__(self, name): def dataflowblock_pass( - pass_func=None, opt_level=None, name=None, required=None + pass_func=None, opt_level=None, name=None, required=None, traceable=False ) -> Union[Callable, DataflowBlockPass]: """Decorate a dataflowblock pass. @@ -418,6 +476,9 @@ def dataflowblock_pass( required : Optional[List[str]] The list of passes that the dataflowblock pass is dependent on. + traceable: Boolean + Boolean variable whether the dataflowblock pass is traceable + Returns ------- create_dataflowblock_pass : Union[Callable, DataflowBlockPass] @@ -499,7 +560,7 @@ def transform(block, mod, ctx): def create_dataflowblock_pass(pass_arg): """Internal function that creates a dataflowblock pass""" fname = name if name else pass_arg.__name__ - info = tvm.transform.PassInfo(opt_level, fname, required) + info = tvm.transform.PassInfo(opt_level, fname, required, traceable) if inspect.isclass(pass_arg): return _wrap_class_dataflowblock_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): diff --git a/python/tvm/relax/transform/tuning_api/__init__.py b/python/tvm/relax/transform/tuning_api/__init__.py new file mode 100644 index 000000000000..6c39d5c5359e --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax Tunign Pass API""" + +from .primitives import * +from .default_functions import * +from .database import * diff --git a/python/tvm/relax/transform/tuning_api/_ffi_api.py b/python/tvm/relax/transform/tuning_api/_ffi_api.py new file mode 100644 index 000000000000..f31522d02595 --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for relax.tuning_api""" +import tvm._ffi + +tvm._ffi._init_api("relax.tuning_api", __name__) diff --git a/python/tvm/relax/transform/tuning_api/database.py b/python/tvm/relax/transform/tuning_api/database.py new file mode 100644 index 000000000000..9477e142bad4 --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/database.py @@ -0,0 +1,273 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Tuning Pass API default functions""" +from typing import List, Optional +import logging + +from tvm.runtime import Object +from tvm.ir.module import IRModule +from tvm.meta_schedule.utils import _json_de_tvm +from tvm.meta_schedule.database import Workload +from tvm.tir.schedule.trace import JSON_TYPE +from tvm.target import Target +from tvm._ffi import register_object +from .primitives import Trace +from . import _ffi_api + +logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name + + +@register_object("relax.tuning_api.TuningRecord") +class TuningRecord(Object): + """The class of tuning records. + + Parameters + ---------- + trace : tvm.relax.transform.tuning_api.Trace + The trace of the tuning record. + run_secs : Optional[List[float]] + The run-time of the tuning record. + """ + + trace: Trace + run_secs: Optional[List[float]] + + def __init__( # type: ignore # pylint: disable=too-many-arguments + self, + trace: Trace, + run_secs: Optional[List[float]] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.TuningRecord, # type: ignore # pylint: disable=no-member + trace, + run_secs, + ) + + def as_json(self, include_irmod: bool = False) -> JSON_TYPE: + """Export the tuning record to a JSON string. + Parameters + ---------- + include_irmod: bool + Decides whether to serialize in_mod as well. + + Returns + ------- + json_str : str + The JSON string exported. + """ + return _json_de_tvm(_ffi_api.TuningRecordAsJSON(self, include_irmod)) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "TuningRecord": + """Create a tuning record from a json object. + + Parameters + ---------- + json_obj : JSON_TYPE + The json object to parse. + + Returns + ------- + tuning_record : TuningRecord + The parsed tuning record. + """ + return _ffi_api.TuningRecordFromJSON(json_obj) # type: ignore # pylint: disable=no-member + + +@register_object("relax.tuning_api.Database") +class Database(Object): + """The abstract database interface.""" + + def has_workload(self, mod: IRModule) -> bool: + """Check if the database has the given workload. + Parameters + ---------- + mod : IRModule + The IRModule to be searched for. + + Returns + ------- + result : bool + Whether the given workload is committed. + """ + return _ffi_api.DatabaseHasWorkload(self, mod) # type: ignore # pylint: disable=no-member + + def has_measurement_record(self, workload: Workload, target: Target) -> bool: + """Check if the database has a measurement record for the given workload and target pair. + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + + Returns + ------- + result : bool + Whether the given workload and target pair is committed for the measurement record. + """ + return _ffi_api.DatabaseHasMeasurementRecord(self, workload, target) # type: ignore # pylint: disable=no-member + + def has_tuning_record(self, workload: Workload, target: Target) -> bool: + """Check if the database has a tuning record for the given workload and target pair. + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + + Returns + ------- + result : bool + Whether the given workload and target pair is committed for the tuning record. + """ + return _ffi_api.DatabaseHasTuningRecord(self, workload, target) # type: ignore # pylint: disable=no-member + + def commit_workload(self, mod: IRModule) -> Workload: + """Commit a workload to the database if missing. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for or added. + + Returns + ------- + workload : Workload + The workload corresponding to the given IRModule. + """ + return _ffi_api.DatabaseCommitWorkload(self, mod) # type: ignore # pylint: disable=no-member + + def commit_measurement_record( + self, workload: Workload, target: Target, run_secs: List[float] + ) -> None: + """Commit a measurement record to the database. + A pair of workload and target will be used as a key. + + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + run_secs : Optional[List[float]] + The measurement record to add. + """ + _ffi_api.DatabaseCommitMeasurementRecord(self, workload, target, run_secs) # type: ignore # pylint: disable=no-member + + def commit_tuning_record( + self, workload: Workload, target: Target, record: TuningRecord + ) -> None: + """Commit a tuning record to the database. + A pair of workload and target will be used as a key. + + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + record : TuningRecord + The tuning record to add. + """ + _ffi_api.DatabaseCommitTuningRecord(self, workload, target, record) # type: ignore # pylint: disable=no-member + + def get_measurement_record(self, workload: Workload, target: Target) -> Optional[List[float]]: + """Get the measurement record of given workload and target from the database. + + Parameters + ---------- + workload : Workload + The workload to be searched for. + target: Target + The target to be searched for. + + Returns + ------- + measurement_record : Optional[List[float]] + Measurement record if exists. + """ + return _ffi_api.DatabaseGetMeasurementRecord(self, workload, target) # type: ignore # pylint: disable=no-member + + def get_top_k(self, workload: Workload, target: Target, top_k: int) -> List[TuningRecord]: + """Get the top K tuning records of given workload and target from the database. + + Parameters + ---------- + workload : Workload + The workload to be searched for. + target: Target + The target to be searched for. + top_k : int + The number of top records to get. + + Returns + ------- + top_k_records : List[TuningRecord] + The top K records. + """ + return _ffi_api.DatabaseGetTopK(self, workload, target, top_k) # type: ignore # pylint: disable=no-member + + +@register_object("relax.tuning_api.JSONDatabase") +class JSONDatabase(Database): + """The class of JSON database. + + Parameters + ---------- + path_workload : str + The path to the workload table. + path_tuning_record : str + The path to the tuning record table. + Manages pairs of + path_measurement_record : str + The path to the path_measurement_record table. + Manages pairs of + """ + + path_workload: str + path_tuning_record: str + path_measurement_record: str + + def __init__( + self, + path_workload: str, + path_tuning_record: str, + path_measurement_record: str, + allow_missing: bool = True, + ) -> None: + """Constructor. + + Parameters + ---------- + path_workload : str + The path to the workload table. + path_tuning_record : str + The path to the tuning record table. + path_measurement_record : str + The path to the path_measurement_record table. + allow_missing : bool + Whether to create new file when the given path is not found. + """ + self.__init_handle_by_constructor__( + _ffi_api.DatabaseJSONDatabase, # type: ignore # pylint: disable=no-member + path_workload, + path_tuning_record, + path_measurement_record, + allow_missing, + ) diff --git a/python/tvm/relax/transform/tuning_api/default_functions.py b/python/tvm/relax/transform/tuning_api/default_functions.py new file mode 100644 index 000000000000..b72b2f30ee2b --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/default_functions.py @@ -0,0 +1,306 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Tuning Pass API default functions""" +from typing import Dict, List, Optional +import sys +import itertools +import logging +import numpy as np # type: ignore + +import tvm +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext, Pass +from tvm import meta_schedule +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder +from tvm.meta_schedule.utils import get_global_func_with_default_on_worker +from tvm.meta_schedule.runner import ( + EvaluatorConfig, + LocalRunner, + RunnerInput, +) +from tvm._ffi.registry import register_func +from .primitives import Knob, Trace + +logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name + +# Default transform func that returns original IRModule. +@tvm.register_func("relax.tuning_api.Choice.default_transform_func") +def default_transform_func(mod): + return mod + + +# Default constraint func that always returns true. +@tvm.register_func("relax.tuning_api.Choice.default_constr_func") +def default_constr_func(mod: IRModule) -> bool: # pylint: disable=unused-argument + return True + + +@register_func("relax.tuning_api.default_generate_candidate") +def default_generate_candidate( + knobs: List[Knob], trace: Trace, eval_passes: Optional[List[Pass]] = None +) -> List[Trace]: + """ + Default function to generate the search space for a given trace by using registered choices. + This function simply expands candidate space as long as the knob's constraint satisfies. + To reduce the search space, a developer may expand each choice with smart search method. + (e.g., genetic search, multi-armed bandit) + Note, each pass generates candidates without worrying about the interaction with other passes. + i.e., it only uses its incoming trace/IRModule and Choices for candidate generation. + This will help alleviating the complexity of joint-optimization significantly. + - consideration of interaction between optimizations has known to be extremely difficult. + + Parameters + ---------- + knobs : List[Knob] + List of Knobs to consider to generate candidate for input trace. + trace: Trace + Input trace. + eval_passes: Optional[List[Pass]] + List of passes to consider to evaluate each candidate. + This will enable joint-optimization. + + Return + ---------- + candidates: List[Trace] + List of candidate traces + """ + + candidates = [trace] + # Iterate over every decision + for knob in knobs: + num = len(candidates) + for _ in range(num): + cur_trace = candidates.pop(0) + for decision in knob.choices.keys(): + choice = knob.choices[decision] + # Generate new candidate when this condition satisfies. + if choice.check_constr(cur_trace.out_mod): + new_trace = cur_trace.deepcopy() + new_trace.add(knob, decision) + candidates.append(new_trace) + + # Expand candidates by using eval passes if provided. This will enable joint-optimization. + if eval_passes: + candidates = default_consider_eval_passes(candidates, eval_passes) + return candidates + + +@register_func("relax.tuning_api.default_consider_eval_passes") +def default_consider_eval_passes( + init_candidates: List[Trace], eval_passes: Optional[List[Pass]] = None +) -> List[Trace]: + """ + Default function to update traces with eval passes. + It visits each eval_pass in dfs order in transform.Sequential() and + returns the best possible candidate trace for each candidate. + + Parameters + ---------- + init_candidates: List[Trace] + Initial candidates + eval_passes: Optional[List[Pass]] + List of passes to consider to evaluate each candidate. + This will enable joint-optimization. + Return + ---------- + candidates: List[Trace] + List of candidate traces + """ + if not eval_passes: + return init_candidates + + eval_passes = list(eval_passes) if not isinstance(eval_passes, list) else eval_passes + ctx = PassContext.current() + candidates = [] + + for trace in init_candidates: + ctx.push_trace(trace) + tvm.transform.Sequential(eval_passes)(trace.out_mod) + new_trace = ctx.pop_trace() + # A new trace contains the best decisions in eval_passes + candidates.append(new_trace) + + return candidates + + +@register_func("relax.tuning_api.default_evaluate") +def default_evaluate( + candidates: List[Trace], + target_str: str, + params: Optional[Dict[str, np.ndarray]] = None, + builder: Optional[meta_schedule.builder.Builder] = None, + runner: Optional[meta_schedule.runner.Runner] = None, +) -> None: + """ + Default function to evaluate a set of candidate traces by using MetaSchedule builder/runner. + + Parameters + ---------- + candidates: List[Trace] + List of traces to evaluate. + target_str: str, + Compilation target (e.g., llvm, cuda). + params: Optional[Dict[str, np.ndarray]] + Params to bind. + builder: Optional[meta_schedule.builder.Builder] + builder function. If not provided, default local builder will be used. + runner: Optional[meta_schedule.runner.Runner] + runner function. If not provided, default local runner will be used. + """ + + ctx = PassContext.current() + target = tvm.target.Target(target_str) + database = PassContext.current().get_tuning_api_database() + # Setup default local builder if not provided + if builder is None: + + def relax_build( + mod: IRModule, + target: tvm.target.Target, + params: Optional[Dict[str, np.ndarray]], + ): + if params: + mod = tvm.relax.transform.BindParams("main", params)(mod) + relax_exec = tvm.relax.vm.build(mod, target) + return relax_exec.mod + + builder = LocalBuilder(f_build=relax_build) + + # Setup default local runner if not provided + if runner is None: + + def relax_eval_func(rt_mod, device, evaluator_config, repeated_args): + relax_exec = tvm.relax.vm.Executable(rt_mod) + relax_vm = tvm.relax.VirtualMachine(exec=relax_exec, device=device) + + evaluator = relax_vm.module.time_evaluator( + func_name="main", + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + + return costs + + runner = LocalRunner( + evaluator_config=EvaluatorConfig( + number=3, repeat=5, min_repeat_ms=100, enable_cpu_cache_flush=False + ), + f_run_evaluator=relax_eval_func, + ) + + # set up clean up function + f_clean_build = get_global_func_with_default_on_worker("meta_schedule.remove_build_dir", None) + assert f_clean_build + + # Keep track of number of evaluations (mostly for the debugging purpose) + num_evals = 0 + # Evaluation + for candidate in candidates: + # If this candidate is already evaluated, skip the measurement + if candidate.perf != -1: + continue + + # Evaluate candidates + num_evals += 1 + mod = candidate.out_mod + workload = database.commit_workload(mod) + + # If this workload and target pair has measured before, fetch its data. + if database.has_measurement_record(workload, target): + run_secs = database.get_measurement_record(workload, target) + # Otherwise, measure it. + else: + # Build candidate + (builder_result,) = builder.build([BuilderInput(mod, target, params)]) + + if builder_result.artifact_path is None: + # Build error + # Assign the worst performance and move on to the next candidate. + logger.warning(builder_result.error_msg) + run_secs = [1e100] + else: + # If build passes, set up runner input and measure the performance. + args_info = [ + TensorInfo( + shape=[int(i) for i in p.struct_info.shape], dtype=p.struct_info.dtype + ) + for p in mod["main"].params + ] # convert list[Var] to list[TensorInfo] + runner_input = RunnerInput( + builder_result.artifact_path, target_str, args_info=args_info + ) + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + + run_secs = runner_result.run_secs + # Runtime error + # Assign the worst performance and move on to the next candidate. + if runner_result.error_msg is not None: + logger.warning(runner_result.error_msg) + run_secs = [1e100] + + database.commit_measurement_record(workload, target, run_secs) + + # Clean up the artifact + f_clean_build(builder_result.artifact_path) + + # For valid measurments, compute the average and update the trace performance. + perfs = [] + for result in run_secs: + if isinstance(result, tvm.tir.FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + perfs.append(result) + + # Store the evaluation result + candidate.set_perf(np.mean(perfs)) + + ctx.inc_num_evals(num_evals) + + +def select_best_candidate(candidates: List[Trace]) -> Trace: + """ + Select the best trace. + + Parameters + ---------- + candidates: List[Trace] + Candidate traces + + Return + ---------- + best_trace: Trace + Trace with the best performance + """ + best_perf, best_trace = sys.maxsize, None + for candidate in candidates: + avg = candidate.perf + # Select best one + if best_perf > avg: + best_perf = avg + best_trace = candidate + return best_trace diff --git a/python/tvm/relax/transform/tuning_api/primitives.py b/python/tvm/relax/transform/tuning_api/primitives.py new file mode 100644 index 000000000000..67b81ba7e99c --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/primitives.py @@ -0,0 +1,419 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Tuning Pass API primitives""" + +from typing import Callable, Union, Dict, List, Optional, Sequence +import logging +import tvm +from tvm.runtime import Object +from tvm.ir.module import IRModule +from tvm.relax import Expr +from tvm.tir.schedule.trace import JSON_TYPE, _json_from_tvm +from tvm._ffi import register_object +from . import _ffi_api + +logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name + + +@register_object("relax.tuning_api.Choice") +class Choice(Object): + """ + A TVM object Choice that maintains a set of transformation and constraint function keys. + Corresponding functions should be registered as PackedFunc with these keys. + Transformation function will be applied when constraint function returns true. + Parameters + ---------- + transform_func_key : Optional[str] + Key for transformation function. + transform_func_args : Optional[List] + Arguments for transformation function. + constr_func_key : Optional[str] + Key for constraint function. + constr_func_args : Optional[List] + Arguments for constraint function. + + Examples + -------- + The following code block defines a Choice. + + .. code-block:: python + @tvm.register_func("relax.tuning_api.test.transform_func") + def apply(mod): + return relax.tuning_api.FoldConstant()(mod) + @tvm.register_func("relax.tuning_api.test.constr_func") + def constr(mod): + return len(mod.functions) == 3 + # Define a choice to apply constant folding only when IRModule has three functions. + choice = Choice( + transform_func_key = "relax.tuning_api.test.transform_func", + constr_func_key = "relax.tuning_api.test.constr_func" + ) + """ + + def __init__( + self, + transform_func_key: Optional[str] = None, + transform_func_args: Optional[List] = None, + constr_func_key: Optional[str] = None, + constr_func_args: Optional[List] = None, + ): + """Constructor + Parameters + ---------- + transform_func_key : Optional[str] + Key for transformation function. + + f_tramsform_args: Optional[List] + Arguments for transformation function. + + constr_func_key : Optional[str] + Key for constraint function. + + constr_func_args: Optional[List] + Arguments for constraint function. + """ + + if transform_func_key is None: + transform_func_key = "relax.tuning_api.Choice.default_transform_func" + + if transform_func_args is None: + transform_func_args = [] + + if constr_func_key is None: + constr_func_key = "relax.tuning_api.Choice.default_constr_func" + + if constr_func_args is None: + constr_func_args = [] + + self.__init_handle_by_constructor__( + _ffi_api.Choice, # type: ignore + transform_func_key, + transform_func_args, + constr_func_key, + constr_func_args, # type: ignore # pylint: disable=no-member + ) + + def get_transform_func(self) -> Callable: + """Getter for transform_func + Returns + ------- + ret: Callable + registered transformation function + """ + return _ffi_api.ChoiceGetTransformFunc(self) # type: ignore + + def get_constr_func(self) -> Callable: + """Getter for constr_func + Returns + ------- + ret: Callable + registered constraint function + """ + return _ffi_api.ChoiceGetConstrFunc(self) # type: ignore + + def apply_transform_func(self, mod: IRModule) -> IRModule: + """Perform transform_func with its arguments + Returns + ------- + ret: IRModule + Transformed IRModule + """ + return _ffi_api.ChoiceApplyTransformFunc(self, mod) # type: ignore + + def check_constr(self, mod: IRModule) -> bool: + """Perform constr_func with its arguments + Returns + ------- + ret: bool + Returns whether the IRModule satisfies the constraint or not + """ + return _ffi_api.ChoiceCheckConstr(self, mod) # type: ignore + + def as_json(self) -> JSON_TYPE: + """Serialize the trace as a JSON-style object + Returns + ------- + json: JSON_TYPE + The JSON-style object + """ + return _ffi_api.ChoiceAsJSON(self) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "Choice": + """Create Choice from JSON obj + + Parameters + ---------- + json_obj: JSON_TYPE + Choice serialized with JSON + + Return + ---------- + choice: Choice + Deserialized choice + """ + return _ffi_api.ChoiceFromJSON(json_obj) # type: ignore + + def deepcopy(self): + return Choice.from_json(self.as_json()) + + +@register_object("relax.tuning_api.Knob") +class Knob(Object): + """ + A TVM object Knob that maintains a set of valid Choices. + By using Knobs, a tuning pass can generate candidates and define the search space. + Parameters + ---------- + name : str + Name of the knob. + + choices: Union[List[Choice], Dict[str, Choice]] + A list of valid choices + + Examples + -------- + The following code block defines a Knob. + + .. code-block:: python + @tvm.register_func("relax.tuning_api.test.transform_func") + def apply(mod): + return relax.tuning_api.FoldConstant()(mod) + choices = {"apply": Choice("relax.tuning_api.test.transform_func"), "noapply": Choice()} + # A knob manages a set of its valid choices + knob = Knob("MockTuningKnob", choices) + """ + + def __init__(self, name: str, choices: Union[List[Choice], Dict[str, Choice]]): + """Constructor.""" + if isinstance(choices, list): + choices = {str(idx): val for idx, val in enumerate(choices)} + + self.__init_handle_by_constructor__( + _ffi_api.Knob, name, choices # type: ignore # pylint: disable=no-member + ) + + def verify(self, decision: Union[str, int]) -> bool: + """Verify if the decision is valid.""" + if isinstance(decision, int): + decision = str(decision) + return _ffi_api.KnobIsValidDecision(self, decision) # type: ignore + + def apply(self, mod: IRModule, decision: Union[str, int]) -> IRModule: + """Get choice if a decision is valid.""" + if isinstance(decision, int): + decision = str(decision) + return _ffi_api.KnobApply(self, mod, decision) # type: ignore + + def as_json(self) -> JSON_TYPE: + """Serialize the trace as a JSON-style object + Returns + ------- + json: JSON_TYPE + The JSON-style object + """ + return _ffi_api.KnobAsJSON(self) # type: ignore + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "Knob": + """Create Knob from JSON obj + + Parameters + ---------- + json_obj: JSON_TYPE + Knob serialized with JSON + + Return + ---------- + knob: Knob + Deserialized knob + """ + return _ffi_api.KnobFromJSON(json_obj) # type: ignore + + def __str__(self) -> str: + msg = f"{self.name} (# of choices: {len(self.choices)})\n" + for name, choice in self.choices.items(): + msg += f" - {name}: {choice}\n" + return msg + + def deepcopy(self): + return Knob.from_json(self.as_json()) + + +@register_object("relax.tuning_api.Trace") +class Trace(Object): + """ + A TVM object Trace logs the history of transformations (decisions). + Parameters + ---------- + in_mod : IRModule + Input IRModule. + knobs: Optional[List[Knob]] + A list of knobs applied in the trace. + decisions: Optional[Sequence[Union[str, int]]] + A list of decisions made for each knob + + Examples + -------- + The following code block defines a Trace. + + .. code-block:: python + + trace = Trace(mod, [knob1, knob2, knob3], ["c1", "c0", "c3"]) + assert trace.size == 3 # Length of history. + # 'out' contains IRModule that applies transformations in the trace. + out: IRModule = trace.add(knob4, "c2") + assert trace.size == 4 # Length of history. + trace.set_perf(0.03) # Set the performance number of the trace. + """ + + def __init__( + self, + in_mod: IRModule, + knobs: Optional[List[Knob]] = None, + decisions: Optional[Sequence[Union[str, int]]] = None, + ): + """Constructor.""" + knobs = knobs if knobs else list() + decisions = ( + [str(v) if isinstance(v, int) else v for v in decisions] if decisions else list() + ) + self.__init_handle_by_constructor__( + _ffi_api.Trace, in_mod, knobs, decisions # type: ignore # pylint: disable=no-member + ) + + def verify(self) -> bool: + """Verify if current history is valid.""" + return _ffi_api.TraceVerify() # type: ignore + + def add(self, knob: Knob, decision: Union[str, int]) -> IRModule: + """Add & Apply new decision (with knob).""" + if isinstance(decision, int): + decision = str(decision) + return _ffi_api.TraceAdd(self, knob, decision) # type: ignore + + def set_perf(self, perf: float) -> None: + """Set performance number for the trace.""" + return _ffi_api.TraceSetPerf(self, perf) # type: ignore + + def set_out_mod(self, mod: IRModule) -> None: + """Set out_mod for the trace.""" + return _ffi_api.TraceSetOutMod(self, mod) # type: ignore + + def as_json(self, include_irmod: bool = True) -> JSON_TYPE: + """Serialize the trace as a JSON-style object. + Parameters + ---------- + include_irmod: bool + Decides whether to serialize in_mod as well. + + Returns + ------- + json: JSON_TYPE + The JSON-style object. + """ + obj = _ffi_api.TraceAsJSON(self, include_irmod) # type: ignore + return _json_from_tvm(obj) + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "Trace": + """Create Trace from JSON obj. + + Parameters + ---------- + json_obj: JSON_TYPE + Trace serialized with JSON. + + Return + ---------- + trace: Trace + Deserialized trace. + """ + return _ffi_api.TraceFromJSON(json_obj) # type: ignore + + def __str__(self) -> str: + n = len(self.knobs) + msg = f"Trace length: {n}\n" + for idx in range(n): + msg += f"[{idx+1}] {self.knobs[idx].name}: {self.decisions[idx]}\n" + return msg + + def deepcopy(self) -> "Trace": + new_in_mod = deepcopy_irmodule(self.in_mod) + new_knobs = [knob.deepcopy() for knob in self.knobs] + new_decisions = [str(decision) for decision in self.decisions] + new_trace = Trace(new_in_mod, new_knobs, new_decisions) + new_out_mod = deepcopy_irmodule(self.out_mod) + new_trace.set_out_mod(new_out_mod) + return new_trace + + +def get_trace(in_: Union[Trace, IRModule, Expr]) -> Trace: + """ + Getter for a trace wrapper. + + Parameters + ---------- + in_: Union[Trace, IRModule, Expr] + Input entity + Return + ---------- + wrapped: Trace + Traced entity + """ + if isinstance(in_, Trace): + return in_ + if isinstance(in_, IRModule): + return Trace(in_) + if isinstance(in_, Expr): # type: ignore + return Trace(tvm.IRModule.from_expr(in_)) + + raise Exception(f"Invalid input type for trace: {type(in_)}") + + +@tvm.register_func("relax.tuning_api.deepcopy_irmodule") +def deepcopy_irmodule(mod: IRModule) -> IRModule: + """ + Deepcopy for an IRModule. + Parameters + ---------- + mod: IRModule + input IRModule + Return + ---------- + copied_mod: IRModule + deep-copied IRModule + """ + func_save_json = tvm.get_global_func("node.SaveJSON") + func_load_json = tvm.get_global_func("node.LoadJSON") + new_mod = None + # Handle external modules separately if exist + # TODO(tvm-team): + # Serialization of IRModule with external mods is tricky. + # (1) External mod is runtime module. + # (2) Currently, `export_library` does not support serialization of + # runtime module without the host module + # Therefore, we simply pass around the compiled external modules without copy for now. + # Revisit later when we have a better solution. + if mod.attrs and "external_mods" in mod.attrs: + tmp_mod = mod.without_attr("external_mods") + new_mod = func_load_json(func_save_json(tmp_mod)) + new_mod = new_mod.with_attr("external_mods", mod.attrs["external_mods"]) + else: + new_mod = func_load_json(func_save_json(mod)) + + return new_mod diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py index 9450ade34e67..94d211a7fb4c 100644 --- a/python/tvm/tir/transform/function_pass.py +++ b/python/tvm/tir/transform/function_pass.py @@ -70,6 +70,7 @@ def prim_func_pass( opt_level: int = None, name: Optional[str] = None, required: Optional[List[str]] = None, + traceable=False, ) -> Union[Callable, PrimFuncPass]: """Decorate a function pass. @@ -148,7 +149,7 @@ def transform(func, mod, ctx): def create_function_pass(pass_arg): """Internal function that creates a function pass""" fname = name if name else pass_arg.__name__ - info = PassInfo(opt_level, fname, required) + info = PassInfo(opt_level, fname, required, traceable) if inspect.isclass(pass_arg): return _wrap_class_function_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 66b06e6b505d..619526d0b56b 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -341,11 +342,13 @@ class ModulePass : public Pass { TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); }; -PassInfo::PassInfo(int opt_level, String name, tvm::Array required) { +PassInfo::PassInfo(int opt_level, String name, tvm::Array required, + bool traceable) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); pass_info->required = std::move(required); + pass_info->traceable = std::move(traceable); data_ = std::move(pass_info); } @@ -401,7 +404,7 @@ Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { Sequential::Sequential(tvm::Array passes, String name) { auto n = make_object(); n->passes = std::move(passes); - PassInfo pass_info = PassInfo(0, std::move(name), {}); + PassInfo pass_info = PassInfo(0, std::move(name), {}, /* traceable */ false); n->pass_info = std::move(pass_info); data_ = std::move(n); } @@ -444,26 +447,61 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c VLOG(0) << "skipping disabled pass '" << pass_info->name << "'"; continue; } + // resolve dependencies for (const auto& it : pass_info->required) { mod = GetPass(it)(std::move(mod), pass_ctx); } - mod = pass(std::move(mod), pass_ctx); + + // This handles passes that does not use Relax tuning API (untraceable passes). + // We make untraceable passes trackable when pass context has a trace (trace mode). + // When passes to trace (make_traceable) is provided from users, we only make them trackable. + if (pass_ctx->trace_stack.size() && !pass_info->traceable && + (!pass_ctx->make_traceable.defined() || + pass_ctx->make_traceable.value().count(pass_info->name))) { + // TODO(tvm-team): Currently, there are some inconsistency in the pass registration. + // 1. Some passes are not registered in ffi registry. + // 2. Some passes do not follow the name convention. (e.g., = + ) + + // Due to these problems, serialization with non-traceable passes is handled in a hacky way + // now. Find a systematic way to identify such inconsistencies and fix them. + + // In the future, we should pass the ffi key for a pass by deducing from its name. + String transform_func_key = "relax.tuning_api.Choice.default_transform_func"; + String constr_func_key = "relax.tuning_api.Choice.default_constr_func"; + + relax::Knob knob = relax::Knob( + pass_info->name, {{"Applied", relax::Choice(transform_func_key, Array(), + constr_func_key, Array())}}); + + // Add new decision to the trace at the top of the stack. + auto trace = Downcast(pass_ctx->trace_stack.back()); + trace->Add(knob, "Applied"); + // In the future, we should just have + // mod = trace->Add(knob, "enabled"); + // instead of the two lines below. + mod = pass(std::move(mod), pass_ctx); + trace->SetOutMod(mod); + + } else { + mod = pass(std::move(mod), pass_ctx); + } } return mod; } Pass CreateModulePass(const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return ModulePass(pass_func, pass_info); } TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_GLOBAL("transform.PassInfo") - .set_body_typed([](int opt_level, String name, tvm::Array required) { - return PassInfo(opt_level, name, required); + .set_body_typed([](int opt_level, String name, tvm::Array required, bool traceable) { + return PassInfo(opt_level, name, required, traceable); }); TVM_REGISTER_GLOBAL("transform.Info").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -514,7 +552,8 @@ TVM_REGISTER_GLOBAL("transform.Sequential").set_body([](TVMArgs args, TVMRetValu int opt_level = args[1]; std::string name = args[2]; tvm::Array required = args[3]; - PassInfo pass_info = PassInfo(opt_level, name, required); + bool traceable = args[4]; + PassInfo pass_info = PassInfo(opt_level, name, required, /* traceable */ traceable); *ret = Sequential(passes, pass_info); }); @@ -537,7 +576,9 @@ TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, Array instruments, - Optional> config) { + Optional> config, Array trace_stack, + Optional> make_traceable, int num_evals, + Optional tuning_api_database) { auto pctx = PassContext::Create(); pctx->opt_level = opt_level; @@ -547,6 +588,10 @@ TVM_REGISTER_GLOBAL("transform.PassContext") if (config.defined()) { pctx->config = config.value(); } + pctx->trace_stack = std::move(trace_stack); + pctx->make_traceable = std::move(make_traceable); + pctx->num_evals = std::move(num_evals); + pctx->tuning_api_database = std::move(tuning_api_database); PassConfigManager::Global()->Legalize(&(pctx->config)); return pctx; }); @@ -562,7 +607,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "\tdisabled passes: " << node->disabled_pass << "\n"; p->stream << "\tinstruments: " << node->instruments << "\n"; - p->stream << "\tconfig: " << node->config; + p->stream << "\tconfig: " << node->config << "\n"; + p->stream << "\ttrace stack: " << node->trace_stack; }); class PassContext::Internal { @@ -572,6 +618,22 @@ class PassContext::Internal { static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; +TVM_REGISTER_GLOBAL("transform.GetTraceStack") + .set_body_method(&PassContextNode::GetTraceStack); +TVM_REGISTER_GLOBAL("transform.PushTrace") + .set_body_method(&PassContextNode::PushTrace); +TVM_REGISTER_GLOBAL("transform.PopTrace").set_body_method(&PassContextNode::PopTrace); +TVM_REGISTER_GLOBAL("transform.GetTraceStackSize") + .set_body_method(&PassContextNode::GetTraceStackSize); +TVM_REGISTER_GLOBAL("transform.GetCurrentTrace") + .set_body_method(&PassContextNode::GetCurrentTrace); +TVM_REGISTER_GLOBAL("transform.SetNumEvals") + .set_body_method(&PassContextNode::SetNumEvals); +TVM_REGISTER_GLOBAL("transform.IncNumEvals") + .set_body_method(&PassContextNode::IncNumEvals); +TVM_REGISTER_GLOBAL("transform.GetTuningAPIDatabase") + .set_body_method(&PassContextNode::GetTuningAPIDatabase); + TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::Internal::EnterScope); @@ -595,7 +657,7 @@ Pass PrintIR(String header, bool show_meta_data) { LOG(INFO) << "PrintIR(" << header << "):\n" << mod; return mod; }; - return CreateModulePass(pass_func, 0, "PrintIR", {}); + return CreateModulePass(pass_func, 0, "PrintIR", {}, /* traceable */ false); } TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc new file mode 100644 index 000000000000..beb3950af1d1 --- /dev/null +++ b/src/relax/backend/task_extraction.cc @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace backend { + +using tvm::meta_schedule::ExtractedTask; + +/*! + * \brief Extract the Meta-Schedule tuning task from a given IRModule. + * \note + * 1. The task extractor is responsible for task deduplication. The + * deduplication is achieved by comparing structural hashes of PrimFuncs. + * 2. For a PrimFunc, the weight of its corresponding task is the number + * of times it called by op Call-TIR. Say in an IRModule there are three + * PrimFuncs `fn1`, `fn2` and `fn3` sharing the same structural hash. + * Suppose `fn1` is called by 5 Call-TIR ops among all Relax function, + * `fn2` is called by 3 Call-TIR and `fn3` is called by 5 Call-TIR. + * Then we will have a ExtractedTask for all three functions, whose weight + * is 5 + 3 + 2 = 10. + */ +class TaskExtractor : public ExprVisitor { + public: + static Array ExtractTask(IRModule mod, Target target) { + TaskExtractor extractor(mod, target); + // We go through each Relax function in the module. + for (const auto& kv : mod->functions) { + if (const auto* func = kv.second.as()) { + extractor(GetRef(func)); + } + } + return std::move(extractor.tasks_); + } + + private: + explicit TaskExtractor(IRModule mod, Target target) + : mod_(std::move(mod)), target_(std::move(target)) { + normalize_mod_func_ = runtime::Registry::Get("tvm.meta_schedule.normalize_mod"); + ICHECK(normalize_mod_func_) << "Normalization function is not found."; + } + + void VisitExpr_(const CallNode* call) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + + // TODO(@tvm-team): When we differentiate the call for tir function and packed function, + // this logic should be changed accordingly. + if (!call->op.same_as(call_tir_op)) { + // Since the Relax function is of A-normal form, the arguments of this call cannot be another + // Calls. And hence we do not need to recurse into this Call. + return; + } + + // Do not extract external function + if (call->args[0].as()) { + return; + } + + const GlobalVar& global_var = Downcast(call->args[0]); + const tir::PrimFunc& func = Downcast(mod_->Lookup(global_var)); + + auto it = func2task_.find(func); + if (it != func2task_.end()) { + it->second->weight += 1; + return; + } + + IRModule tir_mod = (*normalize_mod_func_)(func); + ExtractedTask task(/*task_name=*/global_var->name_hint, // + /*mod=*/tir_mod, // + /*target=*/target_, // + /*dispatched=*/{tir_mod}, // + /*weight=*/1); + tasks_.push_back(task); + func2task_.emplace(func, task); + } + + IRModule mod_; + Target target_; + Array tasks_; + std::unordered_map func2task_; + const runtime::PackedFunc* normalize_mod_func_; +}; + +TVM_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask") + .set_body_typed([](IRModule mod, Target target) { + return TaskExtractor::ExtractTask(std::move(mod), std::move(target)); + }); + +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index 1b077d8b887a..9f418bff5c6d 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -173,8 +173,8 @@ bool FunctionPassNode::SkipFunction(const Function& func) const { Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return FunctionPass(pass_func, pass_info); } @@ -389,8 +389,8 @@ IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext& pass Pass CreateDataflowBlockPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return DataflowBlockPass(pass_func, pass_info); } diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc new file mode 100644 index 000000000000..d444ba16654f --- /dev/null +++ b/src/relax/transform/meta_schedule.cc @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/meta_schedule.cc + * \brief Pass for meta_schedule tuning + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace transform { + +class MetaScheduleTuner { + public: + explicit MetaScheduleTuner(Target target, String work_dir, Integer max_trials_global, + Map params = {}) + : target_(target), + work_dir_(work_dir), + max_trials_global_(max_trials_global), + params_(params) { + candgen_func_ = runtime::Registry::Get("relax.tuning_api.default_generate_candidate"); + ICHECK(candgen_func_) << "Default candidate generation function is not found."; + normalize_mod_func_ = runtime::Registry::Get("tvm.meta_schedule.normalize_mod"); + ICHECK(normalize_mod_func_) << "Normalization function is not found."; + } + + // TODO(@sunggg): Currently, only supports basic arguments. + IRModule TuneIRMod(IRModule mod, transform::PassContext ctx) { + Trace trace = Downcast(ctx->GetCurrentTrace()); + ctx->PopTrace(); + Choice choice("tvm.meta_schedule.tune_relax", {params_, target_, work_dir_, max_trials_global_}, + "relax.tuning_api.Choice.default_constr_func", {}); + Knob knob("meta_schedule.tune_irmod", {{"0", choice}}); + Array candidates = (*candgen_func_)(Array({knob}), trace); + ICHECK(candidates.size() == 1); + Trace best_trace = candidates[0]; + ctx->PushTrace(best_trace); + // since we separate tuning from application, return original IRModule + return mod; + } + + // TODO(@sunggg): Currently, only supports basic arguments. + tir::PrimFunc TuneTIR(tir::PrimFunc f, transform::PassContext ctx) { + // TODO(@sunggg): Whenever we tune tir, assume we start a new trace w/o pushing to the trace + // stack. Revisit later when we collect more usecases. + Trace trace = Trace((*normalize_mod_func_)(f), {}, {}); + + Choice choice("tvm.meta_schedule.tune_tir", {target_, work_dir_, max_trials_global_}, + "relax.tuning_api.Choice.default_constr_func", {}); + Knob knob("meta_schedule.tune_primfunc", {{"0", choice}}); + Array candidates = (*candgen_func_)(Array({knob}), trace); + ICHECK(candidates.size() == 1); + // since we separate tuning from application, return original IRModule + return f; + } + + private: + Target target_; + String work_dir_; + Integer max_trials_global_; + Map params_; + const runtime::PackedFunc* candgen_func_; + const runtime::PackedFunc* normalize_mod_func_; +}; + +Pass MetaScheduleApplyDatabase(Optional work_dir) { + using tvm::meta_schedule::Database; + Target target = Target::Current(false); + const runtime::PackedFunc* normalize_mod_func_ = + runtime::Registry::Get("tvm.meta_schedule.normalize_mod"); + ICHECK(normalize_mod_func_) << "Normalization function is not found."; + + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext ctx) { + Database database{nullptr}; + if (Database::Current().defined()) { + database = Database::Current().value(); + } else { + ICHECK(work_dir.defined()); + String path_workload = work_dir.value() + "/database_workload.json"; + String path_tuning_record = work_dir.value() + "/database_tuning_record.json"; + LOG(WARNING) << "Creating JSONDatabase. Workload at: " << path_workload + << ", Tuning records at: " << path_tuning_record; + database = meta_schedule::Database::JSONDatabase(path_workload, path_tuning_record, true); + } + + Map result; + for (const auto& iter : mod->functions) { + GlobalVar gv = iter.first; + BaseFunc base_func = iter.second; + if (const auto* prim_func_node = base_func.as()) { + tir::PrimFunc prim_func = GetRef(prim_func_node); + + IRModule tir_mod = (*normalize_mod_func_)(prim_func); + if (Optional sch = database->QuerySchedule(tir_mod, target, gv->name_hint)) { + IRModule new_mod = sch.value()->mod(); + ICHECK_EQ(new_mod->functions.size(), 1); + BaseFunc new_base_func = (*new_mod->functions.begin()).second; + ICHECK(new_base_func->IsInstance()); + tir::PrimFunc new_prim_func = Downcast(new_base_func); + // copy the original attrs + new_prim_func = WithAttrs(std::move(new_prim_func), {prim_func->attrs->dict}); + result.Set(gv, new_prim_func); + continue; + } else { + LOG(WARNING) << "Tuning record is not found for primfunc: " << gv->name_hint; + } + } + result.Set(gv, base_func); + } + return IRModule(result, // functions + {}, // type_definitions + {}, // import_set + {}, // map + mod->attrs); // attrs); + }; + return CreateModulePass(pass_func, 0, "MetaScheduleApplyDatabase", {}); +} + +Pass MetaScheduleTuneIRMod(Map params, String work_dir, + Integer max_trials_global) { + Target target = Target::Current(false); + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext ctx) { + return MetaScheduleTuner(target, work_dir, max_trials_global, params).TuneIRMod(m, ctx); + }; + return CreateModulePass(/*pass function*/ pass_func, /*opt level*/ 0, + /*pass name*/ "MetaScheduleTuneIRModule", + /*required*/ {}, + /*traceable*/ true); +} + +Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) { + Target target = Target::Current(false); + runtime::TypedPackedFunc pass_func = + [=](tir::PrimFunc f, IRModule mod, PassContext ctx) { + return MetaScheduleTuner(target, work_dir, max_trials_global).TuneTIR(f, ctx); + }; + return tir::transform::CreatePrimFuncPass(/*pass function*/ pass_func, /*opt level*/ 0, + /*pass name*/ "MetaScheduleTuneTIR", + /*required*/ {}, + /*traceable*/ true); +} + +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleApplyDatabase") + .set_body_typed(MetaScheduleApplyDatabase); +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneIRMod").set_body_typed(MetaScheduleTuneIRMod); +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneTIR").set_body_typed(MetaScheduleTuneTIR); +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/tuning_api/database.cc b/src/relax/transform/tuning_api/database.cc new file mode 100644 index 000000000000..0d239e5fbf81 --- /dev/null +++ b/src/relax/transform/tuning_api/database.cc @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/tuning_api/database.cc + * \brief Database of tuning APIs. + */ +#include + +#include +#include +#include + +#include "../../../meta_schedule/utils.h" + +namespace tvm { +namespace meta_schedule { + +void JSONFileAppendLine(const String& path, const std::string& line); +std::vector JSONFileReadLines(const String& path, int num_threads, bool allow_missing); + +} // namespace meta_schedule +} // namespace tvm + +namespace tvm { +namespace relax { + +TuningRecord::TuningRecord(Trace trace, Optional> run_secs) { + ObjectPtr n = make_object(); + n->trace = trace; + n->run_secs = run_secs; + this->data_ = n; +} + +ObjectRef TuningRecordNode::AsJSON(bool include_irmod) const { + return Array{trace->AsJSON(include_irmod), // + run_secs}; +} + +TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj) { + Trace trace{nullptr}; + Optional> run_secs{nullptr}; + try { + const ArrayNode* json_array = json_obj.as(); + CHECK(json_array && json_array->size() == 2); + // Load json[0] => trace + { + const ObjectRef& json_trace = json_array->at(0); + trace = Trace::FromJSON(json_trace); + } + + // Load json[1] => run_secs + if (json_array->at(1).defined()) { + run_secs = meta_schedule::AsFloatArray(json_array->at(1)); + } + } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + return TuningRecord(trace, run_secs); +} + +/*! \brief The struct defining comparison function of sorting by mean run seconds. */ +struct SortTuningRecordByMeanRunSecs { + static const constexpr double kMaxMeanTime = 1e10; + + static double Mean(const Array& a) { + if (a.empty()) { + return kMaxMeanTime; + } + double sum = 0.0; + for (const FloatImm& i : a) { + sum += i->value; + } + return sum / a.size(); + } + + bool operator()(const TuningRecord& a, const TuningRecord& b) const { + double a_time = Mean(a->run_secs.value_or({})); + double b_time = Mean(b->run_secs.value_or({})); + return a_time < b_time; + } +}; + +// TODO(tvm-team): Currently, we strictly treat each target separately. +// Since not every option in the target matters, this might be the overkill. +// Revisit this when we have better approach with target equality check. +inline std::string get_database_key(int workload_idx, Target target) { + return std::to_string(workload_idx) + "/" + target->str(); +} + +/*! \brief The default database implementation, which mimics two database tables with two files. + */ +class JSONDatabaseNode : public DatabaseNode { + public: + /*! \brief The path to the workload table */ + String path_workload; + /*! \brief The path to the tuning record table */ + String path_tuning_record; + /*! \brief The path to the measurement table */ + String path_measurement_record; + /*! \brief All the workloads in the database */ + std::unordered_map + workloads2idx_; + /*! \brief All the tuning records in the database */ + std::unordered_map> + tuning_records_; + + /*! \brief Measurement logs in the database */ + std::unordered_map> measurement_records_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("path_workload", &path_workload); + v->Visit("path_tuning_record", &path_tuning_record); + v->Visit("path_measurement_record", &path_measurement_record); + // `workloads2idx_` is not visited + // `tuning_records_` is not visited + // `measurement_records_` is not visited + } + + static constexpr const char* _type_key = "relax.tuning_api.JSONDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode); + + public: + bool HasWorkload(const IRModule& mod) { + return workloads2idx_.find(meta_schedule::Workload(mod, tvm::StructuralHash()(mod))) != + workloads2idx_.end(); + } + + bool HasMeasurementRecord(const meta_schedule::Workload& workload, const Target& target) { + int workload_idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(workload_idx, target); + return measurement_records_.count(key) > 0; + } + + bool HasTuningRecord(const meta_schedule::Workload& workload, const Target& target) { + int workload_idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(workload_idx, target); + return tuning_records_.count(key) > 0; + } + + meta_schedule::Workload CommitWorkload(const IRModule& mod) { + // Try to insert `mod` into `workloads_` + decltype(this->workloads2idx_)::iterator it; + bool inserted = false; + std::tie(it, inserted) = + this->workloads2idx_.emplace(meta_schedule::Workload(mod, tvm::StructuralHash()(mod)), -1); + meta_schedule::Workload workload = it->first; + // If `mod` is new in `workloads2idx_`, append it to the workload file + if (inserted) { + it->second = static_cast(this->workloads2idx_.size()) - 1; + meta_schedule::JSONFileAppendLine(this->path_workload, + meta_schedule::JSONDumps(workload->AsJSON())); + } + return it->first; + } + + void CommitMeasurementRecord(const meta_schedule::Workload& workload, const Target& target, + const Array& run_secs) { + int workload_idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(workload_idx, target); + + if (measurement_records_[key].size() == 0) { + measurement_records_[key] = run_secs; + meta_schedule::JSONFileAppendLine(this->path_measurement_record, + meta_schedule::JSONDumps(Array{ + Integer(workload_idx), target->Export(), + run_secs // + })); + } else { + LOG(WARNING) << "Measurement record for " << key + << " already exists. Use the existing one instead."; + } + } + + void CommitTuningRecord(const meta_schedule::Workload& workload, const Target& target, + const TuningRecord& record) { + int workload_idx = this->workloads2idx_.at(workload); + // There may exist multiple tuning records (with different traces) for a single key pair. + std::string key = get_database_key(workload_idx, target); + this->tuning_records_[key].insert(record); + + meta_schedule::JSONFileAppendLine( + this->path_tuning_record, meta_schedule::JSONDumps(Array{ + Integer(workload_idx), target->Export(), record->AsJSON()})); + } + + Array GetTopK(const meta_schedule::Workload& workload, const Target& target, + int top_k) { + CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; + if (top_k == 0) { + return {}; + } + Array results; + results.reserve(top_k); + int counter = 0; + int idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(idx, target); + for (const TuningRecord& record : this->tuning_records_[key]) { + results.push_back(record); + if (++counter == top_k) { + break; + } + } + + return results; + } + + Array GetMeasurementRecord(const meta_schedule::Workload& workload, + const Target target) { + int workload_idx = this->workloads2idx_.at(workload); + return this->measurement_records_[get_database_key(workload_idx, target)]; + } +}; + +Database Database::JSONDatabase(String path_workload, String path_tuning_record, + String path_measurement_record, bool allow_missing) { + int num_threads = std::thread::hardware_concurrency(); + ObjectPtr n = make_object(); + // Load `n->workloads2idx_` from `path_workload` + std::vector workloads; + { + std::vector json_objs = + meta_schedule::JSONFileReadLines(path_workload, num_threads, allow_missing); + int n_objs = json_objs.size(); + n->workloads2idx_.reserve(n_objs); + workloads.reserve(n_objs); + for (int i = 0; i < n_objs; ++i) { + meta_schedule::Workload workload = meta_schedule::Workload::FromJSON(json_objs[i]); + n->workloads2idx_.emplace(workload, i); + workloads.push_back(workload); + } + } + // Load `n->tuning_records_` from `path_tuning_record` + { + std::vector json_objs = + meta_schedule::JSONFileReadLines(path_tuning_record, num_threads, allow_missing); + + std::vector workload_idxs; + std::vector targets; + std::vector records; + int size = json_objs.size(); + workload_idxs.resize(size, -1); + targets.resize(size, Target{nullptr}); + records.resize(size, TuningRecord{nullptr}); + support::parallel_for_dynamic( + 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { + const ObjectRef& json_obj = json_objs[task_id]; + try { + const ArrayNode* arr = json_obj.as(); + ICHECK_EQ(arr->size(), 3); + workload_idxs[task_id] = Downcast(arr->at(0)).IntValue(); + targets[task_id] = Target(Downcast>(arr->at(1))); + records[task_id] = TuningRecord::FromJSON(arr->at(2)); + } catch (std::runtime_error& e) { + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + }); + + for (int i = 0; i < size; i++) { + std::string key = get_database_key(workload_idxs[i], targets[i]); + n->tuning_records_[key].insert(records[i]); + } + } + + // Load `n->measuremet_log` from `path_measurement_record` + { + std::vector json_objs = + meta_schedule::JSONFileReadLines(path_measurement_record, num_threads, allow_missing); + std::vector workload_idxs; + std::vector targets; + std::vector> measurements; + int size = json_objs.size(); + workload_idxs.resize(size, -1); + targets.resize(size, Target{nullptr}); + measurements.resize(size, Array({})); + support::parallel_for_dynamic( + 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { + const ObjectRef& json_obj = json_objs[task_id]; + try { + const ArrayNode* arr = json_obj.as(); + ICHECK_EQ(arr->size(), 3); + workload_idxs[task_id] = Downcast(arr->at(0)).IntValue(); + targets[task_id] = Target(Downcast>(arr->at(1))); + measurements[task_id] = meta_schedule::AsFloatArray(arr->at(2)); + } catch (std::runtime_error& e) { + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + }); + for (int i = 0; i < size; i++) { + n->measurement_records_[get_database_key(workload_idxs[i], targets[i])] = measurements[i]; + } + } + + n->path_workload = path_workload; + n->path_tuning_record = path_tuning_record; + n->path_measurement_record = path_measurement_record; + return Database(n); +} + +/**************** FFI ****************/ +TVM_REGISTER_NODE_TYPE(TuningRecordNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecord") + .set_body_typed([](Trace trace, Optional> run_secs) { + return TuningRecord(trace, run_secs); + }); +TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecordAsJSON") + .set_body_method(&TuningRecordNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); + +TVM_REGISTER_OBJECT_TYPE(DatabaseNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasWorkload") + .set_body_method(&DatabaseNode::HasWorkload); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasMeasurementRecord") + .set_body_method(&DatabaseNode::HasMeasurementRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasTuningRecord") + .set_body_method(&DatabaseNode::HasTuningRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitMeasurementRecord") + .set_body_method(&DatabaseNode::CommitMeasurementRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitWorkload") + .set_body_method(&DatabaseNode::CommitWorkload); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitTuningRecord") + .set_body_method(&DatabaseNode::CommitTuningRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetTopK") + .set_body_method(&DatabaseNode::GetTopK); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetMeasurementRecord") + .set_body_method(&DatabaseNode::GetMeasurementRecord); + +TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseJSONDatabase").set_body_typed(Database::JSONDatabase); +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/tuning_api/primitives.cc b/src/relax/transform/tuning_api/primitives.cc new file mode 100644 index 000000000000..ef4a3d41bdf0 --- /dev/null +++ b/src/relax/transform/tuning_api/primitives.cc @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/tuning_api/primitives.cc + * \brief Primitives of tuning APIs. + */ + +#include + +#include "../../../meta_schedule/utils.h" +namespace tvm { +namespace relax { + +Choice::Choice(String transform_func_key, Array transform_func_args, + String constr_func_key, Array constr_func_args) { + ObjectPtr n = make_object(); + n->transform_func_key = std::move(transform_func_key); + n->transform_func_args = std::move(transform_func_args); + n->constr_func_key = std::move(constr_func_key); + n->constr_func_args = std::move(constr_func_args); + data_ = std::move(n); +} + +// TODO(sunggg): Currently, it only supports an array of primitive data types. +ObjectRef ChoiceNode::AsJSON() const { + Array json_transfrom_args, json_constr_args; + for (ObjectRef arg : this->transform_func_args) { + std::string json_arg = tvm::SaveJSON(arg); + std::string b64_arg = meta_schedule::Base64Encode(json_arg); + json_transfrom_args.push_back(String(b64_arg)); + } + for (ObjectRef arg : this->constr_func_args) { + std::string json_arg = tvm::SaveJSON(arg); + std::string b64_arg = meta_schedule::Base64Encode(json_arg); + json_constr_args.push_back(String(b64_arg)); + } + return Array{ + this->transform_func_key, + json_transfrom_args, + this->constr_func_key, + json_constr_args, + }; +} + +Choice Choice::FromJSON(const ObjectRef& json) { + // Parse `json` into `choice` + String transform_func_key, constr_func_key; + Array transform_func_args, constr_func_args; + try { + const ArrayNode* arr = json.as(); + ICHECK(arr && arr->size() == 4); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + const auto* arr2 = arr->at(2).as(); + const auto* arr3 = arr->at(3).as(); + ICHECK(arr0 && arr1 && arr2 && arr3); + transform_func_key = GetRef(arr0); + { + transform_func_args.reserve(arr1->size()); + for (const ObjectRef& elem : *arr1) { + String b64_arg = Downcast(elem); + std::string json_arg = meta_schedule::Base64Decode(b64_arg); + ObjectRef arg = LoadJSON(json_arg); + transform_func_args.push_back(arg); + } + } + constr_func_key = GetRef(arr2); + { + constr_func_args.reserve(arr3->size()); + for (const ObjectRef& elem : *arr3) { + String b64_arg = Downcast(elem); + std::string json_arg = meta_schedule::Base64Decode(b64_arg); + ObjectRef arg = LoadJSON(json_arg); + constr_func_args.push_back(arg); + } + } + } catch (const tvm::Error& e) { + LOG(FATAL) + << "ValueError: The json entry of a choice should contain a set of two strings, but gets: " + << json; + throw; + } + return Choice(transform_func_key, transform_func_args, constr_func_key, constr_func_args); +} + +Knob::Knob(String name, Map choices) { + ObjectPtr n = make_object(); + n->name = std::move(name); + n->choices = std::move(choices); + data_ = std::move(n); +} + +ObjectRef KnobNode::AsJSON() const { + Map json_choices; + for (auto const& x : choices) { + json_choices.Set(x.first, x.second->AsJSON()); + } + return Array{ + /* 0: name */ std::move(name), + /* 1: choices */ std::move(json_choices), + }; +} + +Knob Knob::FromJSON(const ObjectRef& json) { + // Parse `json` into `name` and `choices` + String name; + Map choices; + try { + const ArrayNode* arr = json.as(); + ICHECK(arr && arr->size() == 2); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + ICHECK(arr0 && arr1); + name = GetRef(arr0); + for (auto const& x : GetRef>(arr1)) { + String decision = x.first; + Choice choice = Choice::FromJSON(x.second); + choices.Set(decision, choice); + } + } catch (const tvm::Error& e) { + LOG(FATAL) + << "ValueError: The json entry of a choice should contain a set of two strings, but gets: " + << json; + throw; + } + return Knob(name, choices); +} + +Trace::Trace() { data_ = make_object(); } + +Trace::Trace(IRModule in_mod, Array knobs, Array decisions) { + ICHECK(knobs.size() == decisions.size()) << "Size of knobs and decisions should match"; + // Deep-copy IRModule + auto func_deepcopy = runtime::Registry::Get("relax.tuning_api.deepcopy_irmodule"); + ICHECK(func_deepcopy); + IRModule out_mod = (*func_deepcopy)(in_mod); + // Apply the decision history if provided + int size = knobs.size(); + for (int i = 0; i < size; i++) { + out_mod = knobs[i]->Apply(out_mod, decisions[i]); + } + + ObjectPtr n = make_object(); + n->in_mod = std::move(in_mod); + n->out_mod = std::move(out_mod); + n->knobs = std::move(knobs); + n->decisions = std::move(decisions); + n->size = std::move(size); + data_ = std::move(n); +} + +ObjectRef TraceNode::AsJSON(bool include_in_mod) const { + ICHECK(this->Verify()) << "Trace should be valid"; + + Array json_knobs; + Array json_decisions; + + int size = this->size; + json_knobs.reserve(size); + json_decisions.reserve(size); + + for (int i = 0; i < size; i++) { + const Knob& knob = this->knobs[i]; + const String& decision = this->decisions[i]; + + json_knobs.push_back(knob->AsJSON()); + json_decisions.push_back(decision); + } + if (include_in_mod) { + std::string json_mod = tvm::SaveJSON(this->in_mod); + std::string b64_mod = meta_schedule::Base64Encode(json_mod); + return Array{json_knobs, json_decisions, String(b64_mod)}; + } else { + return Array{json_knobs, json_decisions}; + } +} + +Trace Trace::FromJSON(const ObjectRef& json) { + // Parse `json` into `trace` + IRModule in_mod; + Array knobs; + Array decisions; + try { + const ArrayNode* arr = json.as(); + // A trace will have 2 or 3 entries depending on `include_irmod` parameter. + ICHECK(arr && (arr->size() == 2 || arr->size() == 3)); + + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + ICHECK(arr0 && arr1); + + for (const ObjectRef& elem : *arr0) { + knobs.push_back(Knob::FromJSON(elem)); + } + + for (const ObjectRef& elem : *arr1) { + decisions.push_back(Downcast(elem)); + } + + // When `include_irmod = true` + if (arr->size() == 3) { + const auto* arr2 = arr->at(2).as(); + String b64_mod = GetRef(arr2); + ICHECK(arr2); + std::string json_mod = meta_schedule::Base64Decode(b64_mod); + in_mod = Downcast(LoadJSON(json_mod)); + } + } catch (const tvm::Error& e) { + LOG(FATAL) << "ValueError: Malformed Trace format - " << json; + throw; + } + return Trace(in_mod, knobs, decisions); +} + +/**************** FFI ****************/ +TVM_REGISTER_NODE_TYPE(ChoiceNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.Choice") + .set_body_typed([](String transform_func_key, Array transform_func_args, + String constr_func_key, Array constr_func_args) { + return Choice(transform_func_key, transform_func_args, constr_func_key, constr_func_args); + }); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceAsJSON").set_body_method(&ChoiceNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceFromJSON").set_body_typed(Choice::FromJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetTransformFunc") + .set_body_method(&ChoiceNode::GetTransformFunc); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetConstrFunc") + .set_body_method(&ChoiceNode::GetConstrFunc); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceApplyTransformFunc") + .set_body_method(&ChoiceNode::ApplyTransformFunc); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceCheckConstr") + .set_body_method(&ChoiceNode::CheckConstr); + +TVM_REGISTER_NODE_TYPE(KnobNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.Knob") + .set_body_typed([](String name, Map choices) { return Knob(name, choices); }); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobAsJSON").set_body_method(&KnobNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobFromJSON").set_body_typed(Knob::FromJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobIsValidDecision") + .set_body_method(&KnobNode::IsValidDecision); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobApply").set_body_method(&KnobNode::Apply); + +TVM_REGISTER_NODE_TYPE(TraceNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.Trace") + .set_body_typed([](IRModule in_mod, Array knobs, Array decisions) { + return Trace(in_mod, knobs, decisions); + }); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceVerify").set_body_method(&TraceNode::Verify); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceAdd").set_body_method(&TraceNode::Add); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceSetPerf").set_body_method(&TraceNode::SetPerf); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceSetOutMod") + .set_body_method(&TraceNode::SetOutMod); + +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceAsJSON").set_body_method(&TraceNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceFromJSON").set_body_typed(Trace::FromJSON); +} // namespace relax +} // namespace tvm diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index fc1f3a15077e..dd31a1f7367d 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -154,8 +154,8 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return FunctionPass(pass_func, pass_info); } diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index d2eb48073f7d..a152bbe9c3cb 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -950,7 +950,7 @@ TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const E }); Pass InferType() { - auto pass_info = PassInfo(0, "InferType", {}); + auto pass_info = PassInfo(0, "InferType", {}, /* trace */ false); return tvm::transform::CreateModulePass( [=](IRModule mod, const PassContext& pass_ctx) { // Execute the pass function and return a new module. diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 4c59a1767372..781a0ecd7c3d 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -115,8 +115,8 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return PrimFuncPass(pass_func, pass_info); } diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py b/tests/python/relax/test_transform_meta_schedule_tuning.py new file mode 100644 index 000000000000..ff695b9436a3 --- /dev/null +++ b/tests/python/relax/test_transform_meta_schedule_tuning.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tempfile + +import tvm +import tvm.testing +import tvm.meta_schedule as ms +from tvm import relax +from tvm.ir import transform +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext +from tvm.relax.transform.tuning_api import Trace +from tvm.script import relax as R +from tvm.script import tir as T + +target = tvm.target.Target("llvm --num-cores=16") + + +@tvm.script.ir_module +class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + k = T.var("int32") + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + C = T.match_buffer(z, (32, 32)) + + for (i0, j0, k0) in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + @T.prim_func + def tir_relu(x: T.handle, y: T.handle): + T.func_attr({"global_symbol": "tir_relu"}) + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + for (i, j) in T.grid(32, 32): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.max(A[vi, vj], 0.0) + + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir(tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) + R.output(lv1) + return lv1 + + +# TODO(@sunggg): determine how to pass MS database object across different passes. +# PassContext might be an option, but we already have TuningAPI database. +# (MS database and TuningAPI database will be unified in the future) +# For now, we only support default JSON database config. +def test_ms_tuning_irmodule(): + + mod = InputModule + assert isinstance(mod, IRModule) + + with tempfile.TemporaryDirectory() as work_dir: + with target, transform.PassContext(trace=Trace(mod), opt_level=0): + tuning_pass = relax.transform.MetaScheduleTuneIRMod( + params={}, work_dir=work_dir, max_trials_global=4 + ) + out_mod = tuning_pass(mod) + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + tvm.ir.assert_structural_equal(mod, out_mod) + + application_pass = relax.transform.MetaScheduleApplyDatabase(work_dir) + + out_mod = application_pass(mod) + assert not tvm.ir.structural_equal(mod, out_mod) + + +def test_ms_tuning_primfunc(): + mod = InputModule + assert isinstance(mod, IRModule) + with tempfile.TemporaryDirectory() as work_dir: + with target, transform.PassContext(trace=Trace(mod), opt_level=0): + tuning_pass = relax.transform.MetaScheduleTuneTIR( + work_dir=work_dir, max_trials_global=4 + ) + out_mod = tuning_pass(mod) + assert PassContext.current().get_trace_stack_size() == 1 + # TODO (@sunggg): Need to determine how to track subgraph-level tuning traces. + # Currently, we don't track this so the trace size. Revisit this later. + tvm.ir.assert_structural_equal(mod, out_mod) + + application_pass = relax.transform.MetaScheduleApplyDatabase(work_dir) + out_mod = application_pass(mod) + assert not tvm.ir.structural_equal(mod, out_mod) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tuning_api.py b/tests/python/relax/test_tuning_api.py new file mode 100644 index 000000000000..b12ff016705d --- /dev/null +++ b/tests/python/relax/test_tuning_api.py @@ -0,0 +1,781 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import numpy as np +import os.path as osp +import tempfile +from typing import List +from math import isclose + +import tvm +from tvm import ir +from tvm.ir import transform +from tvm.ir.transform import PassContext +from tvm.ir.module import IRModule +from tvm.script import tir as T, relax as R +from tvm import relax +from tvm.relax.expr import Expr, DataflowBlock, Function +from tvm.relax.transform.tuning_api import ( + Choice, + Knob, + Trace, + TuningRecord, + JSONDatabase, + default_generate_candidate, + default_consider_eval_passes, + default_evaluate, + select_best_candidate, + get_trace, +) + + +@tvm.script.ir_module +class TestModule: + @T.prim_func + def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16), "int32"]) -> None: + T.func_attr(({"global_symbol": "addone"})) + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + # Input IRModule. + @R.function + def before(c0: R.Tensor((16, 16), "int32")): + lv0 = R.call_tir(addone, (c0,), R.Tensor((16, 16), dtype="int32")) + return lv0 + + # Expected IRModule after transformation. + @R.function + def expected(c1: R.Tensor((16, 16), "int32")): + lv0 = c1 + return c1 + + +def gen_mod(mod, name, binding): + funcs = {} + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + + for k, v in mod.functions.items(): + if isinstance(v, tvm.relax.Function): + if k.name_hint == name: + # rename to main. + gv = tvm.ir.GlobalVar("main") + funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_struct_info).with_attr( + "global_symbol", "main" + ) + else: + funcs[k] = v + mod = tvm.IRModule(funcs) + return relax.transform.BindParams("main", binding)(mod) + + +# Setup for simple testing with IRModule. +def setup_test(): + mod = TestModule + assert isinstance(mod, tvm.IRModule) + return gen_mod(mod, "before", {}) + + +# Setup for testing with constant folding. +def setup_test_const_folding(): + mod = TestModule + assert isinstance(mod, tvm.IRModule) + # Test setup. + c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(mod, "before", {"c0": c0_np}) + expected = gen_mod(mod, "expected", {"c1": c1_np}) + + return before, expected + + +# Define a choice by using FoldConstant pass. +@tvm.register_func("testing.apply_fold_constant") +def apply_fold_constant(mod): + return relax.transform.FoldConstant()(mod) + + +@tvm.register_func("testing.add_global_symbol") +def add_global_symbol(mod, func_name, global_symbol): + mod[func_name] = mod[func_name].with_attr("global_symbol", global_symbol) + return mod + + +@tvm.register_func("testing.check_num_functions") +def check_num_funcs(mod, N): + # Explicit type specification is necessary. + # Otherwise, PackedFunc cannot derive the return type correctly. + # e.g., Check failed: type_code_ == kDLInt (8 vs. 0) : expected int but got Object + return bool(len(mod.functions) == N) + + +def test_choice(): + # Test setup. + ( + before, + expected, + ) = setup_test_const_folding() + + # Without any argument, default setting will be used for both transformation and constraint functions. + # default transformation function will return the original IRModule without any change. + choice = Choice( + # - transform_func_key="relax.tuning_api.Choice.default_transform_func" + # - constr_func_key="relax.tuning_api.Choice.default_constr_func") + ) + # Load transformation function from the choice and apply it. + after = choice.apply_transform_func(before) + tvm.ir.assert_structural_equal(after, before) + + choice = Choice("testing.apply_fold_constant") + # Load transformation function from the choice and apply it. + after = choice.apply_transform_func(before) + tvm.ir.assert_structural_equal(after, expected) + + # Create a choice that tags global symbol onto target function. + choice = Choice("testing.add_global_symbol", ["addone", "test-symbol"]) + after = choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "test-symbol" + # The transformation should be applied with Copy-On-Write. + # So, the original module should be unchanged. + assert before["addone"].attrs["global_symbol"] == "addone" + + # Test choice with impossible constraint + choice = Choice( + transform_func_key="testing.add_global_symbol", + transform_func_args=["addone", "test-symbol"], + constr_func_key="testing.check_num_functions", + constr_func_args=[1000], + ) + # Since the constraint is not met, it should return the original function + after = choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "addone" + + # Test choice with the proper constraint + choice = Choice( + transform_func_key="testing.add_global_symbol", + transform_func_args=["addone", "test-symbol"], + constr_func_key="testing.check_num_functions", + constr_func_args=[2], + ) + # Since the constraint is not met, it should return the original function + after = choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "test-symbol" + # The original module should be unchanged. + assert before["addone"].attrs["global_symbol"] == "addone" + + # Test roundtrip. + # Export as JSON. + json_obj = choice.as_json() + # Import JSON. + new_choice = Choice.from_json(json_obj) + # Test imported choice + after = new_choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "test-symbol" + # The original module should be unchanged. + assert before["addone"].attrs["global_symbol"] == "addone" + + +def test_knob(): + # Test setup. + before, expected = setup_test_const_folding() + + # Users can define a set of choices with list. + choices = [ + Choice("testing.apply_fold_constant"), + Choice(), + ] + + # Define knob. + knob = Knob("TestKnob", choices) + # Check the sanity of decision space. + assert knob.verify(0) + assert knob.verify(1) + assert not knob.verify(3) + + # Check the sanity of each decision. + after_apply = knob.apply(before, 0) + after_noapply = knob.apply(before, 1) + + tvm.ir.assert_structural_equal(after_apply, expected) + tvm.ir.assert_structural_equal(after_noapply, before) + + # Users can define a set of choices with dict. + choices = { + "apply": Choice("testing.apply_fold_constant"), + "noapply": Choice(), + "apply_with_impossible_constr": Choice( + transform_func_key="testing.apply_fold_constant", + constr_func_key="testing.check_num_functions", + constr_func_args=[1000], + ), + } + # Define knob. + knob = Knob("TestKnob", choices) + assert knob.verify("apply") + assert knob.verify("noapply") + assert knob.verify("apply_with_impossible_constr") + assert not knob.verify("INVLAID") + + after_apply = knob.apply(before, "apply") + after_noapply = knob.apply(before, "noapply") + # Because constr was not satisfied, it will return the original IRModule + after_apply_with_constr = knob.apply(before, "apply_with_impossible_constr") + tvm.ir.assert_structural_equal(after_apply, expected) + tvm.ir.assert_structural_equal(after_noapply, before) + tvm.ir.assert_structural_equal(after_apply_with_constr, before) + + # Test roundtrip. + # Export as JSON. + json_obj = knob.as_json() + # Import JSON. + new_knob = Knob.from_json(json_obj) + assert new_knob.name == knob.name + # Test imported knob + assert new_knob.verify("apply") + assert new_knob.verify("noapply") + assert new_knob.verify("apply_with_impossible_constr") + assert not new_knob.verify("INVLAID") + + after_apply = new_knob.apply(before, "apply") + after_noapply = new_knob.apply(before, "noapply") + # Because constr was not satisfied, it will return the original IRModule + after_apply_with_constr = knob.apply(before, "apply_with_impossible_constr") + tvm.ir.assert_structural_equal(after_apply, expected) + tvm.ir.assert_structural_equal(after_noapply, before) + tvm.ir.assert_structural_equal(after_apply_with_constr, before) + + +def test_trace(): + before, expected = setup_test_const_folding() + + # Define choices and its knob. + choices = { + "apply": Choice( + transform_func_key="testing.apply_fold_constant", + transform_func_args=[], + constr_func_key="testing.check_num_functions", + constr_func_args=[2], + ), + "noapply": Choice(), + } + knob = Knob("TestKnob", choices) + + # Define a Trace with empty decision (transformation) history. + trace = Trace(before) + assert trace.size == 0 + + # Define a Trace with single decision (transformation) history. + trace = Trace(before, [knob], ["noapply"]) + assert trace.size == 1 + tvm.ir.assert_structural_equal(trace.in_mod, before) + tvm.ir.assert_structural_equal(trace.out_mod, before) + + # Add a new knob and its decision to the trace. + # It will update the current trace and returns its new output IRModule. + out: IRModule = trace.add(knob, "noapply") + assert trace.size == 2 + tvm.ir.assert_structural_equal(trace.in_mod, before) + tvm.ir.assert_structural_equal(trace.out_mod, before) + tvm.ir.assert_structural_equal(out, before) + # Assume we assign arbitrary performance number. + trace.set_perf(100) + assert trace.perf == 100 + + # Add a new knob and its decision to the trace. + out: IRModule = trace.add(knob, "apply") + tvm.ir.assert_structural_equal(trace.in_mod, before) + tvm.ir.assert_structural_equal(trace.out_mod, expected) + tvm.ir.assert_structural_equal(out, expected) + + assert trace.size == 3 + # Should be initalized when new knob is applied. + assert trace.perf == -1 + + # Test roundtrip. + # Export as JSON. + json_obj = trace.as_json() + # Import JSON. + new_trace = Trace.from_json(json_obj) + tvm.ir.assert_structural_equal(trace.in_mod, new_trace.in_mod) + assert str(trace) == str(new_trace) + assert new_trace.size == 3 + tvm.ir.assert_structural_equal(trace.out_mod, new_trace.out_mod) + + +def test_trace_wrapper(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + assert isinstance(Trace(mod), Trace) + assert isinstance(get_trace(mod), Trace) + assert isinstance(get_trace(mod["main"]), Trace) + assert isinstance(get_trace(mod["addone"]), Trace) + + +def create_tmp_database(tmpdir: str) -> JSONDatabase: + path_workload = osp.join(tmpdir, "workloads.json") + path_tuning_record = osp.join(tmpdir, "tuning_records.json") + path_measurement_record = osp.join(tmpdir, "measurement_records.json") + return JSONDatabase(path_workload, path_tuning_record, path_measurement_record) + + +def test_database(): + def equal_measurement_record(a: List[float], b: List[float]): + assert len(a) == len(b) + for i in range(len(a)): + assert isclose(a[i], b[i], rel_tol=1e-5) + + def equal_tuning_record(a: TuningRecord, b: TuningRecord): + assert str(a.trace) == str(b.trace) + equal_measurement_record(a.run_secs, b.run_secs) + + # Test setup. + ( + mod1, + mod2, + ) = setup_test_const_folding() + knob = Knob("test", {"noapply": Choice()}) + trace = Trace(mod1, [knob, knob], ["noapply", "noapply"]) + target = tvm.target.Target("llvm") + + # Test roundtrip + run_secs = [1.0, 0.9, 0.4] + tuning_record = TuningRecord( + trace, + run_secs, + ) + new_tuning_record = TuningRecord.from_json(json_obj=tuning_record.as_json()) + equal_tuning_record(tuning_record, new_tuning_record) + + with tempfile.TemporaryDirectory() as tmpdir: + database = create_tmp_database(tmpdir) + workload1 = database.commit_workload(mod1) + + database.commit_measurement_record(workload1, target, run_secs) + new_run_secs1 = database.get_measurement_record(workload1, target) + equal_measurement_record(run_secs, new_run_secs1) + workload2 = database.commit_workload(mod2) + new_run_secs2 = database.get_measurement_record(workload2, target) + assert len(new_run_secs2) == 0 + + database.commit_tuning_record(workload1, target, tuning_record) + new_tuning_records = database.get_top_k(workload1, target, top_k=1) + assert len(new_tuning_records) == 1 + equal_tuning_record(tuning_record, new_tuning_records[0]) + new_tuning_records = database.get_top_k(workload1, target, top_k=0) + assert len(new_tuning_records) == 0 + + +def test_default_functions(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + + # Define choice, knob, trace. + choices = {"apply": Choice("testing.apply_fold_constant"), "noapply": Choice()} + knob = Knob("TestKnob", choices) + trace = Trace(mod) + + # Launch a pass pipeline in trace mode. + with tempfile.TemporaryDirectory() as tmpdir: + database = create_tmp_database(tmpdir) + with transform.PassContext(trace=trace, tuning_api_database=database): + # Default generation function expands every valid choice. + candidates = default_generate_candidate([knob], trace) + assert len(candidates) == 2 + + # Default evaluate function uses MetaSchedule builder/runner. + # Since builder/runner are not provided, local builder/runner will be used. + default_evaluate(candidates, "llvm --num-cores=16") + assert PassContext.current().num_evals == 2 + + # Because these candidates are already evaluated, num_evals stays the same. + default_evaluate(candidates, "llvm --num-cores=16") + assert PassContext.current().num_evals == 2 + + # Test with multiple knobs + candidates = default_generate_candidate([knob, knob], trace) + assert len(candidates) == 4 + + # Launch new pass pipeline in trace mode. + with transform.PassContext(trace=trace, tuning_api_database=database): + candidates = default_generate_candidate([knob], trace) + assert len(candidates) == 2 + # Provide tuning pass as an eval pass. + # Note that MockConstFoldingTuningPass() has its own generation function, evaluation function. + # Evaluation would be done in a tornament fashion. + # `default_consider_eval_passes` will convert candidates into the best version by considering eval_passes. + # For example, if we say candidates = [C1, C2] + # `default_consider_eval_passes` will return best form of C1 variant (C11 vs C12) and C2 variant (C21 vs C22) + # that can be generated by eval_passes. + # Assume C11 > C12, C21 < C22, + # new_candidates = [C11, C22] + new_candidates = default_consider_eval_passes( + candidates, [MockConstFoldingTuningPass(eval_passes=[])] + ) + + # len(candidates) == len(new candidates). + assert len(new_candidates) == 2 + # To find the best version of each candidate, it would take 4 evals (C11, C12, C21, C22). + assert PassContext.current().num_evals == 4 + + HeuristicPass = relax.transform.FoldConstant + with transform.PassContext(trace=trace, tuning_api_database=database): + candidates = default_generate_candidate([knob], trace) + assert len(candidates) == 2 + # Provide heuristic pass as an eval pass. + new_candidates = default_consider_eval_passes(candidates, [HeuristicPass()]) + # Since heuristic pass has single decision, it won't need any tornament. + # new_candidates = [C11, C21] + assert len(new_candidates) == 2 + # We only conduct evaluation when its necessary (e.g., choose better candidate in tuning pass). + # Heuristic pass won't conduct any evaluation. + assert PassContext.current().num_evals == 0 + + +# TODO(sunggg): Do we need to serialize pass context as well? +def test_pass_context(): + before, expected = setup_test_const_folding() + HeuristicPass = relax.transform.FoldConstant + # FoldConstant implicitly performs TIR passes (prob for constant evaluation). + # If make_traceable is not provided, the pass infra will make every non-traceable pass traceable by default. + seq = transform.Sequential([HeuristicPass()]) + with transform.PassContext( + trace=Trace(before), + ): + after = seq(before) + tvm.ir.assert_structural_equal(after, expected) + assert PassContext.current().get_trace_stack_size() == 1 + # The exact number of implicit passes might change as TVM develops more passes. + # As of today, this size returns 57. + assert PassContext.current().get_current_trace().size > 1 + + # We can explicitly specify which pass we want to keep track of. + with transform.PassContext(trace=Trace(before), make_traceable=["FoldConstant"]): + after = seq(before) + tvm.ir.assert_structural_equal(after, expected) + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + + # Check the functionality of trace stack. + with transform.PassContext(trace=Trace(before)): + assert PassContext.current().get_trace_stack_size() == 1 + PassContext.current().push_trace(Trace(before)) + assert PassContext.current().get_trace_stack_size() == 2 + PassContext.current().pop_trace() + assert PassContext.current().get_trace_stack_size() == 1 + PassContext.current().pop_trace() + assert PassContext.current().get_trace_stack_size() == 0 + + +# Mock evaluation pass for testing. +# Assigns arbitrary performance number to each candidate. +def mock_evaluate(candidates: List[Trace], target_str: str, ctx: PassContext): + num_evals = 0 + # Evaluation + for candidate in candidates: + # If this candidate is already evaluated, skip the measurement. + if candidate.perf != -1: + continue + + num_evals += 1 + # Assign arbitrary performance. + mock_perf = 100 - (ctx.num_evals + num_evals) + candidate.set_perf(mock_perf) + # Update number of evals for testing. + ctx.inc_num_evals(num_evals) + + +# Mock tuning pass that determines whether to apply relax.transform.FoldConstant(). +# Each pass invocation will generate two candidates for the incoming IRModule. +# In relax pass infra, each pass will define its own way of generating candidates and evaluating them without needing to know how other passes generate its candidate and evaluate them. +# This will significantly alleviate the development process since it is known to be HARD problem to consider the interaction with (potentially hundreds of) other passes. +@ir.transform.module_pass(opt_level=0, traceable=True) +class MockConstFoldingTuningPass(transform.Pass): + def __init__( + self, + f_generate_candidate=None, + f_evaluate=mock_evaluate, + eval_passes: List[transform.Pass] = None, + required: List[transform.Pass] = [], + ): + self.f_generate_candidate = ( + f_generate_candidate if f_generate_candidate else default_generate_candidate + ) + self.f_evaluate = f_evaluate if f_evaluate else default_evaluate + self.eval_passes = eval_passes + self.required = required + + def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: + trace = ctx.pop_trace() + + # Create mock choices for testing. + choices = {"apply": Choice("testing.apply_fold_constant"), "noapply": Choice()} + # Tuning pass manages a set of transformation functions registered via knob. + knob = Knob("MockTuningKnob", choices) + + candidates = self.f_generate_candidate([knob], trace, self.eval_passes) + self.f_evaluate(candidates, "llvm", ctx) + best_trace = select_best_candidate(candidates) + + ctx.push_trace(best_trace) + return best_trace.out_mod + + +def test_module_pass(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + # Test setup + c0 = np.arange((16 * 16)).astype("int32").reshape(16, 16) + mod = relax.transform.BindParams("main", {"c0": tvm.nd.array(c0)})(mod) + HeuristicPass = relax.transform.FoldConstant + + # Tuning pass without any eval_pass. + mock_pass = MockConstFoldingTuningPass(eval_passes=[]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + + # Heuristic pass should not affect the number of candidates. + mock_pass = MockConstFoldingTuningPass(eval_passes=[HeuristicPass()]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 2 + + # Joint-optimization will increase the search space in the combinatorial way + mock_pass = MockConstFoldingTuningPass(eval_passes=[MockConstFoldingTuningPass(eval_passes=[])]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 2 + + # Joint-optimization can be nested. + mock_pass = MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass(eval_passes=[MockConstFoldingTuningPass(eval_passes=[])]) + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * 2 * 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 3 + + # Tuning pass and heuritic passes can be used together. + # Note that heuristic pass won't increate the search space (num_evals). + # It only increases the length of the trace. + mock_pass = MockConstFoldingTuningPass( + eval_passes=[ + HeuristicPass(), + MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass(eval_passes=[HeuristicPass(), HeuristicPass()]) + ] + ), + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * 2 * 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 6 + + # Users can mix-use sequential application and joint-application. + mock_pass = MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * (2 + 2 + 2) + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 4 + + +def test_sequential(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + # Test setup. + c0 = np.arange((16 * 16)).astype("int32").reshape(16, 16) + mod = relax.transform.BindParams("main", {"c0": tvm.nd.array(c0)})(mod) + HeuristicPass = relax.transform.FoldConstant + + # Sequential with a single tuning pass should behave same with a single pass. + seq = transform.Sequential([MockConstFoldingTuningPass(eval_passes=[])]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + + # Sequential pass should increase search space (num_evals) in additive manner. + seq = transform.Sequential( + [ + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == 2 + 2 + 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 3 + + # Heuristic pass will not increase the search space. Just increase trace length. + seq = transform.Sequential( + [ + MockConstFoldingTuningPass(eval_passes=[]), + HeuristicPass(), + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + HeuristicPass(), + ] + ) + + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == 2 + 2 + 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 5 + + # Users can mix-use sequential application and joint-application. + seq = transform.Sequential( + [ + HeuristicPass(), + MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass( + eval_passes=[ + HeuristicPass(), + ] + ) + ] + ), + ] + ), + MockConstFoldingTuningPass(eval_passes=[]), + HeuristicPass(), + ] + ) + + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == (2 * 2 * 2) + 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 7 + + +def test_passes_with_mixed_granularities(): + @tvm.script.ir_module + class MockModule: + @R.function + def f1(x: R.Tensor(("m", "n"), "float32")): + with R.dataflow(): + lv0 = R.multiply(x, x) + gv0 = R.add(x, x) + R.output(gv0) + return gv0 + + @R.function + def main(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): + with R.dataflow(): + lv0 = R.multiply(x, y) + gv0 = R.add(lv0, y) + R.output(gv0) + gv1 = R.multiply(x, y) + gv2 = R.add(gv1, y) + return (gv0, gv1, gv2) + + mod = MockModule + assert isinstance(mod, tvm.IRModule) + + # Helper function for tuning + def pass_func( + mod: IRModule, ctx: PassContext, eval_passes: List[transform.Pass] = None + ) -> IRModule: + trace = ctx.pop_trace() + + # Create mock choices for testing + choices = [Choice(), Choice(), Choice()] + # Tuning pass manages a set of transformation functions registered via knob. + knob = Knob("MockTuningKnob", choices) + + candidates = default_generate_candidate([knob], trace, eval_passes) + mock_evaluate(candidates, "llvm", ctx) + best_trace = select_best_candidate(candidates) + + ctx.push_trace(best_trace) + return best_trace.out_mod + + @ir.transform.module_pass(opt_level=0, traceable=True) + def MockModulePass(mod: IRModule, ctx: PassContext) -> IRModule: + # Input granularity == Candidate granularity. + return pass_func(mod, ctx) + + @relax.transform.function_pass(opt_level=0, traceable=True) + def MockFunctionPass(func: Expr, mod: IRModule, ctx: PassContext) -> Function: + # Input granularity > Candidate granularity. + # Start trace with smaller granularity: IRModule->Function. + ctx.push_trace(Trace(IRModule.from_expr(func))) + # Do something. + pass_func(mod, ctx) + # Pop tuned trace and recover the previous trace. + ctx.pop_trace() + return func + + @relax.transform.dataflowblock_pass(opt_level=0, traceable=True) + def MockDataflowBlockPass( + block: DataflowBlock, mod: IRModule, ctx: PassContext + ) -> DataflowBlock: + # TODO(sunggg): figure out how to create IRModule from DataflowBlock + # Provide random binding for now + x = relax.Var("x", R.Tensor([tvm.tir.Var("n", "int64")], "float32")) + seq_expr = relax.SeqExpr([block], x) + func = relax.Function([x], seq_expr, R.Tensor("float32", ndim=-1)) + ctx.push_trace(Trace(IRModule.from_expr(func))) + # Do something + pass_func(mod, ctx) + ctx.pop_trace() + return block + + seq = transform.Sequential( + [ + MockModulePass, + MockFunctionPass, + MockDataflowBlockPass, + ] + ) + + with transform.PassContext(trace=Trace(mod), make_traceable=[]): + _ = seq(mod) + # Trace length and num eval can be different depending on how each function/dataflow block is treated. + assert PassContext.current().get_trace_stack_size() == 1 + + +if __name__ == "__main__": + pytest.main([__file__]) From eeb40ac34835083724075a7f734560f67b861828 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Fri, 17 Feb 2023 11:37:12 -0800 Subject: [PATCH 30/81] [Unity] Relay -> Relax translator (#14026) This PR implements a Relay to Relax translator, which allows us to import Relay workloads to Relax for benchmarking and development purposes (tests and examples are added). --- apps/relax_examples/e2e_auto_tir.py | 253 ++++++++++++++++ apps/relax_examples/mlp.py | 57 ++++ apps/relax_examples/nn_module.py | 69 +++++ apps/relax_examples/resnet.py | 53 ++++ python/tvm/relax/testing/__init__.py | 1 + python/tvm/relax/testing/relay_translator.py | 251 ++++++++++++++++ python/tvm/relax/testing/transform.py | 125 ++++++++ src/relay/backend/utils.cc | 7 + tests/python/relax/test_relay_translator.py | 300 +++++++++++++++++++ 9 files changed, 1116 insertions(+) create mode 100644 apps/relax_examples/e2e_auto_tir.py create mode 100644 apps/relax_examples/mlp.py create mode 100644 apps/relax_examples/nn_module.py create mode 100644 apps/relax_examples/resnet.py create mode 100644 python/tvm/relax/testing/relay_translator.py create mode 100644 python/tvm/relax/testing/transform.py create mode 100644 tests/python/relax/test_relay_translator.py diff --git a/apps/relax_examples/e2e_auto_tir.py b/apps/relax_examples/e2e_auto_tir.py new file mode 100644 index 000000000000..92cda16f7927 --- /dev/null +++ b/apps/relax_examples/e2e_auto_tir.py @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import datetime +import os +import csv +import json +import argparse +import logging +from typing import Dict +import numpy as np # type: ignore + +import tvm +from tvm import relay, relax, runtime, transform +from tvm.ir.module import IRModule +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.relay_workload import get_network +from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc +from tvm.relax.testing import relay_translator +from tvm.target.target import Target + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--input-shape", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + default=None, + ) + args.add_argument( + "--rpc-port", + type=int, + default=None, + ) + args.add_argument( + "--rpc-key", + type=str, + default=None, + ) + args.add_argument( + "--work-dir", + type=str, + required=True, + ) + args.add_argument( + "--cache-dir", + type=str, + default=None, + ) + args.add_argument( + "--rpc-timeout-sec", + type=int, + default=180, + ) + args.add_argument("--num-measurement-repeats", type=int, default=5) + args.add_argument("--num-measurements", type=int, default=10) + args.add_argument("--results-file", type=str, required=False, default=None) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.input_shape = json.loads(parsed.input_shape) + if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu": + parsed.alloc_repeat = 3 + else: + parsed.alloc_repeat = 1 + if parsed.rpc_host and parsed.rpc_port and parsed.rpc_key: + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=parsed.rpc_timeout_sec, + ) + parsed.workers = parsed.rpc_config.count_num_servers(allow_missing=False) + else: + # check all rpc configs are None + assert ( + (parsed.rpc_host is None) and (parsed.rpc_port is None) and (parsed.rpc_key is None) + ), "Please set all 'rpc_host', 'rpc_port' and 'rpc_key' to use PRC server" + parsed.rpc_config = None + parsed.workers = 1 + return parsed + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +ARGS = _parse_args() + + +def apply_opt_before_tuning( + relay_mod: IRModule, params: Dict[str, runtime.NDArray], target: Target +): + with transform.PassContext(opt_level=3): + main_func = relay_mod["main"] + bind_main_func = relay.build_module.bind_params_by_name(main_func, params) + relay_mod = IRModule.from_expr(bind_main_func) + relay_mod = relay.transform.SimplifyInference()(relay_mod) + relay_mod = relay.transform.FoldConstant()(relay_mod) + relay_mod = relay.transform.FoldScaleAxis()(relay_mod) + relay_mod = relay.transform.CanonicalizeOps()(relay_mod) + relay_mod = relay.transform.AlterOpLayout()(relay_mod) + relay_mod = relay.transform.FoldConstant()(relay_mod) + + relax_mod = relay_translator.from_relay(relay_mod["main"], target=target) + relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod) + relax_mod = relax.transform.FuseOps()(relax_mod) + relax_mod = relax.transform.FuseTIR()(relax_mod) + return relax_mod + + +def f_measurement( + rt_mod: runtime.Module, device: runtime.ndarray.Device, input_data: Dict[str, runtime.NDArray] +): + vm = relax.vm.VirtualMachine(exec=rt_mod, device=device) + vm.save_function("main", "measure_func", **input_data, include_return=False) + evaluator = vm.time_evaluator( + func_name="measure_func", + dev=device, + repeat=ARGS.num_measurement_repeats, + number=ARGS.num_measurements, + min_repeat_ms=500, + ) + return evaluator() + + +def get_runner(): + runner_config = { + "evaluator_config": ms.runner.EvaluatorConfig( + number=3, + repeat=1, + min_repeat_ms=100, + enable_cpu_cache_flush=False, + ), + "alloc_repeat": ARGS.alloc_repeat, + } + if ARGS.rpc_config: + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, max_workers=ARGS.workers, **runner_config + ) + else: + runner = ms.runner.LocalRunner(**runner_config) + + return runner + + +def main(): + relay_mod, params, (input_name, input_shape, input_dtype) = get_network( + ARGS.workload, + ARGS.input_shape, + cache_dir=ARGS.cache_dir, + ) + input_info = {input_name: input_shape} + input_data = {} + for input_name, input_shape in input_info.items(): + print(f" input_name: {input_name}") + print(f" input_shape: {input_shape}") + print(f" input_dtype: {input_dtype}") + + # translate the ResNet model from Relay to Relax + relax_mod = apply_opt_before_tuning(relay_mod, params, target=ARGS.target) + assert isinstance(relax_mod, tvm.IRModule) + + db = ms.relax_integration.tune_relax( + mod=relax_mod, + target=ARGS.target, + params=params, + num_trials_per_iter=64, + max_trials_per_task=ARGS.num_trials, + max_trials_global=ARGS.num_trials, + runner=get_runner(), + work_dir=ARGS.work_dir, + ) + executable = ms.relax_integration.compile_relax( + db, + mod=relax_mod, + target=ARGS.target, + params=params, + ) + + for input_name, input_shape in input_info.items(): + if input_dtype.startswith("float"): + input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype) + else: + input_data[input_name] = np.random.randint( + low=0, high=10000, size=input_shape, dtype=input_dtype + ) + + # for documentation purposes + start_time = datetime.datetime.now() + + if ARGS.rpc_config: + result = run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=executable.mod, + dev_type=ARGS.target.kind.name, + args=input_data, + continuation=f_measurement, + ) + else: + dev = tvm.device(ARGS.target.kind.name) + result = f_measurement(executable.mod, dev, input_data) + + print(result) + + if not ARGS.results_file: + return + + out_path = os.path.abspath(os.path.expanduser(ARGS.results_file)) + with open(out_path, "w") as out_file: + writer = csv.writer(out_file) + # write experiment parameters at the top as a record + writer.writerow(["start", str(start_time)]) + writer.writerow(["workload", ARGS.workload]) + writer.writerow(["input_shape", ARGS.input_shape]) + writer.writerow(["target", ARGS.target]) + writer.writerow(["num_measurement_repeats", ARGS.num_measurement_repeats]) + for res in result.results: + writer.writerow([str(res)]) + + +if __name__ == "__main__": + main() diff --git a/apps/relax_examples/mlp.py b/apps/relax_examples/mlp.py new file mode 100644 index 000000000000..02e17dc3041a --- /dev/null +++ b/apps/relax_examples/mlp.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Example code on creating, compiling, and running an MLP model in relax + + +import tvm +from tvm import relax, tir, topi +import numpy as np + + +def build_mlp(data, weight): + bb = relax.BlockBuilder() + + with bb.function("mlp", [data, weight]): + gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False) + gv1 = bb.emit_te(topi.nn.relu, gv0) + bb.emit_func_output(gv1) + + mod = bb.get() + return mod + + +if __name__ == "__main__": + # symbolic dimensions + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + # create data and weight variables + data = relax.Var("data", relax.TensorStructInfo([n, m], "float32")) + weight = relax.Var("weight", relax.TensorStructInfo([m, n], "float32")) + + # construct a mlp model + mod = build_mlp(data, weight) + + # build and create vm executor + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + # run the mlp model on relax vm + data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + res = vm["mlp"](data, weight) + print(res) diff --git a/apps/relax_examples/nn_module.py b/apps/relax_examples/nn_module.py new file mode 100644 index 000000000000..b57cb00685ae --- /dev/null +++ b/apps/relax_examples/nn_module.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Example code on creating, compiling, and running a neural network with pytorch-like API + + +import tvm +from tvm.relay import Call +from tvm import relax, tir +from tvm.relax.testing import nn +from tvm.script import relax as R +import numpy as np + + +if __name__ == "__main__": + builder = relax.BlockBuilder() + + # a symbolic variable to represent minibatch size + n = tir.Var("n", "int64") + input_size = 784 + hidden_sizes = [128, 32] + output_size = 10 + + # build a three linear-layer neural network for a classification task + with builder.function("main"): + model = nn.Sequential( + nn.Linear(input_size, hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], output_size), + nn.LogSoftmax(), + ) + data = nn.Placeholder((n, input_size), name="data") + output = model(data) + params = [data] + model.parameters() + builder.emit_func_output(output, params=params) + + # get and print the IRmodule being built + mod = builder.get() + mod.show() + + # build the IRModule and create relax vm + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + # init parameters + params = nn.init_params(mod) + + # run the model on relax vm + # the input data has a minibatch size of 3 + data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32)) + res = vm["main"](data, *params) + print(res) diff --git a/apps/relax_examples/resnet.py b/apps/relax_examples/resnet.py new file mode 100644 index 000000000000..df0cab02f19c --- /dev/null +++ b/apps/relax_examples/resnet.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example ResNet workload by translating the Relay program to Relax""" + +import tvm +import tvm.testing +from tvm.relay import testing +from tvm import relax, relay +from tvm.relax.testing import relay_translator, nn +from tvm.runtime import vm as vm_rt +from tvm.script import relax as R +import numpy as np + +if __name__ == "__main__": + relay_mod, _ = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32") + + # translate the ResNet model from Relay to Relax + target = tvm.target.Target("llvm", host="llvm") + relax_mod = relay_translator.from_relay(relay_mod["main"], target) + + # print the ResNet IRmodule got translated + relax_mod.show() + + # build the IRModule and create relax vm + ex = relax.vm.build(relax_mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + # init weights and run the model on relax vm + shape = (1, 3, 224, 224) + data = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + params = nn.init_params(relax_mod) + res = vm["main"](data, *params) + + # check correctness by comparing with relay result + exe = relay.vm.compile(relay_mod, target) + relay_vm = vm_rt.VirtualMachine(exe, tvm.cpu()) + inputs = [data] + params + expected_output = relay_vm.run(*inputs) + tvm.testing.assert_allclose(res.numpy(), expected_output.numpy(), rtol=1e-4, atol=1e-4) diff --git a/python/tvm/relax/testing/__init__.py b/python/tvm/relax/testing/__init__.py index ab1dd6f5155e..7344798f70dc 100644 --- a/python/tvm/relax/testing/__init__.py +++ b/python/tvm/relax/testing/__init__.py @@ -18,3 +18,4 @@ """The Relax testing namespace containing nn and translator.""" from .nn import * +from .relay_translator import * diff --git a/python/tvm/relax/testing/relay_translator.py b/python/tvm/relax/testing/relay_translator.py new file mode 100644 index 000000000000..fd5aab89fa76 --- /dev/null +++ b/python/tvm/relax/testing/relay_translator.py @@ -0,0 +1,251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, invalid-name, no-else-return, too-many-nested-blocks +"""Relay to Relax translator.""" + +from typing import Any, Dict, List, Optional + +import tvm +from tvm import relax, relay +from tvm.ir.module import IRModule +from tvm.relax.testing import nn +from tvm.relay.backend.te_compiler import select_implementation +from tvm.runtime import NDArray +from tvm.target import Target +from tvm.meta_schedule.relay_integration import _autotvm_silencer + + +def from_relay( + func: relay.Function, + target: Target, + relay_params: Optional[Dict[str, NDArray]] = None, + *, + opt_level: int = 3, + pass_config: Optional[Dict[str, Any]] = None, + disabled_pass: Optional[List[str]] = None, + translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] = None, +) -> IRModule: + """Convert a Relay function into a Relax program. + + Parameters + ---------- + func : relay.Function + Relay function to be converted. + + target: Target + The target to compile the model, used for selecting topi functions. + + relay_params: Optional[Dict[str, NDArray]] + Parameters to bind. + + opt_level: int + The optimization level. + + pass_config: Optional[Dict[str, Any]] + Pass configuration. + + disabled_pass: Optional[List[str]] + Passes to disable. + + translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] + Dict that maps op names to user-defined PrimFuncs. + Takes relay operator names and forces them to user-defined PrimFuncs during translation. + + Returns + ------- + mod : tvm.IRModule + The Relax IRModule for compilation + """ + # A map to store the mapping of Relay Expr to its corresponding Relax var + var_map = {} + # The output of the function + output_var = None + + if not isinstance(target, Target): + target = Target(target) + if disabled_pass is None: + disabled_pass = [] + if pass_config is None: + pass_config = { + "relay.FuseOps.max_depth": 1, # Disable relay fusion + "relay.backend.use_meta_schedule": True, + "relay.backend.use_meta_schedule_dispatch": True, + } + + if relay_params: + func = relay.build_module.bind_params_by_name(func, relay_params) + + params = [] + tir_var_map: Dict[tvm.tir.Var, tvm.tir.PrimExpr] = dict() + + def convert_shape(shape: List[tvm.tir.PrimExpr]) -> List[tvm.tir.PrimExpr]: + """Convert the relay shape to relax shape by changing Any dim to symbolic dim""" + ret = [] + for dim in shape: + if isinstance(dim, tvm.tir.IntImm): + ret.append(tvm.tir.IntImm("int64", int(dim))) + elif isinstance(dim, tvm.tir.Any): + ret.append(tvm.tir.Var("d", "int64")) + else: + ret.append(dim) + return ret + + def _copy_undefined_var_in_shape(sinfo: relax.TensorStructInfo): + def _visit_expr(e: tvm.tir.PrimExpr): + if isinstance(e, tvm.tir.Var) and e not in tir_var_map: + new_var = tvm.tir.Var(e.name, e.dtype) + tir_var_map[e] = new_var + + assert isinstance( + sinfo.shape, relax.ShapeExpr + ), "arg with TensorStructInfo in Relay translator must have ShapeExpr shape" + for shape_value in sinfo.shape.values: + tvm.tir.stmt_functor.post_order_visit(shape_value, _visit_expr) + + def visit_func(node): + nonlocal output_var + if isinstance(node, relay.Var): + if isinstance(node.type_annotation, relay.TensorType): + var_map[node] = nn.Placeholder( + tuple(convert_shape(node.type_annotation.shape)), + node.type_annotation.dtype, + node.name_hint, + ) + params.append(var_map[node]) + else: + raise TypeError("The type of relay.Var to be translated must be of TensorType.") + elif isinstance(node, relay.Call): + args = node.args + new_args = [] + te_inputs = [] + for arg in args: + if arg in var_map: + arg_expr = var_map[arg] + if isinstance(arg_expr.struct_info, relax.TensorStructInfo): + _copy_undefined_var_in_shape(arg_expr.struct_info) + new_args.append(arg_expr) + te_inputs.append(tvm.relax.expr.te_tensor(arg_expr, tir_var_map)) + elif isinstance(arg_expr.struct_info, relax.TupleStructInfo): + n_tensor = len(arg_expr.struct_info.fields) + bound_tuple = bb.lookup_binding(arg_expr) + if isinstance(bound_tuple, relax.Tuple): + assert len(bound_tuple) == n_tensor + for i in range(n_tensor): + if isinstance(bound_tuple, relax.Tuple): + item = bb.emit(bound_tuple[i]) + else: + item = bb.emit(relax.TupleGetItem(arg_expr, i)) + + assert isinstance(item.struct_info, relax.TensorStructInfo), ( + "Relay translator doesn't support Call " + "argument being nested Tensor tuple." + ) + _copy_undefined_var_in_shape(item.struct_info) + new_args.append(item) + te_inputs.append(tvm.relax.expr.te_tensor(item, tir_var_map)) + else: + raise TypeError( + f"CallTIR argument type being {type(arg_expr.checked_type)} is not " + "supported." + ) + + op_name = node.op.name + attrs = node.attrs + out_type = node.checked_type + + if translate_op_with_tir and op_name in translate_op_with_tir: + tir_gvar = bb.add_func(translate_op_with_tir[op_name], op_name) + call = relax.call_tir( + tir_gvar, new_args, relax.TensorStructInfo(out_type.shape, out_type.dtype) + ) + var = bb.emit(call) + else: + with target: + best_impl, outputs = select_implementation( + node.op, + attrs, + te_inputs, + out_type, + target, + use_autotvm=False, + ) + compute_func = best_impl.compute + name_hint = op_name.split(".")[-1] + var = bb.emit_te( + compute_func, + attrs, + new_args, + node.checked_type, + primfunc_name_hint=name_hint, + ) + + output_var = var + var_map[node] = var + elif isinstance(node, relay.Constant): + # fill the shape and checked_type fields of the Constant + new_constant = relax.Constant(node.data) + var_map[node] = new_constant + elif isinstance(node, relay.Tuple): + new_fields = [] + for field in node.fields: + if field in var_map: + new_fields.append(var_map[field]) + else: + raise RuntimeError("field is not in var_map.") + new_tuple = relax.Tuple(new_fields) + new_tuple_var = relax.BlockBuilder.current().emit(new_tuple) + var_map[node] = new_tuple_var + output_var = new_tuple_var + elif isinstance(node, relay.TupleGetItem): + if node.tuple_value in var_map: + new_tuple = var_map[node.tuple_value] + new_tuple_get_item_node = relax.TupleGetItem(new_tuple, node.index) + new_tuple_get_item_var = relax.BlockBuilder.current().emit(new_tuple_get_item_node) + var_map[node] = new_tuple_get_item_var + output_var = new_tuple_get_item_var + else: + raise RuntimeError("tuple is not in var_map") + elif isinstance(node, relay.Function): + cur_bb = relax.BlockBuilder.current() + gv = cur_bb.emit_output(output_var) + df_block = cur_bb._end_block() + cur_bb._blocks.append(df_block) + cur_bb.emit_func_output(gv, params) + elif isinstance(node, tvm.ir.Op): + pass + else: + raise TypeError("{} is not supported yet.".format(str(type(node)))) + + # List of subset of relay->relay optimizations + # See src/relay/backend/utils.cc::GetPassPrefix() for full list + seq = tvm.get_global_func("relay.backend.GetPassPrefixSeq")(True, True) + + # Since optimization passes and OpStrategy are highly context-dependent, + # we match the exact same context with `extract_task_from_relay()` env + with _autotvm_silencer(), tvm.transform.PassContext( + opt_level=opt_level, + config=pass_config, + disabled_pass=disabled_pass, + ): + mod = tvm.IRModule.from_expr(func) + mod = seq(mod) + bb = relax.BlockBuilder() + with bb.function("main"): + bb._begin_dataflow_block() + relay.analysis.post_order_visit(mod["main"], visit_func) + + return bb.get() diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py new file mode 100644 index 000000000000..c8ca618d4c1a --- /dev/null +++ b/python/tvm/relax/testing/transform.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ +"""Relax transformation passes for testing""" + +from tvm import ir +from tvm import relax +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext +from tvm.target import Target +from tvm.ir import transform +from tvm.relax import PyExprMutator +from tvm.relax.expr import Call +from tvm.relay.backend.te_compiler import select_implementation + + +@ir.transform.module_pass(opt_level=0) +class LowerWithRelayOpStrategyPass(transform.Pass): + """Lower Relax Op into TIR by using Relay OpStrategy. + + Since operators like conv2d, add, matmul are relay-, relax- independent, + this pass assumes we can always find relay op equivalent for such relax ops, + and use Relay Op Strategy (legacy) to perform lowering and find the TOPI implementation. + + Parameters + ---------- + target : Target + target info + + Returns + ------- + pass : transform.Pass + lowering pass + """ + + def __init__(self, target: Target): + self.target = target + + def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: + """Implement lowering mechanism. + + Parameters + ---------- + mod : IRModule + Input IRModule with Relax ops + + ctx: PassContext + Pass context + + Returns + ------- + out_mod : IRModule + Output IRModule with lowered TIR functions + """ + target = self.target + + @relax.expr_functor.mutator + class Lowerer(PyExprMutator): + """Mutator that performs lowering.""" + + def visit_call_(self, call_node: Call): + # Ignore function calls + # We only target calls for operators + if isinstance(call_node.op, (relax.GlobalVar, relax.expr.ExternFunc)): + return call_node + + # Current relax op name simply adds "relax." prefix to relay op name. + # Thus, remove "relax." prefix to deduce relay op name. + relay_op_name = call_node.op.name[6:] + # Check if equivalent relay op exists. If not, return the original call. + if relay_op_name in ir.Op.list_op_names(): + relay_op = ir.Op.get(relay_op_name) + + # Todo(relax-team): to be revisited - support dyn shape or deprecate. + tir_var_map = dict() + te_inputs = [relax.expr.te_tensor(arg, tir_var_map) for arg in call_node.args] + best_impl_tuple = select_implementation( + relay_op, + call_node.attrs, + te_inputs, + call_node.checked_type, + target, + use_autotvm=False, + ) + compute_func = best_impl_tuple[0].compute + # Extract the name of the operator without the prefix + # e.g., for relay op "nn.conv2d", name_hint would be conv2d + name_hint = relay_op_name.split(".")[-1] + + return self.builder_.call_te( + compute_func, + call_node.attrs, + call_node.args, + call_node.attrs, + primfunc_name_hint=name_hint, + ) + else: + return call_node + + # TOOD(@team): transform() wapper is necessary to include TIR functions. + # IMO, this is bit unintuitive. Can we improve this? + def transform(self): + for gv, func in mod.functions.items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + self.builder_.update_func(gv, updated_func) + new_mod = self.builder_.get() + new_mod = new_mod.with_attrs(mod.attrs) if mod.attrs else new_mod + return new_mod + + return Lowerer().transform() diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 4ff8a59b349e..3fb1c89c286e 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -443,6 +443,13 @@ TVM_REGISTER_GLOBAL("relay.backend.tir_converter.allow_extern") return DefaultTIRConverterImpl(args, constants, true); }); +TVM_REGISTER_GLOBAL("relay.backend.GetPassPrefixSeq") + .set_body_typed([](bool is_homogeneous, bool is_vm) { + auto pass_seqs = GetPassPrefix(is_homogeneous, is_vm); + transform::Sequential seq(pass_seqs); + return seq; + }); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/tests/python/relax/test_relay_translator.py b/tests/python/relax/test_relay_translator.py new file mode 100644 index 000000000000..5f7e05b02d3a --- /dev/null +++ b/tests/python/relax/test_relay_translator.py @@ -0,0 +1,300 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tempfile + +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import meta_schedule as ms +from tvm import relax, relay, tir, topi +from tvm.ir.base import assert_structural_equal +from tvm.relax.testing import relay_translator +from tvm.relay import testing +from tvm.runtime import vm +from tvm.script import tir as T +from tvm.target import Target + + +def get_resnet(batch_size, dtype, layout, image_shape): + relay_mod, params = testing.resnet.get_workload( + num_layers=18, + batch_size=batch_size, + dtype=dtype, + layout=layout, + image_shape=image_shape, + ) + + return relay_mod, params + + +def relay_build_and_run(mod, target, dev, params, data): + with tempfile.TemporaryDirectory() as work_dir: + db = ms.relay_integration.tune_relay( + mod=mod, + params=params, + target=target, + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=1024, + task_scheduler="round-robin", + work_dir=work_dir, + ) + ex = ms.relay_integration.compile_relay( + db, + mod=mod, + target=target, + params=params, + ) + rt_mod = tvm.contrib.graph_executor.GraphModule(ex["default"](dev)) + rt_mod.set_input("data", data) + rt_mod.run() + out = rt_mod.get_output(0).numpy() + return ex, rt_mod, out + + +def relax_build_and_run(mod, target, dev, params, data): + mod = relax.transform.BindParams("main", params)(mod) + with tempfile.TemporaryDirectory() as work_dir: + db = ms.relax_integration.tune_relax( + mod=mod, + target=target, + task_scheduler="round-robin", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=1024, + work_dir=work_dir, + ) + ex = ms.relax_integration.compile_relax( + db, + mod=mod, + target=target, + params=params, + ) + vm = relax.VirtualMachine(ex, dev) + res = vm["main"](data) + out = res.numpy() + return ex, vm, out + + +def verify_e2e_translation(target_str, layout, batch_size, image_shape): + target = Target(target_str) + dev = tvm.device(str(target), dev_id=0) + relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape) + input_shape = (1, *image_shape) + data = tvm.nd.array(np.random.rand(*input_shape).astype(np.float32), dev) + relax_mod = relay_translator.from_relay(relay_mod["main"], target, params) + assert relax_mod["main"].attrs["global_symbol"] == "main" + + _, _, relay_out = relay_build_and_run(relay_mod, target, dev, params, data) + _, _, relax_out = relax_build_and_run(relax_mod, target, dev, params, data) + tvm.testing.assert_allclose(relay_out, relax_out, atol=1e-5, rtol=1e-5) + + +@pytest.mark.skip(reason="take too much time") +@pytest.mark.parametrize( + "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 1, (224, 224, 3))] +) +def test_verify_e2e_translation_cpu(layout, batch_size, image_shape): + verify_e2e_translation("llvm --num-cores=16", layout, batch_size, image_shape) + + +@pytest.mark.skip(reason="take too much time") +@tvm.testing.requires_gpu +@pytest.mark.parametrize( + "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 1, (224, 224, 3))] +) +def test_verify_e2e_translation_gpu(layout, batch_size, image_shape): + verify_e2e_translation("cuda", layout, batch_size, image_shape) + + +def verify_extracted_tasks(target_str, layout, batch_size, image_shape): + target = Target(target_str) + relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape) + relax_mod = relay_translator.from_relay( + relay_mod["main"], + target, + params, + pass_config={ + "relay.backend.use_meta_schedule": True, + "relay.FuseOps.max_depth": 1, # Disable relay fusion + }, + ) + relay_tasks = ms.relay_integration.extract_tasks( + relay_mod, + target=target, + params=params, + pass_config={ + "relay.backend.use_meta_schedule": True, + "relay.FuseOps.max_depth": 1, # Disable relay fusion + }, + ) + relax_tasks = ms.relax_integration.extract_tasks( + relax_mod, + target=target, + params=params, + ) + # TODO (yongwww, yuchen): tophub guides relay passes, which causes inconsistent tasks + # assert len(relay_tasks) == len(relax_tasks) + # TODO: Can we compare extracted tasks as well? + + +@pytest.mark.parametrize( + "layout, batch_size, image_shape", + [ + ("NCHW", 1, (3, 224, 224)), + ("NHWC", 1, (224, 224, 3)), + ], +) +def test_verify_extracted_tasks_cpu(layout, batch_size, image_shape): + verify_extracted_tasks("llvm --num-cores=16", layout, batch_size, image_shape) + + +@tvm.testing.requires_gpu +@pytest.mark.parametrize( + "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 1, (224, 224, 3))] +) +def test_verify_extracted_tasks_gpu(layout, batch_size, image_shape): + verify_extracted_tasks("cuda", layout, batch_size, image_shape) + + +def translate_and_build_vms(relay_mod, target_str="llvm", translate_op_with_tir=None): + target = tvm.target.Target(target_str) + + # build the relay IRModule and create relay vm + relay_ex = relay.vm.compile(relay_mod, target) + relay_vm = vm.VirtualMachine(relay_ex, tvm.cpu()) + + # build the relax IRModule and create relax vm + relax_mod = relay_translator.from_relay( + relay_mod["main"], target, translate_op_with_tir=translate_op_with_tir + ) + relax_ex = relax.vm.build(relax_mod, target) + relax_vm = relax.VirtualMachine(relax_ex, tvm.cpu()) + + return relay_vm, relax_vm, relax_mod + + +def verify_vm_outputs( + input_shape, + relay_vm, + relax_vm, + extra_args=[], +): + input = tvm.nd.array(np.random.rand(*input_shape).astype(np.float32)) + + # check correctness by comparing relax and relay result + args = [input] + extra_args + relax_output = relax_vm["main"](*args) + relay_output = relay_vm.run(*args) + tvm.testing.assert_allclose(relay_output.numpy(), relax_output.numpy()) + + +def test_single_dynamic_dim(): + wx, wy = 64, 128 + # create relay module: y = data * weights + bias with dynamic batch dimension + data = relay.var("data", shape=(relay.Any(), wx)) + weights = relay.var("weights", shape=(wx, wy)) + bias = relay.var("bias", shape=(wy,)) + y = relay.nn.matmul(data, weights) + relay_mod = tvm.IRModule.from_expr(relay.Function([data, weights, bias], y + bias)) + + relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod) + weights = tvm.nd.array(np.random.rand(wx, wy).astype(np.float32)) + bias = tvm.nd.array(np.random.rand(wy).astype(np.float32)) + # verify for different batch sizes + verify_vm_outputs([10, wx], relay_vm, relax_vm, [weights, bias]) + verify_vm_outputs([32, wx], relay_vm, relax_vm, [weights, bias]) + + +def test_multiple_dynamic_dims(): + # create relay module: y = a + a, where a has shape = (?, 5, ?) + shape = (relay.Any(), 5, relay.Any()) + a = relay.var("a", shape=shape) + + relay_mod = tvm.IRModule.from_expr(relay.Function([a], a + a)) + relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod) + # verify for different shapes + verify_vm_outputs([2, 5, 10], relay_vm, relax_vm) + verify_vm_outputs([12, 5, 24], relay_vm, relax_vm) + + +def test_layout_transform(): + shape = (1, 3, 224, 224) + a = relay.var("a", shape=shape) + b = relay.layout_transform(a, "NCHW", "NHWC") + relay_mod = tvm.IRModule.from_expr(relay.Function([a], b)) + + relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod) + verify_vm_outputs([1, 3, 224, 224], relay_vm, relax_vm) + + +def test_translate_op_with_tir(): + @T.prim_func + def tir_matmul( + A: T.Buffer((512, 512), "float32"), + B: T.Buffer((512, 512), "float32"), + C: T.Buffer((512, 512), "float32"), + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "multiply", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1, i2 in T.grid(512, 512, 512): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(C[i, j], A[i, k], B[k, j]) + T.writes(C[i, j]) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + shape = (512, 512) + a = relay.var("a", shape=shape) + + relay_mod = tvm.IRModule.from_expr(relay.Function([a], a * a)) + _, _, relax_mod = translate_and_build_vms( + relay_mod, translate_op_with_tir={"multiply": tir_matmul} + ) + assert_structural_equal(relax_mod["multiply"], tir_matmul) + + +def test_translate_tuple_arg(): + x = relay.var("x", shape=(10, 16)) + y = relay.var("y", shape=(10, 16)) + relay_mod = tvm.IRModule.from_expr(relay.Function([x, y], relay.concatenate((x, y), axis=-1))) + relax_mod = relay_translator.from_relay(relay_mod["main"], target="llvm") + + # Construct the expected module + bb = relax.BlockBuilder() + x_relax = relax.Var("x", relax.TensorStructInfo([10, 16], "float32")) + y_relax = relax.Var("y", relax.TensorStructInfo([10, 16], "float32")) + with bb.function("main", [x_relax, y_relax]): + with bb.dataflow(): + _ = bb.emit(relax.Tuple((x_relax, y_relax))) + lv1 = bb.emit(x_relax) + lv2 = bb.emit(y_relax) + lv3 = bb.emit_te(topi.x86.concatenate, (lv1, lv2), axis=-1) + gv = bb.emit_output(lv3) + bb.emit_func_output(gv) + + assert_structural_equal(relax_mod, bb.get()) + + +if __name__ == "__main__": + pytest.main([__file__]) From 47722e3b6bd22c960f1e04250a3ea2b393929b87 Mon Sep 17 00:00:00 2001 From: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Fri, 17 Feb 2023 16:25:37 -0800 Subject: [PATCH 31/81] [Unity][Pass] Normalize Pass (#14031) This PR implements relax `Normalize` Pass, which allows users to transform Relax IR to normal form, i.e., the expressions are normalized (no nesting and hence the AST is in ANF), and all `checked_type_` and `shape_` of expressions are available. (tests are added). Co-Authored-by: Yuchen Jin Co-Authored-by: Ruihang Lai Co-Authored-by: Siyuan Feng Co-Authored-by: Tianqi Chen --- include/tvm/relax/transform.h | 9 + python/tvm/relax/transform/transform.py | 11 + src/relax/transform/normalize.cc | 186 ++++++ .../python/relax/test_transform_normalize.py | 554 ++++++++++++++++++ 4 files changed, 760 insertions(+) create mode 100644 src/relax/transform/normalize.cc create mode 100644 tests/python/relax/test_transform_normalize.py diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index e9f63ee9dbc9..7a4054d41405 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -133,6 +133,15 @@ TVM_DLL Pass BindParams(String func_name, Map params); * \return The Pass. */ TVM_DLL Pass FoldConstant(); + +/*! + * \brief Transform Relax IR to normal form: transform AST to A-normal form, and fill the + * checked_type_ and shape_ of expressions. + * + * \return The Pass. + */ +TVM_DLL Pass Normalize(); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index c0ac180ff165..7fcf0b1121d2 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -68,6 +68,17 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass: return _ffi_api.CallTIRRewrite() # type: ignore +def Normalize() -> tvm.ir.transform.Pass: + """Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting + and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are available. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.Normalize() # type: ignore + + def RewriteDataflowReshape() -> tvm.ir.transform.Pass: """Convert all reshape-like call_tir to VM reshape operator call. The VM reshape operator calls will be further lowered to a CreateView diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc new file mode 100644 index 000000000000..915498178f0f --- /dev/null +++ b/src/relax/transform/normalize.cc @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/normalize.cc + * \brief Pass for transforming Relax IR to normal form, i.e., the expressions are normalized(no + * nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are + * available. + */ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +// TODO(@altanh): LCA binding lifting +class NormalizeMutator : public ExprMutatorBase { + public: + NormalizeMutator() { builder_ = BlockBuilder::Create(NullOpt); } + + Expr VisitExpr(const Expr& expr) override { + return builder_->Normalize(ExprMutatorBase::VisitExpr(expr)); + } + + Expr VisitExpr_(const FunctionNode* op) final { + Expr body = this->VisitWithNewScope(op->body, op->params); + + if (body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(op->params, body, op->ret_struct_info, op->attrs); + } + } + + Expr VisitExpr_(const IfNode* op) final { + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitWithNewScope(op->true_branch); + Expr false_b = this->VisitWithNewScope(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } + } + + Expr VisitWithNewScope(const Expr& expr, Optional> params = NullOpt) { + builder_->BeginBindingBlock(); + builder_->BeginScope(params); + Expr ret = this->VisitExpr(expr); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + ret = SeqExpr({prologue}, ret); + } + builder_->EndScope(); + return ret; + } + + Expr VisitExpr_(const SeqExprNode* op) final { + bool all_blocks_unchanged = true; + Array blocks; + for (auto block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + all_blocks_unchanged &= block.same_as(new_block); + } + + builder_->BeginBindingBlock(); + Expr body = this->VisitExpr(op->body); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + blocks.push_back(prologue); + all_blocks_unchanged = false; + } + + if (all_blocks_unchanged && body.same_as(op->body)) { + return GetRef(op); + } else { + return SeqExpr(blocks, body); + } + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) final { + BindingBlock ret; + if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + return ret; + } + + BindingBlock VisitBindingBlock_(const BindingBlockNode* block) { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) { + builder_->BeginDataflowBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); + } + + void VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + VisitBinding_(node); + } else if (const auto* node = binding.as()) { + VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + } + + void VisitBinding_(const VarBindingNode* binding) { + Expr new_value = this->VisitExpr(binding->value); + if (!binding->var->struct_info_.defined()) { + UpdateStructInfo(binding->var, GetStructInfo(new_value)); + } + + if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + builder_->EmitNormalized(VarBinding(binding->var, new_value)); + } + } + + void VisitBinding_(const MatchCastNode* binding) { + Expr new_value = this->VisitExpr(binding->value); + + if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + builder_->EmitNormalized( + MatchCast(binding->var, builder_->NormalizeArgument(new_value), binding->struct_info)); + } + } + + private: + /*! \brief Internal block builder to emit bindings during rewriting. */ + BlockBuilder builder_; +}; // namespace relax + +Expr Normalize(const Expr& e) { return NormalizeMutator().VisitExpr(e); } + +namespace transform { + +Pass Normalize() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(Normalize(f)); }; + return CreateFunctionPass(pass_func, 1, "Normalize", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.Normalize").set_body_typed(Normalize); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py new file mode 100644 index 000000000000..9e9533a5ed23 --- /dev/null +++ b/tests/python/relax/test_transform_normalize.py @@ -0,0 +1,554 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +import tvm.testing +from tvm import relax +from tvm import tir +from tvm.ir.base import assert_structural_equal + +import tvm.script +from tvm.script import tir as T, relax as R + + +def test_normalize_function(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor([m, n], "float16")) + + # Note: the parser automatically normalize the IR written in TVMScript, + # so we manually construct the function here. + mul_add = relax.Function( + [x], + relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)), + ret_struct_info=R.Tensor("float16", ndim=2), + ) + + # Note: from_expr api names private function (function without global_symbol) as "main" + before_mod = tvm.IRModule.from_expr(mul_add) + + after_mod = relax.transform.Normalize()(before_mod) + + @R.function + def expected(x: R.Tensor(("m", "n"), "float16")) -> R.Tensor(dtype="float16", ndim=2): + gv = R.add(x, x) + gv1 = R.add(x, x) + return R.multiply(gv, gv1) + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_if(): + cond = relax.Var("cond", R.Tensor([], "bool")) + x = relax.Var("x", R.Tensor([1], "float32")) + # TODO(relax-team): add type and shape inference for IfNode + y = relax.Var("y") + + # Note: the parser automatically normalize the IR written in TVMScript, + # so we manually construct the function and If here. + f = relax.Function( + [cond, x], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + y, + relax.If( + cond, + relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)), + relax.op.add(relax.op.multiply(x, x), relax.op.multiply(x, x)), + ), + ) + ] + ) + ], + y, + ), + ret_struct_info=R.Tensor("float32", ndim=1), + ) + + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @R.function + def expected( + cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") + ) -> R.Tensor(dtype="float32", ndim=1): + if cond: + gv = R.add(x, x) + gv1 = R.add(x, x) + y = R.multiply(gv, gv1) + else: + gv = R.multiply(x, x) + gv1 = R.multiply(x, x) + y = R.add(gv, gv1) + return y + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_no_op(): + # the normalize pass should be no-op for IR in ANF + @tvm.script.ir_module + class ANFMod1: + @R.function + def f(x: R.Tensor(dtype="float32")): + gv = R.add(x, x) + gv1 = R.add(gv, gv) + gv2 = R.add(gv, gv1) + return (gv, gv2) + + before_mod = ANFMod1 + after_mod = relax.transform.Normalize()(before_mod) + assert_structural_equal(before_mod, after_mod, map_free_vars=True) + + @tvm.script.ir_module + class ANFMod2: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.var("int64"), T.var("int64") + with R.dataflow(): + lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32")) + R.output(gv0) + return gv0 + + mod = ANFMod2 + mod_post = relax.transform.Normalize()(mod) + + assert_structural_equal(mod, mod_post) + + +def test_normalize_seq_body(): + # a seq expression with a non-leaf body should bind the body to a var as well + x = relax.Var("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + seq = relax.SeqExpr([], relax.op.add(x, y)) + f = relax.Function( + [x, y], + seq, + ret_struct_info=R.Tensor([], "int32"), + ) + + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @R.function + def expected( + x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32") + ) -> R.Tensor(ndim=0, dtype="int32"): + # normalization inserts a binding like this + z = R.add(x, y) + return z + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_func_body(): + # a function with a body that is not a seq expr should have it wrapped in a seq expr + x = relax.Var("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + f = relax.Function( + [x, y], + relax.op.add(x, y), + ret_struct_info=R.Tensor([], "int32"), + ) + + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @R.function + def expected( + x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32") + ) -> R.Tensor(ndim=0, dtype="int32"): + # result will be a seq expr where the body is a var + z = R.add(x, y) + return z + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_if_branches(): + # an if node's branches must be seq exprs + x = relax.Var("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + # TODO(@relax-team): z has a shape of () and type of DynTensorType(ndim=0), + # but normalization fails to infer these even though it should + z = relax.Var("z") + cond = relax.Var("cond", R.Tensor([], "bool")) + plus = relax.op.add(x, y) + mult = relax.op.multiply(x, y) + if_node = relax.If(cond, plus, mult) + seq = relax.SeqExpr([relax.BindingBlock([relax.VarBinding(z, if_node)])], z) + f = relax.Function( + [cond, x, y], + seq, + ret_struct_info=R.Tensor([], "int32"), + ) + + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @R.function + def expected( + cond: R.Tensor((), dtype="bool"), + x: R.Tensor((), dtype="int32"), + y: R.Tensor((), dtype="int32"), + ) -> R.Tensor(ndim=0, dtype="int32"): + # the bodies of the branches will be seq exprs with a binding + if cond: + w = R.add(x, y) + z = w + else: + w = R.multiply(x, y) + z = w + return z + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_if_condition(): + cond = relax.Var("cond", R.Tensor([], "bool")) + x = relax.Var("x", R.Tensor([1], "float32")) + # TODO(relax-team): add type and shape inference for IfNode + y = relax.Var("y") + + # The condition is wrapped in a tuple and then indexed + f = relax.Function( + [cond, x], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + y, + relax.If( + relax.TupleGetItem(relax.Tuple([cond]), 0), + relax.op.add(x, x), + relax.op.multiply(x, x), + ), + ) + ] + ) + ], + y, + ), + ret_struct_info=R.Tensor("float32", ndim=1), + ) + + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @R.function + def expected( + cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") + ) -> R.Tensor(dtype="float32", ndim=1): + c = R.TupleGetItem(R.tuple(cond), 0) + if c: + gv = R.add(x, x) + y = gv + else: + gv = R.multiply(x, x) + y = gv + return y + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_tuple_get_item(): + x = relax.Var("x", R.Tensor([], "int32")) + f = relax.Function( + [x], + relax.TupleGetItem( + relax.TupleGetItem( + relax.Tuple([relax.Tuple([x])]), + 0, + ), + 0, + ), + ret_struct_info=R.Tensor([], "int32"), + ) + + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + # TODO: Revisit once we canonicalize SeqExprs (part of normalization?) + # Not using the parser this time because writing it out correctly results in + # *one* binding block, whereas the normalized version has *two* + idx_var = relax.Var("idx_var", R.Tuple([R.Tensor([], "int32")])) + ret_var = relax.Var("ret", R.Tensor([], "int32")) + expected_f = relax.Function( + [x], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + idx_var, relax.TupleGetItem(relax.Tuple([relax.Tuple([x])]), 0) + ) + ] + ), + relax.BindingBlock([relax.VarBinding(ret_var, relax.TupleGetItem(idx_var, 0))]), + ], + ret_var, + ), + ret_struct_info=R.Tensor([], "int32"), + ) + expected_mod = tvm.IRModule.from_expr(expected_f) + # apply normalization to fill in type and shape annotations (tedious otherwise) + final_mod = relax.transform.Normalize()(expected_mod) + + assert_structural_equal(after_mod, final_mod) + + +def test_normalize_combine_nearby_blocks(): + x = relax.Var("x", R.Tensor([], "int32")) + v0 = relax.Var("v0", R.Tensor([], "int32")) + v1 = relax.Var("v1", R.Tensor([], "int32")) + v2 = relax.Var("v2", R.Tensor([], "int32")) + v3 = relax.Var("v3", R.Tensor([], "int32")) + f = relax.Function( + [x], + relax.SeqExpr( + [ + relax.DataflowBlock([relax.VarBinding(v0, x)]), + relax.DataflowBlock([relax.VarBinding(v1, v0)]), + relax.BindingBlock([relax.VarBinding(v2, v1)]), + relax.BindingBlock([relax.VarBinding(v3, v2)]), + ], + v3, + ), + ret_struct_info=R.Tensor([], "int32"), + ) + + after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) + + @R.function + def expected(x: R.Tensor((), "int32")): + with R.dataflow(): + v0 = x + v1 = v0 + R.output(v0, v1) + v2 = v1 + v3 = v2 + return v3 + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_nested_seq(): + x = relax.Var("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + z = relax.Var("z", R.Tensor([], "int32")) + seq = relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding(x, relax.const(1)), + relax.VarBinding( + y, + relax.SeqExpr( + [relax.BindingBlock([relax.VarBinding(z, relax.const(2))])], + z, + ), + ), + ] + ) + ], + y, + ) + + f = relax.Function( + [], + seq, + ret_struct_info=R.Tensor([], "int32"), + ) + after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) + + @R.function + def expected(): + x = relax.const(1) + z = relax.const(2) + y = z + return y + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_nested_seq_dataflow(): + x = relax.Var("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + z = relax.Var("z", R.Tensor([], "int32")) + q = relax.Var("u", R.Tensor([], "int32")) + w = relax.DataflowVar("w", R.Tensor([], "int32")) + u = relax.Var("u", R.Tensor([], "int32")) + seq = relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding(x, relax.const(1)), + relax.VarBinding( + y, + relax.SeqExpr( + [ + relax.BindingBlock([relax.VarBinding(q, relax.const(2))]), + relax.DataflowBlock( + [ + relax.VarBinding(w, q), + relax.VarBinding(u, w), + ] + ), + relax.BindingBlock([relax.VarBinding(z, u)]), + ], + z, + ), + ), + ] + ) + ], + y, + ) + + f = relax.Function( + [], + seq, + ret_struct_info=R.Tensor([], "int32"), + ) + after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) + + @R.function + def expected(): + x = relax.const(1) + q = relax.const(2) + with R.dataflow(): + w = q + u = w + R.output(u) + z = u + y = z + return y + + assert_structural_equal(after_mod["main"], expected) + + +def test_normalize_deeply_nested_seq(): + x = relax.Var("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + z = relax.Var("z", R.Tensor([], "int32")) + u = relax.Var("u", R.Tensor([], "int32")) + v = relax.Var("v", R.Tensor([], "int32")) + w = relax.Var("w", R.Tensor([], "int32")) + _ = relax.Var("w", R.Tensor([], "int32")) + seq = relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding(x, relax.const(1)), + relax.VarBinding( + y, + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + z, + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding(u, relax.const(2)), + relax.MatchCast( + _, u, R.Tensor([], "int32") + ), + relax.VarBinding(v, u), + relax.MatchCast( + w, v, R.Tensor([], "int32") + ), + ] + ) + ], + w, + ), + ) + ] + ) + ], + z, + ), + ), + ] + ) + ], + y, + ) + + f = relax.Function( + [], + seq, + ret_struct_info=R.Tensor([], "int32"), + ) + after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) + + @R.function + def expected(): + x = relax.const(1) + u = relax.const(2) + _ = R.match_cast(u, R.Tensor((), "int32")) + v = u + w = R.match_cast(v, R.Tensor((), "int32")) + z = w + y = z + return y + + assert_structural_equal(after_mod["main"], expected) + + +@pytest.mark.xfail() +def test_nesting_non_dataflow_in_dataflow_error(): + x = relax.DataflowVar("x", R.Tensor([], "int32")) + y = relax.Var("y", R.Tensor([], "int32")) + z = relax.Var("z", R.Tensor([], "int32")) + seq = relax.SeqExpr( + [ + relax.DataflowBlock( + [ + relax.VarBinding(x, relax.const(1)), + relax.VarBinding( + y, + relax.SeqExpr( + [relax.BindingBlock([relax.VarBinding(z, relax.const(2))])], + z, + ), + ), + ] + ) + ], + y, + ) + f = relax.Function( + [], + seq, + ret_struct_info=R.Tensor([], "int32"), + ) + relax.transform.Normalize()(tvm.IRModule.from_expr(f)) + # should fail due to a normal binding block being inside a dataflowblock + + +if __name__ == "__main__": + tvm.testing.main() From bccae02c5267fd3e69447e3df961cb38ce424a88 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 17 Feb 2023 19:26:37 -0500 Subject: [PATCH 32/81] [Unity][BlockBuilder] CallTE convert PrimValue args (#14028) Prior to this PR, the `call_te` of BlockBuilder is not capable of converting PrimValue arguments and directly rejects PrimValues instead. This PR fixes this behavior with PrimValue conversion support and one regression test. Co-authored-by: Siyuan Feng --- python/tvm/relax/block_builder.py | 4 +++- tests/python/relax/test_blockbuilder.py | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 77b45fdf5519..783700847909 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -34,7 +34,7 @@ BaseFunc, Binding, ) -from .struct_info import ShapeStructInfo, StructInfo, TensorStructInfo +from .struct_info import PrimStructInfo, ShapeStructInfo, StructInfo, TensorStructInfo from .op.base import call_tir from . import _ffi_api @@ -256,6 +256,8 @@ def _convert_te_arg_helper(arg): arg, ShapeExpr ), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr" return [_convert_te_arg_helper(val) for val in arg.values] + elif isinstance(arg.struct_info, PrimStructInfo): + return arg.value elif isinstance(arg, (list, tvm.ir.Array)): return [_convert_te_arg_helper(x) for x in arg] elif isinstance(arg, tuple): diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py index 36a22f9712ea..e54e2b7bf943 100644 --- a/tests/python/relax/test_blockbuilder.py +++ b/tests/python/relax/test_blockbuilder.py @@ -23,6 +23,7 @@ from tvm import relax as rx, relay from tvm.ir.base import assert_structural_equal from tvm.relax import ExternFunc +from tvm.script import relax as R from tvm.tir.function import PrimFunc @@ -462,6 +463,29 @@ def test_emit_te_extern(): assert call_node.sinfo_args[0].shape[1] == n +def test_emit_te_prim_value(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", R.Tensor([n, m], "float32")) + a_min = rx.PrimValue(0) + a_max = rx.PrimValue(6) + + with bb.function("rx_clip", [x]): + out = bb.emit_te(topi.clip, x, a_min, a_max) + bb.emit_func_output(out) + + rx_func = bb.get()["rx_clip"] + + # check Relax function calls TIR function with call_tir call + assert rx_func.params[0] == x + assert len(rx_func.body.blocks) == 1 + call_node = rx_func.body.blocks[0].bindings[0].value + assert isinstance(call_node, rx.Call) + assert call_node.op == relay.op.get("relax.call_tir") + assert len(call_node.args) == 2 + assert call_node.args[1][0] == x + + def test_nested_function_fail(): m = tir.Var("m", "int64") n = tir.Var("n", "int64") From 2d9fcfa595692fbaba5441a5253f7920d1150e3c Mon Sep 17 00:00:00 2001 From: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Fri, 17 Feb 2023 18:11:08 -0800 Subject: [PATCH 33/81] [Unity][Pass] Wellformed Analysis (#14032) This PR implements relax wellformed analysis, which checks if the IRModule is well-formed. (tests and examples are added). Co-Authored-by: Ruihang Lai Co-Authored-by: Siyuan Feng Co-Authored-by: Tianqi Chen Co-authored-by: Steven S. Lyubomirsky Co-authored-by: Yong Wu Co-Authored-by: Yuchen Jin Co-Authored-by: Yixin Dong Co-Authored-by: Chaofan Lin Co-Authored-by: Prakalp Srivastava Co-Authored-by: Junru Shao --- include/tvm/relax/analysis.h | 13 + python/tvm/relax/analysis/analysis.py | 27 + python/tvm/relax/ir/instrument.py | 37 ++ src/relax/analysis/well_formed.cc | 465 ++++++++++++++++++ tests/python/relax/conftest.py | 23 + .../python/relax/test_analysis_well_formed.py | 438 +++++++++++++++++ 6 files changed, 1003 insertions(+) create mode 100644 python/tvm/relax/ir/instrument.py create mode 100644 src/relax/analysis/well_formed.cc create mode 100644 tests/python/relax/conftest.py create mode 100644 tests/python/relax/test_analysis_well_formed.py diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index ff576d4ebb6a..f9896efdf272 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -343,6 +343,19 @@ TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func); */ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func); +/*! + * \brief Check if the IRModule is well formed. + * + * \param m the IRModule to check. + * \param check_struct_info A boolean flag indicating if the property "every Expr + * must have defined structure info" will be checked. + * \return true if the IRModule is well formed, false if not. + * \note By default the structure info is always checked. It is only in test cases + * where `check_struct_info` might be false, so that other well-formed requirements + * will be well tested and will not be blocked by not having structure info. + */ +TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true); + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 27416c3a7919..710788347829 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -25,6 +25,7 @@ from enum import IntEnum from tvm import tir +from tvm import IRModule from tvm.relax.ty import Type from tvm.relax.struct_info import StructInfo, FuncStructInfo from tvm.relax.expr import Var, Expr, Call @@ -207,3 +208,29 @@ def has_reshape_pattern(func: tir.PrimFunc) -> bool: of this function. """ return _ffi_api.has_reshape_pattern(func) # type: ignore + + +def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool: + """Check if the IRModule is well formed. + + Parameters + ---------- + mod : tvm.IRModule + The input IRModule. + + check_struct_info : bool + A boolean flag indicating if the property "every Expr must + have defined structure info" will be checked. + + Returns + ------- + ret: bool + True if the IRModule is well formed, False if not. + + Note + ---- + By default the structure info is always checked. It is only in test cases + where `check_struct_info` might be false, so that other well-formed requirements + will be well tested and will not be blocked by not having structure info. + """ + return _ffi_api.well_formed(mod, check_struct_info) # type: ignore diff --git a/python/tvm/relax/ir/instrument.py b/python/tvm/relax/ir/instrument.py new file mode 100644 index 000000000000..fc51a796a7a6 --- /dev/null +++ b/python/tvm/relax/ir/instrument.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Common relax pass instrumentation across IR variants.""" +import tvm +from tvm import relax + + +@tvm.instrument.pass_instrument +class WellFormedInstrument: + """An instrument that checks the input/output IRModule of the Pass + is well formed. It will skip specific passes, like Normalize. + """ + + def __init__(self): + self.skip_pass_name = ["Normalize", "ResolveGlobals"] + + def run_before_pass(self, mod, pass_info): + if pass_info.name not in self.skip_pass_name: + assert relax.analysis.well_formed(mod) + + def run_after_pass(self, mod, pass_info): + if pass_info.name not in self.skip_pass_name: + assert relax.analysis.well_formed(mod) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc new file mode 100644 index 000000000000..e7ec237fd577 --- /dev/null +++ b/src/relax/analysis/well_formed.cc @@ -0,0 +1,465 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/analysis/well_formed.cc + * \brief Check if the IRModule is well-formed. + * + * This pass is supposed to be applied to normalized Relax AST. + * If it's malformed, messages will be logged as Warning. + * This pass will check: + * 1. Each Expr should have `struct_info_` field already populated, when + * `check_struct_info` is true. + * 2. GlobalVars are defined before use. + * 3. When a Function has a corresponding GlobalVar and a `global_symbol` + * attribute, the name of the GlobalVar must equal the value of the + * `global_symbol` attribute value. + * 4. Any variable cannot used as different function parameters in the same IRModule + * 5. Vars are defined before use. + * 6. Vars are defined exactly once. + * 7. Symbolic Vars are defined before use. + * 8. DataflowVars cannot be defined inside BindingBlock. + * 9. Vars defined in IfNode, except the return Var, are invisible + * out of the If body.(May change for new AST designs) + * 10. SeqExpr only serves as function body, or in the true and + * false branches in IfNode. + * 11. The IR is in ANF: + * (a) Expressions cannot contain nested complex expressions. + * Here are the expressions that may be nested inside other expressions: + * Var, DataflowVar, GlobalVar, Constant, ShapeExpr, + * Op, Tuple (we call these "leaf" expressions). + * (b) The right-hand side of a binding may contain a non-leaf expression + * (where all expressions nested in it are leaf expressions), + * other than SeqExprs (see rule 6) + * (c) Exceptions: The body of a Function node and the true branch + * and false branch of If nodes *must* be SeqExprs. + * (d) Places where non-leaf expressions cannot appear: + * * The tuple_value field of TupleGetItem nodes + * * The cond field of If nodes + * * The op or args fields of Call nodes + * * Inside the fields of Tuple nodes + * 12. Expr always has checked_type_ (with the exception of Op). + */ +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +// TODO(relax-team): Consider further refactor using +// Scope Frame to store manage the var context. +// +/*! \brief Helper to implement well formed check.*/ +class WellFormedChecker : public relax::ExprVisitor, + public relax::StructInfoVisitor, + public tir::ExprVisitor { + public: + static bool Check(IRModule mod, bool check_struct_info) { + WellFormedChecker well_formed_checker = WellFormedChecker(mod, check_struct_info); + + for (const auto& it : mod->functions) { + // visit relax.Function + if (auto* n = it.second.as()) { + Function func = GetRef(n); + well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, func); + well_formed_checker.VisitExpr(func); + } + } + return well_formed_checker.well_formed_; + } + + private: + explicit WellFormedChecker(IRModule mod, bool check_struct_info) + : mod_(std::move(mod)), check_struct_info_(check_struct_info) {} + + using relax::ExprVisitor::VisitExpr_; + using tir::ExprVisitor::VisitExpr; + using tir::ExprVisitor::VisitExpr_; + + // Possible mode of visitor + enum class VisitMode { + /*! + * \brief Check all vars are well-defined + */ + kDefault, + /*! + * \brief Match define the vars on first occurance. + * Do not check the well-defined property of composite expr. + */ + kMatchVarDef + }; + + void Malformed(Diagnostic diag) { + well_formed_ = false; + LOG(WARNING) << "This IR is not well formed: " << diag->message; + } + + void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, Function func) { + // check name in global var and gsymbol + Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + if (gsymbol.defined() && gsymbol != var->name_hint) { + Malformed(Diagnostic::Error(func->span) + << "Name in GlobalVar is not equal to name in gsymbol: " << var->name_hint + << " != " << gsymbol.value()); + } + } + + void VisitExpr(const Expr& expr) final { + if (!expr.as() && !expr->checked_type_.defined()) { + Malformed(Diagnostic::Error(expr) << "The checked_type_ of Expr " << expr << " is nullptr."); + } + relax::ExprVisitor::VisitExpr(expr); + } + + void VisitExpr_(const GlobalVarNode* op) final { + GlobalVar var = GetRef(op); + if (!(mod_->ContainGlobalVar(var->name_hint) && + mod_->GetGlobalVar(var->name_hint).same_as(var))) { + Malformed(Diagnostic::Error(var) << "GlobalVar " << op->name_hint << " is not defined."); + } + + if (op->checked_type_.defined()) { + if ((!op->checked_type_->IsInstance()) && + (!op->checked_type_->IsInstance())) { + Malformed(Diagnostic::Error(var) << "The checked_type_ of GlobalVar " << op->name_hint + << " must be either FuncType or PackedFuncType."); + } + } + + CheckStructInfo(op); + } + + void VisitExpr_(const TupleNode* op) final { + for (size_t i = 0; i < op->fields.size(); i++) { + Expr expr = op->fields[i]; + if (IsLeafOrTuple(expr)) { + this->VisitExpr(expr); + } else { + Malformed(Diagnostic::Error(expr) + << "Tuple is not in ANF form, field " << i << " gets " << expr->GetTypeKey()); + } + } + + CheckStructInfo(op); + } + + void VisitExpr_(const TupleGetItemNode* op) final { + if (IsLeafOrTuple(op->tuple)) { + this->VisitExpr(op->tuple); + } else { + Malformed(Diagnostic::Error(op) + << "The tuple value in a TupleGetItem node must be a leaf expression."); + } + CheckStructInfo(op); + } + + void VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + if (var_set_.count(var) == 0) { + Malformed(Diagnostic::Error(var) << "Var " << op->name_hint() << " is not defined."); + } + CheckStructInfo(op); + } + + void VisitExpr_(const DataflowVarNode* op) final { + DataflowVar var = GetRef(op); + if (!is_dataflow_) { + Malformed(Diagnostic::Error(var) + << "DataflowVar " << op->name_hint() << " is used outside DataflowBlock."); + } + if (dataflow_var_set_.count(var) == 0) { + Malformed(Diagnostic::Error(var) << "DataflowVar " << op->name_hint() << " is not defined."); + } + CheckStructInfo(op); + } + + void VisitExpr_(const FunctionNode* op) final { + // save the var_set_ for local function + auto prev_var_set = var_set_; + auto prev_dataflow_var_set = dataflow_var_set_; + auto prev_symbolic_var_set = symbolic_var_set_; + bool old_dataflow_state = is_dataflow_; + // symbolic var is not captured across function boundaries + symbolic_var_set_.clear(); + is_dataflow_ = false; + + // first populate defs in params + WithMode(VisitMode::kMatchVarDef, [&]() { + ICHECK(mode_ == VisitMode::kMatchVarDef); + for (Var param : op->params) { + relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param)); + } + }); + + // check all expr are well defined. + for (Var param : op->params) { + this->VisitVarDef(param); + + if (param_var_func_map_.count(param) == 1) { + // TODO(relax-team): Complete this error info after we integrate printer + Malformed(Diagnostic::Error(param->span) + << "Relax variable " << param->name_hint() + << " is repeatedly used as parameters in function."); + } + param_var_func_map_.insert({param, GetRef(op)}); + } + + if (auto seq = op->body.as()) { + this->VisitSeqExpr(seq); + } else { + Malformed(Diagnostic::Error(op) << "Function bodies must be sequence expressions"); + } + + is_dataflow_ = old_dataflow_state; + dataflow_var_set_ = prev_dataflow_var_set; + var_set_ = prev_var_set; + symbolic_var_set_ = prev_symbolic_var_set; + } + + void VisitExpr_(const CallNode* op) final { + if (IsLeafOrTuple(op->op)) { + this->VisitExpr(op->op); + } else { + Malformed(Diagnostic::Error(op) << "The called expression must be a leaf expression"); + } + for (size_t i = 0; i < op->args.size(); i++) { + Expr arg = op->args[i]; + if (IsLeafOrTuple(arg)) { + this->VisitExpr(arg); + } else { + Malformed(Diagnostic::Error(arg->span) + << "Call is not in ANF form, arg " << i << " gets " << arg->GetTypeKey()); + } + } + + for (const StructInfo& sinfo_arg : op->sinfo_args) { + this->VisitStructInfo(sinfo_arg); + } + + CheckStructInfo(op); + } + + void VisitExpr_(const IfNode* op) final { + if (IsLeafOrTuple(op->cond)) { + this->VisitExpr(op->cond); + } else { + Malformed(Diagnostic::Error(op) << "The condition for an if node must be a leaf expression."); + } + auto true_seq = op->true_branch.as(); + auto false_seq = op->false_branch.as(); + if (true_seq && false_seq) { + std::unordered_set previous_var_set = var_set_; + std::unordered_set previous_symbolic_var_set = + symbolic_var_set_; + this->VisitSeqExpr(true_seq); + var_set_ = previous_var_set; + symbolic_var_set_ = previous_symbolic_var_set; + this->VisitSeqExpr(false_seq); + var_set_ = previous_var_set; + symbolic_var_set_ = previous_symbolic_var_set; + } else { + Malformed(Diagnostic::Error(op) << "If node branches must be seq exprs"); + } + CheckStructInfo(op); + } + + void VisitExpr_(const ShapeExprNode* op) final { + for (PrimExpr expr : op->values) { + // check if the symbolic vars in the expr are defined, e.g, 2 * m + tir::ExprVisitor::VisitExpr(expr); + if (!expr.dtype().is_int()) { + Malformed(Diagnostic::Error(expr) + << "Shape expressions must be of integer type, but got " << expr.dtype()); + } + } + CheckStructInfo(op); + } + + void VisitExpr_(const SeqExprNode* op) final { + Malformed(Diagnostic::Error(op) << "SeqExpr only serves as the function body in FunctionNode, " + "or the true/false branch body in IfNode."); + } + + void VisitSeqExpr(const SeqExprNode* op) { + // a special call only if SeqExpr is the function body + // in FunctionNode or the true/false branch body in IfNode + for (BindingBlock block : op->blocks) { + this->VisitBindingBlock(block); + } + if (!IsLeafOrTuple(op->body)) { + Malformed(Diagnostic::Error(op) << "SeqExpr bodies must be leaf expressions."); + } + this->VisitExpr(op->body); + CheckStructInfo(op); + } + + void VisitBinding_(const VarBindingNode* binding) final { + this->VisitExpr(binding->value); + this->VisitVarDef(binding->var); + } + + void VisitBinding_(const MatchCastNode* binding) final { + this->VisitExpr(binding->value); + // define the vars + WithMode(VisitMode::kMatchVarDef, [&]() { this->VisitStructInfo(binding->struct_info); }); + + this->VisitStructInfo(binding->struct_info); + this->VisitVarDef(binding->var); + } + + void VisitBindingBlock_(const DataflowBlockNode* block) final { + bool old_is_dataflow_ = is_dataflow_; + is_dataflow_ = true; + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + is_dataflow_ = old_is_dataflow_; + dataflow_var_set_.clear(); + } + + void VisitVarDef_(const DataflowVarNode* var) final { + if (!is_dataflow_) { + Malformed(Diagnostic::Error(var) + << "DataflowVar " << var->name_hint() << " is defined outside DataflowBlock."); + } + DataflowVar lv = GetRef(var); + if (dataflow_var_set_.count(lv) == 1) { + Malformed(Diagnostic::Error(var) + << "DataflowVar " << lv->name_hint() << " is defined more than once."); + } + // register DataflowVar + dataflow_var_set_.insert(lv); + CheckStructInfo(var); + } + + void VisitVarDef_(const VarNode* var) final { + Var gv = GetRef(var); + if (var_set_.count(gv) == 1) { + Malformed(Diagnostic::Error(var) + << "Var " << gv->name_hint() << " is defined more than once."); + } + // register Var + var_set_.insert(gv); + CheckStructInfo(var); + } + + void VisitVarDef(const Var& var) final { + if (const DataflowVarNode* lv_node = var.as()) { + VisitVarDef_(lv_node); + } else if (const VarNode* gv_node = var.as()) { + VisitVarDef_(gv_node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } + } + + void VisitExpr_(const tir::VarNode* op) final { + tir::Var var = GetRef(op); + // default mode, check defined. + if (symbolic_var_set_.count(var) == 0) { + this->Malformed(Diagnostic::Error(var) + << "Symbolic Var " << var->name_hint << " is not defined."); + } + } + + void VisitStructInfoExprField(const Expr& expr) final { + if (mode_ == VisitMode::kMatchVarDef) { + // populate symbolic var in first occurrence + if (auto* op = expr.as()) { + auto var = GetRef(op); + if (var_set_.count(var) == 0) { + var_set_.insert(var); + } + } + if (auto* shape = expr.as()) { + for (auto val : shape->values) { + this->VisitStructInfoExprField(val); + } + } + } else { + relax::ExprVisitor::VisitExpr(expr); + } + } + + void VisitStructInfoExprField(const PrimExpr& expr) final { + if (mode_ == VisitMode::kMatchVarDef) { + // populate symbolic var in first occurrence + if (auto* op = expr.as()) { + auto var = GetRef(op); + if (symbolic_var_set_.count(var) == 0) { + symbolic_var_set_.insert(var); + } + } + } else { + tir::ExprVisitor::VisitExpr(expr); + } + } + + void CheckStructInfo(const ExprNode* op) { + if (!check_struct_info_) { + return; + } + + auto* sinfo = op->struct_info_.as(); + if (sinfo != nullptr) { + this->VisitStructInfo(GetRef(sinfo)); + } else { + Malformed(Diagnostic::Error(op) << "Expr must have struct_info populated. " + << " Expr.type_key=" << op->GetTypeKey()); + } + } + + // Run callback with mode. + template + void WithMode(VisitMode mode, FType callback) { + std::swap(mode_, mode); + callback(); + std::swap(mode_, mode); + } + + IRModule mod_; + const bool check_struct_info_; + bool well_formed_ = true; + bool is_dataflow_; + // Current visit mode. + VisitMode mode_ = VisitMode::kDefault; + // set of context variables. + std::unordered_set var_set_; + std::unordered_set dataflow_var_set_; + std::unordered_set symbolic_var_set_; + std::unordered_map param_var_func_map_; +}; + +bool WellFormed(IRModule m, bool check_struct_info) { + return WellFormedChecker::Check(std::move(m), check_struct_info); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.well_formed")) + .set_body_typed([](IRModule m, bool check_struct_info) { + return WellFormed(m, check_struct_info); + }); + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/conftest.py b/tests/python/relax/conftest.py new file mode 100644 index 000000000000..f1b1187066e6 --- /dev/null +++ b/tests/python/relax/conftest.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations + +import pytest + +import tvm +from tvm.relax.ir.instrument import WellFormedInstrument + + +tvm.transform.PassContext.current().override_instruments([WellFormedInstrument()]) diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py new file mode 100644 index 000000000000..cc0de84d53af --- /dev/null +++ b/tests/python/relax/test_analysis_well_formed.py @@ -0,0 +1,438 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm import relax as rx +from tvm.script import relax as R + +m = tir.Var("m", "int64") +n = tir.Var("n", "int64") +x = rx.Var("x", R.Tensor([m, n], "float32")) +cond = rx.Var("cond", R.Tensor([], "bool")) + + +def build_function(blocks, params=[]): + """Returns relax.function with given blocks""" + seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var) + func = rx.Function([x, cond] + params, seq_expr, R.Tensor("float32")).with_attr( + "global_symbol", "foo" + ) + return func + + +def test_var(): + # Error: Var gv0 is not defined + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + gv1 = rx.Var("gv1", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, gv0) + bindings = [rx.VarBinding(gv1, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: Var gv0 is defined more than once + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, x) + call_node2 = rx.op.multiply(x, x) + bindings = [rx.VarBinding(gv0, call_node), rx.VarBinding(gv0, call_node2)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_dataflow_var(): + # Error: DataflowVar lv0 is not defined + lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, lv0) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.DataflowBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: DataflowVar gv0 is defined more than once + lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, x) + call_node2 = rx.op.multiply(x, x) + bindings = [rx.VarBinding(lv0, call_node), rx.VarBinding(lv0, call_node2)] + blocks = [rx.DataflowBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: DataflowVar lv0 is defined outside DataflowBlock + lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, x) + bindings = [rx.VarBinding(lv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: DataflowVar lv0 is used outside DataflowBlock + lv0 = rx.DataflowVar("lv0", R.Tensor([m, n], "float32")) + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(lv0, x) + bindings = [rx.VarBinding(lv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_param_var(): + v0 = rx.Var("v0", R.Tensor([m, n], "float32")) + v1 = rx.Var("v1", R.Tensor([m, n], "float32")) + v2 = rx.Var("v2", R.Tensor([m, n], "float32")) + bb = rx.BlockBuilder() + with bb.function("func1", [v0, v1]): + gv0 = bb.emit(rx.op.add(v0, v1)) + bb.emit_func_output(gv0) + with bb.function("func2", [v0, v2]): + gv0 = bb.emit(rx.op.add(v2, v1)) + bb.emit_func_output(gv0) + mod = bb.get() + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_global_var(): + # Error: GlobalVar GlobalVar0 is not defined + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + globalvar = rx.GlobalVar("GlobalVar0") + call_node = rx.Call( + op=tvm.ir.Op.get("relax.call_tir"), + args=[globalvar, rx.Tuple([x]), rx.ShapeExpr([m, n])], + ) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_symbolic_var(): + # Error: Symbolic Var new_s is not defined + new_s = tir.Var("new_s", "int64") + gv0 = rx.Var("gv0", R.Tensor([m, new_s], "int64")) + call_node = rx.op.add(x, x) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_symbolic_var_invalid_type(): + with pytest.raises( + tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64" + ): + dim = tir.Var("dim", "float32") + y = rx.Var("y", R.Tensor([dim], "float32")) + gv0 = rx.Var("gv0", R.Tensor([dim], "float32")) + call_node = rx.op.add(y, y) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks, [y]) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_seq_expr(): + # Error: SeqExpr in VarBinding + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + # build a SeqExpr + gv1 = rx.Var("gv1", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, gv0) + _bindings = [rx.VarBinding(gv1, call_node)] + _blocks = [rx.BindingBlock(_bindings)] + _seq_expr = rx.SeqExpr(_blocks, gv1) + # build a Binding with the SeqExpr as value + bindings = [rx.VarBinding(gv0, _seq_expr)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_if(): + # Error: Var defined in true/false branch is invisible in the outer scope + # except the return Var, i.e the var in the last stmt + # v_in_if is invisible in the outer scope + v_in_if = rx.Var("v_in_if", R.Tensor([m, n], "float32")) + # gv0 is visible in the outer scope + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + # build true branch + true_bindings = [ + rx.VarBinding(v_in_if, rx.op.add(x, x)), + rx.VarBinding(gv0, rx.op.multiply(x, x)), + ] + true_blocks = [rx.BindingBlock(true_bindings)] + true_seq_expr = rx.SeqExpr(true_blocks, true_blocks[-1].bindings[-1].var) + # build false branch + false_bindings = [ + rx.VarBinding(v_in_if, rx.op.multiply(x, x)), + rx.VarBinding(gv0, rx.op.add(x, x)), + ] + false_blocks = [rx.BindingBlock(false_bindings)] + false_seq_expr = rx.SeqExpr(false_blocks, false_blocks[-1].bindings[-1].var) + # build If node + if_node = rx.If(cond=cond, true_branch=true_seq_expr, false_branch=false_seq_expr) + gv1 = rx.Var("gv1", R.Tensor([m, n], "float32")) + # try to call v_in_if defined in the true/false branch + bindings = [rx.VarBinding(gv0, if_node), rx.VarBinding(gv1, v_in_if)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=True) + + +def test_if_non_seq_body(): + # Error: If node has a body that is not a seq node + if_node = rx.If(cond=cond, true_branch=x, false_branch=x) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", R.Tensor([m, n], "float32")), + if_node, + ) + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # on the other hand, if they're wrapped in a seq node, it's fine + seq = rx.SeqExpr([], x) + new_if_node = rx.If(cond=cond, true_branch=seq, false_branch=seq) + new_blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", R.Tensor([m, n], "float32")), + new_if_node, + ) + ] + ) + ] + new_func = build_function(new_blocks) + new_mod = tvm.IRModule.from_expr(new_func) + # apply normalization to fill in checked_type_ + normalized = rx.transform.Normalize()(new_mod) + assert rx.analysis.well_formed(normalized, check_struct_info=True) + + +def test_if_complex_condition(): + # Error: If condition must be a leaf expression + cond_tuple = rx.Tuple([cond]) + cond_idx = rx.TupleGetItem(cond_tuple, 0) + if_node = rx.If(cond_idx, rx.SeqExpr([], x), rx.SeqExpr([], x)) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding( + rx.Var("gv1", R.Tensor([m, n], "float32")), + if_node, + ) + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + cond_var = rx.Var("q", R.Tensor([], "bool")) + new_if = rx.If(cond_var, rx.SeqExpr([], x), rx.SeqExpr([], x)) + blocks = [ + rx.BindingBlock( + [ + rx.VarBinding(cond_var, cond_idx), + rx.VarBinding( + rx.Var("gv1", R.Tensor([m, n], "float32")), + new_if, + ), + ] + ) + ] + func = build_function(blocks) + mod = tvm.IRModule.from_expr(func) + # apply normalization to fill in checked_type_ + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized, check_struct_info=True) + + +def test_tuple_get_item_nested(): + # Error: The tuple value in tuple get item must be a leaf expression + nested_tup = rx.Var( + "t", rx.TupleStructInfo([rx.TupleStructInfo([rx.TensorStructInfo([], "int32")])]) + ) + double_idx = rx.TupleGetItem(rx.TupleGetItem(nested_tup, 0), 0) + ret_var = rx.Var("r", R.Tensor([], "int32")) + f = rx.Function( + [nested_tup], + rx.SeqExpr([rx.BindingBlock([rx.VarBinding(ret_var, double_idx)])], ret_var), + ret_struct_info=R.Tensor(ndim=0, dtype="int32"), + ) + f = f.with_attr("global_symbol", "f") + mod = tvm.IRModule.from_expr(f) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # okay with an intermediate binding + first_idx = rx.TupleGetItem(nested_tup, 0) + idx_var = rx.Var("v", rx.TupleStructInfo([rx.TensorStructInfo([], "int32")])) + second_idx = rx.TupleGetItem(idx_var, 0) + new_f = rx.Function( + [nested_tup], + rx.SeqExpr( + [ + rx.BindingBlock( + [rx.VarBinding(idx_var, first_idx), rx.VarBinding(ret_var, second_idx)] + ) + ], + ret_var, + ), + ret_struct_info=R.Tensor(ndim=0, dtype="int32"), + ) + new_f = new_f.with_attr("global_symbol", "new_f") + mod = tvm.IRModule.from_expr(new_f) + # normalize in order to fill in checked type + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized, check_struct_info=True) + + +def test_complex_seq_body(): + # Error: seq expr with a body that is not a leaf expression is not permitted + x = rx.Var("x", R.Tensor([], "int32")) + y = rx.Var("y", R.Tensor([], "int32")) + func = rx.Function( + [x, y], + rx.SeqExpr([], rx.op.add(x, y)), + R.Tensor(ndim=0, dtype="int32"), + ).with_attr("global_symbol", "foo") + mod = tvm.IRModule.from_expr(func) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # but if the result is bound, then it's okay + z = rx.Var("z", R.Tensor([], "int32")) + new_func = rx.Function( + [x, y], + rx.SeqExpr( + [ + rx.BindingBlock( + [ + rx.VarBinding( + var=z, + value=rx.op.add(x, y), + ) + ] + ) + ], + z, + ), + R.Tensor(ndim=0, dtype="int32"), + ).with_attr("global_symbol", "foo") + new_mod = tvm.IRModule.from_expr(new_func) + # normalize in order to fill in checked type + normalized = rx.transform.Normalize()(new_mod) + assert rx.analysis.well_formed(normalized, check_struct_info=True) + + +def test_ANF(): + # Error: Nested Call + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + call_node = rx.op.add(x, rx.op.add(x, x)) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + # Error: Call Node in Tuple + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + bindings = [rx.VarBinding(gv0, rx.Tuple((x, rx.op.add(x, x))))] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_global_var_vs_gsymbol(): + # Error: gsymbol "main1" not equals to the name in global var "main" + gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) + bindings = [rx.VarBinding(gv0, x)] + blocks = [rx.DataflowBlock(bindings)] + func = rx.Function( + [x], + rx.SeqExpr(blocks, gv0), + R.Tensor(ndim=2, dtype="float32"), + ).with_attr("global_symbol", "main1") + mod = tvm.IRModule({rx.GlobalVar("main"): func}) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_nested_dataflow(): + scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32") + gv0 = rx.Var("gv0", scalar_struct_info) + f = rx.DataflowVar("f", rx.FuncStructInfo([], scalar_struct_info)) + x0 = rx.DataflowVar("x0", scalar_struct_info) + x1 = rx.DataflowVar("x1", scalar_struct_info) + x2 = rx.DataflowVar("x2", scalar_struct_info) + y = rx.Var("y", scalar_struct_info) + inner_block = rx.DataflowBlock([rx.VarBinding(x0, rx.const(2, "int32")), rx.VarBinding(y, x0)]) + inner_func = rx.Function([], rx.SeqExpr([inner_block], y), scalar_struct_info) + outer_block = rx.DataflowBlock( + [ + rx.VarBinding(x1, rx.const(1, "int32")), + rx.VarBinding(f, inner_func), + rx.VarBinding(x2, rx.op.add(x1, rx.Call(f, []))), + rx.VarBinding(gv0, x2), + ] + ) + func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info) + mod = tvm.IRModule.from_expr(func) + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized) + + +def test_sinfo_args_tir_var_used_before_define_call_packed(): + # Error: Symbolic Var m1, n1 are not defined + m1 = tir.Var("m1", "int64") + n1 = tir.Var("n1", "int64") + call = R.call_packed("my_func", x, sinfo_args=R.Tensor((m1, n1), "float32")) + func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]) + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +def test_sinfo_args_tir_var_used_before_define_call_tir(): + # Error: Symbolic Var m1, n1 are not defined + m1 = tir.Var("m1", "int64") + n1 = tir.Var("n1", "int64") + call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32")) + func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]) + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + +if __name__ == "__main__": + tvm.testing.main() From 8d05dce13d0ebeda6ed32f1dffbeeaa96b8c9a1e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 17 Feb 2023 23:06:52 -0500 Subject: [PATCH 34/81] [Unity][TVMScript] Move tir/relax import in script out of __init__.py (#14033) Prior to this PR, `python/tvm/script/__init__.py` imports both tir and relax submodules. This leads to the phenomenum that when people does ```python from tvm.script import tir as T ``` , the relax submodule will be implicitly visited by `__init__.py` as well. Since TIR does not rely on Relax, it is good not to import both of them at the same time. (This can prevent cyclic imports sometimes.) This PR does this decoupling by introducing two files * `python/tvm/script/relax.py` * `python/tvm/script/tir.py` and removing the imports from `python/tvm/script/__init__.py` and `python/tvm/script/parser/__init__.py`. With this change, we force people to manually do `from tvm.script import tir` and `from tvm.script import relax` to use TVMScript parser, which is right our conventional way. --- python/tvm/script/__init__.py | 2 -- python/tvm/script/parser/__init__.py | 6 ++---- python/tvm/script/parser/tir/__init__.py | 2 +- python/tvm/script/relax.py | 18 ++++++++++++++++++ python/tvm/script/tir.py | 18 ++++++++++++++++++ 5 files changed, 39 insertions(+), 7 deletions(-) create mode 100644 python/tvm/script/relax.py create mode 100644 python/tvm/script/tir.py diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index 6d92c68367b3..f5ee692cbb8f 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -17,5 +17,3 @@ """TVM Script APIs of TVM Python Package""" from .parser import ir, ir_module from .parser import parse as from_source -from .parser import tir -from .parser import relax diff --git a/python/tvm/script/parser/__init__.py b/python/tvm/script/parser/__init__.py index 678297799e6d..ba7f085c08a4 100644 --- a/python/tvm/script/parser/__init__.py +++ b/python/tvm/script/parser/__init__.py @@ -13,10 +13,8 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations -# under the Licens. +# under the License. """The parser""" -from . import _core, ir, tir, relax +from . import _core, ir from ._core import parse from .ir import ir_module -from .tir import prim_func -from .relax import function diff --git a/python/tvm/script/parser/tir/__init__.py b/python/tvm/script/parser/tir/__init__.py index ad16821a89a3..e44b6b521b27 100644 --- a/python/tvm/script/parser/tir/__init__.py +++ b/python/tvm/script/parser/tir/__init__.py @@ -32,4 +32,4 @@ else: from .entry import prim_func -__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"] +__all__ = _tir.__all__ + ["Buffer", "Ptr", "bool", "prim_func"] diff --git a/python/tvm/script/relax.py b/python/tvm/script/relax.py new file mode 100644 index 000000000000..2301463059e3 --- /dev/null +++ b/python/tvm/script/relax.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script APIs of TVM Python Package for Relax""" +from .parser.relax import * # pylint: disable=redefined-builtin,unused-wildcard-import,wildcard-import diff --git a/python/tvm/script/tir.py b/python/tvm/script/tir.py new file mode 100644 index 000000000000..49f3ecd42c50 --- /dev/null +++ b/python/tvm/script/tir.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script APIs of TVM Python Package for TIR""" +from .parser.tir import * # pylint: disable=redefined-builtin,unused-wildcard-import,wildcard-import From 7ccda2571c1b2143856bcf895fe183a1514a31d0 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 17 Feb 2023 23:16:45 -0500 Subject: [PATCH 35/81] [Unity][Pass] Operator legalization (#14029) This PR is the operator legalization pass, which transforms high-level operator calls to `call_tir`s of corresponding low-level TIR PrimFuncs. - The legalization pass provides customizability, which enables people to pass in a customized legalization map to override the default legalization method. - The legalization supports symbolic shape. (At this moment only pooling does not support symbolic shape, as TOPI pooling does not support. This needs to be fixed in followup PRs.) Co-authored-by: Chaofan Lin Co-authored-by: Yixin Dong Co-authored-by: Siyuan Feng --- include/tvm/relax/op_attr_types.h | 9 + include/tvm/relax/transform.h | 23 + python/tvm/relax/__init__.py | 21 +- python/tvm/relax/transform/__init__.py | 3 + .../relax/transform/legalize_ops/__init__.py | 28 + .../relax/transform/legalize_ops/binary.py | 55 + .../relax/transform/legalize_ops/common.py | 104 ++ .../relax/transform/legalize_ops/creation.py | 64 + .../relax/transform/legalize_ops/datatype.py | 31 + .../tvm/relax/transform/legalize_ops/image.py | 39 + .../tvm/relax/transform/legalize_ops/index.py | 63 + .../transform/legalize_ops/linear_algebra.py | 87 ++ .../transform/legalize_ops/manipulate.py | 114 ++ python/tvm/relax/transform/legalize_ops/nn.py | 178 +++ .../relax/transform/legalize_ops/search.py | 21 + .../transform/legalize_ops/statistical.py | 83 ++ .../tvm/relax/transform/legalize_ops/unary.py | 32 + python/tvm/relax/transform/transform.py | 103 ++ src/relax/transform/legalize_ops.cc | 133 ++ .../relax/test_transform_legalize_ops.py | 160 +++ .../test_transform_legalize_ops_binary.py | 1251 +++++++++++++++++ ..._transform_legalize_ops_create_datatype.py | 806 +++++++++++ .../test_transform_legalize_ops_image.py | 103 ++ ...sform_legalize_ops_index_linear_algebra.py | 401 ++++++ .../test_transform_legalize_ops_manipulate.py | 789 +++++++++++ .../relax/test_transform_legalize_ops_nn.py | 1188 ++++++++++++++++ ...ansform_legalize_ops_search_statistical.py | 793 +++++++++++ .../test_transform_legalize_ops_unary.py | 693 +++++++++ 28 files changed, 7365 insertions(+), 10 deletions(-) create mode 100644 python/tvm/relax/transform/legalize_ops/__init__.py create mode 100644 python/tvm/relax/transform/legalize_ops/binary.py create mode 100644 python/tvm/relax/transform/legalize_ops/common.py create mode 100644 python/tvm/relax/transform/legalize_ops/creation.py create mode 100644 python/tvm/relax/transform/legalize_ops/datatype.py create mode 100644 python/tvm/relax/transform/legalize_ops/image.py create mode 100644 python/tvm/relax/transform/legalize_ops/index.py create mode 100644 python/tvm/relax/transform/legalize_ops/linear_algebra.py create mode 100644 python/tvm/relax/transform/legalize_ops/manipulate.py create mode 100644 python/tvm/relax/transform/legalize_ops/nn.py create mode 100644 python/tvm/relax/transform/legalize_ops/search.py create mode 100644 python/tvm/relax/transform/legalize_ops/statistical.py create mode 100644 python/tvm/relax/transform/legalize_ops/unary.py create mode 100644 src/relax/transform/legalize_ops.cc create mode 100644 tests/python/relax/test_transform_legalize_ops.py create mode 100644 tests/python/relax/test_transform_legalize_ops_binary.py create mode 100644 tests/python/relax/test_transform_legalize_ops_create_datatype.py create mode 100644 tests/python/relax/test_transform_legalize_ops_image.py create mode 100644 tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py create mode 100644 tests/python/relax/test_transform_legalize_ops_manipulate.py create mode 100644 tests/python/relax/test_transform_legalize_ops_nn.py create mode 100644 tests/python/relax/test_transform_legalize_ops_search_statistical.py create mode 100644 tests/python/relax/test_transform_legalize_ops_unary.py diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index e171a8d47b0d..a34cf251dc33 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -49,6 +49,15 @@ using FInferStructInfo = */ using FCallPacked = String; +/*! + * \brief The function type of a legalization function, which takes a + * BlockBuilder and the Call to be legalized, and outputs the legalization + * result Expr. + * \param bb The BlockBuilder context. + * \param call The call to be legalized. + */ +using FLegalize = runtime::TypedPackedFunc; + struct PrintAttrs : public tvm::AttrsNode { std::string format; TVM_DECLARE_ATTRS(PrintAttrs, "relax.attrs.PrintAttrs") { diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 7a4054d41405..8b7c7880b9b6 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -142,6 +142,29 @@ TVM_DLL Pass FoldConstant(); */ TVM_DLL Pass Normalize(); +/*! + * \brief Legalize high-level operator calls in Relax functions to call_tir + * with corresponding low-level TIR PrimFuncs. + * + * For each high-level operator, we register the way of legalizing it as a + * function, which takes a context BlockBuilder and the Call being legalized + * as input, and returns the legalized call. Here the input BlockBuilder is + * mainly used for adding the PrimFunc created by call_te into the context + * IRModule. + * + * The legalization function for each operator is registered as an attribute (with + * attribute key `FLegalize`) of the operator. + * + * For customizability, the user can pass their own legalization by an optional customized map, + * with the key to be the operator name and value to be the legalization function. + * The default legalization function will be overridden by the customized one. + * + * \param cmap The customized operator legalization function map. The customized function + * will override the default one. + * \return The Pass. + */ +TVM_DLL Pass LegalizeOps(Optional> cmap); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index a6306b788e5a..cfcf7876dc9f 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -16,16 +16,6 @@ # under the License. # pylint: disable=invalid-name, wrong-import-position """The Relax IR namespace containing the IR, type, operator, builder, vm, etc.""" -from . import exec_builder -from . import expr -from . import ty -from . import analysis -from . import transform -from . import vm -from . import block_builder -from . import op -from . import struct_info - # Expr from .expr import ( Expr, @@ -82,3 +72,14 @@ TupleStructInfo, FuncStructInfo, ) + +# Import submodules in the last to avoid dependency +from . import exec_builder +from . import expr +from . import ty +from . import analysis +from . import transform +from . import vm +from . import block_builder +from . import op +from . import struct_info diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index eb4d5f710c53..78f450b25ce2 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -18,3 +18,6 @@ """Relax transformations. """ from .transform import * + +# Import to register the legalization functions. +from . import legalize_ops diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py b/python/tvm/relax/transform/legalize_ops/__init__.py new file mode 100644 index 000000000000..3e57b815dbd8 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/__init__.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Legalize high-level operator calls in Relax functions to call_tir.""" +from . import binary +from . import creation +from . import datatype +from . import image +from . import index +from . import linear_algebra +from . import manipulate +from . import nn +from . import search +from . import statistical +from . import unary diff --git a/python/tvm/relax/transform/legalize_ops/binary.py b/python/tvm/relax/transform/legalize_ops/binary.py new file mode 100644 index 000000000000..55b832021a5a --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/binary.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for binary operators.""" +from tvm import topi +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import TEFunc, LegalizeFunc, _try_convert_to_scalar_const, register_legalize + + +def _binary(te_func: TEFunc) -> LegalizeFunc: + """A common wrapper util for the legalization of binary operators. + + It detects if one of the binary op arguments is a constant scalar. It so, + it extracts the scalar value to simplify the generated PrimFunc. + """ + + def binary_call_te(bb: BlockBuilder, call: Call) -> Expr: + # To simplify the created PrimFunc, we first check if arg1 is a constant scalar. + # If it is not, we then check if arg0 is a constant scalar. + arg0 = call.args[0] + arg1 = _try_convert_to_scalar_const(call.args[1]) + if isinstance(arg1, Expr): # type: ignore + arg0 = _try_convert_to_scalar_const(arg0) + return bb.call_te(te_func, arg0, arg1) + + return binary_call_te + + +register_legalize("relax.add", _binary(topi.add)) +register_legalize("relax.divide", _binary(topi.divide)) +register_legalize("relax.floor_divide", _binary(topi.floor_divide)) +register_legalize("relax.multiply", _binary(topi.multiply)) +register_legalize("relax.subtract", _binary(topi.subtract)) +register_legalize("relax.equal", _binary(topi.equal)) + +register_legalize("relax.greater", _binary(topi.greater)) +register_legalize("relax.greater_equal", _binary(topi.greater_equal)) +register_legalize("relax.less", _binary(topi.less)) +register_legalize("relax.less_equal", _binary(topi.less_equal)) +register_legalize("relax.not_equal", _binary(topi.not_equal)) diff --git a/python/tvm/relax/transform/legalize_ops/common.py b/python/tvm/relax/transform/legalize_ops/common.py new file mode 100644 index 000000000000..85d7fba85c5b --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/common.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Common functionality for legalization.""" +from typing import Callable, Optional, Union + +import tvm +from tvm import te +from ...block_builder import BlockBuilder +from ...expr import Call, Expr, Constant + + +##################### Types ##################### + + +# The function type of a TE function, which accepts TE Tensors and +# other attributes, and returns the output TE Tensor. +TEFunc = Callable[..., te.Tensor] + +# The function type of a legalization function, which takes a +# BlockBuilder and the Call to be legalized, and outputs the legalization +# result Expr. +LegalizeFunc = Callable[[BlockBuilder, Call], Expr] + + +##################### Utilities ##################### + + +def _try_convert_to_scalar_const(expr: Expr) -> Union[Expr, bool, float, int]: + """Check if the input Expr is a scalar constant. + If it is, return its plain value. + If it is not, return the input expr. + + Parameters + ---------- + expr : Expr + The expr to be checked and converted. + + Returns + --–---- + ret : Union[Expr, bool, float, int] + Return a Python native value (int/float/bool) if the given + expr is a scalar constant. Or return the input itself + if it is not. + """ + if isinstance(expr, Constant) and expr.struct_info.ndim == 0: + return expr.data.numpy()[()].item() + else: + return expr + + +def _call_topi_without_attr(te_func: TEFunc, primfunc_name: Optional[str] = None) -> LegalizeFunc: + """A common wrapper util for the ops who has no attributes and whose + legalization is simply passing its arguments to some TE function. + + Parameters + ---------- + te_func : TEFunc + The input TE function which is to be converted to PrimFunc. + + primfunc_name : Optional[str] + The name of the generated PrimFunc. + If it is not specified, the name of `te_func` will be used by default. + + Returns + ------- + func : LegalizeFunc + The legalization wrapper function, which wraps the input TE function. + """ + if primfunc_name is None: + primfunc_name = te_func.__name__ + return lambda bb, call: bb.call_te(te_func, *call.args, primfunc_name_hint=primfunc_name) + + +##################### Decorators ##################### + +_LEGALIZE_ATTR_NAME = "FLegalize" + + +def register_legalize(op_name: str, legal_func: LegalizeFunc = None): + """Register legal transformation function for a Relax op. + + Parameters + ---------- + op_name : str + The name of the operator + + legal_func: function (bb: BlockBuilder, call: Call) -> new_expr: Expr + The function for transforming an expr to another expr. + """ + return tvm.ir.register_op_attr(op_name, _LEGALIZE_ATTR_NAME, legal_func) diff --git a/python/tvm/relax/transform/legalize_ops/creation.py b/python/tvm/relax/transform/legalize_ops/creation.py new file mode 100644 index 000000000000..38ce8427b7a6 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/creation.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for creation operators.""" +from typing import Optional + +from tvm import topi, tir +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import LegalizeFunc, register_legalize, _try_convert_to_scalar_const + + +def _full(is_like: bool, fill_value: Optional[float], primfunc_name: str) -> LegalizeFunc: + def full_call_te(bb: BlockBuilder, call: Call) -> Expr: + _fill_value = ( + _try_convert_to_scalar_const(call.args[1]) if fill_value is None else fill_value + ) + + return bb.call_te( + topi.full, + call.args[0].struct_info.shape if is_like else call.args[0], + call.struct_info.dtype, + _fill_value, + primfunc_name_hint=primfunc_name, + ) + + return full_call_te + + +def _tril_triu(is_upper: bool, primfunc_name: str) -> LegalizeFunc: + def tril_triu_call_te(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.trilu, + call.args[0], + tir.const(call.attrs.k, "int32"), + upper=is_upper, + primfunc_name_hint=primfunc_name, + ) + + return tril_triu_call_te + + +register_legalize("relax.full", _full(is_like=False, fill_value=None, primfunc_name="full")) +register_legalize("relax.full_like", _full(is_like=True, fill_value=None, primfunc_name="full")) +register_legalize("relax.ones", _full(is_like=False, fill_value=1.0, primfunc_name="ones")) +register_legalize("relax.ones_like", _full(is_like=True, fill_value=1.0, primfunc_name="ones")) +register_legalize("relax.zeros", _full(is_like=False, fill_value=0.0, primfunc_name="zeros")) +register_legalize("relax.zeros_like", _full(is_like=True, fill_value=0.0, primfunc_name="zeros")) +register_legalize("relax.tril", _tril_triu(is_upper=False, primfunc_name="tril")) +register_legalize("relax.triu", _tril_triu(is_upper=True, primfunc_name="triu")) diff --git a/python/tvm/relax/transform/legalize_ops/datatype.py b/python/tvm/relax/transform/legalize_ops/datatype.py new file mode 100644 index 000000000000..a71e8ca15ee4 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/datatype.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for datatype operators.""" +from tvm import topi, relax +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import _try_convert_to_scalar_const, register_legalize + + +@register_legalize("relax.astype") +def _astype(bb: BlockBuilder, call: Call) -> Expr: + arg = _try_convert_to_scalar_const(call.args[0]) + if isinstance(arg, Expr): # type: ignore + return bb.call_te(topi.cast, arg, call.attrs.dtype) + else: + return relax.const(arg, call.attrs.dtype) diff --git a/python/tvm/relax/transform/legalize_ops/image.py b/python/tvm/relax/transform/legalize_ops/image.py new file mode 100644 index 000000000000..1b2a342b0b53 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/image.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for image operators.""" +from tvm import topi +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import register_legalize + + +@register_legalize("relax.image.resize2d") +def _image_resize2d(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.image.resize2d, + call.args[0], + roi=call.attrs.roi, + size=call.args[1], + layout=call.attrs.layout, + method=call.attrs.method, + coordinate_transformation_mode=call.attrs.coordinate_transformation_mode, + rounding_method=call.attrs.rounding_method, + bicubic_alpha=call.attrs.cubic_alpha, + bicubic_exclude=call.attrs.cubic_exclude, + extrapolation_value=call.attrs.extrapolation_value, + ) diff --git a/python/tvm/relax/transform/legalize_ops/index.py b/python/tvm/relax/transform/legalize_ops/index.py new file mode 100644 index 000000000000..9ee6b2813010 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/index.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for index operators.""" +import logging + +from tvm import topi, tir +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import register_legalize + + +@register_legalize("relax.take") +def _take(bb: BlockBuilder, call: Call) -> Expr: + # Currently Relax `take` operator doesn't provide the mode choices and + # requires input indices to be in range. + # We use fast mode, which leads to runtime error whenever some index is + # out of bound. + return bb.call_te(topi.take, call.args[0], call.args[1], call.attrs.axis, mode="fast") + + +@register_legalize("relax.strided_slice") +def _strided_slice(bb: BlockBuilder, call: Call) -> Expr: + if not all( + [ + isinstance(call.args[0].struct_info.shape.values[i.value], tir.IntImm) + for i in call.attrs.axes + ] + ): + logging.info( + "Cases where an axis with symbolic length is sliced are not able " + "to be legalized through TOPI" + ) + return call + + strides = ( + [tir.IntImm("int64", 1)] * len(call.attrs.axes) + if call.attrs.strides is None + else call.attrs.strides + ) + return bb.call_te( + topi.strided_slice, + call.args[0], + call.attrs.begin, + call.attrs.end, + strides, + call.attrs.axes, + slice_mode="end", + ) diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py b/python/tvm/relax/transform/legalize_ops/linear_algebra.py new file mode 100644 index 000000000000..abe21d9fffee --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for linear algebra operators.""" +from tvm import te, relax, tir +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import register_legalize + + +@register_legalize("relax.matmul") +def _matmul(bb: BlockBuilder, call: Call) -> Expr: + def te_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: + a_shape = list(a.shape) + b_shape = list(b.shape) + a_prepended = False + b_appended = False + if len(a_shape) == 1: + a_prepended = True + a_shape.insert(0, 1) + if len(b_shape) == 1: + b_appended = True + b_shape.append(1) + + is_a_larger = len(a_shape) > len(b_shape) + offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape) + + a_relax = relax.Var("a", relax.TensorStructInfo(a.shape)) + b_relax = relax.Var("b", relax.TensorStructInfo(b.shape)) + f_infer_sinfo = call.op.get_attr("FInferStructInfo") + output_shape = f_infer_sinfo(relax.op.matmul(a_relax, b_relax), bb).shape + + def matmul_compute(*idx_spatial): + k = te.reduce_axis((0, a_shape[-1]), name="k") + + def multiply_compute(idx_reduce): + a_indices = [] + b_indices = [] + + for i in range(offset): + if is_a_larger: + a_indices.append(idx_spatial[i]) + else: + b_indices.append(idx_spatial[i]) + for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)): + a_dim = a_shape[i if is_a_larger else i - offset] + b_dim = b_shape[i if not is_a_larger else i - offset] + a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1 + b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1 + a_indices.append(0 if a_dim_is_one else idx_spatial[i]) + b_indices.append(0 if b_dim_is_one else idx_spatial[i]) + if not a_prepended: + a_indices.append(idx_spatial[-2 + b_appended]) + a_indices.append(idx_reduce) + b_indices.append(idx_reduce) + if not b_appended: + b_indices.append(idx_spatial[-1]) + + dtype = call.attrs.out_dtype + if dtype != "": + return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) + else: + return a(*a_indices) * b(*b_indices) + + return te.sum(multiply_compute(k), axis=k) + + return te.compute( + output_shape, + lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda + name="matmul", + ) + + return bb.call_te(te_matmul, call.args[0], call.args[1], primfunc_name_hint="matmul") diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py new file mode 100644 index 000000000000..76e3e74bab9b --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for manipulate operators.""" +import logging + +import tvm +from tvm import topi, tir, relax, te +from ...block_builder import BlockBuilder +from ...expr import Call, Expr, Var, Tuple, TupleGetItem +from .common import TEFunc, LegalizeFunc, register_legalize + + +def _reshape( + te_func: TEFunc, primfunc_name: str, is_collapse_sum_like: bool = False +) -> LegalizeFunc: + def reshape_call_te(bb: BlockBuilder, call: Call): + tgt_shape = call.args[1].struct_info.shape if is_collapse_sum_like else call.args[1] + return bb.call_te(te_func, call.args[0], tgt_shape, primfunc_name_hint=primfunc_name) + + return reshape_call_te + + +register_legalize("relax.broadcast_to", _reshape(topi.broadcast_to, "broadcast_to")) +register_legalize("relax.reshape", _reshape(topi.reshape, "reshape")) + + +@register_legalize("relax.concat") +def _concat(bb: BlockBuilder, call: Call) -> Expr: + t = call.args[0] + n_field = len(t.struct_info.fields) + while isinstance(t, Var): + binding = bb.lookup_binding(t) + if not isinstance(binding, (Tuple, Var)): + break + t = binding + + assert isinstance(t, (Tuple, Var)) + fields = ( + t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] + ) + return bb.call_te( + topi.concatenate, fields, None if call.attrs.axis is None else call.attrs.axis.value + ) + + +@register_legalize("relax.expand_dims") +def _expand_dims(bb: BlockBuilder, call: Call) -> Expr: + def te_expand_dims(data, axis): + data_relax = relax.Var("data", relax.TensorStructInfo(data.shape)) + f_infer_sinfo = call.op.get_attr("FInferStructInfo") + output_shape = f_infer_sinfo(relax.op.expand_dims(data_relax, axis), bb).shape + output_ndim = len(output_shape) + + data_dims = [] + for i in range(output_ndim): + if i not in axis and (i - output_ndim) not in axis: + data_dims.append(i) + return te.compute( + output_shape, + lambda *idx: data(*[idx[dim] for dim in data_dims]), + name="expand_dims", + ) + + return bb.call_te( + te_expand_dims, call.args[0], call.attrs.axis, primfunc_name_hint="expand_dims" + ) + + +@register_legalize("relax.flatten") +def _flatten(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.reshape, call.args[0], call.struct_info.shape.values) + + +@register_legalize("relax.permute_dims") +def _permute_dims(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.transpose, call.args[0], call.attrs.axes) + + +@register_legalize("relax.split") +def _split(bb: BlockBuilder, call: Call) -> Expr: + if isinstance(call.attrs.indices_or_sections, tir.IntImm): + indices_or_sections = call.attrs.indices_or_sections.value + modulo = tvm.arith.Analyzer().simplify( + call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections + ) + if modulo != 0: + logging.info( + "Split cannot be legalized by TOPI when the axis being split has " + "length that not divisible by the input number of section." + ) + return call + else: + indices_or_sections = call.attrs.indices_or_sections + return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis) + + +@register_legalize("relax.squeeze") +def _squeeze(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.squeeze, call.args[0], call.attrs.axis) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py new file mode 100644 index 000000000000..49f198306d14 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Default legalization function for neural network operators.""" +import logging + +from tvm import topi, tir, te +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import register_legalize, _call_topi_without_attr + + +@register_legalize("relax.nn.conv2d") +def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr: + if call.attrs.out_layout != call.attrs.data_layout: + logging.info( + "TOPI conv2d does not support different input-output " + "layouts, and thus cannot be legalized by TOPI" + ) + return call + if len(call.attrs.data_layout) != 4 or len(call.attrs.kernel_layout) != 4: + logging.info( + "Conv2D where data layout or kernel layout have channel chunk " + "cannot be legalized by TOPI at this moment." + ) + return call + if call.attrs.groups != 1: + data_layout = tir.layout(call.attrs.data_layout) + kernel_layout = tir.layout(call.attrs.kernel_layout) + ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")] + oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")] + if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm): + logging.info( + "Conv2D where number of groups is more than one and input or output " + "channel size is symbolic cannot be legalized by TOPI at this moment." + ) + return call + + return bb.call_te( + topi.nn.conv, + inp=call.args[0], + filt=call.args[1], + stride=call.attrs.strides, + padding=call.attrs.padding, + dilation=call.attrs.dilation, + groups=call.attrs.groups, + data_layout=call.attrs.data_layout, + kernel_layout=call.attrs.kernel_layout, + out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None, + primfunc_name_hint="conv2d", + ) + + +@register_legalize("relax.nn.max_pool2d") +def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr: + if call.attrs.out_layout != call.attrs.layout: + logging.info( + "TOPI max_pool2d does not support different input-output " + "layouts, and thus cannot be legalized by TOPI" + ) + return call + + return bb.call_te( + topi.nn.pool2d, + call.args[0], + kernel=call.attrs.pool_size, + stride=call.attrs.strides, + dilation=call.attrs.dilation, + padding=call.attrs.padding, + pool_type="max", + ceil_mode=call.attrs.ceil_mode, + layout=call.attrs.layout, + primfunc_name_hint="max_pool2d", + ) + + +@register_legalize("relax.nn.adaptive_avg_pool2d") +def _nn_adaptive_avg_pool2d(bb: BlockBuilder, call: Call) -> Expr: + if call.attrs.out_layout != call.attrs.layout: + logging.info( + "TOPI adaptive_avg_pool2d does not support different input-output " + "layouts, and thus cannot be legalized by TOPI" + ) + return call + + def te_adaptive_avg_pool2d(data, output_size, layout_str): + if output_size is None: + layout = tir.layout(layout_str) + idx_H = layout.index_of("H") + idx_W = layout.index_of("W") + assert idx_H != -1 and idx_W != -1 + output_size = (data.shape[idx_H], data.shape[idx_W]) + + return topi.nn.adaptive_pool(data, output_size, "avg", layout_str) + + return bb.call_te( + te_adaptive_avg_pool2d, + call.args[0], + call.attrs.output_size, + call.attrs.layout, + primfunc_name_hint="adaptive_avg_pool2d", + ) + + +register_legalize("relax.nn.relu", _call_topi_without_attr(topi.nn.relu)) + + +@register_legalize("relax.nn.gelu") +def _nn_gelu(bb: BlockBuilder, call: Call) -> Expr: + def te_gelu(x: te.Tensor): + dtype = x.dtype + return x * ( + tir.const(0.5, dtype) + + topi.erf(x * tir.const(0.5**0.5, dtype)) * tir.const(0.5, dtype) + ) + + return bb.call_te(te_gelu, call.args[0], primfunc_name_hint="gelu") + + +@register_legalize("relax.nn.silu") +def _nn_silu(bb: BlockBuilder, call: Call) -> Expr: + def te_silu(x: te.Tensor): + return topi.multiply(x, topi.sigmoid(x)) + + return bb.call_te(te_silu, call.args[0], primfunc_name_hint="silu") + + +@register_legalize("relax.nn.softmax") +def _nn_softmax(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.nn.softmax, call.args[0], call.attrs.axis) + + +@register_legalize("relax.nn.batch_norm") +def _nn_batch_norm(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.batch_norm, + data=call.args[0], + gamma=call.args[1], + beta=call.args[2], + moving_mean=call.args[3], + moving_var=call.args[4], + axis=call.attrs.axis, + epsilon=call.attrs.epsilon, + center=call.attrs.center, + scale=call.attrs.scale, + ) + + +@register_legalize("relax.nn.layer_norm") +def _nn_layer_norm(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.layer_norm, + call.args[0], + call.args[1], + call.args[2], + axis=call.attrs.axes, + epsilon=call.attrs.epsilon, + ) + + +@register_legalize("relax.nn.dropout") +def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr: + logging.info("Dropout is handled by frontend translator at this moment and is not legalized.") + return call diff --git a/python/tvm/relax/transform/legalize_ops/search.py b/python/tvm/relax/transform/legalize_ops/search.py new file mode 100644 index 000000000000..fb38d099c496 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/search.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default legalization function for search operators.""" +from tvm import topi +from .common import _call_topi_without_attr, register_legalize + +register_legalize("relax.where", _call_topi_without_attr(topi.where)) diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py new file mode 100644 index 000000000000..3307d49f219f --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/statistical.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Default legalization function for statistical operators.""" +from typing import List +from tvm import topi, tir, te +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import TEFunc, LegalizeFunc, register_legalize + + +def _statistical(te_func: TEFunc) -> LegalizeFunc: + def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(te_func, call.args[0], call.attrs.axis, call.attrs.keepdims) + + return statistical_call_te + + +def _compute_shape_prod(x: te.Tensor, axis: List[tir.IntImm]) -> tir.PrimExpr: + shape_prod = tir.const(1, "int32") + axes = [_axis.value for _axis in axis] if axis is not None else range(0, len(x.shape)) + for dim in axes: + shape_prod = shape_prod * x.shape[dim] + return shape_prod + + +def _te_mean(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor: + shape_prod = _compute_shape_prod(x, axis) + res_sum = topi.sum(x, axis, keepdims) + return topi.divide(res_sum, shape_prod) + + +def _te_variance(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor: + dev = x - _te_mean(x, axis, keepdims) + return _te_mean(dev * dev, axis, keepdims) + + +@register_legalize("relax.mean") +def _mean(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + _te_mean, call.args[0], call.attrs.axis, call.attrs.keepdims, primfunc_name_hint="mean" + ) + + +@register_legalize("relax.std") +def _std(bb: BlockBuilder, call: Call) -> Expr: + def te_std(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor: + return topi.sqrt(_te_variance(x, axis, keepdims)) + + return bb.call_te( + te_std, call.args[0], call.attrs.axis, call.attrs.keepdims, primfunc_name_hint="std" + ) + + +@register_legalize("relax.variance") +def _variance(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + _te_variance, + call.args[0], + call.attrs.axis, + call.attrs.keepdims, + primfunc_name_hint="variance", + ) + + +register_legalize("relax.max", _statistical(topi.max)) +register_legalize("relax.min", _statistical(topi.min)) +register_legalize("relax.prod", _statistical(topi.prod)) +register_legalize("relax.sum", _statistical(topi.sum)) diff --git a/python/tvm/relax/transform/legalize_ops/unary.py b/python/tvm/relax/transform/legalize_ops/unary.py new file mode 100644 index 000000000000..cd29182c4d93 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/unary.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default legalization function for unary operators.""" +from tvm import topi +from .common import _call_topi_without_attr, register_legalize + +# To avoid conflict of IRModule function name and libc function name, we add +# "tir_" as the prefix of the generated PrimFunc name. +register_legalize("relax.abs", _call_topi_without_attr(topi.abs, "tir_abs")) +register_legalize("relax.cos", _call_topi_without_attr(topi.cos, "tir_cos")) +register_legalize("relax.log", _call_topi_without_attr(topi.log, "tir_log")) +register_legalize("relax.exp", _call_topi_without_attr(topi.exp, "tir_exp")) +register_legalize("relax.negative", _call_topi_without_attr(topi.negative, "tir_negative")) +register_legalize("relax.sigmoid", _call_topi_without_attr(topi.sigmoid, "tir_sigmoid")) +register_legalize("relax.sin", _call_topi_without_attr(topi.sin, "tir_sin")) +register_legalize("relax.sqrt", _call_topi_without_attr(topi.sqrt, "tir_sqrt")) +register_legalize("relax.tanh", _call_topi_without_attr(topi.tanh, "tir_tanh")) +register_legalize("relax.clip", _call_topi_without_attr(topi.clip, "tir_clip")) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 7fcf0b1121d2..4ba967935b52 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -24,6 +24,7 @@ import tvm.ir from tvm.runtime import NDArray from . import _ffi_api +from .legalize_ops.common import LegalizeFunc @tvm._ffi.register_object("relax.FunctionPass") @@ -229,6 +230,108 @@ def FuseTIR() -> tvm.ir.transform.Pass: return _ffi_api.FuseTIR() # type: ignore +def LegalizeOps(customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None): + """Legalize high-level operator calls in Relax functions to call_tir + with corresponding low-level TIR PrimFuncs. + + For each high-level operator, we register the way of legalizing it as a + function, which takes a context BlockBuilder and the Call being legalized + as input, and returns the legalized call. Here the input BlockBuilder is + mainly used for adding the PrimFunc created by call_te into the context + IRModule. + + The legalization function for each operator is registered as an attribute (with + attribute key `FLegalize`) of the operator. + + This pass provides customizability for users to use their own legalization + function for operators. The pass takes an optional customized map, + with the key to be the operator name (`str`) and value to be the function + (`LegalizeFunc`). The default legalization function will be overridden by the customized + one. + + Parameters + ---------- + customize_legalize_map : Optional[Dict[str, LegalizeFunc]] + The customized operator legalization function map. The customized function will override + the default one. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass + + Examples + -------- + The following code shows how to use this pass: + + .. code-block:: python + + # Define the pass input IRModule + @tvm.script.ir_module + class Module: + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + z: R.Tensor((2, 3), "float32") = R.add(x, y) + r: R.Tensor((2, 3), "float32") = R.multiply(y, z) + return r + + # Define the customized legalization function for "relax.add" + def customize_legalize_add(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr: + from tvm import topi + return bb.call_te(topi.add, call.args[1], call.args[0]) + + # Apply the pass with the customized function to the module. + mod = LegalizeOps({"relax.add": customize_legalize_add})(Module) + + Print out the result by `mod.show()`, we can see the IRModule after + legalization becomes + + .. code-block:: python + + @tvm.script.ir_module + class Module: + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + z = R.call_tir(add, (y, x), (2, 3), dtype="float32") + r = R.call_tir(multiply, (y, z), (2, 3), dtype="float32") + return r + + @T.prim_func + def add( + A: T.Buffer[(2, 3), "float32"], + B: T.Buffer[(2, 3), "float32"], + T_add: T.Buffer[(2, 3), "float32"], + ): + T.func_attr({"tir.noalias": True}) + for ax0, ax1 in T.grid(2, 3): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1] + + @T.prim_func + def multiply( + A: T.Buffer[(2, 3), "float32"], + B: T.Buffer[(2, 3), "float32"], + T_multiply: T.Buffer[(2, 3), "float32"], + ): + T.func_attr({"tir.noalias": True}) + for ax0, ax1 in T.grid(2, 3): + with T.block("T_multiply"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(T_multiply[v_ax0, v_ax1]) + T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1] + """ + + return _ffi_api.LegalizeOps(customize_legalize_map) # type: ignore + + def MetaScheduleApplyDatabase( work_dir: Optional[str] = None, ) -> tvm.ir.transform.Pass: diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc new file mode 100644 index 000000000000..f9a84c536101 --- /dev/null +++ b/src/relax/transform/legalize_ops.cc @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/legalize_ops.cc + * \brief Legalize high-level operator calls in Relax functions to call_tir + * with corresponding low-level TIR PrimFuncs. + */ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Check if a given Tensor/Shape/TupleStructInfo contains shapes whose + * values are all known. + * \param sinfo The StructInfo to be checked. + * \return A boolean indicating the given struct info contains shape values that are all known. + */ +bool KnowAllShapeValues(const StructInfo& sinfo) { + if (const auto* tensor_sinfo = sinfo.as()) { + return tensor_sinfo->shape.defined() && + tensor_sinfo->shape.value()->IsInstance(); + } else if (const auto* shape_sinfo = sinfo.as()) { + return shape_sinfo->values.defined(); + } else if (const auto* tuple_sinfo = sinfo.as()) { + return std::all_of(tuple_sinfo->fields.begin(), tuple_sinfo->fields.end(), + [](StructInfo field_sinfo) { return KnowAllShapeValues(field_sinfo); }); + } else if (sinfo.as()) { + return true; + } else { + return false; + } +} + +class LegalizeMutator : public ExprMutator { + public: + explicit LegalizeMutator(const IRModule& mod, const Optional>& cmap) + : ExprMutator(mod), mod_(std::move(mod)), cmap_(std::move(cmap)) {} + + IRModule Transform() { + for (const auto& [gv, func] : mod_->functions) { + if (func->IsInstance()) { + auto updated_func = Downcast(this->VisitExpr(func)); + builder_->UpdateFunction(gv, Downcast(updated_func)); + } + } + return builder_->GetContextIRModule(); + } + + private: + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call) final { + Call visited_call = Downcast(this->VisitExprPostOrder_(call)); + static const auto& legalize_map = Op::GetAttrMap("FLegalize"); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + auto* op_node = visited_call->op.as(); + + // Not an OpNode + if (op_node == nullptr) { + return visited_call; + } + + // Not all shape values are known + if (!std::all_of(visited_call->args.begin(), visited_call->args.end(), + [](Expr arg) { return KnowAllShapeValues(GetStructInfo(arg)); }) || + !KnowAllShapeValues(GetStructInfo(visited_call))) { + return visited_call; + } + + auto op = GetRef(op_node); + + // Priority: customize > default. + // Check if it has customize legalization registered. + if (cmap_.defined() && cmap_.value().count(op->name)) { + return cmap_.value()[op->name](this->builder_, visited_call); + } + // Check if it has default legalization registered. + if (legalize_map.count(op)) { + return legalize_map[op](this->builder_, visited_call); + } + + // No legalization. + if (op != call_tir_op) { + LOG(WARNING) << "No legalization func for " << op->name << " is found."; + } + return visited_call; + } + + /*! \brief The context IRModule. */ + IRModule mod_; + /*! \brief The customized legalization function map. */ + Optional> cmap_; +}; + +namespace transform { + +Pass LegalizeOps(Optional> cmap) { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return LegalizeMutator(mod, cmap).Transform(); }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"LegalizeOps", + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.LegalizeOps").set_body_typed(LegalizeOps); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_legalize_ops.py b/tests/python/relax/test_transform_legalize_ops.py new file mode 100644 index 000000000000..91f8cb4259cf --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops.py @@ -0,0 +1,160 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +from tvm.relax.transform import LegalizeOps +from tvm.relax.transform.legalize_ops.common import register_legalize +from tvm.script import relax as R, tir as T +import tvm.testing + + +def test_customize_legalize(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.add(x, y) + return gv + + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv = R.call_tir(add, (y, x), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def add(rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + rxplaceholder[T.int64(0), ax2, ax3] + # fmt: on + + def customize_legalize_add(bb: relax.BlockBuilder, call: relax.Call): + from tvm import topi # pylint: disable=import-outside-toplevel + + return bb.call_te(topi.add, call.args[1], call.args[0]) + + mod = LegalizeOps({"relax.add": customize_legalize_add})(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_legalize_multiple_types_of_call(): + # fmt: off + @tvm.script.ir_module + class Before: + @R.function + def mul2(x: R.Tensor((3, 3), "float32")): + gv = R.multiply(x, R.const(2.0, "float32")) + return gv + + @T.prim_func + def identity(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_id: T.Buffer((T.int64(3), T.int64(3)), "float32")): + for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1]) + T.writes(T_id[v_ax0, v_ax1]) + T_id[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + + @R.function + def main(x: R.Tensor((3, 3), "float32")): + gv: R.Tensor((3, 3), "float32") = mul2(x) + gv1 = R.call_tir(identity, gv, R.Tensor((3, 3), dtype="float32")) + gv2 = R.multiply(gv1, R.const(2.0, "float32")) + return gv2 + + @tvm.script.ir_module + class Expected: + @R.function + def mul2(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"): + gv = R.call_tir(multiply, (x,), R.Tensor((3, 3), dtype="float32")) + return gv + + @T.prim_func + def identity(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_id: T.Buffer((T.int64(3), T.int64(3)), "float32")): + for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1]) + T.writes(T_id[v_ax0, v_ax1]) + T_id[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + + @T.prim_func + def multiply(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(3), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): + with T.block("T_multiply"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1]) + T.writes(T_multiply[v_ax0, v_ax1]) + T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * T.float32(2) + + @R.function + def main(x1: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"): + gv1: R.Tensor((3, 3), dtype="float32") = mul2(x1) + gv11 = R.call_tir(identity, gv1, R.Tensor((3, 3), dtype="float32")) + gv2 = R.call_tir(multiply, (gv11,), R.Tensor((3, 3), dtype="float32")) + return gv2 + # fmt: on + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + +def test_can_not_legalize(): + # case 1: does't have legalization + add_legalize = tvm.ir.Op.get("relax.add").get_attr("FLegalize") + # reset it for test + tvm.ir.Op.get("relax.add").reset_attr("FLegalize") + + # fmt: off + @tvm.script.ir_module + class Before0: + @R.function + def main(x: R.Tensor((3, 3), "float32")): + gv: R.Tensor((3, 3), "float32") = R.add(x, x) + return gv + # fmt: on + After0 = LegalizeOps()(Before0) + tvm.ir.assert_structural_equal(After0, Before0) + + register_legalize("relax.add", add_legalize) + + # case 2: don't know all shape + s = relax.Var("s", relax.ShapeStructInfo((3, 3))) + x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32")) + y = relax.Var("y", relax.TensorStructInfo(s, "float32")) + bb = relax.BlockBuilder() + with bb.function("main", [x, y]): + with bb.dataflow(): + gv = bb.emit_output(R.add(x, y)) + bb.emit_func_output(gv) + Before1 = bb.get() + After1 = LegalizeOps()(Before1) + tvm.ir.assert_structural_equal(After1, Before1) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py b/tests/python/relax/test_transform_legalize_ops_binary.py new file mode 100644 index 000000000000..c2db7e9ba1a1 --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_binary.py @@ -0,0 +1,1251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T +import tvm.testing + + +##################### Binary arithmetic ##################### + + +def test_add(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.add(x, y) + return gv + + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv = R.call_tir(add, (x, y), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_add_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.add(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(add, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + T.float32(1) + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_add_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.add(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(add, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = T.float32(1) + rxplaceholder[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_add_symbolic(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, b, c, d), "float32") = R.add(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv = R.call_tir(add, (x, y), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def add(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_add: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_add = T.match_buffer(var_T_add, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_divide(): + # fmt: off + @tvm.script.ir_module + class Divide: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.divide(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv = R.call_tir(divide, (x, y), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def divide(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_divide: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_divide[ax0, ax1, ax2, ax3]) + T_divide[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] / rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Divide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_divide_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Divide: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.divide(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(divide, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_divide[ax0, ax1]) + T_divide[ax0, ax1] = rxplaceholder[ax0, ax1] / T.float32(1) + # fmt: on + + mod = LegalizeOps()(Divide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_divide_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Divide: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.divide(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(divide, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_divide[ax0, ax1]) + T_divide[ax0, ax1] = T.float32(1) / rxplaceholder[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Divide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_divide_symbolic(): + # fmt: off + @tvm.script.ir_module + class Divide: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, b, c, d), "float32") = R.divide(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv = R.call_tir(divide, (x, y), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_divide: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_divide = T.match_buffer(var_T_divide, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_divide[ax0, ax1, ax2, ax3]) + T_divide[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] / rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Divide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_floor_divide(): + # fmt: off + @tvm.script.ir_module + class FloorDivide: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.floor_divide(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv = R.call_tir(floor_divide, (x, y), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def floor_divide(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_floor_divide: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_floor_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_floor_divide[ax0, ax1, ax2, ax3]) + T_floor_divide[ax0, ax1, ax2, ax3] = T.floor(rxplaceholder[T.int64(0), ax2, ax3] / rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + # fmt: on + + mod = LegalizeOps()(FloorDivide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_floor_divide_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class FloorDivide: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.floor_divide(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(floor_divide, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def floor_divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_floor_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_floor_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_floor_divide[ax0, ax1]) + T_floor_divide[ax0, ax1] = T.floor(rxplaceholder[ax0, ax1] / T.float32(1)) + # fmt: on + + mod = LegalizeOps()(FloorDivide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_floor_divide_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class FloorDivide: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), dtype="float32") = R.floor_divide(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(floor_divide, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def floor_divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_floor_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_floor_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_floor_divide[ax0, ax1]) + T_floor_divide[ax0, ax1] = T.floor(T.float32(1) / rxplaceholder[ax0, ax1]) + # fmt: on + + mod = LegalizeOps()(FloorDivide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_floor_divide_symbolic(): + # fmt: off + @tvm.script.ir_module + class FloorDivide: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, b, c, d), "float32") = R.floor_divide(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv = R.call_tir(floor_divide, (x, y), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def floor_divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_floor_divide: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_floor_divide = T.match_buffer(var_T_floor_divide, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_floor_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_floor_divide[ax0, ax1, ax2, ax3]) + T_floor_divide[ax0, ax1, ax2, ax3] = T.floor(rxplaceholder[T.int64(0), ax2, ax3] / rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + # fmt: on + + mod = LegalizeOps()(FloorDivide) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_multiply(): + # fmt: off + @tvm.script.ir_module + class Multiply: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.multiply(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv = R.call_tir(multiply, (x, y), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def multiply(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_multiply: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] * rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Multiply) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_multiply_symbolic(): + # fmt: off + @tvm.script.ir_module + class Multiply: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, b, c, d), "float32") = R.multiply(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv = R.call_tir(multiply, (x, y), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def multiply(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_multiply: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_multiply = T.match_buffer(var_T_multiply, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] * rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Multiply) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_subtract(): + # fmt: off + @tvm.script.ir_module + class Subtract: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.subtract(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv = R.call_tir(subtract, (x, y), R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def subtract(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_subtract: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] - rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Subtract) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_subtract_symbolic(): + # fmt: off + @tvm.script.ir_module + class Subtract: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, b, c, d), "float32") = R.subtract(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv = R.call_tir(subtract, (x, y), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def subtract(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_subtract: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_subtract = T.match_buffer(var_T_subtract, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] - rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Subtract) + tvm.ir.assert_structural_equal(mod, Expected) + + +##################### Binary comparison ##################### + + +def test_equal(): + # fmt: off + @tvm.script.ir_module + class Equal: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv: R.Tensor((4, 3, 2, 3), "bool") = R.equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv = R.call_tir(equal, (x, y), R.Tensor((4, 3, 2, 3), dtype="bool")) + return gv + + @T.prim_func + def equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_equal[ax0, ax1, ax2, ax3]) + T_equal[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] == rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Equal) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_equal_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), dtype="bool") = R.equal(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv = R.call_tir(equal, (x,), R.Tensor((2, 3), dtype="bool")) + return gv + + @T.prim_func + def equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_equal"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_equal[ax0, ax1]) + T_equal[ax0, ax1] = rxplaceholder[ax0, ax1] == T.float32(1) + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_equal_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), dtype="bool") = R.equal(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv = R.call_tir(equal, (x,), R.Tensor((2, 3), dtype="bool")) + return gv + + @T.prim_func + def equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_equal"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_equal[ax0, ax1]) + T_equal[ax0, ax1] = T.float32(1) == rxplaceholder[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_equal_symbolic(): + # fmt: off + @tvm.script.ir_module + class Equal: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, b, c, d), "bool") = R.equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv = R.call_tir(equal, (x, y), R.Tensor((a, b, c, d), dtype="bool")) + return gv + + @T.prim_func + def equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_equal: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_equal = T.match_buffer(var_T_equal, [a, b, c, d], dtype="bool") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_equal[ax0, ax1, ax2, ax3]) + T_equal[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] == rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Equal) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_greater(): + # fmt: off + @tvm.script.ir_module + class Greater: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv: R.Tensor((4, 3, 2, 3), "bool") = R.greater(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv = R.call_tir(greater, (x, y), R.Tensor((4, 3, 2, 3), dtype="bool")) + return gv + + @T.prim_func + def greater(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_greater: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_greater"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) + T.writes(T_greater[ax0, ax1, ax2, ax3]) + T_greater[ax0, ax1, ax2, ax3] = rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] < rxplaceholder[T.int64(0), ax2, ax3] + # fmt: on + + mod = LegalizeOps()(Greater) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_greater_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), dtype="bool") = R.greater(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv = R.call_tir(greater, (x,), R.Tensor((2, 3), dtype="bool")) + return gv + + @T.prim_func + def greater(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_greater: T.Buffer((T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_greater"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_greater[ax0, ax1]) + T_greater[ax0, ax1] = T.float32(1) < rxplaceholder[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_greater_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), dtype="bool") = R.greater(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv = R.call_tir(greater, (x,), R.Tensor((2, 3), dtype="bool")) + return gv + + @T.prim_func + def greater(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_greater: T.Buffer((T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_greater"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_greater[ax0, ax1]) + T_greater[ax0, ax1] = rxplaceholder[ax0, ax1] < T.float32(1) + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_greater_symbolic(): + # fmt: off + @tvm.script.ir_module + class Greater: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, b, c, d), "bool") = R.greater(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv = R.call_tir(greater, (x, y), R.Tensor((a, b, c, d), dtype="bool")) + return gv + + @T.prim_func + def greater(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_greater: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_greater = T.match_buffer(var_T_greater, [a, b, c, d], dtype="bool") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_greater"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) + T.writes(T_greater[ax0, ax1, ax2, ax3]) + T_greater[ax0, ax1, ax2, ax3] = rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] < rxplaceholder[T.int64(0), ax2, ax3] + # fmt: on + + mod = LegalizeOps()(Greater) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_greater_equal(): + # fmt: off + @tvm.script.ir_module + class GreaterEqual: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv: R.Tensor((4, 3, 2, 3), "bool") = R.greater_equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv = R.call_tir(greater_equal, (x, y), R.Tensor((4, 3, 2, 3), dtype="bool")) + return gv + + @T.prim_func + def greater_equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_greater_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_greater_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) + T.writes(T_greater_equal[ax0, ax1, ax2, ax3]) + T_greater_equal[ax0, ax1, ax2, ax3] = rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] <= rxplaceholder[T.int64(0), ax2, ax3] + # fmt: on + + mod = LegalizeOps()(GreaterEqual) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_greater_equal_symbolic(): + # fmt: off + @tvm.script.ir_module + class GreaterEqual: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, b, c, d), "bool") = R.greater_equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv = R.call_tir(greater_equal, (x, y), R.Tensor((a, b, c, d), dtype="bool")) + return gv + + @T.prim_func + def greater_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_greater_equal: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_greater_equal = T.match_buffer(var_T_greater_equal, [a, b, c, d], dtype="bool") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_greater_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2, T.int64(0)], rxplaceholder[T.int64(0), ax2, ax3]) + T.writes(T_greater_equal[ax0, ax1, ax2, ax3]) + T_greater_equal[ax0, ax1, ax2, ax3] = rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] <= rxplaceholder[T.int64(0), ax2, ax3] + # fmt: on + + mod = LegalizeOps()(GreaterEqual) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_less(): + # fmt: off + @tvm.script.ir_module + class Less: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv: R.Tensor((4, 3, 2, 3), "bool") = R.less(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv = R.call_tir(less, (x, y), R.Tensor((4, 3, 2, 3), dtype="bool")) + return gv + + @T.prim_func + def less(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_less: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_less"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_less[ax0, ax1, ax2, ax3]) + T_less[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] < rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Less) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_less_symbolic(): + # fmt: off + @tvm.script.ir_module + class Less: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, b, c, d), "bool") = R.less(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv = R.call_tir(less, (x, y), R.Tensor((a, b, c, d), dtype="bool")) + return gv + + @T.prim_func + def less(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_less: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_less = T.match_buffer(var_T_less, [a, b, c, d], dtype="bool") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_less"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_less[ax0, ax1, ax2, ax3]) + T_less[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] < rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(Less) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_less_equal(): + # fmt: off + @tvm.script.ir_module + class LessEqual: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv: R.Tensor((4, 3, 2, 3), "bool") = R.less_equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv = R.call_tir(less_equal, (x, y), R.Tensor((4, 3, 2, 3), dtype="bool")) + return gv + + @T.prim_func + def less_equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_less_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_less_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_less_equal[ax0, ax1, ax2, ax3]) + T_less_equal[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] <= rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(LessEqual) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_less_equal_with_arg0_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), dtype="bool") = R.less_equal(x, R.const(1, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv = R.call_tir(less_equal, (x,), R.Tensor((2, 3), dtype="bool")) + return gv + + @T.prim_func + def less_equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_less_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_less_equal"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_less_equal[ax0, ax1]) + T_less_equal[ax0, ax1] = rxplaceholder[ax0, ax1] <= T.float32(1) + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_less_equal_with_arg1_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Add: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), dtype="bool") = R.less_equal(R.const(1, "float32"), x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv = R.call_tir(less_equal, (x,), R.Tensor((2, 3), dtype="bool")) + return gv + + @T.prim_func + def less_equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_less_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_less_equal"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_less_equal[ax0, ax1]) + T_less_equal[ax0, ax1] = T.float32(1) <= rxplaceholder[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Add) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_less_equal_symbolic(): + # fmt: off + @tvm.script.ir_module + class LessEqual: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, b, c, d), "bool") = R.less_equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv = R.call_tir(less_equal, (x, y), R.Tensor((a, b, c, d), dtype="bool")) + return gv + + @T.prim_func + def less_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_less_equal: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_less_equal = T.match_buffer(var_T_less_equal, [a, b, c, d], dtype="bool") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_less_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_less_equal[ax0, ax1, ax2, ax3]) + T_less_equal[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] <= rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(LessEqual) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_not_equal(): + # fmt: off + @tvm.script.ir_module + class NotEqual: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv: R.Tensor((4, 3, 2, 3), "bool") = R.not_equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "bool"): + gv = R.call_tir(not_equal, (x, y), R.Tensor((4, 3, 2, 3), dtype="bool")) + return gv + + @T.prim_func + def not_equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_not_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_not_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_not_equal[ax0, ax1, ax2, ax3]) + T_not_equal[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] != rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(NotEqual) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_not_equal_symbolic(): + # fmt: off + @tvm.script.ir_module + class NotEqual: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, b, c, d), "bool") = R.not_equal(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv = R.call_tir(not_equal, (x, y), R.Tensor((a, b, c, d), dtype="bool")) + return gv + + @T.prim_func + def not_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_not_equal: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") + T_not_equal = T.match_buffer(var_T_not_equal, [a, b, c, d], dtype="bool") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_not_equal"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[T.int64(0), ax2, ax3], rxplaceholder_1[ax0, ax1, ax2, T.int64(0)]) + T.writes(T_not_equal[ax0, ax1, ax2, ax3]) + T_not_equal[ax0, ax1, ax2, ax3] = rxplaceholder[T.int64(0), ax2, ax3] != rxplaceholder_1[ax0, ax1, ax2, T.int64(0)] + # fmt: on + + mod = LegalizeOps()(NotEqual) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_create_datatype.py b/tests/python/relax/test_transform_legalize_ops_create_datatype.py new file mode 100644 index 000000000000..2506e966345f --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_create_datatype.py @@ -0,0 +1,806 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T +import tvm.testing + + +##################### Creation ##################### + + +def test_full(): + # fmt: off + @tvm.script.ir_module + class Full: + @R.function + def main(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.full((2, 3), v, dtype="int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "int32"): + gv = R.call_tir(full, (v,), R.Tensor((2, 3), dtype="int32")) + return gv + + @T.prim_func + def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = rxplaceholder[()] + # fmt: on + + mod = LegalizeOps()(Full) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_constant_scalar_fill_value(): + # fmt: off + @tvm.script.ir_module + class Full: + @R.function + def main() -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.full((2, 3), R.const(3.5, "float32"), dtype="int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main() -> R.Tensor((2, 3), "int32"): + gv = R.call_tir(full, R.tuple(), R.Tensor((2, 3), dtype="int32")) + return gv + + @T.prim_func + def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = 3 + # fmt: on + + mod = LegalizeOps()(Full) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_different_dtype(): + # fmt: off + @tvm.script.ir_module + class Full: + @R.function + def main(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.full((2, 3), v, dtype="float32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(full, (v,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.Cast("float32", rxplaceholder[()]) + # fmt: on + + mod = LegalizeOps()(Full) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_symbolic(): + # fmt: off + @tvm.script.ir_module + class Full: + @R.function + def main(dumb_param: R.Tensor(("m", "n")), v: R.Tensor((), "int32")) -> R.Tensor(("m", "n"), "int32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "int32") = R.full((m, n), v, dtype="int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("m", "n")), v: R.Tensor((), "int32")) -> R.Tensor(("m", "n"), "int32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(full, (v,), R.Tensor((m, n), dtype="int32")) + return gv + + @T.prim_func + def full(rxplaceholder: T.Buffer((), "int32"), var_T_full: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + T_full = T.match_buffer(var_T_full, [m, n], dtype="int32") + for i0, i1 in T.grid(m, n): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = rxplaceholder[()] + # fmt: on + + mod = LegalizeOps()(Full) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_like(): + # fmt: off + @tvm.script.ir_module + class FullLike: + @R.function + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.full_like(x, v) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(full, (v,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = rxplaceholder[()] + # fmt: on + + mod = LegalizeOps()(FullLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_like_constant_scalar_fill_value(): + # fmt: off + @tvm.script.ir_module + class FullLike: + @R.function + def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.full_like(x, R.const(-5, "float32")) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(full, R.tuple(), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(-5) + # fmt: on + + mod = LegalizeOps()(FullLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_like_different_dtype(): + # fmt: off + @tvm.script.ir_module + class FullLike: + @R.function + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float64"): + gv: R.Tensor((2, 3), "float64") = R.full_like(x, v, dtype="float64") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float64"): + gv = R.call_tir(full, (v,), R.Tensor((2, 3), dtype="float64")) + return gv + + @T.prim_func + def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float64")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.Cast("float64", rxplaceholder[()]) + # fmt: on + + mod = LegalizeOps()(FullLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full_like_symbolic(): + # fmt: off + @tvm.script.ir_module + class FullLike: + @R.function + def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.full_like(x, v) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(full, (v,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def full(rxplaceholder: T.Buffer((), "float32"), var_T_full: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = rxplaceholder[()] + # fmt: on + + mod = LegalizeOps()(FullLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_ones(): + # fmt: off + @tvm.script.ir_module + class Ones: + @R.function + def main() -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.ones((2, 3), "float32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main() -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(ones, R.tuple(), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def ones(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(1) + # fmt: on + + mod = LegalizeOps()(Ones) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_ones_symbolic(): + # fmt: off + @tvm.script.ir_module + class Ones: + @R.function + def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.ones((m, n), "float32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(ones, R.tuple(), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def ones(var_T_full: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(1) + # fmt: on + + mod = LegalizeOps()(Ones) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_ones_like(): + # fmt: off + @tvm.script.ir_module + class OnesLike: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.ones_like(x, "int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "int32"): + gv = R.call_tir(ones, R.tuple(), R.Tensor((2, 3), dtype="int32")) + return gv + + @T.prim_func + def ones(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = 1 + # fmt: on + + mod = LegalizeOps()(OnesLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_ones_like_symbolic(): + # fmt: off + @tvm.script.ir_module + class OnesLike: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.ones_like(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(ones, R.tuple(), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def ones(var_T_full: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(1) + # fmt: on + + mod = LegalizeOps()(OnesLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_zeros(): + # fmt: off + @tvm.script.ir_module + class Zeros: + @R.function + def main() -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.zeros((2, 3), "float32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main() -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(zeros, R.tuple(), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def zeros(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(0) + # fmt: on + + mod = LegalizeOps()(Zeros) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_zeros_symbolic(): + # fmt: off + @tvm.script.ir_module + class Zeros: + @R.function + def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.zeros((m, n), "float32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(zeros, R.tuple(), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def zeros(var_T_full: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(0) + # fmt: on + + mod = LegalizeOps()(Zeros) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_zeros_like(): + # fmt: off + @tvm.script.ir_module + class ZerosLike: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "int32"): + gv: R.Tensor((2, 3), "int32") = R.zeros_like(x, "int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "int32"): + gv = R.call_tir(zeros, R.tuple(), R.Tensor((2, 3), dtype="int32")) + return gv + + @T.prim_func + def zeros(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = 0 + # fmt: on + + mod = LegalizeOps()(ZerosLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_zeros_like_symbolic(): + # fmt: off + @tvm.script.ir_module + class ZerosLike: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.zeros_like(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(zeros, R.tuple(), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def zeros(var_T_full: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.float32(0) + # fmt: on + + mod = LegalizeOps()(ZerosLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_tril(): + # fmt: off + @tvm.script.ir_module + class Tril: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.tril(x, k=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv = R.call_tir(tril, (x,), R.Tensor((2, 3, 4), dtype="float32")) + return gv + + @T.prim_func + def tril(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), trilu: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): + with T.block("trilu"): + i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1]) + T.writes(trilu[i0_1, i1_1, i2_1]) + trilu[i0_1, i1_1, i2_1] = T.Select(i2_1 - T.int64(1) <= i1_1, rxplaceholder[i0_1, i1_1, i2_1], T.float32(0)) + # fmt: on + + mod = LegalizeOps()(Tril) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_tril_symbolic(): + # fmt: off + @tvm.script.ir_module + class Tril: + @R.function + def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int8"): + m = T.var("int64") + n = T.var("int64") + k = T.var("int64") + gv: R.Tensor((m, n, k), "int8") = R.tril(x, k=-2) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int8"): + m = T.var("int64") + n = T.var("int64") + k = T.var("int64") + gv = R.call_tir(tril, (x,), R.Tensor((m, n, k), dtype="int8")) + return gv + + @T.prim_func + def tril(var_rxplaceholder: T.handle, var_trilu: T.handle): + T.func_attr({"tir.noalias": True}) + k = T.var("int64") + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n, k], dtype="int8") + trilu = T.match_buffer(var_trilu, [m, n, k], dtype="int8") + for i0, i1, i2 in T.grid(m, n, k): + with T.block("trilu"): + i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1]) + T.writes(trilu[i0_1, i1_1, i2_1]) + trilu[i0_1, i1_1, i2_1] = T.Select(i2_1 + T.int64(2) <= i1_1, rxplaceholder[i0_1, i1_1, i2_1], T.int8(0)) + # fmt: on + + mod = LegalizeOps()(Tril) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_triu(): + # fmt: off + @tvm.script.ir_module + class Triu: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.triu(x, k=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv = R.call_tir(triu, (x,), R.Tensor((2, 3, 4), dtype="float32")) + return gv + + @T.prim_func + def triu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), trilu: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): + with T.block("trilu"): + i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1]) + T.writes(trilu[i0_1, i1_1, i2_1]) + trilu[i0_1, i1_1, i2_1] = T.Select(i1_1 <= i2_1 - T.int64(1), rxplaceholder[i0_1, i1_1, i2_1], T.float32(0)) + # fmt: on + + mod = LegalizeOps()(Triu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_triu_symbolic(): + # fmt: off + @tvm.script.ir_module + class Triu: + @R.function + def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int8"): + m = T.var("int64") + n = T.var("int64") + k = T.var("int64") + gv: R.Tensor((m, n, k), "int8") = R.triu(x, k=-2) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int8"): + m = T.var("int64") + n = T.var("int64") + k = T.var("int64") + gv = R.call_tir(triu, (x,), R.Tensor((m, n, k), dtype="int8")) + return gv + + @T.prim_func + def triu(var_rxplaceholder: T.handle, var_trilu: T.handle): + T.func_attr({"tir.noalias": True}) + k = T.var("int64") + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n, k], dtype="int8") + trilu = T.match_buffer(var_trilu, [m, n, k], dtype="int8") + for i0, i1, i2 in T.grid(m, n, k): + with T.block("trilu"): + i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1]) + T.writes(trilu[i0_1, i1_1, i2_1]) + trilu[i0_1, i1_1, i2_1] = T.Select(i1_1 <= i2_1 + T.int64(2), rxplaceholder[i0_1, i1_1, i2_1], T.int8(0)) + # fmt: on + + mod = LegalizeOps()(Triu) + tvm.ir.assert_structural_equal(mod, Expected) + + +##################### Datatype ##################### + + +def test_astype(): + # fmt: off + @tvm.script.ir_module + class Astype: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "int32"): + gv: R.Tensor((2, 3, 4), "int32") = R.astype(x, "int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "int32"): + gv = R.call_tir(cast, (x,), R.Tensor((2, 3, 4), dtype="int32")) + return gv + + @T.prim_func + def cast(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): + with T.block("compute"): + i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1]) + T.writes(compute[i0_1, i1_1, i2_1]) + compute[i0_1, i1_1, i2_1] = T.Cast("int32", rxplaceholder[i0_1, i1_1, i2_1]) + # fmt: on + + mod = LegalizeOps()(Astype) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_astype_input_constant_scalar(): + # fmt: off + @tvm.script.ir_module + class Astype: + @R.function + def main() -> R.Tensor((), "int32"): + gv: R.Tensor((), "int32") = R.astype(R.const(1.5, "float32"), "int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main() -> R.Tensor((), "int32"): + gv: R.Tensor((), "int32") = R.const(1, "int32") + return gv + # fmt: on + + mod = LegalizeOps()(Astype) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_astype_symbolic(): + # fmt: off + @tvm.script.ir_module + class Astype: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "int32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "int32") = R.astype(x, "int32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "int32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(cast, (x,), R.Tensor((m, n), dtype="int32")) + return gv + + @T.prim_func + def cast(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="int32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.Cast("int32", rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Astype) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_image.py b/tests/python/relax/test_transform_legalize_ops_image.py new file mode 100644 index 000000000000..36c8ecdd7b25 --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_image.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T +import tvm.testing + + +def test_image_resize2d(): + # fmt: off + @tvm.script.ir_module + class Resize2D: + @R.function + def main(x: R.Tensor((2, 8, 8, 3), "float32")) -> R.Tensor((2, 16, 16, 3), "float32"): + gv: R.Tensor((2, 16, 16, 3), "float32") = R.image.resize2d(x, size=(16, 16), layout="NHWC", method="nearest_neighbor", coordinate_transformation_mode="asymmetric") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 8, 8, 3), "float32")) -> R.Tensor((2, 16, 16, 3), "float32"): + gv = R.call_tir(resize2d, (x,), R.Tensor((2, 16, 16, 3), dtype="float32")) + return gv + + @T.prim_func + def resize2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(8), T.int64(8), T.int64(3)), "float32"), resize: T.Buffer((T.int64(2), T.int64(16), T.int64(16), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(16), T.int64(16), T.int64(3)): + with T.block("resize"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, T.max(T.min(T.Div(i1_1, T.int64(2)), T.int64(7)), T.int64(0)), T.max(T.min(T.Div(i2_1, T.int64(2)), T.int64(7)), T.int64(0)), i3_1]) + T.writes(resize[i0_1, i1_1, i2_1, i3_1]) + resize[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, T.max(T.min(T.Div(i1_1, T.int64(2)), T.int64(7)), T.int64(0)), T.max(T.min(T.Div(i2_1, T.int64(2)), T.int64(7)), T.int64(0)), i3_1] + # fmt: on + + mod = LegalizeOps()(Resize2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_image_resize2d_symbolic(): + # fmt: off + @tvm.script.ir_module + class Resize2D: + @R.function + def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", "h", "w", 16), "float32")) -> R.Tensor(("n", "c", "oh", "ow", 16), "float32"): + n = T.var("int64") + c = T.var("int64") + oh = T.var("int64") + ow = T.var("int64") + gv: R.Tensor((n, c, oh, ow, 16), "float32") = R.image.resize2d(x, size=(oh, ow), layout="NCHW16c", method="nearest_neighbor", coordinate_transformation_mode="asymmetric") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", "h", "w", 16), "float32")) -> R.Tensor(("n", "c", "oh", "ow", 16), "float32"): + n = T.var("int64") + c = T.var("int64") + oh = T.var("int64") + ow = T.var("int64") + gv = R.call_tir(resize2d, (x,), R.Tensor((n, c, oh, ow, 16), dtype="float32")) + return gv + + @T.prim_func + def resize2d(var_rxplaceholder: T.handle, var_resize: T.handle): + T.func_attr({"tir.noalias": True}) + c = T.var("int64") + h = T.var("int64") + n = T.var("int64") + oh = T.var("int64") + ow = T.var("int64") + w = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w, T.int64(16)], dtype="float32") + resize = T.match_buffer(var_resize, [n, c, oh, ow, T.int64(16)], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(n, c, oh, ow, T.int64(16)): + with T.block("resize"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(rxplaceholder[i0_1, i1_1, T.int64(0) : T.max(h, T.int64(1)), T.int64(0) : T.max(w, T.int64(1)), i4_1]) + T.writes(resize[i0_1, i1_1, i2_1, i3_1, i4_1]) + resize[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i1_1, T.max(T.min(T.Cast("int64", T.round(T.Cast("float32", h) / T.Cast("float32", oh) * T.Cast("float32", i2_1), dtype="float32")), h - T.int64(1)), T.int64(0)), T.max(T.min(T.Cast("int64", T.round(T.Cast("float32", w) / T.Cast("float32", ow) * T.Cast("float32", i3_1), dtype="float32")), w - T.int64(1)), T.int64(0)), i4_1] + # fmt: on + + mod = LegalizeOps()(Resize2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py new file mode 100644 index 000000000000..8b6f9de981bc --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -0,0 +1,401 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T +import tvm.testing + + +##################### Indexing ##################### + + +def test_take(): + # fmt: off + @tvm.script.ir_module + class Take: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((4,), "int64")) -> R.Tensor((2, 4, 4), "float32"): + gv: R.Tensor((2, 4, 4), "float32") = R.take(x, indices, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((4,), "int64")) -> R.Tensor((2, 4, 4), "float32"): + gv = R.call_tir(take, (x, indices), R.Tensor((2, 4, 4), dtype="float32")) + return gv + + @T.prim_func + def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), rxplaceholder_1: T.Buffer(T.int64(4), "int64"), T_take: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): + with T.block("T_take"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, rxplaceholder_1[ax1], ax2], rxplaceholder_1[ax1]) + T.writes(T_take[ax0, ax1, ax2]) + T_take[ax0, ax1, ax2] = rxplaceholder[ax0, rxplaceholder_1[ax1], ax2] + # fmt: on + + mod = LegalizeOps()(Take) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_take_symbolic(): + # fmt: off + @tvm.script.ir_module + class Take: + @R.function + def main(x: R.Tensor(("m", "n"), "float32"), indices: R.Tensor(("i",), "int64")) -> R.Tensor(("m", "i"), "float32"): + m = T.var("int64") + i = T.var("int64") + gv: R.Tensor((m, i), "float32") = R.take(x, indices, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32"), indices: R.Tensor(("i",), "int64")) -> R.Tensor(("m", "i"), "float32"): + m = T.var("int64") + i = T.var("int64") + gv = R.call_tir(take, (x, indices), R.Tensor((m, i), dtype="float32")) + return gv + + @T.prim_func + def take(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_take: T.handle): + T.func_attr({"tir.noalias": True}) + i = T.var("int64") + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [i], dtype="int64") + T_take = T.match_buffer(var_T_take, [m, i], dtype="float32") + for i0, i1 in T.grid(m, i): + with T.block("T_take"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, rxplaceholder_1[ax1]], rxplaceholder_1[ax1]) + T.writes(T_take[ax0, ax1]) + T_take[ax0, ax1] = rxplaceholder[ax0, rxplaceholder_1[ax1]] + # fmt: on + + mod = LegalizeOps()(Take) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_strided_slice(): + # fmt: off + @tvm.script.ir_module + class StridedSlice: + @R.function + def main(x: R.Tensor((8, 9, 10, 10), "float32")) -> R.Tensor((4, 9, 10, 3), "float32"): + gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice(x, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((8, 9, 10, 10), dtype="float32")) -> R.Tensor((4, 9, 10, 3), dtype="float32"): + gv = R.call_tir(strided_slice, (x,), R.Tensor((4, 9, 10, 3), dtype="float32")) + return gv + + @T.prim_func + def strided_slice(rxplaceholder: T.Buffer((T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(4), T.int64(9), T.int64(10), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(9), T.int64(10), T.int64(3)): + with T.block("T_strided_slice_with_axes"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0 * T.int64(2) + T.int64(1), ax1, ax2, T.int64(8) - ax3 * T.int64(3)]) + T.writes(T_strided_slice_with_axes[ax0, ax1, ax2, ax3]) + T_strided_slice_with_axes[ax0, ax1, ax2, ax3] = rxplaceholder[ax0 * T.int64(2) + T.int64(1), ax1, ax2, T.int64(8) - ax3 * T.int64(3)] + # fmt: on + + mod = LegalizeOps()(StridedSlice) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_strided_slice_no_strides(): + # fmt: off + @tvm.script.ir_module + class StridedSlice: + @R.function + def main(x: R.Tensor((8, 9, 10, 10), "float32")) -> R.Tensor((4, 9, 10, 3), "float32"): + gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice(x, axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((8, 9, 10, 10), dtype="float32")) -> R.Tensor((4, 9, 10, 3), dtype="float32"): + gv = R.call_tir(strided_slice, (x,), out_sinfo=R.Tensor((7, 9, 10, 2), dtype="float32")) + return gv + + @T.prim_func + def strided_slice(rxplaceholder: T.Buffer((T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(7), T.int64(9), T.int64(10), T.int64(2)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(7), T.int64(9), T.int64(10), T.int64(2)): + with T.block("T_strided_slice_with_axes"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0 + T.int64(1), v_ax1, v_ax2, v_ax3 + T.int64(2)]) + T.writes(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3]) + T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0 + T.int64(1), v_ax1, v_ax2, v_ax3 + T.int64(2)] + # fmt: on + + mod = LegalizeOps()(StridedSlice) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_strided_slice_symbolic_sliced_axis(): + # fmt: off + @tvm.script.ir_module + class StridedSlice: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor((2, "n"), "float32"): + n = T.var("int64") + gv: R.Tensor((2, n), "float32") = R.strided_slice(x, axes=[0], begin=[1], end=[8], strides=[3]) + return gv + # fmt: on + + mod = LegalizeOps()(StridedSlice) + tvm.ir.assert_structural_equal(mod, StridedSlice) + + +def test_strided_slice_symbolic(): + # fmt: off + @tvm.script.ir_module + class StridedSlice: + @R.function + def main(x: R.Tensor((10, "n"), "float32")) -> R.Tensor((3, "n"), "float32"): + n = T.var("int64") + gv: R.Tensor((3, n), "float32") = R.strided_slice(x, axes=[0], begin=[1], end=[8], strides=[3]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((10, "n"), dtype="float32")) -> R.Tensor((3, "n"), dtype="float32"): + n = T.var("int64") + gv = R.call_tir(strided_slice, (x,), R.Tensor((3, n), dtype="float32")) + return gv + + @T.prim_func + def strided_slice(var_rxplaceholder: T.handle, var_T_strided_slice_with_axes: T.handle): + T.func_attr({"tir.noalias": True}) + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(10), n], dtype="float32") + T_strided_slice_with_axes = T.match_buffer(var_T_strided_slice_with_axes, [T.int64(3), n], dtype="float32") + for i0, i1 in T.grid(T.int64(3), n): + with T.block("T_strided_slice_with_axes"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0 * T.int64(3) + T.int64(1), ax1]) + T.writes(T_strided_slice_with_axes[ax0, ax1]) + T_strided_slice_with_axes[ax0, ax1] = rxplaceholder[ax0 * T.int64(3) + T.int64(1), ax1] + # fmt: on + + mod = LegalizeOps()(StridedSlice) + tvm.ir.assert_structural_equal(mod, Expected) + + +##################### Linear algebra ##################### + + +def test_matmul_1_4(): + # fmt: off + @tvm.script.ir_module + class Matmul: + @R.function + def main(x: R.Tensor((4,), "float32"), y: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 3, 5), "float32"): + gv: R.Tensor((2, 3, 5), "float32") = R.matmul(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((4,), "float32"), y: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 3, 5), "float32"): + gv = R.call_tir(matmul, (x, y), R.Tensor((2, 3, 5), dtype="float32")) + return gv + + @T.prim_func + def matmul(rxplaceholder: T.Buffer(T.int64(4), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), matmul: T.Buffer((T.int64(2), T.int64(3), T.int64(5)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(5), T.int64(4)): + with T.block("matmul"): + i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k], rxplaceholder_1[i0_1, i1_1, k, i2_1]) + T.writes(matmul[i0_1, i1_1, i2_1]) + with T.init(): + matmul[i0_1, i1_1, i2_1] = T.float32(0) + matmul[i0_1, i1_1, i2_1] = matmul[i0_1, i1_1, i2_1] + rxplaceholder[k] * rxplaceholder_1[i0_1, i1_1, k, i2_1] + # fmt: on + + mod = LegalizeOps()(Matmul) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_matmul_4_1(): + # fmt: off + @tvm.script.ir_module + class Matmul: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.matmul(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32"), y: R.Tensor((5,), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv = R.call_tir(matmul, (x, y), R.Tensor((2, 3, 4), dtype="float32")) + return gv + + @T.prim_func + def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer(T.int64(5), "float32"), matmul: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("matmul"): + i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1, k], rxplaceholder_1[k]) + T.writes(matmul[i0_1, i1_1, i2_1]) + with T.init(): + matmul[i0_1, i1_1, i2_1] = T.float32(0) + matmul[i0_1, i1_1, i2_1] = matmul[i0_1, i1_1, i2_1] + rxplaceholder[i0_1, i1_1, i2_1, k] * rxplaceholder_1[k] + # fmt: on + + mod = LegalizeOps()(Matmul) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_matmul_1_1(): + # fmt: off + @tvm.script.ir_module + class Matmul: + @R.function + def main(x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.matmul(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32")) -> R.Tensor((), "float32"): + gv = R.call_tir(matmul, (x, y), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def matmul(rxplaceholder: T.Buffer(T.int64(4), "float32"), rxplaceholder_1: T.Buffer(T.int64(4), "float32"), matmul: T.Buffer((), "float32")): + T.func_attr({"tir.noalias": True}) + for i0 in T.serial(T.int64(4)): + with T.block("matmul"): + k = T.axis.reduce(T.int64(4), i0) + T.reads(rxplaceholder[k], rxplaceholder_1[k]) + T.writes(matmul[()]) + with T.init(): + matmul[()] = T.float32(0) + matmul[()] = matmul[()] + rxplaceholder[k] * rxplaceholder_1[k] + # fmt: on + + mod = LegalizeOps()(Matmul) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_matmul_4_5(): + # fmt: off + @tvm.script.ir_module + class Matmul: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float16"), y: R.Tensor((6, 2, 3, 5, 7), "float16")) -> R.Tensor((6, 2, 3, 4, 7), "float32"): + gv: R.Tensor((6, 2, 3, 4, 7), "float32") = R.matmul(x, y, out_dtype="float32") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float16"), y: R.Tensor((6, 2, 3, 5, 7), "float16")) -> R.Tensor((6, 2, 3, 4, 7), "float32"): + gv = R.call_tir(matmul, (x, y), R.Tensor((6, 2, 3, 4, 7), dtype="float32")) + return gv + + @T.prim_func + def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"), rxplaceholder_1: T.Buffer((T.int64(6), T.int64(2), T.int64(3), T.int64(5), T.int64(7)), "float16"), matmul: T.Buffer((T.int64(6), T.int64(2), T.int64(3), T.int64(4), T.int64(7)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(6), T.int64(2), T.int64(3), T.int64(4), T.int64(7), T.int64(5)): + with T.block("matmul"): + i0_1, i1_1, i2_1, i3_1, i4_1, k = T.axis.remap("SSSSSR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[i1_1, i2_1, i3_1, k], rxplaceholder_1[i0_1, i1_1, i2_1, k, i4_1]) + T.writes(matmul[i0_1, i1_1, i2_1, i3_1, i4_1]) + with T.init(): + matmul[i0_1, i1_1, i2_1, i3_1, i4_1] = T.float32(0) + matmul[i0_1, i1_1, i2_1, i3_1, i4_1] = matmul[i0_1, i1_1, i2_1, i3_1, i4_1] + T.Cast("float32", rxplaceholder[i1_1, i2_1, i3_1, k]) * T.Cast("float32", rxplaceholder_1[i0_1, i1_1, i2_1, k, i4_1]) + # fmt: on + + mod = LegalizeOps()(Matmul) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_matmul_4_5_symbolic(): + # fmt: off + @tvm.script.ir_module + class Matmul: + @R.function + def main(x: R.Tensor(("b", 1, "m", "k"), "float32"), y: R.Tensor(("a", 1, "c", "k", "n"), "float32")) -> R.Tensor(("a", "b", "c", "m", "n"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((a, b, c, m, n), "float32") = R.matmul(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("b", 1, "m", "k"), "float32"), y: R.Tensor(("a", 1, "c", "k", "n"), "float32")) -> R.Tensor(("a", "b", "c", "m", "n"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(matmul, (x, y), R.Tensor((a, b, c, m, n), dtype="float32")) + return gv + + @T.prim_func + def matmul(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + k = T.var("int64") + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [b, T.int64(1), m, k], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, T.int64(1), c, k, n], dtype="float32") + matmul = T.match_buffer(var_matmul, [a, b, c, m, n], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(a, b, c, m, n, k): + with T.block("matmul"): + i0_1, i1_1, i2_1, i3_1, i4_1, k_1 = T.axis.remap("SSSSSR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[i1_1, T.int64(0), i3_1, k_1], rxplaceholder_1[i0_1, T.int64(0), i2_1, k_1, i4_1]) + T.writes(matmul[i0_1, i1_1, i2_1, i3_1, i4_1]) + with T.init(): + matmul[i0_1, i1_1, i2_1, i3_1, i4_1] = T.float32(0) + matmul[i0_1, i1_1, i2_1, i3_1, i4_1] = matmul[i0_1, i1_1, i2_1, i3_1, i4_1] + rxplaceholder[i1_1, T.int64(0), i3_1, k_1] * rxplaceholder_1[i0_1, T.int64(0), i2_1, k_1, i4_1] + # fmt: on + + mod = LegalizeOps()(Matmul) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py new file mode 100644 index 000000000000..53aa868ffefd --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -0,0 +1,789 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T +import tvm.testing + + +##################### Manipulation ##################### + + +def test_broadcast_to(): + # fmt: off + @tvm.script.ir_module + class BroadcastTo: + @R.function + def main(x: R.Tensor((2, 1, 3), "float32")) -> R.Tensor((4, 2, 5, 3), "float32"): + gv: R.Tensor((4, 2, 5, 3), "float32") = R.broadcast_to(x, (4, 2, 5, 3)) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 1, 3), "float32")) -> R.Tensor((4, 2, 5, 3), "float32"): + gv = R.call_tir(broadcast_to, (x,), R.Tensor((4, 2, 5, 3), dtype="float32")) + return gv + + @T.prim_func + def broadcast_to(rxplaceholder: T.Buffer((T.int64(2), T.int64(1), T.int64(3)), "float32"), T_broadcast_to: T.Buffer((T.int64(4), T.int64(2), T.int64(5), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(2), T.int64(5), T.int64(3)): + with T.block("T_broadcast_to"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax1, T.int64(0), ax3]) + T.writes(T_broadcast_to[ax0, ax1, ax2, ax3]) + T_broadcast_to[ax0, ax1, ax2, ax3] = rxplaceholder[ax1, T.int64(0), ax3] + # fmt: on + + mod = LegalizeOps()(BroadcastTo) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_broadcast_to_symbolic(): + # fmt: off + @tvm.script.ir_module + class BroadcastTo: + @R.function + def main(dumb_param: R.Tensor(("a", "c")), x: R.Tensor(("b", 1, "d"), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, b, c, d), "float32") = R.broadcast_to(x, (a, b, c, d)) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("a", "c")), x: R.Tensor(("b", 1, "d"), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv = R.call_tir(broadcast_to, (x,), R.Tensor((a, b, c, d), dtype="float32")) + return gv + + @T.prim_func + def broadcast_to(var_rxplaceholder: T.handle, var_T_broadcast_to: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [b, T.int64(1), d], dtype="float32") + T_broadcast_to = T.match_buffer(var_T_broadcast_to, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_broadcast_to"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax1, T.int64(0), ax3]) + T.writes(T_broadcast_to[ax0, ax1, ax2, ax3]) + T_broadcast_to[ax0, ax1, ax2, ax3] = rxplaceholder[ax1, T.int64(0), ax3] + # fmt: on + + mod = LegalizeOps()(BroadcastTo) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_concat(): + # fmt: off + @tvm.script.ir_module + class Concat: + @R.function + def main(x1: R.Tensor((1, 2, 3), "float32"), x2: R.Tensor((1, 3, 3), "float32"), x3: R.Tensor((1, 4, 3), "float32")) -> R.Tensor((1, 9, 3), "float32"): + gv: R.Tensor((1, 9, 3), "float32") = R.concat((x1, x2, x3), axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x1: R.Tensor((1, 2, 3), "float32"), x2: R.Tensor((1, 3, 3), "float32"), x3: R.Tensor((1, 4, 3), "float32")) -> R.Tensor((1, 9, 3), "float32"): + gv = R.call_tir(concatenate, (x1, x2, x3), R.Tensor((1, 9, 3), dtype="float32")) + return gv + + @T.prim_func + def concatenate(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1), T.int64(3), T.int64(3)), "float32"), rxplaceholder_2: T.Buffer((T.int64(1), T.int64(4), T.int64(3)), "float32"), T_concat: T.Buffer((T.int64(1), T.int64(9), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(9), T.int64(3)): + with T.block("T_concat"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder_2[ax0, ax1 - T.int64(5), ax2], rxplaceholder_1[ax0, ax1 - T.int64(2), ax2], rxplaceholder[ax0, ax1, ax2]) + T.writes(T_concat[ax0, ax1, ax2]) + T_concat[ax0, ax1, ax2] = T.if_then_else(T.int64(5) <= ax1, rxplaceholder_2[ax0, ax1 - T.int64(5), ax2], T.if_then_else(T.int64(2) <= ax1, rxplaceholder_1[ax0, ax1 - T.int64(2), ax2], rxplaceholder[ax0, ax1, ax2])) + # fmt: on + + mod = LegalizeOps()(Concat) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_concat_input_tuple_var(): + # fmt: off + @tvm.script.ir_module + class Concat: + @R.function + def main(t: R.Tuple(R.Tensor((3, 4), "float32"), R.Tensor((3, 5), "float32"))) -> R.Tensor((3, 9), "float32"): + gv: R.Tensor((3, 9), "float32") = R.concat(t, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(t: R.Tuple(R.Tensor((3, 4), "float32"), R.Tensor((3, 5), "float32"))) -> R.Tensor((3, 9), "float32"): + gv: R.Tensor((3, 4), dtype="float32") = t[0] + gv1: R.Tensor((3, 5), dtype="float32") = t[1] + gv2 = R.call_tir(concatenate, (gv, gv1), R.Tensor((3, 9), dtype="float32")) + return gv2 + + @T.prim_func + def concatenate(rxplaceholder: T.Buffer((T.int64(3), T.int64(4)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3), T.int64(5)), "float32"), T_concat: T.Buffer((T.int64(3), T.int64(9)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(3), T.int64(9)): + with T.block("T_concat"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder_1[ax0, ax1 - T.int64(4)], rxplaceholder[ax0, ax1]) + T.writes(T_concat[ax0, ax1]) + T_concat[ax0, ax1] = T.if_then_else(T.int64(4) <= ax1, rxplaceholder_1[ax0, ax1 - T.int64(4)], rxplaceholder[ax0, ax1]) + # fmt: on + + mod = LegalizeOps()(Concat) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_concat_input_tuple_var_symbolic(): + # fmt: off + @tvm.script.ir_module + class Concat: + @R.function + def main(t: R.Tuple(R.Tensor(("a", "b0"), "float32"), R.Tensor(("a", "b1"), "float32"), R.Tensor(("a", "b2"), "float32"))) -> R.Tensor(("a", "b0 + b1 + b2"), "float32"): + a = T.var("int64") + b0 = T.var("int64") + b1 = T.var("int64") + b2 = T.var("int64") + gv: R.Tensor((a, b0 + b1 + b2), "float32") = R.concat(t, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(t: R.Tuple(R.Tensor(("a", "b0"), "float32"), R.Tensor(("a", "b1"), "float32"), R.Tensor(("a", "b2"), "float32"))) -> R.Tensor(("a", "b0 + b1 + b2"), "float32"): + a = T.var("int64") + b0 = T.var("int64") + b1 = T.var("int64") + b2 = T.var("int64") + gv: R.Tensor((a, b0), dtype="float32") = t[0] + gv1: R.Tensor((a, b1), dtype="float32") = t[1] + gv2: R.Tensor((a, b2), dtype="float32") = t[2] + gv3 = R.call_tir(concatenate, (gv, gv1, gv2), R.Tensor((a, ((b0 + b1) + b2)), dtype="float32")) + return gv3 + + @T.prim_func + def concatenate(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_concat: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b0 = T.var("int64") + b1 = T.var("int64") + b2 = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b0], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b1], dtype="float32") + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [a, b2], dtype="float32") + T_concat = T.match_buffer(var_T_concat, [a, b0 + b1 + b2], dtype="float32") + for i0, i1 in T.grid(a, b0 + b1 + b2): + with T.block("T_concat"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder_2[ax0, ax1 - b0 - b1], rxplaceholder_1[ax0, ax1 - b0], rxplaceholder[ax0, ax1]) + T.writes(T_concat[ax0, ax1]) + T_concat[ax0, ax1] = T.if_then_else(T.int64(0) <= ax1 - b0 - b1, rxplaceholder_2[ax0, ax1 - b0 - b1], T.if_then_else(T.int64(0) <= ax1 - b0, rxplaceholder_1[ax0, ax1 - b0], rxplaceholder[ax0, ax1])) + # fmt: on + + mod = LegalizeOps()(Concat) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_expand_dims(): + # fmt: off + @tvm.script.ir_module + class ExpandDims: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), "float32"): + gv: R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), "float32") = R.expand_dims(x, axis=[-1, 1, -6, 3, 5]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), "float32"): + gv = R.call_tir(expand_dims, (x,), R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), dtype="float32")) + return gv + + @T.prim_func + def expand_dims(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), expand_dims: T.Buffer((T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(3), T.int64(1), T.int64(4), T.int64(1)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(3), T.int64(1), T.int64(4), T.int64(1)): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap("SSSSSSSS", [i0, i1, i2, i3, i4, i5, i6, i7]) + T.reads(rxplaceholder[i0_1, i4_1, i6_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1] = rxplaceholder[i0_1, i4_1, i6_1] + # fmt: on + + mod = LegalizeOps()(ExpandDims) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_expand_dims_symbolic(): + # fmt: off + @tvm.script.ir_module + class ExpandDims: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", 1, "b", 1, "c", 1), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + gv: R.Tensor((a, 1, b, 1, c, 1), "float32") = R.expand_dims(x, axis=[1, 3, 5]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", 1, "b", 1, "c", 1), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + gv = R.call_tir(expand_dims, (x,), R.Tensor((a, 1, b, 1, c, 1), dtype="float32")) + return gv + + @T.prim_func + def expand_dims(var_rxplaceholder: T.handle, var_expand_dims: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") + expand_dims = T.match_buffer(var_expand_dims, [a, T.int64(1), b, T.int64(1), c, T.int64(1)], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(a, T.int64(1), b, T.int64(1), c, T.int64(1)): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1, i5_1 = T.axis.remap("SSSSSS", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[i0_1, i2_1, i4_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1] = rxplaceholder[i0_1, i2_1, i4_1] + # fmt: on + + mod = LegalizeOps()(ExpandDims) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_flatten(): + # fmt: off + @tvm.script.ir_module + class Flatten: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((24,), "float32"): + gv: R.Tensor((24,), "float32") = R.flatten(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((24,), "float32"): + gv = R.call_tir(reshape, (x,), R.Tensor((24,), dtype="float32")) + return gv + + @T.prim_func + def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), T_reshape: T.Buffer(T.int64(24), "float32")): + T.func_attr({"tir.noalias": True}) + for i0 in T.serial(T.int64(24)): + with T.block("T_reshape"): + ax0 = T.axis.spatial(T.int64(24), i0) + T.reads(rxplaceholder[ax0 % T.int64(24) // T.int64(12), ax0 % T.int64(12) // T.int64(4), ax0 % T.int64(4)]) + T.writes(T_reshape[ax0]) + T_reshape[ax0] = rxplaceholder[ax0 % T.int64(24) // T.int64(12), ax0 % T.int64(12) // T.int64(4), ax0 % T.int64(4)] + # fmt: on + + mod = LegalizeOps()(Flatten) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_flatten_zero_rank(): + # fmt: off + @tvm.script.ir_module + class Flatten: + @R.function + def main(x: R.Tensor((), "float32")) -> R.Tensor((1,), "float32"): + gv: R.Tensor((1,), "float32") = R.flatten(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((), "float32")) -> R.Tensor((1,), "float32"): + gv = R.call_tir(reshape, (x,), R.Tensor((1,), dtype="float32")) + return gv + + @T.prim_func + def reshape(rxplaceholder: T.Buffer((), "float32"), T_reshape: T.Buffer(T.int64(1), "float32")): + T.func_attr({"tir.noalias": True}) + for i0 in T.serial(T.int64(1)): + with T.block("T_reshape"): + ax0 = T.axis.spatial(T.int64(1), i0) + T.reads(rxplaceholder[()]) + T.writes(T_reshape[ax0]) + T_reshape[ax0] = rxplaceholder[()] + # fmt: on + + mod = LegalizeOps()(Flatten) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_flatten_symbolic(): + # fmt: off + @tvm.script.ir_module + class Flatten: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a * b * c",), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + gv: R.Tensor((a * b * c,), "float32") = R.flatten(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a * b * c",), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + gv = R.call_tir(reshape, (x,), R.Tensor((((a * b) * c),), dtype="float32")) + return gv + + @T.prim_func + def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") + T_reshape = T.match_buffer(var_T_reshape, [a * b * c], dtype="float32") + for i0 in T.serial(a * b * c): + with T.block("T_reshape"): + ax0 = T.axis.spatial(a * b * c, i0) + T.reads(rxplaceholder[ax0 // c // b % a, ax0 // c % b, ax0 % c]) + T.writes(T_reshape[ax0]) + T_reshape[ax0] = rxplaceholder[ax0 // c // b % a, ax0 // c % b, ax0 % c] + # fmt: on + + mod = LegalizeOps()(Flatten) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_permute_dims(): + # fmt: off + @tvm.script.ir_module + class PermuteDims: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((2, 4, 3, 1), "float32"): + gv: R.Tensor((2, 4, 3, 1), "float32") = R.permute_dims(x, axes=[1, -1, 2, -4]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((2, 4, 3, 1), "float32"): + gv = R.call_tir(transpose, (x,), R.Tensor((2, 4, 3, 1), dtype="float32")) + return gv + + @T.prim_func + def transpose(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"), T_transpose: T.Buffer((T.int64(2), T.int64(4), T.int64(3), T.int64(1)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(4), T.int64(3), T.int64(1)): + with T.block("T_transpose"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax3, ax0, ax2, ax1]) + T.writes(T_transpose[ax0, ax1, ax2, ax3]) + T_transpose[ax0, ax1, ax2, ax3] = rxplaceholder[ax3, ax0, ax2, ax1] + # fmt: on + + mod = LegalizeOps()(PermuteDims) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_permute_dims_symbolic(): + # fmt: off + @tvm.script.ir_module + class PermuteDims: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("b", "d", "c", "a"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + gv: R.Tensor((b, d, c, a), "float32") = R.permute_dims(x, axes=[1, -1, 2, -4]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor(("b", "d", "c", "a"), dtype="float32"): + b = T.var("int64") + d = T.var("int64") + c = T.var("int64") + a = T.var("int64") + gv = R.call_tir(transpose, (x,), R.Tensor((b, d, c, a), dtype="float32")) + return gv + + @T.prim_func + def transpose(var_rxplaceholder: T.handle, var_T_transpose: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + T_transpose = T.match_buffer(var_T_transpose, [b, d, c, a], dtype="float32") + for i0, i1, i2, i3 in T.grid(b, d, c, a): + with T.block("T_transpose"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax3, ax0, ax2, ax1]) + T.writes(T_transpose[ax0, ax1, ax2, ax3]) + T_transpose[ax0, ax1, ax2, ax3] = rxplaceholder[ax3, ax0, ax2, ax1] + # fmt: on + + mod = LegalizeOps()(PermuteDims) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_reshape(): + # fmt: off + @tvm.script.ir_module + class Reshape: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 3), "float32"): + gv: R.Tensor((8, 3), "float32") = R.reshape(x, (8, 3)) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 3), "float32"): + gv = R.call_tir(reshape, (x,), R.Tensor((8, 3), dtype="float32")) + return gv + + @T.prim_func + def reshape(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"), T_reshape: T.Buffer((T.int64(8), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(8), T.int64(3)): + with T.block("T_reshape"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[T.int64(0), (ax0 * T.int64(3) + ax1) % T.int64(24) // T.int64(12), (ax0 * T.int64(3) + ax1) % T.int64(12) // T.int64(4), (ax0 * T.int64(3) + ax1) % T.int64(4)]) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = rxplaceholder[T.int64(0), (ax0 * T.int64(3) + ax1) % T.int64(24) // T.int64(12), (ax0 * T.int64(3) + ax1) % T.int64(12) // T.int64(4), (ax0 * T.int64(3) + ax1) % T.int64(4)] + # fmt: on + + mod = LegalizeOps()(Reshape) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_reshape_symbolic(): + # fmt: off + @tvm.script.ir_module + class Reshape: + @R.function + def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b * 2"), "float32"): + a = T.var("int64") + b = T.var("int64") + gv: R.Tensor((a // 2, b * 2), "float32") = R.reshape(x, (a // 2, b * 2)) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b * 2"), "float32"): + a = T.var("int64") + b = T.var("int64") + gv = R.call_tir(reshape, (x,), R.Tensor(((a // 2), (b * 2)), dtype="float32")) + return gv + + @T.prim_func + def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b], dtype="float32") + T_reshape = T.match_buffer(var_T_reshape, [a // T.int64(2), b * T.int64(2)], dtype="float32") + for i0, i1 in T.grid(a // T.int64(2), b * T.int64(2)): + with T.block("T_reshape"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[(ax0 * (b * T.int64(2)) + ax1) // b % a, (ax0 * (b * T.int64(2)) + ax1) % b]) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = rxplaceholder[(ax0 * (b * T.int64(2)) + ax1) // b % a, (ax0 * (b * T.int64(2)) + ax1) % b] + # fmt: on + + mod = LegalizeOps()(Reshape) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_split_by_indices(): + # fmt: off + @tvm.script.ir_module + class Split: + @R.function + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 3, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 3, 4), "float32")]): + gv: R.Tuple([R.Tensor((2, 3, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 3, 4), "float32")]) = R.split(x, [3, 7], axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 3, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 3, 4), "float32")]): + gv = R.call_tir(split, (x,), [R.Tensor((2, 3, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 3, 4), "float32")]) + return gv + + @T.prim_func + def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), T_split_1: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_2: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): + with T.block("T_split"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, ax2]) + T.writes(T_split[ax0, ax1, ax2]) + T_split[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): + with T.block("T_split_1"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(3), ax2]) + T.writes(T_split_1[ax0, ax1, ax2]) + T_split_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(3), ax2] + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): + with T.block("T_split_2"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(7), ax2]) + T.writes(T_split_2[ax0, ax1, ax2]) + T_split_2[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(7), ax2] + # fmt: on + + mod = LegalizeOps()(Split) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_split_by_indices_n_section_indivisible(): + # fmt: off + @tvm.script.ir_module + class Split: + @R.function + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]): + gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x, 3, axis=1) + return gv + # fmt: on + + mod = LegalizeOps()(Split) + tvm.ir.assert_structural_equal(mod, Split) + + +def test_split_by_indices_n_section_divisible(): + # fmt: off + @tvm.script.ir_module + class Split: + @R.function + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]): + gv: R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]) = R.split(x, 2, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]): + gv = R.call_tir(split, (x,), [R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]) + return gv + + @T.prim_func + def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): + with T.block("T_split_sections"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, ax2]) + T.writes(T_split_sections[ax0, ax1, ax2]) + T_split_sections[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): + with T.block("T_split_sections_1"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(5), ax2]) + T.writes(T_split_sections_1[ax0, ax1, ax2]) + T_split_sections_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(5), ax2] + # fmt: on + + mod = LegalizeOps()(Split) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_split_by_indices_n_section_divisible_symbolic(): + # fmt: off + @tvm.script.ir_module + class Split: + @R.function + def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "n * 3"), "float32")) -> R.Tuple([R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32")]): + m = T.var("int64") + n = T.var("int64") + gv: R.Tuple([R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32")]) = R.split(x, 3, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(split, (x,), [R.Tensor((m, ((n * 3) // 3)), "float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")], tir_vars=(n,)) + return gv + + @T.prim_func + def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32") + T_split_sections = T.match_buffer(var_T_split_sections, [m, n * T.int64(3) // T.int64(3)], dtype="float32") + T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n * T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)], dtype="float32") + T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - n * T.int64(3) // T.int64(3) * T.int64(2)], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_split_sections"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_split_sections[ax0, ax1]) + T_split_sections[ax0, ax1] = rxplaceholder[ax0, ax1] + for i0, i1 in T.grid(m, n): + with T.block("T_split_sections_1"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, n + ax1]) + T.writes(T_split_sections_1[ax0, ax1]) + T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, n + ax1] + for i0, i1 in T.grid(m, n): + with T.block("T_split_sections_2"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, n * T.int64(2) + ax1]) + T.writes(T_split_sections_2[ax0, ax1]) + T_split_sections_2[ax0, ax1] = rxplaceholder[ax0, n * T.int64(2) + ax1] + # fmt: on + + mod = LegalizeOps()(Split) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_squeeze(): + # fmt: off + @tvm.script.ir_module + class Squeeze: + @R.function + def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), "float32"): + gv: R.Tensor((2, 3, 1, 4), "float32") = R.squeeze(x, [1, 4]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), "float32"): + gv = R.call_tir(squeeze, (x,), R.Tensor((2, 3, 1, 4), dtype="float32")) + return gv + + @T.prim_func + def squeeze(rxplaceholder: T.Buffer((T.int64(2), T.int64(1), T.int64(3), T.int64(1), T.int64(1), T.int64(4)), "float32"), T_squeeze: T.Buffer((T.int64(2), T.int64(3), T.int64(1), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(1), T.int64(4)): + with T.block("T_squeeze"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, T.int64(0), ax1, ax2, T.int64(0), ax3]) + T.writes(T_squeeze[ax0, ax1, ax2, ax3]) + T_squeeze[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, T.int64(0), ax1, ax2, T.int64(0), ax3] + # fmt: on + + mod = LegalizeOps()(Squeeze) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_squeeze_no_axis(): + # fmt: off + @tvm.script.ir_module + class Squeeze: + @R.function + def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), "float32"): + gv: R.Tensor((2, 3, 1, 4), "float32") = R.squeeze(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), "float32"): + gv = R.call_tir(squeeze, (x,), R.Tensor((2, 3, 4), dtype="float32")) + return gv + + @T.prim_func + def squeeze(rxplaceholder: T.Buffer((T.int64(2), T.int64(1), T.int64(3), T.int64(1), T.int64(1), T.int64(4)), "float32"), T_squeeze: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): + with T.block("T_squeeze"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, T.int64(0), ax1, T.int64(0), T.int64(0), ax2]) + T.writes(T_squeeze[ax0, ax1, ax2]) + T_squeeze[ax0, ax1, ax2] = rxplaceholder[ax0, T.int64(0), ax1, T.int64(0), T.int64(0), ax2] + # fmt: on + + mod = LegalizeOps()(Squeeze) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_squeeze_symbolic(): + # fmt: off + @tvm.script.ir_module + class Squeeze: + @R.function + def main(x: R.Tensor(("a", 1, "b", 1), "float32")) -> R.Tensor(("a", "b", 1), "float32"): + a = T.var("int64") + b = T.var("int64") + gv: R.Tensor((a, b, 1), "float32") = R.squeeze(x, [1]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", 1, "b", 1), "float32")) -> R.Tensor(("a", "b", 1), "float32"): + a = T.var("int64") + b = T.var("int64") + gv = R.call_tir(squeeze, (x,), R.Tensor((a, b, 1), dtype="float32")) + return gv + + @T.prim_func + def squeeze(var_rxplaceholder: T.handle, var_T_squeeze: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, T.int64(1), b, T.int64(1)], dtype="float32") + T_squeeze = T.match_buffer(var_T_squeeze, [a, b, T.int64(1)], dtype="float32") + for i0, i1, i2 in T.grid(a, b, T.int64(1)): + with T.block("T_squeeze"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, T.int64(0), ax1, ax2]) + T.writes(T_squeeze[ax0, ax1, ax2]) + T_squeeze[ax0, ax1, ax2] = rxplaceholder[ax0, T.int64(0), ax1, ax2] + # fmt: on + + mod = LegalizeOps()(Squeeze) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py new file mode 100644 index 000000000000..3f9f02c410e9 --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -0,0 +1,1188 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T +import tvm.testing + + +##################### Neural network ##################### + + +def test_conv2d(): + # fmt: off + @tvm.script.ir_module + class Conv2d: + @R.function + def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((64, 16, 3, 3), "float32")) -> R.Tensor((2, 64, 13, 13), "float32"): + gv: R.Tensor((2, 4, 13, 13), "float32") = R.nn.conv2d(x, w, strides=(2, 2), padding=(1, 1), dilation=(2, 2), groups=8) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((64, 16, 3, 3), "float32")) -> R.Tensor((2, 64, 13, 13), "float32"): + gv = R.call_tir(conv2d, (x, w), R.Tensor((2, 64, 13, 13), dtype="float32")) + return gv + + @T.prim_func + def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(16), T.int64(3), T.int64(3)), "float32"), group_conv2d_nchw: T.Buffer((T.int64(2), T.int64(64), T.int64(13), T.int64(13)), "float32")): + T.func_attr({"tir.noalias": True}) + pad_temp = T.alloc_buffer([T.int64(2), T.int64(128), T.int64(30), T.int64(30)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128), T.int64(30), T.int64(30)): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1)]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(T.int64(1) <= i2_1 and i2_1 < T.int64(29) and T.int64(1) <= i3_1 and i3_1 < T.int64(29), rxplaceholder[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1)], T.float32(0), dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(64), T.int64(13), T.int64(13), T.int64(16), T.int64(3), T.int64(3)): + with T.block("group_conv2d_nchw"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, ff // T.int64(8) * T.int64(16) + rc, yy * T.int64(2) + ry * T.int64(2), xx * T.int64(2) + rx * T.int64(2)], rxplaceholder_1[ff, rc, ry, rx]) + T.writes(group_conv2d_nchw[nn, ff, yy, xx]) + with T.init(): + group_conv2d_nchw[nn, ff, yy, xx] = T.float32(0) + group_conv2d_nchw[nn, ff, yy, xx] = group_conv2d_nchw[nn, ff, yy, xx] + pad_temp[nn, ff // T.int64(8) * T.int64(16) + rc, yy * T.int64(2) + ry * T.int64(2), xx * T.int64(2) + rx * T.int64(2)] * rxplaceholder_1[ff, rc, ry, rx] + # fmt: on + + mod = LegalizeOps()(Conv2d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv2d_with_out_dtype(): + # fmt: off + @tvm.script.ir_module + class Conv2d: + @R.function + def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32")) -> R.Tensor((2, 4, 26, 26), "float16"): + gv: R.Tensor((2, 4, 26, 26), "float16") = R.nn.conv2d(x, w, out_dtype="float16") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32")) -> R.Tensor((2, 4, 26, 26), "float16"): + gv = R.call_tir(conv2d, (x, w), R.Tensor((2, 4, 26, 26), dtype="float16")) + return gv + + @T.prim_func + def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float16")): + T.func_attr({"tir.noalias": True}) + pad_temp = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(28), T.int64(28)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26), T.int64(3), T.int64(3), T.int64(3)): + with T.block("conv2d_nchw"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, rc, yy + ry, xx + rx], rxplaceholder_1[ff, rc, ry, rx]) + T.writes(conv2d_nchw[nn, ff, yy, xx]) + with T.init(): + conv2d_nchw[nn, ff, yy, xx] = T.float16(0) + conv2d_nchw[nn, ff, yy, xx] = conv2d_nchw[nn, ff, yy, xx] + T.Cast("float16", pad_temp[nn, rc, yy + ry, xx + rx]) * T.Cast("float16", rxplaceholder_1[ff, rc, ry, rx]) + # fmt: on + + mod = LegalizeOps()(Conv2d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv2d_nhwc(): + # fmt: off + @tvm.script.ir_module + class Conv2d: + @R.function + def main(x: R.Tensor((2, 28, 28, 128), "float32"), w: R.Tensor((64, 128, 3, 3), "float32")) -> R.Tensor((2, 26, 26, 64), "float32"): + gv: R.Tensor((2, 26, 26, 64), "float32") = R.nn.conv2d(x, w, data_layout="NHWC") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 28, 28, 128), "float32"), w: R.Tensor((64, 128, 3, 3), "float32")) -> R.Tensor((2, 26, 26, 64), "float32"): + gv = R.call_tir(conv2d, (x, w), R.Tensor((2, 26, 26, 64), dtype="float32")) + return gv + + @T.prim_func + def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(28), T.int64(28), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(128), T.int64(3), T.int64(3)), "float32"), conv2d_nhwc: T.Buffer((T.int64(2), T.int64(26), T.int64(26), T.int64(64)), "float32")): + T.func_attr({"tir.noalias": True}) + pad_temp = T.alloc_buffer([T.int64(2), T.int64(28), T.int64(28), T.int64(128)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(28), T.int64(28), T.int64(128)): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(26), T.int64(26), T.int64(64), T.int64(3), T.int64(3), T.int64(128)): + with T.block("conv2d_nhwc"): + nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, yy + ry, xx + rx, rc], rxplaceholder_1[ff, rc, ry, rx]) + T.writes(conv2d_nhwc[nn, yy, xx, ff]) + with T.init(): + conv2d_nhwc[nn, yy, xx, ff] = T.float32(0) + conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + pad_temp[nn, yy + ry, xx + rx, rc] * rxplaceholder_1[ff, rc, ry, rx] + # fmt: on + + mod = LegalizeOps()(Conv2d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv2d_symbolic(): + # fmt: off + @tvm.script.ir_module + class Conv2d: + @R.function + def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c", "kh", "kw"), "float32")) -> R.Tensor(("n", "f", "h - kh + 1", "w - kw + 1"), "float32"): + n = T.var("int64") + h = T.var("int64") + w = T.var("int64") + f = T.var("int64") + kh = T.var("int64") + kw = T.var("int64") + gv: R.Tensor((n, f, h - kh + 1, w - kw + 1), "float32") = R.nn.conv2d(x, kernel) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c", "kh", "kw"), "float32")) -> R.Tensor(("n", "f", "h - kh + 1", "w - kw + 1"), "float32"): + n = T.var("int64") + f = T.var("int64") + h = T.var("int64") + kh = T.var("int64") + w = T.var("int64") + kw = T.var("int64") + gv = R.call_tir(conv2d, (x, kernel), R.Tensor((n, f, ((h - kh) + 1), ((w - kw) + 1)), dtype="float32")) + return gv + + @T.prim_func + def conv2d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv2d_nchw: T.handle): + T.func_attr({"tir.noalias": True}) + c = T.var("int64") + f = T.var("int64") + h = T.var("int64") + kh = T.var("int64") + kw = T.var("int64") + n = T.var("int64") + w = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [f, c, kh, kw], dtype="float32") + conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h - kh + T.int64(1), w - kw + T.int64(1)], dtype="float32") + pad_temp = T.alloc_buffer([n, c, h, w], dtype="float32") + for i0, i1, i2, i3 in T.grid(n, c, h, w): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1]) + T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1] + for i0, i1, i2, i3, i4, i5, i6 in T.grid(n, f, h + T.int64(1) - kh, w + T.int64(1) - kw, c, kh, kw): + with T.block("conv2d_nchw"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(pad_temp[nn, rc, yy + ry, xx + rx], rxplaceholder_1[ff, rc, ry, rx]) + T.writes(conv2d_nchw[nn, ff, yy, xx]) + with T.init(): + conv2d_nchw[nn, ff, yy, xx] = T.float32(0) + conv2d_nchw[nn, ff, yy, xx] = conv2d_nchw[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * rxplaceholder_1[ff, rc, ry, rx] + # fmt: on + + mod = LegalizeOps()(Conv2d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_max_pool2d(): + # fmt: off + @tvm.script.ir_module + class MaxPool2D: + @R.function + def main(x: R.Tensor((4, 112, 112, 6), "float32")) -> R.Tensor((4, 56, 56, 6), "float32"): + gv: R.Tensor((4, 56, 56, 6), "float32") = R.nn.max_pool2d(x, pool_size=[3, 3], strides=[2, 2], dilation=[1, 1], padding=[1, 1, 1, 1], layout="NHWC") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((4, 112, 112, 6), "float32")) -> R.Tensor((4, 56, 56, 6), "float32"): + gv = R.call_tir(max_pool2d, (x,), R.Tensor((4, 56, 56, 6), dtype="float32")) + return gv + + @T.prim_func + def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(112), T.int64(112), T.int64(6)), "float32"), pool_max: T.Buffer((T.int64(4), T.int64(56), T.int64(56), T.int64(6)), "float32")): + T.func_attr({"tir.noalias": True}) + pad_temp = T.alloc_buffer([T.int64(4), T.int64(114), T.int64(114), T.int64(6)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(114), T.int64(114), T.int64(6)): + with T.block("pad_temp"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1 - T.int64(1), ax2 - T.int64(1), ax3]) + T.writes(pad_temp[ax0, ax1, ax2, ax3]) + pad_temp[ax0, ax1, ax2, ax3] = T.if_then_else(T.int64(1) <= ax1 and ax1 < T.int64(113) and T.int64(1) <= ax2 and ax2 < T.int64(113), rxplaceholder[ax0, ax1 - T.int64(1), ax2 - T.int64(1), ax3], T.float32(-3.4028234663852886e+38), dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(4), T.int64(56), T.int64(56), T.int64(6), T.int64(3), T.int64(3)): + with T.block("pool_max"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(pad_temp[ax0, ax1 * T.int64(2) + rv0, ax2 * T.int64(2) + rv1, ax3]) + T.writes(pool_max[ax0, ax1, ax2, ax3]) + T.block_attr({"schedule_rule":"meta_schedule.pool_max"}) + with T.init(): + pool_max[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e+38) + pool_max[ax0, ax1, ax2, ax3] = T.max(pool_max[ax0, ax1, ax2, ax3], pad_temp[ax0, ax1 * T.int64(2) + rv0, ax2 * T.int64(2) + rv1, ax3]) + # fmt: on + + mod = LegalizeOps()(MaxPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_max_pool2d_NCHW16c(): + # fmt: off + @tvm.script.ir_module + class MaxPool2D: + @R.function + def main(x: R.Tensor((4, 4, 112, 112, 16), "float32")) -> R.Tensor((4, 4, 110, 110, 16), "float32"): + gv: R.Tensor((4, 4, 110, 110, 16), "float32") = R.nn.max_pool2d(x, pool_size=[3, 3], layout="NCHW16c") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((4, 4, 112, 112, 16), "float32")) -> R.Tensor((4, 4, 110, 110, 16), "float32"): + gv = R.call_tir(max_pool2d, (x,), R.Tensor((4, 4, 110, 110, 16), dtype="float32")) + return gv + + @T.prim_func + def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(4), T.int64(112), T.int64(112), T.int64(16)), "float32"), pool_max: T.Buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16), T.int64(3), T.int64(3)): + with T.block("pool_max"): + ax0, ax1, ax2, ax3, ax4, rv0, rv1 = T.axis.remap("SSSSSRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1, ax4]) + T.writes(pool_max[ax0, ax1, ax2, ax3, ax4]) + T.block_attr({"schedule_rule":"meta_schedule.pool_max"}) + with T.init(): + pool_max[ax0, ax1, ax2, ax3, ax4] = T.float32(-3.4028234663852886e+38) + pool_max[ax0, ax1, ax2, ax3, ax4] = T.max(pool_max[ax0, ax1, ax2, ax3, ax4], rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1, ax4]) + # fmt: on + + mod = LegalizeOps()(MaxPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_max_pool2d_ceil_mode(): + # fmt: off + @tvm.script.ir_module + class MaxPool2D: + @R.function + def main(x: R.Tensor((4, 6, 112, 112), "float32")) -> R.Tensor((4, 6, 38, 38), "float32"): + gv: R.Tensor((4, 6, 38, 38), "float32") = R.nn.max_pool2d(x, pool_size=[3, 3], strides=[3, 3], dilation=[1, 1], padding=[1, 1, 1, 1], ceil_mode=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((4, 6, 112, 112), dtype="float32")) -> R.Tensor((4, 6, 38, 38), dtype="float32"): + gv = R.call_tir(max_pool2d, (x,), R.Tensor((4, 6, 38, 38), dtype="float32")) + return gv + + @T.prim_func + def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(6), T.int64(112), T.int64(112)), "float32"), pool_max: T.Buffer((T.int64(4), T.int64(6), T.int64(38), T.int64(38)), "float32")): + T.func_attr({"tir.noalias": True}) + pad_temp = T.alloc_buffer([T.int64(4), T.int64(6), T.int64(116), T.int64(116)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(6), T.int64(116), T.int64(116)): + with T.block("pad_temp"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2 - T.int64(1), ax3 - T.int64(1)]) + T.writes(pad_temp[ax0, ax1, ax2, ax3]) + pad_temp[ax0, ax1, ax2, ax3] = T.if_then_else(T.int64(1) <= ax2 and ax2 < T.int64(113) and T.int64(1) <= ax3 and ax3 < T.int64(113), rxplaceholder[ax0, ax1, ax2 - T.int64(1), ax3 - T.int64(1)], T.float32(-3.4028234663852886e+38), dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(4), T.int64(6), T.int64(38), T.int64(38), T.int64(3), T.int64(3)): + with T.block("pool_max"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(pad_temp[ax0, ax1, ax2 * T.int64(3) + rv0, ax3 * T.int64(3) + rv1]) + T.writes(pool_max[ax0, ax1, ax2, ax3]) + T.block_attr({"schedule_rule":"meta_schedule.pool_max"}) + with T.init(): + pool_max[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e+38) + pool_max[ax0, ax1, ax2, ax3] = T.max(pool_max[ax0, ax1, ax2, ax3], pad_temp[ax0, ax1, ax2 * T.int64(3) + rv0, ax3 * T.int64(3) + rv1]) + # fmt: on + + mod = LegalizeOps()(MaxPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +@pytest.mark.skip("TOPI pooling casts every shape value to i32.") +def test_max_pool2d_symbolic(): + # fmt: off + @tvm.script.ir_module + class MaxPool2D: + @R.function + def main(dumb_param: R.Tensor(("kh", "kw")), x: R.Tensor(("n", "c", "h", "w"), "float32")) -> R.Tensor(("n", "c", "h - kh + 1", "w - kw + 1"), "float32"): + n = T.var("int64") + c = T.var("int64") + h = T.var("int64") + w = T.var("int64") + kh = T.var("int64") + kw = T.var("int64") + gv: R.Tensor((n, c, h - kh + 1, w - kw + 1), "float32") = R.nn.max_pool2d(x, pool_size=[kh, kw]) + return gv + + # fmt: on + + mod = LegalizeOps()(MaxPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_adaptive_avg_pool2d(): + # fmt: off + @tvm.script.ir_module + class AdaptiveAvgPool2D: + @R.function + def main(x: R.Tensor((2, 4, 7, 7, 16), "float32")) -> R.Tensor((2, 4, 1, 1, 16), "float32"): + gv: R.Tensor((2, 4, 1, 1, 16), "float32") = R.nn.adaptive_avg_pool2d(x, output_size=[1, 1], layout="NCHW16c") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 4, 7, 7, 16), "float32")) -> R.Tensor((2, 4, 1, 1, 16), "float32"): + gv = R.call_tir(adaptive_avg_pool2d, (x,), R.Tensor((2, 4, 1, 1, 16), dtype="float32")) + return gv + + @T.prim_func + def adaptive_avg_pool2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(7), T.int64(7), T.int64(16)), "float32"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16)), "float32")): + T.func_attr({"tir.noalias": True}) + adaptive_pool_sum = T.alloc_buffer([T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16)], dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16), T.int64(7), T.int64(7)): + with T.block("adaptive_pool_sum"): + ax0, ax1, ax2, ax3, ax4, rv0, rv1 = T.axis.remap("SSSSSRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads(rxplaceholder[ax0, ax1, ax2 * T.int64(7) + rv0, ax3 * T.int64(7) + rv1, ax4]) + T.writes(adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4]) + with T.init(): + adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] = T.float32(0) + adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] = adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] + rxplaceholder[ax0, ax1, ax2 * T.int64(7) + rv0, ax3 * T.int64(7) + rv1, ax4] + for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16)): + with T.block("adaptive_pool_avg"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4]) + T.writes(adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4]) + T.block_attr({"schedule_rule":"meta_schedule.adaptive_pool_avg"}) + adaptive_pool_avg[ax0, ax1, ax2, ax3, ax4] = adaptive_pool_sum[ax0, ax1, ax2, ax3, ax4] * T.float32(0.020408163265306121) + # fmt: on + + mod = LegalizeOps()(AdaptiveAvgPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_adaptive_avg_pool2d_without_output_size(): + # fmt: off + @tvm.script.ir_module + class AdaptiveAvgPool2D: + @R.function + def main(x: R.Tensor((2, 16, 7, 7), "float32")) -> R.Tensor((2, 16, 7, 7), "float32"): + gv: R.Tensor((2, 16, 7, 7), "float32") = R.nn.adaptive_avg_pool2d(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 16, 7, 7), "float32")) -> R.Tensor((2, 16, 7, 7), "float32"): + gv = R.call_tir(adaptive_avg_pool2d, (x,), R.Tensor((2, 16, 7, 7), dtype="float32")) + return gv + + @T.prim_func + def adaptive_avg_pool2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(16), T.int64(7), T.int64(7)), "float32"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(16), T.int64(7), T.int64(7)), "float32")): + T.func_attr({"tir.noalias": True}) + adaptive_pool_sum = T.alloc_buffer([T.int64(2), T.int64(16), T.int64(7), T.int64(7)], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(2), T.int64(16), T.int64(7), T.int64(7), T.int64(1), T.int64(1)): + with T.block("adaptive_pool_sum"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1]) + T.writes(adaptive_pool_sum[ax0, ax1, ax2, ax3]) + with T.init(): + adaptive_pool_sum[ax0, ax1, ax2, ax3] = T.float32(0) + adaptive_pool_sum[ax0, ax1, ax2, ax3] = adaptive_pool_sum[ax0, ax1, ax2, ax3] + rxplaceholder[ax0, ax1, ax2 + rv0, ax3 + rv1] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(16), T.int64(7), T.int64(7)): + with T.block("adaptive_pool_avg"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(adaptive_pool_sum[ax0, ax1, ax2, ax3]) + T.writes(adaptive_pool_avg[ax0, ax1, ax2, ax3]) + T.block_attr({"schedule_rule":"meta_schedule.adaptive_pool_avg"}) + adaptive_pool_avg[ax0, ax1, ax2, ax3] = adaptive_pool_sum[ax0, ax1, ax2, ax3] + # fmt: on + + mod = LegalizeOps()(AdaptiveAvgPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +@pytest.mark.skip("TOPI pooling casts every shape value to i32.") +def test_adaptive_avg_pool2d_symbolic(): + # fmt: off + @tvm.script.ir_module + class AdaptiveAvgPool2D: + @R.function + def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", "h", "w"), "float32")) -> R.Tensor(("n", "c", "oh", "ow"), "float32"): + n = T.var("int64") + c = T.var("int64") + oh = T.var("int64") + ow = T.var("int64") + gv: R.Tensor((n, c, oh, ow), "float32") = R.nn.adaptive_avg_pool2d(x, (oh, ow)) + return gv + # fmt: on + + mod = LegalizeOps()(AdaptiveAvgPool2D) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_relu(): + # fmt: off + @tvm.script.ir_module + class Relu: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.relu(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(relu, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def relu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float32(0)) + # fmt: on + + mod = LegalizeOps()(Relu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_relu_symbolic(): + # fmt: off + @tvm.script.ir_module + class Relu: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.nn.relu(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(relu, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def relu(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], T.float32(0)) + # fmt: on + + mod = LegalizeOps()(Relu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_gelu(): + # fmt: off + @tvm.script.ir_module + class Gelu: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.gelu(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(gelu, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def gelu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + T_multiply_1 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + compute = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + T_multiply_2 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + T_divide = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_multiply_1[ax0, ax1]) + T_multiply_1[ax0, ax1] = rxplaceholder[ax0, ax1] * T.float32(0.70710678118654757) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_multiply_1[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.erf(T_multiply_1[i0_1, i1_1], dtype="float32") + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply_1"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(compute[ax0, ax1]) + T.writes(T_multiply_2[ax0, ax1]) + T_multiply_2[ax0, ax1] = compute[ax0, ax1] * T.float32(0.5) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_multiply_2[ax0, ax1]) + T.writes(T_divide[ax0, ax1]) + T_divide[ax0, ax1] = T.float32(0.5) + T_multiply_2[ax0, ax1] + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply_2"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], T_divide[ax0, ax1]) + T.writes(T_multiply[ax0, ax1]) + T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * T_divide[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Gelu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_gelu_symbolic(): + # fmt: off + @tvm.script.ir_module + class Gelu: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.nn.gelu(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(gelu, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def gelu(var_rxplaceholder: T.handle, var_T_multiply: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + T_multiply = T.match_buffer(var_T_multiply, [m, n], dtype="float32") + T_multiply_1 = T.alloc_buffer([m, n], dtype="float32") + compute = T.alloc_buffer([m, n], dtype="float32") + T_multiply_2 = T.alloc_buffer([m, n], dtype="float32") + T_add = T.alloc_buffer([m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_multiply"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_multiply_1[ax0, ax1]) + T_multiply_1[ax0, ax1] = rxplaceholder[ax0, ax1] * T.float32(0.70710678118654757) + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_multiply_1[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.erf(T_multiply_1[i0_1, i1_1], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_multiply_1"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(compute[ax0, ax1]) + T.writes(T_multiply_2[ax0, ax1]) + T_multiply_2[ax0, ax1] = compute[ax0, ax1] * T.float32(0.5) + for i0, i1 in T.grid(m, n): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_multiply_2[ax0, ax1]) + T.writes(T_add[ax0, ax1]) + T_add[ax0, ax1] = T.float32(0.5) + T_multiply_2[ax0, ax1] + for i0, i1 in T.grid(m, n): + with T.block("T_multiply_2"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], T_add[ax0, ax1]) + T.writes(T_multiply[ax0, ax1]) + T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * T_add[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Gelu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_silu(): + # fmt: off + @tvm.script.ir_module + class Silu: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.silu(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(silu, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def silu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + compute = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1]) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], compute[ax0, ax1]) + T.writes(T_multiply[ax0, ax1]) + T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * compute[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Silu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_silu_symbolic(): + # fmt: off + @tvm.script.ir_module + class Silu: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.nn.silu(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(silu, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def silu(var_rxplaceholder: T.handle, var_T_multiply: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + T_multiply = T.match_buffer(var_T_multiply, [m, n], dtype="float32") + compute = T.alloc_buffer([m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1]) + for i0, i1 in T.grid(m, n): + with T.block("T_multiply"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], compute[ax0, ax1]) + T.writes(T_multiply[ax0, ax1]) + T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * compute[ax0, ax1] + # fmt: on + + mod = LegalizeOps()(Silu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_softmax(): + # fmt: off + @tvm.script.ir_module + class Softmax: + @R.function + def main(x: R.Tensor((2, 3, 16, 32), "float32")) -> R.Tensor((2, 3, 16, 32), "float32"): + gv: R.Tensor((2, 3, 16, 32), "float32") = R.nn.softmax(x, axis=-2) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 16, 32), "float32")) -> R.Tensor((2, 3, 16, 32), "float32"): + gv = R.call_tir(softmax, (x,), R.Tensor((2, 3, 16, 32), dtype="float32")) + return gv + + @T.prim_func + def softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"), T_softmax_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32")): + T.func_attr({"tir.noalias": True}) + T_softmax_maxelem = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") + T_softmax_exp = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(16), T.int64(32)], dtype="float32") + T_softmax_expsum = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): + with T.block("T_softmax_maxelem"): + i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, k, i2_1]) + T.writes(T_softmax_maxelem[i0_1, i1_1, i2_1]) + with T.init(): + T_softmax_maxelem[i0_1, i1_1, i2_1] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[i0_1, i1_1, i2_1] = T.max(T_softmax_maxelem[i0_1, i1_1, i2_1], rxplaceholder[i0_1, i1_1, k, i2_1]) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)): + with T.block("T_softmax_exp"): + i0_2, i1_2, i2_2, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_2, i1_2, i2_2, i3_1], T_softmax_maxelem[i0_2, i1_2, i3_1]) + T.writes(T_softmax_exp[i0_2, i1_2, i2_2, i3_1]) + T_softmax_exp[i0_2, i1_2, i2_2, i3_1] = T.exp(rxplaceholder[i0_2, i1_2, i2_2, i3_1] - T_softmax_maxelem[i0_2, i1_2, i3_1], dtype="float32") + for i0_3, i1_3, i2_3, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): + with T.block("T_softmax_expsum"): + i0_4, i1_4, i2_4, k = T.axis.remap("SSSR", [i0_3, i1_3, i2_3, i3]) + T.reads(T_softmax_exp[i0_4, i1_4, k, i2_4]) + T.writes(T_softmax_expsum[i0_4, i1_4, i2_4]) + with T.init(): + T_softmax_expsum[i0_4, i1_4, i2_4] = T.float32(0) + T_softmax_expsum[i0_4, i1_4, i2_4] = T_softmax_expsum[i0_4, i1_4, i2_4] + T_softmax_exp[i0_4, i1_4, k, i2_4] + for i0_5, i1_5, i2_5, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)): + with T.block("T_softmax_norm"): + i0_6, i1_6, i2_6, i3_2 = T.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3]) + T.reads(T_softmax_exp[i0_6, i1_6, i2_6, i3_2], T_softmax_expsum[i0_6, i1_6, i3_2]) + T.writes(T_softmax_norm[i0_6, i1_6, i2_6, i3_2]) + T.block_attr({"axis":2}) + T_softmax_norm[i0_6, i1_6, i2_6, i3_2] = T_softmax_exp[i0_6, i1_6, i2_6, i3_2] / T_softmax_expsum[i0_6, i1_6, i3_2] + # fmt: on + + mod = LegalizeOps()(Softmax) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_softmax_symbolic(): + # fmt: off + @tvm.script.ir_module + class Softmax: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + gv: R.Tensor((a, b, c), "float32") = R.nn.softmax(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + gv = R.call_tir(softmax, (x,), R.Tensor((a, b, c), dtype="float32")) + return gv + + @T.prim_func + def softmax(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") + T_softmax_norm = T.match_buffer(var_T_softmax_norm, [a, b, c], dtype="float32") + T_softmax_maxelem = T.alloc_buffer([a, b], dtype="float32") + T_softmax_exp = T.alloc_buffer([a, b, c], dtype="float32") + T_softmax_expsum = T.alloc_buffer([a, b], dtype="float32") + for i0, i1, i2 in T.grid(a, b, c): + with T.block("T_softmax_maxelem"): + i0_1, i1_1, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, k]) + T.writes(T_softmax_maxelem[i0_1, i1_1]) + with T.init(): + T_softmax_maxelem[i0_1, i1_1] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem[i0_1, i1_1] = T.max(T_softmax_maxelem[i0_1, i1_1], rxplaceholder[i0_1, i1_1, k]) + for i0, i1, i2 in T.grid(a, b, c): + with T.block("T_softmax_exp"): + i0_2, i1_2, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_2, i1_2, i2_1], T_softmax_maxelem[i0_2, i1_2]) + T.writes(T_softmax_exp[i0_2, i1_2, i2_1]) + T_softmax_exp[i0_2, i1_2, i2_1] = T.exp(rxplaceholder[i0_2, i1_2, i2_1] - T_softmax_maxelem[i0_2, i1_2], dtype="float32") + for i0_3, i1_3, i2 in T.grid(a, b, c): + with T.block("T_softmax_expsum"): + i0_4, i1_4, k = T.axis.remap("SSR", [i0_3, i1_3, i2]) + T.reads(T_softmax_exp[i0_4, i1_4, k]) + T.writes(T_softmax_expsum[i0_4, i1_4]) + with T.init(): + T_softmax_expsum[i0_4, i1_4] = T.float32(0) + T_softmax_expsum[i0_4, i1_4] = T_softmax_expsum[i0_4, i1_4] + T_softmax_exp[i0_4, i1_4, k] + for i0_5, i1_5, i2 in T.grid(a, b, c): + with T.block("T_softmax_norm"): + i0_6, i1_6, i2_2 = T.axis.remap("SSS", [i0_5, i1_5, i2]) + T.reads(T_softmax_exp[i0_6, i1_6, i2_2], T_softmax_expsum[i0_6, i1_6]) + T.writes(T_softmax_norm[i0_6, i1_6, i2_2]) + T.block_attr({"axis":2}) + T_softmax_norm[i0_6, i1_6, i2_2] = T_softmax_exp[i0_6, i1_6, i2_2] / T_softmax_expsum[i0_6, i1_6] + # fmt: on + + mod = LegalizeOps()(Softmax) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_batch_norm(): + # fmt: off + @tvm.script.ir_module + class BatchNorm: + @R.function + def main(x: R.Tensor((2, 3, 28, 28), "float32"), gamma: R.Tensor((3,), "float32"), beta: R.Tensor((3,), "float32"), moving_mean: R.Tensor((3,), "float32"), moving_var: R.Tensor((3,), "float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")): + gv: R.Tuple(R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 28, 28), "float32"), gamma: R.Tensor((3,), "float32"), beta: R.Tensor((3,), "float32"), moving_mean: R.Tensor((3,), "float32"), moving_var: R.Tensor((3,), "float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")): + gv = R.call_tir(batch_norm, (x, gamma, beta, moving_mean, moving_var), [R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")]) + return gv + + @T.prim_func + def batch_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer(T.int64(3), "float32"), rxplaceholder_2: T.Buffer(T.int64(3), "float32"), rxplaceholder_3: T.Buffer(T.int64(3), "float32"), rxplaceholder_4: T.Buffer(T.int64(3), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), T_multiply: T.Buffer(T.int64(3), "float32"), T_multiply_1: T.Buffer(T.int64(3), "float32")): + T.func_attr({"tir.noalias": True}) + T_reshape = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(1), T.int64(1)], dtype="float32") + T_subtract = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(28), T.int64(28)], dtype="float32") + T_reshape_1 = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(1), T.int64(1)], dtype="float32") + T_add_1 = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(1), T.int64(1)], dtype="float32") + compute = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(1), T.int64(1)], dtype="float32") + T_divide = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(28), T.int64(28)], dtype="float32") + T_reshape_2 = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(1), T.int64(1)], dtype="float32") + T_multiply_2 = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(28), T.int64(28)], dtype="float32") + T_reshape_3 = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(1), T.int64(1)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_reshape"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_3[(ax1 + ax2 + ax3) % T.int64(3)]) + T.writes(T_reshape[ax0, ax1, ax2, ax3]) + T_reshape[ax0, ax1, ax2, ax3] = rxplaceholder_3[(ax1 + ax2 + ax3) % T.int64(3)] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_reshape[T.int64(0), ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_reshape[T.int64(0), ax1, T.int64(0), T.int64(0)] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_reshape_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_4[(ax1 + ax2 + ax3) % T.int64(3)]) + T.writes(T_reshape_1[ax0, ax1, ax2, ax3]) + T_reshape_1[ax0, ax1, ax2, ax3] = rxplaceholder_4[(ax1 + ax2 + ax3) % T.int64(3)] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_reshape_1[ax0, ax1, ax2, ax3]) + T.writes(T_add_1[ax0, ax1, ax2, ax3]) + T_add_1[ax0, ax1, ax2, ax3] = T_reshape_1[ax0, ax1, ax2, ax3] + T.float32(1.0000000000000001e-05) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("compute"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_1[i0_1, i1_1, i2_1, i3_1]) + T.writes(compute[i0_1, i1_1, i2_1, i3_1]) + compute[i0_1, i1_1, i2_1, i3_1] = T.sqrt(T_add_1[i0_1, i1_1, i2_1, i3_1], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3], compute[T.int64(0), ax1, T.int64(0), T.int64(0)]) + T.writes(T_divide[ax0, ax1, ax2, ax3]) + T_divide[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] / compute[T.int64(0), ax1, T.int64(0), T.int64(0)] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_reshape_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[(ax1 + ax2 + ax3) % T.int64(3)]) + T.writes(T_reshape_2[ax0, ax1, ax2, ax3]) + T_reshape_2[ax0, ax1, ax2, ax3] = rxplaceholder_1[(ax1 + ax2 + ax3) % T.int64(3)] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_divide[ax0, ax1, ax2, ax3], T_reshape_2[T.int64(0), ax1, T.int64(0), T.int64(0)]) + T.writes(T_multiply_2[ax0, ax1, ax2, ax3]) + T_multiply_2[ax0, ax1, ax2, ax3] = T_divide[ax0, ax1, ax2, ax3] * T_reshape_2[T.int64(0), ax1, T.int64(0), T.int64(0)] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_reshape_3"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_2[(ax1 + ax2 + ax3) % T.int64(3)]) + T.writes(T_reshape_3[ax0, ax1, ax2, ax3]) + T_reshape_3[ax0, ax1, ax2, ax3] = rxplaceholder_2[(ax1 + ax2 + ax3) % T.int64(3)] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_multiply_2[ax0, ax1, ax2, ax3], T_reshape_3[T.int64(0), ax1, T.int64(0), T.int64(0)]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = T_multiply_2[ax0, ax1, ax2, ax3] + T_reshape_3[T.int64(0), ax1, T.int64(0), T.int64(0)] + for i0 in T.serial(T.int64(3)): + with T.block("T_multiply_1"): + ax0 = T.axis.spatial(T.int64(3), i0) + T.reads(rxplaceholder_3[ax0]) + T.writes(T_multiply[ax0]) + T_multiply[ax0] = rxplaceholder_3[ax0] + for i0 in T.serial(T.int64(3)): + with T.block("T_multiply_2"): + ax0 = T.axis.spatial(T.int64(3), i0) + T.reads(rxplaceholder_4[ax0]) + T.writes(T_multiply_1[ax0]) + T_multiply_1[ax0] = rxplaceholder_4[ax0] + # fmt: on + + mod = LegalizeOps()(BatchNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_batch_norm_symbolic(): + # fmt: off + @tvm.script.ir_module + class BatchNorm: + @R.function + def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), R.Tensor(("c",), "float32")): + n = T.var("int64") + h = T.var("int64") + w = T.var("int64") + c = T.var("int64") + gv: R.Tuple(R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), R.Tensor((c,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=-1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), R.Tensor(("c",), "float32")): + n = T.var("int64") + h = T.var("int64") + w = T.var("int64") + c = T.var("int64") + gv = R.call_tir(batch_norm, (x, gamma, beta, moving_mean, moving_var), [R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), R.Tensor((c,), "float32")]) + return gv + + @T.prim_func + def batch_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_rxplaceholder_3: T.handle, var_rxplaceholder_4: T.handle, var_T_add: T.handle, var_T_multiply: T.handle, var_T_multiply_1: T.handle): + T.func_attr({"tir.noalias": True}) + c = T.var("int64") + h = T.var("int64") + n = T.var("int64") + w = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, h, w, c], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [c], dtype="float32") + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [c], dtype="float32") + rxplaceholder_3 = T.match_buffer(var_rxplaceholder_3, [c], dtype="float32") + rxplaceholder_4 = T.match_buffer(var_rxplaceholder_4, [c], dtype="float32") + T_add = T.match_buffer(var_T_add, [n, h, w, c], dtype="float32") + T_multiply = T.match_buffer(var_T_multiply, [c], dtype="float32") + T_multiply_1 = T.match_buffer(var_T_multiply_1, [c], dtype="float32") + T_reshape = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(1), c], dtype="float32") + T_subtract = T.alloc_buffer([n, h, w, c], dtype="float32") + T_reshape_1 = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(1), c], dtype="float32") + T_add_1 = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(1), c], dtype="float32") + compute = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(1), c], dtype="float32") + T_divide = T.alloc_buffer([n, h, w, c], dtype="float32") + T_reshape_2 = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(1), c], dtype="float32") + T_multiply_2 = T.alloc_buffer([n, h, w, c], dtype="float32") + T_reshape_3 = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(1), c], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(1), c): + with T.block("T_reshape"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_3[((ax0 + ax1 + ax2) * c + ax3) % c]) + T.writes(T_reshape[ax0, ax1, ax2, ax3]) + T_reshape[ax0, ax1, ax2, ax3] = rxplaceholder_3[((ax0 + ax1 + ax2) * c + ax3) % c] + for i0, i1, i2, i3 in T.grid(n, h, w, c): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_reshape[T.int64(0), T.int64(0), T.int64(0), ax3]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_reshape[T.int64(0), T.int64(0), T.int64(0), ax3] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(1), c): + with T.block("T_reshape_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_4[((ax0 + ax1 + ax2) * c + ax3) % c]) + T.writes(T_reshape_1[ax0, ax1, ax2, ax3]) + T_reshape_1[ax0, ax1, ax2, ax3] = rxplaceholder_4[((ax0 + ax1 + ax2) * c + ax3) % c] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(1), c): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_reshape_1[ax0, ax1, ax2, ax3]) + T.writes(T_add_1[ax0, ax1, ax2, ax3]) + T_add_1[ax0, ax1, ax2, ax3] = T_reshape_1[ax0, ax1, ax2, ax3] + T.float32(1.0000000000000001e-05) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(1), c): + with T.block("compute"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_1[i0_1, i1_1, i2_1, i3_1]) + T.writes(compute[i0_1, i1_1, i2_1, i3_1]) + compute[i0_1, i1_1, i2_1, i3_1] = T.sqrt(T_add_1[i0_1, i1_1, i2_1, i3_1], dtype="float32") + for i0, i1, i2, i3 in T.grid(n, h, w, c): + with T.block("T_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3], compute[T.int64(0), T.int64(0), T.int64(0), ax3]) + T.writes(T_divide[ax0, ax1, ax2, ax3]) + T_divide[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] / compute[T.int64(0), T.int64(0), T.int64(0), ax3] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(1), c): + with T.block("T_reshape_2"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[((ax0 + ax1 + ax2) * c + ax3) % c]) + T.writes(T_reshape_2[ax0, ax1, ax2, ax3]) + T_reshape_2[ax0, ax1, ax2, ax3] = rxplaceholder_1[((ax0 + ax1 + ax2) * c + ax3) % c] + for i0, i1, i2, i3 in T.grid(n, h, w, c): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_divide[ax0, ax1, ax2, ax3], T_reshape_2[T.int64(0), T.int64(0), T.int64(0), ax3]) + T.writes(T_multiply_2[ax0, ax1, ax2, ax3]) + T_multiply_2[ax0, ax1, ax2, ax3] = T_divide[ax0, ax1, ax2, ax3] * T_reshape_2[T.int64(0), T.int64(0), T.int64(0), ax3] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(1), c): + with T.block("T_reshape_3"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_2[((ax0 + ax1 + ax2) * c + ax3) % c]) + T.writes(T_reshape_3[ax0, ax1, ax2, ax3]) + T_reshape_3[ax0, ax1, ax2, ax3] = rxplaceholder_2[((ax0 + ax1 + ax2) * c + ax3) % c] + for i0, i1, i2, i3 in T.grid(n, h, w, c): + with T.block("T_add_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_multiply_2[ax0, ax1, ax2, ax3], T_reshape_3[T.int64(0), T.int64(0), T.int64(0), ax3]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = T_multiply_2[ax0, ax1, ax2, ax3] + T_reshape_3[T.int64(0), T.int64(0), T.int64(0), ax3] + for i0 in T.serial(c): + with T.block("T_multiply_1"): + ax0 = T.axis.spatial(c, i0) + T.reads(rxplaceholder_3[ax0]) + T.writes(T_multiply[ax0]) + T_multiply[ax0] = rxplaceholder_3[ax0] + for i0 in T.serial(c): + with T.block("T_multiply_2"): + ax0 = T.axis.spatial(c, i0) + T.reads(rxplaceholder_4[ax0]) + T.writes(T_multiply_1[ax0]) + T_multiply_1[ax0] = rxplaceholder_4[ax0] + # fmt: on + + mod = LegalizeOps()(BatchNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_layer_norm(): + # fmt: off + @tvm.script.ir_module + class LayerNorm: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32"), gamma: R.Tensor((4, 5), "float32"), beta: R.Tensor((4, 5), "float32")) -> R.Tensor((2, 3, 4, 5), "float32"): + gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.layer_norm(x, gamma, beta, axes=[-2, -1]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32"), gamma: R.Tensor((4, 5), "float32"), beta: R.Tensor((4, 5), "float32")) -> R.Tensor((2, 3, 4, 5), "float32"): + gv = R.call_tir(layer_norm, (x, gamma, beta), R.Tensor((2, 3, 4, 5), dtype="float32")) + return gv + + @T.prim_func + def layer_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(5)), "float32"), rxplaceholder_2: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_layer_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")): + T.func_attr({"tir.noalias": True}) + rxplaceholder_red_temp_v0 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + rxplaceholder_red_temp_v1 = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red_temp"): + ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, k2, k3]) + T.writes(rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1]) + with T.init(): + rxplaceholder_red_temp_v0[ax0, ax1] = T.float32(0) + rxplaceholder_red_temp_v1[ax0, ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[ax0, ax1] + rxplaceholder[ax0, ax1, k2, k3] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[ax0, ax1] + rxplaceholder[ax0, ax1, k2, k3] * rxplaceholder[ax0, ax1, k2, k3] + rxplaceholder_red_temp_v0[ax0, ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[ax0, ax1] = v_rxplaceholder_red_temp_v1 + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_layer_norm"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], rxplaceholder_1[ax2, ax3], rxplaceholder_2[ax2, ax3]) + T.writes(T_layer_norm[ax0, ax1, ax2, ax3]) + T_layer_norm[ax0, ax1, ax2, ax3] = (rxplaceholder[ax0, ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) * T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] * T.float32(0.05) - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05) * (rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) + T.float32(1e-05), dtype="float32") * rxplaceholder_1[ax2, ax3] + rxplaceholder_2[ax2, ax3] + # fmt: on + mod = LegalizeOps()(LayerNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_layer_norm_symbolic(): + # fmt: off + @tvm.script.ir_module + class LayerNorm: + @R.function + def main(x: R.Tensor(("n", "s", "f"), "float32"), gamma: R.Tensor(("s", "f"), "float32"), beta: R.Tensor(("s", "f"), "float32")) -> R.Tensor(("n", "s", "f"), "float32"): + n = T.var("int64") + s = T.var("int64") + f = T.var("int64") + gv: R.Tensor((n, s, f), "float32") = R.nn.layer_norm(x, gamma, beta, axes=[1, 2]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("n", "s", "f"), "float32"), gamma: R.Tensor(("s", "f"), "float32"), beta: R.Tensor(("s", "f"), "float32")) -> R.Tensor(("n", "s", "f"), "float32"): + n = T.var("int64") + s = T.var("int64") + f = T.var("int64") + gv = R.call_tir(layer_norm, (x, gamma, beta), R.Tensor((n, s, f), dtype="float32")) + return gv + + @T.prim_func + def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle): + T.func_attr({"tir.noalias": True}) + f = T.var("int64") + n = T.var("int64") + s = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, s, f], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [s, f], dtype="float32") + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [s, f], dtype="float32") + T_layer_norm = T.match_buffer(var_T_layer_norm, [n, s, f], dtype="float32") + rxplaceholder_red_temp_v0 = T.alloc_buffer([n], dtype="float32") + rxplaceholder_red_temp_v1 = T.alloc_buffer([n], dtype="float32") + for i0, i1, i2 in T.grid(n, s, f): + with T.block("rxplaceholder_red_temp"): + ax0, k1, k2 = T.axis.remap("SRR", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, k1, k2]) + T.writes(rxplaceholder_red_temp_v0[ax0], rxplaceholder_red_temp_v1[ax0]) + with T.init(): + rxplaceholder_red_temp_v0[ax0] = T.float32(0) + rxplaceholder_red_temp_v1[ax0] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[ax0] + rxplaceholder[ax0, k1, k2] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[ax0] + rxplaceholder[ax0, k1, k2] * rxplaceholder[ax0, k1, k2] + rxplaceholder_red_temp_v0[ax0] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[ax0] = v_rxplaceholder_red_temp_v1 + for i0, i1, i2 in T.grid(n, s, f): + with T.block("T_layer_norm"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_red_temp_v0[ax0], rxplaceholder_red_temp_v1[ax0], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, ax2]) + T.writes(T_layer_norm[ax0, ax1, ax2]) + T_layer_norm[ax0, ax1, ax2] = (rxplaceholder[ax0, ax1, ax2] - rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32", f))) * T.rsqrt(rxplaceholder_red_temp_v1[ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) - rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32", f)) * (rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32", f))) + T.float32(1e-05), dtype="float32") * rxplaceholder_1[ax1, ax2] + rxplaceholder_2[ax1, ax2] + # fmt: on + mod = LegalizeOps()(LayerNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py new file mode 100644 index 000000000000..4c31077d9c4b --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -0,0 +1,793 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T +import tvm.testing + + +##################### Search ##################### + + +def test_where(): + # fmt: off + @tvm.script.ir_module + class Where: + @R.function + def main(condition: R.Tensor((3, 2, 1), "bool"), x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32")) -> R.Tensor((3, 2, 3), "float32"): + gv: R.Tensor((3, 2, 3), "float32") = R.where(condition, x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(condition: R.Tensor((3, 2, 1), "bool"), x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32")) -> R.Tensor((3, 2, 3), "float32"): + gv = R.call_tir(where, (condition, x, y), R.Tensor((3, 2, 3), dtype="float32")) + return gv + + @T.prim_func + def where(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(1)), "bool"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3)), "float32"), rxplaceholder_2: T.Buffer((T.int64(2), T.int64(1)), "float32"), T_where: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(3), T.int64(2), T.int64(3)): + with T.block("T_where"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, T.int64(0)], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)]) + T.writes(T_where[ax0, ax1, ax2]) + T_where[ax0, ax1, ax2] = T.Select(0 < T.Cast("int32", rxplaceholder[ax0, ax1, T.int64(0)]), rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)]) + # fmt: on + + mod = LegalizeOps()(Where) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_where_symbolic(): + # fmt: off + @tvm.script.ir_module + class Where: + @R.function + def main(condition: R.Tensor(("a", "b", 1), "bool"), x: R.Tensor(("b", "c"), "float32"), y: R.Tensor(("b", 1), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + gv: R.Tensor((a, b, c), "float32") = R.where(condition, x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(condition: R.Tensor(("a", "b", 1), "bool"), x: R.Tensor(("b", "c"), "float32"), y: R.Tensor(("b", 1), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + gv = R.call_tir(where, (condition, x, y), R.Tensor((a, b, c), dtype="float32")) + return gv + + @T.prim_func + def where(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_where: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, T.int64(1)], dtype="bool") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [b, c], dtype="float32") + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [b, T.int64(1)], dtype="float32") + T_where = T.match_buffer(var_T_where, [a, b, c], dtype="float32") + for i0, i1, i2 in T.grid(a, b, c): + with T.block("T_where"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, T.int64(0)], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)]) + T.writes(T_where[ax0, ax1, ax2]) + T_where[ax0, ax1, ax2] = T.Select(0 < T.Cast("int32", rxplaceholder[ax0, ax1, T.int64(0)]), rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)]) + # fmt: on + + mod = LegalizeOps()(Where) + tvm.ir.assert_structural_equal(mod, Expected) + + +##################### Statistical ##################### + + +def test_max(): + # fmt: off + @tvm.script.ir_module + class Max: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 5), "float32"): + gv: R.Tensor((2, 5), "float32") = R.max(x, axis=[1, 2]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 5), "float32"): + gv = R.call_tir(max, (x,), R.Tensor((2, 5), dtype="float32")) + return gv + + @T.prim_func + def max(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(5)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(5), T.int64(3), T.int64(4)): + with T.block("rxplaceholder_red"): + ax0, ax1, k1, k2 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, k1, k2, ax1]) + T.writes(rxplaceholder_red[ax0, ax1]) + with T.init(): + rxplaceholder_red[ax0, ax1] = T.min_value("float32") + rxplaceholder_red[ax0, ax1] = T.max(rxplaceholder_red[ax0, ax1], rxplaceholder[ax0, k1, k2, ax1]) + # fmt: on + + mod = LegalizeOps()(Max) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_max_symbolic(): + # fmt: off + @tvm.script.ir_module + class Max: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", "d"), "float32"): + a = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, d), "float32") = R.max(x, axis=[1, 2]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", "d"), "float32"): + a = T.var("int64") + d = T.var("int64") + gv = R.call_tir(max, (x,), R.Tensor((a, d), dtype="float32")) + return gv + + @T.prim_func + def max(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, [a, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, d, b, c): + with T.block("rxplaceholder_red"): + ax0, ax1, k1, k2 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, k1, k2, ax1]) + T.writes(rxplaceholder_red[ax0, ax1]) + with T.init(): + rxplaceholder_red[ax0, ax1] = T.min_value("float32") + rxplaceholder_red[ax0, ax1] = T.max(rxplaceholder_red[ax0, ax1], rxplaceholder[ax0, k1, k2, ax1]) + # fmt: on + + mod = LegalizeOps()(Max) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_min(): + # fmt: off + @tvm.script.ir_module + class Min: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 1, 1, 5), "float32"): + gv: R.Tensor((2, 1, 1, 5), "float32") = R.min(x, axis=[1, 2], keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 1, 1, 5), "float32"): + gv = R.call_tir(min, (x,), R.Tensor((2, 1, 1, 5), dtype="float32")) + return gv + + @T.prim_func + def min(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(1), T.int64(1), T.int64(5)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(5), T.int64(3), T.int64(4)): + with T.block("rxplaceholder_red"): + ax0, ax1, ax2, ax3, k1, k2 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[ax0, k1, k2, ax3]) + T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) + with T.init(): + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.max_value("float32") + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.min(rxplaceholder_red[ax0, ax1, ax2, ax3], rxplaceholder[ax0, k1, k2, ax3]) + # fmt: on + + mod = LegalizeOps()(Min) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_min_symbolic(): + # fmt: off + @tvm.script.ir_module + class Min: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", 1, 1, "d"), "float32"): + a = T.var("int64") + d = T.var("int64") + gv: R.Tensor((a, 1, 1, d), "float32") = R.min(x, axis=[1, 2], keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", 1, 1, "d"), "float32"): + a = T.var("int64") + d = T.var("int64") + gv = R.call_tir(min, (x,), R.Tensor((a, 1, 1, d), dtype="float32")) + return gv + + @T.prim_func + def min(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, [a, T.int64(1), T.int64(1), d], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(a, T.int64(1), T.int64(1), d, b, c): + with T.block("rxplaceholder_red"): + ax0, ax1, ax2, ax3, k1, k2 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[ax0, k1, k2, ax3]) + T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) + with T.init(): + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.max_value("float32") + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.min(rxplaceholder_red[ax0, ax1, ax2, ax3], rxplaceholder[ax0, k1, k2, ax3]) + # fmt: on + + mod = LegalizeOps()(Min) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sum(): + # fmt: off + @tvm.script.ir_module + class Sum: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.sum(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((), "float32"): + gv = R.call_tir(sum, (x,), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def sum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red"): + k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k0, k1, k2, k3]) + T.writes(rxplaceholder_red[()]) + with T.init(): + rxplaceholder_red[()] = T.float32(0) + rxplaceholder_red[()] = rxplaceholder_red[()] + rxplaceholder[k0, k1, k2, k3] + # fmt: on + + mod = LegalizeOps()(Sum) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sum_symbolic(): + # fmt: off + @tvm.script.ir_module + class Sum: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.sum(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((), "float32"): + gv = R.call_tir(sum, (x,), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def sum(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((), "float32")): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("rxplaceholder_red"): + k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k0, k1, k2, k3]) + T.writes(rxplaceholder_red[()]) + with T.init(): + rxplaceholder_red[()] = T.float32(0) + rxplaceholder_red[()] = rxplaceholder_red[()] + rxplaceholder[k0, k1, k2, k3] + # fmt: on + + mod = LegalizeOps()(Sum) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_prod(): + # fmt: off + @tvm.script.ir_module + class Prod: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((1, 1, 1, 1), "float32"): + gv: R.Tensor((1, 1, 1, 1), "float32") = R.prod(x, keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((1, 1, 1, 1), "float32"): + gv = R.call_tir(prod, (x,), R.Tensor((1, 1, 1, 1), dtype="float32")) + return gv + + @T.prim_func + def prod(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red"): + ax0, ax1, ax2, ax3, k0, k1, k2, k3 = T.axis.remap("SSSSRRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) + T.reads(rxplaceholder[k0, k1, k2, k3]) + T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) + with T.init(): + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.float32(1) + rxplaceholder_red[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] * rxplaceholder[k0, k1, k2, k3] + # fmt: on + + mod = LegalizeOps()(Prod) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_prod_symbolic(): + # fmt: off + @tvm.script.ir_module + class Prod: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, 1, 1, 1), "float32"): + gv: R.Tensor((1, 1, 1, 1), "float32") = R.prod(x, keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, 1, 1, 1), "float32"): + gv = R.call_tir(prod, (x,), R.Tensor((1, 1, 1, 1), dtype="float32")) + return gv + + @T.prim_func + def prod(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1)), "float32")): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), a, b, c, d): + with T.block("rxplaceholder_red"): + ax0, ax1, ax2, ax3, k0, k1, k2, k3 = T.axis.remap("SSSSRRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) + T.reads(rxplaceholder[k0, k1, k2, k3]) + T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) + with T.init(): + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.float32(1) + rxplaceholder_red[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] * rxplaceholder[k0, k1, k2, k3] + # fmt: on + + mod = LegalizeOps()(Prod) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_mean(): + # fmt: off + @tvm.script.ir_module + class Mean: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((3, 4), "float32"): + gv: R.Tensor((3, 4), "float32") = R.mean(x, [0, 3]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((3, 4), "float32"): + gv = R.call_tir(mean, (x,), R.Tensor((3, 4), dtype="float32")) + return gv + + @T.prim_func + def mean(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), T_divide: T.Buffer((T.int64(3), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + rxplaceholder_red = T.alloc_buffer([T.int64(3), T.int64(4)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(3), T.int64(4), T.int64(2), T.int64(5)): + with T.block("rxplaceholder_red"): + ax0, ax1, k0, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k0, ax0, ax1, k3]) + T.writes(rxplaceholder_red[ax0, ax1]) + with T.init(): + rxplaceholder_red[ax0, ax1] = T.float32(0) + rxplaceholder_red[ax0, ax1] = rxplaceholder_red[ax0, ax1] + rxplaceholder[k0, ax0, ax1, k3] + for i0, i1 in T.grid(T.int64(3), T.int64(4)): + with T.block("T_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder_red[ax0, ax1]) + T.writes(T_divide[ax0, ax1]) + T_divide[ax0, ax1] = rxplaceholder_red[ax0, ax1] * T.float32(0.1) + # fmt: on + + mod = LegalizeOps()(Mean) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_mean_symbolic(): + # fmt: off + @tvm.script.ir_module + class Mean: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("b", "c"), "float32"): + b = T.var("int64") + c = T.var("int64") + gv: R.Tensor((b, c), "float32") = R.mean(x, [0, 3]) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor(("b", "c"), dtype="float32"): + b = T.var("int64") + c = T.var("int64") + gv = R.call_tir(mean, (x,), R.Tensor((b, c), dtype="float32")) + return gv + + @T.prim_func + def mean(var_rxplaceholder: T.handle, var_T_divide: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + T_divide = T.match_buffer(var_T_divide, [b, c], dtype="float32") + rxplaceholder_red = T.alloc_buffer([b, c], dtype="float32") + for i0, i1, i2, i3 in T.grid(b, c, a, d): + with T.block("rxplaceholder_red"): + ax0, ax1, k0, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k0, ax0, ax1, k3]) + T.writes(rxplaceholder_red[ax0, ax1]) + with T.init(): + rxplaceholder_red[ax0, ax1] = T.float32(0) + rxplaceholder_red[ax0, ax1] = rxplaceholder_red[ax0, ax1] + rxplaceholder[k0, ax0, ax1, k3] + for i0, i1 in T.grid(b, c): + with T.block("T_divide"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder_red[ax0, ax1]) + T.writes(T_divide[ax0, ax1]) + T_divide[ax0, ax1] = rxplaceholder_red[ax0, ax1] / T.Cast("float32", a * d) + # fmt: on + + mod = LegalizeOps()(Mean) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_std(): + # fmt: off + @tvm.script.ir_module + class Std: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.std(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((), "float32"): + gv = R.call_tir(std, (x,), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def std(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), compute: T.Buffer((), "float32")): + T.func_attr({"tir.noalias": True}) + rxplaceholder_red = T.alloc_buffer([], dtype="float32") + T_divide = T.alloc_buffer([], dtype="float32") + T_subtract = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32") + T_multiply = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32") + T_multiply_red = T.alloc_buffer([], dtype="float32") + T_divide_1 = T.alloc_buffer([], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red"): + k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k0, k1, k2, k3]) + T.writes(rxplaceholder_red[()]) + with T.init(): + rxplaceholder_red[()] = T.float32(0) + rxplaceholder_red[()] = rxplaceholder_red[()] + rxplaceholder[k0, k1, k2, k3] + with T.block("T_divide"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(rxplaceholder_red[()]) + T.writes(T_divide[()]) + T_divide[()] = rxplaceholder_red[()] * T.float32(0.0083333333333333332) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide[()]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_divide[()] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] * T_subtract[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_multiply_red"): + k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) + T.reads(T_multiply[k0, k1, k2, k3]) + T.writes(T_multiply_red[()]) + with T.init(): + T_multiply_red[()] = T.float32(0) + T_multiply_red[()] = T_multiply_red[()] + T_multiply[k0, k1, k2, k3] + with T.block("T_divide_1"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_red[()]) + T.writes(T_divide_1[()]) + T_divide_1[()] = T_multiply_red[()] * T.float32(0.0083333333333333332) + with T.block("compute"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_divide_1[()]) + T.writes(compute[()]) + compute[()] = T.sqrt(T_divide_1[()]) + # fmt: on + + mod = LegalizeOps()(Std) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_std_symbolic(): + # fmt: off + @tvm.script.ir_module + class Std: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.std(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((), "float32"): + gv = R.call_tir(std, (x,), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def std(var_rxplaceholder: T.handle, compute: T.Buffer((), "float32")): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + rxplaceholder_red = T.alloc_buffer([], dtype="float32") + T_divide = T.alloc_buffer([], dtype="float32") + T_subtract = T.alloc_buffer([a, b, c, d], dtype="float32") + T_multiply = T.alloc_buffer([a, b, c, d], dtype="float32") + T_multiply_red = T.alloc_buffer([], dtype="float32") + T_divide_1 = T.alloc_buffer([], dtype="float32") + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("rxplaceholder_red"): + k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[k0, k1, k2, k3]) + T.writes(rxplaceholder_red[()]) + with T.init(): + rxplaceholder_red[()] = T.float32(0) + rxplaceholder_red[()] = rxplaceholder_red[()] + rxplaceholder[k0, k1, k2, k3] + with T.block("T_divide"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(rxplaceholder_red[()]) + T.writes(T_divide[()]) + T_divide[()] = rxplaceholder_red[()] / T.Cast("float32", a * b * c * d) + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide[()]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_divide[()] + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] * T_subtract[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_multiply_red"): + k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) + T.reads(T_multiply[k0, k1, k2, k3]) + T.writes(T_multiply_red[()]) + with T.init(): + T_multiply_red[()] = T.float32(0) + T_multiply_red[()] = T_multiply_red[()] + T_multiply[k0, k1, k2, k3] + with T.block("T_divide_1"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_red[()]) + T.writes(T_divide_1[()]) + T_divide_1[()] = T_multiply_red[()] / T.Cast("float32", a * b * c * d) + with T.block("compute"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_divide_1[()]) + T.writes(compute[()]) + compute[()] = T.sqrt(T_divide_1[()]) + # fmt: on + + mod = LegalizeOps()(Std) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_variance(): + # fmt: off + @tvm.script.ir_module + class Variance: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((1, 3, 4, 1), "float32"): + gv: R.Tensor((1, 3, 4, 1), "float32") = R.variance(x, [0, 3], keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((1, 3, 4, 1), dtype="float32"): + gv = R.call_tir(variance, (x,), R.Tensor((1, 3, 4, 1), dtype="float32")) + return gv + + @T.prim_func + def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), T_divide: T.Buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(1)), "float32")): + T.func_attr({"tir.noalias": True}) + rxplaceholder_red = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(4), T.int64(1)], dtype="float32") + T_divide_1 = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(4), T.int64(1)], dtype="float32") + T_subtract = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32") + T_multiply = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32") + T_multiply_red = T.alloc_buffer([T.int64(1), T.int64(3), T.int64(4), T.int64(1)], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1), T.int64(2), T.int64(5)): + with T.block("rxplaceholder_red"): + ax0, ax1, ax2, ax3, k0, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[k0, ax1, ax2, k3]) + T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) + with T.init(): + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.float32(0) + rxplaceholder_red[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] + rxplaceholder[k0, ax1, ax2, k3] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1)): + with T.block("T_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_red[ax0, ax1, ax2, ax3]) + T.writes(T_divide_1[ax0, ax1, ax2, ax3]) + T_divide_1[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] * T.float32(0.10000000000000001) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide_1[T.int64(0), ax1, ax2, T.int64(0)]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_divide_1[T.int64(0), ax1, ax2, T.int64(0)] + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] * T_subtract[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1), T.int64(2), T.int64(5)): + with T.block("T_multiply_red"): + ax0, ax1, ax2, ax3, k0, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(T_multiply[k0, ax1, ax2, k3]) + T.writes(T_multiply_red[ax0, ax1, ax2, ax3]) + with T.init(): + T_multiply_red[ax0, ax1, ax2, ax3] = T.float32(0) + T_multiply_red[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] + T_multiply[k0, ax1, ax2, k3] + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(1)): + with T.block("T_divide_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_multiply_red[ax0, ax1, ax2, ax3]) + T.writes(T_divide[ax0, ax1, ax2, ax3]) + T_divide[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] * T.float32(0.10000000000000001) + # fmt: on + + mod = LegalizeOps()(Variance) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_variance_symbolic(): + # fmt: off + @tvm.script.ir_module + class Variance: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, "b", "c", 1), "float32"): + b = T.var("int64") + c = T.var("int64") + gv: R.Tensor((1, b, c, 1), "float32") = R.variance(x, [0, 3], keepdims=True) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, "b", "c", 1), "float32"): + b = T.var("int64") + c = T.var("int64") + gv = R.call_tir(variance, (x,), R.Tensor((1, b, c, 1), dtype="float32")) + return gv + + @T.prim_func + def variance(var_rxplaceholder: T.handle, var_T_divide: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + d = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") + T_divide = T.match_buffer(var_T_divide, [T.int64(1), b, c, T.int64(1)], dtype="float32") + rxplaceholder_red = T.alloc_buffer([T.int64(1), b, c, T.int64(1)], dtype="float32") + T_divide_1 = T.alloc_buffer([T.int64(1), b, c, T.int64(1)], dtype="float32") + T_subtract = T.alloc_buffer([a, b, c, d], dtype="float32") + T_multiply = T.alloc_buffer([a, b, c, d], dtype="float32") + T_multiply_red = T.alloc_buffer([T.int64(1), b, c, T.int64(1)], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(1), b, c, T.int64(1), a, d): + with T.block("rxplaceholder_red"): + ax0, ax1, ax2, ax3, k0, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[k0, ax1, ax2, k3]) + T.writes(rxplaceholder_red[ax0, ax1, ax2, ax3]) + with T.init(): + rxplaceholder_red[ax0, ax1, ax2, ax3] = T.float32(0) + rxplaceholder_red[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] + rxplaceholder[k0, ax1, ax2, k3] + for i0, i1, i2, i3 in T.grid(T.int64(1), b, c, T.int64(1)): + with T.block("T_divide"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_red[ax0, ax1, ax2, ax3]) + T.writes(T_divide_1[ax0, ax1, ax2, ax3]) + T_divide_1[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] / T.Cast("float32", a * d) + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_subtract"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide_1[T.int64(0), ax1, ax2, T.int64(0)]) + T.writes(T_subtract[ax0, ax1, ax2, ax3]) + T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] - T_divide_1[T.int64(0), ax1, ax2, T.int64(0)] + for i0, i1, i2, i3 in T.grid(a, b, c, d): + with T.block("T_multiply"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_subtract[ax0, ax1, ax2, ax3]) + T.writes(T_multiply[ax0, ax1, ax2, ax3]) + T_multiply[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] * T_subtract[ax0, ax1, ax2, ax3] + for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(1), b, c, T.int64(1), a, d): + with T.block("T_multiply_red"): + ax0, ax1, ax2, ax3, k0, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(T_multiply[k0, ax1, ax2, k3]) + T.writes(T_multiply_red[ax0, ax1, ax2, ax3]) + with T.init(): + T_multiply_red[ax0, ax1, ax2, ax3] = T.float32(0) + T_multiply_red[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] + T_multiply[k0, ax1, ax2, k3] + for i0, i1, i2, i3 in T.grid(T.int64(1), b, c, T.int64(1)): + with T.block("T_divide_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_multiply_red[ax0, ax1, ax2, ax3]) + T.writes(T_divide[ax0, ax1, ax2, ax3]) + T_divide[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] / T.Cast("float32", a * d) + # fmt: on + + mod = LegalizeOps()(Variance) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_unary.py b/tests/python/relax/test_transform_legalize_ops_unary.py new file mode 100644 index 000000000000..12ae366dcc8a --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_unary.py @@ -0,0 +1,693 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_abs(): + # fmt: off + @tvm.script.ir_module + class Abs: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.abs(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + gv = R.call_tir(tir_abs, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_abs(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32"),): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.fabs(rxplaceholder[v_i0, v_i1], dtype="float32") + # fmt: on + + mod = LegalizeOps()(Abs) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_abs_symbolic(): + # fmt: off + @tvm.script.ir_module + class Abs: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.abs(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(tir_abs, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_abs(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.fabs(rxplaceholder[v_i0, v_i1], dtype="float32") + # fmt: on + + mod = LegalizeOps()(Abs) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cos(): + # fmt: off + @tvm.script.ir_module + class Cos: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.cos(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(tir_cos, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_cos(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.cos(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Cos) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cos_symbolic(): + # fmt: off + @tvm.script.ir_module + class Cos: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.cos(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(tir_cos, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_cos(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.cos(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Cos) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_exp(): + # fmt: off + @tvm.script.ir_module + class Exp: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.exp(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + gv = R.call_tir(tir_exp, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32"),): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.exp(rxplaceholder[v_i0, v_i1], dtype="float32") + # fmt: on + + mod = LegalizeOps()(Exp) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_exp_symbolic(): + # fmt: off + @tvm.script.ir_module + class Exp: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.exp(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(tir_exp, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.exp(rxplaceholder[v_i0, v_i1], dtype="float32") + # fmt: on + + mod = LegalizeOps()(Exp) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_log(): + # fmt: off + @tvm.script.ir_module + class Log: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.log(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(tir_log, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_log(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.log(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Log) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_log_symbolic(): + # fmt: off + @tvm.script.ir_module + class Log: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.log(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(tir_log, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_log(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.log(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Log) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_negative(): + # fmt: off + @tvm.script.ir_module + class Negative: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.negative(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(tir_negative, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_negative(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = rxplaceholder[i0_1, i1_1] * T.float32(-1) + # fmt: on + + mod = LegalizeOps()(Negative) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_negative_symbolic(): + # fmt: off + @tvm.script.ir_module + class Negative: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.negative(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(tir_negative, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_negative(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = rxplaceholder[i0_1, i1_1] * T.float32(-1) + # fmt: on + + mod = LegalizeOps()(Negative) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sigmoid(): + # fmt: off + @tvm.script.ir_module + class Sigmoid: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.sigmoid(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(tir_sigmoid, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_sigmoid(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Sigmoid) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sigmoid_symbolic(): + # fmt: off + @tvm.script.ir_module + class Sigmoid: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.sigmoid(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(tir_sigmoid, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_sigmoid(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sigmoid(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Sigmoid) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sin(): + # fmt: off + @tvm.script.ir_module + class Sin: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.sin(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(tir_sin, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_sin(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sin(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Sin) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sin_symbolic(): + # fmt: off + @tvm.script.ir_module + class Sin: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.sin(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(tir_sin, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_sin(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sin(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Sin) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sqrt(): + # fmt: off + @tvm.script.ir_module + class Sqrt: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.sqrt(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(tir_sqrt, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_sqrt(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sqrt(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Sqrt) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_sqrt_symbolic(): + # fmt: off + @tvm.script.ir_module + class Sqrt: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.sqrt(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(tir_sqrt, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_sqrt(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.sqrt(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Sqrt) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_tanh(): + # fmt: off + @tvm.script.ir_module + class Tanh: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.tanh(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv = R.call_tir(tir_tanh, (x,), R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func + def tir_tanh(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.tanh(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Tanh) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_tanh_symbolic(): + # fmt: off + @tvm.script.ir_module + class Tanh: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.tanh(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(tir_tanh, (x,), R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_tanh(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.tanh(rxplaceholder[i0_1, i1_1]) + # fmt: on + + mod = LegalizeOps()(Tanh) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_clip_symbolic(): + @tvm.script.ir_module + class Clip: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + m = T.var("int64") + n = T.var("int64") + gv: R.Tensor((m, n), "float32") = R.clip(x, 5, 8) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): + m = T.var("int64") + n = T.var("int64") + gv = R.call_tir(tir_clip, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) + return gv + + @T.prim_func + def tir_clip(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") + compute = T.match_buffer(var_compute, [m, n], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + compute[v_i0, v_i1] = T.max( + T.min(rxplaceholder[v_i0, v_i1], T.float32(8)), T.float32(5) + ) + + mod = LegalizeOps()(Clip) + tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() From b2e46d010a0cc3b68211e50780226cdbd38a70ad Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Sat, 18 Feb 2023 21:14:34 +0800 Subject: [PATCH 36/81] [Unity][Op] Add ShapeExpr Tests for Reshape Op (#14035) This PR specially checks the relax.reshape operator when the input is a ShapeExpr. --- tests/python/relax/test_op_manipulate.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 92d4bb26760a..6c7727b7d502 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -54,6 +54,7 @@ def test_reshape_infer_struct_info(): s0 = relax.Var("s", R.Shape((3, 8, 5))) s1 = relax.Var("s", R.Shape(ndim=3)) s2 = relax.Var("s", R.Shape()) + s3 = relax.ShapeExpr((3, 8, 5)) _check_inference( bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") @@ -98,6 +99,12 @@ def test_reshape_infer_struct_info(): _check_inference(bb, relax.op.reshape(x3, s2), relax.TensorStructInfo(s2, dtype="")) _check_inference(bb, relax.op.reshape(x4, s2), relax.TensorStructInfo(s2, dtype="")) _check_inference(bb, relax.op.reshape(x5, s2), relax.TensorStructInfo(s2, dtype="")) + _check_inference(bb, relax.op.reshape(x0, s3), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.reshape(x1, s3), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.reshape(x2, s3), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.reshape(x3, s3), relax.TensorStructInfo(s3, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s3), relax.TensorStructInfo(s3, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s3), relax.TensorStructInfo(s3, dtype="")) def test_reshape_infer_struct_info_shape_symbolic(): @@ -109,6 +116,7 @@ def test_reshape_infer_struct_info_shape_symbolic(): x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) s0 = relax.Var("s", R.Shape((c, a, d, b))) s1 = relax.Var("s", R.Shape()) + s2 = relax.ShapeExpr((c, a, d, b)) _check_inference( bb, relax.op.reshape(x, (c, a, d, b)), relax.TensorStructInfo((c, a, d, b), "float32") @@ -147,6 +155,7 @@ def test_reshape_infer_struct_info_shape_symbolic(): ) _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo(s0, "float32")) _check_inference(bb, relax.op.reshape(x, s1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.reshape(x, s2), relax.TensorStructInfo(s2, "float32")) def test_reshape_infer_struct_info_shape_var(): From f8ad7845edc1297c4a43534623fe8edf41b77301 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 18 Feb 2023 11:19:48 -0500 Subject: [PATCH 37/81] [Unity] Initial PyTorch Frontend (#14037) [Unity] Initial PyTorch Frontend This PR introduces initial pytorch frontend components of Relax, including - a FX translator that translates a Torch FX graph module to an TVM IRModule, - a Relax-backend of Torch Dynamo, which brings the mechanism to build PyTorch model using Relax compilation pipeline, - a pipeline prototype that contains the collection of pre-defined pipelines that optimizes and lower IRModule before passing to minimum build. Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Tianqi Chen Co-authored-by: Siyuan Feng --- python/tvm/relax/__init__.py | 3 + python/tvm/relax/frontend/__init__.py | 19 + python/tvm/relax/frontend/torch/__init__.py | 21 + python/tvm/relax/frontend/torch/dynamo.py | 156 ++ .../tvm/relax/frontend/torch/fx_translator.py | 820 ++++++++ python/tvm/relax/pipeline.py | 84 + tests/python/relax/test_frontend_dynamo.py | 198 ++ tests/python/relax/test_frontend_from_fx.py | 1729 +++++++++++++++++ tests/python/relax/test_pipeline.py | 45 + 9 files changed, 3075 insertions(+) create mode 100644 python/tvm/relax/frontend/__init__.py create mode 100644 python/tvm/relax/frontend/torch/__init__.py create mode 100644 python/tvm/relax/frontend/torch/dynamo.py create mode 100644 python/tvm/relax/frontend/torch/fx_translator.py create mode 100644 python/tvm/relax/pipeline.py create mode 100644 tests/python/relax/test_frontend_dynamo.py create mode 100644 tests/python/relax/test_frontend_from_fx.py create mode 100644 tests/python/relax/test_pipeline.py diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index cfcf7876dc9f..33a9c2eece21 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -73,6 +73,9 @@ FuncStructInfo, ) +# pipeline +from .pipeline import get_pipeline + # Import submodules in the last to avoid dependency from . import exec_builder from . import expr diff --git a/python/tvm/relax/frontend/__init__.py b/python/tvm/relax/frontend/__init__.py new file mode 100644 index 000000000000..6c9c188aaad0 --- /dev/null +++ b/python/tvm/relax/frontend/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Frontends for constructing Relax programs, with the model importers +""" diff --git a/python/tvm/relax/frontend/torch/__init__.py b/python/tvm/relax/frontend/torch/__init__.py new file mode 100644 index 000000000000..55da5a456d6a --- /dev/null +++ b/python/tvm/relax/frontend/torch/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +PyTorch Frontends for constructing Relax programs, with the model importers +""" +from .fx_translator import from_fx +from .dynamo import relax_dynamo, dynamo_capture_subgraphs diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py new file mode 100644 index 000000000000..94de73a43115 --- /dev/null +++ b/python/tvm/relax/frontend/torch/dynamo.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, missing-function-docstring, not-callable +# pylint: disable=import-outside-toplevel, unused-argument +# mypy: ignore-errors +"""PyTorch Dynamo backend of Relax.""" +import functools +from typing import Optional + +import tvm +from tvm.relax.vm import build as relax_build +from tvm.relax.frontend.torch.fx_translator import from_fx + + +def device_from_inputs(example_inputs): + for x in example_inputs: + if hasattr(x, "device"): + return x.device + return None + + +def relax_dynamo(pipeline: Optional[tvm.transform.Pass] = None): + """A helper function to create a relax backend. + + Parameters + ---------- + pipeline : Optional[tvm.transform.Pass] + The pipeline to be applied to the relax module before sent to build. + + Returns + ------- + backend : Callable[[torch.fx.GraphModule, List[torch.Tensor]], Callable] + The relax dynamo backend. + """ + + def _relax_backend(graph_module, example_inputs): + import torch # type: ignore[import] + + assert isinstance(graph_module, torch.fx.GraphModule) + + def to_torch_tensor(nd_tensor): + """A helper function to transfer a NDArray to torch.tensor.""" + if isinstance(nd_tensor, tvm.nd.NDArray): + return torch.from_numpy(nd_tensor.numpy()) + elif isinstance(nd_tensor, tvm.ir.Array): + return tuple(to_torch_tensor(x) for x in nd_tensor) + else: + raise ValueError(f"Unsupported type {type(nd_tensor)}") + + def to_tvm_tensor(torch_tensor): + """A helper function to transfer a torch.tensor to NDArray.""" + if not isinstance(torch_tensor, torch._subclasses.fake_tensor.FakeTensor): + return tvm.nd.array(torch_tensor.numpy()) + # Fake Tensor + real_tensor = torch.randn(torch_tensor.shape, dtype=torch_tensor.dtype) + return tvm.nd.array(real_tensor.numpy()) + + device = device_from_inputs(example_inputs) + input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs] + mod = from_fx(graph_module, input_info) + + if device.type == "cuda": + dev = tvm.cuda(device.index) + target = tvm.target.cuda() + else: + dev = tvm.cpu(0) + target = tvm.target.Target(llvm_target()) + + # invoke optimization pipeline. + if pipeline is None: + # get default pipeline + seq = tvm.relax.get_pipeline() + elif isinstance(pipeline, str): + # lookup by name + seq = tvm.relax.get_pipeline(pipeline) + else: + seq = pipeline + + mod = mod.with_attr("target", target) + mod = seq(mod) + + ex = relax_build(mod, target=target) + + vm = tvm.relax.vm.VirtualMachine(exec=ex.mod, device=dev) + + def exec_tvm(*i_args): + args = [a.contiguous() for a in i_args] + vm_args = list() + for arg in args: + if arg.dim() != 0: + if arg.requires_grad: + arg = arg.detach() + vm_args.append(to_tvm_tensor(arg)) + outputs = vm["main"](*vm_args) + return to_torch_tensor(outputs) + + return exec_tvm + + return _relax_backend + + +def dynamo_capture_subgraphs(model, *params) -> tvm.ir.IRModule: + """Capture subgraphs of the PyTorch model using torch.compile into an IRModule. + + Parameters + ---------- + model : torch.nn.Module + The PyTorch model to be captured. + + params : List[torch.Tensor] + The parameters of the PyTorch model. + + Returns + ------- + mod : tvm.ir.IRModule + The IRModule that contains captured subgraphs. + """ + import torch # type: ignore[import] + from torch import fx # type: ignore[import] + from torch import _dynamo as dynamo # type: ignore[import] + + mod = tvm.IRModule() + + def _capture(graph_module: fx.GraphModule, example_inputs): + assert isinstance(graph_module, torch.fx.GraphModule) + input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs] + subgraph = from_fx(graph_module, input_info) + mod["subgraph_" + str(len(mod.get_global_vars()))] = subgraph["main"] + return graph_module.forward + + dynamo.reset() + compiled_model = torch.compile(model, backend=_capture) + compiled_model(*params) + return mod + + +@functools.lru_cache(None) +def llvm_target(): + if "avx512" in open("/proc/cpuinfo").read(): + return "llvm -mcpu=skylake-avx512" + return "llvm -mcpu=core-avx2" diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py new file mode 100644 index 000000000000..582f2edbcf55 --- /dev/null +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -0,0 +1,820 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck +# pylint: disable=import-outside-toplevel +"""PyTorch FX frontend of Relax.""" +from typing import Callable, Dict, List, Tuple, Union +from functools import reduce + +import tvm +from tvm import relax + + +class TorchFXImporter: + """An importer from PyTorch FX to Relax.""" + + import torch # type: ignore + from torch import fx + + def __init__(self) -> None: + import torch # type: ignore + from torch import fx + + self.env: Dict[fx.node.Node, relax.Expr] = {} + self.params: Dict[torch.Tensor, relax.Constant] = {} + self.named_modules: Dict[str, torch.Module] = None + self.block_builder: relax.BlockBuilder = None + self.create_convert_map() + + ########## Utilities ########## + @staticmethod + def _fetch_attr(model, target: str): + import torch # type: ignore + + target_atoms = target.split(".") + attr_itr = model + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced non existing target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + if isinstance(attr_itr, torch.Tensor): + return TorchFXImporter._convert_torch_tensor_to_relax(attr_itr) + return attr_itr + + @staticmethod + def _convert_data_type(input_type): + """converts the PyTorch scalar type input_type to a TVM dtype.""" + import torch # type: ignore + + input_type = input_type.lower() + if input_type in ["float", "float32", "torch.float32", torch.float32]: + return "float32" + elif input_type in ["float16", "torch.float16", torch.float16]: + return "float16" + elif input_type in ["int64", "torch.int64", torch.int64]: + return "int64" + else: + raise NotImplementedError("input_type {} is not handled yet".format(input_type)) + + @staticmethod + def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var: + tensor = tensor.detach().cpu() + shape = tensor.data.shape + dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype)) + return relax.const(tensor.data.numpy(), relax.TensorStructInfo(shape, dtype)) + + @staticmethod + def shape_of(tensor): + """Get the shape of a tensor.""" + import torch # type: ignore + + if isinstance(tensor, relax.Expr): + if not isinstance(tensor.struct_info, relax.TensorStructInfo): + raise TypeError("The input Expr of shape_of should be a Tensor") + return tensor.struct_info.shape + elif isinstance(tensor, torch.Tensor): + return tensor.shape + raise ValueError("Unsupported type: {}".format(type(tensor))) + + def retrieve_args(self, node): + return self._retrieve_args(node.args) + + def _retrieve_args(self, node): + from torch import fx + + if isinstance(node, fx.node.Node): + return self.env[node] + elif isinstance(node, tuple): + return tuple(self._retrieve_args(x) for x in node) + elif isinstance(node, list): + return [self._retrieve_args(x) for x in node] + elif isinstance(node, dict): + return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()} + else: + return node + + @staticmethod + def _promote_binary_op_args(lhs, rhs): + if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): + return lhs, rhs + elif isinstance(lhs, relax.Expr): + assert isinstance(lhs.struct_info, relax.TensorStructInfo) + return lhs, relax.const(rhs, lhs.struct_info.dtype) + elif isinstance(rhs, relax.Expr): + assert isinstance(rhs.struct_info, relax.TensorStructInfo) + return relax.const(lhs, rhs.struct_info.dtype), rhs + else: + assert False + + def _call_binary_op(self, op, lhs, rhs): + lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs) + return self.block_builder.emit(op(lhs, rhs)) + + ########## Arithmetic ########## + + def _cos(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.cos(self.env[node.args[0]])) + + def _sin(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.sin(self.env[node.args[0]])) + + def _sqrt(self, node: fx.node.Node) -> relax.Expr: + arg = self.env[node.args[0]] + if isinstance(arg, (int, float)): + arg = relax.const(arg, "float32") + return self.block_builder.emit(relax.op.sqrt(arg)) + + def _add(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.add, lhs, rhs) + return lhs + rhs + + def _floordiv(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.floor_divide, lhs, rhs) + return lhs // rhs + + def _mul(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.multiply, lhs, rhs) + return lhs * rhs + + def _sub(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.subtract, lhs, rhs) + return lhs - rhs + + def _truediv(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.divide, lhs, rhs) + return lhs / rhs + + def _clamp(self, node: fx.node.Node) -> relax.Expr: + args = self.retrieve_args(node) + a_min = node.kwargs["min"] + a_max = node.kwargs["max"] + if not isinstance(a_min, (int, float)): + raise ValueError( + f"TVM only supports constant min value for torch.clamp/clip, " + f"but got {a_min} with type {type(a_min)}" + ) + if not isinstance(a_max, (int, float)): + raise ValueError( + f"TVM only supports constant max value for torch.clamp/clip, " + f"but got {a_max} with type {type(a_max)}" + ) + return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) + + ########## Compare ########## + + def _lt(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + return self._call_binary_op(relax.op.less, lhs, rhs) + + ########## Creation ########## + + def _tril(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else 0 + assert isinstance(k, int) + return self.block_builder.emit(relax.op.create.tril(x, k)) + + def _new_ones(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + self_var = args[0] + size = args[1:] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, self_var.struct_info.dtype), + self_var.struct_info.dtype, + ) + ) + + ########## Statistical ########## + + def _sum(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + if len(args) == 1: + return self.block_builder.emit(relax.op.sum(args[0])) + return self.block_builder.emit(relax.op.sum(args[0], args[1])) + + ########## DataType ########## + + def _float(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) + + def _half(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) + + def _type(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.astype(args[0], args[1])) + + ########## Linear Algebra ########## + + def _matmul_impl(self, a: relax.Expr, b: relax.Expr): + return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) + + def _matmul(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + res = self._matmul_impl( + args[0], + args[1], + ) + return res + + def _addmm(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + y = self.env[node.args[1]] + z = self.env[node.args[2]] + matmul = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) + return self.block_builder.emit(relax.op.add(x, matmul)) + + ########## Manipulation ########## + + def _cat(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.concat(args[0], axis=node.kwargs["dim"])) + + def _expand(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.broadcast_to(args[0], args[1:])) + + def _flatten(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + start_dim = module.start_dim + end_dim = module.end_dim + else: + start_dim = node.args[1] if len(node.args) >= 2 else 0 + end_dim = node.args[2] if len(node.args) == 3 else -1 + shape = self.shape_of(x) + start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim + end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim + flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) + new_shape = ( + [shape[i] for i in range(0, start_dim)] + + [flattened] + + [shape[i] for i in range(end_dim + 1, len(shape))] + ) + return self.block_builder.emit(relax.op.reshape(x, new_shape)) + + def _permute(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) + + def _reshape(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.reshape(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.reshape(args[0], args[1:])) + + def _split(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + split_size = node.args[1] + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + else: + dim = 0 + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + return self.block_builder.emit(relax.op.split(x, n_section, dim)) + + def _transpose(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + full_idx = list(range(len(self.shape_of(args[0])))) + full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] + return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) + + ########## Neural Network ########## + + def _linear(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = None if module.bias is None else self.params[module.bias] + return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + + def _conv2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + + conv2d = self.block_builder.emit( + relax.op.nn.conv2d( + x, + weight, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + ) + + if module.bias is None: + return conv2d + + bias = self.params[module.bias] + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + + return self.block_builder.emit(relax.op.add(conv2d, bias)) + + def _max_pool2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + kernel = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + ceil_mode = module.ceil_mode + else: + nargs = len(node.args) + kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"] + stride = node.args[2] if nargs > 2 else node.kwargs["stride"] + padding = node.args[3] if nargs > 3 else node.kwargs["padding"] + dilation = node.args[4] if nargs > 4 else node.kwargs["dilation"] + ceil_mode = node.args[5] if nargs > 5 else node.kwargs["ceil_mode"] + + stride = kernel if stride is None else stride + + return self.block_builder.emit( + relax.op.nn.max_pool2d( + x, + pool_size=kernel, + strides=stride, + padding=padding, + dilation=dilation, + layout="NCHW", + ceil_mode=ceil_mode, + ) + ) + + def _adaptive_avg_pool2d(self, is_module: bool) -> Callable: + from torch import fx + + def _impl(node: fx.node.Node) -> relax.Var: + if is_module: + module = self.named_modules[node.target] + x = self.env[node.args[0]] + output_size = module.output_size + else: + x = self.env[node.args[0]] + output_size = node.args[1] + return self.block_builder.emit( + relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") + ) + + return _impl + + def _softmax(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + dim = module.dim + else: + nargs = len(node.args) + dim = node.args[1] if nargs > 1 else node.kwargs["dim"] + assert dim is not None + return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + + def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params[module.bias] + dtype = self._convert_data_type(str(module.running_mean.dtype)) + running_mean = relax.const(module.running_mean.cpu().detach().numpy(), dtype) + running_var = relax.const(module.running_var.cpu().detach().numpy(), dtype) + eps = module.eps + + res_tuple = self.block_builder.emit( + relax.op.nn.batch_norm( + x, + weight, + bias, + running_mean, + running_var, + axis=1, + epsilon=eps, + ) + ) + + return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) + + def _layer_norm(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + + if module.elementwise_affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.normalized_shape), x.checked_type) + beta = relax.const(torch.zeros_like(module.normalized_shape), x.checked_type) + dim_num = len(module.normalized_shape) + axes = list(range(-dim_num, 0)) + + return self.block_builder.emit( + relax.op.nn.layer_norm( + x, + gamma, + beta, + axes=axes, + epsilon=module.eps, + ) + ) + + def _group_norm(self, node: fx.node.Node) -> relax.Var: + # torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, + # affine=True, device=None, dtype=None) + x = self.env[node.args[0]] + module = self.named_modules[node.target] + num_groups = module.num_groups + num_channels = module.num_channels + eps = module.eps + affine = module.affine + + shape = self.shape_of(x) + assert len(shape) == 4 + N, C, H, W = shape[0], shape[1], shape[2], shape[3] + assert C == num_channels + assert C % num_groups == 0 + grouped_x = self.block_builder.emit( + relax.op.reshape(x, [N, num_groups, C // num_groups, H, W]) + ) + mean_x = self.block_builder.emit(relax.op.mean(grouped_x, [2, 3, 4], keepdims=True)) + sub_x = self.block_builder.emit(relax.op.subtract(grouped_x, mean_x)) + square_x = self.block_builder.emit(relax.op.multiply(sub_x, sub_x)) + sum_square_x = self.block_builder.emit(relax.op.sum(square_x, [2, 3, 4], keepdims=True)) + var_x = self._call_binary_op(relax.op.divide, sum_square_x, (C // num_groups * H * W).value) + var_x_eps = self._call_binary_op(relax.op.add, var_x, eps) + std_x = self.block_builder.emit(relax.op.sqrt(var_x_eps)) + norm_x = self.block_builder.emit(relax.op.divide(sub_x, std_x)) + + if affine: + weight = self.params[module.weight] + bias = self.params[module.bias] + weight_reshape = self.block_builder.emit( + relax.op.reshape(weight, (1, num_groups, C // num_groups, 1, 1)) + ) + bias_reshape = self.block_builder.emit( + relax.op.reshape(bias, (1, num_groups, C // num_groups, 1, 1)) + ) + norm_x = self.block_builder.emit(relax.op.multiply(norm_x, weight_reshape)) + norm_x = self.block_builder.emit(relax.op.add(norm_x, bias_reshape)) + return self.block_builder.emit(relax.op.reshape(norm_x, (N, C, H, W))) + + def _embedding(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + x = self.block_builder.emit(relax.op.astype(x, "int32")) + return self.block_builder.emit(relax.op.take(weight, x, axis=0)) + + def _interpolate(self, node: fx.node.Node) -> relax.Var: + # torch.nn.functional.interpolate( + # input, size=None, scale_factor=None, mode='nearest', align_corners=None, + # recompute_scale_factor=None, antialias=False) + # (TODO) this is a temporary implementation for interpolate that only considers NCHW layout + # it basically replicates the implementation in tvm.relay.frontend.pytorch + data = self.env[node.args[0]] + size = node.kwargs["size"] + scale_factor = node.kwargs["scale_factor"] + method = node.kwargs["mode"] + align_corners = node.kwargs["align_corners"] + recompute_scale_factor = node.kwargs["recompute_scale_factor"] + antialias = node.kwargs["antialias"] + + assert recompute_scale_factor is None + assert antialias is False + + if size is None: + shape = self.shape_of(data) + assert isinstance(shape, relax.ShapeExpr) + size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) + + if method.startswith("nearest"): + method = "nearest_neighbor" + elif method[0:2] == "bi": + method = method[2:] + + if method == "nearest_neighbor": + coord_trans = "asymmetric" + elif align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" + + return self.block_builder.emit( + relax.op.image.resize2d( + data, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + ) + ) + + ########## Others ########## + + def _size(self, node: fx.node.Node) -> relax.Expr: + x = self.env[node.args[0]] + shape = self.shape_of(x) + if len(node.args) == 1: + assert isinstance(shape, relax.ShapeExpr) + return shape + assert len(node.args) == 2 + idx = node.args[1] + return self.shape_of(x)[idx].value + + def _getattr(self, node: fx.node.Node) -> relax.Var: + if isinstance(self.env[node.args[0]], relax.Expr): + if node.args[1] == "dtype": + return self.env[node.args[0]].struct_info.dtype + elif node.args[1] == "shape": + return self.shape_of(self.env[node.args[0]]) + return getattr(self.env[node.args[0]], node.args[1]) + + def _getitem(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)): + return x[node.args[1]] + elif isinstance(x, relax.Var): + if isinstance(x.struct_info, relax.TupleStructInfo): + return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) + + assert isinstance(x.struct_info, relax.TensorStructInfo) + begin = [] + end = [] + stride = [] + axes = [] + expand_dim = [] + i = 0 + shape = self.shape_of(x) + for index in node.args[1]: + if isinstance(index, int): + begin.append(index) + end.append(index + 1) + stride.append(1) + axes.append(i) + i = i + 1 + elif isinstance(index, slice): + begin.append(0 if index.start is None else index.start) + end.append(shape[i] if index.stop is None else index.stop) + stride.append(1 if index.step is None else index.step) + axes.append(i) + i = i + 1 + elif index is None: + expand_dim.append(i) + i = i + 1 + else: + raise ValueError("Unsupported index type: " + str(type(index))) + while i < len(shape): + begin.append(0) + end.append(shape[i]) + axes.append(i) + i = i + 1 + sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + sliced_shape = list(self.shape_of(sliced)) + for i in expand_dim: + sliced_shape.insert(i, 1) + return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) + else: + assert False + + def create_convert_map(self): + from torch import nn + from torch import fx + + self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.node.Node], relax.Var]] = { + # call_module + nn.Linear: self._linear, + nn.Conv2d: self._conv2d, + nn.MaxPool2d: self._max_pool2d, + nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True), + nn.Softmax: self._softmax, + nn.ReLU: lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), + nn.ReLU6: lambda node: self.block_builder.emit( + relax.op.clip(self.env[node.args[0]], 0, 6) + ), + nn.SiLU: lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), + nn.Flatten: self._flatten, + nn.BatchNorm2d: self._batch_norm_2d, + nn.LayerNorm: self._layer_norm, + nn.GroupNorm: self._group_norm, + nn.Dropout: lambda node: self.env[node.args[0]], + nn.modules.sparse.Embedding: self._embedding, + # call_function and call_method + "cos": self._cos, + "sin": self._sin, + "add": self._add, + "floordiv": self._floordiv, + "mul": self._mul, + "sub": self._sub, + "sqrt": self._sqrt, + "lt": self._lt, + "truediv": self._truediv, + "new_ones": self._new_ones, + "tril": self._tril, + "sum": self._sum, + "float": self._float, + "half": self._half, + "type": self._type, + "matmul": self._matmul, + "addmm": self._addmm, + "cat": self._cat, + "expand": self._expand, + "flatten": self._flatten, + "permute": self._permute, + "reshape": self._reshape, + "split": self._split, + "transpose": self._transpose, + "unsqueeze": lambda node: self.block_builder.emit( + relax.op.expand_dims(self.env[node.args[0]], node.args[1]) + ), + "view": self._reshape, + "softmax": self._softmax, + "clamp": self._clamp, + "relu": lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), + "gelu": lambda node: self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])), + "interpolate": self._interpolate, + "size": self._size, + "getattr": self._getattr, + "getitem": self._getitem, + "contiguous": lambda node: self.env[node.args[0]], + "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), + } + + def from_fx(self, model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule: + """Convert a PyTorch FX GraphModule to a Relax program.""" + from torch import fx + + self.named_modules = dict(model.named_modules()) + + graph: fx.Graph = model.graph + # Create input variables. + inputs = list() + for idx, (shape, dtype) in enumerate(input_info): + inputs.append( + relax.Var( + f"inp_{idx}", relax.TensorStructInfo(shape, self._convert_data_type(dtype)) + ) + ) + + # Initialize the block builder with a function and a dataflow block. + self.block_builder = relax.BlockBuilder() + with self.block_builder.function(name="main", params=inputs.copy()): + output = None + with self.block_builder.dataflow(): + # Translate model parameters. + for _, param in model.named_parameters(): + shape = param.data.shape + dtype = self._convert_data_type(str(param.data.dtype)) + if dtype in ("float32", "float16"): + self.params[param] = relax.const( + param.data.cpu().numpy(), relax.TensorStructInfo(shape, dtype) + ) + else: + raise ValueError("Unsupported data type for model parameters: %s" % dtype) + # Translate the model. + for node in graph.nodes: + if node.op == "placeholder": + assert len(inputs) > 0, "Provided inputs is less than actual inputs" + self.env[node] = inputs.pop(0) + elif node.op == "output": + args = self.retrieve_args(node) + output = self.block_builder.emit_output(args[0]) + break + elif node.op == "get_attr": + self.env[node] = TorchFXImporter._fetch_attr(model, node.target) + elif node.op == "call_module": + module = self.named_modules[node.target] + assert ( + type(module) in self.convert_map + ), f"Unsupported module type {type(module)}" + self.env[node] = self.convert_map[type(module)](node) + elif node.op == "call_function": + func_name = node.name.rstrip("0123456789_") + assert ( + func_name in self.convert_map + ), f"Unsupported function type {func_name}" + self.env[node] = self.convert_map[func_name](node) + elif node.op == "call_method": + assert ( + node.target in self.convert_map + ), f"Unsupported function target {node.target}" + self.env[node] = self.convert_map[node.target](node) + else: + raise ValueError(f"Unsupported op {node.op}") + assert output is not None + self.block_builder.emit_func_output(output) + + return self.block_builder.get() + + +def from_fx(model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule: + """Convert a PyTorch FX GraphModule to a Relax program + + Parameters + ---------- + model : fx.GraphModule + The PyTorch FX GraphModule to convert. + + input_info : List[Tuple[Tuple[int], str]] + A list of shapes and data types of input tensors. + + Returns + ------- + module : tvm.IRModule + The converted Relax program. + + Examples + -------- + Users can use the FX tracer or dynamo.export() to extract + a fx.GraphModule from a PyTorch model. The following codes show + how to convert a PyTorch model to a Relax program. + + .. code-block:: python + + # Import the importer. + import numpy as np + import torch + from tvm.relax.frontend.torch_fx import from_fx + from torch import _dynamo as dynamo + + # Define the module + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True) + + def forward(self, input): + return self.linear(input) + + # Instantiate the model and create the input info dict. + torch_model = MyModule() + input_info = [((128, 10), "float32")] + input_tensors = [ + torch.astensor(np.random.randn(*shape).astype(dtype)) + for shape, dtype in input_info + ] + + # Use FX tracer to trace the PyTorch model. + graph_module = fx.symbolic_trace(torch_model) + + # Use the dynamo.export() to export the PyTorch model to FX. + try: + graph_module = dynamo.export(torch_model, *input_tensors) + except: + raise RuntimeError("Failed to export the PyTorch model to FX.") + + # Use the importer to import the PyTorch model to Relax. + mod: tvm.IRModule = from_fx(graph_module, input_info) + + # Print out the imported model. + print(mod.script()) + + Notes + ----- + For a given PyTorch model, to lookup the names of the model inputs in + FX, one can use + + .. code-block:: python + + fx.symbolic_trace(model).graph.print_tabular() + + to print out the tabular representation of the PyTorch module, and then + check the placeholder rows in the beginning of the tabular. + """ + return TorchFXImporter().from_fx(model, input_info) diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py new file mode 100644 index 000000000000..a5da15b76d3b --- /dev/null +++ b/python/tvm/relax/pipeline.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Pre-defined pipelines. + +oRelax enables flexible pipeline optimizations before min build. +This namespace offers a pre-defined collection that can be used +as it is or serves as a basis to do further composition. +""" +# pylint: disable=unused-argument +import tvm +from tvm import meta_schedule as ms +from . import transform + + +@tvm.transform.module_pass(opt_level=0) +def zero_pipeline(mod: tvm.ir.IRModule, ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + """Pipeline that applies pre-tuned logs. + + Parameters + ---------- + mod : tvm.ir.IRModule + Input IRModule. + + ctx : tvm.transform.PassContext + The pass context + + Returns + ------- + mod: tvm.ir.IRModule + The result transformed module. + """ + seq = tvm.transform.Sequential( + [ + transform.LegalizeOps(), + transform.AnnotateTIROpPattern(), + transform.FoldConstant(), + transform.FuseOps(), + transform.FuseTIR(), + ] + ) + mod = seq(mod) + if ms.Database.current(): + mod = transform.MetaScheduleApplyDatabase()(mod) + return mod + + +# global map of pre-built pipelines +PIPELINE_MAP = {"zero": zero_pipeline} + + +def get_pipeline(name: str = "zero") -> tvm.transform.Pass: + """Get pre-build pipeline by name + + Parameters + ---------- + name : Optional[str] + Name of the pipeline + + Returns + ------- + pipeline: tvm.transform.Pass + The transformation pipeline. + """ + + if name in PIPELINE_MAP: + return PIPELINE_MAP[name] + else: + raise ValueError( + f"Unknown pre-built pipeline {name}," f"candidates are {list(PIPELINE_MAP.keys())}" + ) diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py new file mode 100644 index 000000000000..370df2103d79 --- /dev/null +++ b/tests/python/relax/test_frontend_dynamo.py @@ -0,0 +1,198 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("torch._dynamo") + + +import tvm +from tvm import relax, meta_schedule as ms, tir +import tvm.testing +import torch +import torch._dynamo as dynamo +from tvm.relax.frontend.torch import relax_dynamo +from tvm.script.parser import relax as R, tir as T + + +def test_relax_dynamo(): + class Input1(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(100, 10) + + def forward(self, x): + return torch.nn.functional.relu(self.lin(x)) + + model = Input1() + ### construct the database + @tvm.script.ir_module + class Input1_ir: + @T.prim_func + def main( + inp_0: T.Buffer[(T.int64(10), T.int64(100)), "float32"], + param_0: T.Buffer[(T.int64(100), T.int64(10)), "float32"], + param_1: T.Buffer[T.int64(10), "float32"], + compute: T.Buffer[(T.int64(10), T.int64(10)), "float32"], + ): + # function attr dict + T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + matmul = T.alloc_buffer([T.int64(10), T.int64(10)], dtype="float32") + T_add = T.alloc_buffer([T.int64(10), T.int64(10)], dtype="float32") + for i0, i1, k in T.grid(T.int64(10), T.int64(10), T.int64(100)): + with T.block("matmul"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(inp_0[v_i0, v_k], param_0[v_k, v_i1]) + T.writes(matmul[v_i0, v_i1]) + with T.init(): + matmul[v_i0, v_i1] = T.float32(0) + matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + inp_0[v_i0, v_k] * param_0[v_k, v_i1] + for ax0, ax1 in T.grid(T.int64(10), T.int64(10)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(matmul[v_ax0, v_ax1], param_1[v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = matmul[v_ax0, v_ax1] + param_1[v_ax1] + for i0, i1 in T.grid(T.int64(10), T.int64(10)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_add[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.max(T_add[v_i0, v_i1], T.float32(0)) + + db = ms.Database.create("memory") + workload = db.commit_workload(Input1_ir) + + sch = tir.Schedule(Input1_ir, debug_mask="all") + b0 = sch.get_block(name="matmul", func_name="main") + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + sch.compute_inline(block=b1) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") + l3, l4, l5 = sch.get_loops(block=b0) + v6, v7, v8, v9 = sch.sample_perfect_tile( + loop=l3, n=4, max_innermost_factor=64, decision=[1, 2, 5, 1] + ) + l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9], preserve_unit_iters=True) + v14, v15, v16, v17 = sch.sample_perfect_tile( + loop=l4, n=4, max_innermost_factor=64, decision=[1, 1, 10, 1] + ) + l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17], preserve_unit_iters=True) + v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64, decision=[100, 1]) + l24, l25 = sch.split(loop=l5, factors=[v22, v23], preserve_unit_iters=True) + sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21) + (b26,) = sch.get_consumers(block=b0) + sch.reverse_compute_at(block=b26, loop=l18, preserve_unit_loops=True, index=-1) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.parallel", ann_val=96) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.vectorize", ann_val=64) + v27 = sch.sample_categorical( + candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0 + ) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v27) + + tuning_record = ms.database.TuningRecord(sch.trace, workload, run_secs=[0.0]) + db.commit_tuning_record(tuning_record) + ### Optimize the model with tuned-log + with db: + opt_model = torch.compile(model, backend=relax_dynamo()) + inp = torch.randn(10, 100) + tvm.testing.assert_allclose( + opt_model(inp).detach().numpy(), model(inp).detach().numpy(), rtol=1e-5, atol=1e-5 + ) + + +def test_subgraph_capture(): + import torch + from tvm.relax.frontend.torch.dynamo import dynamo_capture_subgraphs + + class Input1(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(100, 10) + + def forward(self, x): + return torch.nn.functional.relu(self.lin(x)) + + @tvm.script.ir_module + class Expected1: + @R.function + def subgraph_0( + inp_0: R.Tensor((10, 100), dtype="float32"), + w0: R.Tensor((10, 100), dtype="float32"), + w1: R.Tensor((10,), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, axes=None) + lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32") + lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1) + lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + model = Input1() + mod = dynamo_capture_subgraphs(model, torch.randn(10, 100)) + binding = {"w0": model.lin.weight.detach().numpy(), "w1": model.lin.bias.detach().numpy()} + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + expected = relax.transform.BindParams("subgraph_0", binding)(Expected1) + tvm.ir.assert_structural_equal(mod, expected) + + def Input2(a, b): + x = a / (torch.sin(a) + 1) + if torch.sum(b) < 1: + b = b * -1 + return x * b + + @tvm.script.ir_module + class Expected2: + @R.function + def subgraph_0( + inp_0: R.Tensor((10,), dtype="float32"), inp_1: R.Tensor((10,), dtype="float32") + ) -> R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((), dtype="bool")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10,), dtype="float32") = R.sin(inp_0) + lv1: R.Tensor((10,), dtype="float32") = R.add(lv, R.const(1, "float32")) + lv2: R.Tensor((10,), dtype="float32") = R.divide(inp_0, lv1) + lv3: R.Tensor((), dtype="float32") = R.sum(inp_1, axis=None, keepdims=False) + lv4: R.Tensor((), dtype="bool") = R.less(lv3, R.const(1, "float32")) + gv: R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((), dtype="bool")) = ( + lv2, + lv4, + ) + R.output(gv) + return gv + + @R.function + def subgraph_1( + inp_01: R.Tensor((10,), dtype="float32"), inp_11: R.Tensor((10,), dtype="float32") + ) -> R.Tuple(R.Tensor((10,), dtype="float32")): + # block 0 + with R.dataflow(): + lv5: R.Tensor((10,), dtype="float32") = R.multiply(inp_11, inp_01) + gv1: R.Tuple(R.Tensor((10,), dtype="float32")) = (lv5,) + R.output(gv1) + return gv1 + + mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10)) + tvm.ir.assert_structural_equal(mod, Expected2) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py new file mode 100644 index 000000000000..9b35d34bd370 --- /dev/null +++ b/tests/python/relax/test_frontend_from_fx.py @@ -0,0 +1,1729 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import relax +import tvm.testing +from tvm.script.parser import relax as R, tir as T + + +def verify_model(torch_model, input_info, binding, expected): + from torch import fx + from tvm.relax.frontend.torch import from_fx + + graph_model = fx.symbolic_trace(torch_model) + mod = from_fx(graph_model, input_info) + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + expected = relax.transform.BindParams("main", binding)(expected) + tvm.ir.assert_structural_equal(mod, expected) + + +@tvm.testing.requires_gpu +def test_conv(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class Conv2D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1]) + lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3 + R.output(gv) + return gv + + class Conv2D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv1 + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10], "float32")] + + model = Conv2D1() + binding = {"w1": model.conv.weight.numpy(), "w2": model.conv.bias.numpy()} + verify_model(model, input_info, binding, expected1) + + model = Conv2D2() + binding = {"w1": model.conv.weight.numpy()} + verify_model(model, input_info, binding, expected2) + + +@tvm.testing.requires_gpu +def test_linear(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + # nn.Linear + class Dense1(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=True) + + def forward(self, input): + return self.linear(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((7, 10), dtype="float32"), + w2: R.Tensor((1, 7), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 7), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) + lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( + input_1, lv, out_dtype="float32" + ) + lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, w2) + gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv2 + R.output(gv) + return gv + + class Dense2(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=False) + + def forward(self, input): + return self.linear(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((7, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 7), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) + lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( + input_1, lv, out_dtype="float32" + ) + gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv1 + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10], "float32")] + + model = Dense1() + binding = {"w1": model.linear.weight.numpy(), "w2": model.linear.bias.numpy()} + verify_model(model, input_info, binding, expected1) + + model = Dense2() + binding = {"w1": model.linear.weight.numpy()} + verify_model(model, input_info, binding, expected2) + + # matmul + class MatMul1(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32"), + input_2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul( + input_1, input_2, out_dtype="float32" + ) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model( + MatMul1(), + [([10, 10], "float32"), ([10, 10], "float32")], + {}, + expected3, + ) + + +@tvm.testing.requires_gpu +def test_relu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class ReLU0(Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, input): + return self.relu(input) + + class ReLU1(Module): + def forward(self, input): + return torch.nn.functional.relu(input) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.nn.relu(input_1) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + input_info = [([10, 10], "float32")] + verify_model(ReLU0(), input_info, {}, expected) + verify_model(ReLU1(), input_info, {}, expected) + + +@tvm.testing.requires_gpu +def test_relu6(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class ReLU6(Module): + def __init__(self): + super().__init__() + self.relu6 = torch.nn.ReLU6() + + def forward(self, input): + return self.relu6(input) + + @tvm.script.ir_module + class expected: + @R.function + def main(input: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.clip(input, 0, 6) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + input_info = [([10, 10], "float32")] + verify_model(ReLU6(), input_info, {}, expected) + + +@tvm.testing.requires_gpu +def test_maxpool2d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class MaxPool2d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[1, 1], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool2d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[2, 2], + strides=[2, 2], + dilation=[2, 3], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tensor((1, 3, 4, 4), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool2d3(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 6, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 6, 6), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[4, 4], + strides=[2, 2], + dilation=[1, 1], + padding=[2, 2, 2, 2], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tensor((1, 3, 6, 6), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(MaxPool2d(), input_info, {}, expected1) + verify_model(MaxPool2d2(), input_info, {}, expected2) + verify_model(MaxPool2d3(), input_info, {}, expected3) + + +@tvm.testing.requires_gpu +def test_adaptive_avgpool2d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class AdaptiveAvgPool2d0(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d([10, 10]) + + def forward(self, input): + return self.pool(input) + + class AdaptiveAvgPool2d1(Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.adaptive_avg_pool2d( + input_1, output_size=[10, 10], layout="NCHW", out_layout="NCHW" + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(AdaptiveAvgPool2d0(), input_info, {}, expected1) + verify_model(AdaptiveAvgPool2d1(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_flatten(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Flatten(Module): + def __init__(self): + super().__init__() + self.f = torch.nn.Flatten(2, -1) + + def forward(self, input): + return self.f(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 100), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 100), dtype="float32") = R.reshape(input_1, (1, 3, 100)) + gv: R.Tensor((1, 3, 100), dtype="float32") = lv + R.output(gv) + return gv + + # call_module + verify_model(Flatten(), input_info, {}, expected1) + # call_method + verify_model(torch.nn.Flatten(2, -1), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_batchnorm2d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class BatchNorm2d(Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, input): + return self.bn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + w3: R.Tensor((3,), dtype="float32"), + w4: R.Tensor((3,), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = R.nn.batch_norm( + input_1, + w1, + w2, + w3, + w4, + axis=1, + epsilon=1e-05, + center=True, + scale=True, + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv1 + R.output(gv) + return gv + + model = BatchNorm2d() + binding = { + "w1": model.bn.weight.numpy(), + "w2": model.bn.bias.numpy(), + "w3": model.bn.running_mean.numpy(), + "w4": model.bn.running_var.numpy(), + } + verify_model(BatchNorm2d(), input_info, binding, expected1) + + +@tvm.testing.requires_gpu +def test_embedding(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([4], "int64")] + + class Embedding(Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(10, 3) + + def forward(self, input): + return self.embedding(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((4,), dtype="int64"), w1: R.Tensor((10, 3), dtype="float32") + ) -> R.Tensor((4, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4,), dtype="int32") = R.astype(input_1, dtype="int32") + lv1: R.Tensor((4, 3), dtype="float32") = R.take(w1, lv, axis=0) + gv: R.Tensor((4, 3), dtype="float32") = lv1 + R.output(gv) + return gv + + model = Embedding() + binding = {"w1": model.embedding.weight.numpy()} + verify_model(model, input_info, binding, expected1) + + +@tvm.testing.requires_gpu +def test_dropout(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Dropout(Module): + def __init__(self): + super().__init__() + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, input): + return self.dropout(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = input_1 + R.output(gv) + return gv + + verify_model(Dropout(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_layernorm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class LayerNorm(Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm((10, 10)) + + def forward(self, input): + return self.ln(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((10, 10), dtype="float32"), + w2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm( + input_1, + w1, + w2, + axes=[-2, -1], + epsilon=1e-05, + center=True, + scale=True, + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + model = LayerNorm() + binding = { + "w1": model.ln.weight.numpy(), + "w2": model.ln.bias.numpy(), + } + verify_model(LayerNorm(), input_info, binding, expected1) + + +@tvm.testing.requires_gpu +def test_silu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class SiLU(Module): + def __init__(self): + super().__init__() + self.silu = torch.nn.SiLU() + + def forward(self, input): + return self.silu(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.silu(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(SiLU(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_groupnorm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class GroupNorm(Module): + def __init__(self): + super().__init__() + self.gn = torch.nn.GroupNorm(3, 3) + + def forward(self, input): + return self.gn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.reshape( + input_1, (1, 3, 1, 10, 10) + ) + lv1: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.mean( + lv, axis=[2, 3, 4], keepdims=True + ) + lv2: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.subtract(lv, lv1) + lv3: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv2, lv2) + lv4: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sum( + lv3, axis=[2, 3, 4], keepdims=True + ) + lv5: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.divide(lv4, R.const(100.0)) + lv6: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.add(lv5, R.const(1e-05)) + lv7: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sqrt(lv6) + lv8: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.divide(lv2, lv7) + lv9: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w1, (1, 3, 1, 1, 1)) + lv10: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w2, (1, 3, 1, 1, 1)) + lv11: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv8, lv9) + lv12: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.add(lv11, lv10) + lv13: R.Tensor((1, 3, 10, 10), dtype="float32") = R.reshape(lv12, (1, 3, 10, 10)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv13 + R.output(gv) + return gv + + model = GroupNorm() + binding = { + "w1": model.gn.weight.numpy(), + "w2": model.gn.bias.numpy(), + } + verify_model(model, input_info, binding, expected1) + + +@tvm.testing.requires_gpu +def test_softmax(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Softmax(Module): + def __init__(self): + super().__init__() + self.sm = torch.nn.Softmax(dim=1) + + def forward(self, input): + return self.sm(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softmax(input_1, axis=1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Softmax(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_binary(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] + input_info2 = [([1, 3, 10, 10], "float32")] + # Add + class Add1(Module): + def forward(self, lhs, rhs): + return lhs + rhs + + @tvm.script.ir_module + class expected1: + @R.function + def main( + lhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs, rhs) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Add2(Module): + def forward(self, lhs): + return lhs + 1.0 + + @tvm.script.ir_module + class expected2: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Add1(), input_info1, {}, expected1) + verify_model(Add2(), input_info2, {}, expected2) + + # Sub + class Sub1(Module): + def forward(self, lhs, rhs): + return lhs - rhs + + @tvm.script.ir_module + class expected3: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Sub2(Module): + def forward(self, lhs): + return lhs - 1.0 + + @tvm.script.ir_module + class expected4: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sub1(), input_info1, {}, expected3) + verify_model(Sub2(), input_info2, {}, expected4) + + # Mul + class Mul1(Module): + def forward(self, lhs, rhs): + return lhs * rhs + + @tvm.script.ir_module + class expected5: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Mul2(Module): + def forward(self, lhs): + return lhs * 1.0 + + @tvm.script.ir_module + class expected6: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Mul1(), input_info1, {}, expected5) + verify_model(Mul2(), input_info2, {}, expected6) + + # True div + class TrueDiv1(Module): + def forward(self, lhs, rhs): + return lhs / rhs + + @tvm.script.ir_module + class expected7: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class TrueDiv2(Module): + def forward(self, lhs): + return lhs / 1.0 + + @tvm.script.ir_module + class expected8: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(TrueDiv1(), input_info1, {}, expected7) + verify_model(TrueDiv2(), input_info2, {}, expected8) + + # Floor div + class FloorDiv1(Module): + def forward(self, lhs, rhs): + return lhs // rhs + + @tvm.script.ir_module + class expected9: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.floor_divide(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class FloorDiv2(Module): + def forward(self, lhs): + return lhs // 1.0 + + @tvm.script.ir_module + class expected10: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.floor_divide(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(FloorDiv1(), input_info1, {}, expected9) + verify_model(FloorDiv2(), input_info2, {}, expected10) + + # LT + class LT1(Module): + def forward(self, lhs, rhs): + return lhs < rhs + + @tvm.script.ir_module + class expected11: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="bool"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv + R.output(gv) + return gv + + class LT2(Module): + def forward(self, lhs): + return lhs < 1.0 + + @tvm.script.ir_module + class expected12: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="bool"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv + R.output(gv) + return gv + + verify_model(LT1(), input_info1, {}, expected11) + verify_model(LT2(), input_info2, {}, expected12) + + +@tvm.testing.requires_gpu +def test_size(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Size(Module): + def forward(self, input): + return input.size() + + @tvm.script.ir_module + class expected1: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Shape([1, 3, 10, 10]): + # block 0 + with R.dataflow(): + gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10]) + R.output(gv) + return gv + + verify_model(Size(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_unsqueeze(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Unsqueeze1(Module): + def forward(self, input): + return input.unsqueeze(1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = R.expand_dims(input_1, 1) + gv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Unsqueeze2(Module): + def forward(self, input): + return input.unsqueeze(-1) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10, 1), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = R.expand_dims(input_1, -1) + gv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Unsqueeze1(), input_info, {}, expected1) + verify_model(Unsqueeze2(), input_info, {}, expected2) + + +@tvm.testing.requires_gpu +def test_getattr(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class GetAttr1(Module): + def forward(self, input): + return input.shape + + @tvm.script.ir_module + class expected1: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Shape([1, 3, 10, 10]): + # block 0 + with R.dataflow(): + gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10]) + R.output(gv) + return gv + + verify_model(GetAttr1(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_getitem(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Slice1(Module): + def forward(self, x): + return x[0, 1::2, :, :3] + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 1, 10, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 1, 10, 3), dtype="float32") = R.strided_slice( + x, + axes=[0, 1, 2, 3], + begin=[0, 1, 0, 0], + end=[1, T.int64(3), T.int64(10), 3], + strides=[1, 2, 1, 1], + ) + lv1: R.Tensor((1, 1, 10, 3), dtype="float32") = R.reshape(lv, (1, 1, 10, 3)) + gv: R.Tensor((1, 1, 10, 3), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(Slice1(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_unary(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + # sin + class Sin(Module): + def forward(self, input): + return torch.sin(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sin(), input_info, {}, expected1) + + # cos + class Cos(Module): + def forward(self, input): + return torch.cos(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Cos(), input_info, {}, expected2) + + # sqrt + class Sqrt(Module): + def forward(self, input): + return torch.sqrt(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sqrt(), input_info, {}, expected3) + + +@tvm.testing.requires_gpu +def test_gelu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Gelu(Module): + def forward(self, input): + return torch.nn.functional.gelu(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.gelu(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Gelu(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_clamp(): + import torch + from torch import fx + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Clamp(Module): + def forward(self, input): + return torch.clamp(input, min=0.1, max=0.5) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0.1, 0.5) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Clamp(), input_info, {}, expected1) + + from tvm.relax.frontend.torch import from_fx + + with pytest.raises( + ValueError, match="TVM only supports constant max value for torch.clamp/clip" + ): + + class Clamp_Error(Module): + def forward(self, input): + return torch.clamp(input, min=0.5, max=None) + + gm = fx.symbolic_trace(Clamp_Error()) + from_fx(gm, input_info) + + with pytest.raises( + ValueError, match="TVM only supports constant min value for torch.clamp/clip" + ): + + class Clamp_Error(Module): + def forward(self, input): + return torch.clamp(input, min=input, max=input) + + gm = fx.symbolic_trace(Clamp_Error()) + from_fx(gm, input_info) + + +@tvm.testing.requires_gpu +def test_interpolate(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Interpolate(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, size=(5, 5)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 5, 5), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 5), dtype="float32") = R.image.resize2d( + input_1, + (5, 5), + roi=[0.000000, 0.000000, 0.000000, 0.000000], + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="asymmetric", + rounding_method="round", + cubic_alpha=-0.5, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 3, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Interpolate(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_addmm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [ + ([10, 10], "float32"), + ([10, 10], "float32"), + ([10, 10], "float32"), + ] + + class Addmm(Module): + def forward(self, x1, x2, x3): + return torch.addmm(x1, x2, x3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv) + gv: R.Tensor((10, 10), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(Addmm(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_split(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Split(Module): + def forward(self, input): + return torch.split(input, 1, dim=1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=3, axis=1) + gv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = lv + R.output(gv) + return gv + + verify_model(Split(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_tril(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([10, 10], "float32")] + + class Tril(Module): + def forward(self, input): + return torch.tril(input, 1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tril(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_new_ones(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3], "float32")] + + class NewOnes(Module): + def forward(self, x): + return x.new_ones(1, 2, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3), dtype="float32") = R.full( + (1, 2, 3), R.const(1, "float32"), dtype="float32" + ) + gv: R.Tensor((1, 2, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(NewOnes(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_expand(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Expand(Module): + def forward(self, x): + return x.expand(4, 2, 3, 4) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((4, 2, 3, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 2, 3, 4), dtype="float32") = R.broadcast_to(x, (4, 2, 3, 4)) + gv: R.Tensor((4, 2, 3, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Expand(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_reduce(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + # sum + class Sum(Module): + def forward(self, x): + return torch.sum(x, (2, 1)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 1], keepdims=False) + gv: R.Tensor((1, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sum(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_to(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + # float + class ToFloat(Module): + def forward(self, x): + return x.float() + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 2, 3, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, dtype="float32") + gv: R.Tensor((1, 2, 3, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(ToFloat(), input_info, {}, expected1) + + # half + class ToHalf(Module): + def forward(self, x): + return x.half() + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 2, 3, 4), dtype="float16"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(x, dtype="float16") + gv: R.Tensor((1, 2, 3, 4), dtype="float16") = lv + R.output(gv) + return gv + + verify_model(ToHalf(), input_info, {}, expected2) + + +@tvm.testing.requires_gpu +def test_permute(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Permute(Module): + def forward(self, x): + return x.permute(0, 3, 2, 1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4, 3, 2), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Permute(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_reshape(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Reshape(Module): + def forward(self, x): + return x.reshape(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tensor((2, 12), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Reshape(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_transpose(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Transpose(Module): + def forward(self, x): + return x.transpose(1, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4, 3, 2), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Transpose(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_view(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class View(Module): + def forward(self, x): + return x.view(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tensor((2, 12), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(View(), input_info, {}, expected1) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_pipeline.py b/tests/python/relax/test_pipeline.py new file mode 100644 index 000000000000..6d6704ae97ec --- /dev/null +++ b/tests/python/relax/test_pipeline.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +from tvm import relax +from tvm.script import relax as R + + +def test_pipeline_compile(): + pipeline = relax.get_pipeline() + + @tvm.script.ir_module + class Mod: + @R.function + def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): + lv0 = R.add(x, y) + return lv0 + + mod = Mod + mod = pipeline(mod) + target = tvm.target.Target("llvm", host="llvm") + + ex = relax.vm.build(mod, target) + x_np = np.random.rand(3, 4).astype(np.float32) + y_np = np.random.rand(3, 4).astype(np.float32) + x = tvm.nd.array(x_np) + y = tvm.nd.array(y_np) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + z = vm["main"](x, y) + tvm.testing.assert_allclose(z.numpy(), x_np + y_np, rtol=1e-7, atol=1e-7) From df0e043272038ce979140b99292e63083965cc00 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 18 Feb 2023 11:38:32 -0500 Subject: [PATCH 38/81] [Unity][Pass] Block-level static memory planning (#14038) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR introduces the static memory planning pass on binding block level, as well as an analysis function that estimate the memory usage after the memory planning pass. It supports the following features: nested-tuples, reuse memory of the input of reshape ops, an estimator that returns total memory size needed to be allocated before and after memory planning, as well as the number of tensors / memory blocks to be allocated before and after memory planning. The estimation is static -- it does not consider control flows (such as “if” and cross-function calls). It simply accumulates the size of every alloc_tensor and alloc_storage. We will produce “`relax.memory.alloc_tensor/storage`” as the results produced by memory planning. --- include/tvm/relax/transform.h | 9 + python/tvm/relax/analysis/__init__.py | 1 + .../relax/analysis/estimate_memory_usage.py | 164 ++++ python/tvm/relax/transform/transform.py | 11 + python/tvm/relax/vm.py | 1 + .../transform/static_plan_block_memory.cc | 750 ++++++++++++++++++ .../test_analysis_estimate_memory_usage.py | 125 +++ ...test_transform_static_plan_block_memory.py | 612 ++++++++++++++ 8 files changed, 1673 insertions(+) create mode 100644 python/tvm/relax/analysis/estimate_memory_usage.py create mode 100644 src/relax/transform/static_plan_block_memory.cc create mode 100644 tests/python/relax/test_analysis_estimate_memory_usage.py create mode 100644 tests/python/relax/test_transform_static_plan_block_memory.py diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 8b7c7880b9b6..1934a9f9f2a0 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -95,6 +95,15 @@ TVM_DLL Pass CallTIRRewrite(); */ TVM_DLL Pass RewriteDataflowReshape(); +/*! + * \brief The static memory planning pass on BindingBlock level. + * The pass will reuse allocated memory to its best effort, in order to + * reduce the total amount of allocated memory size. + * + * \return The pass. + */ +TVM_DLL Pass StaticPlanBlockMemory(); + /*! * \brief Bind params of function of the module to constant tensors. * diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py index cc0089ff3134..7ba56ff40840 100644 --- a/python/tvm/relax/analysis/__init__.py +++ b/python/tvm/relax/analysis/__init__.py @@ -18,3 +18,4 @@ """Relax IR analysis. """ from .analysis import * +from .estimate_memory_usage import estimate_memory_usage diff --git a/python/tvm/relax/analysis/estimate_memory_usage.py b/python/tvm/relax/analysis/estimate_memory_usage.py new file mode 100644 index 000000000000..55f82740ec9c --- /dev/null +++ b/python/tvm/relax/analysis/estimate_memory_usage.py @@ -0,0 +1,164 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=abstract-method,unused-argument +# pylint: disable=missing-function-docstring,missing-module-docstring +from typing import Union +import tvm +from tvm.ir import Op +from tvm.ir.module import IRModule + +from ..expr import Call, Expr, Function, ShapeExpr +from ..expr_functor import visitor, PyExprVisitor + + +def estimate_memory_usage(mod: Union[IRModule, Function]) -> str: + """Analysis function that estimates the memory usage of Relax functions + in an IRModule. The estimation includes the total memory size needed to + be allocated before and after memory planning. + + The result might be over-estimated, as the estimation is static, which + does not consider control flows (such as "if" and cross-function calls). + It simply accumulates the size of every alloc_tensor and alloc_storage. + + This analysis function is used to demonstrate the effect of memory + planning. + + Parameters + ---------- + mod : Union[IRModule, Function] + The input IRModule whose functions inside are to be analyzed. + If the input is a Function, we will wrap it with a IRModule, with + the function named "main". + + Returns + ------- + est : str + The estimation information, in the form of a string. + + Notes + ----- + We regards "relax.memory.alloc_tensor/storage" as the results produced by memory planning. + """ + + @visitor + class MemoryEstimator(PyExprVisitor): + """The IR visitor which estimates the memory usage of each Relax function. + + Attributes + ---------- + total_alloc_tensor_mem : int + The total memory size of alloc_tensor, in bytes. + + total_const_size_tensor_num : int + The number of constant-size tensors. + + total_dyn_size_tensor_num : int + The number of dynamic-size tensors. + + planned_alloc_mem : int + The total memory size of memory.alloc_storage after memory planning, in bytes. + + planned_mem_num : int + The number of memory.alloc_storages. + """ + + total_alloc_tensor_mem: int + total_const_size_tensor_num: int + total_dyn_size_tensor_num: int + planned_alloc_mem: int + planned_mem_num: int + builtin_alloc_tensor_op = Op.get("relax.builtin.alloc_tensor") + memory_alloc_tensor_op = Op.get("relax.memory.alloc_tensor") + memory_alloc_storage_op = Op.get("relax.memory.alloc_storage") + + def estimate(self, mod: IRModule) -> str: + estimation: str = "" + for global_var, func in mod.functions.items(): + if not isinstance(func, Function): + continue + + self.cleanup() + self.visit_expr(func) + estimation += self.generate_est_string(global_var.name_hint) + + if estimation != "": + estimation = "Memory usage estimation:\n" + estimation + return estimation + + def cleanup(self) -> None: + self.total_alloc_tensor_mem = 0 + self.total_const_size_tensor_num = 0 + self.total_dyn_size_tensor_num = 0 + self.planned_alloc_mem = 0 + self.planned_mem_num = 0 + + def visit_call_(self, call: Call) -> None: # pylint: disable=arguments-differ + if call.op == self.builtin_alloc_tensor_op: + self.accumulate_tensor_alloc(shape=call.args[0], dtype_str=call.args[1].value) + elif call.op == self.memory_alloc_tensor_op: + self.accumulate_tensor_alloc(shape=call.args[2], dtype_str=call.args[3].value) + elif call.op == self.memory_alloc_storage_op: + self.accumulate_storage_alloc(size=call.args[0]) + + def accumulate_tensor_alloc(self, shape: Expr, dtype_str: str) -> None: + if not isinstance(shape, ShapeExpr): + raise TypeError( + "The shape of relax.builtin.alloc_tensor and " + "relax.memory.alloc_tensor is expected to be ShapeExpr" + ) + size: int = 1 + for dim_len in shape.values: + if not isinstance(dim_len, tvm.tir.IntImm): + self.total_dyn_size_tensor_num += 1 + return + size *= dim_len.value + + dtype = tvm.DataType(dtype_str) + self.total_const_size_tensor_num += 1 + self.total_alloc_tensor_mem += (size * dtype.bits * dtype.lanes + 7) // 8 + + def accumulate_storage_alloc(self, size: Expr) -> None: + if not isinstance(size, ShapeExpr): + raise TypeError( + "The size of relax.memory.alloc_storage is expected to be ShapeExpr" + ) + + self.planned_mem_num += 1 + self.planned_alloc_mem += size.values[0].value + + def generate_est_string(self, func_name: str) -> str: + est = ( + f" * Without memory planning, there are {self.total_const_size_tensor_num} " + "constant-size memory allocation(s) with total size " + "{0:.4} GB".format(self.total_alloc_tensor_mem / 2**30) + ) + if self.total_dyn_size_tensor_num > 0: + est += f", and {self.total_dyn_size_tensor_num} dynamic-size allocation(s)" + est += ( + f".\n * With memory planning, there are {self.planned_mem_num} constant-size " + "memory allocation(s) with total size " + "{0:.4} GB.\n".format(self.planned_alloc_mem / 2**30) + ) + est += " * Memory planning reduces constant memory size to " "{0:.1%}.".format( + self.planned_alloc_mem / self.total_alloc_tensor_mem + ) + return "- Function " + func_name + ":\n" + est + + if isinstance(mod, Function): + mod = tvm.IRModule({tvm.ir.GlobalVar("foo"): mod}) + + return MemoryEstimator().estimate(mod) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 4ba967935b52..1f14823b5a94 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -93,6 +93,17 @@ def RewriteDataflowReshape() -> tvm.ir.transform.Pass: return _ffi_api.RewriteDataflowReshape() # type: ignore +def StaticPlanBlockMemory() -> tvm.ir.transform.Pass: + """The static memory planning pass on BindingBlock level. + The pass will reuse allocated memory to its best effort, in order to + reduce the total amount of allocated memory size. + Returns + ------- + ret : tvm.ir.transform.Pass + """ + return _ffi_api.StaticPlanBlockMemory() # type: ignore + + def VMBuiltinLower() -> tvm.ir.transform.Pass: """Lowering generic intrinsic to VM intrinsics. diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index ff6bf816b62b..2cf1250690a0 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -585,6 +585,7 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): passes.append(relax.transform.RewriteDataflowReshape()) passes.append(relax.transform.ToNonDataflow()) passes.append(relax.transform.CallTIRRewrite()) + passes.append(relax.transform.StaticPlanBlockMemory()) passes.append(relax.transform.VMBuiltinLower()) passes.append(relax.transform.VMShapeLower()) passes.append(relax.transform.AttachGlobalSymbol()) diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc new file mode 100644 index 000000000000..8b7adae246eb --- /dev/null +++ b/src/relax/transform/static_plan_block_memory.cc @@ -0,0 +1,750 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/static_plan_block_memory.cc + * \brief The static memory planning pass on BindingBlock level. + * \details + * The core data structure of the planning pass is StorageToken, which denotes + * reusable memory in this planning pass. + * + * The memory planning pass contains three stages: + * + * The first stage is initialization. A storage token object will be created + * for each builtin alloc_tensor as long as the allocated storage satisfies + * the requirements (which are described in the code). The reference counter + * (i.e., the times of reference) for each token is recorded. + * + * The second stage is allocation planning. We maintain a pool of available + * allocated storage, in the form of storage tokens. For the storage token of + * each builtin alloc_tensor, we check if there is appropriate available token + * in the pool under certain criterion. If there is, we reuse that storage + * for this alloc_tensor. Otherwise, we decide to allocate a storage for the + * alloc_tensor. + * + * The third stage is IR rewrite. Based on the decision made in the second + * stage, we insert memory alloc_storage, alloc_tensor, kill_tensor, and + * kill_storage accordingly. Specifically, we + * - insert alloc_storage before the site that each storage token is firstly + * used, + * - insert memory alloc_tensor for each builtin alloc_tensor, + * - insert kill_tensor after the site that a tensor created by alloc_tensor + * is last referenced, and + * - insert kill_storage at the end of each binding block, for all the storage + * tokens that are allocated inside the binding block, as the memory planning + * only works on block level. + */ +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A representation of a block of reusable memory required at runtime. + * \details Only the tensors whose memory can be "possibly reused" will have + * their storage token. In other words, we do not have storage token for tensor + * - that is a function parameter, + * - that is a function return value, + * - one of whose use site is a BindingBlock different from its allocation site, + * - that is used as a condition or branch return of a IfNode, + * - that is used as the body of a SeqExprNode, + * - that is used as arguments in a Call whose op is not a PrimFunc. + * + * In practice, we do create a storage token for such tensor at first. But at + * any time we find a tensor satisfying any of the conditions above, we erase + * its storage token. + */ +class StorageTokenNode : public Object { + public: + /*! \brief Reference counter. */ + int ref_counter{0}; + /*! \brief Number of bytes that this token requires. */ + int64_t bytes; + /*! \brief The dtype of this token. */ + DataType dtype; + /*! \brief The storage id, reserved for debug and demo use. */ + int storage_id{-1}; + /*! + * \brief The variable corresponding to the allocated storage, which is NullOpt + * before definition. + */ + Optional storage{NullOpt}; + + static constexpr const char* _type_key = "relax.transform.StorageToken"; + TVM_DECLARE_BASE_OBJECT_INFO(StorageTokenNode, Object); +}; + +/*! + * \brief Managed reference to StorageTokenNode. + * \sa StorageTokenNode + */ +class StorageToken : public ObjectRef { + public: + explicit StorageToken(Array shape, DataType dtype) { + // Compute the tensor size from the shape. + int64_t size = 1; + for (const PrimExpr& dim_len : shape) { + const auto* int_len = dim_len.as(); + ICHECK_NOTNULL(int_len); + size *= int_len->value; + } + + ObjectPtr n = make_object(); + n->bytes = (size * dtype.bits() * dtype.lanes() + 7) / 8; + n->dtype = dtype; + data_ = std::move(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(StorageToken, ObjectRef, StorageTokenNode); +}; + +// We use NestedMsg to store the tokens used by each Expr. +using Tokens = NestedMsg; + +/*! + * \brief Memory manager for flattened 1d memory (buffers) + * \note We can generalize this implementation to multi-dimensional memory + * following the same flow in the future. + */ +class TokenAllocator1D { + public: + /*! + * \brief Request a storage token from the available token pool for a + * given prototype, or report no appropriate available token in the pool. + * \param prototype The requesting prototype storage token. + * \return The request result token. Return NullOpt if there is no + * appropriate available token in the pool. + */ + Optional RequestReuse(StorageToken prototype) { + // Step 0. Sanity check: the prototype token is supposed not to be allocated with actual storage + ICHECK_EQ(prototype->storage_id, -1) << "The token is expected not to be allocated before."; + // If the prototype has no reference at all, feel free to allocate new storage. + // The unused binding can be removed by cleaning passes. + if (prototype->ref_counter == 0) { + return NullOpt; + } + + // Step 1. Get the available pool of the token dtype. + std::multimap& pool = available_pool_[prototype->dtype]; + + // Step 2. Get the range of memory blocks in [size / match_range_, size * match_range_) + int64_t size = prototype->bytes; + auto begin = pool.lower_bound(size / match_range_); + auto mid = pool.lower_bound(size); + auto end = pool.upper_bound(size * match_range_); + // Step 3. Search for memory block that equals or is larger than the requested size. + if (mid != end) { + StorageToken available_token = mid->second; + ICHECK_EQ(available_token->ref_counter, 0) + << "Available tokens are expected to have 0 reference."; + ICHECK_LE(size, available_token->bytes); + available_token->ref_counter = prototype->ref_counter; + pool.erase(mid); + return available_token; + } + // Step 4. Then search for memory block that is smaller than the requested size. + if (mid != begin) { + --mid; + StorageToken available_token = mid->second; + ICHECK_EQ(available_token->ref_counter, 0) + << "Available tokens are expected to have 0 reference."; + ICHECK_GE(size, available_token->bytes); + // Enlarge the token size. + available_token->bytes = size; + available_token->ref_counter = prototype->ref_counter; + pool.erase(mid); + return available_token; + } + // Return `NullOpt` indicating that no satisfiable storage token is found in the available pool. + return NullOpt; + } + + /*! + * \brief Allocate a storage token for the input prototype token. + * \param prototype The prototype token. + * \param storage_id The id of this token. + */ + StorageToken Alloc(StorageToken prototype, int storage_id) { + // Sanity check: the prototype token is supposed not to be allocated with actual storage yet + ICHECK_EQ(prototype->storage_id, -1) << "The token is expected not to be allocated before."; + prototype->storage_id = storage_id; + full_pool_.push_back(prototype); + return prototype; + } + + /*! + * \brief Release the input token, putting it into the available pool. + * \param token The token to be released. + */ + void Release(StorageToken token) { + // Sanity check: the token has been allocated with actual storage, and should have 0 reference. + ICHECK_GE(token->storage_id, 0) + << "The token to be released is expected to be allocated before"; + ICHECK_EQ(token->ref_counter, 0) << "The token to be released is expected to have 0 reference."; + available_pool_[token->dtype].insert({token->bytes, token}); + } + + private: + /*! \brief A constant scale representing the token search range. */ + const int match_range_{16}; + /*! \brief The pool of available storage tokens for each dtype. */ + std::unordered_map> available_pool_; + /*! \brief All the storage tokens that have been allocated with actual storage. */ + std::vector full_pool_; +}; + +/*! \brief Check if the input op is "relax.reshape". */ +bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); } + +/*! \brief The base class for the storage allocation visitor. */ +class StorageAllocatorBaseVisitor : public ExprVisitor { + protected: + using ExprVisitor::VisitExpr_; + + void VisitBindingBlock_(const BindingBlockNode* block) override { + // We maintain a block stack for token allocation-site and use-site check. + block_stack_.push_back(block); + ExprVisitor::VisitBindingBlock_(block); + ICHECK(!block_stack_.empty()); + ICHECK(block_stack_.back() == block); + block_stack_.pop_back(); + } + + void VisitBinding_(const VarBindingNode* binding) override { + ExprVisitor::VisitBinding_(binding); + // The binding var has the same tokens as the binding value. + SetTokens(binding->var.get(), token_map_[binding->value.get()]); + } + + void VisitExpr_(const TupleNode* tuple) final { + Array tokens; + tokens.reserve(tuple->fields.size()); + for (const Expr& field : tuple->fields) { + Tokens field_tokens = GetTokens(field); + tokens.push_back(field_tokens); + } + SetTokens(tuple, Tokens(tokens)); + } + + void VisitExpr_(const TupleGetItemNode* tuple_item) final { + Tokens tokens = GetTokens(tuple_item->tuple); + // If the tuple has no token, every of its field has no token as well. + if (tokens.IsNull()) { + token_map_[tuple_item] = Tokens(); + return; + } + ICHECK(tokens.IsNested()); + Array field_tokens = tokens.NestedArray(); + ICHECK_GT(static_cast(field_tokens.size()), tuple_item->index); + ICHECK_GE(tuple_item->index, 0); + SetTokens(tuple_item, field_tokens[tuple_item->index]); + } + + /******************** Utilities ********************/ + + Tokens GetTokens(const Expr& expr) { + this->VisitExpr(expr); + return token_map_[expr.get()]; + } + + virtual void SetTokens(const ExprNode* expr, Tokens tokens) { token_map_[expr] = tokens; } + + /*! \brief The mapping from each Expr to its corresponding storage tokens. */ + std::unordered_map token_map_; + /*! \brief The binding block stack. */ + std::vector block_stack_; +}; + +/*! + * \brief The visitor class for storage token initialization. + * \details It goes through the entire function to get the storage tokens + * used by each Expr. After the initialization, we + * - know the tokens that each Expr is using, + * - know the number of references for each token, + * - rule out the builtin alloc_tensors to which the planning does not apply. + */ +class StorageAllocatorInit : public StorageAllocatorBaseVisitor { + public: + explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {} + + /*! + * \brief The entry of the initialization. + * \return The mapping from each Expr to the token it uses. + */ + std::unordered_map Initialize(const Function& func) { + // Recurse into the function to get its tokens. + Tokens body_tokens = GetTokens(func->body); + // Discard the tokens used by the function return value, as they are external referenced. + DiscardTokensIn(body_tokens); + return this->token_map_; + } + + private: + using ExprVisitor::VisitExpr_; + + void VisitExpr_(const CallNode* call) final { + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + if (call->op == alloc_tensor_op) { + // Create a storage token for builtin alloc_tensor. + this->CreateToken(call); + return; + } else if (IsReshape(call->op)) { + // Reuse the input's token for builtin reshape. + SetTokens(call, GetTokens(call->args[0])); + return; + } + + // - Increase the reference counters of the arguments when the callee is + // a PrimFunc of the context module. + // - Otherwise, discard the tokens used by the arguments, as there might be + // potential external reference. + if (IsPrimFuncGlobalVar(call->op)) { + ICHECK(!block_stack_.empty()); + for (const Expr& arg : call->args) { + Tokens tokens = GetTokensWithAllocSiteCheck(arg, block_stack_.back()); + ForEachLeaf(tokens, [](StorageToken token) { token->ref_counter += 1; }); + } + } else { + for (const Expr& arg : call->args) { + DiscardTokensIn(GetTokens(arg)); + } + } + } + + void VisitExpr_(const IfNode* if_node) final { + Tokens cond_tokens = GetTokens(if_node->cond); + Tokens then_tokens = GetTokens(if_node->true_branch); + Tokens else_tokens = GetTokens(if_node->false_branch); + // Discard the tokens used by the condition, then-body and else-body, + // as the planning works on block level. + DiscardTokensIn(cond_tokens); + DiscardTokensIn(then_tokens); + DiscardTokensIn(else_tokens); + } + + void VisitExpr_(const SeqExprNode* seq) final { + for (const BindingBlock& binding_block : seq->blocks) { + this->VisitBindingBlock(binding_block); + } + Tokens body_tokens = GetTokens(seq->body); + // Discard the tokens used by the body, as the planning works on block level. + DiscardTokensIn(body_tokens); + } + + /******************** Utilities ********************/ + + /*! + * \brief Check if the input op is GlobalVar corresponding to a PrimFunc inside the ctx module. + * \param op The op to be checked + * \return A boolean indicating if the input op corresponds to a PrimFunc. + */ + bool IsPrimFuncGlobalVar(const Expr& op) { + const auto* global_var = op.as(); + if (global_var == nullptr) { + return false; + } + auto func_it = ctx_mod_->functions.find(GetRef(global_var)); + if (func_it == ctx_mod_->functions.end()) { + return false; + } + return (*func_it).second->IsInstance(); + } + + /*! + * \brief Create a storage token for the builtin alloc_tensor call. + * \param call The call to be processed. + * \return The created token. + */ + Tokens CreateToken(const CallNode* call) { + // Sanity checks about + // - the call return value is a Tensor; + // - the shape of the tensor is known, in the form of ShapeExpr; + // - the tensor has known dtype; + // - no storage token was created for this call before. + const auto* sinfo = call->struct_info_.as(); + const auto* shape = sinfo->shape.as(); + ICHECK_NOTNULL(sinfo); + ICHECK_NOTNULL(shape); + ICHECK(!sinfo->IsUnknownDtype()); + ICHECK(sinfo->dtype == Downcast(call->args[1])->value); + ICHECK(!token_map_.count(call)); + + // No support for symbolic shape at this moment. + for (const PrimExpr& dim_len : shape->values) { + const auto* int_len = dim_len.as(); + if (!int_len) { + token_map_[call] = Tokens(); + return Tokens(); + } + } + + // Create and set token. + StorageToken token(shape->values, sinfo->dtype); + + Tokens tokens(token); + SetTokens(call, tokens); + ICHECK(!block_stack_.empty()); + token2block_[token.get()] = block_stack_.back(); + return tokens; + } + + /*! + * \brief Override the token setter in the base visitor. + * For each token, we keep record of all Expr that are using that token. + * When we want to discard one token, we use the records to remove the token + * from the Expr that are using it. + */ + void SetTokens(const ExprNode* expr, Tokens tokens) final { + StorageAllocatorBaseVisitor::SetTokens(expr, tokens); + ForEachLeaf(tokens, [this, expr](StorageToken token) { + this->token2exprs_[token.get()].push_back(expr); + }); + } + + /*! + * \brief Token getter with allocation site check. + * We first get the tokens used by the input Expr, and check if the allocation + * site of each token is the input current block. + * Since the planning works on block level, if some token's allocation site + * is not the current block, we discard the token so that it will not be planned. + * \param expr The Expr whose tokens is to be got. + * \param cur_block The pointer to the current block. + * \return The tokens used by the input Expr. + */ + Tokens GetTokensWithAllocSiteCheck(const Expr& expr, const BindingBlockNode* cur_block) { + Tokens tokens = GetTokens(expr); + ForEachLeaf(tokens, [this, cur_block](StorageToken token) { + auto it = this->token2block_.find(token.get()); + ICHECK(it != this->token2block_.end()); + if (it->second != cur_block) { + this->DiscardToken(token); + } + }); + return token_map_[expr.get()]; + } + + /*! \brief Discard the input tokens. */ + void DiscardTokensIn(Tokens tokens) { + ForEachLeaf(tokens, [this](StorageToken token) { this->DiscardToken(token); }); + } + + /*! + * \brief Discard the input token. + * For each Expr that is using the input token, remove the token from the Expr's token set. + * \param token_to_discard The token to be discarded. + */ + void DiscardToken(StorageToken token_to_discard) { + const std::vector& exprs = token2exprs_[token_to_discard.get()]; + for (const ExprNode* expr : exprs) { + token_map_[expr] = MapNestedMsg(token_map_[expr], [token_to_discard](StorageToken token) { + return token.same_as(token_to_discard) ? Tokens() : Tokens(token); + }); + } + token2exprs_.erase(token_to_discard.get()); + token2block_.erase(token_to_discard.get()); + } + + /*! + * \brief The context IRModule, used for checking if a callee function is + * a PrimFunc inside the IRModule. + */ + const IRModule& ctx_mod_; + /*! \brief The mapping from each token to the binding block where it is created. */ + std::unordered_map token2block_; + /*! \brief The mapping from each token to the Exprs that are using this token. */ + std::unordered_map> token2exprs_; +}; + +/*! + * \brief The visitor class for storage token allocation planning. + * \details + * - For each builtin alloc_tensor whose token is not discarded in the + * initialization stage, we request a storage reuse or decide to allocate + * storage for this token, depending on if there is appropriate available + * token in the token pool we maintain. + * - For each VM builtin reshape, we reuse the input's tokens. + * + * After the allocation planning, we + * - know the token that each builtin alloc_tensor plans to use. Compared + * with the initialization, here the token is possibly a reuse of some + * previous token, rather than we having one token for each alloc_tensor. + * - know the last referenced site for each builtin alloc_tensor. This + * information is used for inserting kill_tensor in the rewrite stage. + * - know the tokens allocated in each binding block. This information + * is used for inserting kill_storage in the rewrite stage. + */ +class StorageAllocator : public StorageAllocatorBaseVisitor { + public: + explicit StorageAllocator(std::unordered_map token_map) { + this->token_map_ = std::move(token_map); + } + + /*! + * \brief The mapping from each `builtin.alloc_tensor` to its corresponding + * underlying storage token that it is using. + */ + std::unordered_map alloc_tensor2token; + /*! \brief The mapping from each Expr to the tensors that need to be killed after it. */ + std::unordered_map> expr2killed_tensors; + /*! \brief The mapping from each binding block to the storage tokens that are create inside. */ + std::unordered_map> block2tokens; + + private: + using ExprVisitor::VisitBinding_; + using ExprVisitor::VisitExpr_; + + void VisitBindingBlock_(const BindingBlockNode* block) final { + StorageAllocatorBaseVisitor::VisitBindingBlock_(block); + // Sanity check: each token allocated inside the block should not be + // referenced by anyone at the end of the block. + for (const StorageTokenNode* token : block2tokens[block]) { + ICHECK_EQ(token->ref_counter, 0); + } + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + if (call->op == alloc_tensor_op) { + auto it = token_map_.find(call); + ICHECK(it != token_map_.end()); + + if (it->second.IsNull()) { + // IsNull being true means the token was discarded, and this alloc_tensor + // is not considered by the planning. + return; + } + ICHECK(it->second.IsLeaf()); + StorageToken new_token = this->RequestReuseOrAlloc(it->second.LeafValue()); + + // Record that this alloc_tensor is using the token. + alloc_tensor2token.insert({call, new_token}); + token2cur_tensor_[new_token.get()].push_back(binding->var); + SetTokens(call, Tokens(new_token)); + // Record that the token is allocated in the current block. + ICHECK(!block_stack_.empty()); + std::vector& block_tokens = block2tokens[block_stack_.back()]; + if (std::find(block_tokens.begin(), block_tokens.end(), new_token.get()) == + block_tokens.end()) { + block_tokens.push_back(new_token.get()); + } + return; + } else if (IsReshape(call->op)) { + Tokens tokens = GetTokens(call->args[0]); + ICHECK(!tokens.IsNested()); + if (tokens.IsLeaf()) { + // If the input is using a token, record that the reshape uses the token as well. + token2cur_tensor_[tokens.LeafValue().get()].push_back(binding->var); + SetTokens(call, tokens); + } else { + ICHECK(token_map_[call].IsNull()); + } + return; + } + + // Decrease the reference counter by one for each token that the arguments use. + // Check if a token can be released (i.e., has no reference) after decrease. + // And release it if so. + for (const Expr& arg : call->args) { + Tokens tokens = GetTokens(arg); + ForEachLeaf(tokens, [this, call](StorageToken token) { + ICHECK_GT(token->ref_counter, 0); + token->ref_counter -= 1; + this->CheckForRelease(token, call); + }); + } + } + + /*! \brief Request a storage reuse, or allocate storage if no appropriate storage is reusable. */ + StorageToken RequestReuseOrAlloc(StorageToken prototype) { + Optional token = allocator_.RequestReuse(prototype); + if (!token.defined()) { + return allocator_.Alloc(prototype, this->n_storage_++); + } else { + return token.value(); + } + } + + /*! + * \brief Check if a token has no reference and thus can be released. And release it if so. + * \param token The token to be checked. + * \param release_site The CallNode where the the input token is send for release. + * If the token is checked to release here, we keep record of the release site so that + * kill_tensor can be inserted here at the rewrite stage. + */ + void CheckForRelease(StorageToken token, const CallNode* release_site) { + // Sanity check: the token was allocated before and has non-negative reference. + ICHECK_GE(token->storage_id, 0); + ICHECK_GE(token->ref_counter, 0); + + if (token->ref_counter == 0) { + allocator_.Release(token); + auto it = token2cur_tensor_.find(token.get()); + ICHECK(it != token2cur_tensor_.end()); + // Record that the tensors that are using this token will be killed + // immediately after the release site. + std::vector& killed_tensors = expr2killed_tensors[release_site]; + killed_tensors.insert(killed_tensors.end(), it->second.begin(), it->second.end()); + token2cur_tensor_.erase(it); + } + } + + /*! \brief Number of allocated storages. */ + int n_storage_{0}; + /*! \brief The 1D memory allocator. */ + TokenAllocator1D allocator_; + /*! \brief The mapping from each token to the tensors that are currently using it. */ + std::unordered_map> token2cur_tensor_; +}; + +/*! + * \brief The rewriter class based on the token allocation planning. + * \details + * - For each builtin alloc_tensor that was planned, substitute it with a memory + * alloc_tensor. If no memory alloc_storage was created for it before, create one. + * - Insert memory kill_tensor at the release site of each tensor. + * - Insert memory kill_storage at the end of each binding block, for the tokens allocated in it. + */ +class StorageAllocationRewriter : public ExprMutator { + public: + explicit StorageAllocationRewriter( + std::unordered_map alloc_tensor2token, + std::unordered_map> expr2killed_tensors, + std::unordered_map> + block2tokens) + : alloc_tensor2token_(std::move(alloc_tensor2token)), + expr2killed_tensors_(std::move(expr2killed_tensors)), + block2tokens_(std::move(block2tokens)) {} + + private: + using ExprMutator::VisitExpr_; + + BindingBlock VisitBindingBlock_(const BindingBlockNode* block) final { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + + // Insert `memory.kill_storage` for the storage tokens allocated inside this block. + for (const StorageTokenNode* token : block2tokens_[block]) { + ICHECK(token->storage.defined()); + static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage"); + this->builder_->Emit(Call(mem_kill_storage, {token->storage.value()}), /*name_hint=*/"_"); + } + + BindingBlock new_block = builder_->EndBlock(); + return new_block; + } + + void VisitBinding_(const VarBindingNode* binding) final { + ExprMutator::VisitBinding_(binding); + + // Insert `memory.kill_tensor` for the tensors that need to be killed after this binding. + auto it = expr2killed_tensors_.find(binding->value.get()); + if (it != expr2killed_tensors_.end()) { + for (const Var& var : it->second) { + static const Op& mem_kill_tensor = Op::Get("relax.memory.kill_tensor"); + this->builder_->Emit(Call(mem_kill_tensor, {Downcast(this->VisitExpr(var))}), + /*name_hint=*/"_"); + } + } + } + + Expr VisitExpr_(const CallNode* call) final { + auto it = alloc_tensor2token_.find(call); + if (it != alloc_tensor2token_.end()) { + const auto* sinfo = call->struct_info_.as(); + ICHECK_NOTNULL(sinfo); + ICHECK_NOTNULL(sinfo->shape.as()); + PrimValue runtime_device_index = Downcast(call->args[2]); + + // If the token is visited for the first time, create a storage variable using + // `memory.alloc_storage` for it. + StorageToken token = it->second; + if (!token->storage.defined()) { + static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage"); + ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)}); + PrimValue virtual_device_index = runtime_device_index; + std::string storage_scope = "global"; + DataType dtype = token->dtype; + Call alloc_storage( + mem_alloc_storage, + {std::move(size), virtual_device_index, StringImm(storage_scope), DataTypeImm(dtype)}, + Attrs()); + token->storage = builder_->Emit(alloc_storage, "storage"); + } + + // And always create a `memory.alloc_tensor` for the old `builtin.alloc_tensor`. + static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor"); + PrimValue offset = PrimValue::Int64(0); + DataType dtype = sinfo->dtype; + return Call(mem_alloc_tensor, + {token->storage.value(), offset, sinfo->shape.value(), DataTypeImm(dtype)}, + Attrs()); + } + + return ExprMutator::VisitExpr_(call); + } + + /*! + * \brief The mapping from each memory-reusable `builtin.alloc_tensor` to + its corresponding underlying storage token that it is using. + */ + std::unordered_map alloc_tensor2token_; + /*! \brief The mapping from each Expr to the tensors that need to be killed after it. */ + std::unordered_map> expr2killed_tensors_; + /*! \brief The mapping from each binding block to the storage tokens that are create inside. */ + std::unordered_map> block2tokens_; +}; + +Expr StaticPlanBlockMemory(Function func, const IRModule& ctx_mod) { + // Step 1. Initialize. + StorageAllocatorInit initializer(ctx_mod); + std::unordered_map token_map = initializer.Initialize(func); + // Step 2. Collect the memory allocation info. + StorageAllocator allocator(std::move(token_map)); + allocator(func); + // Step 3. Rewrite the function. + StorageAllocationRewriter rewriter(std::move(allocator.alloc_tensor2token), + std::move(allocator.expr2killed_tensors), + std::move(allocator.block2tokens)); + func = Downcast(rewriter(func)); + return func; +} + +namespace transform { + +Pass StaticPlanBlockMemory() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(StaticPlanBlockMemory(std::move(f), m)); + }; + return CreateFunctionPass(pass_func, /*opt_level=*/0, "StaticPlanBlockMemory", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.StaticPlanBlockMemory").set_body_typed(StaticPlanBlockMemory); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_analysis_estimate_memory_usage.py b/tests/python/relax/test_analysis_estimate_memory_usage.py new file mode 100644 index 000000000000..3e6ba4499fe6 --- /dev/null +++ b/tests/python/relax/test_analysis_estimate_memory_usage.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import relax as R, tir as T +from tvm.relax.analysis import estimate_memory_usage + + +def test_basic(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add( + rxplaceholder: T.Buffer(T.int64(8), "float32"), + rxplaceholder_1: T.Buffer((), "float32"), + T_add: T.Buffer(T.int64(8), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), + T_reshape: T.Buffer(T.int64(8), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def relu( + rxplaceholder: T.Buffer(T.int64(8), "float32"), compute: T.Buffer(T.int64(8), "float32") + ): + T.evaluate(0) + + @T.prim_func + def log( + rxplaceholder: T.Buffer(T.int64(10), "float32"), + compute: T.Buffer(T.int64(10), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def exp( + rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), + compute: T.Buffer((T.int64(2), T.int64(4)), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def pad( + rxplaceholder: T.Buffer(T.int64(8), "float32"), + PadInput: T.Buffer(T.int64(10), "float32"), + ): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + storage: R.Object = R.memory.alloc_storage( + R.shape([32]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor( + storage, offset=0, shape=R.shape([2, 4]), dtype="float32" + ) + _: R.Tuple() = exp(x, alloc) + lv: R.Tensor((2, 4), dtype="float32") = alloc + lv1: R.Tensor((8,), dtype="float32") = R.call_packed( + "vm.builtin.reshape", lv, R.shape([8]), sinfo_args=[R.Tensor((8,), dtype="float32")] + ) + storage1: R.Object = R.memory.alloc_storage( + R.shape([40]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc1: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( + storage1, offset=0, shape=R.shape([8]), dtype="float32" + ) + _1: R.Tuple() = relu(lv1, alloc1) + _2: R.Tuple() = R.memory.kill_tensor(alloc) + _3: R.Tuple() = R.memory.kill_tensor(lv1) + lv2: R.Tensor((8,), dtype="float32") = alloc1 + alloc2: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( + storage, offset=0, shape=R.shape([8]), dtype="float32" + ) + _4: R.Tuple() = add(lv2, R.const(1, "float32"), alloc2) + _5: R.Tuple() = R.memory.kill_tensor(alloc1) + lv3: R.Tensor((8,), dtype="float32") = alloc2 + alloc3: R.Tensor((10,), dtype="float32") = R.memory.alloc_tensor( + storage1, offset=0, shape=R.shape([10]), dtype="float32" + ) + _6: R.Tuple() = pad(lv3, alloc3) + _7: R.Tuple() = R.memory.kill_tensor(alloc2) + lv4: R.Tensor((10,), dtype="float32") = alloc3 + alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor( + R.shape([10]), dtype="float32", runtime_device_index=0 + ) + _8: R.Tuple() = log(lv4, alloc4) + _9: R.Tuple() = R.memory.kill_tensor(alloc3) + gv5: R.Tensor((10,), dtype="float32") = alloc4 + _11: R.Tuple() = R.memory.kill_storage(storage) + _10: R.Tuple() = R.memory.kill_storage(storage1) + return gv5 + + assert ( + estimate_memory_usage(Module) + == r"""Memory usage estimation: +- Function main: + * Without memory planning, there are 5 constant-size memory allocation(s) with total size 1.639e-07 GB. + * With memory planning, there are 2 constant-size memory allocation(s) with total size 6.706e-08 GB. + * Memory planning reduces constant memory size to 40.9%.""" + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py new file mode 100644 index 000000000000..f11df58b26ed --- /dev/null +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -0,0 +1,612 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R, tir as T + + +def test_basic(): + # fmt: off + @tvm.script.ir_module + class Module: + @T.prim_func + def add(rxplaceholder: T.Buffer(T.int64(8), "float32"), rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer(T.int64(8), "float32")): + T.evaluate(0) + + @T.prim_func + def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), T_reshape: T.Buffer(T.int64(8), "float32")): + T.evaluate(0) + + @T.prim_func + def relu(rxplaceholder: T.Buffer(T.int64(8), "float32"), compute: T.Buffer(T.int64(8), "float32")): + T.evaluate(0) + + @T.prim_func + def log(rxplaceholder: T.Buffer(T.int64(10), "float32"), compute: T.Buffer(T.int64(10), "float32")): + T.evaluate(0) + + @T.prim_func + def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): + T.evaluate(0) + + @T.prim_func + def pad(rxplaceholder: T.Buffer(T.int64(8), "float32"), PadInput: T.Buffer(T.int64(10), "float32")): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + alloc: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), dtype="float32", runtime_device_index=0) + _: R.Tuple() = exp(x, alloc) + lv: R.Tensor((2, 4), dtype="float32") = alloc + lv1: R.Tensor((8,), dtype="float32") = R.reshape(lv, (8,)) + alloc1: R.Tensor((8,), dtype="float32") = R.builtin.alloc_tensor(R.shape([8]), dtype="float32", runtime_device_index=0) + _1: R.Tuple() = relu(lv1, alloc1) + lv2: R.Tensor((8,), dtype="float32") = alloc1 + alloc2: R.Tensor((8,), dtype="float32") = R.builtin.alloc_tensor(R.shape([8]), dtype="float32", runtime_device_index=0) + _2: R.Tuple() = add(lv2, R.const(1, "float32"), alloc2) + lv3: R.Tensor((8,), dtype="float32") = alloc2 + alloc3: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0) + _3: R.Tuple() = pad(lv3, alloc3) + lv4: R.Tensor((10,), dtype="float32") = alloc3 + alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0) + _4: R.Tuple() = log(lv4, alloc4) + gv: R.Tensor((10,), dtype="float32") = alloc4 + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def add(rxplaceholder: T.Buffer(T.int64(8), "float32"), rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer(T.int64(8), "float32")): + T.evaluate(0) + + @T.prim_func + def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), T_reshape: T.Buffer(T.int64(8), "float32")): + T.evaluate(0) + + @T.prim_func + def relu(rxplaceholder: T.Buffer(T.int64(8), "float32"), compute: T.Buffer(T.int64(8), "float32")): + T.evaluate(0) + + @T.prim_func + def log(rxplaceholder: T.Buffer(T.int64(10), "float32"), compute: T.Buffer(T.int64(10), "float32")): + T.evaluate(0) + + @T.prim_func + def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): + T.evaluate(0) + + @T.prim_func + def pad(rxplaceholder: T.Buffer(T.int64(8), "float32"), PadInput: T.Buffer(T.int64(10), "float32")): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + storage: R.Object = R.memory.alloc_storage(R.shape([32]), virtual_device_index=0, storage_scope="global", dtype="float32") + alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), dtype="float32") + _: R.Tuple() = exp(x, alloc) + lv: R.Tensor((2, 4), dtype="float32") = alloc + lv1: R.Tensor((8,), dtype="float32") = R.reshape(lv, (8,)) + storage1: R.Object = R.memory.alloc_storage(R.shape([40]), virtual_device_index=0, storage_scope="global", dtype="float32") + alloc1: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(storage1, 0, R.shape([8]), dtype="float32") + _1: R.Tuple() = relu(lv1, alloc1) + _2: R.Tuple() = R.memory.kill_tensor(alloc) + _3: R.Tuple() = R.memory.kill_tensor(lv1) + lv2: R.Tensor((8,), dtype="float32") = alloc1 + alloc2: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(storage, 0, R.shape([8]), dtype="float32") + _4: R.Tuple() = add(lv2, R.const(1, "float32"), alloc2) + _5: R.Tuple() = R.memory.kill_tensor(alloc1) + lv3: R.Tensor((8,), dtype="float32") = alloc2 + alloc3: R.Tensor((10,), dtype="float32") = R.memory.alloc_tensor(storage1, 0, R.shape([10]), dtype="float32") + _6: R.Tuple() = pad(lv3, alloc3) + _7: R.Tuple() = R.memory.kill_tensor(alloc2) + lv4: R.Tensor((10,), dtype="float32") = alloc3 + alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0) + _8: R.Tuple() = log(lv4, alloc4) + _9: R.Tuple() = R.memory.kill_tensor(alloc3) + gv5: R.Tensor((10,), dtype="float32") = alloc4 + _11: R.Tuple() = R.memory.kill_storage(storage) + _10: R.Tuple() = R.memory.kill_storage(storage1) + return gv5 + # fmt: on + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_different_dtype(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + C: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def add1( + A: T.Buffer((T.int64(2), T.int64(3)), "int32"), + B: T.Buffer((T.int64(2), T.int64(3)), "int32"), + C: T.Buffer((T.int64(2), T.int64(3)), "int32"), + ): + T.evaluate(0) + + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") + ) -> R.Tensor((2, 3), dtype="float32"): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = add(x, x, alloc) + gv: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="int32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="int32", runtime_device_index=0 + ) + _1: R.Tuple() = add1(y, y, alloc1) + gv1: R.Tensor((2, 3), dtype="int32") = alloc1 + return x + + @tvm.script.ir_module + class Expected: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + C: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.evaluate(0) + + @T.prim_func + def add1( + A: T.Buffer((T.int64(2), T.int64(3)), "int32"), + B: T.Buffer((T.int64(2), T.int64(3)), "int32"), + C: T.Buffer((T.int64(2), T.int64(3)), "int32"), + ): + T.evaluate(0) + + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") + ) -> R.Tensor((2, 3), dtype="float32"): + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _: R.Tuple() = add(x, x, alloc) + _1: R.Tuple() = R.memory.kill_tensor(alloc) + gv1: R.Tensor((2, 3), dtype="float32") = alloc + storage1: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="int32" + ) + alloc1: R.Tensor((2, 3), dtype="int32") = R.memory.alloc_tensor( + storage1, 0, R.shape([2, 3]), dtype="int32" + ) + _2: R.Tuple() = add1(y, y, alloc1) + _3: R.Tuple() = R.memory.kill_tensor(alloc1) + gv12: R.Tensor((2, 3), dtype="int32") = alloc1 + _5: R.Tuple() = R.memory.kill_storage(storage) + _4: R.Tuple() = R.memory.kill_storage(storage1) + return x + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_same_dtype(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + C: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.evaluate(0) + + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = add(x, x, alloc) + gv: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _1: R.Tuple() = add(y, y, alloc1) + gv1: R.Tensor((2, 3), dtype="float32") = alloc1 + return x + + @tvm.script.ir_module + class Expected: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(3)), "float32"), + B: T.Buffer((T.int64(2), T.int64(3)), "float32"), + C: T.Buffer((T.int64(2), T.int64(3)), "float32"), + ): + T.evaluate(0) + + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _: R.Tuple() = add(x, x, alloc) + _1: R.Tuple() = R.memory.kill_tensor(alloc) + gv1: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _2: R.Tuple() = add(y, y, alloc1) + _3: R.Tuple() = R.memory.kill_tensor(alloc1) + gv12: R.Tensor((2, 3), dtype="float32") = alloc1 + _4: R.Tuple() = R.memory.kill_storage(storage) + return x + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_if_cond(): + @tvm.script.ir_module + class Module: + @T.prim_func + def all_less_than_zero(A: T.Buffer((2, 3), "float32"), B: T.Buffer((), "bool")): + T.evaluate(0) + + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + alloc: R.Tensor((), dtype="bool") = R.builtin.alloc_tensor( + R.shape([]), dtype="bool", runtime_device_index=0 + ) + _: R.Tuple() = all_less_than_zero(x, alloc) + x1: R.Tensor((), dtype="bool") = alloc + if x1: + y: R.Tensor((2, 3), dtype="float32") = x + else: + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _1: R.Tuple() = exp(x, alloc1) + gv3: R.Tensor((2, 3), dtype="float32") = alloc1 + y: R.Tensor((2, 3), dtype="float32") = gv3 + return x + + # The pass does no change. + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +def test_if_then_else(): + @tvm.script.ir_module + class Module: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def main( + cond: R.Tensor((), dtype="bool"), x: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = exp(x, alloc) + y: R.Tensor((2, 3), dtype="float32") = alloc + if cond: + z: R.Tensor((2, 3), dtype="float32") = y + else: + z: R.Tensor((2, 3), dtype="float32") = y + return x + + # The pass does no change. + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +def test_cross_block_use(): + @tvm.script.ir_module + class Module: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def main( + cond: R.Tensor((), dtype="bool"), x: R.Tensor((2, 3), dtype="float32") + ) -> R.Tensor((2, 3), dtype="float32"): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = exp(x, alloc) + y: R.Tensor((2, 3), dtype="float32") = alloc + if cond: + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _1: R.Tuple() = exp(y, alloc1) + y2: R.Tensor((2, 3), dtype="float32") = alloc1 + z: R.Tensor((2, 3), dtype="float32") = y2 + else: + alloc2: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _2: R.Tuple() = exp(y, alloc2) + y2: R.Tensor((2, 3), dtype="float32") = alloc2 + z: R.Tensor((2, 3), dtype="float32") = y2 + return x + + # The pass does no change. + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +def test_nested_tuple(): + @tvm.script.ir_module + class Module: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = exp(x, alloc) + y1: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _1: R.Tuple() = exp(x, alloc1) + y2: R.Tensor((2, 3), dtype="float32") = alloc1 + alloc2: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _2: R.Tuple() = exp(x, alloc2) + y3: R.Tensor((2, 3), dtype="float32") = alloc2 + t: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")) = ( + y1, + y2, + ) + nt: R.Tuple( + R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")), + R.Tensor((2, 3), dtype="float32"), + ) = (t, y3) + nt0: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")) = nt[ + 0 + ] + y1_: R.Tensor((2, 3), dtype="float32") = nt0[0] + y2_: R.Tensor((2, 3), dtype="float32") = nt0[1] + y3_: R.Tensor((2, 3), dtype="float32") = nt[1] + alloc3: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _3: R.Tuple() = exp(y1_, alloc3) + z1: R.Tensor((2, 3), dtype="float32") = alloc3 + alloc4: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _4: R.Tuple() = exp(y2_, alloc4) + z2: R.Tensor((2, 3), dtype="float32") = alloc4 + alloc5: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _5: R.Tuple() = exp(y3_, alloc5) + z3: R.Tensor((2, 3), dtype="float32") = alloc5 + return x + + @tvm.script.ir_module + class Expected: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _: R.Tuple() = exp(x, alloc) + y1: R.Tensor((2, 3), dtype="float32") = alloc + storage1: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc1: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage1, 0, R.shape([2, 3]), dtype="float32" + ) + _1: R.Tuple() = exp(x, alloc1) + y2: R.Tensor((2, 3), dtype="float32") = alloc1 + storage2: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc2: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage2, 0, R.shape([2, 3]), dtype="float32" + ) + _2: R.Tuple() = exp(x, alloc2) + y3: R.Tensor((2, 3), dtype="float32") = alloc2 + t: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")) = ( + y1, + y2, + ) + nt: R.Tuple( + R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")), + R.Tensor((2, 3), dtype="float32"), + ) = (t, y3) + nt0: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")) = nt[ + 0 + ] + y1_: R.Tensor((2, 3), dtype="float32") = nt0[0] + y2_: R.Tensor((2, 3), dtype="float32") = nt0[1] + y3_: R.Tensor((2, 3), dtype="float32") = nt[1] + storage3: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc3: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage3, 0, R.shape([2, 3]), dtype="float32" + ) + _3: R.Tuple() = exp(y1_, alloc3) + _4: R.Tuple() = R.memory.kill_tensor(alloc) + _11: R.Tuple() = R.memory.kill_tensor(alloc3) + z1: R.Tensor((2, 3), dtype="float32") = alloc3 + alloc4: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _41: R.Tuple() = exp(y2_, alloc4) + _21: R.Tuple() = R.memory.kill_tensor(alloc1) + _31: R.Tuple() = R.memory.kill_tensor(alloc4) + z2: R.Tensor((2, 3), dtype="float32") = alloc4 + alloc5: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage3, 0, R.shape([2, 3]), dtype="float32" + ) + _5: R.Tuple() = exp(y3_, alloc5) + _42: R.Tuple() = R.memory.kill_tensor(alloc2) + _51: R.Tuple() = R.memory.kill_tensor(alloc5) + z3: R.Tensor((2, 3), dtype="float32") = alloc5 + _9: R.Tuple() = R.memory.kill_storage(storage) + _7: R.Tuple() = R.memory.kill_storage(storage1) + _8: R.Tuple() = R.memory.kill_storage(storage2) + _6: R.Tuple() = R.memory.kill_storage(storage3) + return x + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_call_func_other_than_primfunc(): + @tvm.script.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _ = R.add(x, alloc) + y: R.Tensor((2, 3), dtype="float32") = alloc + return x + + # The pass does no change. + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +def test_symbolic_shape(): + @tvm.script.ir_module + class Module: + @T.prim_func + def exp(var_A: T.handle, var_B: T.handle): + m = T.var("int64") + n = T.var("int64") + A = T.match_buffer(var_A, (m, n), "float32") + B = T.match_buffer(var_B, (m, n), "float32") + T.evaluate(0) + + @R.function + def main(x: R.Tensor(("m", "n"), "float32")): + m = T.var("int64") + n = T.var("int64") + alloc: R.Tensor((m, n), dtype="float32") = R.builtin.alloc_tensor( + R.shape([m, n]), dtype="float32", runtime_device_index=0 + ) + _ = exp(x, alloc) + y: R.Tensor((m, n), dtype="float32") = alloc + return x + + # The pass does no change. + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +def test_zero_reference(): + @tvm.script.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + return x + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, 0, R.shape([2, 3]), dtype="float32" + ) + _: R.Tuple() = R.memory.kill_storage(storage) + return x + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_reshape_param(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add( + A: T.Buffer((T.int64(2), T.int64(25), T.int64(2)), "float32"), + B: T.Buffer((T.int64(2), T.int64(25), T.int64(2)), "float32"), + C: T.Buffer((T.int64(2), T.int64(25), T.int64(2)), "float32"), + ): + T.evaluate(0) + + @R.function + def main( + x: R.Tensor((2, 50), dtype="float32"), y: R.Tensor((100,), dtype="float32") + ) -> R.Tensor((2, 25, 2), dtype="float32"): + lv: R.Tensor((2, 25, 2), dtype="float32") = R.reshape(x, (2, 25, 2)) + lv1: R.Tensor((2, 25, 2), dtype="float32") = R.reshape(y, (2, 25, 2)) + alloc: R.Tensor((2, 25, 2), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 25, 2]), dtype="float32", runtime_device_index=0 + ) + _: R.Tuple() = add(lv, lv1, alloc) + gv: R.Tensor((2, 25, 2), dtype="float32") = alloc + return gv + + # The pass does no change. + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +if __name__ == "__main__": + tvm.testing.main() From ff8473727099858a0b8629c747dcb11eb9051587 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Sat, 18 Feb 2023 11:41:10 -0800 Subject: [PATCH 39/81] [Unity] Disallow inline prim_func in relax IR (#14040) Disallow inline prim_func in relax IR --- python/tvm/script/parser/relax/parser.py | 10 +++++ src/relax/analysis/well_formed.cc | 6 ++- src/relax/ir/block_builder.cc | 22 --------- .../python/relax/test_analysis_well_formed.py | 36 +++++++++++++++ tests/python/relax/test_tvmscript_parser.py | 45 +++++++++---------- 5 files changed, 73 insertions(+), 46 deletions(-) diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index ef26ddd6e921..e5e5bb2743e1 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -139,6 +139,16 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: R.func_ret_struct_info(ann_sinfo) self.visit(node.args) + + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + if not stmt.decorator_list: + self.report_error(stmt, "Function must be decorated") + dec = self.eval_expr(stmt.decorator_list[-1]) + # inline prim_func was found + if dec.dispatch_token == "tir": + self.report_error(stmt, "inline prim_func is disallowed in Relax IR") + self.visit_body(node.body) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index e7ec237fd577..05ad0954bbfc 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -316,7 +316,11 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitBinding_(const VarBindingNode* binding) final { - this->VisitExpr(binding->value); + if (binding->value->IsInstance()) { + Malformed(Diagnostic::Error(binding->value) << "Inline PrimFunc is disallowed in Relax IR."); + } else { + this->VisitExpr(binding->value); + } this->VisitVarDef(binding->var); } diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 6a2d7ea5c584..5976cbb3f441 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -469,12 +469,6 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor()) { - return NormalizePrimFunc(GetRef(prim_func)); - } - if (!block_stack_.empty()) { // cache lookup BlockFrame* cur_frame = CurrentBlockFrame(); @@ -520,23 +514,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor(var); } - // Temp patch to ensure we handle inline PrimFunc case. - // TODO(relax-team) remove such cases from parser and testcases. - Expr NormalizePrimFunc(tir::PrimFunc prim_func) { - if (!prim_func->struct_info_.defined()) { - auto finfo = FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)); - UpdateStructInfo(prim_func, finfo); - } - return prim_func; - } - Expr VisitExpr(const Expr& expr) final { - // Temp patch to ensure we handle inline PrimFunc case. - // TODO(relax-team) remove such cases from parser and testcases. - if (auto* prim_func = expr.as()) { - return NormalizePrimFunc(GetRef(prim_func)); - } - // lookup normalize map if (!block_stack_.empty()) { BlockFrame* cur_frame = CurrentBlockFrame(); diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index cc0de84d53af..67da77274188 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -357,6 +357,42 @@ def test_complex_seq_body(): assert rx.analysis.well_formed(normalized, check_struct_info=True) +def test_inline_prim_func(): + # Error: inline prim_func is disallowed in Relax IR + x = rx.Var("x", R.Tensor([], "int32")) + y = rx.Var("y", R.Tensor([], "int32")) + new_func = rx.Function( + [], + rx.SeqExpr( + [ + rx.BindingBlock( + [ + rx.VarBinding( + var=x, + value=tir.PrimFunc([], tir.Evaluate(0)), + ), + rx.VarBinding( + var=y, + value=rx.Call( + op=tvm.ir.Op.get("relax.call_tir"), + args=[ + rx.GlobalVar("GlobalVar0"), + rx.Tuple([x, tir.PrimFunc([], tir.Evaluate(0))]), + rx.ShapeExpr([]), + ], + ), + ), + ] + ) + ], + y, + ), + R.Tensor(ndim=0, dtype="int32"), + ).with_attr("global_symbol", "foo") + new_mod = tvm.IRModule.from_expr(new_func) + assert not rx.analysis.well_formed(new_mod, check_struct_info=False) + + def test_ANF(): # Error: Nested Call gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 6e9e14d3dc47..507ce72c0676 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -736,30 +736,29 @@ def inner_func(x1: R.Tensor((2, 3), "float32")): inner_func = outer_func_bindings[0].value assert isinstance(inner_func, relax.Function) - @I.ir_module - class TestModule: - @R.function - def f(x: R.Tensor((128, 128), "float32"), y: R.Tensor((128, 128), "float32")): - @T.prim_func - def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) - C = T.match_buffer(c, (128, 128)) - - for i, j, k in T.grid(128, 128, 128): - with T.block(): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] += A[vi, vk] * B[vj, vk] - - z = relax.call_tir(my_matmul, (x, y), R.Tensor((128, 128), dtype="float32")) - return z - bindings = TestModule["f"].body.blocks[0].bindings - assert len(bindings) == 2 - tir_func = bindings[0].value - assert isinstance(tir_func, tir.PrimFunc) +def test_inline_prim_func(): + with pytest.raises(tvm.error.DiagnosticError): + + @I.ir_module + class TestModule: + @R.function + def f(x: R.Tensor((128, 128), "float32"), y: R.Tensor((128, 128), "float32")): + @T.prim_func + def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + + z = relax.call_tir(my_matmul, (x, y), R.Tensor((128, 128), dtype="float32")) + return z def test_cross_function_call(): From 7d2296fb3905684dea98ee6032d6cd93e12db1bf Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 18 Feb 2023 16:56:36 -0500 Subject: [PATCH 40/81] [Unity] Update tests to adapt to latest TVMScript syntax (#14039) Given that some latest changes of TVMScript syntax have been merged, some test files are now containing deprecated uses of TVMScript syntax. This PR updates the test files with latest TVMScript syntax so that running the tests will not trigger deprecation warnings. Co-authored-by: Tianqi Chen --- .../tvm/relax/frontend/torch/fx_translator.py | 11 +- tests/python/relax/test_frontend_dynamo.py | 8 +- tests/python/relax/test_transform.py | 6 +- .../test_transform_annotate_tir_op_pattern.py | 6 +- .../test_transform_attach_global_symbol.py | 16 +- .../relax/test_transform_fold_constant.py | 24 +- .../relax/test_transform_lambda_lift.py | 10 +- .../test_transform_legalize_ops_binary.py | 264 +++++++++--------- ..._transform_legalize_ops_create_datatype.py | 120 ++++---- .../test_transform_legalize_ops_image.py | 28 +- ...sform_legalize_ops_index_linear_algebra.py | 54 ++-- .../test_transform_legalize_ops_manipulate.py | 142 +++++----- .../relax/test_transform_legalize_ops_nn.py | 154 +++++----- ...ansform_legalize_ops_search_statistical.py | 106 +++---- .../test_transform_legalize_ops_unary.py | 120 ++++---- .../test_transform_meta_schedule_tuning.py | 2 +- .../python/relax/test_transform_normalize.py | 2 +- ...test_transform_static_plan_block_memory.py | 8 +- tests/python/relax/test_tuning_api.py | 2 +- .../python/relax/test_tvmscript_ir_builder.py | 4 +- tests/python/relax/test_tvmscript_parser.py | 42 +-- tests/python/relax/test_vm_build.py | 22 +- 22 files changed, 574 insertions(+), 577 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 582f2edbcf55..a762b0a0fbbd 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -76,9 +76,8 @@ def _convert_data_type(input_type): @staticmethod def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var: tensor = tensor.detach().cpu() - shape = tensor.data.shape dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype)) - return relax.const(tensor.data.numpy(), relax.TensorStructInfo(shape, dtype)) + return relax.const(tensor.data.numpy(), dtype) @staticmethod def shape_of(tensor): @@ -444,8 +443,8 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var: gamma = self.params[module.weight] beta = self.params[module.bias] else: - gamma = relax.const(torch.ones_like(module.normalized_shape), x.checked_type) - beta = relax.const(torch.zeros_like(module.normalized_shape), x.checked_type) + gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) + beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) dim_num = len(module.normalized_shape) axes = list(range(-dim_num, 0)) @@ -702,9 +701,7 @@ def from_fx(self, model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModu shape = param.data.shape dtype = self._convert_data_type(str(param.data.dtype)) if dtype in ("float32", "float16"): - self.params[param] = relax.const( - param.data.cpu().numpy(), relax.TensorStructInfo(shape, dtype) - ) + self.params[param] = relax.const(param.data.cpu().numpy(), dtype) else: raise ValueError("Unsupported data type for model parameters: %s" % dtype) # Translate the model. diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index 370df2103d79..b47e3e22bd71 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -43,10 +43,10 @@ def forward(self, x): class Input1_ir: @T.prim_func def main( - inp_0: T.Buffer[(T.int64(10), T.int64(100)), "float32"], - param_0: T.Buffer[(T.int64(100), T.int64(10)), "float32"], - param_1: T.Buffer[T.int64(10), "float32"], - compute: T.Buffer[(T.int64(10), T.int64(10)), "float32"], + inp_0: T.Buffer((T.int64(10), T.int64(100)), "float32"), + param_0: T.Buffer((T.int64(100), T.int64(10)), "float32"), + param_1: T.Buffer(T.int64(10), "float32"), + compute: T.Buffer((T.int64(10), T.int64(10)), "float32"), ): # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main"}) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 12dd095c6b5d..85de4f912ecf 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -30,7 +30,7 @@ def test_to_non_dataflow(): class TestToNonDataflow: @R.function def foo(x: R.Tensor(("m", "n"), "float32")): - m, n = T.var("int64"), T.var("int64") + m, n = T.int64(), T.int64() with R.dataflow(): lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32")) @@ -75,7 +75,7 @@ def test_call_tir_rewrite(): class TestCallTIRRewrite: @R.function def foo(x: R.Tensor(("m", "n"), "float32")): - m, n = T.var("int64"), T.var("int64") + m, n = T.int64(), T.int64() gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) return gv0 @@ -108,7 +108,7 @@ def test_vm_builtin_lower(): class TestVMBuiltinLower: @R.function def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: - m, n = T.var("int64"), T.var("int64") + m, n = T.int64(), T.int64() alloc = R.builtin.alloc_tensor(R.shape([m, n]), runtime_device_index=0, dtype="float32") _ = R.call_packed( "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) diff --git a/tests/python/relax/test_transform_annotate_tir_op_pattern.py b/tests/python/relax/test_transform_annotate_tir_op_pattern.py index 73c65378693a..23ce49a7c220 100644 --- a/tests/python/relax/test_transform_annotate_tir_op_pattern.py +++ b/tests/python/relax/test_transform_annotate_tir_op_pattern.py @@ -39,9 +39,9 @@ class InputModule: @T.prim_func def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: T.func_attr({"global_symbol": "tir_matmul"}) - m = T.var("int32") - n = T.var("int32") - k = T.var("int32") + m = T.int32() + n = T.int32() + k = T.int32() A = T.match_buffer(x, (m, n)) B = T.match_buffer(y, (n, k)) C = T.match_buffer(z, (m, k)) diff --git a/tests/python/relax/test_transform_attach_global_symbol.py b/tests/python/relax/test_transform_attach_global_symbol.py index edfc646e2108..cef3842e3e49 100644 --- a/tests/python/relax/test_transform_attach_global_symbol.py +++ b/tests/python/relax/test_transform_attach_global_symbol.py @@ -28,9 +28,9 @@ class Before: @T.prim_func def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: - m = T.var("int64") - n = T.var("int64") - k = T.var("int64") + m = T.int64() + n = T.int64() + k = T.int64() A = T.match_buffer(x, (m, n)) B = T.match_buffer(y, (n, k)) C = T.match_buffer(z, (m, k)) @@ -44,7 +44,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: @R.function def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")) -> R.Tensor: - m, n, k = T.var("int64"), T.var("int64"), T.var("int64") + m, n, k = T.int64(), T.int64(), T.int64() gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) return gv0 @@ -55,9 +55,9 @@ class Expected: @T.prim_func def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: T.func_attr({"global_symbol": "tir_matmul"}) - m = T.var("int64") - n = T.var("int64") - k = T.var("int64") + m = T.int64() + n = T.int64() + k = T.int64() A = T.match_buffer(x, (m, n)) B = T.match_buffer(y, (n, k)) C = T.match_buffer(z, (m, k)) @@ -74,7 +74,7 @@ def main( x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") ) -> R.Tensor: R.func_attr({"global_symbol": "main"}) - m, n, k = T.var("int64"), T.var("int64"), T.var("int64") + m, n, k = T.int64(), T.int64(), T.int64() gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) return gv0 diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py index 32ee3e700080..95542dd4e6ca 100644 --- a/tests/python/relax/test_transform_fold_constant.py +++ b/tests/python/relax/test_transform_fold_constant.py @@ -59,7 +59,7 @@ def test_one_fold_addone(): @tvm.script.ir_module class Module: @T.prim_func - def addone(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) -> None: + def addone(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None: for i, j in T.grid(16, 16): with T.block("addone"): vi, vj = T.axis.remap("SS", [i, j]) @@ -89,7 +89,7 @@ def test_one_fold_transpose(): @tvm.script.ir_module class Module: @T.prim_func - def func(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]) -> None: + def func(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2), "float32")) -> None: for i, j in T.grid(3, 2): with T.block("transpose"): vi, vj = T.axis.remap("SS", [i, j]) @@ -118,7 +118,7 @@ def test_two_hop_addone(): @tvm.script.ir_module class Module: @T.prim_func - def addone(A: T.Buffer[(2, 2), "float32"], B: T.Buffer[(2, 2), "float32"]) -> None: + def addone(A: T.Buffer((2, 2), "float32"), B: T.Buffer((2, 2), "float32")) -> None: for i, j in T.grid(2, 2): with T.block("addone"): vi, vj = T.axis.remap("SS", [i, j]) @@ -150,7 +150,7 @@ def test_dataflow_fold(): @tvm.script.ir_module class Module: @T.prim_func - def identity(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) -> None: + def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None: for i, j in T.grid(16, 16): with T.block("identity"): vi, vj = T.axis.remap("SS", [i, j]) @@ -184,8 +184,8 @@ class Module: # TIR function can handle different cases. @T.prim_func def addone(a: T.handle, b: T.handle) -> None: - n = T.var("int32") - m = T.var("int32") + n = T.int32() + m = T.int32() A = T.match_buffer(a, (n, m)) B = T.match_buffer(b, (n, m)) for i, j in T.grid(n, m): @@ -195,9 +195,9 @@ def addone(a: T.handle, b: T.handle) -> None: @T.prim_func def sub( - A: T.Buffer[(16, 16), "float32"], - B: T.Buffer[(16, 16), "float32"], - C: T.Buffer[(16, 16), "float32"], + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), ) -> None: for i, j in T.grid(16, 16): with T.block("sub"): @@ -206,7 +206,7 @@ def sub( @R.function def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2)): - n, m = T.var("int64"), T.var("int64") + n, m = T.int64(), T.int64() x0 = R.match_cast(x, R.Tensor((n, m), "float32")) # this line cannot be folded because n is unknown lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16), dtype="float32")) @@ -225,7 +225,7 @@ def expected( c2: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2), ) -> R.Tensor: - n, m = T.var("int64"), T.var("int64") + n, m = T.int64(), T.int64() x0 = R.match_cast(x, R.Tensor((n, m), "float32")) # this line cannot be folded because n is unknown lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16), dtype="float32")) @@ -251,7 +251,7 @@ def test_int32_fold(): @tvm.script.ir_module class Module: @T.prim_func - def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16), "int32"]) -> None: + def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: for i, j in T.grid(16, 16): with T.block("addone"): vi, vj = T.axis.remap("SS", [i, j]) diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index fbdb1fbdcea9..c9bbc0fb91e7 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -190,7 +190,7 @@ def while_loop( before = Before expected = Expected - # Perform Lamda Lifting + # Perform Lambda Lifting after = transform.LambdaLift()(before) assert len(after.functions) == 2 @@ -266,7 +266,7 @@ def inner( before = Before expected = Expected - # Perform Lamda Lifting + # Perform Lambda Lifting after = transform.LambdaLift()(before) assert len(after.functions) == 4 assert_structural_equal(after, expected, map_free_vars=True) @@ -278,9 +278,9 @@ def test_no_local_func(): class Before: @T.prim_func def sub( - A: T.Buffer[(16, 16), "float32"], - B: T.Buffer[(16, 16), "float32"], - C: T.Buffer[(16, 16), "float32"], + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), ) -> None: for i, j in T.grid(16, 16): with T.block("sub"): diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py b/tests/python/relax/test_transform_legalize_ops_binary.py index c2db7e9ba1a1..c99fb885c46c 100644 --- a/tests/python/relax/test_transform_legalize_ops_binary.py +++ b/tests/python/relax/test_transform_legalize_ops_binary.py @@ -124,10 +124,10 @@ def test_add_symbolic(): class Add: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv: R.Tensor((a, b, c, d), "float32") = R.add(x, y) return gv @@ -135,20 +135,20 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), class Expected: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv = R.call_tir(add, (x, y), R.Tensor((a, b, c, d), dtype="float32")) return gv @T.prim_func def add(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_add: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_add = T.match_buffer(var_T_add, [a, b, c, d], dtype="float32") @@ -263,10 +263,10 @@ def test_divide_symbolic(): class Divide: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv: R.Tensor((a, b, c, d), "float32") = R.divide(x, y) return gv @@ -274,20 +274,20 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), class Expected: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv = R.call_tir(divide, (x, y), R.Tensor((a, b, c, d), dtype="float32")) return gv @T.prim_func def divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_divide: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_divide = T.match_buffer(var_T_divide, [a, b, c, d], dtype="float32") @@ -402,10 +402,10 @@ def test_floor_divide_symbolic(): class FloorDivide: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv: R.Tensor((a, b, c, d), "float32") = R.floor_divide(x, y) return gv @@ -413,20 +413,20 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), class Expected: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv = R.call_tir(floor_divide, (x, y), R.Tensor((a, b, c, d), dtype="float32")) return gv @T.prim_func def floor_divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_floor_divide: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_floor_divide = T.match_buffer(var_T_floor_divide, [a, b, c, d], dtype="float32") @@ -479,10 +479,10 @@ def test_multiply_symbolic(): class Multiply: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv: R.Tensor((a, b, c, d), "float32") = R.multiply(x, y) return gv @@ -490,20 +490,20 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), class Expected: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv = R.call_tir(multiply, (x, y), R.Tensor((a, b, c, d), dtype="float32")) return gv @T.prim_func def multiply(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_multiply: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_multiply = T.match_buffer(var_T_multiply, [a, b, c, d], dtype="float32") @@ -556,10 +556,10 @@ def test_subtract_symbolic(): class Subtract: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv: R.Tensor((a, b, c, d), "float32") = R.subtract(x, y) return gv @@ -567,20 +567,20 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), class Expected: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv = R.call_tir(subtract, (x, y), R.Tensor((a, b, c, d), dtype="float32")) return gv @T.prim_func def subtract(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_subtract: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_subtract = T.match_buffer(var_T_subtract, [a, b, c, d], dtype="float32") @@ -698,10 +698,10 @@ def test_equal_symbolic(): class Equal: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv: R.Tensor((a, b, c, d), "bool") = R.equal(x, y) return gv @@ -709,20 +709,20 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), class Expected: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv = R.call_tir(equal, (x, y), R.Tensor((a, b, c, d), dtype="bool")) return gv @T.prim_func def equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_equal: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_equal = T.match_buffer(var_T_equal, [a, b, c, d], dtype="bool") @@ -837,10 +837,10 @@ def test_greater_symbolic(): class Greater: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv: R.Tensor((a, b, c, d), "bool") = R.greater(x, y) return gv @@ -848,20 +848,20 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), class Expected: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv = R.call_tir(greater, (x, y), R.Tensor((a, b, c, d), dtype="bool")) return gv @T.prim_func def greater(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_greater: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_greater = T.match_buffer(var_T_greater, [a, b, c, d], dtype="bool") @@ -914,10 +914,10 @@ def test_greater_equal_symbolic(): class GreaterEqual: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv: R.Tensor((a, b, c, d), "bool") = R.greater_equal(x, y) return gv @@ -925,20 +925,20 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), class Expected: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv = R.call_tir(greater_equal, (x, y), R.Tensor((a, b, c, d), dtype="bool")) return gv @T.prim_func def greater_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_greater_equal: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_greater_equal = T.match_buffer(var_T_greater_equal, [a, b, c, d], dtype="bool") @@ -991,10 +991,10 @@ def test_less_symbolic(): class Less: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv: R.Tensor((a, b, c, d), "bool") = R.less(x, y) return gv @@ -1002,20 +1002,20 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), class Expected: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv = R.call_tir(less, (x, y), R.Tensor((a, b, c, d), dtype="bool")) return gv @T.prim_func def less(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_less: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_less = T.match_buffer(var_T_less, [a, b, c, d], dtype="bool") @@ -1130,10 +1130,10 @@ def test_less_equal_symbolic(): class LessEqual: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv: R.Tensor((a, b, c, d), "bool") = R.less_equal(x, y) return gv @@ -1141,20 +1141,20 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), class Expected: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv = R.call_tir(less_equal, (x, y), R.Tensor((a, b, c, d), dtype="bool")) return gv @T.prim_func def less_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_less_equal: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_less_equal = T.match_buffer(var_T_less_equal, [a, b, c, d], dtype="bool") @@ -1207,10 +1207,10 @@ def test_not_equal_symbolic(): class NotEqual: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv: R.Tensor((a, b, c, d), "bool") = R.not_equal(x, y) return gv @@ -1218,20 +1218,20 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), class Expected: @R.function def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "bool"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv = R.call_tir(not_equal, (x, y), R.Tensor((a, b, c, d), dtype="bool")) return gv @T.prim_func def not_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_not_equal: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(1), c, d], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b, c, T.int64(1)], dtype="float32") T_not_equal = T.match_buffer(var_T_not_equal, [a, b, c, d], dtype="bool") diff --git a/tests/python/relax/test_transform_legalize_ops_create_datatype.py b/tests/python/relax/test_transform_legalize_ops_create_datatype.py index 2506e966345f..6082f7410264 100644 --- a/tests/python/relax/test_transform_legalize_ops_create_datatype.py +++ b/tests/python/relax/test_transform_legalize_ops_create_datatype.py @@ -123,8 +123,8 @@ def test_full_symbolic(): class Full: @R.function def main(dumb_param: R.Tensor(("m", "n")), v: R.Tensor((), "int32")) -> R.Tensor(("m", "n"), "int32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "int32") = R.full((m, n), v, dtype="int32") return gv @@ -132,16 +132,16 @@ def main(dumb_param: R.Tensor(("m", "n")), v: R.Tensor((), "int32")) -> R.Tensor class Expected: @R.function def main(dumb_param: R.Tensor(("m", "n")), v: R.Tensor((), "int32")) -> R.Tensor(("m", "n"), "int32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(full, (v,), R.Tensor((m, n), dtype="int32")) return gv @T.prim_func def full(rxplaceholder: T.Buffer((), "int32"), var_T_full: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="int32") for i0, i1 in T.grid(m, n): with T.block("T_full"): @@ -254,8 +254,8 @@ def test_full_like_symbolic(): class FullLike: @R.function def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.full_like(x, v) return gv @@ -263,16 +263,16 @@ def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tens class Expected: @R.function def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(full, (v,), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def full(rxplaceholder: T.Buffer((), "float32"), var_T_full: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): with T.block("T_full"): @@ -323,8 +323,8 @@ def test_ones_symbolic(): class Ones: @R.function def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.ones((m, n), "float32") return gv @@ -332,16 +332,16 @@ def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(ones, R.tuple(), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def ones(var_T_full: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): with T.block("T_full"): @@ -392,8 +392,8 @@ def test_ones_like_symbolic(): class OnesLike: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.ones_like(x) return gv @@ -401,16 +401,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(ones, R.tuple(), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def ones(var_T_full: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): with T.block("T_full"): @@ -461,8 +461,8 @@ def test_zeros_symbolic(): class Zeros: @R.function def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.zeros((m, n), "float32") return gv @@ -470,16 +470,16 @@ def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(zeros, R.tuple(), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def zeros(var_T_full: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): with T.block("T_full"): @@ -530,8 +530,8 @@ def test_zeros_like_symbolic(): class ZerosLike: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.zeros_like(x) return gv @@ -539,16 +539,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(zeros, R.tuple(), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def zeros(var_T_full: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): with T.block("T_full"): @@ -599,9 +599,9 @@ def test_tril_symbolic(): class Tril: @R.function def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int8"): - m = T.var("int64") - n = T.var("int64") - k = T.var("int64") + m = T.int64() + n = T.int64() + k = T.int64() gv: R.Tensor((m, n, k), "int8") = R.tril(x, k=-2) return gv @@ -609,18 +609,18 @@ def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int class Expected: @R.function def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int8"): - m = T.var("int64") - n = T.var("int64") - k = T.var("int64") + m = T.int64() + n = T.int64() + k = T.int64() gv = R.call_tir(tril, (x,), R.Tensor((m, n, k), dtype="int8")) return gv @T.prim_func def tril(var_rxplaceholder: T.handle, var_trilu: T.handle): T.func_attr({"tir.noalias": True}) - k = T.var("int64") - m = T.var("int64") - n = T.var("int64") + k = T.int64() + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n, k], dtype="int8") trilu = T.match_buffer(var_trilu, [m, n, k], dtype="int8") for i0, i1, i2 in T.grid(m, n, k): @@ -672,9 +672,9 @@ def test_triu_symbolic(): class Triu: @R.function def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int8"): - m = T.var("int64") - n = T.var("int64") - k = T.var("int64") + m = T.int64() + n = T.int64() + k = T.int64() gv: R.Tensor((m, n, k), "int8") = R.triu(x, k=-2) return gv @@ -682,18 +682,18 @@ def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int class Expected: @R.function def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int8"): - m = T.var("int64") - n = T.var("int64") - k = T.var("int64") + m = T.int64() + n = T.int64() + k = T.int64() gv = R.call_tir(triu, (x,), R.Tensor((m, n, k), dtype="int8")) return gv @T.prim_func def triu(var_rxplaceholder: T.handle, var_trilu: T.handle): T.func_attr({"tir.noalias": True}) - k = T.var("int64") - m = T.var("int64") - n = T.var("int64") + k = T.int64() + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n, k], dtype="int8") trilu = T.match_buffer(var_trilu, [m, n, k], dtype="int8") for i0, i1, i2 in T.grid(m, n, k): @@ -769,8 +769,8 @@ def test_astype_symbolic(): class Astype: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "int32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "int32") = R.astype(x, "int32") return gv @@ -778,16 +778,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "int32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "int32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(cast, (x,), R.Tensor((m, n), dtype="int32")) return gv @T.prim_func def cast(var_rxplaceholder: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="int32") for i0, i1 in T.grid(m, n): diff --git a/tests/python/relax/test_transform_legalize_ops_image.py b/tests/python/relax/test_transform_legalize_ops_image.py index 36c8ecdd7b25..5860fea0bf7e 100644 --- a/tests/python/relax/test_transform_legalize_ops_image.py +++ b/tests/python/relax/test_transform_legalize_ops_image.py @@ -58,10 +58,10 @@ def test_image_resize2d_symbolic(): class Resize2D: @R.function def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", "h", "w", 16), "float32")) -> R.Tensor(("n", "c", "oh", "ow", 16), "float32"): - n = T.var("int64") - c = T.var("int64") - oh = T.var("int64") - ow = T.var("int64") + n = T.int64() + c = T.int64() + oh = T.int64() + ow = T.int64() gv: R.Tensor((n, c, oh, ow, 16), "float32") = R.image.resize2d(x, size=(oh, ow), layout="NCHW16c", method="nearest_neighbor", coordinate_transformation_mode="asymmetric") return gv @@ -69,22 +69,22 @@ def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", "h", "w", 16 class Expected: @R.function def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", "h", "w", 16), "float32")) -> R.Tensor(("n", "c", "oh", "ow", 16), "float32"): - n = T.var("int64") - c = T.var("int64") - oh = T.var("int64") - ow = T.var("int64") + n = T.int64() + c = T.int64() + oh = T.int64() + ow = T.int64() gv = R.call_tir(resize2d, (x,), R.Tensor((n, c, oh, ow, 16), dtype="float32")) return gv @T.prim_func def resize2d(var_rxplaceholder: T.handle, var_resize: T.handle): T.func_attr({"tir.noalias": True}) - c = T.var("int64") - h = T.var("int64") - n = T.var("int64") - oh = T.var("int64") - ow = T.var("int64") - w = T.var("int64") + c = T.int64() + h = T.int64() + n = T.int64() + oh = T.int64() + ow = T.int64() + w = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w, T.int64(16)], dtype="float32") resize = T.match_buffer(var_resize, [n, c, oh, ow, T.int64(16)], dtype="float32") for i0, i1, i2, i3, i4 in T.grid(n, c, oh, ow, T.int64(16)): diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index 8b6f9de981bc..5dd9728918d5 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -61,8 +61,8 @@ def test_take_symbolic(): class Take: @R.function def main(x: R.Tensor(("m", "n"), "float32"), indices: R.Tensor(("i",), "int64")) -> R.Tensor(("m", "i"), "float32"): - m = T.var("int64") - i = T.var("int64") + m = T.int64() + i = T.int64() gv: R.Tensor((m, i), "float32") = R.take(x, indices, axis=1) return gv @@ -70,17 +70,17 @@ def main(x: R.Tensor(("m", "n"), "float32"), indices: R.Tensor(("i",), "int64")) class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32"), indices: R.Tensor(("i",), "int64")) -> R.Tensor(("m", "i"), "float32"): - m = T.var("int64") - i = T.var("int64") + m = T.int64() + i = T.int64() gv = R.call_tir(take, (x, indices), R.Tensor((m, i), dtype="float32")) return gv @T.prim_func def take(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_take: T.handle): T.func_attr({"tir.noalias": True}) - i = T.var("int64") - m = T.var("int64") - n = T.var("int64") + i = T.int64() + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [i], dtype="int64") T_take = T.match_buffer(var_T_take, [m, i], dtype="float32") @@ -165,7 +165,7 @@ def test_strided_slice_symbolic_sliced_axis(): class StridedSlice: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor((2, "n"), "float32"): - n = T.var("int64") + n = T.int64() gv: R.Tensor((2, n), "float32") = R.strided_slice(x, axes=[0], begin=[1], end=[8], strides=[3]) return gv # fmt: on @@ -180,7 +180,7 @@ def test_strided_slice_symbolic(): class StridedSlice: @R.function def main(x: R.Tensor((10, "n"), "float32")) -> R.Tensor((3, "n"), "float32"): - n = T.var("int64") + n = T.int64() gv: R.Tensor((3, n), "float32") = R.strided_slice(x, axes=[0], begin=[1], end=[8], strides=[3]) return gv @@ -188,14 +188,14 @@ def main(x: R.Tensor((10, "n"), "float32")) -> R.Tensor((3, "n"), "float32"): class Expected: @R.function def main(x: R.Tensor((10, "n"), dtype="float32")) -> R.Tensor((3, "n"), dtype="float32"): - n = T.var("int64") + n = T.int64() gv = R.call_tir(strided_slice, (x,), R.Tensor((3, n), dtype="float32")) return gv @T.prim_func def strided_slice(var_rxplaceholder: T.handle, var_T_strided_slice_with_axes: T.handle): T.func_attr({"tir.noalias": True}) - n = T.var("int64") + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(10), n], dtype="float32") T_strided_slice_with_axes = T.match_buffer(var_T_strided_slice_with_axes, [T.int64(3), n], dtype="float32") for i0, i1 in T.grid(T.int64(3), n): @@ -351,11 +351,11 @@ def test_matmul_4_5_symbolic(): class Matmul: @R.function def main(x: R.Tensor(("b", 1, "m", "k"), "float32"), y: R.Tensor(("a", 1, "c", "k", "n"), "float32")) -> R.Tensor(("a", "b", "c", "m", "n"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - m = T.var("int64") - n = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + m = T.int64() + n = T.int64() gv: R.Tensor((a, b, c, m, n), "float32") = R.matmul(x, y) return gv @@ -363,23 +363,23 @@ def main(x: R.Tensor(("b", 1, "m", "k"), "float32"), y: R.Tensor(("a", 1, "c", " class Expected: @R.function def main(x: R.Tensor(("b", 1, "m", "k"), "float32"), y: R.Tensor(("a", 1, "c", "k", "n"), "float32")) -> R.Tensor(("a", "b", "c", "m", "n"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - m = T.var("int64") - n = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + m = T.int64() + n = T.int64() gv = R.call_tir(matmul, (x, y), R.Tensor((a, b, c, m, n), dtype="float32")) return gv @T.prim_func def matmul(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - k = T.var("int64") - m = T.var("int64") - n = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + k = T.int64() + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [b, T.int64(1), m, k], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, T.int64(1), c, k, n], dtype="float32") matmul = T.match_buffer(var_matmul, [a, b, c, m, n], dtype="float32") diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 53aa868ffefd..2a30994b83c4 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -62,10 +62,10 @@ def test_broadcast_to_symbolic(): class BroadcastTo: @R.function def main(dumb_param: R.Tensor(("a", "c")), x: R.Tensor(("b", 1, "d"), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv: R.Tensor((a, b, c, d), "float32") = R.broadcast_to(x, (a, b, c, d)) return gv @@ -73,20 +73,20 @@ def main(dumb_param: R.Tensor(("a", "c")), x: R.Tensor(("b", 1, "d"), "float32") class Expected: @R.function def main(dumb_param: R.Tensor(("a", "c")), x: R.Tensor(("b", 1, "d"), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv = R.call_tir(broadcast_to, (x,), R.Tensor((a, b, c, d), dtype="float32")) return gv @T.prim_func def broadcast_to(var_rxplaceholder: T.handle, var_T_broadcast_to: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [b, T.int64(1), d], dtype="float32") T_broadcast_to = T.match_buffer(var_T_broadcast_to, [a, b, c, d], dtype="float32") for i0, i1, i2, i3 in T.grid(a, b, c, d): @@ -171,10 +171,10 @@ def test_concat_input_tuple_var_symbolic(): class Concat: @R.function def main(t: R.Tuple(R.Tensor(("a", "b0"), "float32"), R.Tensor(("a", "b1"), "float32"), R.Tensor(("a", "b2"), "float32"))) -> R.Tensor(("a", "b0 + b1 + b2"), "float32"): - a = T.var("int64") - b0 = T.var("int64") - b1 = T.var("int64") - b2 = T.var("int64") + a = T.int64() + b0 = T.int64() + b1 = T.int64() + b2 = T.int64() gv: R.Tensor((a, b0 + b1 + b2), "float32") = R.concat(t, axis=1) return gv @@ -182,10 +182,10 @@ def main(t: R.Tuple(R.Tensor(("a", "b0"), "float32"), R.Tensor(("a", "b1"), "flo class Expected: @R.function def main(t: R.Tuple(R.Tensor(("a", "b0"), "float32"), R.Tensor(("a", "b1"), "float32"), R.Tensor(("a", "b2"), "float32"))) -> R.Tensor(("a", "b0 + b1 + b2"), "float32"): - a = T.var("int64") - b0 = T.var("int64") - b1 = T.var("int64") - b2 = T.var("int64") + a = T.int64() + b0 = T.int64() + b1 = T.int64() + b2 = T.int64() gv: R.Tensor((a, b0), dtype="float32") = t[0] gv1: R.Tensor((a, b1), dtype="float32") = t[1] gv2: R.Tensor((a, b2), dtype="float32") = t[2] @@ -195,10 +195,10 @@ def main(t: R.Tuple(R.Tensor(("a", "b0"), "float32"), R.Tensor(("a", "b1"), "flo @T.prim_func def concatenate(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_concat: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b0 = T.var("int64") - b1 = T.var("int64") - b2 = T.var("int64") + a = T.int64() + b0 = T.int64() + b1 = T.int64() + b2 = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b0], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [a, b1], dtype="float32") rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [a, b2], dtype="float32") @@ -252,9 +252,9 @@ def test_expand_dims_symbolic(): class ExpandDims: @R.function def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", 1, "b", 1, "c", 1), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() gv: R.Tensor((a, 1, b, 1, c, 1), "float32") = R.expand_dims(x, axis=[1, 3, 5]) return gv @@ -262,18 +262,18 @@ def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", 1, "b", 1, " class Expected: @R.function def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", 1, "b", 1, "c", 1), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() gv = R.call_tir(expand_dims, (x,), R.Tensor((a, 1, b, 1, c, 1), dtype="float32")) return gv @T.prim_func def expand_dims(var_rxplaceholder: T.handle, var_expand_dims: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") expand_dims = T.match_buffer(var_expand_dims, [a, T.int64(1), b, T.int64(1), c, T.int64(1)], dtype="float32") for i0, i1, i2, i3, i4, i5 in T.grid(a, T.int64(1), b, T.int64(1), c, T.int64(1)): @@ -356,9 +356,9 @@ def test_flatten_symbolic(): class Flatten: @R.function def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a * b * c",), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() gv: R.Tensor((a * b * c,), "float32") = R.flatten(x) return gv @@ -366,18 +366,18 @@ def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a * b * c",), "f class Expected: @R.function def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a * b * c",), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() gv = R.call_tir(reshape, (x,), R.Tensor((((a * b) * c),), dtype="float32")) return gv @T.prim_func def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") T_reshape = T.match_buffer(var_T_reshape, [a * b * c], dtype="float32") for i0 in T.serial(a * b * c): @@ -429,10 +429,10 @@ def test_permute_dims_symbolic(): class PermuteDims: @R.function def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("b", "d", "c", "a"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() gv: R.Tensor((b, d, c, a), "float32") = R.permute_dims(x, axes=[1, -1, 2, -4]) return gv @@ -440,20 +440,20 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("b", "d", "c class Expected: @R.function def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor(("b", "d", "c", "a"), dtype="float32"): - b = T.var("int64") - d = T.var("int64") - c = T.var("int64") - a = T.var("int64") + b = T.int64() + d = T.int64() + c = T.int64() + a = T.int64() gv = R.call_tir(transpose, (x,), R.Tensor((b, d, c, a), dtype="float32")) return gv @T.prim_func def transpose(var_rxplaceholder: T.handle, var_T_transpose: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") T_transpose = T.match_buffer(var_T_transpose, [b, d, c, a], dtype="float32") for i0, i1, i2, i3 in T.grid(b, d, c, a): @@ -505,8 +505,8 @@ def test_reshape_symbolic(): class Reshape: @R.function def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b * 2"), "float32"): - a = T.var("int64") - b = T.var("int64") + a = T.int64() + b = T.int64() gv: R.Tensor((a // 2, b * 2), "float32") = R.reshape(x, (a // 2, b * 2)) return gv @@ -514,16 +514,16 @@ def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b * 2"), "f class Expected: @R.function def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b * 2"), "float32"): - a = T.var("int64") - b = T.var("int64") + a = T.int64() + b = T.int64() gv = R.call_tir(reshape, (x,), R.Tensor(((a // 2), (b * 2)), dtype="float32")) return gv @T.prim_func def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") + a = T.int64() + b = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b], dtype="float32") T_reshape = T.match_buffer(var_T_reshape, [a // T.int64(2), b * T.int64(2)], dtype="float32") for i0, i1 in T.grid(a // T.int64(2), b * T.int64(2)): @@ -638,8 +638,8 @@ def test_split_by_indices_n_section_divisible_symbolic(): class Split: @R.function def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "n * 3"), "float32")) -> R.Tuple([R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32")]): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tuple([R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32")]) = R.split(x, 3, axis=1) return gv @@ -647,15 +647,15 @@ def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "n * 3"), "float32")) - class Expected: @R.function def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(split, (x,), [R.Tensor((m, ((n * 3) // 3)), "float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")], tir_vars=(n,)) return gv @T.prim_func def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64): T.func_attr({"tir.noalias": True}) - m = T.var("int64") + m = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32") T_split_sections = T.match_buffer(var_T_split_sections, [m, n * T.int64(3) // T.int64(3)], dtype="float32") T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n * T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)], dtype="float32") @@ -752,8 +752,8 @@ def test_squeeze_symbolic(): class Squeeze: @R.function def main(x: R.Tensor(("a", 1, "b", 1), "float32")) -> R.Tensor(("a", "b", 1), "float32"): - a = T.var("int64") - b = T.var("int64") + a = T.int64() + b = T.int64() gv: R.Tensor((a, b, 1), "float32") = R.squeeze(x, [1]) return gv @@ -761,16 +761,16 @@ def main(x: R.Tensor(("a", 1, "b", 1), "float32")) -> R.Tensor(("a", "b", 1), "f class Expected: @R.function def main(x: R.Tensor(("a", 1, "b", 1), "float32")) -> R.Tensor(("a", "b", 1), "float32"): - a = T.var("int64") - b = T.var("int64") + a = T.int64() + b = T.int64() gv = R.call_tir(squeeze, (x,), R.Tensor((a, b, 1), dtype="float32")) return gv @T.prim_func def squeeze(var_rxplaceholder: T.handle, var_T_squeeze: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") + a = T.int64() + b = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, T.int64(1), b, T.int64(1)], dtype="float32") T_squeeze = T.match_buffer(var_T_squeeze, [a, b, T.int64(1)], dtype="float32") for i0, i1, i2 in T.grid(a, b, T.int64(1)): diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 3f9f02c410e9..729368b82a21 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -151,12 +151,12 @@ def test_conv2d_symbolic(): class Conv2d: @R.function def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c", "kh", "kw"), "float32")) -> R.Tensor(("n", "f", "h - kh + 1", "w - kw + 1"), "float32"): - n = T.var("int64") - h = T.var("int64") - w = T.var("int64") - f = T.var("int64") - kh = T.var("int64") - kw = T.var("int64") + n = T.int64() + h = T.int64() + w = T.int64() + f = T.int64() + kh = T.int64() + kw = T.int64() gv: R.Tensor((n, f, h - kh + 1, w - kw + 1), "float32") = R.nn.conv2d(x, kernel) return gv @@ -164,25 +164,25 @@ def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c class Expected: @R.function def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c", "kh", "kw"), "float32")) -> R.Tensor(("n", "f", "h - kh + 1", "w - kw + 1"), "float32"): - n = T.var("int64") - f = T.var("int64") - h = T.var("int64") - kh = T.var("int64") - w = T.var("int64") - kw = T.var("int64") + n = T.int64() + f = T.int64() + h = T.int64() + kh = T.int64() + w = T.int64() + kw = T.int64() gv = R.call_tir(conv2d, (x, kernel), R.Tensor((n, f, ((h - kh) + 1), ((w - kw) + 1)), dtype="float32")) return gv @T.prim_func def conv2d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv2d_nchw: T.handle): T.func_attr({"tir.noalias": True}) - c = T.var("int64") - f = T.var("int64") - h = T.var("int64") - kh = T.var("int64") - kw = T.var("int64") - n = T.var("int64") - w = T.var("int64") + c = T.int64() + f = T.int64() + h = T.int64() + kh = T.int64() + kw = T.int64() + n = T.int64() + w = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [n, c, h, w], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [f, c, kh, kw], dtype="float32") conv2d_nchw = T.match_buffer(var_conv2d_nchw, [n, f, h - kh + T.int64(1), w - kw + T.int64(1)], dtype="float32") @@ -330,12 +330,12 @@ def test_max_pool2d_symbolic(): class MaxPool2D: @R.function def main(dumb_param: R.Tensor(("kh", "kw")), x: R.Tensor(("n", "c", "h", "w"), "float32")) -> R.Tensor(("n", "c", "h - kh + 1", "w - kw + 1"), "float32"): - n = T.var("int64") - c = T.var("int64") - h = T.var("int64") - w = T.var("int64") - kh = T.var("int64") - kw = T.var("int64") + n = T.int64() + c = T.int64() + h = T.int64() + w = T.int64() + kh = T.int64() + kw = T.int64() gv: R.Tensor((n, c, h - kh + 1, w - kw + 1), "float32") = R.nn.max_pool2d(x, pool_size=[kh, kw]) return gv @@ -434,10 +434,10 @@ def test_adaptive_avg_pool2d_symbolic(): class AdaptiveAvgPool2D: @R.function def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", "h", "w"), "float32")) -> R.Tensor(("n", "c", "oh", "ow"), "float32"): - n = T.var("int64") - c = T.var("int64") - oh = T.var("int64") - ow = T.var("int64") + n = T.int64() + c = T.int64() + oh = T.int64() + ow = T.int64() gv: R.Tensor((n, c, oh, ow), "float32") = R.nn.adaptive_avg_pool2d(x, (oh, ow)) return gv # fmt: on @@ -483,8 +483,8 @@ def test_relu_symbolic(): class Relu: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.nn.relu(x) return gv @@ -492,16 +492,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(relu, (x,), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def relu(var_rxplaceholder: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): @@ -581,8 +581,8 @@ def test_gelu_symbolic(): class Gelu: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.nn.gelu(x) return gv @@ -590,16 +590,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(gelu, (x,), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def gelu(var_rxplaceholder: T.handle, var_T_multiply: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") T_multiply = T.match_buffer(var_T_multiply, [m, n], dtype="float32") T_multiply_1 = T.alloc_buffer([m, n], dtype="float32") @@ -686,8 +686,8 @@ def test_silu_symbolic(): class Silu: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.nn.silu(x) return gv @@ -695,16 +695,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(silu, (x,), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def silu(var_rxplaceholder: T.handle, var_T_multiply: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") T_multiply = T.match_buffer(var_T_multiply, [m, n], dtype="float32") compute = T.alloc_buffer([m, n], dtype="float32") @@ -789,9 +789,9 @@ def test_softmax_symbolic(): class Softmax: @R.function def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() gv: R.Tensor((a, b, c), "float32") = R.nn.softmax(x) return gv @@ -799,18 +799,18 @@ def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), " class Expected: @R.function def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() gv = R.call_tir(softmax, (x,), R.Tensor((a, b, c), dtype="float32")) return gv @T.prim_func def softmax(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") T_softmax_norm = T.match_buffer(var_T_softmax_norm, [a, b, c], dtype="float32") T_softmax_maxelem = T.alloc_buffer([a, b], dtype="float32") @@ -963,10 +963,10 @@ def test_batch_norm_symbolic(): class BatchNorm: @R.function def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), R.Tensor(("c",), "float32")): - n = T.var("int64") - h = T.var("int64") - w = T.var("int64") - c = T.var("int64") + n = T.int64() + h = T.int64() + w = T.int64() + c = T.int64() gv: R.Tuple(R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), R.Tensor((c,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=-1) return gv @@ -974,20 +974,20 @@ def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), " class Expected: @R.function def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), R.Tensor(("c",), "float32")): - n = T.var("int64") - h = T.var("int64") - w = T.var("int64") - c = T.var("int64") + n = T.int64() + h = T.int64() + w = T.int64() + c = T.int64() gv = R.call_tir(batch_norm, (x, gamma, beta, moving_mean, moving_var), [R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), R.Tensor((c,), "float32")]) return gv @T.prim_func def batch_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_rxplaceholder_3: T.handle, var_rxplaceholder_4: T.handle, var_T_add: T.handle, var_T_multiply: T.handle, var_T_multiply_1: T.handle): T.func_attr({"tir.noalias": True}) - c = T.var("int64") - h = T.var("int64") - n = T.var("int64") - w = T.var("int64") + c = T.int64() + h = T.int64() + n = T.int64() + w = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [n, h, w, c], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [c], dtype="float32") rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [c], dtype="float32") @@ -1133,9 +1133,9 @@ def test_layer_norm_symbolic(): class LayerNorm: @R.function def main(x: R.Tensor(("n", "s", "f"), "float32"), gamma: R.Tensor(("s", "f"), "float32"), beta: R.Tensor(("s", "f"), "float32")) -> R.Tensor(("n", "s", "f"), "float32"): - n = T.var("int64") - s = T.var("int64") - f = T.var("int64") + n = T.int64() + s = T.int64() + f = T.int64() gv: R.Tensor((n, s, f), "float32") = R.nn.layer_norm(x, gamma, beta, axes=[1, 2]) return gv @@ -1143,18 +1143,18 @@ def main(x: R.Tensor(("n", "s", "f"), "float32"), gamma: R.Tensor(("s", "f"), "f class Expected: @R.function def main(x: R.Tensor(("n", "s", "f"), "float32"), gamma: R.Tensor(("s", "f"), "float32"), beta: R.Tensor(("s", "f"), "float32")) -> R.Tensor(("n", "s", "f"), "float32"): - n = T.var("int64") - s = T.var("int64") - f = T.var("int64") + n = T.int64() + s = T.int64() + f = T.int64() gv = R.call_tir(layer_norm, (x, gamma, beta), R.Tensor((n, s, f), dtype="float32")) return gv @T.prim_func def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle): T.func_attr({"tir.noalias": True}) - f = T.var("int64") - n = T.var("int64") - s = T.var("int64") + f = T.int64() + n = T.int64() + s = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [n, s, f], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [s, f], dtype="float32") rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [s, f], dtype="float32") diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index 4c31077d9c4b..5bdfb1774c16 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -61,9 +61,9 @@ def test_where_symbolic(): class Where: @R.function def main(condition: R.Tensor(("a", "b", 1), "bool"), x: R.Tensor(("b", "c"), "float32"), y: R.Tensor(("b", 1), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() gv: R.Tensor((a, b, c), "float32") = R.where(condition, x, y) return gv @@ -71,18 +71,18 @@ def main(condition: R.Tensor(("a", "b", 1), "bool"), x: R.Tensor(("b", "c"), "fl class Expected: @R.function def main(condition: R.Tensor(("a", "b", 1), "bool"), x: R.Tensor(("b", "c"), "float32"), y: R.Tensor(("b", 1), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() gv = R.call_tir(where, (condition, x, y), R.Tensor((a, b, c), dtype="float32")) return gv @T.prim_func def where(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_where: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, T.int64(1)], dtype="bool") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [b, c], dtype="float32") rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [b, T.int64(1)], dtype="float32") @@ -141,8 +141,8 @@ def test_max_symbolic(): class Max: @R.function def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", "d"), "float32"): - a = T.var("int64") - d = T.var("int64") + a = T.int64() + d = T.int64() gv: R.Tensor((a, d), "float32") = R.max(x, axis=[1, 2]) return gv @@ -150,18 +150,18 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", "d"), " class Expected: @R.function def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", "d"), "float32"): - a = T.var("int64") - d = T.var("int64") + a = T.int64() + d = T.int64() gv = R.call_tir(max, (x,), R.Tensor((a, d), dtype="float32")) return gv @T.prim_func def max(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, [a, d], dtype="float32") for i0, i1, i2, i3 in T.grid(a, d, b, c): @@ -217,8 +217,8 @@ def test_min_symbolic(): class Min: @R.function def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", 1, 1, "d"), "float32"): - a = T.var("int64") - d = T.var("int64") + a = T.int64() + d = T.int64() gv: R.Tensor((a, 1, 1, d), "float32") = R.min(x, axis=[1, 2], keepdims=True) return gv @@ -226,18 +226,18 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", 1, 1, " class Expected: @R.function def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", 1, 1, "d"), "float32"): - a = T.var("int64") - d = T.var("int64") + a = T.int64() + d = T.int64() gv = R.call_tir(min, (x,), R.Tensor((a, 1, 1, d), dtype="float32")) return gv @T.prim_func def min(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") rxplaceholder_red = T.match_buffer(var_rxplaceholder_red, [a, T.int64(1), T.int64(1), d], dtype="float32") for i0, i1, i2, i3, i4, i5 in T.grid(a, T.int64(1), T.int64(1), d, b, c): @@ -306,10 +306,10 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((), "float32" @T.prim_func def sum(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((), "float32")): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") for i0, i1, i2, i3 in T.grid(a, b, c, d): with T.block("rxplaceholder_red"): @@ -377,10 +377,10 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, 1, 1, 1), @T.prim_func def prod(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1)), "float32")): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), a, b, c, d): with T.block("rxplaceholder_red"): @@ -442,8 +442,8 @@ def test_mean_symbolic(): class Mean: @R.function def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("b", "c"), "float32"): - b = T.var("int64") - c = T.var("int64") + b = T.int64() + c = T.int64() gv: R.Tensor((b, c), "float32") = R.mean(x, [0, 3]) return gv @@ -451,18 +451,18 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("b", "c"), " class Expected: @R.function def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor(("b", "c"), dtype="float32"): - b = T.var("int64") - c = T.var("int64") + b = T.int64() + c = T.int64() gv = R.call_tir(mean, (x,), R.Tensor((b, c), dtype="float32")) return gv @T.prim_func def mean(var_rxplaceholder: T.handle, var_T_divide: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") T_divide = T.match_buffer(var_T_divide, [b, c], dtype="float32") rxplaceholder_red = T.alloc_buffer([b, c], dtype="float32") @@ -579,10 +579,10 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((), "float32" @T.prim_func def std(var_rxplaceholder: T.handle, compute: T.Buffer((), "float32")): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") rxplaceholder_red = T.alloc_buffer([], dtype="float32") T_divide = T.alloc_buffer([], dtype="float32") @@ -715,8 +715,8 @@ def test_variance_symbolic(): class Variance: @R.function def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, "b", "c", 1), "float32"): - b = T.var("int64") - c = T.var("int64") + b = T.int64() + c = T.int64() gv: R.Tensor((1, b, c, 1), "float32") = R.variance(x, [0, 3], keepdims=True) return gv @@ -724,18 +724,18 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, "b", "c", class Expected: @R.function def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, "b", "c", 1), "float32"): - b = T.var("int64") - c = T.var("int64") + b = T.int64() + c = T.int64() gv = R.call_tir(variance, (x,), R.Tensor((1, b, c, 1), dtype="float32")) return gv @T.prim_func def variance(var_rxplaceholder: T.handle, var_T_divide: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") - d = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d], dtype="float32") T_divide = T.match_buffer(var_T_divide, [T.int64(1), b, c, T.int64(1)], dtype="float32") rxplaceholder_red = T.alloc_buffer([T.int64(1), b, c, T.int64(1)], dtype="float32") diff --git a/tests/python/relax/test_transform_legalize_ops_unary.py b/tests/python/relax/test_transform_legalize_ops_unary.py index 12ae366dcc8a..7250e711beee 100644 --- a/tests/python/relax/test_transform_legalize_ops_unary.py +++ b/tests/python/relax/test_transform_legalize_ops_unary.py @@ -59,8 +59,8 @@ def test_abs_symbolic(): class Abs: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.abs(x) return gv @@ -68,16 +68,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(tir_abs, (x,), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def tir_abs(var_rxplaceholder: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): @@ -129,8 +129,8 @@ def test_cos_symbolic(): class Cos: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.cos(x) return gv @@ -138,16 +138,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(tir_cos, (x,), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def tir_cos(var_rxplaceholder: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): @@ -199,8 +199,8 @@ def test_exp_symbolic(): class Exp: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.exp(x) return gv @@ -208,16 +208,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(tir_exp, (x,), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): @@ -269,8 +269,8 @@ def test_log_symbolic(): class Log: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.log(x) return gv @@ -278,16 +278,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(tir_log, (x,), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def tir_log(var_rxplaceholder: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): @@ -339,8 +339,8 @@ def test_negative_symbolic(): class Negative: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.negative(x) return gv @@ -348,16 +348,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(tir_negative, (x,), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def tir_negative(var_rxplaceholder: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): @@ -409,8 +409,8 @@ def test_sigmoid_symbolic(): class Sigmoid: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.sigmoid(x) return gv @@ -418,16 +418,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(tir_sigmoid, (x,), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def tir_sigmoid(var_rxplaceholder: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): @@ -479,8 +479,8 @@ def test_sin_symbolic(): class Sin: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.sin(x) return gv @@ -488,16 +488,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(tir_sin, (x,), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def tir_sin(var_rxplaceholder: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): @@ -549,8 +549,8 @@ def test_sqrt_symbolic(): class Sqrt: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.sqrt(x) return gv @@ -558,16 +558,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(tir_sqrt, (x,), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def tir_sqrt(var_rxplaceholder: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): @@ -619,8 +619,8 @@ def test_tanh_symbolic(): class Tanh: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.tanh(x) return gv @@ -628,16 +628,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(tir_tanh, (x,), R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def tir_tanh(var_rxplaceholder: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): @@ -657,8 +657,8 @@ def test_clip_symbolic(): class Clip: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv: R.Tensor((m, n), "float32") = R.clip(x, 5, 8) return gv @@ -666,16 +666,16 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): class Expected: @R.function def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv = R.call_tir(tir_clip, (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) return gv @T.prim_func def tir_clip(var_rxplaceholder: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") compute = T.match_buffer(var_compute, [m, n], dtype="float32") for i0, i1 in T.grid(m, n): diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py b/tests/python/relax/test_transform_meta_schedule_tuning.py index ff695b9436a3..d87ea5cec728 100644 --- a/tests/python/relax/test_transform_meta_schedule_tuning.py +++ b/tests/python/relax/test_transform_meta_schedule_tuning.py @@ -36,7 +36,7 @@ class InputModule: @T.prim_func def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: T.func_attr({"global_symbol": "tir_matmul"}) - k = T.var("int32") + k = T.int32() A = T.match_buffer(x, (32, 32)) B = T.match_buffer(y, (32, 32)) C = T.match_buffer(z, (32, 32)) diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py index 9e9533a5ed23..da123f956d59 100644 --- a/tests/python/relax/test_transform_normalize.py +++ b/tests/python/relax/test_transform_normalize.py @@ -122,7 +122,7 @@ def f(x: R.Tensor(dtype="float32")): class ANFMod2: @R.function def foo(x: R.Tensor(("m", "n"), "float32")): - m, n = T.var("int64"), T.var("int64") + m, n = T.int64(), T.int64() with R.dataflow(): lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32")) diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index f11df58b26ed..1b556139ccc9 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -530,16 +530,16 @@ def test_symbolic_shape(): class Module: @T.prim_func def exp(var_A: T.handle, var_B: T.handle): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() A = T.match_buffer(var_A, (m, n), "float32") B = T.match_buffer(var_B, (m, n), "float32") T.evaluate(0) @R.function def main(x: R.Tensor(("m", "n"), "float32")): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() alloc: R.Tensor((m, n), dtype="float32") = R.builtin.alloc_tensor( R.shape([m, n]), dtype="float32", runtime_device_index=0 ) diff --git a/tests/python/relax/test_tuning_api.py b/tests/python/relax/test_tuning_api.py index b12ff016705d..3fc2d41618a2 100644 --- a/tests/python/relax/test_tuning_api.py +++ b/tests/python/relax/test_tuning_api.py @@ -47,7 +47,7 @@ @tvm.script.ir_module class TestModule: @T.prim_func - def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16), "int32"]) -> None: + def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: T.func_attr(({"global_symbol": "addone"})) for i, j in T.grid(16, 16): with T.block("addone"): diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py index 12d8b114b862..eb0aaf56040b 100644 --- a/tests/python/relax/test_tvmscript_ir_builder.py +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -61,8 +61,8 @@ def test_match_cast(): """ @R.function def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() _ = R.match_cast(x, R.Tensor((m,), "float32")) y1 = R.match_cast(x, R.Tensor((n,), "float32")) return (m, n * 2) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 507ce72c0676..8df125ac72da 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -105,7 +105,7 @@ def test_unexpected_tir_cast_args(): @R.function def f(x: R.Tensor(("m",), "float32")): - m = T.var("int64") + m = T.int64() # tir.cast expects 2 arguments, but got 3 return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),), dtype="float32")) @@ -116,7 +116,7 @@ def test_unexpected_tir_max_args(): @R.function def f(x: R.Tensor(("m", "n"), "float32")): - m = T.var("int64") + m = T.int64() # tir.max expects 2 arguments, but got 1 return relax.call_tir("foo", (x,), R.Tensor((T.max(m),), dtype="float32")) @@ -220,15 +220,15 @@ def foo(x: R.Tensor((4, 4), "float32")): def test_symbolic_shape(): @R.function def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64", "m") - n = T.var("int64", "n") + m = T.int64() + n = T.int64() gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) return gv0 @R.function def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) return gv0 @@ -236,8 +236,8 @@ def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): @R.function def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): - m = T.var("int64") - n = T.var("int32") # The shape dtype should be int64 + m = T.int64() + n = T.int32() # The shape dtype should be int64 gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) return gv0 @@ -282,8 +282,8 @@ def foo(x: R.Tensor((4, 4), "float32")): def test_match_cast(): @R.function def foo(x: R.Tensor("float32"), y: R.Tensor("float32")): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() x0 = R.match_cast(x, R.Tensor([m], "float32")) with R.dataflow(): y0 = R.match_cast(y, R.Tensor([n], "float32")) @@ -327,7 +327,7 @@ def foo(x: R.Tensor((4, 4), "float32")): def test_tuple_return_2(): @R.function def foo(x: R.Tensor("float32", ndim=2)): - n, m = T.var("int64"), T.var("int64") + n, m = T.int64(), T.int64() x0 = R.match_cast(x, R.Tensor((n, m), "float32")) return (x0, R.shape([n + 1, m, 1])) @@ -344,7 +344,7 @@ def foo(x: R.Tensor("float32", ndim=2)): def test_tuple_binding(): @R.function def foo(x: R.Tensor("float32", ndim=2)): - n, m = T.var("int64"), T.var("int64") + n, m = T.int64(), T.int64() x0 = R.match_cast(x, R.Tensor((n, m), "float32")) t0 = (x, x0) t1 = (x, R.shape([n, m]), t0) @@ -414,8 +414,8 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) with R.dataflow(): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() lv0 = R.call_tir("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) lv1 = R.match_cast(lv0, R.Tensor((m, n), "float32")) gv2 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) @@ -601,7 +601,7 @@ def foo( y: R.Tensor(("m",), "float32"), r: R.Tensor(dtype="int64"), ) -> R.Object: - m = T.var("int64", "m") + m = T.int64() z: R.Tensor((32, m), "float32") = R.multiply(x, y) w: R.Tensor = R.multiply(z, z) q: R.Tensor(ndim=2) = R.add(w, w) @@ -690,7 +690,7 @@ class Module: def main( dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2", "float32")) ) -> R.Tensor(("n * 2",), "float32"): - n = T.var("int64") + n = T.int64() y = R.call_tir(copy, (x,), R.Tensor(((n * 2,)), dtype="float32"), tir_vars=(n,)) return y @@ -884,7 +884,7 @@ def test_erase_to_well_defined(): @R.function def foo(x: R.Tensor): q = x - m, n = T.var("int64"), T.var("int64") + m, n = T.int64(), T.int64() z = R.match_cast(q, R.Tensor((m, n))) w = z return w @@ -930,7 +930,7 @@ def foo(x: R.Tensor(("m + 1",), "float32"), y: R.Tensor(("m", 1), "float32")): def bar( x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32") ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"): - m = T.var("int64") + m = T.int64() z = R.call_tir("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32")) return z @@ -949,7 +949,7 @@ def bar( # Shape Case @R.function def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): - m = T.var("int64") + m = T.int64() z = R.call_tir("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) return z @@ -977,8 +977,8 @@ def foo(x: R.Tensor(("m + 1", "m * 2"), "float32")): # name 'm' is not defined def test_vm_ops(): @R.function def foo(x: R.Tensor(("m", "n"), dtype="float32")): - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() storage = R.vm.alloc_storage(R.shape([4 * m * n]), dtype="float32", runtime_device_index=0) alloc = R.vm.alloc_tensor(storage, shape=R.shape([m, n]), offset=0, dtype="float32") tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", runtime_device_index=0) diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index d57efd8b9992..e78e926dcb7c 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -86,7 +86,7 @@ def test_vm_compile_stage2(exec_mode): class TestVMCompileStage2: @R.function def foo(x: R.Tensor(dtype="float32")) -> R.Shape: - n, m = T.var("int64"), T.var("int64") + n, m = T.int64(), T.int64() _ = R.match_cast(x, R.Tensor((n, m), "float32")) return R.shape([n * 2, m * 3]) @@ -143,7 +143,7 @@ class TestVMCompileE2E: @R.function def foo(x: R.Tensor(dtype="float32")) -> R.Tensor: with R.dataflow(): - n, m = T.var("int64"), T.var("int64") + n, m = T.int64(), T.int64() _ = R.match_cast(x, R.Tensor((n, m), "float32")) y = R.call_tir("test.vm.tile", (x), R.Tensor((n, m * 2), dtype="float32")) R.output(y) @@ -168,9 +168,9 @@ class TestVMCompileE2E2: @T.prim_func def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: T.func_attr({"global_symbol": "tir_matmul"}) - m = T.var("int32") - n = T.var("int32") - k = T.var("int32") + m = T.int32() + n = T.int32() + k = T.int32() A = T.match_buffer(x, (m, n)) B = T.match_buffer(y, (n, k)) C = T.match_buffer(z, (m, k)) @@ -186,7 +186,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: def func( x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") ) -> R.Tensor: - m, k = T.var("int64"), T.var("int64") + m, k = T.int64(), T.int64() gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) return gv0 @@ -540,9 +540,9 @@ class TestVMSubFunction: @T.prim_func def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: T.func_attr({"global_symbol": "tir_matmul"}) - m = T.var("int32") - n = T.var("int32") - k = T.var("int32") + m = T.int32() + n = T.int32() + k = T.int32() A = T.match_buffer(x, (m, n)) B = T.match_buffer(y, (n, k)) C = T.match_buffer(z, (m, k)) @@ -680,8 +680,8 @@ class TestVMSetInput: @T.prim_func def test_vm_mul(x: T.handle, y: T.handle, z: T.handle): T.func_attr({"global_symbol": "test_vm_mul"}) - m = T.var("int32") - n = T.var("int32") + m = T.int32() + n = T.int32() A = T.match_buffer(x, (m, n)) B = T.match_buffer(y, (m, n)) C = T.match_buffer(z, (m, n)) From ef3524a6c95fc41e4429799323eb6ae27f011601 Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Sat, 18 Feb 2023 15:57:02 -0600 Subject: [PATCH 41/81] [Unity] Relax dataflow pattern language (matching) (#14041) The dataflow pattern language for Relax (originally from https://github.com/tlc-pack/relax/pull/163). The implementation splits patterns into two parts: - Match an Expression: match an expression syntactically (MatchExprPattern, i.e., DFPatternMatcher); - Match a Graph: match a graph (cross multiple VarBinding) topologically (MatchGraphPattern); --- include/tvm/relax/analysis.h | 32 + include/tvm/relax/dataflow_matcher.h | 80 ++ include/tvm/relax/dataflow_pattern.h | 810 +++++++++++++ include/tvm/relax/dataflow_pattern_functor.h | 183 +++ python/tvm/relax/analysis/analysis.py | 38 +- python/tvm/relax/dpl/__init__.py | 21 + python/tvm/relax/dpl/_ffi.py | 20 + python/tvm/relax/dpl/context.py | 86 ++ python/tvm/relax/dpl/pattern.py | 1095 ++++++++++++++++++ src/relax/analysis/udchain.cc | 102 ++ src/relax/analysis/var2value.cc | 91 ++ src/relax/ir/dataflow_matcher.cc | 768 ++++++++++++ src/relax/ir/dataflow_matcher_impl.h | 87 ++ src/relax/ir/dataflow_pattern.cc | 607 ++++++++++ src/relax/ir/dataflow_pattern_functor.cc | 111 ++ tests/python/relax/test_analysis.py | 23 +- tests/python/relax/test_dataflow_pattern.py | 867 ++++++++++++++ 17 files changed, 5018 insertions(+), 3 deletions(-) create mode 100644 include/tvm/relax/dataflow_matcher.h create mode 100644 include/tvm/relax/dataflow_pattern.h create mode 100644 include/tvm/relax/dataflow_pattern_functor.h create mode 100644 python/tvm/relax/dpl/__init__.py create mode 100644 python/tvm/relax/dpl/_ffi.py create mode 100644 python/tvm/relax/dpl/context.py create mode 100644 python/tvm/relax/dpl/pattern.py create mode 100644 src/relax/analysis/udchain.cc create mode 100644 src/relax/analysis/var2value.cc create mode 100644 src/relax/ir/dataflow_matcher.cc create mode 100644 src/relax/ir/dataflow_matcher_impl.h create mode 100644 src/relax/ir/dataflow_pattern.cc create mode 100644 src/relax/ir/dataflow_pattern_functor.cc create mode 100644 tests/python/relax/test_dataflow_pattern.py diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index f9896efdf272..32e1582134c7 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -317,6 +317,38 @@ TVM_DLL tvm::Array CalledGlobalVars(const Expr& expr); */ TVM_DLL tvm::Array AllGlobalVars(const Expr& expr); +/*! + * \brief Analyze var -> value mapping from VarBindings. + * + * \param m The IRModule to check. + * \return Var -> Value (Expr) + */ +TVM_DLL Map AnalyzeVar2Value(const IRModule& m); + +/*! + * \brief Analyze var -> value mapping from VarBindings. + * + * \param expr The expression to check. + * \return Var -> Value (Expr) + */ +TVM_DLL Map AnalyzeVar2Value(const Expr& expr); + +/*! + * \brief Analyze var -> value mapping from VarBindings. + * + * \param dfb The dataflow block to check. + * \return Var -> Value (Expr) + */ +TVM_DLL Map AnalyzeVar2Value(const DataflowBlock& dfb); + +/*! + * \brief Get the use-def chain of variables inside a dataflow block. + * + * \param dfb The dataflow block to be analyzed. + * \return A map mapping variable definitions to a set of uses. + */ +TVM_DLL Map> DataflowBlockUseDef(const DataflowBlock& dfb); + /*! * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps. * diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h new file mode 100644 index 000000000000..fa58308faced --- /dev/null +++ b/include/tvm/relax/dataflow_matcher.h @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_matcher.h + * \brief A pattern matcher for matching dataflow properties. + */ +#ifndef TVM_RELAX_DATAFLOW_MATCHER_H_ +#define TVM_RELAX_DATAFLOW_MATCHER_H_ + +#include +#include + +#include + +namespace tvm { +namespace relax { + +/** + * \brief Determine if a pattern matches an expression. + * \note The behavior of MatchExpr is to match a relax.Expr (`expr`) syntactically through + * one given pattern (`pattern`). + * + * \param pattern The pattern to match + * \param expr The expression to match + * \param bindings The mapping from relax.Var to relax.Expr + * \return true if matched + * \return false if unmatched + */ +bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings = NullOpt); + +/* \brief Similar to above, but return pairs of a matching pattern and an expression. */ +Optional> ExtractMatchedExpr( + DFPattern pattern, Expr expr, Optional> bindings = NullOpt); + +/** + * \brief Match a sub-graph in a DataflowBlock with a graph of patterns and return the mapping. + * \note This algorithm returns the first matched sub-graph. Use `start_hint` to specify the + * starting point of the matching so that we can distinguish multiple matches. + * + * \param ctx The graph-wise patterns. + * \param dfb The function to match. + * \param start_hint The starting point expression to match to distinguish multiple matches. + * \param must_include_hint If start_hint is given, the return pattern must include start_hint. + * \return tvm::runtime::Map + */ +TVM_DLL tvm::runtime::Map MatchGraph(const PatternContext& ctx, + const DataflowBlock& dfb, + Optional start_hint = NullOpt, + bool must_include_hint = false); + +/** + * \brief Match a graph-wise pattern with the current context (PatternContext::Current()). + */ +inline tvm::runtime::Map MatchGraphDefault(const DataflowBlock& dfb, + Optional start_hint = NullOpt, + bool must_include_hint = false) { + return MatchGraph(PatternContext::Current(), dfb, start_hint, must_include_hint); +} + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_DATAFLOW_MATCHER_H_ diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h new file mode 100644 index 000000000000..701879745efa --- /dev/null +++ b/include/tvm/relax/dataflow_pattern.h @@ -0,0 +1,810 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_pattern.h + * \brief A pattern language for matching dataflow properties. + */ +#ifndef TVM_RELAX_DATAFLOW_PATTERN_H_ +#define TVM_RELAX_DATAFLOW_PATTERN_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class PatternSeq; +class CallPattern; +class OrPattern; +class AndPattern; +class NotPattern; +class ShapePattern; +class TypePattern; +class DataTypePattern; +class AttrPattern; + +/*! + * \brief Create used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned. + * + * \param lhs Left hand side of the used-by relationship. + * \param rhs Right hand side of the used-by relationship. + * \param index lhs[-1] is used as the index'th argument of rhs[0]. + * \return PatternSeq The concatenated sequence of [*lhs, *rhs]. + */ +TVM_DLL PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index = -1); +/*! \brief Syntax sugar of UsedBy(lhs, rhs, -1). */ +TVM_DLL PatternSeq operator^(const PatternSeq& lhs, const PatternSeq& rhs); + +/*! + * \brief Create only-used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned. + * + * \param lhs Left hand side of the used-by relationship. + * \param rhs Right hand side of the used-by relationship. + * \param index lhs[-1] is used as the index'th argument of rhs[0]. + * \return PatternSeq The concatenated sequence of [*lhs, *rhs]. + */ +TVM_DLL PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index = -1); +/*! \brief Syntax sugar of OnlyUsedBy(lhs, rhs, -1). */ +TVM_DLL PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs); + +/*! + * \brief Base type of all dataflow patterns. + * \sa DFPattern + */ +class DFPatternNode : public Object { + public: + static constexpr const char* _type_key = "DFPatternNode"; + TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object); +}; + +/*! + * \brief Managed reference to dataflow patterns. + * \sa DFPatternNode + */ +class DFPattern : public ObjectRef { + public: + /*! \brief Syntatic Sugar for creating a CallPattern */ + template + CallPattern operator()(Args&&... args) const; + /*! \brief Syntatic Sugar for creating a CallPattern */ + TVM_DLL CallPattern operator()(const std::vector& args) const; + /*! \brief Syntatic Sugar for creating an OrPattern */ + TVM_DLL OrPattern operator|(const DFPattern& other) const; + /*! \brief Syntatic Sugar for creating an AndPattern */ + TVM_DLL AndPattern operator&(const DFPattern& other) const; + /*! \brief Syntatic Sugar for creating a NotPattern */ + TVM_DLL NotPattern operator~() const; + /*! \brief Syntatic Sugar for creating an AttrPattern */ + TVM_DLL AttrPattern HasAttr(const Map& attrs) const; + /*! \brief Syntatic Sugar for creating a TypePattern */ + TVM_DLL TypePattern HasType(const Type& type) const; + /*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */ + TVM_DLL DataTypePattern HasDtype(const DataType& dtype) const; + /*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */ + TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const; + /*! \brief Syntatic Sugar for creating a ShapePattern */ + TVM_DLL ShapePattern HasShape(const Array& shape) const; + /*! \brief Syntatic Sugar for duplicating the current pattern */ + TVM_DLL DFPattern dup() const; + + /*! \brief Implicit conversion from DFPattern to PatternSeq */ + TVM_DLL operator PatternSeq() const; + + TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode); +}; + +/*! \brief Constraint of a DFPattern edge (producer -> consumer) in graph-level matching */ +struct PairCons { + /*! \brief Constraint types of the edge */ + enum Type { + kUsedBy, /*!< producer ^ consumer */ + kOnlyUsedBy, /*!< producer >> consumer */ + } type = kUsedBy; + int index = -1; /*!< The argument index of the producer in the consumer caller site */ + + /*! + * \brief Construct a new PairCons object + * + * \param t The constraint type + * \param index The producer is called as the index'th argument of the consumer function. + */ + TVM_DLL explicit PairCons(Type t, int index = -1) : type(t), index(index) {} + + bool operator==(const PairCons& other) const { + return type == other.type && index == other.index; + } +}; + +/*! + * \brief A sequence of DFPatterns that the previous DFPattern is connected to the next one. + * \sa PatternSeq + */ +class PatternSeqNode final : public Object { + public: + tvm::Array patterns; /*!< The sequence of DFPatterns */ + std::vector pair_constraints; /*!< Constraints between the previous and next patterns */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("patterns", &patterns); } + static constexpr const char* _type_key = "relax.dpl.PatternSeq"; + TVM_DECLARE_BASE_OBJECT_INFO(PatternSeqNode, Object); +}; + +/*! + * \brief Managed reference to pattern sequences. + * \sa PatternSeqNode + */ +class PatternSeq final : public ObjectRef { + public: + TVM_DLL explicit PatternSeq(DFPattern init_pattern); + TVM_DLL explicit PatternSeq(tvm::Array patterns, bool only_used_by = false); + + PatternSeq UsedBy(PatternSeq other, int index = -1) const; + PatternSeq OnlyUsedBy(PatternSeq other, int index = -1) const; + + /*! \brief Syntatic Sugar for duplicating the current pattern sequence */ + PatternSeq dup() const; + + // friend functions + friend PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index); + friend PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index); + + TVM_DEFINE_OBJECT_REF_METHODS(PatternSeq, ObjectRef, PatternSeqNode); +}; + +/*! + * \brief A context to manage the graph-level pattern matching. + * \sa PatternContext + */ +class PatternContextNode : public Object { + public: + /*! \brief Constrainting matched graph with assertion to external uses */ + enum ExternUse { + kMay, /*!< No constraints */ + kMustNot, /*!< All nodes except outputs only have internal depedencies in the matched graph. */ + } allow_extern_use = kMay; + // src node -> constraints. + std::map>> constraints; + + static constexpr const char* _type_key = "relax.dpl.PatternContext"; + TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextNode, Object); +}; + +/*! + * \brief Managed reference to a pattern context. + * \sa PatternContextNode + */ +class PatternContext : public ObjectRef { + public: + TVM_DLL explicit PatternContext(ObjectPtr n) : ObjectRef(n) {} + TVM_DLL explicit PatternContext(bool incremental = false); + + const PatternContextNode* operator->() const { + ICHECK(get() != nullptr); + return static_cast(get()); + } + + PatternContextNode* operator->() { + ICHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + /*! + * \brief Build an edge constraint between two patterns (producer and consumer). + * + * \param producer The pattern corresponding to the producer node. + * \param consumer The pattern corresponding to the consumer node. + * \param cons The constraint type. \sa PairCons + */ + void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons) { + auto& vec = (*this)->constraints[producer][consumer]; + ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend()) << "Constraint already exists"; + vec.push_back(cons); + } + + /*! \brief Get the pass context object on the top of the stack */ + TVM_DLL static PatternContext Current(); + + class Internal; + + private: + /*! \brief The RAII-like entry of a pass context scope */ + TVM_DLL void EnterWithScope(); + /*! \brief The RAII-like exit of a pass context scope */ + TVM_DLL void ExitWithScope(); + friend class Internal; + friend class With; +}; + +/*! + * \brief Pattern for Relax Expression. + * \sa ExprPattern + */ +class ExprPatternNode : public DFPatternNode { + public: + Expr expr; /*!< The expression to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); } + + static constexpr const char* _type_key = "relax.dpl.ExprPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to an ExprPattern. + * \sa ExprPatternNode + */ +class ExprPattern : public DFPattern { + public: + TVM_DLL explicit ExprPattern(Expr expr); + TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relax Variable. + * \note The name field matches any string if it is empty. + * \sa VarPattern + */ +class VarPatternNode : public DFPatternNode { + public: + String name; + const String& name_hint() const { return name; } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); } + + static constexpr const char* _type_key = "relax.dpl.VarPattern"; + TVM_DECLARE_BASE_OBJECT_INFO(VarPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a VarPattern. + * \sa VarPatternNode + */ +class VarPattern : public DFPattern { + public: + /*! + * \brief Create a pattern matching by variable name. + * + * \param name_hint Variable name to match. Any if empty (""). + */ + TVM_DLL VarPattern(String name_hint); + TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relax Dataflow Variable + * \sa DataflowVarPattern + */ +class DataflowVarPatternNode : public VarPatternNode { + public: + static constexpr const char* _type_key = "relax.dpl.DataflowVarPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a DataflowVarPattern. + * \sa DataflowVarPatternNode + */ +class DataflowVarPattern : public DFPattern { + public: + /*! \sa VarPattern::VarPattern */ + TVM_DLL DataflowVarPattern(String name_hint); + TVM_DEFINE_OBJECT_REF_METHODS(DataflowVarPattern, DFPattern, DataflowVarPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relax Global Variable + * \sa GlobalVarPattern + */ +class GlobalVarPatternNode : public VarPatternNode { + public: + static constexpr const char* _type_key = "relax.dpl.GlobalVarPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a GlobalVarPattern. + * \sa GlobalVarPatternNode + */ +class GlobalVarPattern : public DFPattern { + public: + TVM_DLL GlobalVarPattern(String name_hint); + TVM_DEFINE_OBJECT_REF_METHODS(GlobalVarPattern, DFPattern, GlobalVarPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relax Constant. + * \sa ConstantPattern + */ +class ConstantPatternNode : public DFPatternNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relax.dpl.ConstantPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a ConstantPattern. + * \sa ConstantPatternNode + */ +class ConstantPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(ConstantPattern, DFPattern, ConstantPatternNode); +}; + +/*! + * \brief A pattern to match a callable node in Relax. + * \sa CallPattern + */ +class CallPatternNode : public DFPatternNode { + public: + /*! + * \note The op field can be: + * - relay::Op which corresponds to the primitive operators. + * - user defined functions (Function, GlobalVar, Var). + */ + DFPattern op; /*!< The operator (function) being invoked */ + tvm::Array args; /*!< The arguments of the function call */ + /*! + * \note If varg_default_wildcard is true. Given args of [pA, pB], when matching a call whose + * arguments are [A, B, ...], the pattern will still match despite N(args) < N(call.args). That + * said, with varg_default_wildcard set to true, we match the args in the order we have, and + * regard the rest of the arguments as wildcards. + */ + bool varg_default_wildcard; /*!< N(args) can be < N(real args) by the padding of Wildcard */ + + // Todo(relax-team): Dataflow pattern for StructInfo, and match sinfo_args + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + v->Visit("args", &args); + } + + static constexpr const char* _type_key = "relax.dpl.CallPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode); +}; + +class CallPattern : public DFPattern { + public: + TVM_DLL CallPattern(DFPattern op, Array args, bool varg_default_wildcard = false); + TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode); +}; + +/*! + * \brief A pattern to match an array of PrimExpr. + * \sa PrimArrPattern + * \note This is often used to match shapes specified as arguments to a function. + */ +class PrimArrPatternNode : public DFPatternNode { + public: + Array fields; /*!< The array to match */ + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + static constexpr const char* _type_key = "relax.dpl.PrimArrPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimArrPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a PrimArrPattern. + * \sa PrimArrPatternNode + */ +class PrimArrPattern : public DFPattern { + public: + TVM_DLL PrimArrPattern(Array arr); + TVM_DEFINE_OBJECT_REF_METHODS(PrimArrPattern, DFPattern, PrimArrPatternNode); +}; + +/*! + * \brief A pattern to match a Relax Function + * \sa Function + * \sa FunctionPattern + */ +class FunctionPatternNode : public DFPatternNode { + public: + tvm::Array params; /*!< The parameters of the function */ + /*! + * \note Note that in Relax, the function body is a SeqExpr which contains + * 1) SeqExprNode::blocks, which is a list of blocks of statements; and 2) + * SeqExprNode::body, which is an Expr that can be anything. FunctionPattern + * only matches the body of the function (writing patterns to statements is tricky). + */ + DFPattern body; /*!< The body of the function */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "relax.dpl.FunctionPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to FunctionPatternNode. + * \sa FunctionPatternNode + */ +class FunctionPattern : public DFPattern { + public: + /*! + * \brief Constructor + * \param params The parameters of the function. + * \param body The body of the function. + */ + TVM_DLL FunctionPattern(tvm::Array params, DFPattern body); + + TVM_DEFINE_OBJECT_REF_METHODS(FunctionPattern, DFPattern, FunctionPatternNode); +}; + +/*! + * \brief Pattern to match a tuple of ordered expressions. + * \sa TuplePattern + */ +class TuplePatternNode : public DFPatternNode { + public: + tvm::Array fields; /*!< The fields of the tuple */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + + static constexpr const char* _type_key = "relax.dpl.TuplePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to TuplePatternNode. + * \sa TuplePatternNode + */ +class TuplePattern : public DFPattern { + public: + TVM_DLL explicit TuplePattern(tvm::Array fields); + TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode); +}; + +/*! + * \brief A pattern to match multiple expressions unorderedly. + * \sa UnorderedTuplePattern + */ +class UnorderedTuplePatternNode : public DFPatternNode { + public: + tvm::Array fields; /*!< The fields of the tuple */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + + static constexpr const char* _type_key = "relax.dpl.UnorderedTuplePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(UnorderedTuplePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to UnorderedTuplePatternNode. + * \sa UnorderedTuplePatternNode + */ +class UnorderedTuplePattern : public DFPattern { + public: + TVM_DLL explicit UnorderedTuplePattern(tvm::Array fields); + TVM_DEFINE_OBJECT_REF_METHODS(UnorderedTuplePattern, DFPattern, UnorderedTuplePatternNode); +}; + +/*! + * \brief A pattern to match n'th indexing to a tuple. + * \sa TupleGetItem + * \sa TupleGetItemPattern + */ +class TupleGetItemPatternNode : public DFPatternNode { + public: + DFPattern tuple; /*!< The tuple Expression */ + int index; /*!< The index of the tuple with -1 meaning arbitrary */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tuple", &tuple); + v->Visit("index", &index); + } + + static constexpr const char* _type_key = "relax.dpl.TupleGetItemPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to TupleGetItemPatternNode. + * \sa TupleGetItemPatternNode + */ +class TupleGetItemPattern : public DFPattern { + public: + TVM_DLL TupleGetItemPattern(DFPattern tuple, int index); + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode); +}; + +/*! + * \brief Match a conjunction of other patterns. + * \sa AndPattern + */ +class AndPatternNode : public DFPatternNode { + public: + DFPattern left; /*!< The left hand side of the conjunction */ + DFPattern right; /*!< The right hand side of the conjunction */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("left", &left); + v->Visit("right", &right); + } + + static constexpr const char* _type_key = "relax.dpl.AndPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(AndPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to AndPatternNode. + * \sa AndPatternNode + */ +class AndPattern : public DFPattern { + public: + TVM_DLL AndPattern(DFPattern lhs, DFPattern rhs); + TVM_DEFINE_OBJECT_REF_METHODS(AndPattern, DFPattern, AndPatternNode); +}; + +/*! + * \brief Match a disjunction of other patterns. + * \sa OrPattern + */ +class OrPatternNode : public DFPatternNode { + public: + DFPattern left; /*!< The left hand side of the disjunction */ + DFPattern right; /*!< The right hand side of the disjunction */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("left", &left); + v->Visit("right", &right); + } + + static constexpr const char* _type_key = "relax.dpl.OrPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(OrPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to OrPatternNode. + * \sa OrPatternNode + */ +class OrPattern : public DFPattern { + public: + TVM_DLL OrPattern(DFPattern left, DFPattern right); + TVM_DEFINE_OBJECT_REF_METHODS(OrPattern, DFPattern, OrPatternNode); +}; + +/*! + * \brief Pattern for rejecting a certain pattern. + * \sa NotPattern + */ +class NotPatternNode : public DFPatternNode { + public: + DFPattern reject; /*!< The pattern to reject */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("reject", &reject); } + + static constexpr const char* _type_key = "relax.dpl.NotPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(NotPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to NotPatternNode. + * \sa NotPatternNode + */ +class NotPattern : public DFPattern { + public: + TVM_DLL NotPattern(DFPattern reject); + TVM_DEFINE_OBJECT_REF_METHODS(NotPattern, DFPattern, NotPatternNode); +}; + +/*! + * \brief Wildcard Pattern is a pattern that can match anything. + * \sa WildcardPattern + */ +class WildcardPatternNode : public DFPatternNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relax.dpl.WildcardPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to WildcardPatternNode. + * \sa WildcardPatternNode + */ +class WildcardPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); +}; + +/*! + * \brief Pattern for matching a certain type. + * \sa TypePattern + */ +class TypePatternNode : public DFPatternNode { + public: + DFPattern pattern; /*!< The pattern to match */ + Type type; /*!< The type to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("type", &type); + } + + static constexpr const char* _type_key = "relax.dpl.TypePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to TypePatternNode. + * \sa TypePatternNode + */ +class TypePattern : public DFPattern { + public: + TVM_DLL TypePattern(DFPattern pattern, Type type); + TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode); +}; + +/*! + * \brief A pattern that asserting a root pattern has a certain shape. + * \sa ShapePattern + */ +class ShapePatternNode : public DFPatternNode { + public: + DFPattern pattern; /*!< The root pattern to match */ + Array shape; /*!< The shape to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("shape", &shape); + } + + static constexpr const char* _type_key = "relax.dpl.ShapePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to ShapePatternNode. + * \sa ShapePatternNode + */ +class ShapePattern : public DFPattern { + public: + TVM_DLL ShapePattern(DFPattern pattern, Array type); + TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode); +}; + +/*! + * \brief A pattern that asserting a root pattern has a certain data type. + * \sa DataTypePattern + */ +class DataTypePatternNode : public DFPatternNode { + public: + DFPattern pattern; /*!< The root pattern to match */ + DataType dtype; /*!< The data type to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("dtype", &dtype); + } + + static constexpr const char* _type_key = "relax.dpl.DataTypePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to DataTypePatternNode. + * \sa DataTypePatternNode + */ +class DataTypePattern : public DFPattern { + public: + TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype); + TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode); +}; + +/*! + * \brief A pattern that asserting a root pattern has certain attributes. + * \sa AttrPattern + */ +class AttrPatternNode : public DFPatternNode { + public: + DFPattern pattern; /*!< The root pattern to match */ + DictAttrs attrs; /*!< The attributes (a map/dictionary) to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("attrs", &attrs); + } + + static constexpr const char* _type_key = "relax.dpl.AttrPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to AttrPatternNode. + * \sa AttrPatternNode + */ +class AttrPattern : public DFPattern { + public: + TVM_DLL AttrPattern(DFPattern pattern, DictAttrs attrs); + TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode); +}; + +/*! + * \brief A pattern of external function. + * \sa ExternFunc + * \sa ExternFuncPattern + */ +class ExternFuncPatternNode : public DFPatternNode { + public: + String global_symbol_; /*!< The global symbol name of the external function */ + + /*! \brief The the external function name */ + const String& global_symbol() const { return global_symbol_; } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("global_symbol", &global_symbol_); } + + static constexpr const char* _type_key = "relax.dpl.ExternFuncPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to ExternFuncPatternNode. + * \sa ExternFuncPatternNode + */ +class ExternFuncPattern : public DFPattern { + public: + TVM_DLL ExternFuncPattern(String global_symbol); + TVM_DEFINE_OBJECT_REF_METHODS(ExternFuncPattern, DFPattern, ExternFuncPatternNode); +}; + +/*! \brief Syntatic Sugar for creating a VarPattern with a name */ +VarPattern IsVar(const String& name); +/*! \brief Syntatic Sugar for creating a ConstantPattern */ +ConstantPattern IsConst(); +/*! \brief Syntatic Sugar for creating a WildcardPattern */ +WildcardPattern Wildcard(); +/*! \brief Syntatic Sugar for creating a ExprPattern */ +ExprPattern IsExpr(const Expr& expr); +/*! \brief Syntatic Sugar for creating a ExprPattern base on an Op */ +ExprPattern IsOp(const String& op_name); +/*! \brief Syntatic Sugar for call_tir (return a tensor) */ +// Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo +CallPattern IsCallTIR(const String& name, Optional args = NullOpt); +/*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */ +CallPattern IsCallTIR(const String& name, TuplePattern var_args); +/*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true) */ +DFPattern IsTuple(const Array& fields, bool unordered = false); +/*! \brief Syntatic Sugar for creating a TupleGetItemPattern */ +TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index = -1); + +/*! \brief Implementation of the templated CallPattern syntax sugar */ +template +CallPattern DFPattern::operator()(Args&&... args) const { + return CallPattern(GetRef(this->get()), + Array({std::forward(args)...})); +} + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_DATAFLOW_PATTERN_H_ diff --git a/include/tvm/relax/dataflow_pattern_functor.h b/include/tvm/relax/dataflow_pattern_functor.h new file mode 100644 index 000000000000..983881ddc9a7 --- /dev/null +++ b/include/tvm/relax/dataflow_pattern_functor.h @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_pattern_functor.h + * \brief Functors and visitors for dataflow patterns. + */ +#ifndef TVM_RELAX_DATAFLOW_PATTERN_FUNCTOR_H_ +#define TVM_RELAX_DATAFLOW_PATTERN_FUNCTOR_H_ + +#include + +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A dynamical functor that dispatches on in the first DFPattern argument. + * + * \tparam FType function signature + * This type is only defined for FType with function signature R(const DFPattern&, + * Args...) + */ +template +class DFPatternFunctor; + +// functions to be overriden. +#define DFPATTERN_FUNCTOR_DEFAULT \ + { return VisitDFPatternDefault_(op, std::forward(args)...); } + +#define RELAX_DFPATTERN_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitDFPattern_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class DFPatternFunctor { + private: + using TSelf = DFPatternFunctor; + using FType = tvm::NodeFunctor; + + public: + /*! \brief virtual destructor */ + virtual ~DFPatternFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const DFPattern& n, Args... args) { + return VisitDFPattern(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitDFPattern(const DFPattern& n, Args... args) { + ICHECK(n.defined()); + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitDFPattern_(const OrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const AndPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const NotPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + + virtual R VisitDFPattern_(const DataflowVarPatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const GlobalVarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ExternFuncPatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const PrimArrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const UnorderedTuplePatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + + virtual R VisitDFPatternDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAX_DFPATTERN_FUNCTOR_DISPATCH(OrPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(AndPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(NotPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); + + RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(GlobalVarPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExternFuncPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(PrimArrPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(UnorderedTuplePatternNode); + return vtable; + } +}; + +/*! + * \brief A simple visitor wrapper around DFPatternFunctor. + * Recursively visit the content. + * + * DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once. + */ +class DFPatternVisitor : public DFPatternFunctor { + public: + void VisitDFPattern(const DFPattern& pattern) override; + void VisitDFPattern_(const OrPatternNode* op) override; + void VisitDFPattern_(const AndPatternNode* op) override; + void VisitDFPattern_(const NotPatternNode* op) override; + void VisitDFPattern_(const AttrPatternNode* op) override; + void VisitDFPattern_(const CallPatternNode* op) override; + void VisitDFPattern_(const ConstantPatternNode* op) override; + void VisitDFPattern_(const DataTypePatternNode* op) override; + void VisitDFPattern_(const ExprPatternNode* op) override; + void VisitDFPattern_(const FunctionPatternNode* op) override; + void VisitDFPattern_(const ShapePatternNode* op) override; + void VisitDFPattern_(const TupleGetItemPatternNode* op) override; + void VisitDFPattern_(const TuplePatternNode* op) override; + void VisitDFPattern_(const TypePatternNode* op) override; + void VisitDFPattern_(const WildcardPatternNode* op) override; + void VisitDFPattern_(const VarPatternNode* op) override; + + void VisitDFPattern_(const DataflowVarPatternNode* op) override; + void VisitDFPattern_(const GlobalVarPatternNode* op) override; + void VisitDFPattern_(const ExternFuncPatternNode* op) override; + void VisitDFPattern_(const PrimArrPatternNode* op) override; + void VisitDFPattern_(const UnorderedTuplePatternNode* op) override; + + protected: + // set of already-visited nodes + std::unordered_set visited_; +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_DATAFLOW_PATTERN_FUNCTOR_H_ diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 710788347829..45c5b6f96288 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -21,14 +21,14 @@ configuring the passes and scripting them in Python. """ -from typing import Dict +from typing import Dict, List from enum import IntEnum from tvm import tir from tvm import IRModule from tvm.relax.ty import Type from tvm.relax.struct_info import StructInfo, FuncStructInfo -from tvm.relax.expr import Var, Expr, Call +from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call from . import _ffi_api @@ -210,6 +210,40 @@ def has_reshape_pattern(func: tir.PrimFunc) -> bool: return _ffi_api.has_reshape_pattern(func) # type: ignore +def get_var2val(func: Function) -> Dict[Var, Expr]: + """ + Get a mapping from Var to Expr for each variable in the function. + + Parameters + ---------- + func : Function + The input function to be analyzed. + + Returns + ------- + Dict[Var, Expr] + A mapping from Var to Expr. + """ + return _ffi_api.get_var2val(func) # type: ignore + + +def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]: + """ + Analyze the variable use-def chain in a dataflow block. + + Parameters + ---------- + dfb : DataflowBlock + The dataflow block to analyze + + Returns + ------- + Dict[Var, List[Var]] + A mapping from variable definition to its uses. + """ + return _ffi_api.udchain(dfb) # type: ignore + + def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool: """Check if the IRModule is well formed. diff --git a/python/tvm/relax/dpl/__init__.py b/python/tvm/relax/dpl/__init__.py new file mode 100644 index 000000000000..e0bbdaff0512 --- /dev/null +++ b/python/tvm/relax/dpl/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""The Relax Dataflow Pattern Language.""" + +from .pattern import * +from .context import * diff --git a/python/tvm/relax/dpl/_ffi.py b/python/tvm/relax/dpl/_ffi.py new file mode 100644 index 000000000000..6699e42bee63 --- /dev/null +++ b/python/tvm/relax/dpl/_ffi.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DataFlow Pattern Language FFI bindings.""" +import tvm._ffi + +tvm._ffi._init_api("relax.dpl", __name__) diff --git a/python/tvm/relax/dpl/context.py b/python/tvm/relax/dpl/context.py new file mode 100644 index 000000000000..69a5e70ed0f1 --- /dev/null +++ b/python/tvm/relax/dpl/context.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""The Graph Matching Context Manager for Dataflow Pattern Language.""" + +from typing import Optional, Dict + +import tvm +from ..expr import DataflowBlock, Var +from .pattern import DFPattern +from . import _ffi as ffi + + +class PatternContext(tvm.runtime.Object): + """A context object for doing graph (topogical) pattern matching.""" + + def __init__(self, incremental=False): + """ + Initialize the PatternContext + + Parameters + ---------- + incremental : bool, optional + perform incremental matching based on the recent context, by default False + """ + self.__init_handle_by_constructor__(ffi.PatternContext, incremental) # type: ignore + + def __enter__(self): + """Enter the context""" + ffi.enter_context(self) # type: ignore + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Exit the context""" + ffi.exit_context(self) # type: ignore + + @staticmethod + def current() -> "PatternContext": + """ + Get the current context + + Returns + ------- + PatternContext + The current context + """ + return ffi.current_context() # type: ignore + + def match_dfb( + self, + dfb: DataflowBlock, + start_hint: Optional[Var] = None, + must_include_hint: bool = False, + ) -> Dict[DFPattern, Var]: + """ + Match a DataflowBlock via a graph of DFPattern and corresponding constraints + + Parameters + ---------- + dfb : DataflowBlock + The DataflowBlock to match + start_hint : Optional[Var], optional + Indicating the starting expression to match, by default None + must_include_hint : bool, optional + Whether the start_hint expression must be matched, by default False + + Returns + ------- + Dict[DFPattern, Var] + The mapping from DFPattern to matched expression + """ + return ffi.match_dfb(self, dfb, start_hint, must_include_hint) # type: ignore diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py new file mode 100644 index 000000000000..44faa0c93a14 --- /dev/null +++ b/python/tvm/relax/dpl/pattern.py @@ -0,0 +1,1095 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern types in Relax Dataflow Pattern Language""" +# pylint: disable=no-member +# pylint: disable=pointless-statement + +import typing +from typing import Dict, List, Optional, Tuple, Union + +import tvm +import tvm._ffi as tvm_ffi +from tvm.ir.container import Array +from tvm.ir.expr import PrimExpr +from tvm.relay.op import get + +from ...ir import make_node +from ...ir.base import Node +from ...runtime import Object +from ..expr import Expr, Var +from . import _ffi as ffi + + +def register_df_node(type_key=None): + """ + Register a Relax node type + + Parameters + ---------- + type_key : str or cls + The type key of the node + """ + if not isinstance(type_key, str): + return tvm_ffi.register_object("relax.dpl." + type_key.__name__)(type_key) + return tvm_ffi.register_object(type_key) + + +class DFPattern(Node): + """Base class of all Patterns.""" + + def __call__(self, *args, varg_default_wildcard=False) -> "CallPattern": + """ + Syntax sugar for creating a CallPattern with argument patterns + + Returns + ------- + result: CallPattern + The resulting CallPattern + """ + return CallPattern(self, args, varg_default_wildcard) + + def __or__(self, other: "DFPattern") -> "OrPattern": + """ + Syntax sugar for creating an OrPattern + + Parameters + ---------- + other: DFPattern + Alternative pattern + + Returns + ------- + result: OrPattern + The resulting OrPattern + """ + return OrPattern(self, other) + + def __and__(self, other: "DFPattern") -> "AndPattern": + """ + Syntax sugar for creating an AndPattern + + Parameters + ---------- + other: DFPattern + Additional pattern to satisfy + + Returns + ------- + result: AndPattern + The resulting AndPattern + """ + return AndPattern(self, other) + + def __invert__(self) -> "NotPattern": + """ + Syntax sugar for creating a DFPattern to reject + + Returns + ------- + result: NotPattern + The resulting NotPattern + """ + return reject(self) + + def has_attr(self, attrs: Dict[str, Object]) -> "AttrPattern": + """ + Add an attribute constraint to this pattern + + Parameters + ---------- + attrs: Dict[str, Object] + + Returns + ------- + result: AttrPattern + The resulting AttrPattern + """ + attrs = make_node("DictAttrs", **attrs) + return AttrPattern(self, attrs) + + def has_type(self, ttype: tvm.ir.type.Type) -> "TypePattern": + """ + Add a type constraint to this pattern + + Parameters + ---------- + ttype: tvm.ir.type.Type + The type to match + + Returns + ------- + result: TypePattern + The resulting TypePattern + """ + return TypePattern(self, ttype) + + def has_dtype(self, dtype: str) -> "DataTypePattern": + """ + Add a type constraint to this pattern + + Parameters + ---------- + dtype: str + The dtype to match + + Returns + ------- + result: DataTypePattern + The resulting DataTypePattern + """ + return has_dtype(dtype, self) + + def has_shape(self, shape: List[PrimExpr]) -> "ShapePattern": + """ + Add a shape constraint to this pattern + + Parameters + ---------- + shape: List[PrimExpr] + Expected shape list + + Returns + ------- + result: ShapePattern + The resulting ShapePattern + + Note + ---- + has_shape assumes that the matched relax.Expr only has one + output tensor. Use is_tuple for those with multiple outputs. + """ + if not isinstance(shape, (list, tuple, tvm.ir.PrimExpr)): + raise ValueError("has_shape takes a list or tuple as input.") + return ShapePattern(pattern=self, shape=shape) + + def match(self, expr, var2val: Optional[Dict[Var, Expr]] = None) -> bool: + """ + Match a relax.Expr syntactically + + Parameters + ---------- + expr : tvm.relax.Expr + The expression to match + var2val : Optional[Dict[tvm.relax.Var, tvm.relax.Expr]] + A mapping from relax.Var to relax.Expr for autojump. + + Returns + ------- + result: bool + Whether or not the expression matches the pattern + + Note + ---- + Unlike Relay whose function is an expression, functions in Relax consist + of blocks of bindings that are not syntactically connected. We use a + mapping (i.e., var2val) to mitigate the gap. For example, when matching + "relax.add(lv0, lv1)", given var2val, we match lv0's bound expression + when the recursive pattern matching goes to check lv0. The var2val mapping + can be computed through the tvm.relax.analysis.get_var2val function. + """ + return ffi.match_expr(self, expr, var2val) # type: ignore + + def used_by(self, other: Union["DFPattern", "PatternSeq"], index=-1) -> "PatternSeq": + """ + The current pattern being used by another pattern (sequence) + + Parameters + ---------- + other : Union[DFPattern, DFPattern] + The consumer pattern (sequence) + index : int, optional + The argument index called by the consumer pattern, by default -1 + + Returns + ------- + result: PatternSeq + A chained pattern sequence + """ + return _used_by(self, other, index) + + def __xor__(self, other: Union["DFPattern", "PatternSeq"]) -> "PatternSeq": + """Syntax sugar of DFPattern.used_by""" + return self.used_by(other, -1) + + def only_used_by(self, other: Union["DFPattern", "PatternSeq"], index=-1) -> "PatternSeq": + """ + The current pattern being **ONLY** used by another pattern (sequence) + + Parameters + ---------- + other : Union[DFPattern, DFPattern] + The consumer pattern (sequence) + index : int, optional + The argument index called by the consumer pattern, by default -1 + + Returns + ------- + result: PatternSeq + A chained pattern sequence + """ + return _only_used_by(self, other, index) + + def __rshift__(self, other: Union["DFPattern", "PatternSeq"]) -> "PatternSeq": + """Syntax sugar of DFPattern.only_used_by""" + return self.only_used_by(other, -1) + + def dup(self) -> "DFPattern": + """ + Duplicate the current pattern (new object under different address) + + Returns + ------- + DFPattern + A duplicated pattern + """ + return ffi.dup_pattern(self) # type: ignore + + def fork_to(self, *args) -> None: + """Fork the current pattern to multiple pattern branches""" + for v in args: + self ^ v + + +@register_df_node +class ExprPattern(DFPattern): + """A pattern which matches an expression. + + Parameters + ---------- + expr : tvm.relax.Expr + The expression to match. + """ + + def __init__(self, expr: Expr): + self.__init_handle_by_constructor__(ffi.ExprPattern, expr) # type: ignore + + +@register_df_node +class VarPattern(DFPattern): + """A pattern for Var. + + Parameters + ---------- + name_hint: str + The name of the variable. Optional, if not provided, + the pattern will match any VarNode. + """ + + def __init__(self, name_hint: str = ""): + self.__init_handle_by_constructor__(ffi.VarPattern, name_hint) # type: ignore + + +@register_df_node +class DataflowVarPattern(DFPattern): + """A pattern for DataflowVar. + + Parameters + ---------- + name_hint: str + The name of the variable. Optional, if not provided, + the pattern will match any VarNode. + """ + + def __init__(self, name_hint: str = ""): + self.__init_handle_by_constructor__(ffi.DataflowVarPattern, name_hint) # type: ignore + + +@register_df_node +class GlobalVarPattern(DFPattern): + """A pattern for GlobalVar. + + Parameters + ---------- + name_hint: str + The name of the variable. Optional, if not provided, + the pattern will match any GlobalVarNode. + """ + + def __init__(self, name_hint: str = ""): + self.__init_handle_by_constructor__(ffi.GlobalVarPattern, name_hint) # type: ignore + + +@register_df_node +class ExternFuncPattern(DFPattern): + """A external function pattern. + + Parameters + ---------- + global_symbol: str + The name of the function. Optional, if not provided, + the pattern will match any ExternFuncNode. + """ + + def __init__(self, global_symbol: str = ""): + self.__init_handle_by_constructor__(ffi.ExternFuncPattern, global_symbol) # type: ignore + + +@register_df_node +class ConstantPattern(DFPattern): + """A pattern matching a Relax Constant.""" + + def __init__(self): + self.__init_handle_by_constructor__(ffi.ConstantPattern) # type: ignore + + +@register_df_node +class CallPattern(DFPattern): + """A pattern matching a function call node. + + Parameters + ---------- + op: tvm.relax.dpl.DFPattern + The operation to be called. + + args: List[tvm.relax.dpl.DFPattern] + The arguments to the call or None to match any arguments. + + varg_default_wildcard: bool + If True, args can be fewer than actual provided arguments. + + Note + ---- + By setting varg_default_wildcard to True, we can only focus on the argument + patterns we specified. For example, CallPattern(Op, [A, B]) can match + a call of Op(A, B) or Op(A, B, C, ...) that has more arguments. However, + the specified argument patterns must be matched (i.e., A and B). + """ + + def __init__( + self, + op: "DFPattern", + args: Union[List["DFPattern"], typing.Tuple["DFPattern", ...]], + varg_default_wildcard: bool = False, + ): + self.__init_handle_by_constructor__( + ffi.CallPattern, op, args, varg_default_wildcard # type: ignore + ) + + +@register_df_node +class FunctionPattern(DFPattern): + """A pattern matching a function node in Relax. + + Parameters + ---------- + params: List[tvm.relax.dpl.DFPattern] + The parameters to the Function or None to match any parameters. + + body: tvm.relax.dpl.DFPattern + The body fo the Function + + """ + + def __init__( + self, + params: List["DFPattern"], + body: "DFPattern", + ): + self.__init_handle_by_constructor__(ffi.FunctionPattern, params, body) # type: ignore + + +@register_df_node +class TuplePattern(DFPattern): + """A patern matching a Relax Tuple. + + Parameters + ---------- + fields : Array[tvm.relax.dpl.DFPattern] + The fields in the tuple. + """ + + def __init__(self, fields: Array): + self.__init_handle_by_constructor__(ffi.TuplePattern, fields) # type: ignore + + def __getitem__(self, index: Optional[int]) -> "TupleGetItemPattern": + if index is not None: + # support negative index for being pythonic + if index < 0: + index += len(self) + if index >= len(self): + raise IndexError("TuplePattern index out of range") + else: + index = -1 # -1 means matching any index + return TupleGetItemPattern(self, index) + + def __len__(self): + return len(self.fields) + + +@register_df_node +class UnorderedTuplePattern(DFPattern): + """A patern matching a Relax Tuple unorderedly. + + Parameters + ---------- + fields : Array[tvm.relax.dpl.DFPattern] + The fields in the tuple. + """ + + def __init__(self, fields: Array): + self.__init_handle_by_constructor__(ffi.UnorderedTuplePattern, fields) # type: ignore + + def __len__(self): + return len(self.fields) + + +@register_df_node +class TupleGetItemPattern(DFPattern): + """Get index-th item from a TuplePattern. + + Parameters + ---------- + tuple_value: tvm.relax.dpl.DFPattern + The input tuple expression. + + index: Optional[int] + The index to match; Default (None) to match a TupleGetItem with any index. + """ + + def __init__(self, tuple_value: "DFPattern", index: Optional[int] = None): + match_index = index if index is not None else -1 + self.__init_handle_by_constructor__( + ffi.TupleGetItemPattern, tuple_value, match_index # type: ignore + ) + + +@register_df_node +class OrPattern(DFPattern): + """Create a Pattern that can match one of two conditions + + Parameters + ---------- + left: tvm.relax.dpl.DFPattern + One possible matching pattern. + right: tvm.relax.dpl.DFPattern + One possible matching pattern. + """ + + def __init__(self, left: "DFPattern", right: "DFPattern"): + self.__init_handle_by_constructor__(ffi.OrPattern, left, right) # type: ignore + + +@register_df_node +class AndPattern(DFPattern): + """Create a Pattern that must match two conditions + + Parameters + ---------- + left: tvm.relax.dpl.DFPattern + One must-matching pattern. + right: tvm.relax.dpl.DFPattern + One must-matching pattern. + """ + + def __init__(self, left: "DFPattern", right: "DFPattern"): + self.__init_handle_by_constructor__(ffi.AndPattern, left, right) # type: ignore + + +@register_df_node +class NotPattern(DFPattern): + """Create a Pattern that matches the negation of a condition. + + Parameters + ---------- + to_reject: tvm.relax.dpl.DFPattern + The pattern to deny. + """ + + def __init__(self, to_reject: "DFPattern"): + self.__init_handle_by_constructor__(ffi.NotPattern, to_reject) # type: ignore + + +@register_df_node +class WildcardPattern(DFPattern): + """A pattern which matches anything.""" + + def __init__(self): + self.__init_handle_by_constructor__(ffi.WildcardPattern) # type: ignore + + +@register_df_node +class TypePattern(DFPattern): + """A pattern that matches another pattern with a certain type annotation. + + Parameters + ---------- + pattern: tvm.relax.dpl.DFPattern + The input pattern that needs type annotation. + + ttype: tvm.ir.type.Type + The type to match. + """ + + def __init__(self, pattern: "DFPattern", ttype: tvm.ir.type.Type): + self.__init_handle_by_constructor__(ffi.TypePattern, pattern, ttype) # type: ignore + + +@register_df_node +class DataTypePattern(DFPattern): + """A pattern that matches another pattern with certain data type + + Parameters + ---------- + pattern: tvm.relax.dpl.DFPattern + The input pattern that needs type annotation. + + dtype: str + The dtype to match. + """ + + def __init__(self, pattern: "DFPattern", dtype: str): + self.__init_handle_by_constructor__(ffi.DataTypePattern, pattern, dtype) # type: ignore + + +@register_df_node +class ShapePattern(DFPattern): + """A pattern that matches another pattern with a certain tensor shape + + Parameters + ---------- + pattern: tvm.relax.dpl.DFPattern + The input pattern that needs type annotation. + + shape: List[tvm.ir.PrimExpr] + The shape to match. + """ + + def __init__(self, pattern: "DFPattern", shape: List[tvm.ir.PrimExpr]): + self.__init_handle_by_constructor__(ffi.ShapePattern, pattern, shape) # type: ignore + + +@register_df_node +class PrimArrPattern(DFPattern): + """ + A pattern to match an array of PrimExpr + + Parameters + ---------- + shape : List[tvm.ir.PrimExpr] + The shape to match. + """ + + def __init__(self, shape: List[tvm.ir.PrimExpr]): + self.__init_handle_by_constructor__(ffi.PrimArrPattern, shape) # type: ignore + + def __getitem__(self, index: int): + if index >= len(self): + raise IndexError("PrimArrPattern index out of range") + return self.fields[index] + + def __len__(self): + return len(self.fields) + + +@register_df_node +class AttrPattern(DFPattern): + """Get match an expression with a certain attributes. + Currently only supports Op Attributes, not call Attributes. + + Parameters + ---------- + pattern: tvm.relax.dpl.DFPattern + The input pattern. + + attrs: tvm.ir.attrs.Attrs + The attributes to match. + """ + + def __init__(self, pattern: "DFPattern", attrs: tvm.ir.attrs.Attrs): + self.__init_handle_by_constructor__(ffi.AttrPattern, pattern, attrs) # type: ignore + + +def is_var(name: str = "") -> VarPattern: + """ + Syntatic sugar for creating an optionally named VarPattern. + + Parameters + ---------- + name: str + The name of the input pattern to match. + + Returns + ------- + result: tvm.relax.dpl.VarPattern + The resulting pattern. + """ + return VarPattern(name) + + +def is_gv(name: str = "") -> GlobalVarPattern: + """Syntax sugar for creating an optionally (if name is empty) named GlobalVarPattern.""" + return GlobalVarPattern(name) + + +def is_dfv(name: str = "") -> DataflowVarPattern: + """Syntax sugar for creating an optionally (if name is empty) named DataflowVarPattern.""" + return DataflowVarPattern(name) + + +def is_const() -> ConstantPattern: + """ + Syntatic sugar for creating a ConstantPattern. + + Parameters + ---------- + name: str + The name of the input pattern to match. + + Returns + ------- + result: tvm.relax.dpl.ConstantPattern + The resulting pattern. + """ + return ConstantPattern() + + +def is_expr(expr: Expr) -> ExprPattern: + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + expr: Expr + The Relax expression to match. + + Returns + ------- + result: tvm.relax.dpl.ExprPattern + The resulting pattern. + """ + return ExprPattern(expr) + + +def is_op(op_name: str) -> ExprPattern: + """ + Syntatic sugar for creating an operator ExprPattern. + + Parameters + ---------- + op_name: String + The name of the tvm.ir.op.Op object + + Returns + ------- + result: tvm.relax.dpl.ExprPattern + The resulting ExprPattern + """ + op = get(op_name) + return ExprPattern(op) + + +def is_tuple( + fields: Union[Array, List, Tuple], unordered=False +) -> Union[TuplePattern, UnorderedTuplePattern]: + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + fields : Array[tvm.relax.dpl.DFPattern] + The fields in the tuple. + + Returns + ------- + result: tvm.relax.dpl.DFPattern + The resulting pattern. + """ + if not isinstance(fields, (list, tuple, Array)): + raise ValueError("fields must be a list, tuple, or Array") + if unordered: + return UnorderedTuplePattern(fields) + return TuplePattern(fields) + + +def is_tuple_get_item(tuple_value: DFPattern, index: Optional[int] = None) -> TupleGetItemPattern: + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + tuple_value: tvm.relax.dpl.DFPattern + The input tuple expression. + + index: Optional[int] + The index to match; Default (None) to match a TupleGetItem with any index. + + Returns + ------- + result: tvm.relax.dpl.TupleGetItemPattern + The resulting pattern. + """ + return TupleGetItemPattern(tuple_value, index) + + +def wildcard() -> WildcardPattern: + """ + Syntatic sugar for creating a WildcardPattern. + + Returns + ------- + result: tvm.relax.dpl.WildcardPattern + The resulting pattern. + """ + return WildcardPattern() + + +def has_dtype(dtype: str, pattern: DFPattern = None) -> DataTypePattern: + """ + Syntatic sugar for creating a DataTypePattern + + Parameters + ---------- + dtype: str + The dtype to match + + pattern: tvm.relax.dpl.DFPattern + The pattern that needs type annotation + + Returns + ------- + result: tvm.relax.dpl.DataTypePattern + The resulting DataTypePattern + """ + if pattern is None: + pattern = wildcard() + return DataTypePattern(pattern, dtype) + + +def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern": + """ + Directly matches a shape which is an array of PrimExpr + + Parameters + ---------- + shape : List[tvm.ir.PrimExpr] + The expected shape + + Returns + ------- + PrimArrPattern + The resulting PrimArrPattern pattern + + Raises + ------ + ValueError + If the argument shape is not a list/tuple/tvm.ir.Array + + Note + ---- + The difference between p.has_shape(s) and is_shape(s) is that: has_shape + puts assumptions on the shape of the tensor matched by pattern p. While + is_shape directly matches the shape (an array of PrimExpr). + """ + if not isinstance(shape, (list, tuple, tvm.ir.Array)): + raise ValueError("is_shape takes a list or tuple as input.") + return PrimArrPattern(shape) + + +# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo +def _is_call_tir( + func_pattern: DFPattern, + args: Union[List, Tuple, TuplePattern] = None, +) -> CallPattern: + if args is None: + args = wildcard() + elif isinstance(args, (list, tuple)): + args = TuplePattern(args) + + return is_op("relax.call_tir")(func_pattern, args) + + +# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo +def is_call_tir( + func_name: str, + args: Union[List, Tuple, TuplePattern] = None, +) -> CallPattern: + """ + Syntax sugar for creating a CallPattern for call_tir that calls an function through global var. + + Parameters + ---------- + func_name : str + Name of the CPS function to call. + args : Union[List[DFPattern], Tuple[DFPattern]], optional + Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments + + Returns + ------- + CallPattern + The resulting CallPattern + """ + func_pattern = GlobalVarPattern(func_name) + return _is_call_tir(func_pattern, args) + + +def is_call_tir_extern( + func_name: str, + args: Union[List, Tuple, TuplePattern] = None, +) -> CallPattern: + """Syntax sugar for creating a CallPattern for call_tir that calls an extern function + + Parameters + ---------- + func_name : str + Name of the CPS function to call. + args : Union[List[DFPattern], Tuple[DFPattern]], optional + Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments + + Returns + ------- + CallPattern + The resulting CallPattern + """ + func_pattern = ExternFuncPattern(func_name) + return _is_call_tir(func_pattern, args) + + +def is_call_packed( + func_name: str, args: Union[List[DFPattern], Tuple[DFPattern]] = None +) -> CallPattern: + """ + Syntax sugar for creating a CallPattern for call_packed + + Parameters + ---------- + func_name : str + Name of the external function to call + args : Union[List[DFPattern], Tuple[DFPattern]], optional + Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments + + Returns + ------- + CallPattern + The resulting CallPattern + """ + if args is None: + return ExternFuncPattern(func_name)(varg_default_wildcard=True) + return ExternFuncPattern(func_name)(*args) + + +def reject(pattern: DFPattern) -> NotPattern: + """ + Syntax sugar for creating a DFPattern to reject + + Parameters + ---------- + pattern : DFPattern + The pattern to deny + + Returns + ------- + result: NotPattern + The resulting NotPattern + """ + return NotPattern(pattern) + + +def has_attr(attrs, pattern=None) -> AttrPattern: + """ + Syntatic sugar for creating an AttrPattern + + Parameters + ---------- + attrs: Dict[str, Object] + The attributes to match + + pattern: Optional[tvm.relax.dpl.DFPattern] + The input pattern. + + Returns + ------- + result: tvm.relax.dpl.DFPattern + The resulting AttrPattern + """ + if pattern is None: + pattern = wildcard() + return pattern.has_attr(attrs) + + +@register_df_node +class PatternSeq(Node): + """A sequence of patterns with consecutive constraints""" + + def __init__(self, patterns: List[DFPattern], only_use=False): + """ + Initializer to PatternSeq + + Parameters + ---------- + patterns : List[DFPattern] + A chain of patterns + only_use : bool, optional + Whether the patterns follows only-used-by relations consecutively, by default False + """ + self.__init_handle_by_constructor__(ffi.PatternSeq, patterns, only_use) # type: ignore + + def used_by(self, other: Union[DFPattern, "PatternSeq"], index=-1) -> "PatternSeq": + """ + Assuming the right-most pattern must be used by the `other` pattern as a producer + + Parameters + ---------- + other : Union[DFPattern, PatternSeq] + The consumer pattern (sequence) + index : int, optional + The argument index called by the consumer pattern, by default -1 + + Returns + ------- + PatternSeq + A chained pattern sequence + + Note + ---- + If other is PatternSeq, it means the right-most pattern must be used by the left-most + pattern of the other sequence. + """ + return _used_by(self, other, index) + + def only_used_by(self, other: Union[DFPattern, "PatternSeq"], index=-1) -> "PatternSeq": + + """ + Assuming the right-most pattern must be **ONLY** used by the `other` pattern as a producer + + Parameters + ---------- + other : Union[DFPattern, PatternSeq] + The consumer pattern (sequence) + index : int, optional + The argument index called by the consumer pattern, by default -1 + + Returns + ------- + PatternSeq + A chained pattern sequence + + Note + ---- + If other is PatternSeq, it means the right-most pattern must be **ONLY** used by the + left-most pattern of the other sequence. + """ + return _only_used_by(self, other, index) + + def __getitem__(self, index: int) -> DFPattern: + """ + Access the pattern at the given index + + Parameters + ---------- + index : int + Index of the accessed pattern + + Returns + ------- + DFPattern + The accessed pattern + """ + return self.patterns[index] + + def __xor__(self, other) -> "PatternSeq": + """Syntax sugar of PatternSeq.used_by""" + return self.used_by(other, -1) + + def __rshift__(self, other) -> "PatternSeq": + """Syntax sugar of PatternSeq.only_used_by""" + return self.only_used_by(other, -1) + + def dup(self) -> "PatternSeq": + """ + Duplicate the pattern sequence (new object under different address) + + Returns + ------- + PatternSeq + A duplicated chain + """ + return ffi.dup_seq(self) # type: ignore + + +### Private functions + + +def _used_by( + lhs: Union[DFPattern, PatternSeq], + rhs: Union[DFPattern, PatternSeq], + index=-1, +) -> PatternSeq: + if isinstance(lhs, DFPattern): + lhs = PatternSeq([lhs]) + if isinstance(rhs, DFPattern): + rhs = PatternSeq([rhs]) + return ffi.used_by(lhs, rhs, index) # type: ignore + + +def _only_used_by( + lhs: Union[DFPattern, PatternSeq], rhs: Union[DFPattern, PatternSeq], index=-1 +) -> PatternSeq: + if isinstance(lhs, DFPattern): + lhs = PatternSeq([lhs]) + if isinstance(rhs, DFPattern): + rhs = PatternSeq([rhs]) + return ffi.only_used_by(lhs, rhs, index) # type: ignore + + +def _add_bias_activation_pattern(out, with_bias=False, activation=None): + if with_bias: + bias = wildcard() + out = is_op("relax.add")(out, bias) + + if activation: + return is_op(activation)(out) + + return out + + +def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None): + """ + A simple utility to create patterns for an operation fused with bias addition and activation. + + Parameters + ---------- + op_name: str + The name of a Relax op, such as "relax.nn.conv2d" + + with_bias: bool + Whether or not to include bias addition + + activation: str + The name of an activation Relax op, such as "relax.nn.relu" + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a fused operation + """ + lhs = wildcard() + rhs = wildcard() + out = is_op(op_name)(lhs, rhs) + + return _add_bias_activation_pattern(out, with_bias, activation) + + +def make_matmul_pattern(with_bias=False, activation=None, transposed_b=False): + lhs = wildcard() + if transposed_b: + rhs = is_op("relax.permute_dims")(wildcard()) + else: + rhs = wildcard() + out = is_op("relax.matmul")(lhs, rhs) + + return _add_bias_activation_pattern(out, with_bias, activation) diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc new file mode 100644 index 000000000000..f3d9b4686b7d --- /dev/null +++ b/src/relax/analysis/udchain.cc @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/analysis/udchain.cc + * \brief Implementation of use-def analysis. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class UDChain : public relax::ExprVisitor { + public: + // nullptr users means it is the output of the function. + std::map> to_users; + + const VarNode* cur_user_; + + void VisitBinding_(const VarBindingNode* binding) override { + // init + cur_user_ = binding->var.get(); + this->VisitVarDef(binding->var); + this->VisitExpr(binding->value); + cur_user_ = nullptr; + } + + void VisitExpr_(const VarNode* op) override { to_users[op].insert(cur_user_); } + void VisitVarDef(const Var& var) override { to_users[var.get()] = {}; } + void VisitExpr_(const FunctionNode* op) override { ExprVisitor::VisitExpr_(op); } + + void VisitExpr_(const DataflowVarNode* op) override { + VisitExpr_(static_cast(op)); + } +}; + +std::pair>, runtime::Array> FunctionUseDef( + const Function& fn) { + UDChain udchain; + udchain.VisitExpr_(fn.get()); + + Map> user_map; + Array fn_outs; + + for (const auto& kv : udchain.to_users) { + Array uses{}; + uses.reserve(kv.second.size()); + for (const auto& v : kv.second) { + if (nullptr == v && + fn_outs.end() == std::find(fn_outs.begin(), fn_outs.end(), GetRef(kv.first))) { + fn_outs.push_back(GetRef(kv.first)); + } else { + uses.push_back(GetRef(v)); + } + } + user_map.Set(GetRef(kv.first), std::move(uses)); + } + return std::make_pair(std::move(user_map), std::move(fn_outs)); +} + +runtime::Map> DataflowBlockUseDef(const DataflowBlock& dfb) { + UDChain udchain; + udchain.VisitBindingBlock_(dfb.get()); + runtime::Map> ret; + for (const auto& kv : udchain.to_users) { + Array uses{}; + uses.reserve(kv.second.size()); + for (const auto& v : kv.second) uses.push_back(GetRef(v)); + ret.Set(GetRef(kv.first), std::move(uses)); + } + return ret; +} + +TVM_REGISTER_GLOBAL("relax.analysis.udchain").set_body_typed(DataflowBlockUseDef); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/var2value.cc b/src/relax/analysis/var2value.cc new file mode 100644 index 000000000000..be50e9bdcef2 --- /dev/null +++ b/src/relax/analysis/var2value.cc @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { +class Var2ValAnalysis : public relax::ExprVisitor { + public: + tvm::runtime::Map var2value_; + void VisitBinding_(const VarBindingNode* binding) override { + var2value_.Set(binding->var, binding->value); + // Recursively visit the value to handle local functions. + VisitExpr(binding->value); + } +}; + +tvm::runtime::Map AnalyzeVar2Value(const Expr& expr) { + Var2ValAnalysis var2val_analysis; + var2val_analysis.VisitExpr(expr); + return std::move(var2val_analysis.var2value_); +} + +tvm::runtime::Map AnalyzeVar2Value(const DataflowBlock& dfb) { + Var2ValAnalysis var2val_analysis; + var2val_analysis.VisitBindingBlock_(dfb.get()); + return std::move(var2val_analysis.var2value_); +} + +tvm::runtime::Map AnalyzeVar2Value(const IRModule& m) { + Var2ValAnalysis var2val_analysis; + + for (const auto& it : m->functions) { + // visit relax.Function + if (auto* n = it.second.as()) { + var2val_analysis.VisitExpr(GetRef(n)); + } + } + + return std::move(var2val_analysis.var2value_); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.get_var2val")).set_body_typed([](const Function& f) { + return AnalyzeVar2Value(f); +}); + +class Name2BindingAnalysis : public relax::ExprVisitor { + public: + // runtime::Map is not suitable for doing in-place update. + // so we use standard container for internal usage. + std::map> name2bindings_; + void VisitBinding_(const VarBindingNode* binding) override { + const auto& vname = binding->var->name_hint(); + name2bindings_[vname].push_back(GetRef(binding)); + } + + void VisitBinding_(const MatchCastNode* binding) override { + const auto& vname = binding->var->name_hint(); + name2bindings_[vname].push_back(GetRef(binding)); + } +}; + +Map> NameToBinding(const Function& fn) { + Name2BindingAnalysis analysis{}; + analysis.VisitExpr_(fn.get()); + return Map>(std::make_move_iterator(analysis.name2bindings_.begin()), + std::make_move_iterator(analysis.name2bindings_.end())); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.name_to_binding")).set_body_typed(NameToBinding); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc new file mode 100644 index 000000000000..92eb452a0065 --- /dev/null +++ b/src/relax/ir/dataflow_matcher.cc @@ -0,0 +1,768 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_matcher.cc + * \brief The dataflow pattern matcher for Relax. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dataflow_matcher_impl.h" + +namespace tvm { +namespace relax { + +using tvm::arith::Analyzer; + +// Pattern Matcher +bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { + memo_.clear(); + matched_nodes_.clear(); + return VisitDFPattern(pattern, expr); +} + +static Expr TryGetValOfVar(const Expr& expr, const Map& var2val) { + if (var2val.empty()) return expr; + + // if not match, try to match value of var if expr is a var. + if (const VarNode* var = expr.as()) { + auto may = var2val.Get(GetRef(var)); + if (may.defined()) return may.value(); + } + + return expr; +} + +void DFPatternMatcher::ClearMap(size_t watermark) { + for (size_t i = watermark; i < matched_nodes_.size(); ++i) { + memo_.erase(matched_nodes_[i]); + } + matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end()); +} + +bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + if (memoize_ && memo_.count(pattern)) { + ICHECK_EQ(memo_[pattern].size(), 1); + return expr.same_as(memo_[pattern][0]); + } else { + size_t watermark = matched_nodes_.size(); + bool out = DFPatternFunctor::VisitDFPattern(pattern, expr); + if (out) { + memo_[pattern].push_back(expr); + matched_nodes_.push_back(pattern); + } else { + ClearMap(watermark); + } + return out; + } +} + +bool DFPatternMatcher::VisitDFPattern_(const OrPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const AndPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + return VisitDFPattern(op->left, expr) && VisitDFPattern(op->right, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const NotPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + return !VisitDFPattern(op->reject, expr); +} + +bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { + switch (rhs.type_code()) { + case kDLInt: + if (auto* val = lhs.as()) { + return val->value == rhs.operator int64_t(); + } + break; + case kDLFloat: + if (auto* val = lhs.as()) { + return val->value == rhs.operator double(); + } + break; + case kTVMStr: + if (auto* val = lhs.as()) { + return val->value == rhs.operator std::string(); + } else if (auto* val = lhs.as()) { + return val->data == rhs.operator std::string(); + } + break; + case kTVMDataType: + if (auto* val = lhs.as()) { + return rhs.operator std::string() == val->value; + } else if (auto* val = lhs.as()) { + return rhs.operator std::string() == val->data; + } else { + ICHECK(false) << "PatternMatcher: Unsupported TVMDataType " << lhs; + } + break; + case kTVMObjectHandle: + if (rhs.IsObjectRef()) { + if (auto* val = lhs.as()) { + return rhs.operator String() == val->value; + } else if (auto* val = lhs.as()) { + return rhs.operator String() == val->data; + } + } else { + // Compare the objects for structural equality + static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); + ICHECK(structural_equal) << "node.StructuralEqual is not registered."; + if ((*structural_equal)(lhs, GetRef(rhs.ptr()), false, true)) { + return true; + } + } + break; + default: + ICHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code(); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + bool matches = VisitDFPattern(attr_pattern->pattern, expr); + if (!matches) return matches; + VLOG(1) << "considering AttrPatternNode at:\n" << expr; + auto attributes = attr_pattern->attrs.as()->dict; + if (const auto* op_node = expr.as()) { + Op op = GetRef(op_node); + for (auto kv : attributes) { + auto attr_name = kv.first; + auto attr_value = kv.second; + if (Op::HasAttrMap(attr_name)) { + auto op_map = Op::GetAttrMap(attr_name); + if (op_map.count(op)) { + matches &= MatchRetValue(attr_value, op_map[op]); + } else { + matches = false; + } + } else { + matches = false; + } + } + } else if (auto* op = expr.as()) { + matches = true; + // TODO(mbrookhart): When OpNode Attrs move from TVMRetValue to the Object system, remove this + // and replace the whole thing with a Visitor-based approach + ReflectionVTable* reflection = ReflectionVTable::Global(); + auto attrs_node = const_cast(op->attrs.get()); + // attrs may be undefined on non-op calls so we check first + std::vector attr_names; + if (attrs_node) { + attr_names = reflection->ListAttrNames(attrs_node); + } + for (auto kv : attributes) { + std::string attr = kv.first; + if (matches && std::find(attr_names.begin(), attr_names.end(), attr) != attr_names.end()) { + matches &= MatchRetValue(kv.second, reflection->GetAttr(attrs_node, attr)); + } else { + matches = false; + break; + } + } + } else if (auto* op = expr.as()) { + matches = true; + for (auto kv : attributes) { + if (matches && op->attrs.defined() && op->attrs->dict.count(kv.first)) { + matches &= StructuralEqual()(kv.second, op->attrs->dict[kv.first]); + } else { + matches = false; + break; + } + } + } else { + matches = false; + } + return matches; +} + +bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + // utilities + auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* { + if (op) { + if (auto* expr_pattern = op->op.as()) { + return expr_pattern->expr.as(); + } + } + return nullptr; + }; + auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) { + if (const auto* op_node = get_op_node(op)) { + if (op_node->name == op_type) { + return true; + } + } + return false; + }; + auto is_expr_op = [](const Expr& expr, std::string op_type) { + if (const auto* call_node = expr.as()) { + if (const auto* op_node = call_node->op.as()) { + if (op_node->name == op_type) { + return true; + } + } + } + return false; + }; + + // logic + auto watermark = matched_nodes_.size(); + if (const auto* call_node = expr.as()) { + auto matches_op = VisitDFPattern(op->op, call_node->op); + if (matches_op) { + auto watermark2 = matched_nodes_.size(); + + auto match_args = [this, &watermark2](const Array& pattern_args, auto expr_begin, + auto expr_end) { + bool matches = true; + auto pattern_it = pattern_args.begin(); + auto expr_it = expr_begin; + if (pattern_args.defined()) { + while (matches && pattern_it != pattern_args.end()) + matches &= VisitDFPattern(*(pattern_it++), *(expr_it++)); + } + if (!matches) ClearMap(watermark2); + return matches; + }; + + const size_t n_arg_pattern = op->args.size(); + const size_t n_arg_expr = call_node->args.size(); + // if allow variable args, #pattern must >= #expr. + if (op->varg_default_wildcard && n_arg_expr < n_arg_pattern) return false; + // if variable args are not allowed, #pattern must == #expr. + if (!op->varg_default_wildcard && n_arg_expr != n_arg_pattern) return false; + + // Standard case + if (match_args(op->args, call_node->args.begin(), call_node->args.end())) return true; + + // Commutative Matching + if (const OpNode* op_node = get_op_node(op)) { + if ((op_node->name == "relax.add") || (op_node->name == "relax.multiply")) { + if (match_args(op->args, call_node->args.rbegin(), call_node->args.rend())) { + return true; + } + } + } + } else { + ClearMap(watermark); + // associate divide/multiply + if (is_pattern_op(op, "relax.divide")) { + if (const auto* arg_node = op->args[0].as()) { + if (is_pattern_op(arg_node, "relax.multiply") && is_expr_op(expr, "relax.multiply") && + (is_expr_op(call_node->args[0], "relax.divide") || + is_expr_op(call_node->args[1], "relax.divide"))) { + bool out = false; + for (size_t arg_id = 0; arg_id < 2; ++arg_id) { + auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}); + auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div}); + out = VisitDFPattern(mul, expr); + if (out) { + return true; + } else { + ClearMap(watermark); + } + } + return out; + } + } + } + if (is_pattern_op(op, "relax.multiply")) { + // associate multiply/divide + for (size_t arg_id = 0; arg_id < 2; ++arg_id) { + if (auto* arg_node = op->args[arg_id].as()) { + if (is_pattern_op(arg_node, "relax.divide") && is_expr_op(expr, "relax.divide") && + (is_expr_op(call_node->args[0], "relax.multiply") || + is_expr_op(call_node->args[1], "relax.multiply"))) { + auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}); + auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}); + return VisitDFPattern(div, expr); + } + } + } + } + } + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + return StructuralEqual()(op->expr, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + bool matches = false; + if (const auto* func = expr.as()) { + matches = true; + if (op->params.defined()) { + size_t i = 0; + if (op->params.size() == func->params.size()) { + while (matches && i < op->params.size()) { + matches &= VisitDFPattern(op->params[i], func->params[i]); + ++i; + } + } else { + matches = false; + } + } + if (matches) { + matches &= VisitDFPattern(op->body, func->body); + } + } + return matches; +} + +bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + if (const auto* tuple_get_item_node = expr.as()) { + return (op->index == -1 || op->index == tuple_get_item_node->index) && + VisitDFPattern(op->tuple, tuple_get_item_node->tuple); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + bool matches = false; + if (const auto* tuple_node = expr.as()) { + matches = true; + if (op->fields.size() == tuple_node->fields.size()) { + size_t i = 0; + while (matches && i < op->fields.size()) { + matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]); + ++i; + } + } else { + matches = false; + } + } + return matches; +} + +bool DFPatternMatcher::TryUnorderedMatch(size_t idx, const tvm::Array patterns, + const tvm::Array fields, + std::vector& match_cache, + std::vector& matched) { + if (idx >= patterns.size()) return true; + constexpr int8_t kUnknown = -1; + auto this_pattern = patterns[idx]; + for (size_t i = 0; i < fields.size(); ++i) { + if (matched[i]) continue; + const size_t table_idx = idx * fields.size() + i; + match_cache[table_idx] = + kUnknown ? VisitDFPattern(this_pattern, fields[i]) : match_cache[table_idx]; + if (match_cache[table_idx]) { + // continue to match the rest; + matched[i] = true; + if (TryUnorderedMatch(idx + 1, patterns, fields, match_cache, matched)) return true; + matched[i] = false; + } + } + + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const UnorderedTuplePatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + + if (const auto* tuple_node = expr.as()) { + if (op->fields.size() == tuple_node->fields.size()) { + constexpr int8_t kUnknown = -1; + ICHECK_LE(op->fields.size(), std::numeric_limits::max()) << "Too many fields!"; + // dynamic programming. + std::vector match_cache(op->fields.size() * op->fields.size(), kUnknown); + std::vector field_match_bitmap(op->fields.size(), false); + return TryUnorderedMatch(0, op->fields, tuple_node->fields, match_cache, field_match_bitmap); + } + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + auto expr_type = expr.as()->checked_type(); + return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); +} + +static bool ShapeEqual(Analyzer* analyzer, const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) + if (!tir::is_one(analyzer->Simplify(lhs[i] == rhs[i]))) return false; + return true; +} + +bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) { + // no need to jump, as var.shape == value.shape + if (const auto* tinfo = GetStructInfoAs(expr)) { + if (const ShapeExprNode* shape_expr = tinfo->shape.as()) { + return ShapeEqual(&analyzer_, op->shape, shape_expr->values) && + VisitDFPattern(op->pattern, expr); + } + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const PrimArrPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + if (const ShapeExprNode* shape_expr = expr.as()) + return ShapeEqual(&analyzer_, op->fields, shape_expr->values); + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) { + // no need to jump, as var.dtype == value.dtype + auto expr_type = expr.as()->checked_type(); + if (const DynTensorTypeNode* tensor_type = expr_type.as()) { + return (StructuralEqual()(op->dtype, tensor_type->dtype)) && VisitDFPattern(op->pattern, expr); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) { + // We don't jump for var pattern, as there's no need to access its value to judge it. + if (const auto* var_node = expr.as()) { + // "" means any name. + return "" == op->name_hint() || op->name_hint() == var_node->name_hint(); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const ExternFuncPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + if (const auto* extern_fn = expr.as()) { + return "" == op->global_symbol() || op->global_symbol() == extern_fn->global_symbol; + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr0) { + // constants can be binded to relax.Var as well. + auto expr = TryGetValOfVar(expr0, var2val_); + return expr.as() != nullptr; +} + +bool DFPatternMatcher::VisitDFPattern_(const DataflowVarPatternNode* op, const Expr& expr) { + // DataflowVar is inherented from Var, so dispatch it to VarPattern. + return expr->IsInstance() && + VisitDFPattern_(static_cast(op), expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const GlobalVarPatternNode* op, const Expr& expr) { + // GlobalVarPattern is not inherited from Var, so we need to handle it separately. + if (const auto* var_node = expr.as()) + return "" == op->name_hint() || op->name_hint() == var_node->name_hint; + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) { + return true; +} + +Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, + Optional> bindings_opt) { + auto bindings = bindings_opt.value_or({}); + DFPatternMatcher matcher(bindings); + + if (!matcher.Match(pattern, expr)) { + return NullOpt; + } + + Map matching; + for (const auto& [pat, matches] : matcher.GetMemo()) { + ICHECK_EQ(matches.size(), 1) << "More than one match for the pattern " << pat; + matching.Set(pat, matches[0]); + } + return matching; +} + +bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { + return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); + +struct PNode { + const DFPatternNode* ptr; + const VarNode* matched = nullptr; + std::vector&>> children; + std::vector&>> parents; +}; + +struct RNode { + const VarNode* ptr; + const DFPatternNode* matched = nullptr; + std::vector children; + std::vector parents; +}; + +/** + * \brief This method try to match a real node and a pattern node along with its neighbors. + */ +static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, + const std::map>& def2use, + const std::map>& use2def) { + if (nullptr != p->matched && p->matched == r->ptr) return true; // matched before. + if (!m->Match(GetRef(p->ptr), GetRef(r->ptr))) return false; + + std::stack> undo_stack{}; + + const auto commit = [&undo_stack](PNode* p, RNode* r) { + // match with each other. + p->matched = r->ptr; + r->matched = p->ptr; + undo_stack.emplace(p, r); + }; + + const auto quit = [&undo_stack] { + while (!undo_stack.empty()) { + auto& top = undo_stack.top(); + top.first->matched = nullptr; + top.second->matched = nullptr; + undo_stack.pop(); + } + return false; + }; + + commit(p, r); + + // match parent patterns. + for (auto& pparent_pairs : p->parents) { + PNode* pparent = pparent_pairs.first; + const std::vector& constraints = pparent_pairs.second; + + bool any_cons_sat = false; + for (auto& rparent : r->parents) { + // skip if mismatch. + if (rparent->matched && rparent->matched != pparent->ptr) continue; + + const auto& uses = def2use.at(rparent->ptr); + // skip if `rparent` is not used by `r`. + if (uses.cend() == uses.find(r->ptr)) continue; + + // check edge constraints. + bool cons_sat = true; + for (const auto& cons : constraints) { + if (PairCons::kOnlyUsedBy == cons.type && uses.size() != 1) { + cons_sat = false; + break; + } + + if (-1 != cons.index) { + const auto& callees = use2def.at(r->ptr); + if (static_cast(cons.index) >= callees.size() || + rparent->ptr != callees[cons.index]) { + cons_sat = false; + break; + } + } + } + if (!cons_sat) continue; + any_cons_sat = true; + + // try all parent R nodes that are not matched yet. + // as long as ppattern can match one node. + if (!pparent->matched && try_match(pparent, rparent, m, def2use, use2def)) { + commit(pparent, rparent); + break; + } + } + if (!pparent->matched || !any_cons_sat) return quit(); + } + + // forward matching; + for (auto& pchild_pairs : p->children) { + PNode* pchild = pchild_pairs.first; + const std::vector& constraints = pchild_pairs.second; + bool any_cons_sat = false; + for (auto& rchild : r->children) { + if (rchild->matched && rchild->matched != pchild->ptr) continue; + + const auto& uses = def2use.at(r->ptr); + if (uses.cend() == uses.find(rchild->ptr)) continue; + + // check edge constraints. + bool all_cons_pass = true; + for (const auto& cons : constraints) { + if (PairCons::kOnlyUsedBy == cons.type && uses.size() != 1) { + all_cons_pass = false; + break; + } + + if (-1 != cons.index) { + const auto& callees = use2def.at(rchild->ptr); + if (static_cast(cons.index) >= callees.size() || r->ptr != callees[cons.index]) { + all_cons_pass = false; + break; + } + } + } + if (!all_cons_pass) continue; + any_cons_sat = true; + + if (!pchild->matched && try_match(pchild, rchild, m, def2use, use2def)) { + commit(pchild, rchild); + break; + } + } + if (!pchild->matched || !any_cons_sat) return quit(); + } + + return true; +} + +class MatcherUseDefAnalysis : public relax::ExprVisitor { + public: + std::map> def2use; + // caller -> callee table. + std::map> caller2callees; + + const VarNode* cur_user_; + + void VisitBinding_(const VarBindingNode* binding) override { + // init + cur_user_ = binding->var.get(); + this->VisitVarDef(binding->var); + this->VisitExpr(binding->value); + cur_user_ = nullptr; + } + + void VisitExpr_(const VarNode* op) override { + if (nullptr == cur_user_) return; + + def2use[op].insert(cur_user_); + caller2callees[cur_user_].push_back(op); + } + + void VisitExpr_(const DataflowVarNode* op) override { + VisitExpr_(static_cast(op)); + } +}; + +Map MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, + Optional start_hint, bool must_include_hint) { + Map ret; + // TODO(@ganler): Handle non-may external use. + ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; + ICHECK(!must_include_hint || start_hint.defined()) + << "must_include_hint is only supported with start_hint."; + + const auto var2val = AnalyzeVar2Value(dfb); + DFPatternMatcher matcher(var2val); + + // std::map> + MatcherUseDefAnalysis ud_analysis; + ud_analysis.VisitBindingBlock_(dfb.get()); + const auto& def2use = ud_analysis.def2use; + const auto& caller2callees = ud_analysis.caller2callees; + + // First construct a graph of PNode and RNode. + std::unordered_map var2node; + var2node.reserve(dfb->bindings.size()); + + for (const auto& du : def2use) { + const VarNode* cur_var = du.first; + const std::set& uses = du.second; + RNode& cur_node = var2node[cur_var]; + cur_node.ptr = cur_var; + for (const VarNode* use : uses) { + auto& use_node = var2node[use]; + use_node.ptr = use; + cur_node.children.push_back(&use_node); + use_node.parents.push_back(&cur_node); + } + } + + std::unordered_map pattern2node; + pattern2node.reserve(ctx->constraints.size()); + + for (const auto& def2use_pattern : ctx->constraints) { + const DFPatternNode* def_pattern = def2use_pattern.first.get(); + const std::map>& uses = def2use_pattern.second; + PNode& def_node = pattern2node[def_pattern]; + def_node.ptr = def_pattern; + def_node.children.reserve(uses.size()); + for (const auto& use : uses) { + const auto& cons = use.second; + const DFPatternNode* use_pattern = use.first.get(); + PNode& use_node = pattern2node[use_pattern]; + use_node.ptr = use_pattern; + use_node.parents.emplace_back(&def_node, std::ref(cons)); + def_node.children.emplace_back(&use_node, std::ref(cons)); + } + } + + if (start_hint.defined()) { + Var v = start_hint.value(); + auto rnode_ptr = var2node.find(v.get()); + for (auto& ppair : pattern2node) { + if (try_match(&ppair.second, &rnode_ptr->second, &matcher, def2use, caller2callees)) { + for (auto ppair : pattern2node) + ret.Set(GetRef(ppair.first), GetRef(ppair.second.matched)); + return ret; + } + } + + if (must_include_hint) return ret; + } + + PNode* pnode_start = &pattern2node.begin()->second; + + if (!pnode_start->matched) { + for (auto& rpair : var2node) { + if (start_hint.defined() && start_hint.value().get() == rpair.first) continue; + if (try_match(pnode_start, &rpair.second, &matcher, def2use, caller2callees)) { + for (auto ppair : pattern2node) + ret.Set(GetRef(ppair.first), GetRef(ppair.second.matched)); + + return ret; + } + } + } + + return ret; +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_matcher_impl.h b/src/relax/ir/dataflow_matcher_impl.h new file mode 100644 index 000000000000..89f3d114c1e3 --- /dev/null +++ b/src/relax/ir/dataflow_matcher_impl.h @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relax/dataflow_matcher_impl.h + * \brief The auxiliary data structure for dataflow matcher. + */ +#ifndef TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ +#define TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ + +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relax { + +class DFPatternMatcher : public DFPatternFunctor { + public: + using var2val_t = runtime::Map; + + explicit DFPatternMatcher() {} + explicit DFPatternMatcher(var2val_t var2val) : var2val_(std::move(var2val)) {} + bool Match(const DFPattern& pattern, const Expr& expr); + Map> GetMemo() { return Map>(memo_); } + + protected: + bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; + bool VisitDFPattern_(const OrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const AndPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const NotPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override; + + bool VisitDFPattern_(const DataflowVarPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const GlobalVarPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ExternFuncPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const PrimArrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const UnorderedTuplePatternNode* op, const Expr& expr) override; + + void ClearMap(size_t watermark); + bool TryUnorderedMatch(size_t idx, const tvm::Array patterns, + const tvm::Array fields, std::vector& match_cache, + std::vector& matched); + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; + var2val_t var2val_; + std::vector matched_nodes_; + arith::Analyzer analyzer_; + bool memoize_ = true; +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc new file mode 100644 index 000000000000..3768627c204c --- /dev/null +++ b/src/relax/ir/dataflow_pattern.cc @@ -0,0 +1,607 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_pattern.cc + * \brief The dataflow pattern language for Relax (inherited from Relay). + */ + +#include +#include + +#include +#include +#include + +#define RELAX_PATTERN_PRINTER_DEF(NODE_TYPE, REPR_LAMBDA) \ + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) \ + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { \ + auto* node = static_cast(ref.get()); \ + REPR_LAMBDA(p, node); \ + }) + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(ExternFuncPatternNode); +ExternFuncPattern::ExternFuncPattern(String global_symbol) { + ObjectPtr n = make_object(); + n->global_symbol_ = std::move(global_symbol); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.ExternFuncPattern").set_body_typed([](String global_symbol) { + return ExternFuncPattern(global_symbol); +}); +RELAX_PATTERN_PRINTER_DEF(ExternFuncPatternNode, [](auto p, auto node) { + p->stream << "ExternFuncPattern(" << node->global_symbol() << ")"; +}); + +TVM_REGISTER_NODE_TYPE(VarPatternNode); +VarPattern::VarPattern(String name_hint) { + ObjectPtr n = make_object(); + n->name = std::move(name_hint); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.VarPattern").set_body_typed([](String name_hint) { + return VarPattern(name_hint); +}); +RELAX_PATTERN_PRINTER_DEF(VarPatternNode, [](auto p, auto node) { + p->stream << "VarPattern(" << node->name_hint() << ")"; +}); + +TVM_REGISTER_NODE_TYPE(DataflowVarPatternNode); +TVM_REGISTER_GLOBAL("relax.dpl.DataflowVarPattern").set_body_typed([](String name_hint) { + return DataflowVarPattern(name_hint); +}); +DataflowVarPattern::DataflowVarPattern(String name_hint) { + ObjectPtr n = make_object(); + n->name = std::move(name_hint); + data_ = std::move(n); +} +RELAX_PATTERN_PRINTER_DEF(DataflowVarPatternNode, [](auto p, auto node) { + p->stream << "DataflowVarPattern(" << node->name_hint() << ")"; +}); + +TVM_REGISTER_NODE_TYPE(GlobalVarPatternNode); +GlobalVarPattern::GlobalVarPattern(String name_hint) { + ObjectPtr n = make_object(); + n->name = std::move(name_hint); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.GlobalVarPattern").set_body_typed([](String name_hint) { + return GlobalVarPattern(name_hint); +}); +RELAX_PATTERN_PRINTER_DEF(GlobalVarPatternNode, [](auto p, auto node) { + p->stream << "GlobalVarPattern(" << node->name_hint() << ")"; +}); + +TVM_REGISTER_NODE_TYPE(ExprPatternNode); +ExprPattern::ExprPattern(Expr expr) { + ObjectPtr n = make_object(); + n->expr = std::move(expr); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.ExprPattern").set_body_typed([](Expr e) { return ExprPattern(e); }); +RELAX_PATTERN_PRINTER_DEF(ExprPatternNode, [](auto p, auto node) { p->Print(node->expr); }); + +TVM_REGISTER_NODE_TYPE(ConstantPatternNode); +TVM_REGISTER_GLOBAL("relax.dpl.ConstantPattern").set_body_typed([]() { + auto c = ConstantPattern(make_object()); + return c; +}); +RELAX_PATTERN_PRINTER_DEF(ConstantPatternNode, + [](auto p, auto node) { p->stream << "ConstantPattern()"; }); + +TVM_REGISTER_NODE_TYPE(CallPatternNode); +CallPattern::CallPattern(DFPattern op, Array args, bool varg_default_wildcard) { + ObjectPtr n = make_object(); + n->op = std::move(op); + n->args = std::move(args); + n->varg_default_wildcard = varg_default_wildcard; + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.CallPattern") + .set_body_typed([](DFPattern op, Array args, bool varg_default_wildcard) { + return CallPattern(op, args, varg_default_wildcard); + }); +RELAX_PATTERN_PRINTER_DEF(CallPatternNode, [](auto p, auto node) { + p->stream << node->op << "("; + for (size_t i = 0; i < node->args.size(); ++i) { + if (i != 0) p->stream << ", "; + p->stream << node->args[i]; + } + if (node->varg_default_wildcard) { + if (node->args.size() != 0) p->stream << ", "; + p->stream << "..."; + } + p->stream << ")"; +}); + +TVM_REGISTER_NODE_TYPE(PrimArrPatternNode); +PrimArrPattern::PrimArrPattern(Array arr) { + ObjectPtr n = make_object(); + n->fields = std::move(arr); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.PrimArrPattern").set_body_typed([](Array arr) { + return PrimArrPattern(std::move(arr)); +}); +RELAX_PATTERN_PRINTER_DEF(PrimArrPatternNode, [](auto p, auto node) { + p->stream << "PrimArrPattern(" << node->fields << ")"; +}); + +TVM_REGISTER_NODE_TYPE(FunctionPatternNode); +FunctionPattern::FunctionPattern(Array params, DFPattern body) { + ObjectPtr n = make_object(); + n->params = std::move(params); + n->body = std::move(body); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.FunctionPattern") + .set_body_typed([](Array params, DFPattern body) { + return FunctionPattern(params, body); + }); +RELAX_PATTERN_PRINTER_DEF(FunctionPatternNode, [](auto p, auto node) { + p->stream << "FunctionPattern(" << node->params << ", " << node->body << ")"; +}); + +TVM_REGISTER_NODE_TYPE(TuplePatternNode); +TuplePattern::TuplePattern(tvm::Array fields) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.TuplePattern").set_body_typed([](tvm::Array fields) { + return TuplePattern(fields); +}); +RELAX_PATTERN_PRINTER_DEF(TuplePatternNode, [](auto p, auto node) { + p->stream << "TuplePattern(" << node->fields << ")"; +}); + +TVM_REGISTER_NODE_TYPE(UnorderedTuplePatternNode); +UnorderedTuplePattern::UnorderedTuplePattern(tvm::Array fields) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.UnorderedTuplePattern") + .set_body_typed([](tvm::Array fields) { return UnorderedTuplePattern(fields); }); +RELAX_PATTERN_PRINTER_DEF(UnorderedTuplePatternNode, [](auto p, auto node) { + p->stream << "UnorderedTuplePattern(" << node->fields << ")"; +}); + +TVM_REGISTER_NODE_TYPE(TupleGetItemPatternNode); +TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) { + ObjectPtr n = make_object(); + n->tuple = std::move(tuple); + n->index = index; + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.TupleGetItemPattern").set_body_typed([](DFPattern tuple, int index) { + return TupleGetItemPattern(tuple, index); +}); +RELAX_PATTERN_PRINTER_DEF(TupleGetItemPatternNode, [](auto p, auto node) { + p->stream << "TupleGetItemPattern(" << node->tuple << ", " << node->index << ")"; +}); + +TVM_REGISTER_NODE_TYPE(AndPatternNode); +AndPattern::AndPattern(DFPattern left, DFPattern right) { + ObjectPtr n = make_object(); + n->left = std::move(left); + n->right = std::move(right); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.AndPattern").set_body_typed([](DFPattern left, DFPattern right) { + return AndPattern(left, right); +}); +RELAX_PATTERN_PRINTER_DEF(AndPatternNode, [](auto p, auto node) { + p->stream << "AndPattern(" << node->left << " & " << node->right << ")"; +}); + +TVM_REGISTER_NODE_TYPE(OrPatternNode); +OrPattern::OrPattern(DFPattern left, DFPattern right) { + ObjectPtr n = make_object(); + n->left = std::move(left); + n->right = std::move(right); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.OrPattern").set_body_typed([](DFPattern left, DFPattern right) { + return OrPattern(left, right); +}); +RELAX_PATTERN_PRINTER_DEF(OrPatternNode, [](auto p, auto node) { + p->stream << "OrPattern(" << node->left << " | " << node->right << ")"; +}); + +TVM_REGISTER_NODE_TYPE(NotPatternNode); +NotPattern::NotPattern(DFPattern reject) { + ObjectPtr n = make_object(); + n->reject = std::move(reject); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.NotPattern").set_body_typed([](DFPattern reject) { + return NotPattern(reject); +}); +RELAX_PATTERN_PRINTER_DEF(NotPatternNode, + [](auto p, auto node) { p->stream << "!(" << node->reject << ")"; }); + +TVM_REGISTER_NODE_TYPE(WildcardPatternNode); +TVM_REGISTER_GLOBAL("relax.dpl.WildcardPattern").set_body_typed([]() { + auto w = WildcardPattern(make_object()); + return w; +}); +RELAX_PATTERN_PRINTER_DEF(WildcardPatternNode, [](auto p, auto node) { p->stream << "*"; }); + +TVM_REGISTER_NODE_TYPE(TypePatternNode); +TypePattern::TypePattern(DFPattern pattern, Type type) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->type = std::move(type); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.TypePattern").set_body_typed([](DFPattern pattern, Type type) { + return TypePattern(pattern, type); +}); +RELAX_PATTERN_PRINTER_DEF(TypePatternNode, [](auto p, auto node) { + p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")"; +}); + +TVM_REGISTER_NODE_TYPE(ShapePatternNode); +ShapePattern::ShapePattern(DFPattern pattern, Array shape) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->shape = std::move(shape); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.ShapePattern") + .set_body_typed([](DFPattern pattern, Array shape) { + return ShapePattern(pattern, shape); + }); +RELAX_PATTERN_PRINTER_DEF(ShapePatternNode, [](auto p, auto node) { + p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")"; +}); + +TVM_REGISTER_NODE_TYPE(DataTypePatternNode); +DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->dtype = std::move(dtype); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.DataTypePattern") + .set_body_typed([](DFPattern pattern, DataType dtype) { + return DataTypePattern(pattern, dtype); + }); +RELAX_PATTERN_PRINTER_DEF(DataTypePatternNode, [](auto p, auto node) { + p->stream << "DataTypePattern(" << node->pattern << " has dtype " << node->dtype << ")"; +}); + +TVM_REGISTER_NODE_TYPE(AttrPatternNode); +AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->attrs = std::move(attrs); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.AttrPattern").set_body_typed([](DFPattern pattern, DictAttrs attrs) { + return AttrPattern(pattern, attrs); +}); +RELAX_PATTERN_PRINTER_DEF(AttrPatternNode, [](auto p, auto node) { + p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; +}); + +class DFPatternDuplicator : public DFPatternFunctor { + public: + DFPattern VisitDFPattern(const DFPattern& pattern) override { + return DFPatternFunctor::VisitDFPattern(pattern); + } + DFPattern VisitDFPattern_(const OrPatternNode* op) override { + return OrPattern(op->left, op->right); + } + DFPattern VisitDFPattern_(const AndPatternNode* op) override { + return AndPattern(op->left, op->right); + } + DFPattern VisitDFPattern_(const NotPatternNode* op) override { return NotPattern(op->reject); } + DFPattern VisitDFPattern_(const VarPatternNode* op) override { return VarPattern(op->name); } + DFPattern VisitDFPattern_(const ConstantPatternNode* op) override { + return ConstantPattern(make_object()); + } + DFPattern VisitDFPattern_(const WildcardPatternNode* op) override { + return WildcardPattern(make_object()); + } + DFPattern VisitDFPattern_(const ExprPatternNode* op) override { return ExprPattern(op->expr); } + DFPattern VisitDFPattern_(const GlobalVarPatternNode* op) override { + return GlobalVarPattern(op->name); + } + DFPattern VisitDFPattern_(const TuplePatternNode* op) override { + return TuplePattern(op->fields); + } + DFPattern VisitDFPattern_(const UnorderedTuplePatternNode* op) override { + return UnorderedTuplePattern(op->fields); + } + DFPattern VisitDFPattern_(const TupleGetItemPatternNode* op) override { + return TupleGetItemPattern(op->tuple, op->index); + } + DFPattern VisitDFPattern_(const CallPatternNode* op) override { + return CallPattern(op->op, op->args); + } + DFPattern VisitDFPattern_(const DataTypePatternNode* op) override { + return DataTypePattern(op->pattern, op->dtype); + } + DFPattern VisitDFPattern_(const FunctionPatternNode* op) override { + return FunctionPattern(op->params, op->body); + } + DFPattern VisitDFPattern_(const ShapePatternNode* op) override { + return ShapePattern(op->pattern, op->shape); + } + DFPattern VisitDFPattern_(const TypePatternNode* op) override { + return TypePattern(op->pattern, op->type); + } + DFPattern VisitDFPattern_(const DataflowVarPatternNode* op) override { + return DataflowVarPattern(op->name); + } + DFPattern VisitDFPattern_(const ExternFuncPatternNode* op) override { + return ExternFuncPattern(op->global_symbol()); + } + DFPattern VisitDFPattern_(const PrimArrPatternNode* op) override { + return PrimArrPattern(op->fields); + } +}; + +// Syntatic Sugar +CallPattern DFPattern::operator()(const std::vector& args) const { + return CallPattern(*this, Array(args)); +} +OrPattern DFPattern::operator|(const DFPattern& other) const { return OrPattern(*this, other); } + +AndPattern DFPattern::operator&(const DFPattern& other) const { return AndPattern(*this, other); } + +NotPattern DFPattern::operator~() const { return NotPattern(*this); } + +AttrPattern DFPattern::HasAttr(const Map& attrs) const { + return AttrPattern(*this, DictAttrs(attrs)); +} +TypePattern DFPattern::HasType(const Type& type) const { return TypePattern(*this, type); } +DataTypePattern DFPattern::HasDtype(const DataType& dtype) const { + return DataTypePattern(*this, dtype); +} +DataTypePattern DFPattern::HasDtype(const std::string& dtype) const { + return HasDtype(DataType(runtime::String2DLDataType(dtype))); +} +ShapePattern DFPattern::HasShape(const Array& shape) const { + return ShapePattern(*this, shape); +} + +DFPattern::operator PatternSeq() const { return PatternSeq{{*this}}; } + +std::stack& pattern_ctx_stack() { + thread_local std::stack graph_pattern_managers; + return graph_pattern_managers; +} + +PatternContext PatternContext::Current() { + ICHECK(!pattern_ctx_stack().empty()) << "No active PatternContext found."; + return pattern_ctx_stack().top(); +} + +PatternContext::PatternContext(bool incremental) { + auto n = make_object(); + if (incremental) { + ICHECK(!pattern_ctx_stack().empty()) + << "Incremental context needs to be built inside a existing context."; + n->allow_extern_use = pattern_ctx_stack().top()->allow_extern_use; + n->constraints = pattern_ctx_stack().top()->constraints; + } + + data_ = std::move(n); +} + +void PatternContext::EnterWithScope() { pattern_ctx_stack().push(*this); } + +void PatternContext::ExitWithScope() { + ICHECK(pattern_ctx_stack().top().same_as(*this)); + pattern_ctx_stack().pop(); +} + +static void sync_graph_constraints(const DFPattern& lhs, const DFPattern& rhs, PairCons pcon) { + PatternContext::Current().add_constraint(lhs, rhs, pcon); +} + +TVM_REGISTER_NODE_TYPE(PatternSeqNode); +PatternSeq::PatternSeq(DFPattern init_pattern) { + ObjectPtr n = make_object(); + n->patterns = {init_pattern}; + n->pair_constraints = {}; + data_ = std::move(n); +} +PatternSeq::PatternSeq(tvm::Array patterns, bool only_used_by) { + ICHECK_GE(patterns.size(), 1) << "PatternSeq must have at least one pattern"; + const auto cons = PairCons(only_used_by ? PairCons::kOnlyUsedBy : PairCons::kUsedBy); + + ObjectPtr n = make_object(); + n->patterns = std::move(patterns); + n->pair_constraints = std::vector(n->patterns.size() - 1, cons); + data_ = std::move(n); +} + +PatternSeq PatternSeq::UsedBy(PatternSeq other, int index) const { + return relax::UsedBy(*this, other, index); +} + +PatternSeq PatternSeq::OnlyUsedBy(PatternSeq other, int index) const { + return relax::OnlyUsedBy(*this, other, index); +} + +PatternSeq PatternSeq::dup() const { + PatternSeq ret; + + ObjectPtr n = make_object(); + n->patterns = Array{}; + n->patterns.reserve(get()->patterns.size()); + n->pair_constraints = this->get()->pair_constraints; + + for (size_t i = 0; i < get()->patterns.size(); ++i) { + n->patterns.push_back(get()->patterns[i].dup()); + if (i >= 1) + sync_graph_constraints(n->patterns[i - 1], n->patterns[i], n->pair_constraints[i - 1]); + } + + ret.data_ = std::move(n); + + return ret; +} +TVM_REGISTER_GLOBAL("relax.dpl.PatternSeq") + .set_body_typed([](Array patterns, bool only_used_by) { + return PatternSeq(std::move(patterns), only_used_by); + }); +RELAX_PATTERN_PRINTER_DEF(PatternSeqNode, [](auto p, auto node) { + p->stream << "["; + for (size_t i = 0; i < node->patterns.size(); ++i) { + if (i != 0) + p->stream << (PairCons::kOnlyUsedBy == node->pair_constraints[i].type ? " >> " : " ^ "); + p->stream << node->patterns[i]; + } + p->stream << "]"; +}); + +TVM_REGISTER_GLOBAL("relax.dpl.used_by") + .set_body_typed([](PatternSeq lhs, PatternSeq rhs, int index) { + return lhs.UsedBy(rhs, index); + }); + +TVM_REGISTER_GLOBAL("relax.dpl.only_used_by") + .set_body_typed([](PatternSeq lhs, PatternSeq rhs, int index) { + return lhs.OnlyUsedBy(rhs, index); + }); + +PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { + PatternSeq ret; + + const auto constraint = PairCons{PairCons::kOnlyUsedBy, index}; + + sync_graph_constraints(lhs->patterns.back(), rhs->patterns.front(), + PairCons{PairCons::kUsedBy, index}); + + Array patterns; + patterns.reserve(lhs->patterns.size() + rhs->patterns.size()); + patterns.insert(patterns.end(), lhs->patterns.begin(), lhs->patterns.end()); + patterns.insert(patterns.end(), rhs->patterns.begin(), rhs->patterns.end()); + + std::vector pair_constraints = lhs->pair_constraints; + pair_constraints.reserve(pair_constraints.size() + rhs->pair_constraints.size() + 1); + pair_constraints.push_back(constraint); + pair_constraints.insert(pair_constraints.end(), rhs->pair_constraints.begin(), + rhs->pair_constraints.end()); + + ObjectPtr n = make_object(); + n->patterns = std::move(patterns); + n->pair_constraints = std::move(pair_constraints); + ret.data_ = std::move(n); + + return ret; +} +PatternSeq operator^(const PatternSeq& lhs, const PatternSeq& rhs) { return lhs.UsedBy(rhs); } + +PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { + PatternSeq ret; + + const auto constraint = PairCons{PairCons::kOnlyUsedBy, index}; + + sync_graph_constraints(lhs->patterns.back(), rhs->patterns.front(), constraint); + + Array patterns; + patterns.reserve(lhs->patterns.size() + rhs->patterns.size()); + patterns.insert(patterns.end(), lhs->patterns.begin(), lhs->patterns.end()); + patterns.insert(patterns.end(), rhs->patterns.begin(), rhs->patterns.end()); + + std::vector pair_constraints = lhs->pair_constraints; + pair_constraints.reserve(pair_constraints.size() + rhs->pair_constraints.size() + 1); + pair_constraints.push_back(constraint); + pair_constraints.insert(pair_constraints.end(), rhs->pair_constraints.begin(), + rhs->pair_constraints.end()); + + ObjectPtr n = make_object(); + n->patterns = std::move(patterns); + n->pair_constraints = std::move(pair_constraints); + ret.data_ = std::move(n); + + return ret; +} +PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs) { return lhs.OnlyUsedBy(rhs); } + +VarPattern IsVar(const String& name) { return VarPattern(name); } +ConstantPattern IsConst() { return ConstantPattern(make_object()); } +WildcardPattern Wildcard() { return WildcardPattern(make_object()); } +ExprPattern IsExpr(const Expr& expr) { return ExprPattern(expr); } +ExprPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); } +CallPattern IsCallTIR(const String& name, Optional var_args) { + DFPattern arg_pattern; + if (!var_args.defined()) { + arg_pattern = Wildcard(); + } else { + arg_pattern = var_args.value(); + } + + return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern); +} + +CallPattern IsCallTIR(const String& name, TuplePattern var_args) { + return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args); +} + +DFPattern IsTuple(const Array& fields, bool unordered) { + if (unordered) + return UnorderedTuplePattern(fields); + else + return TuplePattern(fields); +} +TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index) { + return TupleGetItemPattern(tuple, index); +} + +DFPattern DFPattern::dup() const { + auto pattern = DFPatternDuplicator().VisitDFPattern(*this); + return pattern; +} + +TVM_REGISTER_GLOBAL("relax.dpl.dup_pattern").set_body_typed([](DFPattern pattern) { + return pattern.dup(); +}); + +TVM_REGISTER_GLOBAL("relax.dpl.dup_seq").set_body_typed([](PatternSeq seq) { return seq.dup(); }); + +TVM_REGISTER_GLOBAL("relax.dpl.PatternContext").set_body_typed([](bool incre) { + return PatternContext(incre); +}); + +TVM_REGISTER_GLOBAL("relax.dpl.current_context").set_body_typed([] { + return PatternContext::Current(); +}); + +class PatternContext::Internal { + public: + static void EnterScope(PatternContext pass_ctx) { pass_ctx.EnterWithScope(); } + static void ExitScope(PatternContext pass_ctx) { pass_ctx.ExitWithScope(); } +}; + +TVM_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed(PatternContext::Internal::EnterScope); + +TVM_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed(PatternContext::Internal::ExitScope); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_pattern_functor.cc b/src/relax/ir/dataflow_pattern_functor.cc new file mode 100644 index 000000000000..37a98f28beef --- /dev/null +++ b/src/relax/ir/dataflow_pattern_functor.cc @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_matcher.cc + * \brief The dataflow pattern matcher for Relay. + */ + +#include + +namespace tvm { +namespace relax { + +// DFPatternVisitor + +void DFPatternVisitor::VisitDFPattern(const DFPattern& pattern) { + if (this->visited_.count(pattern.get()) == 0) { + visited_.insert(pattern.get()); + DFPatternFunctor::VisitDFPattern(pattern); + } +} + +void DFPatternVisitor::VisitDFPattern_(const OrPatternNode* op) { + VisitDFPattern(op->left); + VisitDFPattern(op->right); +} + +void DFPatternVisitor::VisitDFPattern_(const AndPatternNode* op) { + VisitDFPattern(op->left); + VisitDFPattern(op->right); +} + +void DFPatternVisitor::VisitDFPattern_(const NotPatternNode* op) { VisitDFPattern(op->reject); } + +void DFPatternVisitor::VisitDFPattern_(const AttrPatternNode* op) { VisitDFPattern(op->pattern); } + +void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) { + VisitDFPattern(op->op); + if (op->args.defined()) { + for (auto arg : op->args) { + VisitDFPattern(arg); + } + } +} + +void DFPatternVisitor::VisitDFPattern_(const DataTypePatternNode* op) { + VisitDFPattern(op->pattern); +} + +void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {} + +void DFPatternVisitor::VisitDFPattern_(const FunctionPatternNode* op) { + if (op->params.defined()) { + for (auto param : op->params) { + VisitDFPattern(param); + } + } + VisitDFPattern(op->body); +} + +void DFPatternVisitor::VisitDFPattern_(const ShapePatternNode* op) { VisitDFPattern(op->pattern); } + +void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) { + VisitDFPattern(op->tuple); +} + +void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) { + if (op->fields.defined()) { + for (auto field : op->fields) { + VisitDFPattern(field); + } + } +} + +void DFPatternVisitor::VisitDFPattern_(const UnorderedTuplePatternNode* op) { + if (op->fields.defined()) { + for (auto field : op->fields) { + VisitDFPattern(field); + } + } +} + +void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); } + +// leaf nodes. +void DFPatternVisitor::VisitDFPattern_(const PrimArrPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const DataflowVarPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const GlobalVarPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const ExternFuncPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {} + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 5dd83f2da24c..43558a52be32 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -21,10 +21,31 @@ import tvm.testing from tvm import tir from tvm import relax as rx -from tvm.relax.analysis import has_reshape_pattern +from tvm.relax.analysis import has_reshape_pattern, udchain from tvm.script import relax as R, tir as T +def test_use_def(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", R.Tensor([m, n], "float16")) + y = rx.Var("y", R.Tensor([n], "float16")) + ib = rx.BlockBuilder() + with ib.function("func", [x, y]): + with ib.dataflow(): + lv0 = ib.emit(rx.op.add(x, y)) + lv1 = ib.emit(rx.op.multiply(lv0, y)) + gv0 = ib.emit_output(lv1) + ib.emit_func_output(gv0) + dfb = ib.get()["func"].body.blocks[0] + udc = udchain(dfb) + assert set(udc[x]) == {lv0} + assert set(udc[y]) == {lv0, lv1} + assert set(udc[lv0]) == {lv1} + assert set(udc[lv1]) == {gv0} + assert set(udc[gv0]) == set() + + def test_reshape_pattern_reshape(): @T.prim_func def reshape( diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py new file mode 100644 index 000000000000..ab7a5540ad66 --- /dev/null +++ b/tests/python/relax/test_dataflow_pattern.py @@ -0,0 +1,867 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm.testing + +from tvm import relay +from tvm.relax.dpl import * +from tvm.relax.analysis import get_var2val +from tvm import relax as rx, tir +from tvm.script import relax as R, tir as T + + +@tvm.script.ir_module +class Module: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + k = T.int32() + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + C = T.match_buffer(z, (32, 32)) + + for (i0, j0, k0) in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + @T.prim_func + def tir_relu(x: T.handle, y: T.handle): + T.func_attr({"global_symbol": "tir_relu"}) + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + for (i, j) in T.grid(32, 32): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.max(A[vi, vj], 0.0) + + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir(tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) + R.output(lv1) + return lv1 + + +main_fn = Module["main"] +bindings = main_fn.body.blocks[0].bindings + +## Node-wise Matching +def test_expr_pattern(): + ep = is_expr(rx.Var("x")) + assert isinstance(ep, ExprPattern) + assert isinstance(ep.expr, rx.Var) + + +def test_var_pattern(): + v = is_var("x") + assert isinstance(v, VarPattern) + assert v.name == "x" + assert v.match(rx.Var("x")) + assert is_var().match(rx.Var("x")) + assert is_var().match(rx.DataflowVar("x")) # DataflowVar is also a Var + assert not v.match(rx.GlobalVar("x")) + + +def test_dataflow_var_pattern(): + v = is_dfv("x") + assert isinstance(v, DataflowVarPattern) + assert v.name == "x" + assert v.match(rx.DataflowVar("x")) + assert not v.match(rx.GlobalVar("x")) + assert is_dfv().match(bindings[0].var) + + +def test_global_var_pattern(): + assert is_gv("x").match(rx.GlobalVar("x")) + assert is_gv().match(rx.GlobalVar("x")) + assert not is_gv("x").match(rx.GlobalVar("y")) + assert not is_gv("x").match(rx.Var("x")) + + +def test_constant_pattern(): + c = is_const() + assert isinstance(c, ConstantPattern) + assert c.match(rx.const([[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]])) + + +def test_wildcard_pattern(): + wc = wildcard() + assert isinstance(wc, WildcardPattern) + assert wc.match(rx.Var("x")) + + +def test_call_pattern(): + wc1 = wildcard() + wc2 = wildcard() + c = is_op("relax.add")(wc1, wc2) + assert isinstance(c, CallPattern) + assert isinstance(c.args[0], WildcardPattern) + assert isinstance(c.args[1], WildcardPattern) + assert c.match(rx.op.add(rx.Var("x"), rx.Var("y"))) + + +def test_function_pattern(): + wc1 = wildcard() + wc2 = wildcard() + f = FunctionPattern([wc1, wc2], is_op("relax.add")(wc1, wc2)) + assert isinstance(f, FunctionPattern) + assert isinstance(f.params[0], WildcardPattern) + assert isinstance(f.params[1], WildcardPattern) + assert isinstance(f.body, CallPattern) + assert isinstance(f.body.args[0], WildcardPattern) + assert isinstance(f.body.args[1], WildcardPattern) + x = rx.Var("x", R.Tensor("float32")) + y = rx.Var("y", R.Tensor("float32")) + assert f.match(rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32"))) + assert not f.match( + rx.Function([x, y], rx.op.multiply(x, y), ret_struct_info=R.Tensor("float32")) + ) + + +def test_tuple_pattern(): + wc1 = wildcard() + wc2 = is_dfv() + t = is_tuple([wc1, wc2]) + assert isinstance(t, TuplePattern) + assert isinstance(t.fields[0], WildcardPattern) + assert isinstance(t.fields[1], DataflowVarPattern) + assert t.match(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")])) + assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.GlobalVar("y")])) + assert not t.match(rx.Tuple([])) + assert t[0].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)) + assert t[1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1)) + # Negative index is also allowed + assert t[-1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1)) + # None means any index. + assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)) + assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1)) + with pytest.raises(IndexError): + t[2] # index cannot be greater than or equal to the tuple size. + + +def test_unordered_tuple_pattern(): + t = is_tuple([is_const(), is_dfv()], unordered=True) + assert isinstance(t, UnorderedTuplePattern) + assert isinstance(t.fields[0], ConstantPattern) + assert isinstance(t.fields[1], DataflowVarPattern) + assert t.match(rx.Tuple([rx.const([]), rx.DataflowVar("x")])) + assert t.match(rx.Tuple([rx.DataflowVar("x"), rx.const([])])) + assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.DataflowVar("y")])) + assert not t.match(rx.Tuple([])) + + +def test_tuple_get_item_pattern(): + assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match( + rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0) + ) + assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match( + rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0) + ) + + +def test_or_pattern(): + dfv_or_gv = is_dfv("x") | is_gv("x") + assert isinstance(dfv_or_gv, OrPattern) + assert dfv_or_gv.match(rx.DataflowVar("x")) + assert dfv_or_gv.match(rx.GlobalVar("x")) + assert not dfv_or_gv.match(rx.Var("x")) + assert not dfv_or_gv.match(rx.DataflowVar("y")) + assert not dfv_or_gv.match(rx.GlobalVar("y")) + + +def test_and_pattern(): + # float[2, 3, 3] + f32_233 = wildcard().has_shape((2, 3, 3)) & has_dtype("float32") + assert isinstance(f32_233, AndPattern) + assert f32_233.match(rx.Var("x", R.Tensor((2, 3, 3), "float32"))) + assert not f32_233.match(rx.Var("x", R.Tensor((3, 3, 3), "float32"))) + assert not f32_233.match(rx.Var("x", R.Tensor("float32", ndim=3))) + + +def test_not_pattern(): + no_shape233 = ~wildcard().has_shape((2, 3, 3)) + assert isinstance(no_shape233, NotPattern) + assert no_shape233.match(rx.Var("x", R.Tensor((3, 3, 3), "float32"))) + assert not no_shape233.match(rx.Var("x", R.Tensor((2, 3, 3), "float32"))) + + +def test_type_pattern(): + assert wildcard().has_type(rx.DynTensorType(2, "float32")).match(bindings[0].var) + + +def test_dtype_pattern(): + dtype = "float16" + pattern = has_dtype(dtype) + assert isinstance(pattern, DataTypePattern) + assert pattern.dtype == dtype + assert has_dtype("float32").match(bindings[0].var) + + +def test_shape_pattern(): + shape = [32, 32] + pattern = wildcard().has_shape(shape) + assert isinstance(pattern, ShapePattern) + tvm.ir.structural_equal(pattern.shape, shape) + assert pattern.match(bindings[0].var) + assert wildcard().has_shape([32, 32]).match(bindings[0].var) + n, m = tir.Var("n", dtype="int64"), tir.Var("m", dtype="int64") + symsh_var = rx.Var("x", R.Tensor([n, m, n + m], "float32")) + assert wildcard().has_shape([n, m, n + m]).match(symsh_var) + assert wildcard().has_shape([n, m, m + n]).match(symsh_var) # + is commutative. + assert not wildcard().has_shape([1, 2, 3]).match(symsh_var) + assert not wildcard().has_shape([m, n, n + m]).match(symsh_var) + + +def test_prim_arr_pattern(): + """ + The difference between is_shape and has_shape is that: + 1) is_shape directly matches a shape (e.g., as an argument); + 2) has_shape matches a tensor and puts assumptions on the tensor's shape. + """ + pattern = is_shape([32, 32]) + assert pattern[0] == 32 + assert pattern[1] == 32 + assert isinstance(pattern, PrimArrPattern) + assert pattern.match(rx.get_shape_of(bindings[0].var)) + n, m = tir.Var("n", dtype="int64"), tir.Var("m", dtype="int64") + symbolic_shape = rx.ShapeExpr([n, m, n + m]) + assert is_shape([n, m, n + m]).match(symbolic_shape) + assert not is_shape([n, m, n * m]).match(symbolic_shape) + + +def test_extern_fn_pattern(): + pattern = ExternFuncPattern("test.blockbuilder.nop") + assert pattern.match(rx.ExternFunc("test.blockbuilder.nop")) + + +def test_op_attr(): + x = rx.Var("x", R.Tensor("float32")) + y = rx.Var("y", R.Tensor("float32")) + conv2d = relay.nn.conv2d(x, y, kernel_size=(3, 3)) + xp = is_var("x") + yp = is_var("y") + # TODO(@yuchen): reenable the assert after figuring out why it fails + # assert is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size": [3, 3]}).match(conv2d) + assert not is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size": [4, 3]}).match(conv2d) + assert not is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size_": [3, 3]}).match(conv2d) + + +def test_match_call_attr(): + x = rx.Var("x", R.Tensor("float32")) + y = rx.Var("y", R.Tensor("float32")) + fn = rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32")) + annotated_fn = fn.with_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"}) + xp = is_var("x") + yp = is_var("y") + root_pattern = FunctionPattern([xp, yp], is_op("relax.add")(xp, yp)) + assert root_pattern.has_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"}).match( + annotated_fn + ) + + assert root_pattern.has_attr({"Codegen": "test-codegen"}).match(annotated_fn) + assert not root_pattern.has_attr({"ping": "pong"}).match(annotated_fn) + assert root_pattern.has_attr({}).match(annotated_fn) + + +def test_is_call_tir(): + lv1_val = bindings[1].value + var2val = get_var2val(Module["main"]) + assert is_call_tir("tir_relu").match(lv1_val) + assert is_call_tir("tir_relu", [is_call_tir("tir_matmul")]).match(lv1_val, var2val=var2val) + assert not is_call_tir("tir_relu", [is_call_tir("tir_relu")]).match(lv1_val, var2val=var2val) + + +@R.function +def simple_call_packed( + x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") +) -> R.Tensor: + gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) + return gv0 + + +def test_varg_default_wildcard(): + expr = simple_call_packed.body.blocks[0].bindings[0].value + yes_pattern_explicit = ExternFuncPattern("test.vm.mul")(wildcard(), wildcard()) + yes_pattern_implicit = ExternFuncPattern("test.vm.mul")(varg_default_wildcard=True) + no_pattern = ExternFuncPattern("test.vm.mul")(wildcard()) + + assert yes_pattern_explicit.match(expr) + assert yes_pattern_implicit.match(expr) + assert not no_pattern.match(expr) + + +def test_simple_call_packed(): + expr = simple_call_packed.body.blocks[0].bindings[0].value + assert is_call_packed("test.vm.mul").match(expr) + assert is_call_packed("test.vm.mul", [is_var("x"), is_var("w")]).match(expr) + + +## Graph-wise Matching +def test_simple_used_by(): + with PatternContext() as ctx: + n0 = is_var("x") # x is a free var (fn arg) + n1 = wildcard() + n0 ^ n1 + dfb = main_fn.body.blocks[0] + matched = ctx.match_dfb(dfb) + assert matched + assert matched[n0] == main_fn.params[0] + assert matched[n1] == dfb.bindings[0].var + + +def test_simple_call_tir_edge(): + with PatternContext() as ctx: + n0 = is_call_tir("tir_matmul") + n1 = is_call_tir("tir_relu") + n0.used_by(n1) + dfb = main_fn.body.blocks[0] + matched = ctx.match_dfb(dfb) + assert matched + assert matched[n0] == dfb.bindings[0].var + assert matched[n1] == dfb.bindings[1].var + + +def test_simple_oub(): + with PatternContext() as ctx: + n0 = is_call_tir("tir_matmul") + n1 = is_call_tir("tir_relu") + n0 >> n1 + dfb = main_fn.body.blocks[0] + matched = ctx.match_dfb(dfb) + assert matched + assert matched[n0] == dfb.bindings[0].var + assert matched[n1] == dfb.bindings[1].var + + +def test_counter_syntax_match(): + with PatternContext() as ctx: + n0 = is_call_tir_extern("tir_matmul") + n1 = is_call_tir_extern("tir_impossible") + n0 >> n1 + dfb = main_fn.body.blocks[0] + assert not ctx.match_dfb(dfb) + + with PatternContext() as ctx: + n0 = is_call_tir_extern("tir_matmul") + n1 = is_call_tir_extern("tir_impossible") + n0 ^ n1 + dfb = main_fn.body.blocks[0] + assert not ctx.match_dfb(dfb) + + +@tvm.script.ir_module +class Diamond: + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + # matmul + # / \ + # relu sigmoid + # \ / + # add + lv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("tir_relu", (lv0,), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) + lv3 = R.call_tir("tir_add", (lv1, lv2), R.Tensor((32, 32), dtype="float32")) + R.output(lv3) + return lv3 + + +def test_diamond(): + with PatternContext() as ctx: + n0 = is_call_tir_extern("tir_matmul") + n1 = is_call_tir_extern("tir_relu") + n2 = is_call_tir_extern("tir_sigmoid") + n3 = is_call_tir_extern("tir_add") + + n0 ^ n1 + n0 ^ n2 + n1 >> n3 + n2 >> n3 + + dfb = Diamond["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + # simplify it with fork_to + with PatternContext() as ctx: + n1 = is_call_tir_extern("tir_relu") + n2 = is_call_tir_extern("tir_sigmoid") + n3 = is_call_tir_extern("tir_add") + + is_call_tir_extern("tir_matmul").fork_to(n1, n2) + n1 >> n3 + n2 >> n3 + + dfb = Diamond["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + +def test_diamond_counter_oub(): + with PatternContext() as ctx: + n0 = is_call_tir_extern("tir_matmul") + n1 = is_call_tir_extern("tir_relu") + n2 = is_call_tir_extern("tir_sigmoid") + n3 = is_call_tir_extern("tir_add") + + n0 >> n1 + n0 >> n2 + n1 >> n3 + n2 >> n3 + + dfb = Diamond["main"].body.blocks[0] + assert not ctx.match_dfb(dfb) + + +@tvm.script.ir_module +class SmallDiamond: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + # relu + # / \ + # \ / + # add + lv0 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("my_add", (lv0, lv0), R.Tensor((32, 32), dtype="float32")) + R.output(lv1) + return lv1 + + +@tvm.script.ir_module +class SmallParallel: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + # relu relu + # \ / + # add + lv0 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_tir("my_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")) + R.output(lv2) + return lv2 + + +def test_distiguish_diamond_and_parallel(): + # relay pattern lang cannot distinguish the two cases above. + diamond = SmallDiamond["main"].body.blocks[0] + parallel = SmallParallel["main"].body.blocks[0] + + with PatternContext() as ctx: + # describe a diamond pattern + fork = is_call_tir_extern("my_relu") + join = is_call_tir_extern("my_add") + fork.only_used_by(join, index=0) + fork.only_used_by(join, index=1) + + assert ctx.match_dfb(diamond) + assert not ctx.match_dfb(parallel) + + with PatternContext() as ctx: + # describe a parallel pattern + join = is_call_tir_extern("my_add") + # Due to one-one mathcing: + # is_call_tir_extern("my_relu") creates the 1st relu + is_call_tir_extern("my_relu") >> join + # is_call_tir_extern("my_relu") + # creates the another different relu (obj address is different) + is_call_tir_extern("my_relu") >> join + + assert ctx.match_dfb(parallel) + assert not ctx.match_dfb(diamond) + + +@tvm.script.ir_module +class CBRx2: + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w0: R.Tensor((1, 1), "float32"), + bias0: R.Tensor((32, 32), "float32"), + w1: R.Tensor((1, 1), "float32"), + bias1: R.Tensor((32, 32), "float32"), + ) -> R.Tensor: + # R.TensorRT's CBR Optimization Pattern + # input + # / \ + # cbr0 cbr1 + # \ / + # concat + with R.dataflow(): + lv0 = R.call_tir("conv1x1", (x, w0), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("bias_add", (lv0, bias0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_tir("my_relu", (lv1), R.Tensor((32, 32), dtype="float32")) + lv3 = R.call_tir("conv1x1", (x, w1), R.Tensor((32, 32), dtype="float32")) + lv4 = R.call_tir("bias_add", (lv3, bias1), R.Tensor((32, 32), dtype="float32")) + lv5 = R.call_tir("my_relu", (lv4), R.Tensor((32, 32), dtype="float32")) + lv6 = R.call_tir("concat", (lv2, lv5), R.Tensor((32, 64), dtype="float32")) + R.output(lv6) + return lv6 + + +def test_single_cbr(): + with PatternContext() as ctx: + ( + is_call_tir_extern("conv1x1") + >> is_call_tir_extern("bias_add") + >> is_call_tir_extern("my_relu") + ) + dfb = CBRx2["main"].body.blocks[0] + matched = ctx.match_dfb(dfb) + assert matched + + with PatternContext() as ctx: + chain = ( + is_call_tir_extern("conv1x1") + >> is_call_tir_extern("bias_add") + >> is_call_tir_extern("my_relu") + ) + dfb = CBRx2["main"].body.blocks[0] + # we want to specifically match the first CBR (lv0) + matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var) + assert matched + assert matched[chain[0]] == dfb.bindings[0].var + # we want to specifically match the second CBR (lv3) + matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[3].var) + assert matched + assert matched[chain[0]] == dfb.bindings[3].var + + +def test_counter_single_crb(): + with PatternContext() as ctx: + ( + is_call_tir_extern("conv1x1") + >> is_call_tir_extern("my_relu") + >> is_call_tir_extern("bias_add") + ) + dfb = CBRx2["main"].body.blocks[0] + assert not ctx.match_dfb(dfb) + # Quickly fails unpromising matches by assumiung `start_hint` must be matched by a pattern. + # This is usually faster than the full match: + # Full match: let one pattern to match -> all Var: complexity ~ #Var + # must_include_hint: let `start_hint` to match -> all patterns: complexity ~ #patterns + # Usually #patterns is much smaller than #Var, so this is faster. + assert not ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var, must_include_hint=True) + + +def test_nested_context(): + dfb = CBRx2["main"].body.blocks[0] + with PatternContext() as ctx0: + ( + is_call_tir_extern("conv1x1") + >> is_call_tir_extern("bias_add") + >> is_call_tir_extern("my_relu") + ) + with PatternContext() as ctx1: + is_call_tir_extern("conv1x1") >> is_call_tir_extern("my_relu") # pattern to miss + with PatternContext() as ctx2: + is_call_tir_extern("bias_add") >> is_call_tir_extern("my_relu") + assert ctx2.match_dfb(dfb) + assert PatternContext.current() == ctx2 + assert not ctx1.match_dfb(dfb) + assert PatternContext.current() == ctx1 + assert ctx0.match_dfb(dfb) + assert PatternContext.current() == ctx0 + + +def test_two_cbr(): + with PatternContext() as ctx: + cbr0 = ( + is_call_tir_extern("conv1x1") + >> is_call_tir_extern("bias_add") + >> is_call_tir_extern("my_relu") + ) + cbr1 = cbr0.dup() + + assert cbr0.patterns[0] != cbr1.patterns[0] + assert cbr0.patterns[1] != cbr1.patterns[1] + assert cbr0.patterns[2] != cbr1.patterns[2] + + is_var("x").fork_to(cbr0, cbr1) + dfb = CBRx2["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + with PatternContext() as ctx: + # Deny the pattern + cbr0 = ( + is_call_tir_extern("conv1x1") + >> is_call_tir_extern("bias_add") + >> is_call_tir_extern("my_relu") + ) + cbr1 = cbr0.dup() + + # input has no fork at y. + is_var("y").fork_to(cbr0, cbr1) + dfb = CBRx2["main"].body.blocks[0] + assert not ctx.match_dfb(dfb) + + +def test_two_matmul(): + # Same as Figure 2(a) in TASO paper. + @tvm.script.ir_module + class MatMul2: + @R.function + def main( + a: R.Tensor((32, 16), "float32"), + b: R.Tensor((16, 48), "float32"), + c: R.Tensor((48, 32), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.call_tir("matmul", (a, b), R.Tensor((32, 48), dtype="float32")) + lv1 = R.call_tir("matmul", (lv0, c), R.Tensor((32, 32), dtype="float32")) + R.output(lv1) + return lv1 + + with PatternContext() as ctx: + is_call_tir_extern("matmul") >> is_call_tir_extern("matmul") + dfb = MatMul2["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + with PatternContext() as ctx: + is_call_tir_extern("matmul").has_shape([32, 48]) >> is_call_tir_extern("matmul").has_shape( + [32, 32] + ) + dfb = MatMul2["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + with PatternContext() as ctx: + is_call_tir_extern("matmul") >> is_call_tir_extern("matmul") >> is_call_tir_extern("matmul") + dfb = MatMul2["main"].body.blocks[0] + # Three MatMul cannot match + assert not ctx.match_dfb(dfb) + + +def test_concat_mm_split(): + # Same as Figure 2(b) in TASO paper. + @tvm.script.ir_module + class CMS: + @R.function + def main( + a: R.Tensor((32, 32), "float32"), + b: R.Tensor((16, 32), "float32"), + c: R.Tensor((16, 32), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.call_tir("my_concat", (b, c), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("my_matmul", (a, lv0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_tir( + "my_split", + (lv1,), + [R.Tensor((16, 32), dtype="float32"), R.Tensor((16, 32), dtype="float32")], + ) + lv3 = R.TupleGetItem(lv2, 0) + lv4 = R.TupleGetItem(lv2, 1) + lv5 = R.add(lv3, lv4) + R.output(lv5) + return lv5 + + with PatternContext() as ctx: + ( + is_call_tir_extern("my_concat") + >> is_call_tir_extern("my_matmul") + >> is_call_tir_extern("my_split") + ) + dfb = CMS["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + with PatternContext() as ctx: + split = is_call_tir_extern("my_split") + lv3 = TupleGetItemPattern(split, 0).has_shape([16, 32]) + lv4 = TupleGetItemPattern(split, 1).has_shape([16, 32]) + split.fork_to(lv3, lv4) + add = is_op("relax.add")(lv3, lv4) + # TODO(@ganler): simplify this through implicit graph pattern. + lv3 >> add + lv4 >> add + + dfb = CMS["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + +def test_self_attention(): + # The example comes from. + # https://developer.nvidia.com/blog/nlu-with-tensorrt-bert/ + @tvm.script.ir_module + class SelfAttention: + @R.function + def main( + x: R.Tensor(("b", "s", "n", "h"), "float32"), + wq: R.Tensor(("h", "h"), "float32"), + wk: R.Tensor(("h", "h"), "float32"), + wv: R.Tensor(("h", "h"), "float32"), + ) -> R.Tensor: + b, s, n, h = T.int64(), T.int64(), T.int64(), T.int64() + with R.dataflow(): + fcq = R.call_tir("my_fc", (x, wq), R.Tensor((b, s, n, h), dtype="float32")) + tpq = R.call_tir("my_transpose", (fcq,), R.Tensor((b, s, h, n), dtype="float32")) + + fck = R.call_tir("my_fc", (x, wk), R.Tensor((b, s, n, h), dtype="float32")) + tpk = R.call_tir("my_transpose", (fck,), R.Tensor((b, s, h, n), dtype="float32")) + + mul = R.multiply(tpq, tpk) + scale = R.multiply(mul, R.const(1.1, "float32")) + softmax = R.call_tir("softmax", (scale,), R.Tensor((b, s, n, h), dtype="float32")) + + fcv = R.call_tir("my_fc", (x, wv), R.Tensor((b, s, n, h), dtype="float32")) + tpv = R.call_tir("my_transpose", (fcv,), R.Tensor((b, s, h, n), dtype="float32")) + + out = R.multiply(softmax, tpv) + R.output(out) + + return out + + with PatternContext() as ctx: + fc_trans_q = is_call_tir_extern("my_fc") >> is_call_tir_extern("my_transpose") + fc_trans_k = fc_trans_q.dup() + fc_trans_v = fc_trans_q.dup() + + is_var("x").fork_to(fc_trans_q, fc_trans_k, fc_trans_v) + dfb = SelfAttention["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + +def test_nested_diamond(): + @tvm.script.ir_module + class DiamondInDiamond: + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + # matmul0 matmul1 + # / \ / \ + # sigmoid2 add4 sigmoid3 + # \ / \ / + # add5 add6 + # \ / + # add7 + lv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) + lv3 = R.call_tir("tir_sigmoid", (lv1), R.Tensor((32, 32), dtype="float32")) + lv4 = R.call_tir("tir_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")) + lv5 = R.call_tir("tir_add", (lv2, lv4), R.Tensor((32, 32), dtype="float32")) + lv6 = R.call_tir("tir_add", (lv3, lv4), R.Tensor((32, 32), dtype="float32")) + lv7 = R.call_tir("tir_add", (lv5, lv6), R.Tensor((32, 32), dtype="float32")) + R.output(lv7) + return lv7 + + # match matmul0 diamond + with PatternContext() as ctx: + sigmoid2 = is_call_tir_extern("tir_sigmoid") + add4 = is_call_tir_extern("tir_add") + is_call_tir_extern("tir_matmul").fork_to(sigmoid2, add4) + add5 = is_call_tir_extern("tir_add") + sigmoid2 >> add5 + add4 ^ add5 + assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) + + # counter case: mis-match matmul0 diamond + with PatternContext() as ctx: + sigmoid2 = is_call_tir_extern("tir_sigmoid") + add4 = is_call_tir_extern("tir_add") + is_call_tir_extern("tir_matmul").fork_to(sigmoid2, add4) + add5 = is_call_tir_extern("tir_add") + sigmoid2 >> add5 + add4 >> add5 # not only-used-by relation + assert not ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) + + # match matmul1 diamond + with PatternContext() as ctx: + sigmoid3 = is_call_tir_extern("tir_sigmoid") + add4 = is_call_tir_extern("tir_add") + is_call_tir_extern("tir_matmul").fork_to(sigmoid3, add4) + add6 = is_call_tir_extern("tir_add") + sigmoid3 >> add6 + add4 ^ add6 + assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) + + # match add-4-5-6-7 + with PatternContext() as ctx: + add5, add6, add7 = ( + is_call_tir_extern("tir_add"), + is_call_tir_extern("tir_add"), + is_call_tir_extern("tir_add"), + ) + is_call_tir_extern("tir_add").fork_to(add5, add6) # add4 + add5 >> add7 + add6 >> add7 + assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) + + +def test_incremental_solving(): + @R.function + def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + # relu -> sigmoid -> neg + lv0 = R.call_tir("tir_relu", (x), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_tir("tir_neg", (lv1), R.Tensor((32, 32), dtype="float32")) + R.output(lv2) + return lv2 + + relu = is_call_tir_extern("tir_relu") + sigmoid = is_call_tir_extern("tir_sigmoid") + neg = is_call_tir_extern("tir_neg") + + with PatternContext() as ctx0: + relu >> sigmoid + with PatternContext(incremental=True) as ctx1: + # because we are doing incremental solving + # relu >> sigmoid is still a constraint in this context. + # that said the total constraint is: + # relu >> sigmoid >> neg + sigmoid >> neg + assert ctx1.match_dfb(simple_chain.body.blocks[0]) + + # match relue -> sigmoid + assert ctx0.match_dfb(simple_chain.body.blocks[0]) + + +def test_incremental_solving_counter(): + @R.function + def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + # sigmoid -> neg + lv0 = R.call_tir("tir_sigmoid", (x), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_tir("tir_neg", (lv0), R.Tensor((32, 32), dtype="float32")) + R.output(lv1) + return lv1 + + relu = is_call_tir_extern("tir_relu") + sigmoid = is_call_tir_extern("tir_sigmoid") + neg = is_call_tir_extern("tir_neg") + + with PatternContext() as ctx0: + relu >> sigmoid # cannot match + + with PatternContext(incremental=False) as ctx1: + # total constraint: sigmoid >> neg + sigmoid >> neg + assert ctx1.match_dfb(simple_chain.body.blocks[0]) + + with PatternContext(incremental=True) as ctx1: + # total constraint: relu >> sigmoid >> neg + sigmoid >> neg + assert not ctx1.match_dfb(simple_chain.body.blocks[0]) + + +if __name__ == "__main__": + tvm.testing.main() From 988b2aaf0f775e415beeb304237cf6af1ddcd548 Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Sat, 18 Feb 2023 18:13:57 -0600 Subject: [PATCH 42/81] [Unity] Statement rewriter for DataflowBlock (#14043) This PR implements a few APIs to quickly perform statement-level mutation: `add`/`remove_unused`/`remove_all_unused`/`replace_all_uses`. It also implements `remove_all_unused` to remove dead statements inside `DataflowBlock`. --- include/tvm/relax/analysis.h | 24 ++ include/tvm/relax/binding_rewrite.h | 115 +++++++ include/tvm/relax/utils.h | 1 + python/tvm/relax/analysis/analysis.py | 23 +- python/tvm/relax/binding_rewrite.py | 155 ++++++++++ src/relax/ir/binding_rewrite.cc | 324 ++++++++++++++++++++ tests/python/relax/test_analysis.py | 118 +++++++- tests/python/relax/test_binding_rewrite.py | 334 +++++++++++++++++++++ 8 files changed, 1092 insertions(+), 2 deletions(-) create mode 100644 include/tvm/relax/binding_rewrite.h create mode 100644 python/tvm/relax/binding_rewrite.py create mode 100644 src/relax/ir/binding_rewrite.cc create mode 100644 tests/python/relax/test_binding_rewrite.py diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 32e1582134c7..b9866577e9b6 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -341,6 +341,14 @@ TVM_DLL Map AnalyzeVar2Value(const Expr& expr); */ TVM_DLL Map AnalyzeVar2Value(const DataflowBlock& dfb); +/*! + * \brief Return a mapping from variable name to its Bindings. + * + * \param fn The function to be analyzed. + * \return A mapping from variable name to its Bindings. + */ +TVM_DLL Map> NameToBinding(const Function& fn); + /*! * \brief Get the use-def chain of variables inside a dataflow block. * @@ -349,6 +357,22 @@ TVM_DLL Map AnalyzeVar2Value(const DataflowBlock& dfb); */ TVM_DLL Map> DataflowBlockUseDef(const DataflowBlock& dfb); +/*! + * \brief Get the use-def chain of variables inside a function. + * + * \param fn The function to be analyzed. + * \return A map from variable definitions to a set of uses and variables needed by return value. + */ +std::pair>, Array> FunctionUseDef(const Function& fn); + +/*! + * \brief Remove unused statements inside DataflowBlocks. + * + * \param fn The function to remove unused statements. + * \return The function that contains no unused statements in DataflowBlock. + */ +TVM_DLL Function RemoveAllUnused(const Function fn); + /*! * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps. * diff --git a/include/tvm/relax/binding_rewrite.h b/include/tvm/relax/binding_rewrite.h new file mode 100644 index 000000000000..a4b534965ae2 --- /dev/null +++ b/include/tvm/relax/binding_rewrite.h @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/binding_rewrite.h + * \brief An IR rewriter to easily add/remove/replace bindings (statements). + */ + +#ifndef TVM_RELAX_BINDING_REWRITE_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief Statement rewriter for relax.DataflowBlock. */ +class DataflowBlockRewriteNode : public Object { + public: + /*! \brief Replace all uses of old_var with new_var. */ + void ReplaceAllUses(Var old_var, Var new_var); + /*! \brief Insert a Binding statement. */ + void Add(Binding binding); + /*! \brief Insert an expression as VarBinding with variable name. */ + void Add(String var_name, Expr expr, bool is_dfvar = false) { + auto var = is_dfvar ? DataflowVar(var_name, GetStructInfo(expr)) // + : Var(var_name, GetStructInfo(expr)); + Add(VarBinding(std::move(var), std::move(expr))); + } + /*! \brief Insert an expression as VarBinding with automatic variable name. */ + void Add(Expr expr, bool is_dfvar = false) { + Add(name_table_.GetUniqueName("tmp"), expr, is_dfvar); + } + /*! \brief Remove the definition statement of an unused variable. */ + void RemoveUnused(Var unused, bool allow_undef = false); + /*! \brief Remove the definition statements of all unused variables. */ + void RemoveAllUnused(); + + /*! \brief The rewritten dataflow block. */ + DataflowBlock MutatedDataflowBlock() { return dfb_.value(); } + /*! \brief The rewritten function. */ + Function MutatedFunc() { return root_fn_.value(); } + /*! \brief The rewritten IRModule. */ + IRModule MutateIRModule(IRModule irmod); + + /*! \brief Visit attributes. */ + void VisitAttrs(AttrVisitor* v) { + v->Visit("dfb", &dfb_); + v->Visit("root_fn", &root_fn_); + } + + static constexpr const char* _type_key = "relax.DataflowBlockRewrite"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockRewriteNode, Object); + + protected: + friend class DataflowBlockRewrite; + + Optional dfb_; //!< The rewritten dataflow block. + Optional root_fn_; //!< The rewritten function. + const FunctionNode* original_fn_ptr_; //!< Pointer to the original function. + Map> to_users_; //!< Map from variable to its users. + Array fn_outputs_; //!< Variables required by function outputs. + + private: + NameTable name_table_; //!< Name table for tracking and generating unique names. +}; + +/*! + * \brief A statement rewriter for relax.DataflowBlock. + * \sa DataflowBlockRewriteNode + */ +class DataflowBlockRewrite : public ObjectRef { + public: + TVM_DLL explicit DataflowBlockRewrite(DataflowBlock dfb, Function root_fn); + + /*! + * \brief mutable accessor. + * \return mutable access pointer. + */ + DataflowBlockRewriteNode* operator->() { + ICHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockRewrite, ObjectRef, DataflowBlockRewriteNode); +}; + +} // namespace relax +} // namespace tvm + +#define TVM_RELAX_BINDING_REWRITE_H_ +#endif // TVM_RELAX_BINDING_REWRITE_H_ diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 1457a16427cc..c1d984a21a7b 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -25,6 +25,7 @@ #define TVM_RELAX_UTILS_H_ #include +#include #include #include diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 45c5b6f96288..ffcdaceb4076 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -28,7 +28,7 @@ from tvm import IRModule from tvm.relax.ty import Type from tvm.relax.struct_info import StructInfo, FuncStructInfo -from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call +from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call, Binding from . import _ffi_api @@ -244,6 +244,27 @@ def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]: return _ffi_api.udchain(dfb) # type: ignore +def name_to_binding(func: Function) -> Dict[str, List[Binding]]: + """Return a map from variable name to its bindings.""" + return _ffi_api.name_to_binding(func) # type: ignore + + +def remove_all_unused(func: Function) -> Function: + """Remove all unused variables from the function. + + Parameters + ---------- + func : Function + The input function to be analyzed. + + Returns + ------- + Function + The function with unused variables removed. + """ + return _ffi_api.remove_all_unused(func) # type: ignore + + def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool: """Check if the IRModule is well formed. diff --git a/python/tvm/relax/binding_rewrite.py b/python/tvm/relax/binding_rewrite.py new file mode 100644 index 000000000000..a9f6d878ad0d --- /dev/null +++ b/python/tvm/relax/binding_rewrite.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, invalid-name +"""Developer API of add/remove/replace bindings in Relax.""" + +from typing import Optional + +import tvm +import tvm._ffi +from tvm.runtime import Object +from . import Binding, DataflowBlock, Expr, Function, Var +from . import _ffi_api + + +@tvm._ffi.register_object("relax.DataflowBlockRewrite") +class DataflowBlockRewrite(Object): + """ + A binding/statement-level dataflow block rewriter. + + Notes + ----- + Due to the immutable and copy-on-write nature of TVM AST nodes, the rewriting is not done in + place. Instead, a new DataflowBlock is created and returned with mutated_dfb. Similarly, its new + root Function is created and returned by mutated_root_fn. To apply this change for an IRModule, + use mutate_irmodule which rewrites the old function that registered in the constructor. + """ + + def __init__(self, dfb: DataflowBlock, root_fn: Function): + """ + Construct a rewriter with the DataflowBlock to rewrite and its root function. + + Parameters + ---------- + dfb : DataflowBlock + The DataflowBlock to rewrite. + root_fn : Function + The root function of the DataflowBlock. + """ + self.func_name = root_fn.__name__ if hasattr(root_fn, "__name__") else None + self.__init_handle_by_constructor__( + _ffi_api.DataflowBlockRewrite, dfb, root_fn # type: ignore + ) + + def replace_all_uses(self, old_var: Var, new_var: Var) -> None: + """ + Replace all uses of old_var with new_var. + + Parameters + ---------- + old_var : Var + The old variable to replace. + new_var : Var + The new variable to replace with. + """ + _ffi_api.dfb_rewrite_replace_all_uses(self, old_var, new_var) # type: ignore + + def add_binding(self, binding: Binding) -> None: + return _ffi_api.dfb_rewrite_add_binding(self, binding) # type: ignore + + def add(self, expr: Expr, name: Optional[str] = None, is_dfvar: bool = False) -> None: + """ + Add a new statement to the DataflowBlock with an automatically generated variable name. + + Parameters + ---------- + expr : Expr + The expression to add. + name : Optional[str], optional + Variable name, by default None + is_dfvar : bool, optional + The variable type, by default False + + Notes + ----- + If the variable name is not given, it will be automatically generated in a form of + "tmp${COUNTER}". The variable type will be DataflowVar if is_dfvar is True, otherwise + it will be Var. Being Var means the variables are output variables of the DataflowBlock. + While being DataflowVar means the variables are internal variables of the DataflowBlock. + """ + _ffi_api.dfb_rewrite_add(self, expr, name, is_dfvar) # type: ignore + + def remove_unused(self, var: Var, allow_undef=False) -> None: + """ + Remove a statement by its variable definition if and only if it is unused. + + Parameters + ---------- + var : Var + The unused variable definition. + allow_undef : bool, optional + Whether to allow var being undefined variable, by default False + + Raises + ------ + TVMError if the variable is used or undefined (allow_undef=False). + """ + _ffi_api.dfb_rewrite_remove_unused(self, var, allow_undef) # type: ignore + + def remove_all_unused(self) -> None: + """ + Remove all unused variables. + + Notes + ----- + This could remove unused variables in other DataflowBlocks as well. + """ + _ffi_api.dfb_rewrite_remove_all_unused(self) # type: ignore + + def mutated_dfb(self) -> DataflowBlock: + """ + Returns the mutated DataflowBlock. + """ + return self.dfb + + def mutated_root_fn(self) -> Function: + """ + Returns the mutated root function. + """ + ret = self.root_fn + if self.func_name: + ret.__name__ = self.func_name + return ret + + def mutate_irmodule(self, irmodule: tvm.IRModule) -> tvm.IRModule: + """ + Return an updated IRModule by replacing the old function with the mutated root function. + + Parameters + ---------- + irmodule : tvm.IRModule + The base IRModule to update. + + Returns + ------- + tvm.IRModule + The updated IRModule. + """ + ret = _ffi_api.dfb_rewrite_mutate_irmodule(self, irmodule) # type: ignore + if hasattr(irmodule, "__name__"): + ret.__name__ = irmodule.__name__ + return ret diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc new file mode 100644 index 000000000000..dd9fac9fdcd0 --- /dev/null +++ b/src/relax/ir/binding_rewrite.cc @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/binding_rewrite.cc + * \brief Implementation of binding rewriters. + */ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(DataflowBlockRewriteNode); +DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) { + auto n = make_object(); + n->dfb_ = dfb; + n->root_fn_ = root_fn; + n->original_fn_ptr_ = root_fn.get(); + auto p = FunctionUseDef(root_fn); + n->to_users_ = std::move(p.first); + n->fn_outputs_ = std::move(p.second); + n->name_table_ = NameTable(n->to_users_.begin(), n->to_users_.end(), + [](const auto& p) { return p.first->name_hint(); }); + + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowBlockRewrite") + .set_body_typed([](DataflowBlock dfb, Function root_fn) { + return DataflowBlockRewrite(dfb, root_fn); + }); + +void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { + class ReplaceAllUsePass : public ExprMutator { + Var old_var, new_var; + const DataflowBlockNode* const to_catch; + + public: + const DataflowBlockNode* caught = nullptr; + + ReplaceAllUsePass(Var old_var, Var new_var, const DataflowBlockNode* to_catch) + : old_var(old_var), new_var(new_var), to_catch(to_catch) {} + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const VarNode* op) override { + return (op == old_var.get()) ? new_var : GetRef(op); + } + + Expr VisitExpr_(const DataflowVarNode* op) override { + return (op == old_var.get()) ? new_var : GetRef(op); + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { + BindingBlock res = ExprMutator::VisitBindingBlock_(op); + if (op == to_catch) caught = static_cast(res.get()); + return res; + } + }; + + ICHECK(to_users_.find(old_var) != to_users_.end()) << "Cannot find " << old_var; + ICHECK(to_users_.find(new_var) != to_users_.end()) << "Cannot find " << new_var; + + // replace uses inside the DataflowBlock. + ReplaceAllUsePass replacer(old_var, new_var, dfb_.get()); + root_fn_ = Downcast(replacer.VisitExpr_(root_fn_.get())); + dfb_ = GetRef(replacer.caught); + + // update udchain + // old_var -> old_var users | changed to {} + // new_var -> {?} | changed to old_var users + for (Var user : to_users_[old_var]) { + auto new_var_uses = to_users_[new_var]; + if (new_var_uses.end() == std::find(new_var_uses.begin(), new_var_uses.end(), user)) { + new_var_uses.push_back(user); + } + } + + to_users_.Set(old_var, {}); + + auto it_old_output = std::find(fn_outputs_.begin(), fn_outputs_.end(), old_var); + if (it_old_output != fn_outputs_.end()) { + fn_outputs_.Set(std::distance(fn_outputs_.begin(), it_old_output), new_var); + } +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_replace_all_uses") + .set_body_typed([](DataflowBlockRewrite rwt, Var old_var, Var new_var) { + rwt->ReplaceAllUses(old_var, new_var); + }); + +class UpdateDFB : public ExprMutator { + private: + DataflowBlock old_dfb, new_dfb; + + public: + UpdateDFB(DataflowBlock old_dfb, DataflowBlock new_dfb) + : old_dfb(std::move(old_dfb)), new_dfb(std::move(new_dfb)) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { + return old_dfb.get() == op ? new_dfb : old_dfb; + } +}; + +void DataflowBlockRewriteNode::Add(Binding binding) { + auto p = [binding] { + if (auto vb = binding.as()) { + return std::make_pair(vb->var, vb->value); + } else if (auto mc = binding.as()) { + return std::make_pair(mc->var, mc->value); + } + LOG(FATAL) << "Unsupported binding type"; + return std::make_pair(Var{}, Expr{}); + }(); + Var var = p.first; + Expr val = p.second; + + ICHECK(0 == to_users_.count(var)) << var << " has been defined so cannot be added."; + + // Add this VarBinding statement after the definition of uses. + std::set used_vars = [val] { + class UsedVars : public ExprVisitor { + public: + std::set used_vars; + void VisitExpr_(const VarNode* op) override { used_vars.insert(op); } + void VisitExpr_(const DataflowVarNode* op) override { used_vars.insert(op); } + } uvar{}; + uvar.VisitExpr(val); + return std::move(uvar.used_vars); + }(); + + size_t line_last_req_def = 0; + for (size_t i = 0; i < dfb_.value()->bindings.size(); ++i) { + auto line = dfb_.value()->bindings[i]; + if (used_vars.find(line->var.get()) != used_vars.cend()) line_last_req_def = i; + } + + auto old_dfb = dfb_.value(); + + dfb_ = [old_dfb, binding, line_last_req_def, this] { + auto new_dfb = dfb_.value(); + new_dfb.CopyOnWrite()->bindings.insert(dfb_.value()->bindings.begin() + 1 + line_last_req_def, + binding); + return new_dfb; + }(); + + auto updater = UpdateDFB(old_dfb, dfb_.value()); + root_fn_ = Downcast(updater.VisitExpr_(root_fn_.get())); + + for (const VarNode* v : used_vars) to_users_.Get(GetRef(v)).value().push_back(var); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add_binding") + .set_body_typed([](DataflowBlockRewrite rwt, Binding vb) { rwt->Add(vb); }); + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add") + .set_body_typed([](DataflowBlockRewrite rwt, Expr expr, Optional name, bool is_dfvar) { + if (name.get()) { + rwt->Add(name.value(), expr, is_dfvar); + } else { + rwt->Add(expr, is_dfvar); + } + }); + +class RemoveUnusedVars : public ExprMutator { + public: + std::set unused_vars; + Optional caught_rewrite = NullOpt; + + RemoveUnusedVars(Map> users, Array fn_outputs) + : unused_vars([&] { + std::vector unused; + + // iterative dataflow algorithm. + size_t prev_size; + do { + prev_size = unused.size(); + + std::vector used; + used.reserve(users.size()); + for (const auto& kv : users) { + // var -> [users...] + // var is unused iff + // user -> empty + // var is not output var + if (kv.second.empty() && // kv.first is not used by fn outputs. + fn_outputs.end() == std::find(fn_outputs.begin(), fn_outputs.end(), kv.first)) { + unused.push_back(kv.first); + } else { + used.push_back(kv.first); + } + } + + for (size_t i = prev_size; i < unused.size(); ++i) { + users.erase(unused[i]); + // remove def site. + for (const auto& used_var : used) { + ICHECK(users.count(used_var)); + Array var_users = users[used_var]; + // remove the unused var from the use site. + auto it = std::find(var_users.begin(), var_users.end(), unused[i]); + if (it != var_users.end()) { + var_users.erase(it); + users.Set(used_var, std::move(var_users)); + } + } + } + } while (prev_size != unused.size()); // changed? => continue. + + return std::set(unused.begin(), unused.end()); + }()) {} + + RemoveUnusedVars(std::pair>, Array> users_and_outputs) + : RemoveUnusedVars(std::move(users_and_outputs.first), std::move(users_and_outputs.second)) {} + RemoveUnusedVars(Function fn) : RemoveUnusedVars(FunctionUseDef(fn)) {} + RemoveUnusedVars(std::set unused_vars) : unused_vars(std::move(unused_vars)) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) { + auto prev_dfb = GetRef(block); + builder_->BeginDataflowBlock(); + for (Binding binding : block->bindings) { + if (!unused_vars.count(binding->var)) { + VisitBinding(binding); + } + } + auto new_dfb = builder_->EndBlock(); + if (caught_rewrite == prev_dfb) caught_rewrite = Downcast(new_dfb); + return std::move(new_dfb); + } +}; + +void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) { + // first need to check if this var is used. + if (0 == to_users_.count(unused)) { // no def. + if (allow_undef) return; + LOG(FATAL) << unused << " undefined. Set allow_undef=True to allow 'removing' undefined var"; + } + + ICHECK(to_users_[unused].empty()) + << unused << " is used by " << to_users_[unused].size() << " vars"; + + auto old_dfb = dfb_.value(); + + RemoveUnusedVars remover({unused}); + dfb_ = Downcast(remover.VisitBindingBlock_(old_dfb.get())); + + auto updater = UpdateDFB(old_dfb, dfb_.value()); + root_fn_ = Downcast(updater.VisitExpr_(root_fn_.get())); + + to_users_.erase(unused); // update use-def chain. +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_unused") + .set_body_typed([](DataflowBlockRewrite rwt, Var unused, bool allow_undef) { + rwt->RemoveUnused(unused, allow_undef); + }); + +void DataflowBlockRewriteNode::RemoveAllUnused() { + RemoveUnusedVars remover(to_users_, fn_outputs_); + remover.caught_rewrite = dfb_.value(); + + // this could also clean unused variables in other DataflowBlock. + root_fn_ = Downcast(remover.VisitExpr_(root_fn_.get())); + + // DataflowBlock could be None. + dfb_ = remover.caught_rewrite.value(); + + // clean up use-def chain. + for (const auto& unused : remover.unused_vars) to_users_.erase(unused); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused") + .set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); }); + +Function RemoveAllUnused(Function fn) { + RemoveUnusedVars remover(fn); + return Downcast(remover.VisitExpr_(fn.get())); +} + +TVM_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused); + +IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) { + BlockBuilder builder = BlockBuilder::Create(irmod); + + for (auto& p : irmod->functions) { + if (original_fn_ptr_ == p.second.get()) { + builder->UpdateFunction(p.first, root_fn_.value()); + break; + } + } + + return builder->GetContextIRModule(); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_mutate_irmodule") + .set_body_typed([](DataflowBlockRewrite rwt, IRModule irmod) { + return rwt->MutateIRModule(irmod); + }); + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 43558a52be32..e939d2b20830 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -21,7 +21,7 @@ import tvm.testing from tvm import tir from tvm import relax as rx -from tvm.relax.analysis import has_reshape_pattern, udchain +from tvm.relax.analysis import has_reshape_pattern, udchain, remove_all_unused, name_to_binding from tvm.script import relax as R, tir as T @@ -46,6 +46,122 @@ def test_use_def(): assert set(udc[gv0]) == set() +def test_chained_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) + R.output(lv0) + return lv0 + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_binding_block_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + return z + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + return z + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_binding_block_fake_unused_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + return lv0 + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + # This might bring side effect so cannot be removed. + z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + return lv0 + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_edge_binding_block_fake_unused_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + z = R.call_packed("vm.builtin.copy", x, sinfo_args=(R.Tensor((32, 32), "float32"))) + return x + + optimized = remove_all_unused(IdentityUnused["main"]) + tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"]) + + +def test_name_to_binding_var_shadowing(): + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + lv1 = lv0 + R.output(lv1) + + with R.dataflow(): + lv0 = lv1 # shadowing + lv2 = lv0 + R.output(lv2) + return lv2 + + n2binding = name_to_binding(main) + + assert "lv0" in n2binding + assert "lv1" in n2binding + assert "lv2" in n2binding + + assert len(n2binding["lv0"]) == 2 + + def test_reshape_pattern_reshape(): @T.prim_func def reshape( diff --git a/tests/python/relax/test_binding_rewrite.py b/tests/python/relax/test_binding_rewrite.py new file mode 100644 index 000000000000..1b424b97923a --- /dev/null +++ b/tests/python/relax/test_binding_rewrite.py @@ -0,0 +1,334 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +import tvm.testing +from tvm._ffi.base import TVMError +from tvm.relax.analysis import name_to_binding +from tvm.relax.binding_rewrite import DataflowBlockRewrite +from tvm.relax.expr import DataflowVar, Var +from tvm.script import relax as R + + +@tvm.script.ir_module +class Identity: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + +def assert_immutability(rwt, original_dfb, original_root_fn): + assert rwt.mutated_dfb() != original_dfb + assert rwt.mutated_root_fn() != original_root_fn + assert rwt.mutated_root_fn().body.blocks[0] != original_dfb + assert rwt.mutated_root_fn().body.blocks[0] == rwt.mutated_dfb() + + +def test_null_construct(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + DataflowBlockRewrite(dfb, root_fn) + + +def test_simple_add(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(name="tmp", expr=Identity["main"].params[0], is_dfvar=True) + + assert_immutability(rwt, dfb, root_fn) + + # check "tmp" added + assert "tmp" in name_to_binding(rwt.mutated_root_fn()) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + tmp: R.Tensor((32, 32), "float32") = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_simple_auto_add_var(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(root_fn.params[0], is_dfvar=False) + + assert isinstance(rwt.mutated_dfb().bindings[-1].var, Var) + + assert_immutability(rwt, dfb, root_fn) + + +def test_simple_auto_add_dfvar(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(root_fn.params[0], is_dfvar=True) + + assert isinstance(rwt.mutated_dfb().bindings[-1].var, DataflowVar) + + # immutatbility + assert_immutability(rwt, dfb, root_fn) + + +def test_simple_remove_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused = lv0 + R.output(lv0) + return lv0 + + root_fn = IdentityUnused["main"] + dfb = root_fn.body.blocks[0] + + n2binding = name_to_binding(IdentityUnused["main"]) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(n2binding["unused"][0].var) + + assert_immutability(rwt, dfb, root_fn) + + # check "unused" removed + assert "unused" not in name_to_binding(rwt.mutated_root_fn()) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_remove_unused_undef(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + with pytest.raises(TVMError): + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(Var("whatever")) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(Var("whatever"), allow_undef=True) + + assert root_fn == rwt.mutated_root_fn() + + +def test_simple_rm_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused0 = lv0 + unused1 = lv0 + R.output(lv0) + return lv0 + + root_fn = IdentityUnused["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +@tvm.script.ir_module +class DeadDFBlock: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + with R.dataflow(): + lv0 = x + R.output(lv0) + return x + + +def test_empty_dfb_after_removal(): + root_fn = DeadDFBlock["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(DeadDFBlock["main"].body.blocks[0].bindings[0].var) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + return x + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_empty_dfb_after_all_removal(): + dfb = DeadDFBlock["main"].body.blocks[0] + root_fn = DeadDFBlock["main"] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + return x + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_chained_rm_all_unused(): + @tvm.script.ir_module + class IdentityChainedUnused: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) + R.output(lv0) + return lv0 + + root_fn = IdentityChainedUnused["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_simple_replace_all_uses(): + @tvm.script.ir_module + class Lv0To1: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + # lv0 => lv1 + # / \ + # lv2 lv3 + # \ / + # lv4 + with R.dataflow(): + lv0: R.Tensor((32, 32), "float32") = R.call_tir( + "my_relu", (x,), R.Tensor((32, 32), dtype="float32") + ) + lv1: R.Tensor((32, 32), "float32") = R.call_tir( + "my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32") + ) + lv2: R.Tensor((32, 32), "float32") = R.call_tir( + "my_add", (x, lv0), R.Tensor((32, 32), dtype="float32") + ) + lv3: R.Tensor((32, 32), "float32") = R.call_tir( + "my_mul", (x, lv0), R.Tensor((32, 32), dtype="float32") + ) + lv4: R.Tensor((32, 32), "float32") = R.call_tir( + "my_whatever", (lv2, lv3), R.Tensor((32, 32), dtype="float32") + ) + R.output(lv4) + return lv4 + + root_fn = Lv0To1["main"] + dfb = root_fn.body.blocks[0] + + n2binding = name_to_binding(root_fn) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.replace_all_uses(n2binding["lv0"][0].var, n2binding["lv1"][0].var) + rwt.remove_unused(n2binding["lv0"][0].var) + + assert_immutability(rwt, dfb, root_fn) + + n2binding_after = name_to_binding(rwt.mutated_root_fn()) + assert "lv0" not in n2binding_after + + +def test_simple_module_update(): + @tvm.script.ir_module + class Identity: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(name="tmp", expr=root_fn.params[0], is_dfvar=True) + + new_ir = rwt.mutate_irmodule(Identity) + + # immutatbility + assert new_ir != Identity + assert 2 == len(new_ir["main"].body.blocks[0].bindings) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + tmp: R.Tensor((32, 32), "float32") = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(new_ir, GroundTruth) + + +if __name__ == "__main__": + tvm.testing.main() From 63166441e38909d4529e584315cf12afb12f9f95 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 18 Feb 2023 22:19:54 -0500 Subject: [PATCH 43/81] [Unity][Pass] FuseOps FuseTIR fixes (#14044) This PR fixes two bugs of FuseOps and FuseTIR: It fixes FuseOps who only rewrites the "main" function of the IRModule. After the fix, FuseOps now goes through each non-primitive Relax function. Test cases for both FuseOps and FuseTIR sides are added so ensure that both of the two passes work for cases of multiple Relax functions. It also fixes FuseOps and FuseTIR who did not take "call_dps_packed" style "call_tir" into account. The previous behavior will directly downcast the first argument of "call_tir" to GlobalVar, which is not right when the "call_tir" is in "call_dps_packed" stype and the first argument is a PackedFunc. With this fix, FuseOps and FuseTIR will skip such "call_tir"s. Tests for both CallTIR and CallOps are added accordingly. --- src/relax/transform/fuse_ops.cc | 54 +++---- src/relax/transform/fuse_tir.cc | 15 +- tests/python/relax/test_transform_fuse_ops.py | 81 +++++++++- tests/python/relax/test_transform_fuse_tir.py | 141 +++++++++++++++++- 4 files changed, 252 insertions(+), 39 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index f3559b72da3f..0a0209bb8769 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -100,11 +100,15 @@ class GraphCreator : public ExprVisitor { * \return The created IndexedForwardGraph */ static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) { - // Since cross-function call is not supported yet, FuseOps only serves the entry function, whose - // name is "main". - auto relax_func = Downcast(mod->Lookup("main")); GraphCreator creator(mod, arena); - creator(relax_func); + for (const auto& it : mod->functions) { + // Only visit Relax function without attr kPrimitive. + const auto* func = it.second.as(); + if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive)) { + continue; + } + creator(GetRef(func)); + } // The algorithm of the graph creator ensures that each created node will be added to the // post-dfs order and will be set its op pattern. Thus we check whether all these containers @@ -178,25 +182,26 @@ class GraphCreator : public ExprVisitor { // recurse into the call expression. const auto* op = call->op.as(); if (op == call_tir_op_.get()) { - const GlobalVar& global_var = Downcast(call->args[0]); - tir::PrimFunc func = Downcast(mod_->Lookup(global_var)); + // Skip ExternFunc for call_dps_packed. + if (const auto* global_var = call->args[0].as()) { + tir::PrimFunc func = Downcast(mod_->Lookup(GetRef(global_var))); - // Override args for call_tir - args = Downcast(call->args[1])->fields; + // Override args for call_tir + args = Downcast(call->args[1])->fields; - // TODO(tvm-team): handle the shape argument (args[3]) - Optional opt_pattern = func->GetAttr("op_pattern"); - if (opt_pattern.defined()) { - pattern = static_cast(Downcast(opt_pattern)->value); - } else { - pattern = OpPatternKind::kOpaque; + Optional opt_pattern = func->GetAttr("op_pattern"); + if (opt_pattern.defined()) { + pattern = static_cast(Downcast(opt_pattern)->value); + } else { + pattern = OpPatternKind::kOpaque; + } } } // The pattern of the current binding variable node is set to the pattern of this operator. SetNodePattern(binding_var_node, pattern); // Visit all call args for (const Expr& arg : args) { - ICHECK(IsLeaf(arg)); + ICHECK(IsLeafOrTuple(arg)); VisitLeaf(arg, binding_var_node, pattern); } } @@ -226,6 +231,10 @@ class GraphCreator : public ExprVisitor { void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node* binding_var_node, const OpPatternKind& pattern) { ICHECK_NOTNULL(binding_var_node); + if (!leaf_expr->IsInstance()) { + // Skip GlobalVar, ExternFunc, OpNode. + return; + } // Recursive visit if it's Tuple if (const auto* tuple = leaf_expr.as()) { @@ -253,21 +262,6 @@ class GraphCreator : public ExprVisitor { /********** Helper Functions **********/ - /*! - * \brief Check whether the expression is a leaf expression - * \param expr The expression to be checked - * \return Whether the expression is a leaf expression - * \note In order to avoid too much refactor, this method is a simple copy-paste of the is-leaf - * check in "block_builder.cc". And it should be refactored in the future. - * \sa src/relax/ir/block_builder.cc - */ - static bool IsLeaf(const Expr& expr) { - // NOTE: Tuples are treated as leaf nodes for ergonomics - return expr.as() || expr.as() || expr.as() || - expr.as() || expr.as() || expr.as() || - expr.as(); - } - /*! * \brief Create a graph node corresponding to the input key * \param key The object which is used to create the graph node diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index fa5c296d278e..925f09d85d34 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -670,14 +670,15 @@ class TIRFuseMutator : public ExprMutator { } } else if (call->op == call_tir_op_) { // Case 2. It is a call_tir, re-emit the PrimFunc. - GlobalVar gv = Downcast(call->args[0]); - tir::PrimFunc func = Downcast(mod_->Lookup(gv)); - GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint); - return Call(call->op, {new_gv, call->args[1]}, call->attrs, call->sinfo_args, call->span); - } else { - // Case 3. CallNode in other types. Leave it as it is. - return call; + if (const auto* gv = call->args[0].as()) { + tir::PrimFunc func = Downcast(mod_->Lookup(GetRef(gv))); + GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint); + return Call(call->op, {new_gv, call->args[1]}, call->attrs, call->sinfo_args, call->span); + } } + + // Case 3. CallNode in other types. Leave it as it is. + return call; } /********** Helper Functions **********/ diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 1a228bb268fa..6fad4f8165c1 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -18,7 +18,7 @@ import tvm import tvm.testing from tvm import relax, topi -from tvm.script import relax as R +from tvm.script import ir as I, relax as R def _check(mod_actual, mod_expected): @@ -755,5 +755,84 @@ def expected(): _check(before(), expected()) +def test_multiple_relax_functions(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("func1", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + + x = relax.Var("x", R.Tensor([20, 10], "float32")) + with bb.function("func2", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") + + x = relax.Var("x", R.Tensor([20, 10], "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_add1_exp1_squeeze1", [x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add1_exp1_squeeze1 = bb.get().get_global_var("fused_add1_exp1_squeeze1") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("func1", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call(fused_add_exp_squeeze, [x, relax.const(1, "float32")]) + ) + bb.emit_func_output(gv) + + x = relax.Var("x", R.Tensor([20, 10], "float32")) + with bb.function("func2", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call(fused_add1_exp1_squeeze1, [x, relax.const(1, "float32")]) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_skip_call_dps_packed(): + @I.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + with R.dataflow(): + y = R.call_tir("func_packed_dps", x, R.Tensor((2, 3), "float32")) + R.output(y) + return y + + # FuseOps should does no change to it. + _check(Module, Module) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 91edab2bbb98..c2784edec733 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -18,7 +18,7 @@ import tvm import tvm.testing from tvm import relax, topi -from tvm.script import relax as R +from tvm.script import ir as I, relax as R, tir as T def _check(mod_before, mod_expected): @@ -559,5 +559,144 @@ def fused_argmax_add(x, offset): _check(before(), expected()) +def test_multiple_relax_functions(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") + + x = relax.Var("x", R.Tensor([20, 10], "float32")) + p0 = relax.Var("p0", R.Tensor((), "float32")) + with bb.function("fused_add1_exp1_squeeze1", [x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add1_exp1_squeeze1 = bb.get().get_global_var("fused_add1_exp1_squeeze1") + + x = relax.Var("x", R.Tensor([10, 20], "float32")) + with bb.function("func1", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call(fused_add_exp_squeeze, [x, relax.const(1, "float32")]) + ) + bb.emit_func_output(gv) + + x = relax.Var("x", R.Tensor([20, 10], "float32")) + with bb.function("func2", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call(fused_add1_exp1_squeeze1, [x, relax.const(1, "float32")]) + ) + bb.emit_func_output(gv) + + return bb.get() + + @I.ir_module + class Expected: + @R.function + def func1(x: R.Tensor((10, 20), dtype="float32")) -> R.Tensor((10, 20), dtype="float32"): + with R.dataflow(): + gv2 = R.call_tir( + fused_add_exp_squeeze, + (x, R.const(1, "float32")), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv2) + return gv2 + + @R.function + def func2(x: R.Tensor((20, 10), dtype="float32")) -> R.Tensor((20, 10), dtype="float32"): + with R.dataflow(): + gv3 = R.call_tir( + fused_add1_exp1_squeeze1, + (x, R.const(1, "float32")), + out_sinfo=R.Tensor((20, 10), dtype="float32"), + ) + R.output(gv3) + return gv3 + + @T.prim_func + def fused_add1_exp1_squeeze1( + x: T.Buffer((T.int64(20), T.int64(10)), "float32"), + p0: T.Buffer((), "float32"), + T_squeeze: T.Buffer((T.int64(20), T.int64(10)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + T_add = T.alloc_buffer((T.int64(20), T.int64(10))) + compute = T.alloc_buffer((T.int64(20), T.int64(10))) + for ax0, ax1 in T.grid(T.int64(20), T.int64(10)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x[v_ax0, v_ax1], p0[()]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] + for i0, i1 in T.grid(T.int64(20), T.int64(10)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_add[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.exp(T_add[v_i0, v_i1]) + for ax0, ax1 in T.grid(T.int64(20), T.int64(10)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(compute[v_ax0, v_ax1]) + T.writes(T_squeeze[v_ax0, v_ax1]) + T_squeeze[v_ax0, v_ax1] = compute[v_ax0, v_ax1] + + @T.prim_func + def fused_add_exp_squeeze( + x: T.Buffer((T.int64(10), T.int64(20)), "float32"), + p0: T.Buffer((), "float32"), + T_squeeze: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + T_add = T.alloc_buffer((T.int64(10), T.int64(20))) + compute = T.alloc_buffer((T.int64(10), T.int64(20))) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x[v_ax0, v_ax1], p0[()]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_add[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.exp(T_add[v_i0, v_i1]) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(compute[v_ax0, v_ax1]) + T.writes(T_squeeze[v_ax0, v_ax1]) + T_squeeze[v_ax0, v_ax1] = compute[v_ax0, v_ax1] + + _check(before(), Expected) + + +def test_skip_call_dps_packed(): + @I.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + with R.dataflow(): + y = R.call_tir("func_packed_dps", x, R.Tensor((2, 3), "float32")) + R.output(y) + return y + + # FuseTIR should does no change to it. + _check(Module, Module) + + if __name__ == "__main__": tvm.testing.main() From 166bb92fd3660cf2185e0b2d4b1b6d2394f04d4e Mon Sep 17 00:00:00 2001 From: Chaosfan <1713833595@qq.com> Date: Sun, 19 Feb 2023 12:29:53 +0800 Subject: [PATCH 44/81] [Unity][TVMScript] Overload `__neg__` for relax expr (#14045) This PR overloads `__neg__` given that `relax.negative` is now supported. Besides, it adds `test_op_misc.py` and brings tests for calling overloaded operators. --- python/tvm/relax/expr.py | 2 +- tests/python/relax/test_op_misc.py | 98 +++++++++++++++++++++ tests/python/relax/test_tvmscript_parser.py | 49 +++++++++++ 3 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 tests/python/relax/test_op_misc.py diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index f1cf815d8ea5..a20181e6fc42 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -135,7 +135,7 @@ def astype(self, dtype: Union[str, DataType]) -> "ExprWithOp": return _op_ffi_api.astype(self, dtype) # type: ignore def __neg__(self) -> "ExprWithOp": - raise ValueError("relax.negative is not supported yet.") + return _op_ffi_api.negative(self) # type: ignore def __lt__(self, other: Expr) -> "ExprWithOp": return _binary_op_helper(self, other, _op_ffi_api.less) # type: ignore diff --git a/tests/python/relax/test_op_misc.py b/tests/python/relax/test_op_misc.py new file mode 100644 index 000000000000..65772baadfdf --- /dev/null +++ b/tests/python/relax/test_op_misc.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import relax as rx +from tvm.script import relax as R +from tvm.script import tir as T + + +@tvm.register_func("test.op.identity", override=True) +def identity_packed(a): + return tvm.nd.array(a.asnumpy()) + + +@T.prim_func +def identity_tir(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [54, 96]) + B = T.match_buffer(b, [54, 96]) + + for i, j in T.grid(54, 96): + with T.block("compute"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + + +def test_call_tir() -> None: + v0 = rx.Var("v0", R.Tensor([54, 96], "float32")) + v1 = rx.call_tir(rx.extern("test.op.identity"), [v0], R.Tensor((54, 96), "float32")) + v1 = rx.call_tir(identity_tir, [v0], R.Tensor((54, 96), "float32")) + + +def test_implicit_op(): + m, n = tvm.tir.Var("m", "int64"), tvm.tir.Var("n", "int64") + x = rx.Var("x", R.Tensor([m, n], "float32")) + y = rx.Var("y", R.Tensor([m, n], "float32")) + + def _check_call(expr, op_name: str): + assert isinstance(expr, rx.Call) + if not op_name.startswith("relax."): + op_name = "relax." + op_name + op = tvm.ir.Op.get(op_name) + assert expr.op == op + + # Comparison operators + _check_call(x > y, "greater") + _check_call(x >= y, "greater_equal") + _check_call(x < y, "less") + _check_call(x <= y, "less_equal") + + # Arithmetic operators + _check_call(-x, "negative") + _check_call(x + y, "add") + _check_call(x - y, "subtract") + _check_call(x * y, "multiply") + _check_call(x / y, "divide") + _check_call(x // y, "floor_divide") + # _check_call(x % y, "mod") <= relax.mod is not implemented yet + + # Cast + _check_call(x.astype("float32"), "astype") + + # Call + call_expr = x(y)(y) + assert isinstance(call_expr.op, rx.Call) + assert call_expr.op.op == x + + # GetTupleItem + ## Eager get item for tuple + tuple_expr = rx.Tuple((x, y)) + assert tuple_expr[0] == x + assert tuple_expr[1] == y + + ## Eager get item for ShapeExpr + shape_expr = rx.ShapeExpr((1, 2)) + assert shape_expr[0] == 1 + assert shape_expr[1] == 2 + + ## Create TupleGetItem for other expr + assert isinstance(x[0], rx.TupleGetItem) + assert isinstance(x[1][0], rx.TupleGetItem) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 8df125ac72da..b458b290ec13 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -972,6 +972,55 @@ def foo(x: R.Tensor(("m + 1", "m * 2"), "float32")): # name 'm' is not defined return z +def test_arith_operators(): + @R.function + def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): + a0 = -x + a1 = x + y + a2 = x - y + a3 = x * y + a4 = x / y + a5 = x // y + + c0 = x > y + c1 = x < y + c2 = x >= y + c3 = x <= y + + tuple_expr = ((x, x), y) + t0 = tuple_expr[0] + t1 = tuple_expr[1] + t2 = tuple_expr[0][0] # <= Will normalize to two bindings + return a0, a1, a2, a3, a4, a5, c0, c1, c2, c3, t0, t1, t2 + + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = relax.Var("x", relax.TensorStructInfo([m, n], "float32")) + y = relax.Var("y", relax.TensorStructInfo([m, n], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + a0 = bb.emit(relax.op.negative(x)) + a1 = bb.emit(relax.op.add(x, y)) + a2 = bb.emit(relax.op.subtract(x, y)) + a3 = bb.emit(relax.op.multiply(x, y)) + a4 = bb.emit(relax.op.divide(x, y)) + a5 = bb.emit(relax.op.floor_divide(x, y)) + + c0 = bb.emit(relax.op.greater(x, y)) + c1 = bb.emit(relax.op.less(x, y)) + c2 = bb.emit(relax.op.greater_equal(x, y)) + c3 = bb.emit(relax.op.less_equal(x, y)) + + tuple_expr = bb.emit(relax.Tuple((relax.Tuple((x, x)), y))) + t0 = bb.emit(relax.TupleGetItem(tuple_expr, 0)) + t1 = bb.emit(relax.TupleGetItem(tuple_expr, 1)) + tmp = bb.emit(relax.TupleGetItem(tuple_expr, 0)) + t2 = bb.emit(relax.TupleGetItem(tmp, 0)) + bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, a5, c0, c1, c2, c3, t0, t1, t2))) + + _check(foo, bb.get()["foo"]) + + # TODO(relax-team): enable this when vm ops are ready @pytest.mark.xfail def test_vm_ops(): From 6f4ca6b29c2f1f9a48a85c14078ac946644e7667 Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 20 Feb 2023 16:47:56 +0900 Subject: [PATCH 45/81] [Unity][VM] Add per-op profiling support (#14053) Adds per-op profiling support to Relax VM, in a way similar to how Relay VM is instrumented via the common profiling infra in the runtime. Profiling over RPC is supported. Example output: ``` Name Duration (us) Percent Device Count Argument Shapes conv2d1 705,779.00 51.22 hexagon0 1 float32[1, 64, 56, 56], float32[1, 64, 54, 54] conv2d 669,589.00 48.60 hexagon0 1 float32[1, 64, 56, 56], float32[1, 64, 56, 56] relu 683.00 0.05 hexagon0 1 float32[1, 64, 56, 56], float32[1, 64, 56, 56] relu1 679.00 0.05 hexagon0 1 float32[1, 64, 54, 54], float32[1, 64, 54, 54] vm.builtin.check_tensor_info 28.00 0.00 hexagon0 1 float32[1, 64, 56, 56] vm.builtin.match_shape 25.00 0.00 hexagon0 1 float32[1, 64, 56, 56] ---------- Sum 1,376,783.00 99.93 6 Total 0.00 cpu0 1 Total 1,377,809.00 hexagon0 1 Configuration ------------- Number of threads: 4 Executor: VM ``` The original PR: https://github.com/tlc-pack/relax/pull/422 --- include/tvm/runtime/relax_vm/vm.h | 5 + python/tvm/relax/vm.py | 33 ++++++- src/runtime/relax_vm/executable.cc | 7 ++ src/runtime/relax_vm/vm.cc | 129 +++++++++++++++++++++++- tests/python/relax/test_vm_profiler.py | 130 +++++++++++++++++++++++++ 5 files changed, 295 insertions(+), 9 deletions(-) create mode 100644 tests/python/relax/test_vm_profiler.py diff --git a/include/tvm/runtime/relax_vm/vm.h b/include/tvm/runtime/relax_vm/vm.h index cfe388090456..d39de74f2dab 100644 --- a/include/tvm/runtime/relax_vm/vm.h +++ b/include/tvm/runtime/relax_vm/vm.h @@ -120,6 +120,11 @@ class VirtualMachine : public runtime::ModuleNode { * \return Created VM */ static ObjectPtr Create(); + /*! + * \brief Create an instance of VM with the profiling feature enabled. + * \return Created VM + */ + static ObjectPtr CreateProfiler(); /*! * \brief Helper function for vm closure functions to get the context ptr * \param arg The argument value. diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index 2cf1250690a0..0594d86f2a82 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -25,6 +25,7 @@ from tvm.ir.module import IRModule from tvm.runtime import Device, Module, PackedFunc, container from tvm.runtime.object import Object +from tvm.runtime.profiling import Report from tvm.tir.function import PrimFunc from . import _ffi_api from ..rpc.base import RPC_SESS_MASK @@ -63,6 +64,7 @@ def __init__( exec: Union[Executable, Module], device: Union[Device, List[Device]], memory_cfg: Optional[Union[str, Dict[Device, str]]] = None, + profile: bool = False, ) -> None: """ Construct a VirtualMachine wrapper object. @@ -82,12 +84,12 @@ def __init__( allocator type. If memory_cfg is a dict, each device uses the allocator type specified in the dict, or pooled allocator if not specified in the dict. + + profile : Optional[bool] + Whether or not to enable profiling. """ - self.module = ( - exec.mod["vm_load_executable"]() - if isinstance(exec, Executable) - else exec["vm_load_executable"]() - ) + load_exec = "vm_profiler_load_executable" if profile else "vm_load_executable" + self.module = exec.mod[load_exec]() if isinstance(exec, Executable) else exec[load_exec]() self._invoke_closure = self.module["invoke_closure"] self._save_function = self.module["save_function"] self._set_input = self.module["set_input"] @@ -449,6 +451,27 @@ def time_evaluator( f_preproc=f_preproc, ) + def profile(self, func_name: str, *args): + """Profile a function call. + Parameters + ---------- + func_name : str + The name of the function. + args: List of NDArray or other objects supported by PackedFunc. + The arguments to the function. + Returns + ------- + report: tvm.runtime.profiling.Report + The formatted profiling result, showing per-op timing measurements. + """ + cargs: List[Any] = [] + + for arg in args: + self._convert(arg, cargs) + + report_json = self.module["profile"](func_name, *cargs) + return Report.from_json(report_json) + def _vmcodegen( builder: "relax.ExecBuilder", diff --git a/src/runtime/relax_vm/executable.cc b/src/runtime/relax_vm/executable.cc index b7915d7978aa..2090a3b25413 100644 --- a/src/runtime/relax_vm/executable.cc +++ b/src/runtime/relax_vm/executable.cc @@ -67,6 +67,13 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtrLoadExecutable(GetObjectPtr(this)); *rv = Module(vm); }); + } else if (name == "vm_profiler_load_executable") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ObjectPtr vm = VirtualMachine::CreateProfiler(); + ICHECK(sptr_to_self.get() == this); + vm->LoadExecutable(GetObjectPtr(this)); + *rv = Module(vm); + }); } return nullptr; } diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index 3cf65faaa81a..3b952c1ff597 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -23,8 +23,11 @@ #include #include +#include #include +#include + namespace tvm { namespace runtime { namespace relax_vm { @@ -177,7 +180,7 @@ class VirtualMachineImpl : public VirtualMachine { void Init(const std::vector& devices, const std::vector& alloc_types) final; - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override; VMClosure GetClosure(const String& func_name) final; @@ -315,11 +318,29 @@ class VirtualMachineImpl : public VirtualMachine { * \param curr_frame The current frame. * \param inst The call instruction. */ - inline void RunInstrCall(VMFrame* curr_frame, Instruction inst); + virtual void RunInstrCall(VMFrame* curr_frame, Instruction inst); /*! \brief Run VM dispatch loop. */ void RunLoop(); + /*! + * \brief Retrieve the name of the function identified by the given index. + * \param idx The index into the VM executable function table. + * \return The name of the function. + */ + const std::string& GetFuncName(int idx) { return exec_->func_table[idx].name; } + + /*! + * \brief Retrieve the inputs for a function. + * \param func_name The name of the function. + * \return The function inputs. + */ + const std::vector& GetInputsFor(const std::string& func_name) { + return inputs_[func_name]; + } + + void ClearInputsFor(const std::string& func_name) { inputs_.erase(func_name); } + private: //-------------------------------------------------------- // Internal states for execution. @@ -519,7 +540,7 @@ void VirtualMachineImpl::SetInput(std::string func_name, TVMArgs args, int offse int index = i - offset; func_args[index] = ConvertArgToDevice(args[i], devices[0]); } - inputs_.emplace(func_name, func_args); + inputs_[func_name] = func_args; } else { LOG(FATAL) << "ValueError: Unknown function: " << func_name; } @@ -706,7 +727,7 @@ void VirtualMachineImpl::InitFuncPool() { } void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { - DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << exec_->func_table[instr.func_idx].name; + DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << GetFuncName(instr.func_idx); // Use the call arg stack from the current frame to increase reuse // and avoid re-allocation @@ -806,6 +827,106 @@ void VirtualMachineImpl::RunLoop() { ObjectPtr VirtualMachine::Create() { return make_object(); } +/*! + * \brief An extension of VirtualMachineImpl to support per-op profiling + * It overrides RunInstrCall to add instrumentations around it. + */ +class VirtualMachineProfiler : public VirtualMachineImpl { + public: + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { + if (name == "profile") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string f_name = args[0]; + VMClosure clo = this->GetClosure(f_name); + + std::vector devices; + for (auto dev : this->devices) { + if (dev.device_type > 0) { + devices.push_back(dev); + } + } + + prof_ = profiling::Profiler(devices, {}, {{String("Executor"), String("VM")}}); + + auto inputs = GetInputsFor(f_name); + + bool clear_inputs = false; + if (inputs.size() == 0) { + ICHECK(args.num_args > 1) << "No input is provided"; + TVMArgs f_args(args.values + 1, args.type_codes + 1, args.num_args - 1); + SetInput(f_name, args, 1); + inputs = GetInputsFor(f_name); + clear_inputs = true; + } else { + ICHECK_EQ(args.num_args, 1) << "Inputs are already provided by set_input."; + } + + // warmup + this->InvokeClosureInternal(clo, inputs); + + prof_->Start(); + this->InvokeClosureInternal(clo, inputs); + prof_->Stop(); + + // Return the report as json, since profiling::Report object is not supported by RPC + std::string report_json = prof_->Report()->AsJSON(); + *rv = report_json; + + prof_ = std::nullopt; // releases hardware counters + if (clear_inputs) { + // SetInput modifies the internal states of VM. Undo the change after profiling. + ClearInputsFor(f_name); + } + }); + } else { + return VirtualMachineImpl::GetFunction(name, sptr_to_self); + } + } + + protected: + void RunInstrCall(VMFrame* curr_frame, Instruction inst) override { + bool profiling = false; + if (prof_ && prof_->IsRunning()) { + auto f_name = GetFuncName(inst.func_idx); + std::optional dev; + std::vector arrs; + for (Index i = 0; i < inst.num_args; ++i) { + Instruction::Arg arg = inst.args[i]; + if (arg.kind() == Instruction::ArgKind::kRegister) { + auto reg = ReadRegister(curr_frame, arg.value()); + if (reg.type_code() == kTVMNDArrayHandle) { + NDArray arr = reg; + dev = arr->device; + arrs.push_back(arr); + } + } + } + + std::unordered_map metrics; + metrics["Argument Shapes"] = profiling::ShapeString(arrs); + + // If a sutiable device is found, enable profiling. + if (dev) { + profiling = true; + prof_->StartCall(f_name, *dev, metrics); + } + } + + VirtualMachineImpl::RunInstrCall(curr_frame, inst); + + if (profiling) { + prof_->StopCall(); + } + } + + private: + std::optional prof_; +}; + +ObjectPtr VirtualMachine::CreateProfiler() { + return make_object(); +} + } // namespace relax_vm } // namespace runtime } // namespace tvm diff --git a/tests/python/relax/test_vm_profiler.py b/tests/python/relax/test_vm_profiler.py new file mode 100644 index 000000000000..90737cc9c980 --- /dev/null +++ b/tests/python/relax/test_vm_profiler.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing + +from tvm import relax, rpc +from tvm.contrib import utils +from tvm.relax.testing import nn +from tvm.script import relax as R + + +def get_exec(data_shape): + builder = relax.BlockBuilder() + weight1_np = np.random.randn(64, 64).astype("float32") + weight2_np = np.random.randn(64, 64).astype("float32") + + with builder.function("main"): + model = nn.Sequential( + nn.Linear(data_shape[1], weight1_np.shape[0], bias=False), + nn.ReLU(), + nn.Linear(weight2_np.shape[0], weight2_np.shape[1], bias=False), + nn.ReLU(), + ) + data = nn.Placeholder(data_shape, name="data") + output = model(data) + params = [data] + model.parameters() + builder.emit_func_output(output, params=params) + + mod = builder.get() + + params = {"linear_weight": weight1_np, "linear_weight1": weight2_np} + mod = relax.transform.BindParams("main", params)(mod) + + target = "llvm" + return relax.vm.build(mod, target) + + +def test_conv2d_cpu(): + data_np = np.random.randn(1, 64).astype("float32") + ex = get_exec(data_np.shape) + + vm = relax.VirtualMachine(ex, tvm.cpu(), profile=True) + report = vm.profile("main", tvm.nd.array(data_np)) + print(report) + + assert "Duration" in str(report) + assert "matmul" in str(report) + + +def with_rpc(ex, f, data_np): + temp = utils.tempdir() + path = temp.relpath("vm_library.so") + ex.mod.export_library(path) + + server = rpc.Server("127.0.0.1") + remote = rpc.connect(server.host, server.port, session_timeout=10) + + remote.upload(path) + rexec = remote.load_module("vm_library.so") + + device = remote.cpu() + + vm = relax.vm.VirtualMachine(exec=rexec, device=device, profile=True) + data = tvm.nd.array(data_np, device) + + f(vm, data) + + +def test_rpc(): + data_np = np.random.randn(1, 64).astype("float32") + ex = get_exec(data_np.shape) + + def callback(vm, data): + vm.profile("main", data) + + vm.set_input("main", data) + report = vm.profile("main") + + assert "matmul" in str(report) + print(report) + + with_rpc(ex, callback, data_np) + + +def test_tuple(): + @tvm.script.ir_module + class NestedTuple: + @R.function + def main( + x: R.Tensor((16,), "float32") + ) -> R.Tuple( + R.Tuple( + R.Tensor((16,), "float32"), + R.Tuple( + R.Tensor((16,), "float32"), + ), + ), + R.Tensor((16,), "float32"), + ): + return ((x, (x,)), x) + + target = "llvm" + ex = relax.vm.build(NestedTuple, target) + + data_np = np.random.randn(16).astype("float32") + + def callback(vm, data): + report = vm.profile("main", data) + assert "vm.builtin.make_tuple" in str(report) + + with_rpc(ex, callback, data_np) + + +if __name__ == "__main__": + tvm.testing.main() From be1cc698d291095fe0e7bc708da2cc85821bd5f2 Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 20 Feb 2023 17:08:44 +0900 Subject: [PATCH 46/81] [Unity][BYOC] Add pattern-based partitioning pass (#14054) This adds a new pass, FuseOpsByPattern, which applies pattern matching to each function in the given module, and groups matched expressions into a new function. The end result is similar to FuseOps, but fusion is driven completely by the provided patterns. The implementation also reuses OperatorFusor used by FuseOps to create grouped functions from partitioned groups, further illustrating the similarity between the two passes. The new pass will serve the same role the MergeComposite pass plays in Relay BYOC - grouped functions are annotated with the "composite" attribute to denote what operations a given function consists of, and offloaded to external backends. But it can be also useful in non-BYOC settings, for example to support advanced fusion that the op-kind based one doesn't handle (fused MHA, conv2d / gemm + reduction fusion, etc). The original PR: https://github.com/tlc-pack/relax/pull/366 --- python/tvm/relax/transform/transform.py | 37 +- src/relax/transform/fuse_ops.cc | 199 ++++++++ .../test_transform_fuse_ops_by_pattern.py | 464 ++++++++++++++++++ 3 files changed, 699 insertions(+), 1 deletion(-) create mode 100644 tests/python/relax/test_transform_fuse_ops_by_pattern.py diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 1f14823b5a94..bf90ef0b0986 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -19,7 +19,7 @@ import functools import inspect import types -from typing import Callable, Dict, Union, Optional, List +from typing import Callable, Dict, Union, Optional, List, Tuple import numpy as np # type: ignore import tvm.ir from tvm.runtime import NDArray @@ -241,6 +241,41 @@ def FuseTIR() -> tvm.ir.transform.Pass: return _ffi_api.FuseTIR() # type: ignore +def FuseOpsByPattern( + patterns: List[Tuple], annotate_codegen: bool = False +) -> tvm.ir.transform.Pass: + """Apply pattern matching to each function in the given module, and group matched expressions + into a new function. + + The end result is similar to FuseOps, but fusion is driven completely by the provided patterns. + + Parameters + ---------- + patterns : List[Tuple[str, DFPattern]] + The patterns to detect. The order of the patterns determines the order of priority in which + they are matched. Higher-priority patterns should come earlier in the list. + The string is the name of the corresponding pattern. It becomes the value of the kComposite + attribute of a fused function after a successful matching. + + annotate_codegen : bool + If True, wrap each created composite function with another function, whose body consists + only of a call to the composite function, and annotate the outer function with "Codegen" + and "global_symbol" attributes. The "Codegen" attribute is set as the prefix of the + corresponding pattern name. For example, "dnnl" if the pattern name is "dnnl.conv2d_relu". + + This must be True if the created composite functions are intended to be offloaded to + an external backend without using the MergeCompositeFunctions pass. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for pattern-based fusion. + + """ + pattern_names, df_patterns = zip(*patterns) + return _ffi_api.FuseOpsByPattern(pattern_names, df_patterns, annotate_codegen) # type: ignore + + def LegalizeOps(customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None): """Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR PrimFuncs. diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 0a0209bb8769..3b78274cec58 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -28,12 +28,15 @@ */ #include +#include +#include #include #include #include #include #include +#include #include "../../relay/analysis/graph_partitioner.h" #include "../../support/arena.h" @@ -880,6 +883,188 @@ IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) { return OperatorFusor(mod, graph, groups, /*lift_constants*/ true).Transform(); } +IRModule MakeGroupedFunctions( + IRModule mod, const std::unordered_map& partition, + bool lift_constants) { + return OperatorFusor(mod, partition, lift_constants).Transform(); +} + +static Map GetBindingInverse(const Map& binding) { + Map value_to_bound_var; + for (const auto& [var, val] : binding) { + value_to_bound_var.Set(val, var); + } + return value_to_bound_var; +} + +/*! \brief Create a "partitioning", a map from interior / leaf expr to its representative group, + * based on the provided pattern. The result can be passed to OperatorFusor above to fuse operations + * in a group and create a grouped function. + */ +class PatternBasedPartitioner : ExprVisitor { + public: + using Group = GraphPartitioner::Group; + using GroupMap = OperatorFusor::GroupMap; + using ExprVisitor::VisitExpr_; + + static GroupMap Run(String pattern_name, DFPattern pattern, Expr expr, support::Arena* arena) { + PatternBasedPartitioner part(pattern_name, pattern, AnalyzeVar2Value(expr)); + // Initialize each expr to have its own group + PostOrderVisit( + expr, [arena, &part](const Expr& e) { part.group_map_[e.get()] = arena->make(); }); + part.VisitExpr(expr); + return part.group_map_; + } + + PatternBasedPartitioner(String pattern_name, DFPattern pattern, const Map& bindings) + : pat_name_(pattern_name), + pat_(pattern), + bindings_(bindings), + value_to_bound_var_(GetBindingInverse(bindings)) {} + + void VisitExpr_(const CallNode* call) override { + if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef(call), bindings_)) { + // If a match is found, put all matching expressions into the same group. + // OperatorFusor also requires that the bound variable be in the same group as the RHS value. + // Since is_op(...) based pattern only matches against call nodes on the right hand side, + // we need to take care of groups corresponding to the LHS bound variables carefully. + + // In the example below, conv2d + relu pattern would match if the "call" variable in this + // function points to the relu op. We identify the group corresponding to "conv1", and make + // it the representative group for relu and conv2d on the RHS and also "lv" on the LHS. + + // with R.dataflow(): + // lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(...) + // conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv) + + // parent_group corresponds to the group of "conv1" above. + auto parent_group = GetGroupForBoundVar(GetRef(call)); + ICHECK(parent_group); + parent_group->attrs.Set(attr::kComposite, pat_name_); + + for (const auto& [pat, match] : matches_opt.value()) { + ICHECK(group_map_.count(match.get())); + // Put all matching call nodes into the parent group. + if (pat->IsInstance() && match != GetRef(call)) { + AddToGroup(match, parent_group); + // Put the bound variable on the LHS into the same parent group. + AddToGroup(value_to_bound_var_[match], parent_group); + } + } + } + } + + private: + void AddToGroup(Expr e, Group* to) { + if (group_map_[e.get()] != to) { + --group_map_[e.get()]->num_nodes; + group_map_[e.get()]->parent = to; + ++to->num_nodes; + } + } + + Group* GetGroupForBoundVar(Expr e) { + ICHECK(value_to_bound_var_.count(e)); + auto bound_var = value_to_bound_var_[e]; + ICHECK(group_map_.count(bound_var.get())); + return group_map_[bound_var.get()]->FindRoot(); + } + + String pat_name_; + DFPattern pat_; + Map bindings_; + Map value_to_bound_var_; + GroupMap group_map_; +}; + +/*! + * \brief Wrap each created composite function with another function, whose body consists + * only of a call to the composite function, and annotate the outer function with kCodegen + * and kGlobalSymbol attributes. + */ +class CompositeFunctionAnnotator : public ExprMutator { + public: + explicit CompositeFunctionAnnotator(IRModule mod) : ExprMutator(mod) {} + using ExprMutator::VisitExpr_; + + IRModule Run() { + auto mod = builder_->GetContextIRModule(); + auto gvar = mod->GetGlobalVar("main"); + auto func = Downcast(mod->Lookup(gvar)); + auto new_func = + Function(func->params, VisitExpr(func->body), func->ret_struct_info, func->attrs); + builder_->UpdateFunction(gvar, new_func); + return builder_->GetContextIRModule(); + } + + Expr VisitExpr_(const CallNode* call_node) final { + if (auto const* gvar = call_node->op.as()) { + if (auto it = gvar_map_.find(gvar); it != gvar_map_.end()) { + return Call(it->second, call_node->args); + } + auto func = builder_->GetContextIRModule()->Lookup(GetRef(gvar)); + if (auto composite_name = func->GetAttr(attr::kComposite)) { + auto new_func = Downcast(VisitExpr(func)); + auto codegen_name = GetCodegenName(composite_name.value()); + auto gsymbol = gvar->name_hint + "_" + codegen_name; + new_func = WithAttrs(new_func, + {{attr::kCodegen, codegen_name}, {tvm::attr::kGlobalSymbol, gsymbol}}); + builder_->GetContextIRModule()->Remove(GetRef(gvar)); + auto new_gvar = builder_->AddFunction(new_func, gsymbol); + gvar_map_[gvar] = new_gvar; + return Call(new_gvar, call_node->args); + } + } + return ExprMutator::VisitExpr_(call_node); + } + + Expr VisitExpr_(const FunctionNode* func_node) final { + auto f_inner = ExprMutator::VisitExpr_(func_node); + auto composite_name = func_node->GetAttr(attr::kComposite); + ICHECK(composite_name); + + Array param_vars; + Array params; + + for (auto v : func_node->params) { + Var new_v(v->name_hint(), GetStructInfo(v)); + param_vars.push_back(new_v); + params.push_back(new_v); + } + + return Function(param_vars, Call(f_inner, params), func_node->ret_struct_info); + } + + private: + String GetCodegenName(const std::string& composite_name) { + auto delim_pos = composite_name.find("."); + ICHECK(delim_pos != std::string::npos) << "The pattern name for a composite function should " + "start with a compiler name followed by period."; + return composite_name.substr(0, delim_pos); + } + + /*! \brief A map from old global vars to their replacements. */ + std::unordered_map gvar_map_; +}; + +IRModule FuseOpsByPattern(const tvm::Array& pattern_names, + const tvm::Array& patterns, IRModule mod, + bool annotate_codegen) { + support::Arena arena; + for (size_t i = 0; i < pattern_names.size(); ++i) { + OperatorFusor::GroupMap group_map; + for (const auto& entry : mod->functions) { + auto map = PatternBasedPartitioner::Run(pattern_names[i], patterns[i], entry.second, &arena); + group_map.insert(map.begin(), map.end()); + } + mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ false); + } + if (annotate_codegen) { + return CompositeFunctionAnnotator(mod).Run(); + } + return mod; +} + namespace transform { Pass FuseOps(int fuse_opt_level) { @@ -897,6 +1082,20 @@ Pass FuseOps(int fuse_opt_level) { TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); +Pass FuseOpsByPattern(const tvm::Array& pattern_names, + const tvm::Array& patterns, bool annotate_codegen) { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { + return relax::FuseOpsByPattern(pattern_names, patterns, m, annotate_codegen); + }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseOpsByPattern", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern); + } // namespace transform } // namespace relax diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py new file mode 100644 index 000000000000..da5b92fb64e0 --- /dev/null +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -0,0 +1,464 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import numpy as np + +import tvm + +from tvm import relax +from tvm.script import relax as R +from tvm.relax.dpl.pattern import make_fused_bias_activation_pattern, is_op, wildcard + + +@tvm.script.ir_module +class Conv2dReLU: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1))) + R.output(conv1) + + return conv1 + + +@tvm.script.ir_module +class Conv2dReLU_composite_annotated: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + with R.dataflow(): + gv: R.Tensor( + (1, 64, 56, 56), dtype="float32" + ) = fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu_dnnl( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr( + {"Codegen": "dnnl", "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu_dnnl"} + ) + + @R.function + def gv1( + data2: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight12: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data2, + weight12, + padding=[1, 1, 1, 1], + ) + gv2: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv) + R.output(gv2) + return gv2 + + gv11: R.Tensor((1, 64, 56, 56), dtype="float32") = gv1(data1, weight11) + return gv11 + + +@tvm.script.ir_module +class Conv2dReLUx2: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + weight2: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1))) + conv2 = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0))) + R.output(conv2) + + return conv2 + + +@tvm.script.ir_module +class Conv2dReLUx2Partitioned: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_nn_conv2d_relax_nn_relu( + data, weight1 + ) + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_conv2d_relax_nn_relu1( + lv, weight2 + ) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data1, weight11, padding=[1, 1, 1, 1] + ) + gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu1( + conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + conv1, weight21, padding=[0, 0, 0, 0] + ) + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv2) + R.output(gv2) + return gv2 + + +@tvm.script.ir_module +class Conv2dReLUx2Partitioned_only_conv2d: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_nn_conv2d(data, weight1) + conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv) + lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_conv2d1(conv1, weight2) + conv2d: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1) + R.output(conv2d) + return conv2d + + @R.function + def fused_relax_nn_conv2d( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d"}) + with R.dataflow(): + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data1, weight11, padding=[1, 1, 1, 1] + ) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d1( + conv11: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d"}) + with R.dataflow(): + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + conv11, weight21, padding=[0, 0, 0, 0] + ) + R.output(gv1) + return gv1 + + +@tvm.script.ir_module +class Conv2dConv2dReLU: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + weight2: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.conv2d(data, weight1, padding=(1, 1)) + conv2d = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0))) + R.output(conv2d) + + return conv2d + + +@tvm.script.ir_module +class Conv2dConv2dReLUPartitioned: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_nn_conv2d(data, weight1) + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_conv2d_relax_nn_relu( + lv, weight2 + ) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu( + conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + conv1, weight21, padding=[0, 0, 0, 0] + ) + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_conv2d( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d"}) + with R.dataflow(): + gv2: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data1, weight11, padding=[1, 1, 1, 1] + ) + R.output(gv2) + return gv2 + + +@tvm.script.ir_module +class BranchTupleOutput: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.conv2d(data, weight) + relu1 = R.nn.relu(conv1) + gelu1 = R.nn.gelu(relu1) + gelu2 = R.nn.gelu(conv1) + out = relax.op.add(gelu1, gelu2) + R.output(out) + + return out + + +@tvm.script.ir_module +class BranchTupleOutputPartitioned: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 64, 54, 54), dtype="float32"), + R.Tensor((1, 64, 54, 54), dtype="float32"), + ) = fused_relax_nn_conv2d_relax_nn_relu(data, weight) + lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = lv[1] # conv1 + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv[0] # relu(conv1) + gelu1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv2) + gelu2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv1) + out: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(gelu1, gelu2) + R.output(out) + return out + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tuple( + R.Tensor((1, 64, 54, 54), dtype="float32"), R.Tensor((1, 64, 54, 54), dtype="float32") + ): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(data1, weight1) + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(gv) + R.output(gv, gv1) + return (gv1, gv) + + +@tvm.script.ir_module +class Branch: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.conv2d(data, weight) + relu1 = R.nn.relu(conv1) + gelu1 = R.nn.gelu(conv1) + + out = relax.op.add(relu1, gelu1) + R.output(out) + + return out + + +@tvm.script.ir_module +class Conv2dx2: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), "float16"), + weight1: R.Tensor((16, 3, 3, 16), "float16"), + weight2: R.Tensor((16, 3, 3, 16), "float16"), + ): + with R.dataflow(): + conv1 = relax.op.nn.conv2d( + data, weight1, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + conv2 = relax.op.nn.conv2d( + conv1, weight2, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + R.output(conv2) + + return conv2 + + +@tvm.script.ir_module +class Conv2dx2_partitioned: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), + weight2: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + with R.dataflow(): + lv: R.Tensor((16, 32, 32, 16), dtype="float16") = fused_relax_nn_conv2d_cutlass( + data, weight1 + ) + gv: R.Tensor((16, 32, 32, 16), dtype="float16") = fused_relax_nn_conv2d_cutlass( + lv, weight2 + ) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_cutlass( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + R.func_attr({"Codegen": "cutlass", "global_symbol": "fused_relax_nn_conv2d_cutlass"}) + + @R.function + def gv( + data_1: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1}) + with R.dataflow(): + gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = R.nn.conv2d( + data_1, + weight1_1, + padding=[1, 1, 1, 1], + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(gv_1) + return gv_1 + + gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv(data, weight1) + return gv1 + + +conv2d_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation=None) +conv2d_relu_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", activation="relax.nn.relu") + + +def check(mod, patterns, expected, annoatate_codegen=False): + partitioned = relax.transform.FuseOpsByPattern(patterns, annoatate_codegen)(mod) + tvm.ir.assert_structural_equal(partitioned, expected) + + +def test_partition_conv2d_relu(): + check(Conv2dReLUx2, [("dnnl.conv2d_relu", conv2d_relu_pat)], Conv2dReLUx2Partitioned) + + +def test_partition_multiple_patterns(): + check( + Conv2dConv2dReLU, + [("dnnl.conv2d_relu", conv2d_relu_pat), ("dnnl.conv2d", conv2d_pat)], + Conv2dConv2dReLUPartitioned, + ) + + +def test_partition_order(): + check( + Conv2dReLUx2, + [("dnnl.conv2d", conv2d_pat), ("dnnl.conv2d_relu", conv2d_relu_pat)], + Conv2dReLUx2Partitioned_only_conv2d, + ) + + +def test_branch_tuple_output(): + check(BranchTupleOutput, [("dnnl.conv2d_relu", conv2d_relu_pat)], BranchTupleOutputPartitioned) + + +def test_cyclic_dependency(): + conv_pat = make_fused_bias_activation_pattern("relax.nn.conv2d") + relu_pat = is_op("relax.nn.relu")(conv_pat) + add_pat = is_op("relax.add")(relu_pat, wildcard()) + + with pytest.raises(tvm.error.TVMError) as err: + relax.transform.FuseOpsByPattern([("compiler_A.conv2d_relu_add", add_pat)])(Branch) + + assert "A cyclic dependency detected" in str(err.value) + + +def test_bind_params(): + weight_np = np.random.randn(64, 64, 3, 3).astype("float32") + mod = tvm.transform.Sequential( + [ + relax.transform.BindParams("main", {"weight1": weight_np}), + relax.transform.FuseOpsByPattern([("dnnl.conv2d_relu", conv2d_relu_pat)]), + ] + )(Conv2dReLU) + + assert "fused_relax_nn_conv2d_relax_nn_relu" in [var.name_hint for var in mod.functions.keys()] + + for gvar, f in mod.functions.items(): + if gvar.name_hint == "fused_relax_nn_conv2d_relax_nn_relu": + conv2d = f.body.blocks[0].bindings[0].value + assert isinstance(conv2d.args[1], relax.Constant) + + +def test_annotate_codegen(): + check( + Conv2dReLU, + [("dnnl.conv2d_relu", conv2d_relu_pat)], + Conv2dReLU_composite_annotated, + annoatate_codegen=True, + ) + + +def test_multiple_calls_same_extern(): + pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None) + check(Conv2dx2, [("cutlass.conv2d", pat)], Conv2dx2_partitioned, annoatate_codegen=True) + + +if __name__ == "__main__": + pytest.main([__file__]) From 6d5f6f0e93b67080a741ff17ed6000d331194843 Mon Sep 17 00:00:00 2001 From: Chaosfan <1713833595@qq.com> Date: Tue, 21 Feb 2023 11:52:29 +0800 Subject: [PATCH 47/81] [Unity] Relax op: collapse sum (#14059) This PR brings high-level operators `relax.collapse_sum_like` and `relax.collapse_sum_to` which is useful when doing AD in Relax. To achieve this, it exposes the interface of `topi.collapse_sum`. Moreover, this PR also implements the legalization of these op and adds corresponding tests. --- python/tvm/relax/op/manipulate.py | 53 +++ .../transform/legalize_ops/manipulate.py | 5 + python/tvm/script/ir_builder/relax/ir.py | 4 + python/tvm/topi/reduction.py | 31 ++ src/relax/op/tensor/manipulate.cc | 130 +++++++ src/relax/op/tensor/manipulate.h | 21 ++ src/topi/reduction.cc | 4 + tests/python/relax/test_op_manipulate.py | 326 ++++++++++++++++++ .../test_transform_legalize_ops_manipulate.py | 103 ++++++ .../test_tvmscript_parser_op_manipulate.py | 33 ++ tests/python/topi/python/test_topi_reduce.py | 39 +++ 11 files changed, 749 insertions(+) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index a46c62e1f12b..25bf5251912a 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -261,3 +261,56 @@ def squeeze(x: Expr, axis: Optional[Union[int, List[int]]] = None) -> Expr: if isinstance(axis, int): axis = [axis] return _ffi_api.squeeze(x, axis) # type: ignore + + +def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr: + """Return a summation of data to the shape of collapse_target. + + For details, please see relax.op.collapse_sum_to. + + Parameters + ---------- + data : relax.Expr + The input tensor. + + collapse_target : relax.Expr + The tensor whose shape is the shape to collapse to. + + Returns + ------- + result : relax.Expr + The result tensor after summation. + """ + return _ffi_api.collapse_sum_like(data, collapse_target) # type: ignore + + +def collapse_sum_to(data: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: + """Return a summation of data to the given shape. + + collapse_sum_to is intended as the backward operator of tvm.relax.op.broadcast_to and + other broadcast operators in the automatic differentiation process. + + We expect that data is the result of broadcasting some tensor of the given shape in some + broadcast operation. Thus the given `shape` and `data.shape` must follow broadcast rules. + + During computation, all axes of `data.shape` and `shape` are checked from right to left. + For an axis, if it follows these rules, `data` will be summed over this axis: + - the axis exists in `data.shape` but not in `shape`, or + - the axis exists in `data.shape` and equals to 1 in `shape`. + + Parameters + ---------- + data : relax.Expr + The input tensor. + + shape : Union[Tuple[PrimExprLike], relax.Expr] + The shape to collapse to. + + Returns + ------- + result : relax.Expr + The result tensor of the given shape after summation. + """ + if isinstance(shape, (tuple, list)): + shape = ShapeExpr(shape) + return _ffi_api.collapse_sum_to(data, shape) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 76e3e74bab9b..5b992eff1d07 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -37,6 +37,11 @@ def reshape_call_te(bb: BlockBuilder, call: Call): register_legalize("relax.broadcast_to", _reshape(topi.broadcast_to, "broadcast_to")) register_legalize("relax.reshape", _reshape(topi.reshape, "reshape")) +register_legalize( + "relax.collapse_sum_like", + _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), +) +register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum, "collapse_sum")) @register_legalize("relax.concat") diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 7298b8c6e54f..43918ce7ec83 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -45,6 +45,8 @@ call_tir, ceil, clip, + collapse_sum_like, + collapse_sum_to, concat, cos, cosh, @@ -485,6 +487,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "call_builtin_with_ctx", "ceil", "clip", + "collapse_sum_like", + "collapse_sum_to", "concat", "cos", "cosh", diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py index 45d07af577a3..5045cb817457 100644 --- a/python/tvm/topi/reduction.py +++ b/python/tvm/topi/reduction.py @@ -248,3 +248,34 @@ def prod(data, axis=None, keepdims=False): ret : tvm.te.Tensor """ return cpp.prod(data, axis, keepdims) + + +def collapse_sum(data, target_shape): + """Return a summation of data to the given shape. + + collapse_sum is intended as the backward operator of topi broadcast operators in the automatic + differentiation process. + + We expect that data is the result of broadcasting some tensor of target_shape in some + broadcast operation. Thus target_shape and data.shape must follow broadcast rules. + + During computation, the axes of data.shape and target_shape are checked from right to left. + For every axis, if it either: + - exist in data but not in target_shape, or + - is larger than 1 in data and equals to 1 in target_shape, + data will be summed over this axis. + + Parameters + ---------- + data : tvm.te.Tensor + The input tensor. + + shape : Tuple[int] + The shape to collapse to. + + Returns + ------- + ret : tvm.te.Tensor + The result tensor after summation. + """ + return cpp.collapse_sum(data, target_shape) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 8ce2a541da53..e146a604affd 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -839,5 +839,135 @@ TVM_REGISTER_OP("relax.squeeze") .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoSqueeze); +void CheckCollapseShape(const Call& call, const BlockBuilder& ctx, + const Array& data_shape, const Array& target_shape) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + + int data_ndim = data_shape.size(); + int target_ndim = target_shape.size(); + + int data_ax = data_ndim - 1; + int target_ax = target_ndim - 1; + for (; data_ax >= 0; --data_ax) { + if (target_ax < 0) { + continue; + } + const PrimExpr& dim0 = data_shape[data_ax]; + const PrimExpr& dim1 = target_shape[target_ax]; + const auto* int_dim0 = dim0.as(); + const auto* int_dim1 = dim1.as(); + + if (analyzer->CanProveEqual(dim0, dim1) || (int_dim1 != nullptr && int_dim1->value == 1)) { + --target_ax; + } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) { + ctx->ReportFatal(Diagnostic::Error(call) + << "In " << call->op << ", the data shape at dim " << data_ax << " is " + << dim0 << " and the target shape at dim " << target_ax << " is " << dim1 + << ", which do not match the rule of collapse sum."); + } else { + // Todo(relax-team): At this moment, enforcing MatchCast is fine. But we may need to revisit + // this requirement to reduce the workload of importers and better support dynamic shapes. + ctx->ReportFatal(Diagnostic::Error(call) + << call->op + << " fails to match the axes because of unknown dim or symbolic" + " shape. In this position the dim of data shape is " + << dim0 << " while the dim of target shape is " << dim1 + << ". If it is symbolic, consider use MatchCast first."); + } + } +} + +/* relax.collapse_sum_like */ +Expr collapse_sum_like(Expr data, Expr collapse_target) { + static const Op& op = Op::Get("relax.collapse_sum_like"); + return Call(op, {std::move(data), std::move(collapse_target)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.collapse_sum_like").set_body_typed(collapse_sum_like); + +StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo collapse_target_sinfo = input_sinfo[1]; + + DataType output_dtype = data_sinfo->dtype; + + Optional> data_shape_value; + if (data_sinfo->shape.defined()) { + data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; + } + Optional> collapse_target_shape_value; + if (collapse_target_sinfo->shape.defined()) { + collapse_target_shape_value = + GetStructInfoAs(collapse_target_sinfo->shape.value())->values; + } + + if (data_shape_value.defined() && collapse_target_shape_value.defined()) { + CheckCollapseShape(call, ctx, data_shape_value.value(), collapse_target_shape_value.value()); + } + + if (collapse_target_sinfo->shape.defined()) { + return TensorStructInfo(collapse_target_sinfo->shape.value(), output_dtype); + } else { + return TensorStructInfo(output_dtype, collapse_target_sinfo->ndim); + } +} + +TVM_REGISTER_OP("relax.collapse_sum_like") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("collapse_target", "Tensor", + "The tensor whose shape is the shape to collapse to.") + .set_attr("FInferStructInfo", InferStructInfoCollapseSumLike); + +/* relax.collapse_sum_to */ +Expr collapse_sum_to(Expr data, Expr shape) { + static const Op& op = Op::Get("relax.collapse_sum_to"); + return Call(op, {std::move(data), std::move(shape)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.collapse_sum_to").set_body_typed(collapse_sum_to); + +StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "CollapseSumTo should have 2 arguments"); + } + + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* shape_sinfo = GetStructInfoAs(call->args[1]); + + if (data_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "CollapseSumTo requires the input data to be a Tensor. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (shape_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "CollapseSumTo requires the input shape to be a Shape. However, the given one is " + << call->args[1]->struct_info_->GetTypeKey()); + } + + DataType output_dtype = data_sinfo->dtype; + + Optional> data_shape_value; + if (data_sinfo->shape.defined()) { + data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; + } + + if (data_shape_value.defined() && shape_sinfo->values.defined()) { + CheckCollapseShape(call, ctx, data_shape_value.value(), shape_sinfo->values.value()); + } + + return TensorStructInfo(/*shape=*/call->args[1], output_dtype); +} + +TVM_REGISTER_OP("relax.collapse_sum_to") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("shape", "Shape", "The shape to collapse to.") + .set_attr("FInferStructInfo", InferStructInfoCollapseSumTo); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 6a2b23ecbdbb..95e29a3dce04 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -112,6 +112,27 @@ Expr split(Expr x, ObjectRef indices_or_sections, int axis); */ Expr squeeze(Expr x, Optional> axis); +/*! + * \brief Return a summation of data to the shape of collapse_target. + * For details, please see the operator `relax.collapse_sum_to`. + * \param data The input tensor. + * \param collapse_target The tensor whose shape is the shape to collapse to. + * \return The result tensor after summation. + */ +Expr collapse_sum_like(Expr data, Expr collapse_target); + +/*! + * \brief Return a summation of data to the given shape. + * collapse_sum_to is intended as the backward operator of broadcast_to and + * other broadcast operators in the automatic differentiation process. + * We expect that data is the result of broadcasting some tensor of the given shape in some + * broadcast operation. Thus the given shape and data.shape must follow broadcast rules. + * \param data The input tensor. + * \param shape The shape to collapse to. + * \return The result tensor of the given shape after summation. + */ +Expr collapse_sum_to(Expr data, Expr shape); + } // namespace relax } // namespace tvm diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index 3d1c6f9f7d5b..a9d692cc0752 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -64,5 +64,9 @@ TVM_REGISTER_GLOBAL("topi.any").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]); }); +TVM_REGISTER_GLOBAL("topi.collapse_sum").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::collapse_sum(args[0], args[1]); +}); + } // namespace topi } // namespace tvm diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 6c7727b7d502..abb414b4724c 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -36,6 +36,9 @@ def test_op_correctness(): assert relax.op.layout_transform(x, index_map=lambda a, b, c: (b, c, a)).op == Op.get( "relax.layout_transform" ) + assert relax.op.collapse_sum_to(x, (4, 5)).op == Op.get("relax.collapse_sum_to") + y = relax.Var("x", R.Tensor((4, 5), "float32")) + assert relax.op.collapse_sum_like(x, y).op == Op.get("relax.collapse_sum_like") def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -2378,5 +2381,328 @@ def test_broadcast_to_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.broadcast_to(x1, stgt)) +def test_collapse_sum_like_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((3, 4), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=2)) + y2 = relax.Var("y", R.Tensor("float32")) + y3 = relax.Var("y", R.Tensor((3, 4))) + y4 = relax.Var("y", R.Tensor(ndim=2)) + y5 = relax.Var("y", R.Tensor((1, 4))) + + _check_inference( + bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x0, y2), relax.TensorStructInfo(dtype="float32", ndim=-1) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x0, y3), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x2, y0), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x2, y4), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x4, y1), relax.TensorStructInfo(dtype="", ndim=2) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x5, y3), relax.TensorStructInfo((3, 4), dtype="") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x0, y5), relax.TensorStructInfo((1, 4), "float32") + ) + + +def test_collapse_sum_like_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((3, 4, a), "float32")) + y0 = relax.Var("y", R.Tensor((4, a), "float32")) + x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32")) + y1 = relax.Var("x", R.Tensor((1, a + b), "float32")) + + _check_inference( + bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((4, a), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((1, a + b), "float32") + ) + + +def test_collapse_sum_like_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) + s3 = relax.Var("s3", relax.ShapeStructInfo((3, 4))) + s4 = relax.Var("s4", relax.ShapeStructInfo(ndim=2)) + s5 = relax.Var("s5", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) + y1 = relax.Var("y", relax.TensorStructInfo(s4, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(s5, "float32")) + + _check_inference(bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo(s4, "float32")) + _check_inference(bb, relax.op.collapse_sum_like(x2, y2), relax.TensorStructInfo(s5, "float32")) + + +def test_collapse_sum_like_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + y0 = relax.Var("y", R.Tensor((3, 4), "float16")) + y1 = relax.Var("y", R.Tensor((3, 4), "int8")) + + _check_inference( + bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), "float16") + ) + _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((3, 4), "int8")) + + +def test_collapse_sum_like_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) + x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x0, x1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x2, x0)) + + +def test_collapse_sum_like_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + y0 = relax.Var("y", R.Tensor((3, 6, 5), "float32")) + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x1 = relax.Var("z", R.Tensor((3, a, 5), "float32")) + y1 = relax.Var("w", R.Tensor((3, b, 5), "float32")) + + s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5))) + s1 = relax.Var("s1", relax.ShapeStructInfo((3, 6, 5))) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + + s2 = relax.Var("s2", relax.ShapeStructInfo((3, a, 5))) + s3 = relax.Var("s3", relax.ShapeStructInfo((3, b, 5))) + x3 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y3 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x0, y0)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x1, y1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x2, y2)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x3, y3)) + + +def test_collapse_sum_to_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x2, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference(bb, relax.op.collapse_sum_to(x3, (3, 4)), relax.TensorStructInfo((3, 4), "")) + _check_inference(bb, relax.op.collapse_sum_to(x4, (3, 4)), relax.TensorStructInfo((3, 4), "")) + _check_inference(bb, relax.op.collapse_sum_to(x5, (3, 4)), relax.TensorStructInfo((3, 4), "")) + + +def test_collapse_sum_to_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((3, 4, a), "float32")) + x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32")) + + _check_inference( + bb, relax.op.collapse_sum_to(x0, (4, a)), relax.TensorStructInfo((4, a), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (1, a + b)), relax.TensorStructInfo((1, a + b), "float32") + ) + + +def test_collapse_sum_to_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + _check_inference( + bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + + +def test_collapse_sum_to_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + + _check_inference( + bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float16") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "int8") + ) + + +def test_collapse_sum_to_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) + x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x0, x0)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x0, x2)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x1, x1)) + + +def test_collapse_sum_to_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x1 = relax.Var("x", R.Tensor((3, a, 5), "float32")) + + s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5))) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + + s1 = relax.Var("s1", relax.ShapeStructInfo((3, a, 5))) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x0, (4, 4, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x1, (3, b, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x2, (4, 4, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x3, (3, b, 5))) + + +def test_collapse_sum_to_infer_struct_info_struct_info_tgt_shape_var(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + s0 = relax.Var("s0", relax.ShapeStructInfo((3, a, b))) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s2", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((3, a, b), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + stgt0 = relax.Var("stgt0", relax.ShapeStructInfo((a, b))) + stgt1 = relax.Var("stgt1", relax.ShapeStructInfo(ndim=2)) + stgt2 = relax.Var("stgt2", relax.ShapeStructInfo()) + + _check_inference( + bb, relax.op.collapse_sum_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference(bb, relax.op.collapse_sum_to(x2, stgt0), relax.TensorStructInfo(stgt0, "")) + _check_inference( + bb, relax.op.collapse_sum_to(x3, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x4, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x5, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference(bb, relax.op.collapse_sum_to(x2, stgt1), relax.TensorStructInfo(stgt1, "")) + _check_inference( + bb, relax.op.collapse_sum_to(x3, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x4, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x5, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + _check_inference(bb, relax.op.collapse_sum_to(x2, stgt2), relax.TensorStructInfo(stgt2, "")) + _check_inference( + bb, relax.op.collapse_sum_to(x3, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x4, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x5, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 2a30994b83c4..8743261ee71e 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -785,5 +785,108 @@ def squeeze(var_rxplaceholder: T.handle, var_T_squeeze: T.handle): tvm.ir.assert_structural_equal(mod, Expected) +def test_collapse_sum_like(): + # fmt: off + @tvm.script.ir_module + class CollapseSumLike: + @R.function + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1, 3), "float32")) -> R.Tensor((1, 3), "float32"): + gv: R.Tensor((1, 3), "float32") = R.collapse_sum_like(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1, 3), "float32")) -> R.Tensor((1, 3), "float32"): + gv = R.call_tir(collapse_sum, (x,), R.Tensor((1, 3), dtype="float32")) + return gv + + @T.prim_func + def collapse_sum(rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], rxplaceholder_red: T.Buffer[(T.int64(1), T.int64(3)), "float32"]): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(3), T.int64(2)): + with T.block("rxplaceholder_red"): + ax0, ax1, k0 = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(rxplaceholder[k0, ax1]) + T.writes(rxplaceholder_red[ax0, ax1]) + with T.init(): + rxplaceholder_red[ax0, ax1] = T.float32(0) + rxplaceholder_red[ax0, ax1] = rxplaceholder_red[ax0, ax1] + rxplaceholder[k0, ax1] + # fmt: on + + mod = LegalizeOps()(CollapseSumLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +@pytest.mark.skip("TOPI collapse_sum not support symbolic now") +def test_collapse_sum_like_symbolic(): + # fmt: off + @tvm.script.ir_module + class CollapseSumLike: + @R.function + def main(x: R.Tensor(("a", "b", "a"), "float32"), y: R.Tensor(("b", 1), "float32")) -> R.Tensor(("b", 1), "float32"): + b = T.var("int64") + gv: R.Tensor((b, 1), "float32") = R.collapse_sum_like(x, y) + return gv + + # fmt: on + + mod = LegalizeOps()(CollapseSumLike) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_collapse_sum_to(): + # fmt: off + @tvm.script.ir_module + class CollapseSumTo: + @R.function + def main(x: R.Tensor((3, 2, 3), "float32")) -> R.Tensor((2, 1), "float32"): + gv: R.Tensor((2, 1), "float32") = R.collapse_sum_to(x, (2, 1)) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((3, 2, 3), dtype="float32") + ) -> R.Tensor((2, 1), dtype="float32"): + # block 0 + gv = R.call_tir(collapse_sum, (x,), R.Tensor((2, 1), dtype="float32")) + return gv + + @T.prim_func + def collapse_sum(rxplaceholder: T.Buffer[(T.int64(3), T.int64(2), T.int64(3)), "float32"], rxplaceholder_red: T.Buffer[(T.int64(2), T.int64(1)), "float32"]): + T.func_attr({"tir.noalias": True}) + for ax0, ax1, k0, k2 in T.grid(T.int64(2), T.int64(1), T.int64(3), T.int64(3)): + with T.block("rxplaceholder_red"): + v_ax0, v_ax1, v_k0, v_k2 = T.axis.remap("SSRR", [ax0, ax1, k0, k2]) + T.reads(rxplaceholder[v_k0, v_ax0, v_k2]) + T.writes(rxplaceholder_red[v_ax0, v_ax1]) + with T.init(): + rxplaceholder_red[v_ax0, v_ax1] = T.float32(0) + rxplaceholder_red[v_ax0, v_ax1] = (rxplaceholder_red[v_ax0, v_ax1] + rxplaceholder[v_k0, v_ax0, v_k2]) + # fmt: on + + mod = LegalizeOps()(CollapseSumTo) + tvm.ir.assert_structural_equal(mod, Expected) + + +@pytest.mark.skip("TOPI collapse_sum not support symbolic now") +def test_collapse_sum_to_symbolic(): + # fmt: off + @tvm.script.ir_module + class CollapseSumTo: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("b", 1), "float32"): + b = T.var("int64") + gv: R.Tensor((b, 1), "float32") = R.collapse_sum_to(x, (b, 1)) + return gv + + # fmt: on + + mod = LegalizeOps()(CollapseSumTo) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py index 27f089ee67c1..c1d0c90d3462 100644 --- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py +++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py @@ -310,5 +310,38 @@ def foo(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), "f _check(foo, bb.get()["foo"]) +def test_collapse_sum_like(): + @R.function + def foo( + x: R.Tensor((3, 4, 5), "float32"), y: R.Tensor((4, 5), "float32") + ) -> R.Tensor((4, 5), "float32"): + gv: R.Tensor((4, 5), "float32") = R.collapse_sum_like(x, y) + return gv + + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + y = relax.Var("y", R.Tensor((4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y]): + gv = bb.emit(relax.op.collapse_sum_like(x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_collapse_sum_to(): + @R.function + def foo(x: R.Tensor((3, 4, 5), "float32")) -> R.Tensor((4, 5), "float32"): + gv: R.Tensor((4, 5), "float32") = R.collapse_sum_to(x, (4, 5)) + return gv + + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.collapse_sum_to(x, (4, 5))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/topi/python/test_topi_reduce.py b/tests/python/topi/python/test_topi_reduce.py index e7f47ba0c4db..0f585fec964d 100644 --- a/tests/python/topi/python/test_topi_reduce.py +++ b/tests/python/topi/python/test_topi_reduce.py @@ -26,6 +26,7 @@ import tvm.topi.testing from tvm import te, topi +from tvm.topi.utils import get_const_tuple in_shape, axis, keepdims, reduce_type, dtype = tvm.testing.parameters( ((32,), 0, False, "argmax", "float32"), @@ -183,5 +184,43 @@ def test_complex_reduce(target, dev): tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3) +data_shape, target_shape = tvm.testing.parameters( + ((2, 3), (3,)), + ((2, 3, 4), (2, 1, 4)), + ((2, 3, 4, 5), (3, 1, 5)), +) + + +def _my_npy_collapse_sum(data, target_shape): + reduce_axes = [] + i = data.ndim - 1 + j = len(target_shape) - 1 + while i >= 0: + if j < 0: + reduce_axes.append(i) + elif target_shape[j] == 1 and data.shape[i] > 1: + reduce_axes.append(i) + i -= 1 + j -= 1 + return np.sum(data, tuple(reduce_axes)).reshape(target_shape) + + +def test_collapse_sum(data_shape, target_shape): + A = te.placeholder(data_shape, name="A") + B = topi.collapse_sum(A, target_shape) + s = te.create_schedule([B.op]) + + a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) + b_np = _my_npy_collapse_sum(a_np, target_shape) + dev = tvm.cpu(0) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev) + # Building with the CSE pass disabled + with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): + foo = tvm.build(s, [A, B], "llvm", name="collapse_sum") + foo(a, b) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5) + + if __name__ == "__main__": tvm.testing.main() From 93cf0874e83fe2e4ce20b6f5e5bda2e8c41214b5 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 20 Feb 2023 23:37:22 -0500 Subject: [PATCH 48/81] [Unity][Fix][Pass] Fix FuseOps for lack graph edges (#14058) This PR fixes a mistake of #14044. In #14044, in VisitLeaf of graph construction of FuseOps, we first check if the input node is Leaf and then check if it is Tuple. This is not right: as Tuple is not categorized as one leaf node, when the input node is a Tuple, the function will return since the input is not a LeafNode. And the check for Tuple will thereby never holds. It is quite interesting that our existing unit tests fail to filter this mistake out. I add a regression test for this case, which can ensure that the tuple is always visited. --- src/relax/transform/fuse_ops.cc | 9 ++++---- tests/python/relax/test_transform_fuse_ops.py | 22 ++++++++++++++++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 3b78274cec58..813c0c8f0366 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -234,10 +234,6 @@ class GraphCreator : public ExprVisitor { void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node* binding_var_node, const OpPatternKind& pattern) { ICHECK_NOTNULL(binding_var_node); - if (!leaf_expr->IsInstance()) { - // Skip GlobalVar, ExternFunc, OpNode. - return; - } // Recursive visit if it's Tuple if (const auto* tuple = leaf_expr.as()) { @@ -247,6 +243,11 @@ class GraphCreator : public ExprVisitor { return; } + if (!leaf_expr->IsInstance()) { + // Skip GlobalVar, ExternFunc, OpNode. + return; + } + auto it = graph_.node_map.find(leaf_expr.get()); IndexedForwardGraph::Node* leaf_node = nullptr; if (it != graph_.node_map.end()) { diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 6fad4f8165c1..d38e5829815c 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -18,7 +18,7 @@ import tvm import tvm.testing from tvm import relax, topi -from tvm.script import ir as I, relax as R +from tvm.script import ir as I, relax as R, tir as T def _check(mod_actual, mod_expected): @@ -834,5 +834,25 @@ def main(x: R.Tensor((2, 3), "float32")): _check(Module, Module) +def test_edge_with_call_dps_packed(): + @I.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + with R.dataflow(): + a = R.call_tir(exp, (x,), out_sinfo=R.Tensor((2, 3), "float32")) + b = R.call_tir(exp, (a,), out_sinfo=R.Tensor((2, 3), "float32")) + c = R.call_tir("packed_dps", (a,), out_sinfo=R.Tensor((2, 3), "float32")) + R.output(b, c) + return R.tuple(b, c) + + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + + # FuseOps should does no change to it. + _check(Module, Module) + + if __name__ == "__main__": tvm.testing.main() From f5149054afaeac3a55231eed8f9ef1de3aeccf22 Mon Sep 17 00:00:00 2001 From: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Mon, 20 Feb 2023 21:44:11 -0800 Subject: [PATCH 49/81] [Unity][Pass] Remove Unused Function (#14061) This PR implements a pass to clean up unused functions. Co-authored-by: masahi --- python/tvm/ir/function.py | 26 ++- src/relax/transform/remove_unused_funcs.cc | 120 ++++++++++ src/relax/transform/utils.h | 122 ++++++++++ .../test_transform_remove_unused_funcs.py | 211 ++++++++++++++++++ 4 files changed, 475 insertions(+), 4 deletions(-) create mode 100644 src/relax/transform/remove_unused_funcs.cc create mode 100644 src/relax/transform/utils.h create mode 100644 tests/python/relax/test_transform_remove_unused_funcs.py diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index d02698edb54d..b64553d31ce1 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -14,11 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Function defintiions.""" +"""Function definitions.""" +from typing import Union, Dict from enum import IntEnum import tvm.runtime - +from tvm.runtime.object import Object from .expr import RelayExpr +from .attrs import DictAttrs from . import _ffi_api @@ -38,7 +40,7 @@ def attrs(self): """Return the attrs member of the function.""" return _ffi_api.BaseFunc_Attrs(self) - def with_attr(self, attr_key_or_dict, attr_value=None): + def with_attr(self, attr_key_or_dict, attr_value=None) -> "BaseFunc": """Create a new copy of the function and update the attribute. Parameters @@ -51,7 +53,7 @@ def with_attr(self, attr_key_or_dict, attr_value=None): Returns ------- - func : Function + func : BaseFunc A new copy of the function """ # make sure we first copy so that we can safely do copy on write @@ -67,6 +69,22 @@ def with_attr(self, attr_key_or_dict, attr_value=None): res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value) ) + def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> "BaseFunc": + """Copy the IRModule and add the given attribute map to it. + Parameters + ---------- + attr_map: Union[DictAttrs, Dict[str, Object]] + The attribute map + Returns + ------- + func : BaseFunc + A new copy of the function + """ + if isinstance(attr_map, tvm.ir.DictAttrs): + attr_map = attr_map._dict() + + return _ffi_api.BaseFuncWithAttrs(self, attr_map) + def without_attr(self, attr_key: str) -> "BaseFunc": """Create a new copy of the function with an attribute without provided key. diff --git a/src/relax/transform/remove_unused_funcs.cc b/src/relax/transform/remove_unused_funcs.cc new file mode 100644 index 000000000000..5572da13388c --- /dev/null +++ b/src/relax/transform/remove_unused_funcs.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/remove_unused_funcs.cc + * \brief Remove unused global relax functions in a IRModule. + */ + +#include +#include + +#include +#include + +#include "utils.h" + +namespace tvm { +namespace relax { + +/** + * \brief Detects all the functions that can be possibly called by entry function. + */ +class CallTracer : ExprVisitor { + public: + explicit CallTracer(IRModule mod_) : mod_{mod_}, called_funcs_{}, visiting_{} {} + + void VisitExpr_(const GlobalVarNode* op) final { + called_funcs_.insert(GetRef(op)); + auto func = mod_->Lookup(op->name_hint); + if (const auto* function_node = func.as()) { + VisitExpr(GetRef(function_node)); + } + // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein. + } + + void VisitExpr_(const CallNode* call_node) final { ExprVisitor::VisitExpr_(call_node); } + + void VisitExpr_(const FunctionNode* func_node) final { + auto func = GetRef(func_node); + if (visiting_.find(func) == visiting_.end()) { + visiting_.insert(func); + for (auto param : func_node->params) { + ExprVisitor::VisitExpr(param); + } + ExprVisitor::VisitExpr(func_node->body); + } + } + + void Trace(std::string entry) { + called_funcs_.insert(mod_->GetGlobalVar(entry)); + auto main_func = mod_->Lookup(entry); + VisitExpr(main_func); + } + + bool check_if_called(GlobalVar gv) { return called_funcs_.count(gv) > 0; } + + private: + IRModule mod_; + + // Record the names of all encountered functions. + std::unordered_set called_funcs_; + + // Record the expressions that are being visited. + std::unordered_set visiting_; +}; + +/*! + * \brief Remove functions that are not used. + * + * \param mod_ IRModule. + * \param entry_funcs The set of functions that can be entry function. + * + * \return The module with dead functions removed. + */ +IRModule RemoveUnusedFunctions(IRModule mod_, Array entry_funcs) { + auto tracer = CallTracer(mod_); + for (auto entry : entry_funcs) { + tracer.Trace(entry); + } + auto existing_functions = mod_->functions; + for (auto f : existing_functions) { + // If a function has an external linkage type, we do not remove it. + // Otherwise, we check the function and remove it if it is not used anywhere. + if (f.second->GetLinkageType() == LinkageType::kInternal && !tracer.check_if_called(f.first)) { + mod_->Remove(f.first); + } + } + return mod_; +} + +} // namespace relax + +namespace transform { +Pass RemoveUnusedFunctions(Array entry_functions) { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relax::RemoveUnusedFunctions(m, entry_functions); }; + return CreateModulePass(pass_func, 0, "RemoveUnusedFunctions", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedFunctions").set_body_typed(RemoveUnusedFunctions); + +} // namespace transform +} // namespace tvm diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h new file mode 100644 index 000000000000..d94c1e3b3ec0 --- /dev/null +++ b/src/relax/transform/utils.h @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/utils.h + * \brief Additional utility classes and functions for working with the Relax IR. + */ +#ifndef TVM_RELAX_TRANSFORM_UTILS_H_ +#define TVM_RELAX_TRANSFORM_UTILS_H_ + +#include +#include +#include + +#include +#include + +#include "../../relay/analysis/graph_partitioner.h" + +namespace tvm { +namespace relax { + +/*! + * \brief A simple wrapper around ExprFunctor for a single argument case. + * The result of visit is memoized. + */ +template +class MemoizedExprTranslator : public ::tvm::relax::ExprFunctor { + using BaseFunctor = ::tvm::relax::ExprFunctor; + + public: + /*! \brief virtual destructor */ + virtual ~MemoizedExprTranslator() {} + + /*! + * \brief The memoized call. + * \param n The expression node. + * \return The result of the call + */ + virtual OutputType VisitExpr(const Expr& n) { + ICHECK(n.defined()); + auto it = memo_.find(n); + if (it != memo_.end()) { + return it->second; + } + auto res = BaseFunctor::VisitExpr(n); + memo_[n] = res; + return res; + } + + virtual OutputType VisitExpr_(const VarNode* vn) { + ICHECK(memo_.count(GetRef(vn))); + return memo_[GetRef(vn)]; + } + + virtual OutputType VisitBinding_(const VarBindingNode* binding) { + ICHECK_EQ(memo_.count(binding->var), 0); + auto v = VisitExpr(binding->value); + memo_[binding->var] = v; + return v; + } + + protected: + /*! \brief Internal map used for memoization. */ + std::unordered_map memo_; +}; + +/*! + * \brief Remove unused global relax functions in an IRModule. + * \param mod The target module + * \param entry_functions list of entry functions + * \return The updated module. + */ +TVM_DLL IRModule RemoveUnusedFunctions(IRModule mod, Array entry_funcs); + +/*! + * \brief Get the external symbol of the Relax function name. + * + * \param func The provided function. + * \return An external symbol. + */ +inline std::string GetExtSymbol(const Function& func) { + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(name_node.defined()) << "Fail to retrieve external symbol."; + return std::string(name_node.value()); +} + +/*! + * \brief Fuse ops or functions according to the given partition, and grouped them into a new + * function. + * + * \param mod The input module. + * \param partition A mapping from a subexpression to the containing group. + * \param lift_constants Whether or not to lift bound constants to parameters of the + * grouped function. + * \return A new module containing grouped functions. + */ +IRModule MakeGroupedFunctions( + IRModule mod, + const std::unordered_map& partition, + bool lift_constants = true); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_TRANSFORM_UTILS_H_ diff --git a/tests/python/relax/test_transform_remove_unused_funcs.py b/tests/python/relax/test_transform_remove_unused_funcs.py new file mode 100644 index 000000000000..8a57b38508d0 --- /dev/null +++ b/tests/python/relax/test_transform_remove_unused_funcs.py @@ -0,0 +1,211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +import tvm.script +import tvm.testing +from tvm import relax +from tvm.script import relax as R +from tvm.script import tir as T + + +def check_if_func_exists(mod, func_name): + gvs = [gv.name_hint for gv in mod.get_global_vars()] + return func_name in gvs + + +def test_unused_relax_func(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_add( + x: T.Buffer[(16, 16), "float32"], + y: T.Buffer[(16, 16), "float32"], + z: T.Buffer[(16, 16), "float32"], + ) -> None: + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): + gv0 = R.add(x, w) + return gv0 + + @R.function + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + gv0 = R.call_tir(tir_add, (x, w), R.Tensor((16, 16), dtype="float32")) + return gv0 + + mod = InputModule + assert mod + new_mod = relax.transform.RemoveUnusedFunctions()(mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "tir_add") + assert not check_if_func_exists(new_mod, "unused_func") + + +def test_unused_relax_func_custom_entry_func(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_add( + x: T.Buffer[(16, 16), "float32"], + y: T.Buffer[(16, 16), "float32"], + z: T.Buffer[(16, 16), "float32"], + ) -> None: + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): + gv0 = R.add(x, w) + return gv0 + + @R.function + def foo( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + gv0 = R.call_tir(tir_add, (x, w), R.Tensor((16, 16), dtype="float32")) + return gv0 + + mod = InputModule + assert mod + + # Test entry function other than "main". + new_mod = relax.transform.RemoveUnusedFunctions(entry_functions=["foo"])(mod) + assert check_if_func_exists(new_mod, "foo") + assert check_if_func_exists(new_mod, "tir_add") + assert not check_if_func_exists(new_mod, "unused_func") + + +def test_unused_relax_func_symbolic_shape(): + # Test with relax function w/ symbolic shape. + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_add( + x: T.Buffer[(16, 16), "float32"], + y: T.Buffer[(16, 16), "float32"], + z: T.Buffer[(16, 16), "float32"], + ) -> None: + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): + gv0 = R.add(x, w) + return gv0 + + @R.function + def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): + m, k = T.var("int64"), T.var("int64") + gv0 = R.call_tir(tir_add, (x, w), R.Tensor((m + 1, k), dtype="float32")) + return gv0 + + mod = InputModule + assert mod + + new_mod = relax.transform.RemoveUnusedFunctions()(mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "tir_add") + assert not check_if_func_exists(new_mod, "unused_func") + + +def test_unused_prim_func(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def unused_func( + x: T.Buffer[(16, 16), "float32"], + y: T.Buffer[(16, 16), "float32"], + z: T.Buffer[(16, 16), "float32"], + ) -> None: + T.func_attr({"global_symbol": "tir_unused"}) + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def relax_add(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): + gv0 = R.add(x, w) + return gv0 + + @R.function + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + gv0 = relax_add(x, w) + return gv0 + + mod = InputModule + assert mod + new_mod = relax.transform.RemoveUnusedFunctions()(mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "relax_add") + # RemoveUnusedFunction pass won't remove the function with global symbol for the external linkage. + assert check_if_func_exists(new_mod, "unused_func") + + +def test_multiple_unused_funcs(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def unused_func1( + x: T.Buffer[(16, 16), "float32"], + y: T.Buffer[(16, 16), "float32"], + z: T.Buffer[(16, 16), "float32"], + ) -> None: + T.func_attr({"global_symbol": "tir_unused"}) + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def unused_func2(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): + gv0 = R.add(x, w) + return gv0 + + @R.function + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + gv0 = R.add(x, w) + return gv0 + + mod = InputModule + assert mod + + new_mod = relax.transform.RemoveUnusedFunctions()(mod) + assert check_if_func_exists(new_mod, "main") + # RemoveUnusedFunction pass won't remove the function with global symbol for the external linkage. + assert check_if_func_exists(new_mod, "unused_func1") + assert not check_if_func_exists(new_mod, "unused_func2") + + +if __name__ == "__main__": + pytest.main([__file__]) From 8083332e1bfe8a0eccbb70dcd1d8f82349f6ebb0 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 21 Feb 2023 16:13:58 +0900 Subject: [PATCH 50/81] [Unity][BYOC] Add pass to merge composite functions to offload large subgraphs (#14062) This PR adds a pass that merges neighboring calls to composite functions offloaded to the same external backend into one function. This is important for backends that want to receive as large subgraph as possible, for example TensorRT. It plays the same role as the MergeCompilerRegion pass in Relay BYOC does, and the algorithm follows the same idea described in https://discuss.tvm.apache.org/t/relay-improved-graph-partitioning-algorithm/5830. Original PR https://github.com/tlc-pack/relax/pull/372 Substantial improvement by @yelite https://github.com/tlc-pack/relax/pull/411 Related fix PR by @yelite https://github.com/tlc-pack/relax/pull/406 Co-authored-by: Lite Ye --- include/tvm/relax/utils.h | 11 +- python/tvm/relax/transform/transform.py | 14 + python/tvm/relax/utils.py | 12 +- .../transform/merge_composite_functions.cc | 355 ++++++ src/relax/utils.cc | 29 + ...est_transform_merge_composite_functions.py | 1051 +++++++++++++++++ tests/python/relax/test_utils.py | 107 ++ 7 files changed, 1570 insertions(+), 9 deletions(-) create mode 100644 src/relax/transform/merge_composite_functions.cc create mode 100644 tests/python/relax/test_transform_merge_composite_functions.py create mode 100644 tests/python/relax/test_utils.py diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index c1d984a21a7b..b3cc76768dd4 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -142,13 +142,16 @@ TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true, TVM_DLL bool IsLeafOrTuple(const Expr& expr); /*! - * \brief Copy the given function. The parameters of the original function would be copied to - * satisfy the restriction in the well-formed check: any two functions cannot share the same - * parameter variable. + * \brief Copy the given function. All variables that are bound inside the original function + * would be copied to satisfy the restriction in the well-formed check: Variables in + * Relax must be bound exactly once. This also ensures that both the function and its copy + * can be inserted into the same IRModule, and be asserted on the structural equality + * agaisnt IRModule created by TVMScript. + * * \param func The relax function to copy. * \return The copied function. */ -TVM_DLL Function CopyWithNewParams(Function func); +TVM_DLL Function CopyWithNewVars(Function func); } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index bf90ef0b0986..12ed27f73a21 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -276,6 +276,20 @@ def FuseOpsByPattern( return _ffi_api.FuseOpsByPattern(pattern_names, df_patterns, annotate_codegen) # type: ignore +def MergeCompositeFunctions() -> tvm.ir.transform.Pass: + """Group one or multiple composite functions created by FuseOpsByPattern into a new function. + The new function will be annotated with "Codegen" and "global_symbol" attributes, and it + is intented to be offloaded to an external backend. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for merging composite functions. + + """ + return _ffi_api.MergeCompositeFunctions() # type: ignore + + def LegalizeOps(customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None): """Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR PrimFuncs. diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 0bb82c79f4f8..d6b405f183e5 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -250,10 +250,12 @@ def auto(func: FType) -> FType: args_converter = _ArgsConverter() # pylint: disable=invalid-name -def copy_with_new_params(func: Function) -> Function: - """Copy the given function. The parameters of the original function would be copied to - satisfy the restriction in the well-formed check: any two functions cannot share the same - parameter variable. +def copy_with_new_vars(func: Function) -> Function: + """Copy the given function. All variables that are bound inside the original function + would be copied to satisfy the restriction in the well-formed check: Variables in + Relax must be bound exactly once. This also ensures that both the function and its copy + can be inserted into the same IRModule, and be asserted on the structural equality + agaisnt IRModule created by TVMScript. Parameters ---------- @@ -265,4 +267,4 @@ def copy_with_new_params(func: Function) -> Function: ret : Function The copied function. """ - return _ffi_api.CopyWithNewParams(func) # type: ignore + return _ffi_api.CopyWithNewVars(func) # type: ignore diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc new file mode 100644 index 000000000000..db73392b02e6 --- /dev/null +++ b/src/relax/transform/merge_composite_functions.cc @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/merge_composite_functions.cc + * \brief Group one or multiple composite functions created by FuseOpsByPattern into a new + * function. + * + * The new function will be annotated with kCodegen and kGlobalSymbol attributes, and it is + * intented to be offloaded to an external backend. + * + * A group for one composite function can be merged into another group for one of its arguments, + * which we call the parent group for brevity, if the following conditions are met: + * - The argument is the result of calling a composite function offloaded to the same backend + * - Merging into the parent group would not create a cyclic dependency with other parent groups + * + * For example, in the subgraph below the bottom group cannot be merged into the left parent group, + * since the right parent group for X depends on an output from the left parent group. + * + * O = Offloaded to A + * X = Offloaded to B + * + * Correct partitioning: + * + * O O + * / \ / \ + * O X --> O + + X + * \ / \ / + * O O + * + * The algorithm proceeds by assigning a group to each subexpression in the function according to + * its dataflow. On encountering a call node whose callee is a composite function, we check the + * two conditions above to see if we can merge this call node into one of its parent groups, and + * if we can merge some of its parent groups. + * + * To detect cyclic dependencies between groups, we propagate dependency relations, both direct + * and indirect ones, as we flow through the function. The propagation of indirect dependencies + * is important since the dependency relation is transitive. + */ + +#include +#include +#include +#include + +#include "../../support/arena.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +using relay::GraphPartitioner; + +namespace { + +using Group = GraphPartitioner::Group; + +/*! \brief Assign group to each subexpression in a function according to its + * dataflow, and returns a mapping from a subexpression to its group. */ +class CompositeGroupsBuilder : public MemoizedExprTranslator { + public: + using GroupMap = std::unordered_map; + using MemoizedExprTranslator::VisitExpr_; + + CompositeGroupsBuilder(IRModule mod, support::Arena* arena) : mod_(mod), arena_(arena) {} + + GroupMap Run(Function func) { + for (const auto& param : func->params) { + memo_[param] = arena_->make(); + } + VisitExpr(func->body); + + GroupMap group_map; + for (const auto& [expr, group] : memo_) { + group_map[expr.get()] = group->FindRoot(); + } + + return group_map; + } + + Group* VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + return VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + } + + void VisitBindingBlock_(const BindingBlockNode* block) { + for (Binding binding : block->bindings) { + VisitBinding(binding); + } + } + + void VisitBindingBlock_(const DataflowBlockNode* block) { + for (Binding binding : block->bindings) { + VisitBinding(binding); + } + } + + void VisitBindingBlock(const BindingBlock& block) { + if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + } + + Group* VisitExpr_(const SeqExprNode* op) { + for (BindingBlock block : op->blocks) { + VisitBindingBlock(block); + } + return VisitExpr(op->body); + } + + Group* VisitExpr_(const CallNode* call) { + std::vector groups_to_merge = GetGroupsToMerge(call); + Group* group; + + if (groups_to_merge.size() == 0) { + // Create new group if there is nothing to merge with + group = CreateNewGroup(call); + } else { + auto it = groups_to_merge.cbegin(); + // Assign the first mergable group to current node + // to reduce the number of groups created + group = *it++; + group->num_nodes += 1; + + // Merge all groups + for (; it != groups_to_merge.cend(); ++it) { + MergeGroup(*it, group); + } + } + + UpdateGroupDependencies(group, call->args); + return group; + } + + private: + String GetCodegenName(const std::string& composite_name) { + auto delim_pos = composite_name.find("."); + ICHECK(delim_pos != std::string::npos) << "The pattern name for a composite function should " + "start with a compiler name followed by period."; + return composite_name.substr(0, delim_pos); + } + + Optional GetCodegenName(const Expr& callee) { + auto const* gvar = callee.as(); + if (!gvar) { + return NullOpt; + } + + auto composite_name_opt = + mod_->Lookup(GetRef(gvar))->GetAttr(attr::kComposite); + if (!composite_name_opt) { + return NullOpt; + } + + return GetCodegenName(composite_name_opt.value()); + } + + Optional GetCodegenName(Group* group) { + return Downcast>(group->attrs.Get(attr::kCodegen)); + } + + Group* CreateNewGroup(const CallNode* call) { + Group* group = arena_->make(); + if (Optional codegen_name = GetCodegenName(call->op)) { + group->attrs.Set(attr::kCodegen, codegen_name.value()); + } + return group; + } + + void MergeGroup(Group* from, Group* to) { + ICHECK_EQ(GetCodegenName(from), GetCodegenName(to)); + + Group* from_root = from->FindRoot(); + Group* to_root = to->FindRoot(); + if (from_root == to_root) { + return; + } + + from_root->parent = to_root; + to_root->num_nodes += from_root->num_nodes; + + // Update the group_deps_, maintaining the invariant that + // all groups in the map are root groups. + group_deps_[to_root].merge(group_deps_[from_root]); + group_deps_.erase(from_root); + for (auto& it : group_deps_) { + if (it.second.count(from_root)) { + it.second.erase(from_root); + it.second.insert(to_root); + } + } + } + + std::unordered_set GetParentGroupDependencies(const Array& args) { + // Collect groups that parent groups depend on + std::unordered_set dependencies; + + for (const auto& arg : args) { + for (auto dep : group_deps_[memo_[arg]->FindRoot()]) { + dependencies.insert(dep); + } + } + + return dependencies; + } + + void UpdateGroupDependencies(Group* group, const Array& args) { + Group* group_root = group->FindRoot(); + + for (const auto& arg : args) { + auto arg_group_root = memo_[arg]->FindRoot(); + if (arg_group_root == group_root) { + // If arg and the current node are in the same group, + // there is nothing to update. + continue; + } + // Add the group of arg as dependency + group_deps_[group_root].insert(arg_group_root); + // Propagate dependencies of arg + for (auto dep : group_deps_[arg_group_root]) { + group_deps_[group_root].insert(dep); + } + } + } + + std::vector GetGroupsToMerge(const CallNode* call) { + Optional codegen_name = GetCodegenName(call->op); + if (!codegen_name.defined()) { + return {}; + } + + std::vector groups_to_merge; + std::unordered_set parent_dependencies = GetParentGroupDependencies(call->args); + + for (const auto& arg : call->args) { + auto arg_group = memo_[arg]; + Optional arg_codegen_name = GetCodegenName(arg_group); + if (arg_codegen_name == codegen_name && !parent_dependencies.count(arg_group->FindRoot())) { + // If there is a parent group with the same target, which none of the parent dependency + // groups depends on, merging "this" call node into the parent group will not form a cyclic + // dependency. + groups_to_merge.push_back(arg_group); + } + } + + return groups_to_merge; + } + + IRModule mod_; + support::Arena* arena_; + // Map from group to its dependencies. All groups in this map, whether it's + // the key or in value, should be root node (that is, group->parent == nullptr). + std::unordered_map> group_deps_; +}; + +/*! \brief Inline definitions of composite functions at the global level into their call sites. + This is necessary to make functions created by MergeCompositeFunctions self-contained - each + external backend compiler does not need to refer to the original containing module. + */ +class CompositeInliner : public ExprMutator { + public: + explicit CompositeInliner(IRModule mod) : ExprMutator(mod), mod_(mod) {} + using ExprMutator::VisitExpr_; + + Function Run(Function func) { + inlined_functions_ = Map(); + auto new_body = VisitExpr(func->body); + auto new_func = + Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span); + return new_func; + } + + Expr VisitExpr_(const CallNode* call) { + if (call->op->IsInstance()) { + auto gvar = Downcast(call->op); + auto func = Downcast(mod_->Lookup(gvar)); + + if (func->GetAttr(attr::kComposite)) { + if (!inlined_functions_.count(func)) { + inlined_functions_.Set(func, CopyWithNewVars(func)); + } + return Call(inlined_functions_[func], call->args); + } + } + + return ExprMutator::VisitExpr_(call); + } + + private: + IRModule mod_; + Map inlined_functions_; +}; + +} // namespace + +IRModule MergeCompositeFunctions(IRModule mod) { + auto gvar = mod->GetGlobalVar("main"); + auto func = Downcast(mod->Lookup(gvar)); + support::Arena arena; + auto group_map = CompositeGroupsBuilder(mod, &arena).Run(func); + auto new_mod = MakeGroupedFunctions(mod, group_map); + + CompositeInliner inliner(mod); + for (const auto& [gvar, func] : new_mod->functions) { + if (func->GetAttr(attr::kCodegen)) { + auto new_func = inliner.Run(Downcast(func)); + new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, gvar->name_hint); + new_mod->Update(gvar, new_func); + } + } + // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better way to handle this. + return RemoveUnusedFunctions(new_mod, {"main"}); +} + +namespace transform { + +Pass MergeCompositeFunctions() { + runtime::TypedPackedFunc pass_func = // + [=](IRModule mod, PassContext pc) { return relax::MergeCompositeFunctions(mod); }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseOpsByPattern", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.MergeCompositeFunctions") + .set_body_typed(MergeCompositeFunctions); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 24414f250cbc..110bdb5c8c20 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -82,5 +82,34 @@ bool IsLeafOrTuple(const Expr& expr) { expr.as() || expr.as(); } +class FunctionCopier : public ExprMutator { + public: + static Function Transform(Function func) { + FunctionCopier copier; + // All variables that are bound inside the original function would be copied + // to satisfy the restriction in the well-formed check: Variables in Relax + // must be bound exactly once. + return Downcast(copier.VisitExpr(func)); + } + + Var VisitVarDef_(const DataflowVarNode* var) override { + Var new_var = ExprMutator::VisitVarDef_(var); + Var copied_var = DataflowVar(new_var->name_hint(), GetStructInfo(new_var), new_var->span); + var_remap_[var->vid] = copied_var; + return copied_var; + } + + Var VisitVarDef_(const VarNode* var) override { + Var new_var = ExprMutator::VisitVarDef_(var); + Var copied_var = Var(new_var->name_hint(), GetStructInfo(new_var), new_var->span); + var_remap_[var->vid] = copied_var; + return copied_var; + } +}; + +Function CopyWithNewVars(Function func) { return FunctionCopier::Transform(func); } + +TVM_REGISTER_GLOBAL("relax.CopyWithNewVars").set_body_typed(CopyWithNewVars); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_transform_merge_composite_functions.py b/tests/python/relax/test_transform_merge_composite_functions.py new file mode 100644 index 000000000000..8577a4d93c73 --- /dev/null +++ b/tests/python/relax/test_transform_merge_composite_functions.py @@ -0,0 +1,1051 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import relax +from tvm.script import relax as R + + +@tvm.script.ir_module +class Conv2dReLUx2: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_nn_conv2d_relax_nn_relu( + data, weight1 + ) + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_conv2d_relax_nn_relu1( + lv, weight2 + ) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data1, + weight11, + padding=[1, 1, 1, 1], + ) + gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu1( + conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"}) + with R.dataflow(): + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + conv1, + weight21, + padding=[0, 0, 0, 0], + ) + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv2) + R.output(gv2) + return gv2 + + +@tvm.script.ir_module +class Conv2dReLUx2_merged: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + with R.dataflow(): + gv: R.Tensor( + (1, 64, 54, 54), dtype="float32" + ) = fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1( + data, weight1, weight2 + ) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr( + { + "Primitive": 1, + "Codegen": "dnnl", + "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1", + } + ) + with R.dataflow(): + + @R.function + def lv( + data11: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight111: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Composite": "dnnl.conv2d_relu", "Primitive": 1}) + with R.dataflow(): + lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data11, + weight111, + padding=[1, 1, 1, 1], + ) + gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + lv2: R.Tensor((1, 64, 56, 56), dtype="float32") = lv(data1, weight11) + + @R.function + def lv11( + conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight211: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Composite": "dnnl.conv2d_relu", "Primitive": 1}) + with R.dataflow(): + lv21: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + conv1, + weight211, + padding=[0, 0, 0, 0], + ) + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv21) + R.output(gv2) + return gv2 + + gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv11(lv2, weight21) + R.output(gv3) + return gv3 + + +@tvm.script.ir_module +class Diamond: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + with R.dataflow(): + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_conv2d(data, weight) + lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_relu(lv2) + lv4: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_gelu(lv2) + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_add(lv3, lv4) + R.output(gv2) + return gv2 + + @R.function + def fused_relax_nn_gelu( + lv: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"}) + with R.dataflow(): + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_relu( + lv1: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_add( + lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), + gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, gelu1) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_nn_conv2d( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.conv2d"}) + with R.dataflow(): + gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + data1, + weight1, + padding=[0, 0, 0, 0], + ) + R.output(gv4) + return gv4 + + +@tvm.script.ir_module +class Diamond_merged: + @R.function + def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr( + { + "Codegen": "compiler_A", + "Primitive": 1, + "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add", + } + ) + # block 0 + with R.dataflow(): + + @R.function + def lv( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.conv2d", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + data1, + weight1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="", + ) + R.output(gv4) + return gv4 + + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight) + + @R.function + def lv1( + lv11: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv11) + R.output(gv1) + return gv1 + + lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(lv2) + + @R.function + def lv21( + lv4: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv4) + R.output(gv) + return gv + + lv41: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2) + + @R.function + def lv31( + lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), + gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.add", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, gelu1) + R.output(gv3) + return gv3 + + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv31(lv3, lv41) + R.output(gv2) + return gv2 + + @R.function + def main( + data2: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # block 0 + with R.dataflow(): + gv5: R.Tensor( + (1, 64, 54, 54), dtype="float32" + ) = fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add(data2, weight2) + R.output(gv5) + return gv5 + + +@tvm.script.ir_module +class Diamond_cyclic_dep: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + with R.dataflow(): + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_conv2d(data, weight) + lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_relu(lv2) + lv4: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_gelu(lv2) + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_add(lv3, lv4) + R.output(gv2) + return gv2 + + @R.function + def fused_relax_nn_gelu( + lv: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_B.gelu"}) + with R.dataflow(): + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_relu( + lv1: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_add( + lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), + gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, gelu1) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_nn_conv2d( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.conv2d"}) + with R.dataflow(): + gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + data1, + weight1, + padding=[0, 0, 0, 0], + ) + R.output(gv4) + return gv4 + + +@tvm.script.ir_module +class Diamond_cyclic_dep_merged: + @R.function + def main( + data2: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + with R.dataflow(): + lv4: R.Tuple( + R.Tensor((1, 64, 54, 54), dtype="float32"), + R.Tensor((1, 64, 54, 54), dtype="float32"), + ) = fused_relax_nn_conv2d_relax_nn_relu(data2, weight2) + lv12: R.Tensor((1, 64, 54, 54), dtype="float32") = lv4[0] + lv22: R.Tensor((1, 64, 54, 54), dtype="float32") = lv4[1] + lv31: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_gelu1(lv12) + gv5: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_add1(lv22, lv31) + R.output(gv5) + return gv5 + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tuple( + R.Tensor((1, 64, 54, 54), dtype="float32"), R.Tensor((1, 64, 54, 54), dtype="float32") + ): + R.func_attr( + { + "Primitive": 1, + "Codegen": "compiler_A", + "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu", + } + ) + with R.dataflow(): + + @R.function + def lv( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Composite": "compiler_A.conv2d", "Primitive": 1}) + with R.dataflow(): + gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + data1, + weight1, + padding=[0, 0, 0, 0], + ) + R.output(gv4) + return gv4 + + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight) + + @R.function + def lv1( + lv11: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1}) + with R.dataflow(): + gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv11) + R.output(gv1) + return gv1 + + gv11: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(gv) + R.output(gv, gv11) + return (gv, gv11) + + @R.function + def fused_relax_nn_gelu1( + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr( + {"Primitive": 1, "Codegen": "compiler_B", "global_symbol": "fused_relax_nn_gelu1"} + ) + with R.dataflow(): + + @R.function + def lv21( + lv3: R.Tensor((1, 64, 54, 54), dtype="float32") + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Composite": "compiler_B.gelu", "Primitive": 1}) + with R.dataflow(): + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv3) + R.output(gv2) + return gv2 + + gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_add1( + lv32: R.Tensor((1, 64, 54, 54), dtype="float32"), + lv41: R.Tensor((1, 64, 54, 54), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Primitive": 1, "Codegen": "compiler_A", "global_symbol": "fused_relax_add1"}) + with R.dataflow(): + + @R.function + def lv33( + lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), + gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Composite": "compiler_A.add", "Primitive": 1}) + with R.dataflow(): + gv31: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, gelu1) + R.output(gv31) + return gv31 + + gv6: R.Tensor((1, 64, 54, 54), dtype="float32") = lv33(lv32, lv41) + R.output(gv6) + return gv6 + + +@tvm.script.ir_module +class MultipleProducers: + @R.function + def main( + x1: R.Tensor((10,), dtype="float32"), x2: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + with R.dataflow(): + lv1: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu(x1) + lv2: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu(x2) + lv3: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu(lv1) + lv4: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu(lv2) + gv1: R.Tensor((10,), dtype="float32") = fused_relax_add(lv3, lv4) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_relu( + x11: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11) + R.output(gv2) + return gv2 + + @R.function + def fused_relax_nn_gelu( + x21: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"}) + with R.dataflow(): + gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_add( + lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.add(lv, gelu1) + R.output(gv) + return gv + + +@tvm.script.ir_module +class MultipleProducers_merged: + @R.function + def fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add( + x1: R.Tensor((10,), dtype="float32"), x2: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr( + { + "Codegen": "compiler_A", + "Primitive": 1, + "global_symbol": "fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add", + } + ) + # block 0 + with R.dataflow(): + + @R.function + def lv(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11) + R.output(gv2) + return gv2 + + lv1: R.Tensor((10,), dtype="float32") = lv(x1) + + @R.function + def lv11(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21) + R.output(gv3) + return gv3 + + lv2: R.Tensor((10,), dtype="float32") = lv11(x2) + lv3: R.Tensor((10,), dtype="float32") = lv(lv1) + lv4: R.Tensor((10,), dtype="float32") = lv11(lv2) + + @R.function + def lv21( + lv5: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.add", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.add(lv5, gelu1) + R.output(gv) + return gv + + gv1: R.Tensor((10,), dtype="float32") = lv21(lv3, lv4) + R.output(gv1) + return gv1 + + @R.function + def main( + x12: R.Tensor((10,), dtype="float32"), x22: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # block 0 + with R.dataflow(): + gv4: R.Tensor( + (10,), dtype="float32" + ) = fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add(x12, x22) + R.output(gv4) + return gv4 + + +@tvm.script.ir_module +class MultipleProducersCyclic: + @R.function + def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + with R.dataflow(): + lv1: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu(x1) + lv2: R.Tensor((10,), dtype="float32") = R.nn.relu(lv1) + lv3: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu(lv2) + gv1: R.Tensor((10,), dtype="float32") = fused_relax_add(lv1, lv3) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_relu( + x11: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11) + R.output(gv2) + return gv2 + + @R.function + def fused_relax_nn_gelu( + x21: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.gelu"}) + with R.dataflow(): + gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_add( + lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.add(lv, gelu1) + R.output(gv) + return gv + + +@tvm.script.ir_module +class MultipleProducersCyclic_merged: + @R.function + def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu1(x1) + lv2: R.Tensor((10,), dtype="float32") = R.nn.relu(lv) + gv: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu_relax_add(lv2, lv) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_relu1( + x11: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr( + {"Codegen": "compiler_A", "Primitive": 1, "global_symbol": "fused_relax_nn_relu1"} + ) + # block 0 + with R.dataflow(): + + @R.function + def lv1(x111: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x111) + R.output(gv2) + return gv2 + + gv1: R.Tensor((10,), dtype="float32") = lv1(x11) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_gelu_relax_add( + lv21: R.Tensor((10,), dtype="float32"), lv11: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr( + { + "Codegen": "compiler_A", + "Primitive": 1, + "global_symbol": "fused_relax_nn_gelu_relax_add", + } + ) + # block 0 + with R.dataflow(): + + @R.function + def lv12(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21) + R.output(gv3) + return gv3 + + lv3: R.Tensor((10,), dtype="float32") = lv12(lv21) + + @R.function + def lv22( + lv4: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + # function attr dict + R.func_attr({"Composite": "compiler_A.add", "Primitive": 1}) + # block 0 + with R.dataflow(): + gv4: R.Tensor((10,), dtype="float32") = R.add(lv4, gelu1) + R.output(gv4) + return gv4 + + gv5: R.Tensor((10,), dtype="float32") = lv22(lv11, lv3) + R.output(gv5) + return gv5 + + +@tvm.script.ir_module +class MergeCompilerRegionsExample: + @R.function + def main( + x1: R.Tensor((10,), dtype="float32"), + x2: R.Tensor((10,), dtype="float32"), + x3: R.Tensor((10,), dtype="float32"), + ) -> R.Tensor((10,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((10,), dtype="float32") = fused_relax_add(x1, x2) + lv1: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu(x3) + lv11: R.Tensor((10,), dtype="float32") = fused_relax_add(lv, lv1) + lv12: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu(lv11) + lv2: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu(lv11) + lv21: R.Tensor((10,), dtype="float32") = fused_relax_add(lv12, lv2) + gv1: R.Tensor((10,), dtype="float32") = fused_relax_nn_relu(lv21) + R.output(gv1) + return gv1 + + @R.function + def fused_relax_nn_relu( + add2: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.nn.relu(add2) + R.output(gv) + return gv + + @R.function + def fused_relax_add( + x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.add(x11, x21) + R.output(gv2) + return gv2 + + @R.function + def fused_relax_nn_gelu( + x31: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_B.gelu"}) + with R.dataflow(): + gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x31) + R.output(gv3) + return gv3 + + +@tvm.script.ir_module +class MergeCompilerRegionsExampleRef: + @R.function + def fused_relax_add_relax_add_relax_nn_relu( + x1: R.Tensor((10,), dtype="float32"), + x2: R.Tensor((10,), dtype="float32"), + lv: R.Tensor((10,), dtype="float32"), + ) -> R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((10,), dtype="float32")): + R.func_attr( + { + "Primitive": 1, + "Codegen": "compiler_A", + "global_symbol": "fused_relax_add_relax_add_relax_nn_relu", + } + ) + with R.dataflow(): + + @R.function + def lv1( + x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21) + R.output(gv) + return gv + + lv2: R.Tensor((10,), dtype="float32") = lv1(x1, x2) + gv1: R.Tensor((10,), dtype="float32") = lv1(lv2, lv) + + @R.function + def lv11(add2: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2) + R.output(gv2) + return gv2 + + gv11: R.Tensor((10,), dtype="float32") = lv11(gv1) + R.output(gv1, gv11) + return (gv1, gv11) + + @R.function + def fused_relax_add_relax_nn_relu( + lv12: R.Tensor((10,), dtype="float32"), lv3: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr( + { + "Primitive": 1, + "Codegen": "compiler_A", + "global_symbol": "fused_relax_add_relax_nn_relu", + } + ) + with R.dataflow(): + + @R.function + def lv21( + x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"}) + with R.dataflow(): + gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21) + R.output(gv) + return gv + + lv22: R.Tensor((10,), dtype="float32") = lv21(lv12, lv3) + + @R.function + def lv31(add2: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"}) + with R.dataflow(): + gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2) + R.output(gv2) + return gv2 + + gv3: R.Tensor((10,), dtype="float32") = lv31(lv22) + R.output(gv3) + return gv3 + + @R.function + def fused_relax_nn_gelu1( + x3: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((10,), dtype="float32"): + R.func_attr( + {"Primitive": 1, "Codegen": "compiler_B", "global_symbol": "fused_relax_nn_gelu1"} + ) + with R.dataflow(): + + @R.function + def lv4(x31: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + R.func_attr({"Primitive": 1, "Composite": "compiler_B.gelu"}) + with R.dataflow(): + gv4: R.Tensor((10,), dtype="float32") = R.nn.gelu(x31) + R.output(gv4) + return gv4 + + gv5: R.Tensor((10,), dtype="float32") = lv4(x3) + R.output(gv5) + return gv5 + + @R.function + def main( + x12: R.Tensor((10,), dtype="float32"), + x22: R.Tensor((10,), dtype="float32"), + x32: R.Tensor((10,), dtype="float32"), + ) -> R.Tensor((10,), dtype="float32"): + with R.dataflow(): + lv5: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu1(x32) + lv13: R.Tuple( + R.Tensor((10,), dtype="float32"), R.Tensor((10,), dtype="float32") + ) = fused_relax_add_relax_add_relax_nn_relu(x12, x22, lv5) + lv23: R.Tensor((10,), dtype="float32") = lv13[0] + lv32: R.Tensor((10,), dtype="float32") = lv13[1] + lv41: R.Tensor((10,), dtype="float32") = fused_relax_nn_gelu1(lv23) + gv6: R.Tensor((10,), dtype="float32") = fused_relax_add_relax_nn_relu(lv41, lv32) + R.output(gv6) + return gv6 + + +@tvm.script.ir_module +class ModuleWithNonComposite: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_nn_conv2d(data, weight) + conv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv) + R.output(conv) + return conv + + @R.function + def fused_relax_nn_conv2d( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": 1}) + with R.dataflow(): + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data1, + weight1, + padding=[1, 1, 1, 1], + ) + R.output(gv) + return gv + + +@tvm.script.ir_module +class ModuleWithNonComposite_ref: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_nn_conv2d1(data, weight) + conv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv) + R.output(conv) + return conv + + @R.function + def fused_relax_nn_conv2d1( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr( + {"Codegen": "tensorrt", "Primitive": 1, "global_symbol": "fused_relax_nn_conv2d1"} + ) + with R.dataflow(): + + @R.function + def lv1( + data2: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": 1}) + with R.dataflow(): + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data2, + weight2, + padding=[1, 1, 1, 1], + ) + R.output(gv) + return gv + + gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = lv1(data1, weight1) + R.output(gv1) + return gv1 + + +def check(mod, expected): + partitioned = relax.transform.MergeCompositeFunctions()(mod) + tvm.ir.assert_structural_equal(partitioned, expected) + + +def test_conv2d_relu_x2(): + check(Conv2dReLUx2, Conv2dReLUx2_merged) + + +def test_diamond_cyclic_dep(): + """ + O = Offloaded to A + X = Offloaded to B + + O O + / \\ / \\ + O X --> O + + X + \\ / \\ / + O O + + We cannot merge all 'O' since it would create a cyclic dependency between the group of `X`. + """ + check(Diamond_cyclic_dep, Diamond_cyclic_dep_merged) + + +def test_diamond(): + """ + O = Offloaded to A + + O O + / \\ / \\ + O O --> O O + \\ / \\ / + O O + + """ + check(Diamond, Diamond_merged) + + +def test_merge_producers(): + """ + Test merging multiple producer groups into a single representative group. + O O + | | + O O + \\ / + O + """ + check(MultipleProducers, MultipleProducers_merged) + + +def test_merge_producers_cyclic_dep(): + """ + Test when multiple producer groups being blocked to merge due to circular dependency + in the result. + O + |\\ + | X + | | + | O + |/ + O + """ + check(MultipleProducersCyclic, MultipleProducersCyclic_merged) + + +def test_merge_compiler_regions_example(): + """ + A tricky example from https://discuss.tvm.apache.org/t/relay-improved-graph-partitioning-algorithm/5830 + See also the corresponding test case for Relay MergeCompilerRegions in relay/test_pass_merge_compiler_regions.py. + """ + check( + MergeCompilerRegionsExample, + MergeCompilerRegionsExampleRef, + ) + + +def test_mixed_non_composite(): + check(ModuleWithNonComposite, ModuleWithNonComposite_ref) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py new file mode 100644 index 000000000000..fbeb57564fb5 --- /dev/null +++ b/tests/python/relax/test_utils.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import relax +from tvm.ir.base import assert_structural_equal +from tvm.script.parser import relax as R + + +def test_copy_with_new_vars(): + @R.function + def before(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + gv = R.add(x, y) + return gv + + after = relax.utils.copy_with_new_vars(before) + assert_structural_equal(after, before) + + assert len(after.params) == len(before.params) + for before_var, after_var in zip(before.params, after.params): + assert before_var != after_var + + +def test_copy_with_new_vars_on_ir_module(): + @tvm.script.ir_module + class Actual: + @R.function + def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + gv = R.add(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + gv = R.add(x, y) + return gv + + @R.function + def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + gv = R.add(x, y) + return gv + + Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]) + + # Assertion will fail if the f_copied contains the same VarNode that's used in + # the original function, due to var mapping during structural equal. + assert_structural_equal(Actual, Expected) + + +def test_copy_with_new_vars_on_ir_module_nested_function(): + @tvm.script.ir_module + class Actual: + @R.function + def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + @R.function + def inner(x: R.Tensor((3,), "float32")): + gv = R.add(x, x) + return gv + + gv = R.add(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + @R.function + def inner(x: R.Tensor((3,), "float32")): + gv = R.add(x, x) + return gv + + gv = R.add(x, y) + return gv + + @R.function + def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + @R.function + def inner(x: R.Tensor((3,), "float32")): + gv = R.add(x, x) + return gv + + gv = R.add(x, y) + return gv + + Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]) + + assert_structural_equal(Actual, Expected) + + +if __name__ == "__main__": + pytest.main([__file__]) From c575220cbc3789d26ddfb4d7165fb2232b8de73d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 21 Feb 2023 10:18:09 -0800 Subject: [PATCH 51/81] [Unity][Frontend] Annotate number of non-static input of FX function (#14067) --- .../tvm/relax/frontend/torch/fx_translator.py | 30 ++++++++--- tests/python/relax/test_frontend_from_fx.py | 51 ++++++++++++++++++- 2 files changed, 73 insertions(+), 8 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a762b0a0fbbd..4acad6185592 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -36,7 +36,7 @@ def __init__(self) -> None: from torch import fx self.env: Dict[fx.node.Node, relax.Expr] = {} - self.params: Dict[torch.Tensor, relax.Constant] = {} + self.params: Dict[torch.Tensor, relax.Expr] = {} self.named_modules: Dict[str, torch.Module] = None self.block_builder: relax.BlockBuilder = None self.create_convert_map() @@ -675,7 +675,9 @@ def create_convert_map(self): "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), } - def from_fx(self, model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule: + def from_fx( + self, model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input: bool + ) -> tvm.IRModule: """Convert a PyTorch FX GraphModule to a Relax program.""" from torch import fx @@ -693,7 +695,17 @@ def from_fx(self, model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModu # Initialize the block builder with a function and a dataflow block. self.block_builder = relax.BlockBuilder() - with self.block_builder.function(name="main", params=inputs.copy()): + if keep_params_as_input: + func_attrs = {"num_input": len(inputs)} + for name, param in model.named_parameters(): + shape = param.data.shape + dtype = self._convert_data_type(str(param.data.dtype)) + inputs.append(relax.Var(name, relax.TensorStructInfo(shape, dtype))) + self.params[param] = inputs[-1] + else: + func_attrs = None + + with self.block_builder.function(name="main", params=inputs.copy(), attrs=func_attrs): output = None with self.block_builder.dataflow(): # Translate model parameters. @@ -701,7 +713,8 @@ def from_fx(self, model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModu shape = param.data.shape dtype = self._convert_data_type(str(param.data.dtype)) if dtype in ("float32", "float16"): - self.params[param] = relax.const(param.data.cpu().numpy(), dtype) + if not keep_params_as_input: + self.params[param] = relax.const(param.data.cpu().numpy(), dtype) else: raise ValueError("Unsupported data type for model parameters: %s" % dtype) # Translate the model. @@ -740,7 +753,9 @@ def from_fx(self, model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModu return self.block_builder.get() -def from_fx(model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule: +def from_fx( + model, input_info: List[Tuple[Tuple[int], str]], keep_params_as_input: bool = False +) -> tvm.IRModule: """Convert a PyTorch FX GraphModule to a Relax program Parameters @@ -751,6 +766,9 @@ def from_fx(model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule: input_info : List[Tuple[Tuple[int], str]] A list of shapes and data types of input tensors. + keep_params_as_input : bool + Whether to keep model parameters as input variables. + Returns ------- module : tvm.IRModule @@ -814,4 +832,4 @@ def forward(self, input): to print out the tabular representation of the PyTorch module, and then check the placeholder rows in the beginning of the tabular. """ - return TorchFXImporter().from_fx(model, input_info) + return TorchFXImporter().from_fx(model, input_info, keep_params_as_input) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 9b35d34bd370..24ed9946a3ff 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -22,12 +22,12 @@ from tvm.script.parser import relax as R, tir as T -def verify_model(torch_model, input_info, binding, expected): +def verify_model(torch_model, input_info, binding, expected, keep_params_as_input=False): from torch import fx from tvm.relax.frontend.torch import from_fx graph_model = fx.symbolic_trace(torch_model) - mod = from_fx(graph_model, input_info) + mod = from_fx(graph_model, input_info, keep_params_as_input=keep_params_as_input) binding = {k: tvm.nd.array(v) for k, v in binding.items()} expected = relax.transform.BindParams("main", binding)(expected) tvm.ir.assert_structural_equal(mod, expected) @@ -786,6 +786,7 @@ def test_binary(): input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] input_info2 = [([1, 3, 10, 10], "float32")] + # Add class Add1(Module): def forward(self, lhs, rhs): @@ -1725,5 +1726,51 @@ def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype= verify_model(View(), input_info, {}, expected1) +@tvm.testing.requires_gpu +def test_keep_params(): + import torch + from torch.nn import Module + + class Conv2D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(w2, [1, 6, 1, 1]) + lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3 + R.output(gv) + return gv + + model = Conv2D1() + input_info = [([1, 3, 10, 10], "float32")] + verify_model(model, input_info, {}, expected1, keep_params_as_input=True) + + if __name__ == "__main__": tvm.testing.main() From 9be900becf7125b4439ca9ee6eb1b279b33a3a29 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 21 Feb 2023 12:21:56 -0800 Subject: [PATCH 52/81] [Unity][Transform] Add LiftTransformParams pass (#14069) This PR added a pass `LiftTransformParams`. It allows to compile the end-to-end model without weights provided. The idea is annotate the input parameters that are weights, and identify and lift the transformations to weights, and compile it to a separate function `transform_params` that can be executed in runtime. Users can run `transform_params` with weights to get the weights for the optimized model as a prep step before the deployment. In this way, we perform the same optimizations and defer the weight transformations to the user side, while the overhead of the deferred weight transformation can be ignored as it only need to be run once. This pass is integrated with the default `vm.build`. It is optional and only necessary when the parameters are kept as inputs when importing the model from the frontend. --- include/tvm/relax/transform.h | 16 + python/tvm/relax/transform/transform.py | 21 +- src/relax/transform/lift_transform_params.cc | 297 ++++++++++++++++++ .../test_transform_lift_transform_params.py | 295 +++++++++++++++++ 4 files changed, 628 insertions(+), 1 deletion(-) create mode 100644 src/relax/transform/lift_transform_params.cc create mode 100644 tests/python/relax/test_transform_lift_transform_params.py diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 1934a9f9f2a0..7d9f3d64b0a5 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -174,6 +174,22 @@ TVM_DLL Pass Normalize(); */ TVM_DLL Pass LegalizeOps(Optional> cmap); +/* + * \brief Lift transformation of the parameters of a function. + * + * When some inputs of the function is marked as 'parameters' (the model weights), this pass + * identifies the transformation of the parameters and lifts them to a separate function called + * `transform_params`. `transform_params` takes a tuple of the original parameters as input and + * returns a tuple of the transformed parameters. The original function will be rewritten to accept + * a tuple of transformed parameters as input. + * + * Users are expected to invoke the `transform_params` function in runtime and pass the transformed + * parameters to the original function as input. + * + * \return The Pass. + */ +TVM_DLL Pass LiftTransformParams(); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 12ed27f73a21..590059739ce3 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -285,11 +285,30 @@ def MergeCompositeFunctions() -> tvm.ir.transform.Pass: ------- ret : tvm.transform.Pass The registered pass for merging composite functions. - """ return _ffi_api.MergeCompositeFunctions() # type: ignore +def LiftTransformParams() -> tvm.ir.transform.Pass: + """Lift transformation of the parameters of a function. + + When some inputs of the function is marked as 'parameters' (the model weights), this pass + identifies the transformation of the parameters and lifts them to a separate function called + `transform_params`. `transform_params` takes a tuple of the original parameters as input and + returns a tuple of the transformed parameters. The original function will be rewritten to accept + a tuple of transformed parameters as input. + + Users are expected to invoke the `transform_params` function in runtime and pass the transformed + parameters to the original function as input. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for lifting transformation of parameters. + """ + return _ffi_api.LiftTransformParams() # type: ignore + + def LegalizeOps(customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None): """Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR PrimFuncs. diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc new file mode 100644 index 000000000000..97ed8b24a08e --- /dev/null +++ b/src/relax/transform/lift_transform_params.cc @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/lambda_lift.cc + * \brief Lift local functions into global functions. + */ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief Plan of lifting transform params */ +struct LiftTransformParamsInfoPlan { + Function f_transform_params; // the lifted function that transforms the parameters + std::unordered_map + output_to_index; // the index of the original bindings in the output tuple + std::unordered_set + lifted_bindings; // the bindings of the original function that are lifted +}; + +/*! \brief Builder of the function that transforms the parameters. */ +class TransformParamsFuncBuilder : public ExprMutator { + public: + TransformParamsFuncBuilder() { builder_->BeginDataflowBlock(); } + + /*! \brief Add a input parameter. */ + void AddInput(const Var& var) { inputs_.push_back(var); } + + /*! \brief Add a binding to lift. */ + void AddBinding(const VarBinding& binding) { bindings_.push_back(binding); } + + /*! \brief Mark a variable as the output of the function. */ + void MarkOutput(const Var& output) { outputs_.insert(output); } + + /*! + * \brief Build the function that transforms the parameters + * \return The created function, and a map from the variable in the original function to the index + * of the element of the output tuple + */ + std::pair> Build() { + Array input_sinfo; + Array output_vars; + std::unordered_map output_to_index; + + for (const auto& input : inputs_) { + input_sinfo.push_back(Downcast(input->struct_info_.value())); + } + Var params("params", TupleStructInfo(input_sinfo)); + + // Helper to add a variable to the output tuple + // original_var: the binding variable in the original function + // output_var: the variable, which is a binding in the transform_params function, that is added + // to the output tuple + auto f_add_output = [&](const Var& original_var, const Var& output_var) -> void { + output_to_index[original_var] = output_vars.size(); + output_vars.push_back(output_var); + }; + + // Create mapping from the original input variables to the TupleGetItem from the packed + // parameter tuple Add the parameters that are marked as the output of the function to the + // output tuple + for (const auto& input : inputs_) { + input_remap_.emplace(input.get(), TupleGetItem(params, input_remap_.size())); + if (outputs_.count(input)) { + auto output_var = builder_->Emit(input_remap_.at(input.get())); + f_add_output(input, output_var); + } + } + + // Re-emit the bindings that are lifted. Update the output tuple if the binding is marked as the + // output. + for (const auto& binding : bindings_) { + if (outputs_.count(binding->var)) { + auto output_var = builder_->Emit(VisitExpr(binding->value)); + var_remap_[binding->var->vid] = output_var; + f_add_output(binding->var, output_var); + } else { + VisitBinding(binding); + } + } + + // Create the function. + Expr transformed_params = builder_->EmitOutput(Tuple(output_vars)); + BindingBlock block = builder_->EndBlock(); + Expr body = builder_->Normalize(SeqExpr({block}, transformed_params)); + Function f_transform_params = + Function(/*params=*/{params}, /*body=*/body, /*ret_struct_info=*/NullOpt); + return {f_transform_params, output_to_index}; + } + + Expr VisitExpr_(const VarNode* var) final { + if (auto it = input_remap_.find(var); it != input_remap_.end()) { + return builder_->Emit((*it).second); + } else { + return ExprMutator::VisitExpr_(var); + } + } + + // The input parameters of the function. + Array inputs_; + // Remap from the original input variable to TupleGetItem from the packed parameter tuple, which + // is the input of the lifted function. + std::unordered_map input_remap_; + // The bindings that are lifted. + Array bindings_; + // The variables that are marked as the output of the function. + std::unordered_set outputs_; +}; + +/*! + * \brief Visitor that creates the plan of lifting transform params. + * + * Starting from the parameters of the function (they are the initial set of lifted bindings), we + * will visit the body of the function to find the bindings that can be lifted. A binding can be + * lifted if all the variables that it depends on are also lifted. + * + * When a binding cannot be lifted, all the variables that 1) it depends on, and 2) have been + * lifted, will be marked as the boundary variable and will be in the output of the lifted function. + */ +class LiftTransformParamsPlanner : public ExprVisitor { + public: + LiftTransformParamsInfoPlan Plan(const Function& function, int num_inputs) { + for (int i = num_inputs; i < static_cast(function->params.size()); ++i) { + builder_.AddInput(function->params[i]); + lifted_bindings_.emplace(function->params[i]); + } + VisitExpr(function->body); + + const auto& [f_transform_params, output_to_index] = builder_.Build(); + return {f_transform_params, output_to_index, std::move(lifted_bindings_)}; + } + + private: + void VisitBindingBlock_(const DataflowBlockNode* block) final { + is_in_dataflow_block_ = true; + ExprVisitor::VisitBindingBlock_(block); + is_in_dataflow_block_ = false; + } + + void VisitBinding_(const VarBindingNode* binding) final { + std::vector producers; + bool can_lift = true; + if (!is_in_dataflow_block_) { + can_lift = false; + } + + PostOrderVisit(binding->value, [&](const ObjectRef& obj) { + if (const VarNode* var = obj.as()) { + producers.push_back(var); + if (!lifted_bindings_.count(GetRef(var))) { + can_lift = false; + } + } + }); + if (can_lift) { + lifted_bindings_.insert(binding->var); + builder_.AddBinding(GetRef(binding)); + } else { + for (const VarNode* producer : producers) { + if (lifted_bindings_.count(GetRef(producer))) { + builder_.MarkOutput(GetRef(producer)); + } + } + } + } + + // The bindings that are lifted + std::unordered_set lifted_bindings_; + // The builder of the function that transforms the parameters + TransformParamsFuncBuilder builder_; + // Whether we are in a dataflow block + bool is_in_dataflow_block_{false}; +}; + +/*! + *\brief The rewriter that lifts the transform params of a function and updates the original + * function. + */ +class TransformParamsLifter : public ExprMutator { + public: + explicit TransformParamsLifter(const IRModule& module) : ExprMutator(module) {} + + IRModule Lift() { + auto mod = builder_->GetContextIRModule(); + GlobalVar gv_main = mod->GetGlobalVar("main"); + Function func = Downcast(mod->Lookup(gv_main)); + func = RewriteFunc(func); + builder_->UpdateFunction(gv_main, func); + return builder_->GetContextIRModule(); + } + + private: + Function RewriteFunc(const Function& func) { + const std::string attr_num_input = "num_input"; + auto opt_num_input = func->attrs.GetAttr(attr_num_input); + if (!opt_num_input.defined()) { + return func; + } + LiftTransformParamsPlanner planner; + int64_t params_begin = opt_num_input.value()->value; + + // Step 1: Create the plan of lifting transform params + lift_plan_ = planner.Plan(func, params_begin); + + // Step 2: Add the lifted function to the module + builder_->AddFunction(lift_plan_.f_transform_params, "transform_params"); + + // Step 3: Update the current function. + + // Step 3.1: Update the function signature + Var params("params", lift_plan_.f_transform_params->ret_struct_info); + Array new_params; + for (int i = 0; i < params_begin; ++i) { + new_params.push_back(func->params[i]); + } + new_params.push_back(params); + + // Step 3.2: Update the function body + for (const auto& [var, index] : lift_plan_.output_to_index) { + param_remap_[var] = TupleGetItem(params, index); + } + auto new_body = VisitExpr(func->body); + + // Step 3.3: Remove function attributes that are not needed + auto new_attrs = func->attrs; + auto* new_attrs_node = new_attrs.CopyOnWrite(); + new_attrs_node->dict.erase(attr_num_input); + if (new_attrs->dict.empty()) { + new_attrs = NullValue(); + } + + Function new_func(new_params, new_body, func->ret_struct_info, new_attrs); + return new_func; + } + + void VisitBinding_(const VarBindingNode* binding) final { + if (lift_plan_.lifted_bindings.count(binding->var)) { + return; + } + ExprMutator::VisitBinding_(binding); + } + + Expr VisitExpr_(const VarNode* var) final { + auto it = param_remap_.find(GetRef(var)); + if (it != param_remap_.end()) { + return builder_->Emit(it->second); + } + return ExprMutator::VisitExpr_(var); + } + + Expr VisitExpr_(const DataflowVarNode* var) final { + return VisitExpr_(static_cast(var)); + } + + // Remap the original parameters to TupleGetItem from the packed tuple of transformed parameters. + std::unordered_map param_remap_; + // The plan of lifting the transform params + LiftTransformParamsInfoPlan lift_plan_; +}; + +namespace transform { +Pass LiftTransformParams() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return TransformParamsLifter(m).Lift(); }; + return CreateModulePass(pass_func, 1, "LiftTransformParams", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.LiftTransformParams").set_body_typed(LiftTransformParams); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py new file mode 100644 index 000000000000..a1f67d41da4d --- /dev/null +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -0,0 +1,295 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R, tir as T +import numpy as np +import tvm.topi.testing + + +def test_basic(): + @tvm.script.ir_module + class Before: + @T.prim_func + def transform_layout_IOHW_to_OIHW( + w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") + ) -> None: + for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): + with T.block("layout_transform"): + o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + out[o, i, h, w] = w1[i, o, h, w] + + @R.function + def main( + x: R.Tensor((1, 3, 224, 224), "float32"), + w1: R.Tensor((3, 16, 3, 3), "float32"), + w2: R.Tensor((16, 16, 3, 3), "float32"), + ) -> R.Tensor((1, 16, 224, 224), "float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + w1_transformed = R.call_tir( + transform_layout_IOHW_to_OIHW, w1, R.Tensor((16, 3, 3, 3), "float32") + ) + conv1 = R.nn.conv2d( + x, w1_transformed, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW" + ) + conv2 = R.nn.conv2d( + conv1, w2, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW" + ) + R.output(conv2) + return conv2 + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((1, 3, 224, 224), dtype="float32"), + params: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") + ), + ) -> R.Tensor((1, 16, 224, 224), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((16, 3, 3, 3), dtype="float32") = params[1] + conv1: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d( + x, + lv, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="void", + ) + lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] + conv2: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d( + conv1, + lv1, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="void", + ) + R.output(conv2) + return conv2 + + @T.prim_func + def transform_layout_IOHW_to_OIHW( + w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32") + ): + for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3): + with T.block("layout_transform"): + o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(w1[i, o, h, w]) + T.writes(out[o, i, h, w]) + out[o, i, h, w] = w1[i, o, h, w] + + @R.function + def transform_params( + params: R.Tuple( + R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") + ) + ) -> R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] + lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0] + lv2 = R.call_tir( + transform_layout_IOHW_to_OIHW, + (lv1,), + out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + ) + gv: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((16, 3, 3, 3), dtype="float32"), + ) = (lv, lv2) + R.output(gv) + return gv + + mod = Before + after = relax.transform.LiftTransformParams()(mod) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_tuple(): + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor((1, 16, 224, 224), "float32"), w1: R.Tensor((16, 16, 3, 3), "float32") + ) -> R.Tensor((1, 16, 224, 224), "float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + l0 = (w1,) + l1 = (l0,) + l2 = l1[0] + l3 = l2[0] + conv1 = R.nn.conv2d(x, l3, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW") + conv2 = R.nn.conv2d( + conv1, w1, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW" + ) + R.output(conv2) + return conv2 + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((1, 16, 224, 224), dtype="float32"), + params: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") + ), + ) -> R.Tensor((1, 16, 224, 224), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] + conv1: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d( + x, + lv, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="void", + ) + lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] + conv2: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d( + conv1, + lv1, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="void", + ) + R.output(conv2) + return conv2 + + @R.function + def transform_params( + params: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) + ) -> R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") + ): + with R.dataflow(): + lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] + lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] + l0: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) = (lv1,) + l1: R.Tuple(R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32"))) = (l0,) + l2: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")) = l1[0] + lv2: R.Tensor((16, 16, 3, 3), dtype="float32") = l2[0] + gv: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((16, 16, 3, 3), dtype="float32"), + ) = (lv, lv2) + R.output(gv) + return gv + + mod = Before + after = relax.transform.LiftTransformParams()(mod) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_condition(): + """Test case that the conditional statement can't be lifted""" + + @tvm.script.ir_module + class Before: + @R.function + def main( + x: R.Tensor((1, 16, 224, 224), "float32"), + w1: R.Tensor((16, 16, 3, 3), "float32"), + w2: R.Tensor((16, 16, 3, 3), "float32"), + cond: R.Tensor((), "bool"), + ) -> R.Tensor((1, 16, 224, 224), "float32"): + R.func_attr({"num_input": 1}) + if cond: + w = w1 + else: + w = w2 + with R.dataflow(): + conv1 = R.nn.conv2d(x, w, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW") + R.output(conv1) + return conv1 + + @tvm.script.ir_module + class Expected: + @R.function + def transform_params( + params: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((), dtype="bool"), + ) + ) -> R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((), dtype="bool"), + ): + with R.dataflow(): + lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] + lv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] + lv2: R.Tensor((), dtype="bool") = params[2] + gv: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((), dtype="bool"), + ) = (lv, lv1, lv2) + R.output(gv) + return gv + + @R.function + def main( + x: R.Tensor((1, 16, 224, 224), "float32"), + params: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((), dtype="bool"), + ), + ) -> R.Tensor((1, 16, 224, 224), "float32"): + gv: R.Tensor((), dtype="bool") = params[2] + if gv: + gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = params[0] + w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1 + else: + gv2: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] + w: R.Tensor((16, 16, 3, 3), dtype="float32") = gv2 + with R.dataflow(): + conv1 = R.nn.conv2d(x, w, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW") + R.output(conv1) + return conv1 + + mod = Before + after = relax.transform.LiftTransformParams()(mod) + tvm.ir.assert_structural_equal(after, Expected) + + +if __name__ == "__main__": + tvm.testing.main() From 5eee3af667668a3de221151c8257b084098afcf5 Mon Sep 17 00:00:00 2001 From: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Date: Tue, 21 Feb 2023 19:14:52 -0800 Subject: [PATCH 53/81] [Unity][BYOC][Pass] RunCodegen and TensorRT (#14078) This PR introduces the fundamental workflow for BYOC and integrate TensorRT as a demonstration. --- cmake/modules/contrib/TensorRT.cmake | 2 +- include/tvm/ir/module.h | 6 + include/tvm/relax/transform.h | 110 ++++- python/tvm/ir/module.py | 35 ++ python/tvm/relax/transform/transform.py | 23 + src/ir/module.cc | 12 + .../contrib/codegen_json/codegen_json.h | 419 ++++++++++++++++++ src/relax/backend/contrib/tensorrt/codegen.cc | 267 +++++++++++ src/relax/backend/contrib/utils.h | 127 ++++++ src/relax/transform/run_codegen.cc | 190 ++++++++ tests/python/relax/test_codegen_tensorrt.py | 124 ++++++ .../relax/test_transform_codegen_pass.py | 260 +++++++++++ 12 files changed, 1553 insertions(+), 22 deletions(-) create mode 100644 src/relax/backend/contrib/codegen_json/codegen_json.h create mode 100644 src/relax/backend/contrib/tensorrt/codegen.cc create mode 100644 src/relax/backend/contrib/utils.h create mode 100644 src/relax/transform/run_codegen.cc create mode 100644 tests/python/relax/test_codegen_tensorrt.py create mode 100644 tests/python/relax/test_transform_codegen_pass.py diff --git a/cmake/modules/contrib/TensorRT.cmake b/cmake/modules/contrib/TensorRT.cmake index 696108b50142..a749b6e80fd2 100644 --- a/cmake/modules/contrib/TensorRT.cmake +++ b/cmake/modules/contrib/TensorRT.cmake @@ -23,7 +23,7 @@ include (FindPackageHandleStandardArgs) if(USE_TENSORRT_CODEGEN) message(STATUS "Build with TensorRT codegen") - tvm_file_glob(GLOB COMPILER_TENSORRT_SRCS src/relay/backend/contrib/tensorrt/*.cc) + tvm_file_glob(GLOB COMPILER_TENSORRT_SRCS src/relay/backend/contrib/tensorrt/*.cc src/relax/backend/contrib/tensorrt/*.cc) set_source_files_properties(${COMPILER_TENSORRT_SRCS} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") tvm_file_glob(GLOB RUNTIME_TENSORRT_SRCS src/runtime/contrib/tensorrt/tensorrt_runtime.cc) set_source_files_properties(${RUNTIME_TENSORRT_SRCS} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index fdb44b11887c..538ff64ca3fb 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -115,6 +115,12 @@ class IRModuleNode : public Object { return GetAttr(attr_key, Optional(default_value)); } + /*! + * \brief Get the metadata attributes. + * \returns The additional meta-data attributes + */ + DictAttrs GetAttrs() const { return attrs; } + /*! * \brief Check whether the module has an non-zero integer attr. * diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 7d9f3d64b0a5..7d6c93bcde48 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -25,6 +25,7 @@ #define TVM_RELAX_TRANSFORM_H_ #include +#include #include namespace tvm { @@ -67,6 +68,13 @@ TVM_DLL Pass CreateDataflowBlockPass( const runtime::TypedPackedFunc& pass_func, int opt_level, String name, tvm::Array required, bool traceable = false); +/*! + * \brief Perform lambda lifting to lift functions from nested into global. + * + * \return The Pass. + */ +TVM_DLL Pass LambdaLift(); + /*! * \brief Transform all dataflow structure to non-dataflow version. * @@ -105,27 +113,20 @@ TVM_DLL Pass RewriteDataflowReshape(); TVM_DLL Pass StaticPlanBlockMemory(); /*! - * \brief Bind params of function of the module to constant tensors. - * - * \param func_name The name of the function to bind parameters. - * \param params The parameters to bind. + * \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen. * * \return The Pass. */ -TVM_DLL Pass BindParams(String func_name, Map params); +TVM_DLL Pass AttachGlobalSymbol(); /*! - * \brief Fold constant expressions. - * - * \return The Pass. - */ -TVM_DLL Pass FoldConstant(); -/*! - * \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen. + * \brief Transform Relax IR to normal form: transform AST to A-normal form, and fill the + * checked_type_ and shape_ of expressions. * * \return The Pass. */ -TVM_DLL Pass AttachGlobalSymbol(); +TVM_DLL Pass Normalize(); + /*! * \brief Bind params of function of the module to constant tensors. * @@ -143,14 +144,6 @@ TVM_DLL Pass BindParams(String func_name, Map params); */ TVM_DLL Pass FoldConstant(); -/*! - * \brief Transform Relax IR to normal form: transform AST to A-normal form, and fill the - * checked_type_ and shape_ of expressions. - * - * \return The Pass. - */ -TVM_DLL Pass Normalize(); - /*! * \brief Legalize high-level operator calls in Relax functions to call_tir * with corresponding low-level TIR PrimFuncs. @@ -190,6 +183,81 @@ TVM_DLL Pass LegalizeOps(Optional> cmap); */ TVM_DLL Pass LiftTransformParams(); +/*! + * \brief Annotate Op Pattern Kind for TIR functions, which is used in FuseOps. + * \note It is an auto-detect pass for "unscheduled prim_funcs", the op_pattern will be + * "opaque" of we can't detect it. Users can manually annotate the attr `op_pattern` + * to prim_func. + * \return The Pass. + */ +TVM_DLL Pass AnnotateTIROpPattern(); + +/*! + * \brief This pass groups bindings in a dataflow block of Relax functions and generates a new + * grouped Relax function for each group, according to the fusion algorithm described in the pass + * implementation. By grouping bindings into new Relax functions, we substitute the bindings in the + * function being manipulated into function calls to the new grouped function. + * + * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. + * \param fuse_opt_level The level of fuse optimization. + * -1 indicates that the level will be inferred from pass context. + * \return The Pass. + */ +TVM_DLL Pass FuseOps(int fuse_opt_level = -1); + +/*! + * \brief Apply pattern matching to each function in the given module, and group matched + * expressions into a new function. The end result is similar to FuseOps, but fusion is driven + * completely by the provided patterns. + * + * \param pattern_names The name of each pattern. It becomes the value of the kComposite attribute + * of a fused function after successful matching. + * \param patterns The patterns to detect. The order of the patterns determines the order + * of priority in which they are matched. Higher-priority patterns should come earlier in the list. + * \param annotate_codegen If true, wrap each created composite function with another function, + * whose body consists only of a call to the composite function, and annotate the outer function + * with kCodegen and kGlobalSymbol attributes. The kCodegen attribute is set as the prefix of the + * corresponding pattern name. For example, "dnnl" if the pattern name is "dnnl.conv2d_relu". + * This must be True if the created composite functions are intended to be offloaded to + * an external backend without using the MergeCompositeFunctions pass. + * \return The Pass. + */ +TVM_DLL Pass FuseOpsByPattern(const tvm::Array& pattern_names, + const tvm::Array& patterns, bool annotate_codegen = false); + +/*! + * \brief Group one or multiple composite functions created by FuseOpsByPattern into a new + * function. The new function will be annotated with kCodegen and GlobalSymbol attributes, + * and it is intented to be offloaded to an external backend. + * + * \return The Pass. + */ +TVM_DLL Pass MergeCompositeFunctions(); + +/*! + * \brief Fuse relax sub-function into a larger TIR function if possible. + this pass works together with FuseOps to perform operator fusion. + + * \return The Pass. + */ +TVM_DLL Pass FuseTIR(); + +/*! + * \brief Remove unused global relax functions in an IRModule. + * \param entry_functions list of entry functions + * \return The Pass. + */ +TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); + +/*! + * \brief Run codegen. + * \param target_options pairs of target name and compilation options + * \param entry_functions list of entry functions + * \return The Pass. + */ +TVM_DLL Pass RunCodegen(Optional>> target_options, + Array entry_functions); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 3daffb2640c5..6a151d5a897c 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -15,13 +15,18 @@ # specific language governing permissions and limitations # under the License. """IRModule that holds the functions and type definitions.""" +from __future__ import annotations + +from typing import Dict, Union import tvm._ffi from tvm._ffi.base import string_types from tvm.runtime import Scriptable +from tvm.runtime.object import Object from . import _ffi_api from . import expr as _expr from . import type as _ty +from .attrs import DictAttrs from .base import Node @@ -286,6 +291,36 @@ def with_attr(self, attr_key, attr_value): return _ffi_api.Module_WithAttr(self, attr_key, attr_value) + def without_attr(self, attr_key: str) -> "IRModule": + """Copy the IRModule and remove an attribute key and its associated value. + Parameters + ---------- + attr_key : str + The attribute key. + Returns + ------- + mod : IRModule + A new copy of the IRModule without the attribute + """ + + return _ffi_api.Module_WithoutAttr(self, attr_key) + + def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> "IRModule": + """Copy the IRModule and add the given attribute map to it. + Parameters + ---------- + attr_map: Union[DictAttrs, Dict[str, Object]] + The attribute map + Returns + ------- + mod : IRModule + A new copy of the IRModule with the attribute + """ + if isinstance(attr_map, tvm.ir.DictAttrs): + attr_map = attr_map._dict() + + return _ffi_api.Module_WithAttrs(self, attr_map) + def astext(self, show_meta_data=True, annotate=None): """Get the text format of the expression. diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 590059739ce3..9fb2458dc026 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -188,6 +188,29 @@ def RemoveUnusedFunctions(entry_functions: Optional[List[str]] = None) -> tvm.ir return _ffi_api.RemoveUnusedFunctions(entry_functions) # type: ignore +def RunCodegen( + target_options: Optional[dict] = None, + entry_functions: Optional[List[str]] = None, +) -> tvm.ir.transform.Pass: + """Produce the runtime::Module with an annotated codegen and global symbol. + + Parameters + ---------- + target_options: Optional[dict] + Pairs of a target name and compilation options + entry_functions: Optional[List[str]] + The set of entry functions to start from. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass to remove unused functions. + """ + if entry_functions is None: + entry_functions = ["main"] + return _ffi_api.RunCodegen(target_options, entry_functions) # type: ignore + + def FoldConstant() -> tvm.ir.transform.Pass: """Fold constant expressions. diff --git a/src/ir/module.cc b/src/ir/module.cc index 4a09bdaaf7c6..8f23f19d352e 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -431,11 +431,23 @@ TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, S mod->ImportFromStd(path); }); +TVM_REGISTER_GLOBAL("ir.Module_GetAttrs").set_body_typed([](IRModule mod) -> ObjectRef { + return mod->GetAttrs(); +}); + TVM_REGISTER_GLOBAL("ir.Module_WithAttr") .set_body_typed([](IRModule mod, String key, ObjectRef value) -> IRModule { return WithAttr(mod, key, value); }); +TVM_REGISTER_GLOBAL("ir.Module_WithoutAttr") + .set_body_typed([](IRModule mod, String key) -> IRModule { return WithoutAttr(mod, key); }); + +TVM_REGISTER_GLOBAL("ir.Module_WithAttrs") + .set_body_typed([](IRModule mod, Map attr_map) -> IRModule { + return WithAttrs(mod, attr_map); + }); + TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod, String key) -> ObjectRef { return mod->GetAttr(key); }); diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h new file mode 100644 index 000000000000..219799870728 --- /dev/null +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -0,0 +1,419 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/backend/contrib/codegen_json/codegen_json.h + * \brief Utilities for json codegen and runtime + */ +#ifndef TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_ +#define TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../../../../runtime/contrib/json/json_node.h" +#include "../../../../runtime/contrib/json/json_runtime.h" +#include "../../../transform/utils.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace backend { +namespace contrib { + +using namespace tvm::runtime::json; + +using ShapeVector = std::vector>; +using TypeVector = std::vector; +using JSONGraphObjectPtr = std::shared_ptr; + +/*! + * \brief Helper class to extract all attributes of a certain op and save them + * into text format. + */ +class OpAttrExtractor : public AttrVisitor { + public: + explicit OpAttrExtractor(JSONGraphObjectPtr node) : node_(node) {} + + template ::value>> + std::string Fp2String(const T value) { + std::ostringstream out; + out.precision(std::numeric_limits::max_digits10); + out << value; + return out.str(); + } + + void SetNodeAttr(const char* key, const std::vector& value) { + std::vector attr; + attr.emplace_back(value); + node_->SetAttr(key, attr); + } + + void Visit(const char* key, double* value) final { SetNodeAttr(key, {Fp2String(*value)}); } + + void Visit(const char* key, int64_t* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, uint64_t* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, int* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, bool* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, std::string* value) final { SetNodeAttr(key, {*value}); } + + void Visit(const char* key, DataType* value) final { + if (!value->is_void()) { + SetNodeAttr(key, {runtime::DLDataType2String(*value)}); + } else { + SetNodeAttr(key, {""}); + } + } + + void Visit(const char* key, runtime::ObjectRef* value) final { + if (const auto* an = (*value).as()) { + std::vector attr; + for (size_t i = 0; i < an->size(); ++i) { + if (const auto* im = (*an)[i].as()) { + attr.push_back(std::to_string(im->value)); + } else if (const auto* fm = (*an)[i].as()) { + attr.push_back(Fp2String(fm->value)); + } else if (const auto* str = (*an)[i].as()) { + String s = GetRef(str); + attr.push_back(s); + } else { + LOG(FATAL) << "Not supported type: " << (*an)[i]->GetTypeKey(); + } + } + SetNodeAttr(key, attr); + } else if (!(*value).defined()) { // Skip NullValue + SetNodeAttr(key, std::vector{""}); + } else if (const auto* im = (*value).as()) { + SetNodeAttr(key, std::vector{std::to_string(im->value)}); + } else if (const auto* fm = (*value).as()) { + SetNodeAttr(key, std::vector{Fp2String(fm->value)}); + } else if (const auto* str = (*value).as()) { + String s = GetRef(str); + SetNodeAttr(key, std::vector{s}); + } else { + LOG(FATAL) << "Not yet supported type: " << (*value)->GetTypeKey() << ": " << *value; + } + } + + void Visit(const char* key, runtime::NDArray* value) final { + LOG(FATAL) << "NDArray is not allowed in op attribute"; + } + + void Visit(const char* key, void** value) final { + LOG(FATAL) << "void pointer is not allowed in op attribute"; + } + + void Extract(Object* node) { + if (node) { + reflection_->VisitAttrs(node, this); + } + } + + private: + JSONGraphObjectPtr node_; + ReflectionVTable* reflection_ = ReflectionVTable::Global(); +}; + +using NodeEntries = std::vector; + +/*! \brief Serialize a Relax expression to JSON. */ +class JSONSerializer : public relax::MemoizedExprTranslator { + public: + using MemoizedExprTranslator::VisitExpr_; + using MemoizedExprTranslator::VisitBinding_; + + /*! + * \brief Constructor + * \param constant_names The names of all constants in the original module. + */ + explicit JSONSerializer(const Map& constant_names) + : constant_names_(constant_names) {} + + void serialize(Function func) { + // First we convert all the parameters into input nodes. + for (const auto& param : func->params) { + auto node_ptr = std::make_shared(param->name_hint(), "input" /* op_type_ */); + memo_[param] = AddNode(node_ptr, param); + } + heads_ = VisitExpr(func->body); + } + + /*!\brief Return the required constants. */ + Array GetConstantNames() const { return constants_used_; } + + /*!\brief Return the generated json. */ + std::string GetJSON() { + std::ostringstream os; + dmlc::JSONWriter writer(&os); + Save(&writer); + return os.str(); + } + + protected: + /*! + * \brief Add a node to graph. + * + * \param node A graph node. It is a shared pointer. Some attributes of it + * will be added, i.e. shape and type. These attributes are attached to + * the JSON graph in the end. + * \param expr The relax expression. + * \return A list of graph entry nodes. It the relax expr is a tuple type, we + * will flatten it. + */ + NodeEntries AddNode(JSONGraphObjectPtr node, const Expr& expr) { + auto struct_info = GetStructInfo(expr); + auto node_id = nodes_.size(); + nodes_.push_back(node); + NodeEntries ret; + ShapeVector shape; + TypeVector dtype; + + // Flatten tuple node. + if (const auto* tuple_sinfo = struct_info.as()) { + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + const auto* tensor_sinfo = tuple_sinfo->fields[i].as(); + ICHECK(tensor_sinfo) << "Expect TensorStructInfo, but received: ." + << tuple_sinfo->fields[i]->GetTypeKey(); + ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined."; + ShapeExpr output_shape = Downcast(tensor_sinfo->shape.value()); + ret.push_back(JSONGraphNodeEntry(node_id, i)); + shape.emplace_back(GetIntShape(output_shape->values)); + dtype.emplace_back(DType2String(tensor_sinfo->dtype)); + } + node->SetNumOutput(tuple_sinfo->fields.size()); + } else { + const auto* tensor_sinfo = struct_info.as(); + ICHECK(tensor_sinfo) << "Expect TensorStructInfo, but received: " + << struct_info->GetTypeKey(); + ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined."; + ShapeExpr output_shape = Downcast(tensor_sinfo->shape.value()); + + shape.emplace_back(GetIntShape(output_shape->values)); + dtype.emplace_back(DType2String(tensor_sinfo->dtype)); + ret.push_back(JSONGraphNodeEntry(node_id, 0)); + } + std::vector shape_attrs; + shape_attrs.emplace_back(shape); + node->SetAttr("shape", shape_attrs); + + std::vector type_attrs; + type_attrs.emplace_back(dtype); + node->SetAttr("dtype", type_attrs); + return ret; + } + + void SetCallNodeAttribute(JSONGraphObjectPtr node, const CallNode* cn) { + if (cn->op.as()) { + OpAttrExtractor extractor(node); + const Object* call_attr = cn->attrs.get(); + extractor.Extract(const_cast(call_attr)); + } else if (const auto* fn = cn->op.as()) { + ICHECK(false); + auto pattern = fn->GetAttr(attr::kPartitionedFromPattern); + ICHECK(pattern.defined()); + std::vector values; + values.push_back(pattern.value()); + std::vector attr; + attr.emplace_back(values); + node->SetAttr("PartitionedFromPattern", attr); + } + } + + NodeEntries VisitBinding_(const MatchCastNode* binding) { + LOG(FATAL) << "JSON runtime currently doesn't match cast\n"; + return {}; + } + + NodeEntries VisitBinding(const Binding& binding) { + NodeEntries nodes; + if (const auto* node = binding.as()) { + auto from_b = VisitBinding_(node); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } else if (const auto* node = binding.as()) { + auto from_b = VisitBinding_(node); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + return nodes; + } + + NodeEntries VisitBindingBlock(const BindingBlock& block) { + NodeEntries nodes; + if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + nodes.insert(nodes.end(), from_bb.begin(), from_bb.end()); + } else if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + nodes.insert(nodes.end(), from_bb.begin(), from_bb.end()); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + return nodes; + } + + NodeEntries VisitBindingBlock_(const BindingBlockNode* block) { + NodeEntries nodes; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } + return nodes; + } + + NodeEntries VisitBindingBlock_(const DataflowBlockNode* block) { + NodeEntries nodes; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } + return nodes; + } + + NodeEntries VisitExpr_(const SeqExprNode* op) { + NodeEntries nodes; + for (BindingBlock block : op->blocks) { + VisitBindingBlock(block); + } + auto from_body = VisitExpr(op->body); + nodes.insert(nodes.end(), from_body.begin(), from_body.end()); + return nodes; + } + + NodeEntries VisitExprDefault_(const Object* op) { + LOG(FATAL) << "JSON runtime currently doesn't support " << op->GetTypeKey(); + return {}; + } + + NodeEntries VisitExpr_(const ConstantNode* cn) { + auto name = constant_names_.find(GetRef(cn)); + ICHECK(name != constant_names_.end()) + << "Cannot find the name of the constant: " << GetRef(cn); + constants_used_.push_back((*name).second); + auto node = std::make_shared((*name).second, "const" /* op_type_ */); + return AddNode(node, GetRef(cn)); + } + + NodeEntries VisitExpr_(const TupleNode* tn) { + NodeEntries fields; + for (const auto& field : tn->fields) { + auto ref = VisitExpr(field); + fields.insert(fields.end(), ref.begin(), ref.end()); + } + return fields; + } + + NodeEntries VisitExpr_(const CallNode* cn) { + Expr expr = GetRef(cn); + std::string name; + if (const auto* op_node = cn->op.as()) { + name = op_node->name; + } else if (const auto* fn = cn->op.as()) { + auto comp = fn->GetAttr(attr::kComposite); + ICHECK(comp.defined()) << "JSON runtime only supports composite functions."; + name = comp.value(); + } else { + LOG(FATAL) << "JSON runtime does not support calls to " << cn->op->GetTypeKey(); + } + + // TODO(@sunggg): Revisit when we have op naming convention. + // Currently, simply remove "relax." prefix to make it work. + name = std::string("tensorrt.") + name.substr(6); + + NodeEntries inputs; + for (const auto& arg : cn->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + SetCallNodeAttribute(node, cn); + return AddNode(node, GetRef(cn)); + } + + NodeEntries VisitExpr_(const TupleGetItemNode* gtn) { + auto vtuple = VisitExpr(gtn->tuple); + return {vtuple[gtn->index]}; + } + + NodeEntries VisitExpr_(const FunctionNode* fn) { + ICHECK(fn->GetAttr(attr::kComposite).defined()) + << "JSON runtime only supports composite functions"; + + // FunctionNode should be handled by the caller. + return {}; + } + + /*! + * \brief Save to JSON graph + * + * \param writer A json writer + */ + void Save(dmlc::JSONWriter* writer) { + std::vector arg_nodes; + for (size_t i = 0; i < nodes_.size(); ++i) { + auto node = nodes_[i]; + if (node->IsLeaf()) { + arg_nodes.push_back(i); + } + } + size_t num_entry = 0; + std::vector node_row_ptr{0}; + for (auto node : nodes_) { + num_entry += node->GetNumOutput(); + node_row_ptr.push_back(num_entry); + } + writer->BeginObject(); + writer->WriteObjectKeyValue("nodes", nodes_); + writer->WriteObjectKeyValue("arg_nodes", arg_nodes); + writer->WriteObjectKeyValue("heads", heads_); + writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr); + writer->EndObject(); + } + + private: + /*! \brief JSON graph nodes. */ + std::vector nodes_; + /*! \brief Output of the JSON graph. */ + NodeEntries heads_; + /*! \brief The list of required constants, ordered. */ + Array constants_used_; + /*! \brief The names of all constants in the original module. */ + const Map& constant_names_; +}; + +} // namespace contrib +} // namespace backend +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_ diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc new file mode 100644 index 000000000000..5ce6bf5e7d42 --- /dev/null +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/tensorrt/codegen.cc + * \brief Implementation of the TensorRT JSON serializer. + */ +#include +// TODO(sunggg): add operator attribute when it's ready +// #include +#include + +#include +#include +#include + +#include "../../../transform/utils.h" +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + +#if TVM_GRAPH_EXECUTOR_TENSORRT +#include "NvInfer.h" +#endif + +namespace tvm { +namespace relax { +namespace contrib { + +/*! \brief Attributes to store the compiler options for TensorRT. */ +struct TensorRTCompilerConfigNode : public tvm::AttrsNode { + Array tensorrt_version; + bool use_implicit_batch; + size_t max_workspace_size; + bool remove_no_mac_subgraphs; + bool use_fp16; + bool use_uint8; + + TVM_DECLARE_ATTRS(TensorRTCompilerConfigNode, "relax.ext.attrs.TensorRTCompilerConfigNode") { + TVM_ATTR_FIELD(tensorrt_version) + .describe("TensorRT version as (major, minor, patch).") + .set_default(Array({6, 0, 1})); + TVM_ATTR_FIELD(use_implicit_batch).set_default(true); + TVM_ATTR_FIELD(max_workspace_size).set_default(size_t(1) << 30); + TVM_ATTR_FIELD(remove_no_mac_subgraphs).set_default(false); + TVM_ATTR_FIELD(use_fp16).set_default(false); + TVM_ATTR_FIELD(use_uint8).set_default(false); + } +}; + +class TensorRTCompilerConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorRTCompilerConfig, Attrs, + TensorRTCompilerConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("relax.ext.tensorrt.options", TensorRTCompilerConfig); + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; +using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr; +using OpAttrExtractor = backend::contrib::OpAttrExtractor; +using JSONSerializer = backend::contrib::JSONSerializer; + +class TensorRTJSONSerializer; + +/*! + * \brief Collect the constants and attributes from all operator calls in the body + * of a "Composite" function. + */ +class CollectFromCompositeFunctionBody : public ExprVisitor { + public: + explicit CollectFromCompositeFunctionBody(TensorRTJSONSerializer* serializer) + : serializer_(serializer), node_(std::make_shared()) {} + + void VisitExpr_(const ConstantNode* constant_node) final; + void VisitExpr_(const CallNode* call_node) final; + + void SetGenericAttributes(const CallNode* call_node) { + OpAttrExtractor extractor(node_); + const Object* attr_obj = call_node->attrs.get(); + extractor.Extract(const_cast(attr_obj)); + } + + TensorRTJSONSerializer* serializer_; + /*! \brief Accumulated translated arguments. */ + std::vector args_; + /*! + * \brief Temporary node into which we'll accumulate attributes. Ideally this would be the + * final JSONGraphNode however we don't yet know how many inputs that will have. + */ + JSONGraphObjectPtr node_; +}; + +/*! + * \brief Generates an TensorRTModule from a relax expression by serializing the expression to a + * json representation. TensorRT is not required here because use of TensorRT APIs is deferred until + * runtime. + */ +class TensorRTJSONSerializer : public JSONSerializer { + public: + explicit TensorRTJSONSerializer(Map constant_names, Map bindings) + : JSONSerializer(constant_names), bindings_(bindings) {} + + using JSONSerializer::VisitExpr_; + + std::vector VisitExpr_(const CallNode* call_node) final { + // The call must be to an inline "Composite" function + const auto* fn_var = call_node->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + + auto opt_composite = fn->GetAttr(attr::kComposite); + ICHECK(opt_composite.defined()); + std::string name = opt_composite.value(); + + // Collect the constants and attributes of all operator calls inside the composite body. + CollectFromCompositeFunctionBody collector(this); + collector.VisitExpr(fn->body); + + // Capture the args to the "Composite" function as inputs for this node. + std::vector inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + + // Capture constants from the composite function body as additional inputs for this node. + for (const auto& node : collector.args_) { + inputs.emplace_back(node); + } + + // Create the final node. + auto node = std::make_shared(name, + /*op_type=*/"kernel", inputs, + /*num_output=*/1); + + // Transfer attributes from the collector's node to the final node. + node->CaptureAttrs(*collector.node_); + + // Capture global settings on the JSON node. + SaveGlobalAttributes(node); + + VLOG(1) << name << " has " << node->GetInputs().size() << " inputs"; + + return AddNode(node, GetRef(call_node)); + } + + static void SaveGlobalAttributes(std::shared_ptr node) { + auto ctx = transform::PassContext::Current(); + auto cfg = ctx->GetConfig("relax.ext.tensorrt.options"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + ICHECK_EQ(cfg.value()->tensorrt_version.size(), 3); + std::vector tensorrt_version = { + std::to_string(cfg.value()->tensorrt_version[0].IntValue()), + std::to_string(cfg.value()->tensorrt_version[1].IntValue()), + std::to_string(cfg.value()->tensorrt_version[2].IntValue())}; + std::vector use_implicit_batch = {std::to_string(cfg.value()->use_implicit_batch)}; + std::vector max_workspace_size = {std::to_string(cfg.value()->max_workspace_size)}; + std::vector use_fp16 = {std::to_string(cfg.value()->use_fp16)}; + std::vector use_uint8 = {std::to_string(cfg.value()->use_uint8)}; + std::vector tensorrt_version_attr, use_implicit_batch_attr, max_workspace_size_attr, + use_fp16_attr, use_uint8_attr; + tensorrt_version_attr.emplace_back(tensorrt_version); + use_implicit_batch_attr.emplace_back(use_implicit_batch); + max_workspace_size_attr.emplace_back(max_workspace_size); + use_fp16_attr.emplace_back(use_fp16); + use_uint8_attr.emplace_back(use_uint8); + node->SetAttr("tensorrt_version", tensorrt_version_attr); + node->SetAttr("use_implicit_batch", use_implicit_batch_attr); + node->SetAttr("max_workspace_size", max_workspace_size_attr); + node->SetAttr("use_fp16", use_fp16_attr); + node->SetAttr("use_uint8", use_uint8_attr); + } + + private: + /*! \brief The bindings to look up composite functions. */ + Map bindings_; +}; + +void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) { + for (const auto& entry : serializer_->VisitExpr(GetRef(constant_node))) { + args_.emplace_back(entry); + } +} + +void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { + SetGenericAttributes(call_node); + ExprVisitor::VisitExpr_(call_node); +} + +/*! + * \brief Create runtime modules for TensorRT. + * \param functions The extern functions to be compiled via TensorRT + * \return Runtime modules. + */ +Array TensorRTCompiler(Array functions, + Map /*unused*/, + Map constant_names) { + Array compiled_functions; + for (const auto& func : functions) { + VLOG(1) << "TensorRT partition:" << std::endl << func; + TensorRTJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + serializer.serialize(func); + std::string graph_json = serializer.GetJSON(); + VLOG(1) << "TensorRT JSON:" << std::endl << graph_json; + auto constant_names = serializer.GetConstantNames(); + const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create"); + ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function."; + std::string func_name = GetExtSymbol(func); + VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'"; + compiled_functions.push_back((*pf)(func_name, graph_json, constant_names)); + } + return compiled_functions; +} + +TVM_REGISTER_GLOBAL("relax.ext.tensorrt").set_body_typed(TensorRTCompiler); + +/*! + * \brief Check whether TensorRT graph executor is enabled. + * \return True if enabled, False if not. + */ +inline constexpr bool IsTensorRTRuntimeEnabled() { +#if TVM_GRAPH_EXECUTOR_TENSORRT + return true; +#else + return false; +#endif // TVM_GRAPH_EXECUTOR_TENSORRT +} + +/*! + * \brief Get TensorRT version that TVM is built against. + * \return Array of three integers for major, minor, and patch, or empty array if TensorRT graph + * runtime is not enabled. + */ +Array GetTensorRTVersion() { +#if TVM_GRAPH_EXECUTOR_TENSORRT + return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR), Integer(NV_TENSORRT_PATCH)}; +#else + return {}; +#endif // TVM_GRAPH_EXECUTOR_TENSORRT +} + +TVM_REGISTER_GLOBAL("relax.is_tensorrt_runtime_enabled").set_body_typed(IsTensorRTRuntimeEnabled); +TVM_REGISTER_GLOBAL("relax.get_tensorrt_version").set_body_typed(GetTensorRTVersion); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h new file mode 100644 index 000000000000..4190ad66b6df --- /dev/null +++ b/src/relax/backend/contrib/utils.h @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/backend/contrib/utils.h + * \brief Utils function for backend + */ +#ifndef TVM_RELAX_BACKEND_CONTRIB_UTILS_H_ +#define TVM_RELAX_BACKEND_CONTRIB_UTILS_H_ + +#include +#include + +#include +#include + +#include "../../transform/utils.h" + +namespace tvm { +namespace relax { +namespace backend { + +/*! + * \brief Get the Packed Func + * + * \param func_name + * \return const PackedFunc* + */ +inline const PackedFunc* GetPackedFunc(const std::string& func_name) { + return tvm::runtime::Registry::Get(func_name); +} + +/*! + * \brief Extract shape from an IndexExpr array to std::vector + * + * \param shape The shape in Array + * \return The converted shape in std::vector + */ + +inline std::vector GetIntShape(const Array& shape) { + std::vector ret; + for (const auto& dim : shape) { + const int64_t* pval = tir::as_const_int(dim); + ret.push_back(pval ? *pval : -1); + } + return ret; +} + +/*! + * \brief Convert type to string + * + * \param typ + * \return std::string string format of type + */ +inline std::string DType2String(const tvm::DataType dtype) { + std::ostringstream os; + if (dtype.is_float()) { + os << "float"; + } else if (dtype.is_int()) { + os << "int"; + } else if (dtype.is_uint()) { + os << "uint"; + } else if (dtype.is_bfloat16()) { + os << "bfloat"; + } else if ((*GetPackedFunc("runtime._datatype_get_type_registered"))(dtype.code())) { + os << "custom[" + << (*GetPackedFunc("runtime._datatype_get_type_name"))(dtype.code()).operator std::string() + << "]"; + } else { + LOG(FATAL) << "Unknown type with code " << static_cast(dtype.code()); + } + os << dtype.bits(); + return os.str(); +} + +/*! + * \brief Check if a call node is calling an op with the given name + * \param call The call node whose callee we want to check + * \param op_name The name of the op + * \return true if the callee op matches with the op name + */ +inline bool IsOp(const CallNode* call, const std::string& op_name) { + const auto* op_node = call->op.as(); + if (!op_node) return false; + Op op = GetRef(op_node); + return op == Op::Get(op_name); +} + +/*! + * \brief Return a call node within the function which calls an op with the given name + * The function must contain exactly one call to such op. + * \param f The function to look for an op. + * \param op_name The name of the op + * \return A call node which calls an op with the given name + */ +inline const CallNode* GetOpInFunction(Function f, const std::string& op_name) { + auto local_bindings = AnalyzeVar2Value(f); + for (const auto& entry : local_bindings) { + if (auto call = entry.second.as(); call && backend::IsOp(call, op_name)) { + return call; + } + } + LOG(FATAL) << op_name << " not found in the function:\n" << f; + return nullptr; +} + +} // namespace backend +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_CONTRIB_UTILS_H_ diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc new file mode 100644 index 000000000000..114b7d2a345d --- /dev/null +++ b/src/relax/transform/run_codegen.cc @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/run_codegen.cc + * \brief Run codegen for annotated relax functions. + */ + +#include +#include + +#include + +#include "utils.h" + +namespace tvm { +namespace relax { + +class CodeGenRunner : ExprMutator { + public: + using OptionMap = Map; + + explicit CodeGenRunner(IRModule mod) : ExprMutator(mod) {} + + IRModule Run(Optional> target_options, Array entry_functions) { + IRModule mod = builder_->GetContextIRModule(); + for (const String& entry_func_name : entry_functions) { + auto entry_func = mod->Lookup(entry_func_name); + auto gvar = mod->GetGlobalVar(entry_func_name); + builder_->UpdateFunction(gvar, Downcast(VisitExpr(entry_func))); + } + + auto ext_mods = InvokeCodegen(mod, target_options.value_or({})); + auto out_mod = builder_->GetContextIRModule(); + + if (ext_mods.size()) { + out_mod = WithAttr(out_mod, tvm::attr::kExternalMods, std::move(ext_mods)); + } + + if (constant_names.size()) { + // Some backends (e.g. TensorRT) expect constants to be passed when they are instantiated + Map constants; + for (const auto& [constant, name] : constant_names) { + ICHECK(!constants.count(name)) << "More than one constant with the name " << name; + constants.Set(name, constant->data); + } + out_mod = WithAttr(out_mod, tvm::attr::kConstNameToConstant, std::move(constants)); + } + + // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better way to handle this. + return RemoveUnusedFunctions(out_mod, entry_functions); + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) override { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + if (auto const* gvar_node = call_node->op.as()) { + const GlobalVar gvar = GetRef(gvar_node); + + auto create_call_tir = [call_node, this](Expr extern_func, StructInfo ret_struct_info) { + Array new_args({extern_func}); + new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return VisitExpr(arg); }))); + + static const Op& call_op = Op::Get("relax.call_tir"); + + return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info}); + }; + + if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) { + return create_call_tir(it->second.first, it->second.second); + } else { + // TODO(@sunggg): Is there any better way to get this func? + Function func = Downcast(builder_->GetContextIRModule()->Lookup(gvar)); + Expr new_func = VisitExpr(func); + + if (new_func->IsInstance()) { + extern_funcs_[gvar_node] = {new_func, func->ret_struct_info}; + // Remove the global symbol and codegen attributes from the function so that it can be + // removed the module. + static const runtime::PackedFunc* RemoveFuncAttrFunc = + runtime::Registry::Get("ir.BaseFuncWithoutAttr"); + ICHECK(RemoveFuncAttrFunc); + func = (*RemoveFuncAttrFunc)(func, tvm::attr::kGlobalSymbol); + func = (*RemoveFuncAttrFunc)(func, attr::kCodegen); + builder_->UpdateFunction(gvar, func); + return create_call_tir(new_func, func->ret_struct_info); + } + } + } + Array new_args; + for (const auto& arg : call_node->args) { + new_args.push_back(VisitExpr(arg)); + } + + return Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); + } + + Expr VisitExpr_(const FunctionNode* func_node) override { + Function func = GetRef(func_node); + auto opt_codegen = func->GetAttr(attr::kCodegen); + if (opt_codegen) { + auto ext_symbol = GetExtSymbol(func); + size_t count = 0; + PostOrderVisit(func->body, [=, &count](Expr e) { + if (e->IsInstance()) { + // Make sure to pick a unique name + auto name = ext_symbol + "_" + opt_codegen.value() + "_const_" + std::to_string(count++); + auto constant = Downcast(e); + constant_names.Set(constant, name); + } + }); + return ExternFunc(GetExtSymbol(func)); + } else { + return ExprMutator::VisitExpr_(func_node); + } + } + + private: + Array InvokeCodegen(IRModule mod, Map target_options) { + std::unordered_map> target_functions; + + for (const auto& entry : mod->functions) { + PostOrderVisit(entry.second, [&target_functions](Expr e) { + if (e->IsInstance()) { + auto f = Downcast(e); + if (auto target_opt = f->GetAttr(attr::kCodegen)) { + String target = target_opt.value(); + target_functions[target].push_back(f); + } + } + }); + } + + Array ext_mods; + + for (const auto& [target, functions] : target_functions) { + OptionMap options = target_options.Get(target).value_or({}); + // Start the codegen process. + // Get the codegen with its ffi key. + String codegen_name = "relax.ext." + target; + auto codegen = runtime::Registry::Get(codegen_name); + ICHECK(codegen) << "Codegen is not found: " << codegen_name << "\n"; + + Array compiled_functions = (*codegen)(functions, options, constant_names); + ext_mods.insert(ext_mods.end(), compiled_functions.begin(), compiled_functions.end()); + } + + return ext_mods; + } + + /*! \brief The names of all constants in the original module. */ + Map constant_names; + /*! \brief Extern funcs and their return struct infos for each global variable. */ + std::unordered_map> extern_funcs_; +}; + +} // namespace relax + +namespace transform { +Pass RunCodegen(Optional>> target_options, + Array entry_functions) { + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext pc) { + return relax::CodeGenRunner(m).Run(target_options, entry_functions); + }; + return CreateModulePass(pass_func, 0, "RunCodegen", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.RunCodegen").set_body_typed(RunCodegen); + +} // namespace transform +} // namespace tvm diff --git a/tests/python/relax/test_codegen_tensorrt.py b/tests/python/relax/test_codegen_tensorrt.py new file mode 100644 index 000000000000..164cf3a8189e --- /dev/null +++ b/tests/python/relax/test_codegen_tensorrt.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import numpy as np +import tvm +import tvm.testing + +from tvm import relax, relay +from tvm.script import relax as R +from tvm.relax.dpl import make_fused_bias_activation_pattern, is_op, wildcard + + +def get_relay_residual_block(d_shape, w_shape): + data = relay.var("data", shape=d_shape) + weight1 = relay.var("weight1", shape=w_shape) + weight2 = relay.var("weight2", shape=w_shape) + conv1 = relay.nn.relu( + relay.nn.conv2d( + data=data, + weight=weight1, + padding=(1, 1), + ) + ) + conv2d = relay.nn.relu( + relay.nn.conv2d( + data=conv1, + weight=weight2, + padding=(1, 1), + ) + ) + return conv2d + data + + +@tvm.script.ir_module +class Conv2dResidualBlock: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + weight2: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = relax.op.nn.relu(relax.op.nn.conv2d(data, weight1, padding=(1, 1))) + conv2 = relax.op.nn.relu(relax.op.nn.conv2d(conv1, weight2, padding=(1, 1))) + out = relax.op.add(conv2, data) + R.output(out) + + return out + + +has_tensorrt = tvm.get_global_func("relax.ext.tensorrt", True) + +tensorrt_enabled = pytest.mark.skipif( + not has_tensorrt, + reason="TENSORRT not enabled.", +) + +pytestmark = [tensorrt_enabled] + + +def test_tensorrt_offload(): + weight1_np = np.random.randn(64, 64, 3, 3).astype("float32") + weight2_np = np.random.randn(64, 64, 3, 3).astype("float32") + + conv_pat = make_fused_bias_activation_pattern( + "relax.nn.conv2d", with_bias=False, activation=None + ) + relu_pat = is_op("relax.nn.relu")(wildcard()) + add_pat = is_op("relax.add")(wildcard(), wildcard()) + + patterns = [ + ("tensorrt.nn.conv2d", conv_pat), + ("tensorrt.nn.relu", relu_pat), + ("tensorrt.add", add_pat), + ] + + params_np = {"weight1": weight1_np, "weight2": weight2_np} + + mod = tvm.transform.Sequential( + [ + relax.transform.BindParams("main", params_np), + relax.transform.FuseOpsByPattern(patterns), + relax.transform.MergeCompositeFunctions(), + relax.transform.RunCodegen(), + ] + )(Conv2dResidualBlock) + + target = "cuda" + dev = tvm.device(target, 0) + ex = relax.vm.build(mod, target) + + vm = relax.VirtualMachine(ex, dev) + f = vm["main"] + + data_np = np.random.randn(1, 64, 56, 56).astype("float32") + out = f(tvm.nd.array(data_np, dev)).numpy() + + relay_mod = tvm.IRModule.from_expr(get_relay_residual_block(data_np.shape, weight1_np.shape)) + + ref = ( + relay.create_executor("graph", mod=relay_mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight1_np, weight2_np]) + .numpy() + ) + + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test_tensorrt_offload() diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py new file mode 100644 index 000000000000..e50ad8f5f427 --- /dev/null +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -0,0 +1,260 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import os +import tvm +import tvm.testing +from tvm import relax +import numpy as np +from tvm.script import relax as R +from tvm.relax.testing import transform +import tempfile +from tvm.relax.transform.tuning_api import Trace +from tvm.relax.dpl import is_op, wildcard + +env_checker_codegen = tvm.get_global_func("relax.ext.tensorrt", True) +env_checker_runtime = tvm.get_global_func("relax.is_tensorrt_runtime_enabled", True) + +has_tensorrt_codegen = pytest.mark.skipif( + not env_checker_codegen, + reason="TensorRT codegen not available", +) +has_tensorrt_runtime = pytest.mark.skipif( + not env_checker_runtime or not env_checker_runtime(), + reason="TensorRT runtime not available", +) + +# Global variable in pytest that applies markers to all tests. +pytestmark = [has_tensorrt_codegen, has_tensorrt_runtime] + +# Target gpu +target_str = "nvidia/geforce-rtx-3070" # "nvidia/nvidia-t4" +target = tvm.target.Target(target_str) +dev = tvm.cuda() + + +def check_executable(exec, dev, inputs, expected): + vm = relax.VirtualMachine(exec, dev) + out = vm["main"](*inputs) + tvm.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-5, rtol=1e-5) + + +def check_roundtrip(exec0, dev, inputs, expected): + exec0.mod.export_library("exec.so") + exec1 = relax.vm.Executable(tvm.runtime.load_module("exec.so")) + os.remove("exec.so") + assert exec0.stats() == exec1.stats() + assert exec0.as_text() == exec1.as_text() + + check_executable(exec0, dev, inputs, expected) + check_executable(exec1, dev, inputs, expected) + + +def gen_ground_truth(mod, target, dev, inputs): + # Lower and run tuning + # Since there is no default schedule for GPU in MS yet, this is necessary + with tempfile.TemporaryDirectory() as work_dir: + with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=0): + seq = tvm.transform.Sequential( + [ + relax.transform.LegalizeOps(), + relax.transform.MetaScheduleTuneIRMod( + params={}, work_dir=work_dir, max_trials_global=8 + ), + relax.transform.MetaScheduleApplyDatabase(work_dir), + ] + ) + new_mod = seq(mod) + assert relax.analysis.well_formed(new_mod) + exec = relax.vm.build(new_mod, target, params={}) + vm = relax.VirtualMachine(exec, dev) + return vm["main"](*inputs) + + +@tvm.script.ir_module +class InputModule: + @R.function + def main( + x: R.Tensor((16, 16), "float32"), y: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): + with R.dataflow(): + z1 = R.multiply(x, y) + z2 = R.add(z1, x) + z3 = R.add(z1, z2) + z4 = R.multiply(z3, z2) + z5 = R.add(z4, z1) + R.output(z5) + return z5 + + +def setup_test(): + # Prepare IRModule and its input + mod = InputModule + assert isinstance(mod, tvm.IRModule) + + np0 = np.random.rand(16, 16).astype(np.float32) + np1 = np.random.rand(16, 16).astype(np.float32) + data0 = tvm.nd.array(np0, dev) + data1 = tvm.nd.array(np1, dev) + inputs = [data0, data1] + + # Ground truth should be generated before annotation + # due to the conflict with MS task extraction + # TODO(@sunggg): Sort this out + expected = gen_ground_truth(mod, target, dev, inputs) + return mod, inputs, expected + + +@tvm.testing.requires_gpu +def test_tensorrt_only(): + mod, inputs, expected = setup_test() + + # Define patterns that we want to offload to byoc + # This test will offload entire model + # Thus, define patterns for both `multiply` and `add` ops + patterns = [ + ("tensorrt.multiply", is_op("relax.multiply")(wildcard(), wildcard())), + ("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())), + ] + + new_mod = tvm.transform.Sequential( + [ + relax.transform.FuseOpsByPattern(patterns), + relax.transform.MergeCompositeFunctions(), + relax.transform.RunCodegen(), + ] + )(mod) + + ex0 = relax.vm.build(new_mod, target, params={}) + # Sanity check for the correctness and rountrip + check_roundtrip(ex0, dev, inputs, expected) + + +@tvm.testing.requires_gpu +def test_mix_use_tensorrt_and_tvm(): + mod, inputs, expected = setup_test() + + # Define patterns that we want to offload to byoc + # This test will only offload `add` op to tensorrt + # and tune `multiply` op with MetaSchedule + patterns = [ + ("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())), + ] + + # Run Codegen pass + with tempfile.TemporaryDirectory() as work_dir: + with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=0): + new_mod = tvm.transform.Sequential( + [ + relax.transform.FuseOpsByPattern(patterns), + relax.transform.MergeCompositeFunctions(), + relax.transform.RunCodegen(), + relax.transform.LegalizeOps(), + relax.transform.MetaScheduleTuneIRMod( + params={}, work_dir=work_dir, max_trials_global=8 + ), + relax.transform.MetaScheduleApplyDatabase(work_dir), + ] + )(mod) + assert relax.analysis.well_formed(new_mod) + with transform.PassContext(opt_level=0): + ex0 = relax.vm.build(new_mod, target, params={}) + + # Sanity check for the correctness and rountrip + check_roundtrip(ex0, dev, inputs, expected) + + +@tvm.script.ir_module +class Conv2dx2: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), + weight2: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + with R.dataflow(): + lv: R.Tensor((16, 32, 32, 16), dtype="float16") = fused_relax_nn_conv2d_tensorrt( + data, weight1 + ) + gv: R.Tensor((16, 32, 32, 16), dtype="float16") = fused_relax_nn_conv2d_tensorrt( + lv, weight2 + ) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_tensorrt( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + R.func_attr({"Codegen": "tensorrt", "global_symbol": "fused_relax_nn_conv2d_tensorrt"}) + + @R.function + def gv( + data_1: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": 1}) + with R.dataflow(): + gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = R.nn.conv2d( + data_1, + weight1_1, + padding=[1, 1, 1, 1], + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(gv_1) + return gv_1 + + gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv(data, weight1) + return gv1 + + +@tvm.script.ir_module +class Conv2dx2_after: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), + weight2: R.Tensor((16, 3, 3, 16), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): + with R.dataflow(): + lv = R.call_tir( + "fused_relax_nn_conv2d_tensorrt", + (data, weight1), + out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"), + ) + gv = R.call_tir( + "fused_relax_nn_conv2d_tensorrt", + (lv, weight2), + out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"), + ) + R.output(gv) + return gv + + +def test_multiple_calls_same_extern(): + mod = relax.transform.RunCodegen()(Conv2dx2) + tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"]) + + +# TODO(@sunggg): test with more complex patterns (e.g., multiple annots, mixed codegens, different ops, const binding) + +if __name__ == "__main__": + pytest.main([__file__]) From 69cf869a4a2927698e2a9b173d2b1c27718af39a Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Tue, 21 Feb 2023 21:31:52 -0800 Subject: [PATCH 54/81] [Unity][Pass] Canonicalize Bindings (#14079) It may be useful for some passes to collapse chains of definitions, particularly after other compiler transformations that may reduce or simplify some expressions. This pass will take chains of definitions and replace references to later definitions to the original one. It works by checking `LookupBinding` for each var use-site and replacing the var with its definition if the definition was another var. Additionally, `MatchCast` bindings where the LHS and the RHS are guaranteed to match at compile time are canonicalized into ordinary `VarBinding`s. Example: ```python y = x z = y w = z o = w p = o ``` Will be replaced with ```python y = x z = x w = x o = x p = x ``` Original PR: https://github.com/tlc-pack/relax/pull/233 Co-authored-by: Steven S. Lyubomirsky --- include/tvm/relax/transform.h | 9 + python/tvm/relax/transform/transform.py | 14 ++ src/relax/transform/canonicalize_bindings.cc | 135 +++++++++++ .../test_transform_canonicalize_bindings.py | 224 ++++++++++++++++++ 4 files changed, 382 insertions(+) create mode 100644 src/relax/transform/canonicalize_bindings.cc create mode 100644 tests/python/relax/test_transform_canonicalize_bindings.py diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 7d6c93bcde48..b42fb5864ef7 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -127,6 +127,15 @@ TVM_DLL Pass AttachGlobalSymbol(); */ TVM_DLL Pass Normalize(); +/*! + * \brief Simplify a Relax module by folding var bindings and match shape nodes. + * May include other forms of expression simplification in the future. + * Best used alongside constant folding and eliminating unused bindings. + * + * \return The Pass. + */ +TVM_DLL Pass CanonicalizeBindings(); + /*! * \brief Bind params of function of the module to constant tensors. * diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 9fb2458dc026..c72d053290af 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -80,6 +80,20 @@ def Normalize() -> tvm.ir.transform.Pass: return _ffi_api.Normalize() # type: ignore +def CanonicalizeBindings() -> tvm.ir.transform.Pass: + """ + Canonicalizes variable definitions + (e.g., if there is y = x and z = y, it replaces uses of y and z with x). + + Best combined with constant folding and the elimination of unused definitions. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.CanonicalizeBindings() # type: ignore + + def RewriteDataflowReshape() -> tvm.ir.transform.Pass: """Convert all reshape-like call_tir to VM reshape operator call. The VM reshape operator calls will be further lowered to a CreateView diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc new file mode 100644 index 000000000000..962f76a376b6 --- /dev/null +++ b/src/relax/transform/canonicalize_bindings.cc @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/canonicalize_bindings.cc + * \brief Pass for simplifying modules by folding var bindings and match shape nodes. + * May include other forms of simplification in the future. + * Ideally should be used before constant folding and eliminating unused bindings. + */ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class BindingCanonicalizer : public ExprMutator { + public: + BindingCanonicalizer() {} + + Expr VisitExpr_(const VarNode* op) override { + // remap first + Var v = Downcast(ExprMutator::VisitExpr_(op)); + if (!CanCanonicalizeVar(v)) { + return Downcast(v); + } + // visit again in case we need to do a substitution in the value + return ExprMutator::VisitExpr_(LookupBinding(v).as()); + } + + Expr VisitExpr_(const DataflowVarNode* op) override { + Var v = Downcast(ExprMutator::VisitExpr_(op)); + if (!CanCanonicalizeVar(v)) { + return Downcast(v); + } + return ExprMutator::VisitExpr_(LookupBinding(v).as()); + } + + void VisitBinding_(const VarBindingNode* binding) override { + // Unlike default visitor, we do not permit the checked type to change + // if the new value's checked type is different (this preserves user annotations) + Expr new_value = this->VisitExpr(binding->value); + Var new_var = this->VisitVarDef(binding->var); + + if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + this->builder_->EmitNormalized(GetRef(binding)); + return; + } + + this->builder_->EmitNormalized(VarBinding(new_var, new_value)); + } + + void VisitBinding_(const MatchCastNode* binding) override { + // If we have a trivial shape check (the shape_ of LHS and RHS is the same), + // we can canonicalize to a var binding + Expr new_value = this->VisitExpr(binding->value); + + // if the LHS and RHS have the same struct info, we canonicalize to a var binding instead + if (StructuralEqual()(binding->struct_info, GetStructInfo(new_value))) { + builder_->EmitNormalized(VarBinding(binding->var, new_value)); + } else if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + builder_->EmitNormalized(MatchCast(binding->var, new_value, binding->struct_info)); + } + } + + private: + bool AnnotationsDiffer(const ObjectRef& obj1, const ObjectRef& obj2, + std::function check_eq) { + // annotations differ if one is present but not the other + // or they're both present and they differ + bool both_present = obj1.defined() && obj2.defined(); + bool neither_present = !obj1.defined() && !obj2.defined(); + return !(both_present || neither_present) || (both_present && !check_eq(obj1, obj2)); + } + + bool CanCanonicalizeVar(Var v) { + Optional value = LookupBinding(v); + // can replace only if the value is also a var + if (!value || !value.as()) { + return false; + } + Var parent_var = Downcast(value); + + // Cases when we conservatively do not unify: + // 1. checked_type_ or shape_ of the child differs from that of the parent + // In this case, we could be overriding user annotations. + // 2. If the child is a Var and the parent is a DataflowVar. + // That could result in a DataflowVar leaving the current DataflowBlock. + bool annotations_differ = AnnotationsDiffer(v->struct_info_, parent_var->struct_info_, + [&](const ObjectRef& lhs, const ObjectRef& rhs) { + return tvm::StructuralEqual()(lhs, rhs); + }); + bool var_to_dataflow = (!v.as() && parent_var.as()); + return !annotations_differ && !var_to_dataflow; + } +}; + +Expr CanonicalizeBindings(const Expr& e) { return BindingCanonicalizer().VisitExpr(e); } + +namespace transform { + +Pass CanonicalizeBindings() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeBindings(f)); + }; + return CreateFunctionPass(pass_func, 1, "CanonicalizeBindings", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings").set_body_typed(CanonicalizeBindings); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py new file mode 100644 index 000000000000..4694e98973f4 --- /dev/null +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.script +import tvm.testing +import pytest +from tvm import relax +from tvm.ir.base import assert_structural_equal +from tvm.script import relax as R, tir as T + + +def test_simple_assignments(): + @tvm.script.ir_module + class TestChainAssignments: + @R.function + def main(x: R.Tensor): + y = x + z = y + q = z + p = q + o = p + return o + + # a little annoying to have these unused bindings around + # but they can be eliminated in a separate pass + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + y = x + z = x + q = x + p = x + o = x + return x + + new_mod = relax.transform.CanonicalizeBindings()(TestChainAssignments) + assert_structural_equal(new_mod, Expected) + + +def test_dataflow_block(): + @tvm.script.ir_module + class TestDataflowAssignments: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = R.const(1) + z = y + o = z + p = o + m = p + n = m + R.output(n) + return n + + # a little annoying to have these unused bindings around + # but they can be eliminated in a separate pass + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = R.const(1) + z = y + o = y + p = y + m = y + # we can't get rid of n because it leaves the block + n = y + R.output(n) + return n + + new_mod = relax.transform.CanonicalizeBindings()(TestDataflowAssignments) + assert_structural_equal(new_mod, Expected) + + +def test_ops(): + @tvm.script.ir_module + class TestOps: + @R.function + def main(x: R.Tensor, y: R.Tensor): + w = y + q = x + z = R.add(w, q) + return R.add(q, z) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor, y: R.Tensor): + w = y + q = x + z = R.add(y, x) + return R.add(x, z) + + new_mod = relax.transform.CanonicalizeBindings()(TestOps) + assert_structural_equal(new_mod, Expected) + + +@pytest.mark.xfail(reason="The lhs and rhs of an assignment should have the same struct info.") +def test_casting(): + @tvm.script.ir_module + class TestCasting: + @R.function + def main(x: R.Tensor) -> R.Object: + y = x + # z will be treated as object type even though it's a tensor + z: R.Object = y + return z + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor) -> R.Object: + y = x + # Cannot unify because the cast indicates user intent + z: R.Object = x + return z + + new_mod = relax.transform.CanonicalizeBindings()(TestCasting) + assert_structural_equal(new_mod, Expected) + + +def test_match_cast(): + @tvm.script.ir_module + class TestMatchCast: + @R.function + def main(x: R.Tensor): + q = x + m, n = T.var("int64"), T.var("int64") + z = R.match_cast(q, R.Tensor((m, n))) + w = z + return w + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + q = x + # can't get rid of z because its shape_ is different from x's + m, n = T.var("int64"), T.var("int64") + z = R.match_cast(x, R.Tensor((m, n))) + w = z + return z + + new_mod = relax.transform.CanonicalizeBindings()(TestMatchCast) + assert_structural_equal(new_mod, Expected) + + +def test_same_shape(): + @tvm.script.ir_module + class TestSameShape: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")): + m, n = T.var("int64"), T.var("int64") + y = x + # trivial check + z = R.match_cast(x, R.Tensor((m, n), "float32")) + w = z + q = R.add(w, y) + return R.add(q, w) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")): + m, n = T.var("int64"), T.var("int64") + y = x + # canonicalized into a var binding + z = x + w = x + q = R.add(x, x) + return R.add(q, x) + + new_mod = relax.transform.CanonicalizeBindings()(TestSameShape) + assert_structural_equal(new_mod, Expected) + + +def test_change_shape(): + @tvm.script.ir_module + class TestChangeShape: + @R.function + def main(x: R.Tensor(("m", "n"))): + y = x + # not trivial: introduces new shape vars + o, p = T.var("int64"), T.var("int64") + z = R.match_cast(x, R.Tensor((o, p))) + w = z + q = R.add(w, y) + return R.add(q, w) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"))): + y = x + o, p = T.var("int64"), T.var("int64") + z = R.match_cast(x, R.Tensor((o, p))) + w = z + # the shape_ field on q will need to be updated + q = R.add(z, x) + return R.add(q, z) + + new_mod = relax.transform.CanonicalizeBindings()(TestChangeShape) + assert_structural_equal(new_mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() From a40f1da47fa2ef5c6b9bdef3ae6830bf8397ec73 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 22 Feb 2023 20:51:10 +0800 Subject: [PATCH 55/81] [Unity] Add testcases for `expr_args_converter` (#14080) This is a missing test file when we added the `expr_args_converter`. This PR adds it back. --- .../python/relax/test_expr_args_converter.py | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 tests/python/relax/test_expr_args_converter.py diff --git a/tests/python/relax/test_expr_args_converter.py b/tests/python/relax/test_expr_args_converter.py new file mode 100644 index 000000000000..bd058e897979 --- /dev/null +++ b/tests/python/relax/test_expr_args_converter.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Callable, List, Optional, Union + +import pytest +import tvm +import tvm.testing +from tvm import relax +from tvm.relax import Expr +from tvm.relax.utils import args_converter + + +def _test_base(f_checker: Callable, arg: Any, *args: Any, **kwargs: Any) -> None: + # Test converting to `Expr` + assert f_checker(arg) + # Test converting `*args` + assert isinstance(args, tuple) + assert all([f_checker(arg) for arg in args]) + # Test converting `**kwargs` + assert isinstance(kwargs, dict) + assert all([f_checker(arg) for arg in kwargs.values()]) + + +def _test_expr(arg: Expr, *args: Expr, **kwargs: Expr) -> None: + f_checker = lambda x: isinstance(x, Expr) + _test_base(f_checker, arg, *args, **kwargs) + + +def _test_optional_expr( + arg: Optional[Expr], *args: Optional[Expr], **kwargs: Optional[Expr] +) -> None: + f_checker = lambda x: x is None or isinstance(x, Expr) + _test_base(f_checker, arg, *args, **kwargs) + + +def _test_list_expr(arg: List[Expr], *args: List[Expr], **kwargs: List[Expr]) -> None: + f_checker = lambda x: isinstance(x, list) and all([isinstance(arg, Expr) for arg in x]) + _test_base(f_checker, arg, *args, **kwargs) + + +def _test_optional_list_expr( + arg: Optional[List[Expr]], *args: Optional[List[Expr]], **kwargs: Optional[List[Expr]] +) -> None: + f_checker = lambda x: x is None or ( + isinstance(x, list) and all([isinstance(arg, Expr) for arg in x]) + ) + _test_base(f_checker, arg, *args, **kwargs) + + +prim_value = 1 +str_value = "value_to_be_convert" +shape_value = (1, 1) +tuple_value = (relax.const(1), (1, 1)) +placeholder = relax.const(0) + +test_cases = [prim_value, str_value, shape_value, tuple_value, placeholder] + + +def test_args_to_expr(): + for _f in [_test_expr, _test_optional_expr]: + f = args_converter.to_expr("arg", "args", "kwargs")(_f) + for x in test_cases: + f( + x, + x, # the first argument in *args + x, # the second argument in *args + test_kwargs=x, + ) + + if _f == _test_optional_expr: + f(None, None, x, test_kwargs=None) + + +def test_args_to_list_expr(): + for _f in [_test_list_expr, _test_optional_list_expr]: + f = args_converter.to_list_expr("arg", "args", "kwargs")(_f) + for x in test_cases: + f( + [x], + [x], # the first argument in *args + [x, x], # the second argument in *args + test_kwargs=[x, (x,)], + ) + + if _f == _test_optional_list_expr: + f(None, None, [x], test_kwargs=None) + + +def test_error(): + f = args_converter.to_list_expr("arg", "args", "kwargs")(_test_list_expr) + with pytest.raises(TypeError): + f(prim_value) # fail to convert prim_value to `List[Expr]` + + +def test_auto_convert(): + for _f in [_test_expr, _test_optional_expr]: + f = args_converter.auto(_f) + for x in test_cases: + f(x, (x,), test_kwargs=x) + + if _f == _test_optional_expr: + f(None, x, test_kwargs=None) + + for _f in [_test_list_expr, _test_optional_list_expr]: + f = args_converter.auto(_f) + for x in test_cases: + f([x], [x, x], test_kwargs=[x, (x,)]) + + if _f == _test_optional_list_expr: + f(None, None, [x], test_kwargs=None) + + +def test_auto_convert_skip(): + def _test_expr_skip(arg: int, *args: Union[str, Expr], **kwargs: List[Optional[Expr]]) -> None: + f_checker = lambda x: not isinstance(x, Expr) + _test_base(f_checker, arg, *args, **kwargs) + + f = args_converter.auto(_test_expr_skip) + f(1, "str", test_kwargs=[None]) + + +def test_empty_tuple(): + def _test(arg: Expr): + assert isinstance(arg, relax.Tuple) + + f = args_converter.auto(_test) + f(()) + + +if __name__ == "__main__": + tvm.testing.main() From 59692e75a975e1f0c6e7aab015879d931bc245f5 Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 22 Feb 2023 21:55:01 +0900 Subject: [PATCH 56/81] [Unity][BYOC] Add CUTLASS backend (#14081) Co-authored-by: Lite Ye --- cmake/modules/contrib/CUTLASS.cmake | 4 +- python/tvm/contrib/cutlass/build.py | 219 ++++++++- python/tvm/contrib/cutlass/gemm_profiler.py | 4 +- python/tvm/contrib/cutlass/gen_gemm.py | 26 +- python/tvm/contrib/cutlass/gen_tensor_op.py | 4 + python/tvm/relax/vm.py | 12 +- src/ir/function.cc | 18 + src/relax/backend/contrib/cutlass/codegen.cc | 274 +++++++++++ src/relax/ir/expr.cc | 24 + .../backend/contrib/codegen_c/codegen_c.h | 35 +- src/relay/backend/contrib/cutlass/codegen.h | 21 + tests/python/relax/test_codegen_cutlass.py | 429 ++++++++++++++++++ 12 files changed, 1044 insertions(+), 26 deletions(-) create mode 100644 src/relax/backend/contrib/cutlass/codegen.cc create mode 100644 tests/python/relax/test_codegen_cutlass.py diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index afd5ef530252..4b4ef355b678 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -16,8 +16,8 @@ # under the License. if(USE_CUDA AND USE_CUTLASS) - tvm_file_glob(GLOB CUTLASS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc) - list(APPEND COMPILER_SRCS ${CUTLASS_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc src/relax/backend/contrib/cutlass/*.cc) + list(APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC}) message(STATUS "Build with CUTLASS") endif() diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 363548fb2ba0..ad0e59af02fa 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -17,16 +17,17 @@ # pylint: disable=invalid-name, dangerous-default-value, arguments-differ """Driver for partitioning and building a Relay module for CUTLASS offload.""" import logging -import os import multiprocessing +import os + import tvm -from tvm import relay, runtime +from tvm import relax, relay, runtime from tvm._ffi.registry import register_func from tvm.contrib.nvcc import get_cuda_version from .gen_conv2d import CutlassConv2DProfiler from .gen_gemm import CutlassGemmProfiler -from .library import ConvKind +from .library import ConvKind, LayoutType logger = logging.getLogger("cutlass") @@ -521,6 +522,192 @@ def tune_cutlass_function( ) +def _extract_relax_function_info(f): + signature = {} + + for i, arg in enumerate(f.params): + sinfo = arg.struct_info + signature["arg%d_shape" % i] = list(sinfo.shape) + signature["arg%d_dtype" % i] = sinfo.dtype + + ret_sinfo = f.ret_struct_info + signature["ret_shape"] = list(ret_sinfo.shape) + signature["ret_dtype"] = ret_sinfo.dtype + + op_attrs = {} + + def fvisit(e): + nonlocal op_attrs + if isinstance(e, relax.Call) and e.op.name in ["relax.nn.conv2d"]: + op_attrs = e.attrs + + relax.analysis.post_order_visit(f.body, fvisit) + + return signature, op_attrs + + +@relax.expr_functor.mutator +class CutlassRelaxFunctionAnnotator(relax.PyExprMutator): + """A Relax function mutator that tunes and annotates CUTLASS composite functions + with shape, dtype and generated templates. + """ + + def __init__( + self, + mod, + conv2d_profiler: CutlassConv2DProfiler, + gemm_profiler: CutlassGemmProfiler, + options, + ): + super().__init__(mod) + self.options = options + self.conv2d_profiler = conv2d_profiler + self.gemm_profiler = gemm_profiler + + def handle_conv2d(self, f, op_type): + """Tune and annotate a conv2d op.""" + signature, op_attrs = _extract_relax_function_info(f) + + d_shape = signature["arg0_shape"] + w_shape = signature["arg1_shape"] + out_shape = signature["ret_shape"] + data_dtype = signature["arg0_dtype"] + weight_dtype = signature["arg1_dtype"] + out_dtype = signature["ret_dtype"] + padding = op_attrs["padding"] + strides = op_attrs["strides"] + dilation = op_attrs["dilation"] + conv_kind = ConvKind.Fprop + + use_3xtf32 = self.options.get("use_3xtf32", False) + profile_all_alignments = self.options.get("profile_all_alignments", False) + find_first_valid = self.options.get("find_first_valid", True) + use_multiprocessing = self.options.get("use_multiprocessing", True) + split_k_slices = self.options.get("split_k_slices", [1]) + + op_name, op_def, _ = self.conv2d_profiler.profile( + op_type, + d_shape, + w_shape, + padding, + strides, + dilation, + out_dtype, + data_dtype, + weight_dtype, + use_3xtf32, + conv_kind, + split_k_slices, + profile_all_alignments, + find_first_valid=find_first_valid, + use_multiprocessing=use_multiprocessing, + ) + + return f.with_attrs( + { + "op_type": op_type, + "arg0_dtype": data_dtype, + "arg1_dtype": weight_dtype, + "ret_dtype": out_dtype, + "arg0_shape": d_shape, + "arg1_shape": w_shape, + "ret_shape": out_shape, + "strides": strides, + "padding": padding, + "dilation": dilation, + "cutlass_op_name": op_name, + "cutlass_op_def": op_def, + } + ) + + def handle_matmul(self, f, op_type): + """Tune and annotate a dense op.""" + signature, _ = _extract_relax_function_info(f) + + arg0_shape = signature["arg0_shape"] + arg1_shape = signature["arg1_shape"] + out_shape = signature["ret_shape"] + arg0_dtype = signature["arg0_dtype"] + arg1_dtype = signature["arg1_dtype"] + out_dtype = signature["ret_dtype"] + + MM = arg0_shape[0] + KK = arg0_shape[1] + NN = arg1_shape[1] + + use_3xtf32 = self.options.get("use_3xtf32", False) + find_first_valid = self.options.get("find_first_valid", True) + use_multiprocessing = self.options.get("use_multiprocessing", True) + + op_name, op_def, _ = self.gemm_profiler.profile( + op_type, + MM, + NN, + KK, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + batched=False, + find_first_valid=find_first_valid, + use_multiprocessing=use_multiprocessing, + layout_b=LayoutType.RowMajor, + ) + + return f.with_attrs( + { + "op_type": op_type, + "arg0_dtype": arg0_dtype, + "arg1_dtype": arg1_dtype, + "ret_dtype": out_dtype, + "arg0_shape": arg0_shape, + "arg1_shape": arg1_shape, + "ret_shape": out_shape, + "lda": "K", + "ldb": "N", + "ldc": "N", + "cutlass_op_name": op_name, + "cutlass_op_def": op_def, + } + ) + + def visit_function_(self, f): + if "Composite" not in f.attrs: + body = super().visit_expr(f.body) + return relax.Function(f.params, body, f.ret_struct_info, f.attrs, f.span) + + op_type = f.attrs["Composite"] + + if "conv2d" in op_type: + return self.handle_conv2d(f, op_type) + elif "matmul" in op_type: + return self.handle_matmul(f, op_type) + + raise ValueError("Unsupported composite {}".format(op_type)) + + def visit_span(self, span): + return span + + +@register_func("contrib.cutlass.tune_relax_function") +def profile_relax_function(functions, options): + """Tune and annotate CUTLASS composite functions with shape, dtype and generated templates.""" + tmp_dir = options.get("tmp_dir", "./tmp") + sm = options.get("sm", 80) + conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir) + gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir) + + annotated_functions = [] + + for f in functions: + annotator = CutlassRelaxFunctionAnnotator( + tvm.IRModule.from_expr(f), conv2d_profiler, gemm_profiler, options + ) + annotated_functions.append(annotator.visit_expr(f)) + + return annotated_functions + + @register_func("contrib.cutlass.compile") def compile_cutlass_module(c_source_module, options): """Compile all CUTLASS kernels in the given C-source module. @@ -664,3 +851,29 @@ def finalize_modules_vm(vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro", fo.write(code) lib = tvm.runtime.load_module(lib_path) return tvm.runtime.vm.Executable.load_exec(code, lib) + + +def finalize_modules_relax(vm_exec, lib_path="compile.so", tmp_dir="./tmp"): + """finalize_modules_vm equivalent for Relax VM. + + Parameters + ---------- + vm_exec : vm.Executable + The output from relax.vm.build containing compiled host code and kernels. + + lib_path : string + The path to a shared library which will be generated as the result of the build process. + + tmp_dir : string + A temporary directory where intermediate compiled artifacts will be stored. + + Returns + ------- + updated_vm_exec : relax.vm.Executable + The updated VM executable with all compilation and linking completed. + """ + lib_path = os.path.join(tmp_dir, lib_path) + vm_exec.mod.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc") + lib = tvm.runtime.load_module(lib_path) + + return relax.vm.Executable(lib) diff --git a/python/tvm/contrib/cutlass/gemm_profiler.py b/python/tvm/contrib/cutlass/gemm_profiler.py index 13679cd05c42..e89e7defbfb7 100644 --- a/python/tvm/contrib/cutlass/gemm_profiler.py +++ b/python/tvm/contrib/cutlass/gemm_profiler.py @@ -55,7 +55,7 @@ def __init__(self): } template -cudaError_t CutlassGemmRCR( +cudaError_t CutlassGemm( int M, int N, int K, @@ -148,7 +148,7 @@ def __init__(self): cudaFree(B); return result; } - result = CutlassGemmRCR(M, N, K, alpha, A, lda, B, ldb, + result = CutlassGemm(M, N, K, alpha, A, lda, B, ldb, beta, C_cutlass, ldc); if (result != cudaSuccess) { std::cerr << "CUTLASS GEMM kernel failed: " diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index ddeddbd39cac..6aa4c5122164 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name """GEMM kernel generator and profiler for CUTLASS.""" +from functools import partial + from .gemm_operation import EmitGemmInstance, GemmOperation from .gemm_profiler import GemmProfilerEmitter from .gen_tensor_op import EPILOGUE_MAP, GENERATOR_FUNC_TABLE, ProfilerEngine @@ -36,6 +38,7 @@ def create_gemm_operator_with_epilogue( alignment, swizzling_functor, batched=False, + layout_b=LayoutType.ColumnMajor, ): """ Instantiate a cutlass kernel from the given configuration, @@ -44,7 +47,7 @@ def create_gemm_operator_with_epilogue( element_a, element_b, element_c, element_epilogue = data_type A = TensorDescription(element_a, LayoutType.RowMajor, alignment) - B = TensorDescription(element_b, LayoutType.ColumnMajor, alignment) + B = TensorDescription(element_b, layout_b, alignment) C = TensorDescription(element_c, LayoutType.RowMajor, alignment) if batched: @@ -74,6 +77,7 @@ def enumerate_gemm_operators( data_type, alignment_constraints, swizzling_functor=SwizzlingFunctor.Identity8, + layout_b=LayoutType.ColumnMajor, ): """Exhaustively instantiate all kernels from a given configuration.""" ret = [] @@ -85,7 +89,7 @@ def enumerate_gemm_operators( for tile_description in tile_descriptions: for alignment in alignment_constraints: A = TensorDescription(element_a, LayoutType.RowMajor, alignment) - B = TensorDescription(element_b, LayoutType.ColumnMajor, alignment) + B = TensorDescription(element_b, layout_b, alignment) C = TensorDescription(element_c, LayoutType.RowMajor, alignment) if element_c == DataType.s32 and A.alignment == 1: @@ -160,7 +164,14 @@ def __init__(self, sm, cutlass_path, binary_path): self.cache = {} def get_default( - self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False + self, + op_type, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32=True, + batched=False, + layout_b=LayoutType.ColumnMajor, ): """Return the default kernel for the requested architecture. For now, the default kernel was picked arbitrary. @@ -169,7 +180,7 @@ def get_default( out_dtype, arg0_dtype, arg1_dtype, - enumerate_gemm_operators, + partial(enumerate_gemm_operators, layout_b=layout_b), lambda align: align == 1, # Only request align1 kernels use_3xtf32, profile_all_alignments=True, # To include all align1 kernels @@ -194,6 +205,7 @@ def get_default( op["alignment"], op["swizzle_functor"], batched=batched, + layout_b=layout_b, ) op.update({"name": name, "opdef": opdef}) return op @@ -210,6 +222,7 @@ def select_op( profile_all_alignments=False, find_first_valid=False, use_multiprocessing=False, + layout_b=LayoutType.ColumnMajor, ): """ Profile and select the best kernel from candidate kernels. @@ -227,7 +240,7 @@ def select_op( out_dtype, arg0_dtype, arg1_dtype, - enumerate_gemm_operators, + partial(enumerate_gemm_operators, layout_b=layout_b), lambda align: all([dim % align == 0 for dim in [M, N, K]]), use_3xtf32, profile_all_alignments=profile_all_alignments, @@ -263,6 +276,7 @@ def profile( find_first_valid=False, use_multiprocessing=False, batched=False, + layout_b=LayoutType.ColumnMajor, ): """Profile and select the best kernel from candidate kernels. If find_first_valid is True, return immediately after the first applicable kernel is found. @@ -279,6 +293,7 @@ def profile( profile_all_alignments=profile_all_alignments, find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, + layout_b=layout_b, ) name, opdef = create_gemm_operator_with_epilogue( @@ -288,6 +303,7 @@ def profile( op["alignment"], op["swizzle_functor"], batched=batched, + layout_b=layout_b, ) return name, opdef, op["runtime"] diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 1eeb0f4b26b6..d3ab020839f3 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -367,6 +367,10 @@ def get_tile_descriptions(math_inst): "cutlass.dense_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True), "cutlass.dense_bias_gelu_fp16": (EpilogueFunctor.LinearCombinationGelu, False), "cutlass.dense_bias_gelu_fp32": (EpilogueFunctor.LinearCombinationGelu, False), + "cutlass.matmul": (EpilogueFunctor.LinearCombination, False), + "cutlass.matmul_bias": (EpilogueFunctor.LinearCombinationBias, True), + "cutlass.matmul_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True), + "cutlass.matmul_bias_gelu": (EpilogueFunctor.LinearCombinationGelu, False), "cutlass.batch_matmul": (EpilogueFunctor.LinearCombination, False), "cutlass.conv2d_bias_hardswish": (EpilogueFunctor.LinearCombinationHardSwish, False), "cutlass.conv2d_bias_silu": (EpilogueFunctor.LinearCombinationSilu, False), diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index 0594d86f2a82..a3578c8a409d 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -616,9 +616,15 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): new_mod = seq(mod) # Extract external runtime modules if exist. - ext_libs = [] - if mod.attrs and "external_mods" in mod.attrs: - ext_libs = mod.attrs["external_mods"] + attrs = dict(mod.attrs) if mod.attrs else {} + + ext_libs = attrs.get("external_mods", []) + constants = attrs.get("const_name_to_constant", {}) + + if params is not None: + params.update(dict(constants)) + else: + params = constants # builder collects the executable builder = relax.ExecBuilder() diff --git a/src/ir/function.cc b/src/ir/function.cc index 6a7ccc7cf27b..59f94201b241 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -46,6 +46,24 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") } }); +TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttrs") + .set_body_typed([](BaseFunc func, Map attr_map) -> BaseFunc { + if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } + if (const auto* f = runtime::Registry::Get("relay.ir.FuncWithAttrs")) { + if (Optional ret = (*f)(func, attr_map)) { + return ret.value(); + } + } + if (const auto* f = runtime::Registry::Get("relax.FuncWithAttrs")) { + if (Optional ret = (*f)(func, attr_map)) { + return ret.value(); + } + } + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + }); + TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") .set_body_typed([](BaseFunc func, String key) -> BaseFunc { if (func->IsInstance()) { diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc new file mode 100644 index 000000000000..b99cae77ec56 --- /dev/null +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/cutlass/codegen.cc + * \brief Implementation of the CUTLASS code generator for Relax. + */ +#include "../../../../relay/backend/contrib/cutlass/codegen.h" + +#include +#include +#include +#include + +#include +#include +#include + +#include "../../../../relay/backend/contrib/codegen_c/codegen_c.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using namespace relay::contrib::cutlass; + +using Output = relay::contrib::Output; +using GenerateBodyOutput = relay::contrib::GenerateBodyOutput; +using relay::contrib::cutlass::GenerateBody; +using OutputType = std::vector; + +class CodegenCutlass : public relax::MemoizedExprTranslator, + public relay::contrib::CodegenCBase { + public: + CodegenCutlass(const std::string& id, const Map& bindings) + : ext_func_id_(id), bindings_(bindings) {} + + std::string JIT(const OutputType& out) final { + std::vector arg_types, arg_names; + + for (const auto& arg : ext_func_args_) { + auto sinfo = GetStructInfo(arg); + if (const auto* tensor_sinfo = sinfo.as()) { + arg_types.emplace_back(backend::DType2String(tensor_sinfo->dtype)); + } else { + LOG(FATAL) << "Unimplemented"; + } + arg_names.push_back(arg->name_hint()); + } + + code_stream_ << EmitSignature(out, ext_func_id_, arg_names) << "{\n"; + + this->EnterScope(); + + for (auto decl : buf_decl_) { + this->PrintIndents(); + code_stream_ << decl << "\n"; + } + code_stream_ << "\n"; + for (auto stmt : ext_func_body_) { + this->PrintIndents(); + code_stream_ << stmt << "\n"; + } + + this->ExitScope(); + code_stream_ << "}\n"; + + this->GenerateBackendCFunc(ext_func_id_, arg_types, /*const_arr_name=*/"", out, true); + return code_stream_.str(); + } + + Array GetHeaders() { return headers_; } + + protected: + OutputType VisitExpr_(const VarNode* node) final { + ext_func_args_.push_back(GetRef(node)); + Output output; + output.name = node->name_hint(); + return {output}; + } + + OutputType VisitExpr_(const CallNode* call) final { + const auto* fn_var = call->op.as(); + ICHECK(fn_var); + const auto func = Downcast(bindings_[GetRef(fn_var)]); + const auto pattern_name_opt = func->GetAttr(attr::kComposite); + ICHECK(pattern_name_opt) << "Only composite function is supported for CUTLASS."; + auto ret = GenerateBody(call, pattern_name_opt.value(), func->attrs->dict); + ext_func_body_.push_back(ret.decl); + headers_ = ret.headers; + return ret.outputs; + } + + OutputType VisitExpr_(const FunctionNode* fn) { + ICHECK(fn->GetAttr(attr::kComposite).defined()) + << "JSON runtime only supports composite functions"; + // FunctionNode should be handled by the caller. + return {}; + } + + OutputType VisitBinding(const Binding& binding) { + OutputType outputs; + if (const auto* node = binding.as()) { + auto from_b = VisitBinding_(node); + outputs.insert(outputs.end(), from_b.begin(), from_b.end()); + } else { + LOG(FATAL) << "Unimplemented type: " << binding->GetTypeKey(); + } + return outputs; + } + + OutputType VisitBindingBlock(const BindingBlock& block) { + OutputType outputs; + if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + outputs.insert(outputs.end(), from_bb.begin(), from_bb.end()); + } else if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + outputs.insert(outputs.end(), from_bb.begin(), from_bb.end()); + } else { + LOG(FATAL) << "Unimplemented type: " << block->GetTypeKey(); + } + return outputs; + } + + OutputType VisitBindingBlock_(const BindingBlockNode* block) { + OutputType outputs; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + outputs.insert(outputs.end(), from_b.begin(), from_b.end()); + } + return outputs; + } + + OutputType VisitBindingBlock_(const DataflowBlockNode* block) { + OutputType outputs; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + outputs.insert(outputs.end(), from_b.begin(), from_b.end()); + } + return outputs; + } + + OutputType VisitExpr_(const SeqExprNode* op) { + OutputType outputs; + + for (BindingBlock block : op->blocks) { + VisitBindingBlock(block); + } + + auto from_body = VisitExpr(op->body); + outputs.insert(outputs.end(), from_body.begin(), from_body.end()); + + return outputs; + } + + private: + Array GetArgumentNames(const CallNode* call) { + Array arg_names; + for (size_t i = 0; i < call->args.size(); ++i) { + auto res = VisitExpr(call->args[i]); + for (const auto& out : res) { + arg_names.push_back(out.name); + } + } + return arg_names; + } + + GenerateBodyOutput GenerateBody(const CallNode* call, const std::string& func_name, + const Map& attrs) { + auto func_args = GetArgumentNames(call); + auto struct_info = GetStructInfo(GetRef(call)); + + std::vector out_types; + if (const auto* tensor_sinfo = struct_info.as()) { + out_types.emplace_back(backend::DType2String(tensor_sinfo->dtype)); + } else { + LOG(FATAL) << "Unimplemented sinfo type: " << struct_info; + } + + return contrib::GenerateBody(func_name, ext_func_id_, out_types, func_args, attrs, &buf_idx_); + } + + /*! \brief The id of the external cutlass ext_func. */ + std::string ext_func_id_; + /*! + * \brief The index to track the output buffer. Each kernel will redirect the + * output to a buffer that may be consumed by other kernels. + */ + int buf_idx_{0}; + /*! \brief The arguments used by a wrapped function that calls CUTLASS kernels. */ + Array ext_func_args_; + /*! \brief The statements of the function that will be compiled using CUTLASS kernels. */ + std::vector ext_func_body_; + /*! \brief The declaration of intermediate buffers. */ + std::vector buf_decl_; + /*! \brief The binding to look up composite functions. */ + Map bindings_; + /*! \brief Required header-file names. */ + Array headers_; +}; + +class CutlassModuleCodegen { + public: + runtime::Module CreateCSourceModule(Function f, const Map& options) { + std::string headers = ""; + auto [code, op_headers] = GenCutlassFunc(f, options); + for (const auto& header : op_headers) { + headers += "#include <" + header + ">\n"; + } + return Finalize(headers + "\n" + code, func_names_); + } + + private: + std::pair> GenCutlassFunc(const Function& function, + const Map& options) { + ICHECK(function.defined()) << "Input error: expect a Relay function."; + + auto sid = GetExtSymbol(function); + func_names_.push_back(sid); + + CodegenCutlass builder(sid, AnalyzeVar2Value(function)); + auto out = builder.VisitExpr(function->body); + return {builder.JIT(out), builder.GetHeaders()}; + } + + /*! \brief The accumulated function names. */ + Array func_names_; +}; + +Array CUTLASSCompiler(Array functions, Map options, + Map /*unused*/) { + const auto* tune_func = runtime::Registry::Get("contrib.cutlass.tune_relax_function"); + ICHECK(tune_func != nullptr) + << "The packed function contrib.cutlass.tune_relax_function not found, " + "please import tvm.contrib.cutlass.build"; + + Array annotated_functions = (*tune_func)(functions, options); + + Array compiled_functions; + for (const auto& func : annotated_functions) { + auto func_name = GetExtSymbol(func); + auto source_mod = CutlassModuleCodegen().CreateCSourceModule(func, options); + const auto* pf = runtime::Registry::Get("contrib.cutlass.compile"); + ICHECK(pf != nullptr) << "The packed function contrib.cutlass.compile not found, please import " + "tvm.contrib.cutlass.build"; + compiled_functions.push_back((*pf)(source_mod, options)); + } + + return compiled_functions; +} + +TVM_REGISTER_GLOBAL("relax.ext.cutlass").set_body_typed(CUTLASSCompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index a0aaea886ddc..5392be7cb69b 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -552,5 +552,29 @@ TVM_REGISTER_GLOBAL("relax.GetShapeOf").set_body_typed([](const Expr& expr) { return GetShapeOf(expr); }); +TVM_REGISTER_GLOBAL("relax.FuncWithAttr") + .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> Optional { + if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } + return NullOpt; + }); + +TVM_REGISTER_GLOBAL("relax.FuncWithAttrs") + .set_body_typed([](BaseFunc func, Map attr_map) -> Optional { + if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } + return NullOpt; + }); + +TVM_REGISTER_GLOBAL("relax.FuncWithoutAttr") + .set_body_typed([](BaseFunc func, String key) -> Optional { + if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } + return NullOpt; + }); + } // namespace relax } // namespace tvm diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index db8e0329d943..cdbfbed8db89 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -133,7 +133,7 @@ class CodegenCBase { * \brief Gerenate C code for the external function. * * \param func_name The name of the external function. - * \param args arguments to the external function. + * \param arg_types Types of arguments represented as string * * \code * @@ -160,14 +160,14 @@ class CodegenCBase { * * \endcode */ - void GenerateBackendCFunc(const std::string& func_name, const Array& args, + void GenerateBackendCFunc(const std::string& func_name, const std::vector& arg_types, const std::string& const_arr_name, const std::vector& outs, bool pass_dl_tensor = false) { // Print signature code_stream_ << "\n"; code_stream_ << "int " << func_name << "_wrapper_("; - for (size_t i = 0; i < args.size(); i++) { + for (size_t i = 0; i < arg_types.size(); i++) { code_stream_ << "DLTensor* arg" << i << ",\n"; code_stream_ << "\t"; } @@ -182,12 +182,11 @@ class CodegenCBase { // Generate the internal call. PrintIndents(); code_stream_ << func_name << "_("; - for (size_t i = 0; i < args.size(); i++) { + for (size_t i = 0; i < arg_types.size(); i++) { if (pass_dl_tensor) { code_stream_ << "arg" << i << ",\n"; } else { - const auto& dtype_str = GetDtypeString(args[i]); - code_stream_ << "(" << dtype_str << "*)(arg" << i << "->data),\n"; + code_stream_ << "(" << arg_types[i] << "*)(arg" << i << "->data),\n"; } PrintIndents(); } @@ -212,21 +211,21 @@ class CodegenCBase { // Create the external function PrintRuntimeFunctionHeader(func_name); EnterScope(); - for (size_t i = 0; i < args.size(); i++) { + for (size_t i = 0; i < arg_types.size(); i++) { PrintArgToData(i); } for (size_t i = 0; i < outs.size(); i++) { - PrintRetToData(args.size() + i); + PrintRetToData(arg_types.size() + i); } PrintIndents(); code_stream_ << func_name << "_wrapper_("; - for (size_t i = 0; i < args.size(); i++) { + for (size_t i = 0; i < arg_types.size(); i++) { code_stream_ << "arg" << i << ","; } for (size_t i = 0; i < outs.size() - 1; i++) { - code_stream_ << "ret" << args.size() + i << ","; + code_stream_ << "ret" << arg_types.size() + i << ","; } - code_stream_ << "ret" << args.size() + outs.size() - 1 << ");\n"; + code_stream_ << "ret" << arg_types.size() + outs.size() - 1 << ");\n"; PrintIndents(); code_stream_ << "return 0;\n"; ExitScope(); @@ -256,6 +255,16 @@ class CodegenCBase { } } + void GenerateBackendCFunc(const std::string& func_name, const Array& args, + const std::string& const_arr_name, const std::vector& outs, + bool pass_dl_tensor = false) { + std::vector arg_types; + for (size_t i = 0; i < args.size(); i++) { + arg_types.push_back(GetDtypeString(args[i])); + } + return GenerateBackendCFunc(func_name, arg_types, const_arr_name, outs, pass_dl_tensor); + } + /*! * \brief Emit the code for external runtime. * @@ -370,6 +379,10 @@ class CodegenCBase { dtype = "int"; } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) { dtype = "int64_t"; + } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 8)) { + dtype = "int8_t"; + } else if (runtime::TypeMatch(ttype->dtype, kDLUInt, 8)) { + dtype = "uint8_t"; } else { LOG(FATAL) << "Unsupported dtype " << ttype->dtype; } diff --git a/src/relay/backend/contrib/cutlass/codegen.h b/src/relay/backend/contrib/cutlass/codegen.h index e70e97a2fafa..03b8e6afbddc 100644 --- a/src/relay/backend/contrib/cutlass/codegen.h +++ b/src/relay/backend/contrib/cutlass/codegen.h @@ -27,6 +27,11 @@ #include +#include +#include + +#include "../codegen_c/codegen_c.h" + namespace tvm { namespace relay { namespace contrib { @@ -40,6 +45,22 @@ namespace cutlass { */ transform::Pass CompileForCutlass(); +// The rest is sparsely documented since they are exposed only for code sharing between Relay +// and Relax backend implementations. + +/*! \brief Emit the function signature for a kernel */ +std::string EmitSignature(const std::vector& out, + const std::string& func_id, const std::vector& arg_names); + +/*! \brief Generate the body of the kernel */ +GenerateBodyOutput GenerateBody(const std::string& func_name, const std::string& ext_func_id, + const std::vector& output_types, + const Array& func_args, const Map& attrs, + int* buf_idx); + +/*! \brief Create a C-source module from the given kernel string */ +runtime::Module Finalize(const std::string& code, const Array& func_names); + } // namespace cutlass } // namespace contrib } // namespace relay diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py new file mode 100644 index 000000000000..1eafb1bc1caf --- /dev/null +++ b/tests/python/relax/test_codegen_cutlass.py @@ -0,0 +1,429 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import math +from typing import List, Tuple + +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm import relax, relay +from tvm.contrib.cutlass.build import finalize_modules_relax +from tvm.relax.dpl import make_fused_bias_activation_pattern, make_matmul_pattern +from tvm.script import relax as R + + +@pytest.fixture(autouse=True) +def reset_seed(): + np.random.seed(0) + + +def get_relay_conv2d_bias_relu( + d_shape, w_shape, data_dtype="float16", weight_dtype="float16", out_dtype="float16" +): + data = relay.var("data", shape=d_shape, dtype=data_dtype) + weight = relay.var("weight", shape=w_shape, dtype=weight_dtype) + bias = relay.var("bias", shape=(1, 1, 1, w_shape[0]), dtype=out_dtype) + return relay.nn.relu( + relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="OHWI", + out_dtype=out_dtype, + ) + + bias + ) + + +def get_relay_matmul( + x_shape, + y_shape, + x_dtype="float16", + y_dtype="float16", + out_dtype="float16", +): + x = relay.var("x", shape=x_shape, dtype=x_dtype) + y = relay.var("y", shape=y_shape, dtype=y_dtype) + return relay.nn.dense(x, y, out_dtype=out_dtype) + + +def get_relay_matmul_bias( + x_shape, + y_shape, + x_dtype="float16", + y_dtype="float16", + bias_dtype="float16", + out_dtype="float16", +): + bias = relay.var("bias", shape=(y_shape[0],), dtype=bias_dtype) + return relay.nn.bias_add( + get_relay_matmul( + x_shape, + y_shape, + x_dtype, + y_dtype, + out_dtype, + ), + bias, + ) + + +def get_relay_matmul_bias_relu( + x_shape, + y_shape, + x_dtype="float16", + y_dtype="float16", + bias_dtype="float16", + out_dtype="float16", +): + return relay.nn.relu( + get_relay_matmul_bias( + x_shape, + y_shape, + x_dtype, + y_dtype, + bias_dtype, + out_dtype, + ) + ) + + +def get_relay_matmul_bias_gelu( + x_shape, + y_shape, + x_dtype="float16", + y_dtype="float16", + bias_dtype="float16", + out_dtype="float16", +): + bias_add = get_relay_matmul_bias(x_shape, y_shape, x_dtype, y_dtype, bias_dtype, out_dtype) + mul = bias_add * relay.const((1.0 / math.sqrt(2.0)), dtype=out_dtype) + if out_dtype == "float16": + erf = relay.cast(relay.op.erf(relay.cast(mul, "float32")), "float16") + else: + erf = relay.op.erf(mul) + mul_half = erf * relay.const(0.5, dtype=out_dtype) + add = mul_half + relay.const(0.5, dtype=out_dtype) + return add * bias_add + + +def get_relay_conv2d_relu_x2( + d_shape, w_shape, data_dtype="float16", weight_dtype="float16", out_dtype="float16" +): + data = relay.var("data", shape=d_shape, dtype=data_dtype) + weight1 = relay.var("weight1", shape=w_shape, dtype=weight_dtype) + weight2 = relay.var("weight2", shape=w_shape, dtype=weight_dtype) + + conv1 = relay.nn.conv2d( + data=data, + weight=weight1, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="OHWI", + out_dtype=out_dtype, + ) + return relay.nn.conv2d( + data=conv1, + weight=weight2, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="OHWI", + out_dtype=out_dtype, + ) + + +def get_relay_ref(relay_expr, *args): + relay_mod = tvm.IRModule.from_expr(relay_expr) + + with tvm.transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential( + [relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "HWIO"]})] + ) + relay_mod = seq(relay_mod) + + return ( + relay.create_executor("graph", mod=relay_mod, device=tvm.gpu(0), target="cuda") + .evaluate()(*args) + .numpy() + ) + + +@tvm.script.ir_module +class Conv2dBiasReLU: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), "float16"), + weight: R.Tensor((32, 3, 3, 16), "float16"), + bias: R.Tensor((1, 1, 1, 32), "float16"), + ): + with R.dataflow(): + conv1 = relax.op.nn.relu( + relax.op.add( + relax.op.nn.conv2d( + data, weight, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ), + bias, + ) + ) + R.output(conv1) + + return conv1 + + +@tvm.script.ir_module +class Conv2dx2: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), "float16"), + weight1: R.Tensor((16, 3, 3, 16), "float16"), + weight2: R.Tensor((16, 3, 3, 16), "float16"), + ): + with R.dataflow(): + conv1 = relax.op.nn.conv2d( + data, weight1, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + conv2 = relax.op.nn.conv2d( + conv1, weight2, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ) + R.output(conv2) + + return conv2 + + +has_cutlass = tvm.get_global_func("relax.ext.cutlass", True) + +cutlass_enabled = pytest.mark.skipif( + not has_cutlass, + reason="CUTLASS note enabled.", +) + +pytestmark = [cutlass_enabled] + + +def get_result_with_relax_cutlass_offload(mod, patterns: List[Tuple], *args): + seq = tvm.transform.Sequential( + [ + relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True), + relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}}), + ] + ) + + mod = seq(mod) + + target = tvm.target.Target("cuda") + ex = relax.vm.build(mod, target) + ex = finalize_modules_relax(ex) + + dev = tvm.gpu(0) + vm = relax.VirtualMachine(ex, dev) + + return vm["main"](*(tvm.nd.array(arg, dev) for arg in args)).numpy() + + +def test_conv2d_offload(): + data = np.random.randn(16, 32, 32, 16).astype("float16") + weight = np.random.randn(32, 3, 3, 16).astype("float16") + bias = np.random.randn(1, 1, 1, 32).astype("float16") + + patterns = [ + ( + "cutlass.conv2d_bias_relu", + make_fused_bias_activation_pattern( + "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu" + ), + ) + ] + out = get_result_with_relax_cutlass_offload(Conv2dBiasReLU, patterns, data, weight, bias) + + ref_relay_expr = get_relay_conv2d_bias_relu(data.shape, weight.shape) + ref = get_relay_ref(ref_relay_expr, data, weight, bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) + + +def get_relax_matmul_module(x, y, with_bias=False, activation=None): + m, k = x.shape + n = y.shape[-1] + dtype = str(x.dtype) + + from tvm.script.ir_builder import IRBuilder + from tvm.script.ir_builder import relax as relax_builder + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + x = R.arg("x", R.Tensor((m, k), dtype)) + y = R.arg("y", R.Tensor((k, n), dtype)) + if with_bias: + bias = R.arg("bias", R.Tensor((n,), dtype)) + + with R.dataflow() as frame: + result = R.emit(R.matmul(x, y)) + if with_bias: + result = R.emit(result + bias) + if activation is not None: + result = R.emit(activation(result)) + R.output(result) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +@pytest.fixture(params=["float16"]) +def target_dtype(request): + return request.param + + +@pytest.fixture( + params=[ + # M, K, N + (32, 6, 16), + (29, 17, 19), + (64, 512, 1024), + ] +) +def matmul_size(request): + return request.param + + +@pytest.fixture +def matmul_x(matmul_size, target_dtype): + m, k, _ = matmul_size + return np.random.randn(m, k).astype(target_dtype) + + +@pytest.fixture +def matmul_y(matmul_size, target_dtype): + _, k, n = matmul_size + return np.random.randn(k, n).astype(target_dtype) + + +@pytest.fixture +def matmul_bias(matmul_size, target_dtype): + _, _, n = matmul_size + return np.random.randn(n).astype(target_dtype) + + +def test_matmul_offload(matmul_x, matmul_y): + x, y = matmul_x, matmul_y + + patterns = [ + ( + "cutlass.matmul", + make_matmul_pattern( + with_bias=False, + ), + ), + ] + + mod = get_relax_matmul_module(x, y) + out = get_result_with_relax_cutlass_offload(mod, patterns, x, y) + ref_relay_expr = get_relay_matmul(x.shape, y.shape[::-1]) + ref = get_relay_ref(ref_relay_expr, x, y.transpose()) + + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-4) + + +def test_matmul_bias_offload(matmul_x, matmul_y, matmul_bias): + x, y, bias = matmul_x, matmul_y, matmul_bias + + patterns = [ + ( + "cutlass.matmul_bias", + make_matmul_pattern( + with_bias=True, + ), + ), + ] + mod = get_relax_matmul_module(x, y, with_bias=True) + out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias) + + ref_relay_expr = get_relay_matmul_bias(x.shape, y.shape[::-1]) + ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-4) + + +def test_matmul_bias_relu_offload(matmul_x, matmul_y, matmul_bias): + x, y, bias = matmul_x, matmul_y, matmul_bias + + patterns = [ + ( + "cutlass.matmul_bias_relu", + make_matmul_pattern( + with_bias=True, + activation="relax.nn.relu", + ), + ), + ] + mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.relu) + out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias) + + ref_relay_expr = get_relay_matmul_bias_relu(x.shape, y.shape[::-1]) + ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-4) + + +def test_matmul_bias_gelu_offload(matmul_x, matmul_y, matmul_bias): + x, y, bias = matmul_x, matmul_y, matmul_bias + + patterns = [ + ( + "cutlass.matmul_bias_gelu", + make_matmul_pattern( + with_bias=True, + activation="relax.nn.gelu", + ), + ), + ] + mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.gelu) + out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias) + + ref_relay_expr = get_relay_matmul_bias_gelu(x.shape, y.shape[::-1]) + ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-3) + + +def test_kernel_sharing(): + data_np = np.random.randn(16, 32, 32, 16).astype("float16") + weight1_np = np.random.randn(16, 3, 3, 16).astype("float16") + weight2_np = np.random.randn(16, 3, 3, 16).astype("float16") + + pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None) + + out = get_result_with_relax_cutlass_offload( + Conv2dx2, [("cutlass.conv2d", pat)], data_np, weight1_np, weight2_np + ) + + relay_expr = get_relay_conv2d_relu_x2(data_np.shape, weight1_np.shape) + ref = get_relay_ref(relay_expr, data_np, weight1_np, weight2_np) + + tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main() From cdd61cdf0f94d0fb69e41b0e0250a0c0ce0a01da Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 22 Feb 2023 21:55:32 +0900 Subject: [PATCH 57/81] [Unity][BYOC] Add DNNL backend (#14082) This PR adds dnnl backend to the unity flow. --- cmake/modules/contrib/DNNL.cmake | 8 +- src/relax/backend/contrib/dnnl/codegen.cc | 105 +++++++++++++++++++ tests/python/relax/test_codegen_dnnl.py | 120 ++++++++++++++++++++++ 3 files changed, 229 insertions(+), 4 deletions(-) create mode 100644 src/relax/backend/contrib/dnnl/codegen.cc create mode 100644 tests/python/relax/test_codegen_dnnl.py diff --git a/cmake/modules/contrib/DNNL.cmake b/cmake/modules/contrib/DNNL.cmake index 7547af81eb1a..857f7bdfd597 100644 --- a/cmake/modules/contrib/DNNL.cmake +++ b/cmake/modules/contrib/DNNL.cmake @@ -21,8 +21,8 @@ if(IS_DIRECTORY ${USE_DNNL}) message(WARNING "Cannot find DNNL library at ${USE_DNNL}.") else() add_definitions(-DUSE_JSON_RUNTIME=1) - tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc) - list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc src/relax/backend/contrib/dnnl/*.cc) + list(APPEND COMPILER_SRCS ${DNNL_CONTRIB_SRC}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL}) tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -34,8 +34,8 @@ if(IS_DIRECTORY ${USE_DNNL}) endif() elseif((USE_DNNL STREQUAL "ON") OR (USE_DNNL STREQUAL "JSON")) add_definitions(-DUSE_JSON_RUNTIME=1) - tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc) - list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc src/relax/backend/contrib/dnnl/*.cc) + list(APPEND COMPILER_SRCS ${DNNL_CONTRIB_SRC}) find_library(EXTERN_LIBRARY_DNNL dnnl) list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL}) diff --git a/src/relax/backend/contrib/dnnl/codegen.cc b/src/relax/backend/contrib/dnnl/codegen.cc new file mode 100644 index 000000000000..3cbf4cfa2ace --- /dev/null +++ b/src/relax/backend/contrib/dnnl/codegen.cc @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/dnnl/codegen.cc + * \brief Implementation of the DNNL JSON serializer. + */ +#include + +#include + +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; +using JSONSerializer = backend::contrib::JSONSerializer; +using backend::contrib::NodeEntries; + +class DNNLJSONSerializer : public JSONSerializer { + public: + DNNLJSONSerializer(Map constant_names, Map bindings) + : JSONSerializer(constant_names), bindings_(bindings) {} + + using JSONSerializer::VisitExpr_; + + NodeEntries VisitExpr_(const CallNode* call_node) final { + const auto* fn_var = call_node->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + ICHECK(fn.defined()) << "Expects the callee to be a function."; + + auto composite_opt = fn->GetAttr(attr::kComposite); + ICHECK(composite_opt.defined()) << "Only composite functions are supported."; + + std::string composite_name = composite_opt.value(); + + NodeEntries inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(composite_name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + + const CallNode* root_call = nullptr; + if (composite_name.find("conv2d") != std::string::npos) { + root_call = backend::GetOpInFunction(fn, "relax.nn.conv2d"); + } else { + LOG(FATAL) << "Unimplemented pattern: " << composite_name; + } + + SetCallNodeAttribute(node, root_call); + return AddNode(node, GetRef(call_node)); + } + + private: + /*! \brief The bindings to look up composite functions. */ + Map bindings_; +}; + +Array DNNLCompiler(Array functions, Map /*unused*/, + Map constant_names) { + Array compiled_functions; + + for (const auto& func : functions) { + DNNLJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + serializer.serialize(func); + auto graph_json = serializer.GetJSON(); + auto constant_names = serializer.GetConstantNames(); + const auto* pf = runtime::Registry::Get("runtime.DNNLJSONRuntimeCreate"); + ICHECK(pf != nullptr) << "Cannot find DNNL runtime module create function."; + auto func_name = GetExtSymbol(func); + compiled_functions.push_back((*pf)(func_name, graph_json, constant_names)); + } + + return compiled_functions; +} + +TVM_REGISTER_GLOBAL("relax.ext.dnnl").set_body_typed(DNNLCompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_codegen_dnnl.py b/tests/python/relax/test_codegen_dnnl.py new file mode 100644 index 000000000000..69139b28ef34 --- /dev/null +++ b/tests/python/relax/test_codegen_dnnl.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import numpy as np +import tvm +import tvm.testing + +from tvm import relax, relay +from tvm.script import relax as R +from tvm.relax.dpl import make_fused_bias_activation_pattern + + +def get_relay_conv2d_relu_x2(d_shape, w_shape): + data = relay.var("data", shape=d_shape) + weight1 = relay.var("weight1", shape=w_shape) + weight2 = relay.var("weight2", shape=w_shape) + conv1 = relay.nn.relu( + relay.nn.conv2d( + data=data, + weight=weight1, + kernel_size=w_shape[2:], + padding=(1, 1), + ) + ) + return relay.nn.relu( + relay.nn.conv2d( + data=conv1, + weight=weight2, + kernel_size=w_shape[2:], + padding=(0, 0), + ) + ) + + +@tvm.script.ir_module +class Conv2dReLUx2: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + weight2: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = relax.op.nn.relu(relax.op.nn.conv2d(data, weight1, padding=(1, 1))) + conv2 = relax.op.nn.relu(relax.op.nn.conv2d(conv1, weight2, padding=(0, 0))) + R.output(conv2) + + return conv2 + + +has_dnnl = tvm.get_global_func("relax.ext.dnnl", True) + +dnnl_enabled = pytest.mark.skipif( + not has_dnnl, + reason="DNNL note enabled.", +) + +pytestmark = [dnnl_enabled] + + +def test_dnnl_offload(): + pat = make_fused_bias_activation_pattern( + "relax.nn.conv2d", with_bias=False, activation="relax.nn.relu" + ) + + seq = tvm.transform.Sequential( + [ + relax.transform.FuseOpsByPattern([("dnnl.conv2d_relu", pat)]), + relax.transform.MergeCompositeFunctions(), + relax.transform.RunCodegen(), + ] + ) + + mod = seq(Conv2dReLUx2) + + target = tvm.target.Target("llvm") + ex = relax.vm.build(mod, target) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + f = vm["main"] + + data_np = np.random.randn(1, 64, 56, 56).astype("float32") + weight1_np = np.random.randn(64, 64, 3, 3).astype("float32") + weight2_np = np.random.randn(64, 64, 3, 3).astype("float32") + out = f(tvm.nd.array(data_np), tvm.nd.array(weight1_np), tvm.nd.array(weight2_np)).numpy() + + relay_mod = tvm.IRModule.from_expr(get_relay_conv2d_relu_x2(data_np.shape, weight1_np.shape)) + + ref = ( + relay.create_executor("graph", mod=relay_mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight1_np, weight2_np]) + .numpy() + ) + + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) + + profiler_vm = relax.VirtualMachine(ex, tvm.cpu(), profile=True) + report = profiler_vm.profile( + "main", tvm.nd.array(data_np), tvm.nd.array(weight1_np), tvm.nd.array(weight2_np) + ) + + print(report) + + +if __name__ == "__main__": + test_dnnl_offload() From e7354e6463374195e0de5048611931b93a1daa5d Mon Sep 17 00:00:00 2001 From: Chaosfan <1713833595@qq.com> Date: Wed, 22 Feb 2023 20:56:24 +0800 Subject: [PATCH 58/81] [Unity][Op] `log_softmax` and `cross_entropy_with_logits` (#14083) This PR introduces two high-level operators log_softmax and cross_entropy_with_logits, which are important when we are calculating CrossEntropyLoss (in torch). Co-authored-by: Yixin Dong --- python/tvm/relax/op/nn/nn.py | 55 ++++ python/tvm/relax/transform/legalize_ops/nn.py | 20 ++ src/relax/op/nn/nn.cc | 74 +++++ src/relax/op/nn/nn.h | 6 + tests/python/relax/test_op_nn.py | 173 ++++++++++- .../relax/test_transform_legalize_ops_nn.py | 268 ++++++++++++++++++ .../relax/test_tvmscript_parser_op_nn.py | 33 +++ 7 files changed, 621 insertions(+), 8 deletions(-) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index cdf0e9646492..0ff143fd045b 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -348,6 +348,34 @@ def softmax(data: Expr, axis: int = -1) -> Expr: return _ffi_api.softmax(data, axis) # type: ignore +def log_softmax(data: Expr, axis: int = -1) -> Expr: + r"""Computes log softmax. + + .. math:: + + \text{log\_softmax}(x_i) = \log\left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}\right) + + .. note:: + This operator can be optimized away for inference. + + Parameters + ---------- + data: relax.Expr + The input data to the operator. + + axis: int + The axis to sum over when computing log softmax. + If not specified, it is by default the last axis of the input tensor. + Supports negative indexing. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.log_softmax(data, axis) # type: ignore + + def batch_norm( data: Expr, gamma: Expr, @@ -522,3 +550,30 @@ def dropout(data: Expr, rate: float = 0.5) -> Expr: mask tensor (1.0 where element not dropped, 0.0 where dropped) """ return _ffi_api.dropout(data, rate) # type: ignore + + +def cross_entropy_with_logits(predictions: Expr, labels: Expr) -> Expr: + r"""CrossEntropy with logits between the predictions and labels. + + The shape of predictions and labels must be the same. And when ndim >= 2, + the first dimension is regarded as the batch_size N. In this case the + computed result will divide by N to perform a mean reduction. + + .. math:: + + \text{cross\_entropy\_with\_logits}(x_i, y_i) = \frac{\sum_i -x_i \cdot y_i}{N} + + Parameters + ---------- + predictions : relax.Expr + The predictions. + + labels : relax.Expr + The labels (the ground truth values). + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.cross_entropy_with_logits(predictions, labels) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 49f198306d14..31c6eb04cd2d 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -144,6 +144,26 @@ def _nn_softmax(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te(topi.nn.softmax, call.args[0], call.attrs.axis) +@register_legalize("relax.nn.log_softmax") +def _nn_log_softmax(bb: BlockBuilder, call: Call): + return bb.call_te(topi.nn.log_softmax, call.args[0], call.attrs.axis) + + +@register_legalize("relax.nn.cross_entropy_with_logits") +def _nn_cross_entropy_with_logits(bb: BlockBuilder, call: Call): + def te_cross_entropy_with_logits(x, y): + if len(x.shape) > 1: + return -topi.sum(x * y) / x.shape[0] + return -topi.sum(x * y) + + return bb.call_te( + te_cross_entropy_with_logits, + call.args[0], + call.args[1], + primfunc_name_hint="cross_entropy_with_logits", + ) + + @register_legalize("relax.nn.batch_norm") def _nn_batch_norm(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te( diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 66ae10fe6ccd..e63b3306f25d 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -68,6 +68,22 @@ TVM_REGISTER_OP("relax.nn.softmax") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoSoftmax); +/* relax.nn.log_softmax */ +Expr log_softmax(Expr data, int axis) { + auto attrs = make_object(); + attrs->axis = axis; + static const Op& op = Op::Get("relax.nn.log_softmax"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.log_softmax").set_body_typed(log_softmax); + +TVM_REGISTER_OP("relax.nn.log_softmax") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoSoftmax); + bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, const Array& input_sinfo, Array axes) { Op op = Downcast(call->op); @@ -241,5 +257,63 @@ TVM_REGISTER_OP("relax.nn.dropout") .add_argument("data", "Tensor", "Input to which dropout will be applied.") .set_attr("FInferStructInfo", InferStructInfoDropout); +/* relax.nn.cross_entropy_with_logits */ +StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo pred_sinfo = input_sinfo[0]; + TensorStructInfo label_sinfo = input_sinfo[1]; + + // infer dtype + DataType dtype = InferBinaryArithOpOutDtype(call, ctx, pred_sinfo, label_sinfo); + + // infer ndim + if (!pred_sinfo->IsUnknownNdim() && !label_sinfo->IsUnknownNdim() && + pred_sinfo->ndim != label_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "CrossEntropy requires predictions and labels to have the same ndim. " + "However, the ndim of predictions is " + << pred_sinfo->ndim << " while the ndim of labels is " << label_sinfo->ndim); + } + + Optional> pred_shape_value; + if (pred_sinfo->shape.defined()) { + pred_shape_value = GetStructInfoAs(pred_sinfo->shape.value())->values; + } + + Optional> label_shape_value; + if (label_sinfo->shape.defined()) { + label_shape_value = GetStructInfoAs(label_sinfo->shape.value())->values; + } + + if (pred_shape_value.defined() && label_shape_value.defined()) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + for (size_t i = 0; i < pred_shape_value.value().size(); ++i) { + if (analyzer->CanProve(pred_shape_value.value()[i] != label_shape_value.value()[i])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "CrossEntropy requires the predictions and labels to have " + "the same shape. However, the shape of predictions at dim " + << i << " is" << pred_shape_value.value()[i] + << " while the shape of labels at this dim is " + << label_shape_value.value()[i]); + } + } + } + return TensorStructInfo(ShapeExpr(Array()), dtype); +} + +Expr cross_entropy_with_logits(Expr predictions, Expr labels) { + static const Op& op = Op::Get("relax.nn.cross_entropy_with_logits"); + return Call(op, {std::move(predictions), std::move(labels)}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.cross_entropy_with_logits") + .set_body_typed(cross_entropy_with_logits); + +TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") + .set_num_inputs(2) + .add_argument("predictions", "Tensor", "The predictions.") + .add_argument("labels", "Tensor", "The labels.") + .set_attr("FInferStructInfo", InferStructInfoCrossEntropy); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index df2b978fc296..f13b930fc246 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -57,6 +57,9 @@ Expr silu(Expr data); /*! \brief Softmax function. */ Expr softmax(Expr data, int axis); +/*! \brief LogSoftmax function. */ +Expr log_softmax(Expr data, int axis); + /*! \brief Compute batch normalization. */ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // int axis, double epsilon, bool center, bool scale); @@ -75,6 +78,9 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double ep */ Expr dropout(Expr data, double rate); +/*! \brief CrossEntropy with logits. */ +Expr cross_entropy_with_logits(Expr predictions, Expr labels); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index d047448309ab..5294596cee34 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -29,6 +29,7 @@ def test_op_correctness(): assert relax.op.nn.gelu(x).op == Op.get("relax.nn.gelu") assert relax.op.nn.silu(x).op == Op.get("relax.nn.silu") assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax") + assert relax.op.nn.log_softmax(x).op == Op.get("relax.nn.log_softmax") assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout") x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) @@ -41,6 +42,12 @@ def test_op_correctness(): ) assert relax.op.nn.layer_norm(x, gamma, beta, axes=1).op == Op.get("relax.nn.layer_norm") + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + assert relax.op.nn.cross_entropy_with_logits(x, y).op == Op.get( + "relax.nn.cross_entropy_with_logits" + ) + def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): ret = bb.normalize(call) @@ -117,7 +124,7 @@ def test_linear_unit_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.silu(x1)) -def test_softmax_infer_struct_info(): +def test_softmax_log_softmax_infer_struct_info(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) @@ -133,8 +140,20 @@ def test_softmax_infer_struct_info(): _check_inference(bb, relax.op.nn.softmax(x3, axis=-1), relax.TensorStructInfo((2, 3), dtype="")) _check_inference(bb, relax.op.nn.softmax(x4, axis=-2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.nn.log_softmax(x1, axis=0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.nn.log_softmax(x2, axis=1), relax.TensorStructInfo(dtype="float32") + ) + _check_inference( + bb, relax.op.nn.log_softmax(x3, axis=-1), relax.TensorStructInfo((2, 3), dtype="") + ) + _check_inference(bb, relax.op.nn.log_softmax(x4, axis=-2), relax.TensorStructInfo(dtype="")) + -def test_softmax_infer_struct_info_shape_symbolic(): +def test_softmax_log_softmax_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() m = tir.Var("m", "int64") n = tir.Var("n", "int64") @@ -144,8 +163,13 @@ def test_softmax_infer_struct_info_shape_symbolic(): _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((m, n), "float32")) _check_inference(bb, relax.op.nn.softmax(x1, axis=0), relax.TensorStructInfo((4, n), "float32")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference( + bb, relax.op.nn.log_softmax(x1, axis=0), relax.TensorStructInfo((4, n), "float32") + ) -def test_softmax_infer_struct_info_shape_var(): + +def test_softmax_log_softmax_infer_struct_info_shape_var(): bb = relax.BlockBuilder() s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) s1 = relax.Var("s", relax.ShapeStructInfo()) @@ -155,8 +179,11 @@ def test_softmax_infer_struct_info_shape_var(): _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo(s0, "float32")) _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.nn.log_softmax(x1), relax.TensorStructInfo(s1, "float32")) + -def test_softmax_infer_struct_info_more_input_dtype(): +def test_softmax_log_softmax_infer_struct_info_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float16")) x1 = relax.Var("x", R.Tensor((2, 3), "float64")) @@ -164,8 +191,11 @@ def test_softmax_infer_struct_info_more_input_dtype(): _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 3), "float16")) _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.nn.log_softmax(x0), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.nn.log_softmax(x1), relax.TensorStructInfo((2, 3), "float64")) + -def test_softmax_infer_struct_info_invalid_input_dtype(): +def test_softmax_log_softmax_infer_struct_info_invalid_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "int8")) x1 = relax.Var("x", R.Tensor((2, 3), "int64")) @@ -174,26 +204,40 @@ def test_softmax_infer_struct_info_invalid_input_dtype(): bb.normalize(relax.op.nn.softmax(x0)) with pytest.raises(TVMError): bb.normalize(relax.op.nn.softmax(x1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x1)) -def test_softmax_infer_struct_info_axis_out_of_range(): +def test_softmax_log_softmax_infer_struct_info_axis_out_of_range(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + with pytest.raises(TVMError): bb.normalize(relax.op.nn.softmax(x, axis=3)) with pytest.raises(TVMError): bb.normalize(relax.op.nn.softmax(x, axis=-4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x, axis=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x, axis=-4)) -def test_softmax_wrong_with_multiple_axes(): +def test_softmax_log_softmax_wrong_with_multiple_axes(): x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + with pytest.raises(TVMError): relax.op.nn.softmax(x, axis=[1, 2]) with pytest.raises(TVMError): relax.op.nn.softmax(x, axis=[-1, -2, -3]) + with pytest.raises(TVMError): + relax.op.nn.log_softmax(x, axis=[1, 2]) + with pytest.raises(TVMError): + relax.op.nn.log_softmax(x, axis=[-1, -2, -3]) -def test_softmax_infer_struct_info_wrong_input_type(): +def test_softmax_log_softmax_infer_struct_info_wrong_input_type(): bb = relax.BlockBuilder() x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) @@ -202,6 +246,10 @@ def test_softmax_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.softmax(x0)) with pytest.raises(TVMError): bb.normalize(relax.op.nn.softmax(x1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.log_softmax(x1)) def test_batch_norm_infer_struct_info(): @@ -925,5 +973,114 @@ def test_dropout_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.dropout(x1)) +def test_cross_entropy_infer_struct_info(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=2)) + y2 = relax.Var("y", R.Tensor((2, 3))) + y3 = relax.Var("y", R.Tensor(ndim=2)) + + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y0), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, + relax.op.nn.cross_entropy_with_logits(x, y1), + relax.TensorStructInfo((), dtype="float32"), + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y2), relax.TensorStructInfo((), dtype="") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y3), relax.TensorStructInfo((), dtype="") + ) + + +def test_cross_entropy_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m0 = tir.Var("m", "int64") + m1 = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m0, n), "float32")) + x1 = relax.Var("x", R.Tensor((m1, n), "float32")) + y = relax.Var("y", R.Tensor((m0, n), "float32")) + + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x0, y), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x1, y), relax.TensorStructInfo((), "float32") + ) + + +def test_cross_entropy_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y0), relax.TensorStructInfo((), "float32") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x, y1), relax.TensorStructInfo((), "float32") + ) + + +def test_cross_entropy_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) + y0 = relax.Var("y", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + y1 = relax.Var("y", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int32")) + y2 = relax.Var("y", R.Tensor((2, 3), "int32")) + + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x0, y0), relax.TensorStructInfo((), "float16") + ) + _check_inference( + bb, relax.op.nn.cross_entropy_with_logits(x1, y1), relax.TensorStructInfo((), "int8") + ) + + +def test_cross_entropy_infer_struct_info_wrong_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x1, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y1)) + + +def test_cross_entropy_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 4), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y0)) + + +def test_cross_entropy_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x0, y)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.cross_entropy_with_logits(x1, y)) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 729368b82a21..07d414980e30 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -851,6 +851,274 @@ def softmax(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): tvm.ir.assert_structural_equal(mod, Expected) +def test_log_softmax(): + # fmt: off + @tvm.script.ir_module + class LogSoftmax: + @R.function + def main(x: R.Tensor((2, 3, 16, 32), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv: R.Tensor((2, 3, 16, 32), "float32") = R.nn.log_softmax(x, axis=-2) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 16, 32), dtype="float32")) -> R.Tensor((2, 3, 16, 32), dtype="float32"): + gv = R.call_tir(log_softmax, (x,), R.Tensor((2, 3, 16, 32), dtype="float32")) + return gv + + @T.prim_func + def log_softmax(rxplaceholder: T.Buffer[(T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"], compute: T.Buffer[(T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"],): + T.func_attr({"tir.noalias": True}) + T_softmax_maxelem = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") + compute_1 = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): + with T.block("T_softmax_maxelem"): + i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_1, i1_1, k, i2_1]) + T.writes(T_softmax_maxelem[i0_1, i1_1, i2_1]) + with T.init(): + T_softmax_maxelem[i0_1, i1_1, i2_1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem[i0_1, i1_1, i2_1] = T.max(T_softmax_maxelem[i0_1, i1_1, i2_1], rxplaceholder[i0_1, i1_1, k, i2_1]) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): + with T.block("compute"): + i0_2, i1_2, i2_2, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(rxplaceholder[i0_2, i1_2, k, i2_2], T_softmax_maxelem[i0_2, i1_2, i2_2]) + T.writes(compute_1[i0_2, i1_2, i2_2]) + with T.init(): + compute_1[i0_2, i1_2, i2_2] = T.float32(0) + compute_1[i0_2, i1_2, i2_2] = compute_1[i0_2, i1_2, i2_2] + T.exp(rxplaceholder[i0_2, i1_2, k, i2_2] - T_softmax_maxelem[i0_2, i1_2, i2_2], dtype="float32") + for i0_3, i1_3, i2_3, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)): + with T.block("compute_1"): + i0_4, i1_4, i2_4, i3_1 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3]) + T.reads(rxplaceholder[i0_4, i1_4, i2_4, i3_1], T_softmax_maxelem[i0_4, i1_4, i3_1], compute_1[i0_4, i1_4, i3_1]) + T.writes(compute[i0_4, i1_4, i2_4, i3_1]) + T.block_attr({"axis": 2}) + compute[i0_4, i1_4, i2_4, i3_1] = (rxplaceholder[i0_4, i1_4, i2_4, i3_1] - T_softmax_maxelem[i0_4, i1_4, i3_1] - T.log(compute_1[i0_4, i1_4, i3_1], dtype="float32")) + # fmt: on + + mod = LegalizeOps()(LogSoftmax) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_log_softmax_symbolic(): + # fmt: off + @tvm.script.ir_module + class LogSoftmax: + @R.function + def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + gv: R.Tensor((a, b, c), "float32") = R.nn.log_softmax(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", "c"), dtype="float32"): + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + # block 0 + gv = R.call_tir(log_softmax, (x,), R.Tensor((a, b, c), dtype="float32")) + return gv + + @T.prim_func + def log_softmax(var_rxplaceholder: T.handle, var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.var("int64") + b = T.var("int64") + c = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") + compute = T.match_buffer(var_compute, [a, b, c], dtype="float32") + T_softmax_maxelem = T.alloc_buffer([a, b], dtype="float32") + compute_1 = T.alloc_buffer([a, b], dtype="float32") + for i0, i1, k in T.grid(a, b, c): + with T.block("T_softmax_maxelem"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_k]) + T.writes(T_softmax_maxelem[v_i0, v_i1]) + with T.init(): + T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], rxplaceholder[v_i0, v_i1, v_k]) + for i0, i1, k in T.grid(a, b, c): + with T.block("compute"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(rxplaceholder[v_i0, v_i1, v_k], T_softmax_maxelem[v_i0, v_i1]) + T.writes(compute_1[v_i0, v_i1]) + with T.init(): + compute_1[v_i0, v_i1] = T.float32(0) + compute_1[v_i0, v_i1] = compute_1[v_i0, v_i1] + T.exp(rxplaceholder[v_i0, v_i1, v_k] - T_softmax_maxelem[v_i0, v_i1], dtype="float32") + for i0, i1, i2 in T.grid(a, b, c): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1], compute_1[v_i0, v_i1],) + T.writes(compute[v_i0, v_i1, v_i2]) + T.block_attr({"axis": 2}) + compute[v_i0, v_i1, v_i2] = (rxplaceholder[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1] - T.log(compute_1[v_i0, v_i1], dtype="float32")) + # fmt: on + + mod = LegalizeOps()(LogSoftmax) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cross_entropy_with_logits(): + # fmt: off + @tvm.script.ir_module + class CrossEntropyWithLogits: + @R.function + def main(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((3,), dtype="float32"), y: R.Tensor((3,), dtype="float32")) -> R.Tensor(dtype="float32", ndim=2): + gv = R.call_tir(cross_entropy_with_logits, (x, y), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def cross_entropy_with_logits(rxplaceholder: T.Buffer[T.int64(3), "float32"], rxplaceholder_1: T.Buffer[T.int64(3), "float32"], T_multiply: T.Buffer[(), "float32"]): + T.func_attr({"tir.noalias": True}) + T_multiply_1 = T.alloc_buffer([T.int64(3)], dtype="float32") + T_multiply_red = T.alloc_buffer([], dtype="float32") + for i0 in T.serial(T.int64(3)): + with T.block("T_multiply"): + ax0 = T.axis.spatial(T.int64(3), i0) + T.reads(rxplaceholder[ax0], rxplaceholder_1[ax0]) + T.writes(T_multiply_1[ax0]) + T_multiply_1[ax0] = rxplaceholder[ax0] * rxplaceholder_1[ax0] + for i0 in T.serial(T.int64(3)): + with T.block("T_multiply_red"): + k0 = T.axis.reduce(T.int64(3), i0) + T.reads(T_multiply_1[k0]) + T.writes(T_multiply_red[()]) + with T.init(): + T_multiply_red[()] = T.float32(0) + T_multiply_red[()] = T_multiply_red[()] + T_multiply_1[k0] + with T.block("T_multiply_1"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_red[()]) + T.writes(T_multiply[()]) + T_multiply[()] = T_multiply_red[()] * T.float32(-1) + # fmt: on + + mod = LegalizeOps()(CrossEntropyWithLogits) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cross_entropy_with_logits_batch(): + # fmt: off + @tvm.script.ir_module + class CrossEntropyWithLogits: + @R.function + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")) -> R.Tensor(dtype="float32", ndim=2): + gv = R.call_tir(cross_entropy_with_logits, (x, y), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def cross_entropy_with_logits(rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], rxplaceholder_1: T.Buffer[(T.int64(2), T.int64(3)), "float32"], T_divide: T.Buffer[(), "float32"]): + T.func_attr({"tir.noalias": True}) + T_multiply = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") + T_multiply_red = T.alloc_buffer([], dtype="float32") + T_multiply_1 = T.alloc_buffer([], dtype="float32") + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax0, ax1]) + T.writes(T_multiply[ax0, ax1]) + T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * rxplaceholder_1[ax0, ax1] + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_multiply_red"): + k0, k1 = T.axis.remap("RR", [i0, i1]) + T.reads(T_multiply[k0, k1]) + T.writes(T_multiply_red[()]) + with T.init(): + T_multiply_red[()] = T.float32(0) + T_multiply_red[()] = T_multiply_red[()] + T_multiply[k0, k1] + with T.block("T_multiply_1"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_red[()]) + T.writes(T_multiply_1[()]) + T_multiply_1[()] = T_multiply_red[()] * T.float32(-1) + with T.block("T_divide"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_1[()]) + T.writes(T_divide[()]) + T_divide[()] = T_multiply_1[()] * T.float32(0.5) + # fmt: on + + mod = LegalizeOps()(CrossEntropyWithLogits) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cross_entropy_with_logits_batch_symbolic(): + # fmt: off + @tvm.script.ir_module + class CrossEntropyWithLogits: + @R.function + def main(x: R.Tensor(("n", "m"), "float32"), y: R.Tensor(("n", "m"), "float32")) -> R.Tensor(None, "float32", ndim=2): + n = T.var("int64") + m = T.var("int64") + gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("n", "m"), dtype="float32"), y: R.Tensor(("n", "m"), dtype="float32")) -> R.Tensor(dtype="float32", ndim=2): + gv = R.call_tir(cross_entropy_with_logits, (x, y), R.Tensor((), dtype="float32")) + return gv + + @T.prim_func + def cross_entropy_with_logits(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, T_divide: T.Buffer[(), "float32"]): + T.func_attr({"tir.noalias": True}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32") + T_multiply = T.alloc_buffer([n, m], dtype="float32") + T_multiply_red = T.alloc_buffer([], dtype="float32") + T_multiply_1 = T.alloc_buffer([], dtype="float32") + for ax0, ax1 in T.grid(n, m): + with T.block("T_multiply"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax0, v_ax1]) + T.writes(T_multiply[v_ax0, v_ax1]) + T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * rxplaceholder_1[v_ax0, v_ax1] + for k0, k1 in T.grid(n, m): + with T.block("T_multiply_red"): + v_k0, v_k1 = T.axis.remap("RR", [k0, k1]) + T.reads(T_multiply[v_k0, v_k1]) + T.writes(T_multiply_red[()]) + with T.init(): + T_multiply_red[()] = T.float32(0) + T_multiply_red[()] = T_multiply_red[()] + T_multiply[v_k0, v_k1] + with T.block("T_multiply_1"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_red[()]) + T.writes(T_multiply_1[()]) + T_multiply_1[()] = T_multiply_red[()] * T.float32(-1) + with T.block("T_divide"): + vi = T.axis.spatial(1, T.int64(0)) + T.reads(T_multiply_1[()]) + T.writes(T_divide[()]) + T_divide[()] = T_multiply_1[()] / T.Cast("float32", n) + # fmt: on + + mod = LegalizeOps()(CrossEntropyWithLogits) + tvm.ir.assert_structural_equal(mod, Expected) + + def test_batch_norm(): # fmt: off @tvm.script.ir_module diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py index 4e52bccb8637..781700af7b82 100644 --- a/tests/python/relax/test_tvmscript_parser_op_nn.py +++ b/tests/python/relax/test_tvmscript_parser_op_nn.py @@ -115,6 +115,21 @@ def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): _check(foo, bb.get()["foo"]) +def test_log_softmax(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.log_softmax(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.nn.log_softmax(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + def test_batch_norm(): @R.function def foo( @@ -189,5 +204,23 @@ def foo( _check(foo, bb.get()["foo"]) +def test_cross_entropy_with_logits(): + @R.function + def foo( + predictions: R.Tensor((2, 3), "float32"), labels: R.Tensor((2, 3), "float32") + ) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(predictions, labels) + return gv + + predictions = relax.Var("predictions", R.Tensor((2, 3), "float32")) + labels = relax.Var("labels", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [predictions, labels]): + gv = bb.emit(relax.op.nn.cross_entropy_with_logits(predictions, labels)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + if __name__ == "__main__": tvm.testing.main() From df67561c71a4aa4c7ad534ad9ea3c2a07867f5dd Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 22 Feb 2023 07:56:45 -0500 Subject: [PATCH 59/81] [Unity][Analysis] TIR pattern kind analysis for multi-buffer write block (#14075) This PR supports TIR pattern kind analysis for TIR blocks which write to multiple buffers, which is helpful for normalization operators like layernorm, groupnorm, etc. Prior to this PR, the analyzer does not support a blocks which write to multiple buffers. On seeing such a block, the analyzer simply sets the analysis result to "opaque". With this PR, on seeing a block which writes multiple buffers, the analyzer will check if all the BufferStores have the same indices. And it will only set the result to "opaque" when the BufferStores have different indices. By doing this, the analysis works for common cases where a block may write to multiple buffers, like layernorm or groupnorm. Besides the unit test for the analysis itself, this PR also adds a unit test for FuseOps pass, make sure that a "layernorm + relu" pattern can be fused together. --- src/relax/analysis/tir_op_pattern_kind.cc | 14 +-- .../test_transform_annotate_tir_op_pattern.py | 27 +++++ tests/python/relax/test_transform_fuse_ops.py | 102 ++++++++++++++++++ 3 files changed, 136 insertions(+), 7 deletions(-) diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index b7ac8faddd23..dfa073fd9c08 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -50,9 +50,9 @@ class PatternKindAnalyzer : public StmtExprVisitor { } void VisitStmt_(const BufferStoreNode* op) final { - // We only support one buffer store in a block (ususally generated by TE compute) + // We only support one buffer store in a block (usually generated by TE compute) // If we have already seen buffer store in the current block, classify as Opaque. - if (store_.defined()) { + if (store_.defined() && !IsSameArray(op->indices, store_.value()->indices)) { kind_ = relay::kOpaque; return; } @@ -85,7 +85,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { for (const BufferLoad& load : loads_) { // Since elemwise is stricter than broadcast and broadcast is stricter than injective, // while the order amount enums: kElemWise < kBroadcast < kInjective. - // We can simpily use `std::max` to detect these three patterns. + // We can simply use `std::max` to detect these three patterns. // E.g Here is only one store node but two load nodes, like C[i, j] = A[i, j] + B[i] // Buffer C and A are elemwise but C and B are broadcast. So the whole block follows // broadcast pattern. @@ -190,7 +190,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { continue; } - // Try to find the i-th load indice in the store indices. + // Try to find the i-th load index in the store indices. while (j < ndim_store_buf && !store->indices[j].same_as(load->indices[i])) { ++j; } @@ -205,10 +205,10 @@ class PatternKindAnalyzer : public StmtExprVisitor { /*! * \brief Checking the load indices and store indices follows injective pattern. - * It's injective pattern iff all load indice vars are in the store indices, no matter orders. + * It's injective pattern iff all load index vars are in the store indices, no matter orders. * Note that we only support store indices are direct vars so far, which can be enhance later. * E.g. A[i, j] = B[j, i] is injective. - * A[i, j] = B[i - j] is injective since the load indice vars are only i, j + * A[i, j] = B[i - j] is injective since the load index vars are only i, j */ static bool IsInjectivePattern(const BufferStore& store, const BufferLoad& load) { std::unordered_set vars; @@ -307,7 +307,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { private: /*! * \brief The BufferStore node in the current block. - * \note We only support one BufferStore node in a block (ususally generated by TE compute) + * \note We only support one BufferStore node in a block (usually generated by TE compute) */ Optional store_; /*! \brief The BufferLoad nodes in the current block. */ diff --git a/tests/python/relax/test_transform_annotate_tir_op_pattern.py b/tests/python/relax/test_transform_annotate_tir_op_pattern.py index 23ce49a7c220..5fc8c9936706 100644 --- a/tests/python/relax/test_transform_annotate_tir_op_pattern.py +++ b/tests/python/relax/test_transform_annotate_tir_op_pattern.py @@ -356,5 +356,32 @@ def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer(160, "float32")): assert new_mod["cumsum"].attrs["op_pattern"] == OpPatternKind.kOpaque +def test_sum_sqsum(): + @tvm.script.ir_module + class Module: + @T.prim_func + def sum_sqsum( + A: T.Buffer((32, 64), "float32"), + vsum: T.Buffer((32,), "float32"), + sqsum: T.Buffer((32,), "float32"), + ): + for ax0, k0 in T.grid(32, 64): + with T.block("block"): + v_ax0, v_k0 = T.axis.remap("SR", [ax0, k0]) + T.reads(A[v_ax0, v_k0]) + T.writes(vsum[v_ax0], sqsum[v_ax0]) + with T.init(): + vsum[v_ax0] = T.float32(0) + sqsum[v_ax0] = T.float32(0) + v_vsum: T.float32 = vsum[v_ax0] + A[v_ax0, v_k0] + v_sqsum: T.float32 = sqsum[v_ax0] + A[v_ax0, v_k0] * A[v_ax0, v_k0] + vsum[v_ax0] = v_vsum + sqsum[v_ax0] = v_sqsum + + mod = Module + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["sum_sqsum"].attrs["op_pattern"] == OpPatternKind.kCommReduce + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index d38e5829815c..a7a6066c4b60 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -854,5 +854,107 @@ def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): _check(Module, Module) +def test_layer_norm_silu(): + # fmt: off + @I.ir_module + class Module: + @R.function + def main(x: R.Tensor((1, 512, 64, 64), "float32"), mean: R.Tensor((64, 64), "float32"), var: R.Tensor((64, 64), "float32")): + with R.dataflow(): + gv0 = R.call_tir(layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64))) + gv1 = R.call_tir(relu, gv0, out_sinfo=R.Tensor((1, 512, 64, 64), "float32")) + R.output(gv1) + return gv1 + + @T.prim_func + def layer_norm(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), gamma: T.Buffer((T.int64(64), T.int64(64)), "float32"), beta: T.Buffer((T.int64(64), T.int64(64)), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32")): + rxplaceholder_red_temp_v0 = T.alloc_buffer([T.int64(64), T.int64(64)], dtype="float32") + rxplaceholder_red_temp_v1 = T.alloc_buffer([T.int64(64), T.int64(64)], dtype="float32") + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): + with T.block("rxplaceholder_red_temp"): + ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(A[ax0, ax1, k2, k3]) + T.writes(rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1]) + with T.init(): + rxplaceholder_red_temp_v0[ax0, ax1] = T.float32(0) + rxplaceholder_red_temp_v1[ax0, ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[ax0, ax1] + A[ax0, ax1, k2, k3] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[ax0, ax1] + A[ax0, ax1, k2, k3] * A[ax0, ax1, k2, k3] + rxplaceholder_red_temp_v0[ax0, ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[ax0, ax1] = v_rxplaceholder_red_temp_v1 + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): + with T.block("T_layer_norm"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[ax0, ax1, ax2, ax3], rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], gamma[ax2, ax3], beta[ax2, ax3]) + T.writes(T_layer_norm[ax0, ax1, ax2, ax3]) + T_layer_norm[ax0, ax1, ax2, ax3] = (A[ax0, ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) * T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] * T.float32(0.05) - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05) * (rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.05)) + T.float32(1e-05), dtype="float32") * gamma[ax2, ax3] + beta[ax2, ax3] + + @T.prim_func + def relu(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), B: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32")): + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): + with T.block("relu"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[v_i0, v_i1, v_i2, v_i3]) + T.writes(B[v_i0, v_i1, v_i2, v_i3]) + B[v_i0, v_i1, v_i2, v_i3] = T.max(A[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + @I.ir_module + class Expected: + @T.prim_func + def layer_norm(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), gamma: T.Buffer((T.int64(64), T.int64(64)), "float32"), beta: T.Buffer((T.int64(64), T.int64(64)), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32")): + T.func_attr({"op_pattern": 4}) + # with T.block("root"): + rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(64), T.int64(64))) + rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(64), T.int64(64))) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): + with T.block("rxplaceholder_red_temp"): + ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) + T.reads(A[ax0, ax1, k2, k3]) + T.writes(rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1]) + with T.init(): + rxplaceholder_red_temp_v0[ax0, ax1] = T.float32(0) + rxplaceholder_red_temp_v1[ax0, ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[ax0, ax1] + A[ax0, ax1, k2, k3] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[ax0, ax1] + A[ax0, ax1, k2, k3] * A[ax0, ax1, k2, k3] + rxplaceholder_red_temp_v0[ax0, ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[ax0, ax1] = v_rxplaceholder_red_temp_v1 + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): + with T.block("T_layer_norm"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[ax0, ax1, ax2, ax3], rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1], gamma[ax2, ax3], beta[ax2, ax3]) + T.writes(T_layer_norm[ax0, ax1, ax2, ax3]) + T_layer_norm[ax0, ax1, ax2, ax3] = (A[ax0, ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.050000000000000003)) * T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] * T.float32(0.050000000000000003) - rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.050000000000000003) * (rxplaceholder_red_temp_v0[ax0, ax1] * T.float32(0.050000000000000003)) + T.float32(1.0000000000000001e-05)) * gamma[ax2, ax3] + beta[ax2, ax3] + + @T.prim_func + def relu(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), B: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32")): + T.func_attr({"op_pattern": 0}) + # with T.block("root"): + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)): + with T.block("relu"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[v_i0, v_i1, v_i2, v_i3]) + T.writes(B[v_i0, v_i1, v_i2, v_i3]) + B[v_i0, v_i1, v_i2, v_i3] = T.max(A[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + @R.function + def fused_layer_norm_relu(x: R.Tensor((1, 512, 64, 64), dtype="float32"), mean: R.Tensor((64, 64), dtype="float32"), var: R.Tensor((64, 64), dtype="float32")) -> R.Tensor((1, 512, 64, 64), dtype="float32"): + R.func_attr({"Primitive": 1}) + with R.dataflow(): + gv0 = R.call_tir(layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64))) + gv = R.call_tir(relu, (gv0,), out_sinfo=R.Tensor((1, 512, 64, 64), dtype="float32")) + R.output(gv) + return gv + + @R.function + def main(x: R.Tensor((1, 512, 64, 64), dtype="float32"), mean: R.Tensor((64, 64), dtype="float32"), var: R.Tensor((64, 64), dtype="float32")) -> R.Tensor((1, 512, 64, 64), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((1, 512, 64, 64), dtype="float32") = fused_layer_norm_relu(x, mean, var) + R.output(gv) + return gv + # fmt: on + + _check(Module, Expected) + + if __name__ == "__main__": tvm.testing.main() From c0a591d222a02e512d19966927af9d36e32cb94c Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 22 Feb 2023 14:44:00 -0500 Subject: [PATCH 60/81] [Unity][Fix][Pass] FoldConstant with DCE in dataflow block (#14087) The current FoldConstant pass does not support removing unused bindings in the post-folding function. Therefore, for large real-world models, the built executable will be overlarge because of the redundant unused constants. This PR removes the redundant unused constant bindings in FoldConstant by using the analysis function "RemoveAllUnused". Note that "RemoveAllUnused" only works at dataflow block level. Therefore FoldConstant will not remove unused bindings outside of dataflow block as well. --- src/relax/analysis/udchain.cc | 5 ++++- src/relax/transform/fold_constant.cc | 14 +++++++++----- tests/python/relax/test_transform_fold_constant.py | 3 --- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index f3d9b4686b7d..77e52408a710 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -52,7 +52,10 @@ class UDChain : public relax::ExprVisitor { void VisitExpr_(const VarNode* op) override { to_users[op].insert(cur_user_); } void VisitVarDef(const Var& var) override { to_users[var.get()] = {}; } - void VisitExpr_(const FunctionNode* op) override { ExprVisitor::VisitExpr_(op); } + void VisitExpr_(const FunctionNode* op) override { + cur_user_ = nullptr; + ExprVisitor::VisitExpr_(op); + } void VisitExpr_(const DataflowVarNode* op) override { VisitExpr_(static_cast(op)); diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index aa55ee7f7e3d..87b022c8ae08 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -30,9 +31,15 @@ namespace relax { class ConstantFolder : public ExprMutator { public: - explicit ConstantFolder(IRModule ctx_module) : ctx_module_(ctx_module) {} + static Function Fold(Function func, IRModule ctx_module) { + ConstantFolder folder(std::move(ctx_module)); + func = RemoveAllUnused(Downcast(folder(func))); + return func; + } private: + explicit ConstantFolder(IRModule ctx_module) : ctx_module_(ctx_module) {} + /*! * \brief Pattern match the shape inside the given struct info to a * constant shape and get runtime shape tuple from it. @@ -215,10 +222,7 @@ namespace transform { Pass FoldConstant() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - ConstantFolder folder(m); - return Downcast(folder(f)); - }; + [=](Function f, IRModule m, PassContext pc) { return ConstantFolder::Fold(f, m); }; return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); } diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py index 95542dd4e6ca..da0816ef807c 100644 --- a/tests/python/relax/test_transform_fold_constant.py +++ b/tests/python/relax/test_transform_fold_constant.py @@ -165,9 +165,6 @@ def before(c0: R.Tensor((16, 16), "float32")): @R.function def expected(c1: R.Tensor((16, 16), "float32")): - with R.dataflow(): - gv0 = c1 - R.output(gv0) return c1 c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) From a283a71be6b27f94216bc0ead72050ae2d6143ec Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 22 Feb 2023 14:47:11 -0500 Subject: [PATCH 61/81] [Unity] Refactor Relax Build JIT UX (#14088) This PR refactors relax build so it get exposed at the opt-level. We also introduces an explicit jit functionality to handle live loading of compiled artifacts from cutlass. We also move relax vm to runtime so it can be clearly isolated from the rest of the compiler stack. --- apps/relax_examples/e2e_auto_tir.py | 2 +- apps/relax_examples/mlp.py | 2 +- apps/relax_examples/nn_module.py | 2 +- apps/relax_examples/resnet.py | 2 +- python/tvm/contrib/cutlass/build.py | 26 -- python/tvm/meta_schedule/relax_integration.py | 6 +- python/tvm/relax/__init__.py | 8 +- python/tvm/relax/exec_builder.py | 2 +- python/tvm/relax/frontend/torch/dynamo.py | 4 +- python/tvm/relax/transform/transform.py | 4 + .../transform/tuning_api/default_functions.py | 6 +- python/tvm/relax/vm_build.py | 317 ++++++++++++++++++ .../tvm/{relax/vm.py => runtime/relax_vm.py} | 231 ++----------- tests/python/relax/test_codegen_cutlass.py | 6 +- tests/python/relax/test_codegen_dnnl.py | 2 +- tests/python/relax/test_codegen_tensorrt.py | 2 +- tests/python/relax/test_pipeline.py | 2 +- tests/python/relax/test_relay_translator.py | 2 +- .../relax/test_transform_bind_params.py | 4 +- .../relax/test_transform_codegen_pass.py | 12 +- tests/python/relax/test_vm_build.py | 62 ++-- tests/python/relax/test_vm_codegen_only.py | 10 +- tests/python/relax/test_vm_codegen_tir.py | 2 +- tests/python/relax/test_vm_profiler.py | 8 +- tests/python/relay/test_vm.py | 10 +- 25 files changed, 423 insertions(+), 311 deletions(-) create mode 100644 python/tvm/relax/vm_build.py rename python/tvm/{relax/vm.py => runtime/relax_vm.py} (73%) diff --git a/apps/relax_examples/e2e_auto_tir.py b/apps/relax_examples/e2e_auto_tir.py index 92cda16f7927..8113f942d166 100644 --- a/apps/relax_examples/e2e_auto_tir.py +++ b/apps/relax_examples/e2e_auto_tir.py @@ -142,7 +142,7 @@ def apply_opt_before_tuning( def f_measurement( rt_mod: runtime.Module, device: runtime.ndarray.Device, input_data: Dict[str, runtime.NDArray] ): - vm = relax.vm.VirtualMachine(exec=rt_mod, device=device) + vm = relax.VirtualMachine(rt_mod, device=device) vm.save_function("main", "measure_func", **input_data, include_return=False) evaluator = vm.time_evaluator( func_name="measure_func", diff --git a/apps/relax_examples/mlp.py b/apps/relax_examples/mlp.py index 02e17dc3041a..2a81b61543fd 100644 --- a/apps/relax_examples/mlp.py +++ b/apps/relax_examples/mlp.py @@ -47,7 +47,7 @@ def build_mlp(data, weight): # build and create vm executor target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target) + ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, tvm.cpu()) # run the mlp model on relax vm diff --git a/apps/relax_examples/nn_module.py b/apps/relax_examples/nn_module.py index b57cb00685ae..57a13e4fb51b 100644 --- a/apps/relax_examples/nn_module.py +++ b/apps/relax_examples/nn_module.py @@ -56,7 +56,7 @@ # build the IRModule and create relax vm target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target) + ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, tvm.cpu()) # init parameters diff --git a/apps/relax_examples/resnet.py b/apps/relax_examples/resnet.py index df0cab02f19c..6c7350d77847 100644 --- a/apps/relax_examples/resnet.py +++ b/apps/relax_examples/resnet.py @@ -36,7 +36,7 @@ relax_mod.show() # build the IRModule and create relax vm - ex = relax.vm.build(relax_mod, target) + ex = relax.build(relax_mod, target) vm = relax.VirtualMachine(ex, tvm.cpu()) # init weights and run the model on relax vm diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index ad0e59af02fa..c6e5adacec86 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -851,29 +851,3 @@ def finalize_modules_vm(vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro", fo.write(code) lib = tvm.runtime.load_module(lib_path) return tvm.runtime.vm.Executable.load_exec(code, lib) - - -def finalize_modules_relax(vm_exec, lib_path="compile.so", tmp_dir="./tmp"): - """finalize_modules_vm equivalent for Relax VM. - - Parameters - ---------- - vm_exec : vm.Executable - The output from relax.vm.build containing compiled host code and kernels. - - lib_path : string - The path to a shared library which will be generated as the result of the build process. - - tmp_dir : string - A temporary directory where intermediate compiled artifacts will be stored. - - Returns - ------- - updated_vm_exec : relax.vm.Executable - The updated VM executable with all compilation and linking completed. - """ - lib_path = os.path.join(tmp_dir, lib_path) - vm_exec.mod.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc") - lib = tvm.runtime.load_module(lib_path) - - return relax.vm.Executable(lib) diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py index a82d8996858b..db22214b768f 100644 --- a/python/tvm/meta_schedule/relax_integration.py +++ b/python/tvm/meta_schedule/relax_integration.py @@ -317,7 +317,7 @@ def compile_relax( mod: IRModule, target: Union[Target, str], params: Optional[Dict[str, NDArray]], -) -> "relax.vm.Executable": +) -> "relax.Executable": """Compile a relax program with a MetaSchedule database. Parameters @@ -333,12 +333,12 @@ def compile_relax( Returns ------- - lib : relax.vm.Executable + lib : relax.Executable The built runtime module or vm Executable for the given relax workload. """ # pylint: disable=import-outside-toplevel from tvm.relax.transform import BindParams, MetaScheduleApplyDatabase - from tvm.relax.vm import build as relax_build + from tvm.relax import build as relax_build # pylint: enable=import-outside-toplevel if not isinstance(target, Target): diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 33a9c2eece21..d0a1942ebdcb 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -16,6 +16,9 @@ # under the License. # pylint: disable=invalid-name, wrong-import-position """The Relax IR namespace containing the IR, type, operator, builder, vm, etc.""" +from tvm.runtime import relax_vm as vm +from tvm.runtime.relax_vm import VirtualMachine + # Expr from .expr import ( Expr, @@ -51,7 +54,6 @@ # VM from .exec_builder import ExecBuilder -from .vm import VirtualMachine # Operator from .op.base import call_tir @@ -82,7 +84,9 @@ from . import ty from . import analysis from . import transform -from . import vm from . import block_builder from . import op from . import struct_info + +# VM +from .vm_build import build, Executable diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py index 1e28c967d18f..140c497eb967 100644 --- a/python/tvm/relax/exec_builder.py +++ b/python/tvm/relax/exec_builder.py @@ -21,7 +21,7 @@ import tvm from tvm.runtime import Object from tvm.runtime.container import ShapeTuple -from .vm import Executable +from .vm_build import Executable from . import _ffi_api diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py index 94de73a43115..589c6be3b5b5 100644 --- a/python/tvm/relax/frontend/torch/dynamo.py +++ b/python/tvm/relax/frontend/torch/dynamo.py @@ -23,7 +23,7 @@ from typing import Optional import tvm -from tvm.relax.vm import build as relax_build +from tvm.relax import build as relax_build from tvm.relax.frontend.torch.fx_translator import from_fx @@ -96,7 +96,7 @@ def to_tvm_tensor(torch_tensor): ex = relax_build(mod, target=target) - vm = tvm.relax.vm.VirtualMachine(exec=ex.mod, device=dev) + vm = tvm.relax.VirtualMachine(ex.mod, device=dev) def exec_tvm(*i_args): args = [a.contiguous() for a in i_args] diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index c72d053290af..7044314e8581 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -222,6 +222,10 @@ def RunCodegen( """ if entry_functions is None: entry_functions = ["main"] + # enable cutlass byoc registries + # pylint: disable=unused-import,import-outside-toplevel + from tvm.contrib import cutlass as _cutlass + return _ffi_api.RunCodegen(target_options, entry_functions) # type: ignore diff --git a/python/tvm/relax/transform/tuning_api/default_functions.py b/python/tvm/relax/transform/tuning_api/default_functions.py index b72b2f30ee2b..7cdb211bd32f 100644 --- a/python/tvm/relax/transform/tuning_api/default_functions.py +++ b/python/tvm/relax/transform/tuning_api/default_functions.py @@ -176,7 +176,7 @@ def relax_build( ): if params: mod = tvm.relax.transform.BindParams("main", params)(mod) - relax_exec = tvm.relax.vm.build(mod, target) + relax_exec = tvm.relax.build(mod, target) return relax_exec.mod builder = LocalBuilder(f_build=relax_build) @@ -185,8 +185,8 @@ def relax_build( if runner is None: def relax_eval_func(rt_mod, device, evaluator_config, repeated_args): - relax_exec = tvm.relax.vm.Executable(rt_mod) - relax_vm = tvm.relax.VirtualMachine(exec=relax_exec, device=device) + relax_exec = tvm.relax.Executable(rt_mod) + relax_vm = tvm.relax.VirtualMachine(relax_exec, device=device) evaluator = relax_vm.module.time_evaluator( func_name="main", diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py new file mode 100644 index 000000000000..35fc65bdc6c0 --- /dev/null +++ b/python/tvm/relax/vm_build.py @@ -0,0 +1,317 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-member +"""VM build logics""" +from typing import List, Optional, Union, Dict, Any + +import tvm +from tvm import relax + +from tvm.contrib import utils as _utils + +from tvm.ir.module import IRModule +from tvm.tir.function import PrimFunc + +from . import _ffi_api + + +class Executable: + """The executable object emitted by the VM compiler or the ExecBuilder.""" + + def __init__(self, mod: tvm.runtime.Module): + self.mod = mod + self._stats = self.mod["stats"] + self._as_text = self.mod["as_text"] + self._as_python = self.mod["as_python"] + + def stats(self) -> str: + """print the detailed statistics of the executable.""" + return self._stats() + + def as_text(self) -> str: + """print the instructions as text format.""" + return self._as_text() + + def as_python(self) -> str: + """print the instructions as python program.""" + return self._as_python() + + def jit(self, fcompile=None, addons=None, **kwargs) -> tvm.runtime.Module: + """Just-in-time compile and link the modules. + + The Executable returned by relax.build may not be directly + runnable as they may contain cuda source files and objects that + are yet to be compiled and linked. + This function helps to create a runtime.Module for these cases. + + Parameters + ---------- + fcompile : function(target, file_list, kwargs), optional + The compilation function to use create the final library object during + + kwargs : dict, optional + Additional arguments passed to fcompile + + Returns + ------- + rt_mod: tvm.runtime.Module + A runnable runtime module that can be passed to VirtualMachine. + + Examples + -------- + .. code:: python + + ex = relax.build(mod, target) + # build a runnable module using nvcc to link everything + rt_mod = ex.jit() + vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda()) + """ + # TODO(tvm-team): Update runtime.Module interfac + # to query these properties as bitmask. + def _not_runnable(x): + return x.type_key in ("c", "static_library") + + # pylint:disable = protected-access + not_runnable_list = self.mod._collect_from_import_tree(_not_runnable) + + # everything is runnable, directly return mod. + if len(not_runnable_list) == 0: + return self.mod + + # found source module, or other not runnable modules + # need to be export and load + # TODO(tvm-team): Support runnable but not exportable module. + # by collecting the link and allow export_library skip those modules. + workspace_dir = _utils.tempdir() + dso_path = workspace_dir.relpath("exported.so") + self.mod.export_library(dso_path, fcompile=fcompile, addons=addons, **kwargs) + return tvm.runtime.load_module(dso_path) + + def export_library( + self, + file_name: str, + fcompile: Optional[Union[str, callable]] = None, + workspace_dir: Optional[str] = None, + **kwargs, + ) -> Any: + """Export the executable to a library which can then be loaded back. + + Parameters + ---------- + file_name : str + The name of the shared library. + + fcompile : function(target, file_list, kwargs), optional + The compilation function to use create the final library object during + + workspace_dir : str, optional + The path of the directory used to create the intermediate + artifacts when exporting the module. + If this is not provided a temporary dir will be created. + + kwargs : dict, optional + Additional arguments passed to fcompile + + Returns + ------- + result of fcompile() : unknown, optional + If the compilation function returns an artifact it would be returned via + export_library, if any. + + Examples + -------- + .. code:: python + + ex = relax.build(mod, target) + # export the library + ex.export_library("exported.so") + + # load it back for future uses. + rt_mod = tvm.runtime.load_module("exported.so") + vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda()) + """ + return self.mod.export_library( + file_name=file_name, fcompile=fcompile, workspace_dir=workspace_dir, **kwargs + ) + + +def _vmcodegen( + builder: "relax.ExecBuilder", + mod: tvm.IRModule, + exec_mode: str = "bytecode", +) -> tvm.IRModule: + """Running VM codegen. + + Parameters + ---------- + builder: relax.ExecBuilder + ExecBuilder to collect the vm executable. + + mod: IRModule + The input IRModule to be built. + + exec_mode: {"bytecode", "compiled"} + The execution mode. + + Return + ------ + leftover: IRModule + Left over IRModule that may contain extra functions. + """ + + if exec_mode == "bytecode": + return _ffi_api.VMCodeGen(builder, mod) # type:ignore + if exec_mode == "compiled": + return _ffi_api.VMTIRCodeGen(builder, mod) # type: ignore + raise ValueError("Unknown exec_mode %s" % exec_mode) + + +def _vmlink( + builder: "relax.ExecBuilder", + target: Union[str, tvm.target.Target], + tir_mod: Optional[tvm.IRModule] = None, + ext_libs: List[tvm.runtime.Module] = None, + params: Optional[Dict[str, list]] = None, +): + """ + Internal codegen function to make executable. + + This function is only used for unit-testing purpoes. + + Use build instead. + + Parameters + ---------- + builder: relax.ExecBuilder + Builder used to collect executables. + + target : Union[str, tvm.target.Target] + A build target which can have optional host side compilation target. + + tir_mod: IRModule + The input TIR IRModule to be linked together. + + ext_libs: List[tvm.runtime.Module] + List of compiled external modules. + + params: Optional[Dict[str, list]] + Extra parameter mappings. + + Returns + ------- + ex: tvm.relax.Executable + An executable that can be loaded by virtual machine. + """ + if isinstance(target, str): + target = tvm.target.Target(target) + if params is None: + params = {} + if ext_libs is None: + ext_libs = [] + lib = None + if tir_mod is not None: + lib = tvm.build(tir_mod, target=target) + return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore + + +def build( + mod: tvm.IRModule, + target: Union[str, tvm.target.Target], + params: Optional[Dict[str, list]] = None, + exec_mode: str = "bytecode", +) -> Executable: + """ + Build an IRModule to VM executable. + + Parameters + ---------- + mod: IRModule + The input IRModule to be built. + + target : Union[str, tvm.target.Target] + A build target which can have optional host side compilation target. + + When TVM compiles device specific program such as CUDA, + we also need host(CPU) side code to interact with the driver + to setup the dimensions and parameters correctly. + host is used to specify the host side codegen target. + By default, llvm is used if it is enabled, + otherwise a stackvm interpreter is used. + + params: Optional[Dict[str, list]] + Parameters for the input IRModule that will be bound. + + exec_mode: {"bytecode", "compiled"} + The execution mode. + + Returns + ------- + ex: tvm.relax.Executable + An executable that can be loaded by virtual machine. + + Example + ------- + + .. code-block:: python + class InputModule: + @R.function + def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): + z = R.add(x, y) + return z + + mod = InputModule + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target) + """ + if isinstance(target, str): + target = tvm.target.Target(target) + + passes = [] + passes.append(relax.transform.RewriteDataflowReshape()) + passes.append(relax.transform.ToNonDataflow()) + passes.append(relax.transform.CallTIRRewrite()) + passes.append(relax.transform.StaticPlanBlockMemory()) + passes.append(relax.transform.VMBuiltinLower()) + passes.append(relax.transform.VMShapeLower()) + passes.append(relax.transform.AttachGlobalSymbol()) + seq = tvm.transform.Sequential(passes) + new_mod = seq(mod) + + # Extract external runtime modules if exist. + attrs = dict(mod.attrs) if mod.attrs else {} + + ext_libs = attrs.get("external_mods", []) + constants = attrs.get("const_name_to_constant", {}) + + if params is not None: + params.update(dict(constants)) + else: + params = constants + + # builder collects the executable + builder = relax.ExecBuilder() + leftover_mod = _vmcodegen(builder, new_mod, exec_mode=exec_mode) + tir_mod = _filter_tir(leftover_mod) + return _vmlink(builder, target, tir_mod, ext_libs, params) + + +def _filter_tir(mod: tvm.IRModule) -> tvm.IRModule: + tir_mod = IRModule({}) + for gv in mod.get_global_vars(): + if isinstance(mod[gv], PrimFunc): + tir_mod[gv] = mod[gv] + return tir_mod diff --git a/python/tvm/relax/vm.py b/python/tvm/runtime/relax_vm.py similarity index 73% rename from python/tvm/relax/vm.py rename to python/tvm/runtime/relax_vm.py index a3578c8a409d..9defcb7d80f3 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/runtime/relax_vm.py @@ -14,43 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, redefined-builtin, no-else-return -"""The Relax virtual machine""" +# pylint: disable=invalid-name, redefined-builtin, no-else-return, consider-using-dict-items +"""The Relax virtual machine.""" from typing import Callable, List, Optional, Union, Dict, Tuple, Any import numpy as np # type: ignore -from tvm._ffi import base as _base import tvm -from tvm import relax -from tvm.ir.module import IRModule -from tvm.runtime import Device, Module, PackedFunc, container -from tvm.runtime.object import Object -from tvm.runtime.profiling import Report -from tvm.tir.function import PrimFunc -from . import _ffi_api -from ..rpc.base import RPC_SESS_MASK - - -class Executable(object): - """The executable object emitted by the VM compiler or the ExecBuilder.""" - - def __init__(self, mod: Module): - self.mod = mod - self._stats = self.mod["stats"] - self._as_text = self.mod["as_text"] - self._as_python = self.mod["as_python"] - - def stats(self) -> str: - """print the detailed statistics of the executable.""" - return self._stats() +from tvm._ffi import base as _base - def as_text(self) -> str: - """print the instructions as text format.""" - return self._as_text() +from tvm.runtime import Device, PackedFunc, container, Object +from tvm.runtime.profiling import Report - def as_python(self) -> str: - """print the instructions as python program.""" - return self._as_python() +from ..rpc.base import RPC_SESS_MASK class VirtualMachine(object): @@ -61,7 +36,7 @@ class VirtualMachine(object): def __init__( self, - exec: Union[Executable, Module], + rt_mod: Union[tvm.runtime.Module, "tvm.relax.Executable"], device: Union[Device, List[Device]], memory_cfg: Optional[Union[str, Dict[Device, str]]] = None, profile: bool = False, @@ -71,8 +46,8 @@ def __init__( Parameters ---------- - exec: Union[Executable, Module] - The VM executable or Runtime Module + mod: Union[tvm.runtime.Module, tvm.relax.Executable] + Runtime module exported by the result of build. device : Union[Device, List[Device]] The device to deploy the module. @@ -88,8 +63,20 @@ def __init__( profile : Optional[bool] Whether or not to enable profiling. """ + if not isinstance(rt_mod, tvm.runtime.Module): + # important to keep this import local + # as the relax_vm needs to be isolated from compiler + # if we do not use the jit feature + # pylint:disable=import-outside-toplevel + from tvm import relax + + if isinstance(rt_mod, relax.Executable): + rt_mod = rt_mod.jit() + else: + raise ValueError("Expect the rt_mod to be an runtime.Module") + load_exec = "vm_profiler_load_executable" if profile else "vm_load_executable" - self.module = exec.mod[load_exec]() if isinstance(exec, Executable) else exec[load_exec]() + self.module = rt_mod[load_exec]() self._invoke_closure = self.module["invoke_closure"] self._save_function = self.module["save_function"] self._set_input = self.module["set_input"] @@ -408,7 +395,7 @@ def time_evaluator( .. code-block:: python target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(TestTimeEvaluator, target) + ex = relax.build(TestTimeEvaluator, target) vm = relax.VirtualMachine(mod, tvm.cpu()) timing_res = vm.time_evaluator("func_name", tvm.cpu())(arg0, arg1, ..., argn) @@ -417,7 +404,7 @@ def time_evaluator( .. code-block:: python target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(TestTimeEvaluator, target) + ex = relax.build(TestTimeEvaluator, target) vm = relax.VirtualMachine(mod, tvm.cpu()) vm.set_input("func_name", arg0, arg1, ..., argn) timing_res = vm.time_evaluator("invoke_stateful", tvm.cpu())("func_name") @@ -428,7 +415,7 @@ def time_evaluator( .. code-block:: python target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(TestTimeEvaluator, target) + ex = relax.build(TestTimeEvaluator, target) vm = relax.VirtualMachine(mod, tvm.cpu()) vm.save_function("func_name", "func_name_saved", arg0, arg1, ..., argn) timing_res = vm.time_evaluator("func_name_saved", tvm.cpu())() @@ -471,171 +458,3 @@ def profile(self, func_name: str, *args): report_json = self.module["profile"](func_name, *cargs) return Report.from_json(report_json) - - -def _vmcodegen( - builder: "relax.ExecBuilder", - mod: tvm.IRModule, - exec_mode: str = "bytecode", -) -> tvm.IRModule: - """Running VM codegen. - - Parameters - ---------- - builder: relax.ExecBuilder - ExecBuilder to collect the vm executable. - - mod: IRModule - The input IRModule to be built. - - exec_mode: {"bytecode", "compiled"} - The execution mode. - - Return - ------ - leftover: IRModule - Left over IRModule that may contain extra functions. - """ - - if exec_mode == "bytecode": - return _ffi_api.VMCodeGen(builder, mod) # type:ignore - if exec_mode == "compiled": - return _ffi_api.VMTIRCodeGen(builder, mod) # type: ignore - raise ValueError("Unknown exec_mode %s" % exec_mode) - - -def _vmlink( - builder: "relax.ExecBuilder", - target: Union[str, tvm.target.Target], - tir_mod: Optional[tvm.IRModule] = None, - ext_libs: List[tvm.runtime.Module] = None, - params: Optional[Dict[str, list]] = None, -): - """ - Internal codegen function to make executable. - - This function is only used for unit-testing purpoes. - - Use build instead. - - Parameters - ---------- - builder: relax.ExecBuilder - Builder used to collect executables. - - target : Union[str, tvm.target.Target] - A build target which can have optional host side compilation target. - - tir_mod: IRModule - The input TIR IRModule to be linked together. - - ext_libs: List[tvm.runtime.Module] - List of compiled external modules. - - params: Optional[Dict[str, list]] - Extra parameter mappings. - - Returns - ------- - ex: tvm.relax.vm.Executable - An executable that can be loaded by virtual machine. - """ - if isinstance(target, str): - target = tvm.target.Target(target) - if params is None: - params = {} - if ext_libs is None: - ext_libs = [] - lib = None - if tir_mod is not None: - lib = tvm.build(tir_mod, target=target) - return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore - - -def build( - mod: tvm.IRModule, - target: Union[str, tvm.target.Target], - params: Optional[Dict[str, list]] = None, - exec_mode: str = "bytecode", -) -> Executable: - """ - Build an IRModule to VM executable. - - Parameters - ---------- - mod: IRModule - The input IRModule to be built. - - target : Union[str, tvm.target.Target] - A build target which can have optional host side compilation target. - - When TVM compiles device specific program such as CUDA, - we also need host(CPU) side code to interact with the driver - to setup the dimensions and parameters correctly. - host is used to specify the host side codegen target. - By default, llvm is used if it is enabled, - otherwise a stackvm interpreter is used. - - params: Optional[Dict[str, list]] - Parameters for the input IRModule that will be bound. - - exec_mode: {"bytecode", "compiled"} - The execution mode. - - Returns - ------- - ex: tvm.relax.vm.Executable - An executable that can be loaded by virtual machine. - - Example - ------- - - .. code-block:: python - class InputModule: - @R.function - def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): - z = R.add(x, y) - return z - - mod = InputModule - target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target) - """ - if isinstance(target, str): - target = tvm.target.Target(target) - - passes = [] - passes.append(relax.transform.RewriteDataflowReshape()) - passes.append(relax.transform.ToNonDataflow()) - passes.append(relax.transform.CallTIRRewrite()) - passes.append(relax.transform.StaticPlanBlockMemory()) - passes.append(relax.transform.VMBuiltinLower()) - passes.append(relax.transform.VMShapeLower()) - passes.append(relax.transform.AttachGlobalSymbol()) - seq = tvm.transform.Sequential(passes) - new_mod = seq(mod) - - # Extract external runtime modules if exist. - attrs = dict(mod.attrs) if mod.attrs else {} - - ext_libs = attrs.get("external_mods", []) - constants = attrs.get("const_name_to_constant", {}) - - if params is not None: - params.update(dict(constants)) - else: - params = constants - - # builder collects the executable - builder = relax.ExecBuilder() - leftover_mod = _vmcodegen(builder, new_mod, exec_mode=exec_mode) - tir_mod = _filter_tir(leftover_mod) - return _vmlink(builder, target, tir_mod, ext_libs, params) - - -def _filter_tir(mod: tvm.IRModule) -> tvm.IRModule: - tir_mod = IRModule({}) - for gv in mod.get_global_vars(): - if isinstance(mod[gv], PrimFunc): - tir_mod[gv] = mod[gv] - return tir_mod diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 1eafb1bc1caf..5556d1e5d9a8 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -23,7 +23,6 @@ import tvm import tvm.testing from tvm import relax, relay -from tvm.contrib.cutlass.build import finalize_modules_relax from tvm.relax.dpl import make_fused_bias_activation_pattern, make_matmul_pattern from tvm.script import relax as R @@ -214,7 +213,7 @@ def main( cutlass_enabled = pytest.mark.skipif( not has_cutlass, - reason="CUTLASS note enabled.", + reason="CUTLASS not enabled.", ) pytestmark = [cutlass_enabled] @@ -231,8 +230,7 @@ def get_result_with_relax_cutlass_offload(mod, patterns: List[Tuple], *args): mod = seq(mod) target = tvm.target.Target("cuda") - ex = relax.vm.build(mod, target) - ex = finalize_modules_relax(ex) + ex = relax.build(mod, target) dev = tvm.gpu(0) vm = relax.VirtualMachine(ex, dev) diff --git a/tests/python/relax/test_codegen_dnnl.py b/tests/python/relax/test_codegen_dnnl.py index 69139b28ef34..885c88f3b0b4 100644 --- a/tests/python/relax/test_codegen_dnnl.py +++ b/tests/python/relax/test_codegen_dnnl.py @@ -88,7 +88,7 @@ def test_dnnl_offload(): mod = seq(Conv2dReLUx2) target = tvm.target.Target("llvm") - ex = relax.vm.build(mod, target) + ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, tvm.cpu()) f = vm["main"] diff --git a/tests/python/relax/test_codegen_tensorrt.py b/tests/python/relax/test_codegen_tensorrt.py index 164cf3a8189e..47a4b1eec6cd 100644 --- a/tests/python/relax/test_codegen_tensorrt.py +++ b/tests/python/relax/test_codegen_tensorrt.py @@ -101,7 +101,7 @@ def test_tensorrt_offload(): target = "cuda" dev = tvm.device(target, 0) - ex = relax.vm.build(mod, target) + ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] diff --git a/tests/python/relax/test_pipeline.py b/tests/python/relax/test_pipeline.py index 6d6704ae97ec..c66066f8f830 100644 --- a/tests/python/relax/test_pipeline.py +++ b/tests/python/relax/test_pipeline.py @@ -34,7 +34,7 @@ def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): mod = pipeline(mod) target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target) + ex = relax.build(mod, target) x_np = np.random.rand(3, 4).astype(np.float32) y_np = np.random.rand(3, 4).astype(np.float32) x = tvm.nd.array(x_np) diff --git a/tests/python/relax/test_relay_translator.py b/tests/python/relax/test_relay_translator.py index 5f7e05b02d3a..b4f84027ebe4 100644 --- a/tests/python/relax/test_relay_translator.py +++ b/tests/python/relax/test_relay_translator.py @@ -184,7 +184,7 @@ def translate_and_build_vms(relay_mod, target_str="llvm", translate_op_with_tir= relax_mod = relay_translator.from_relay( relay_mod["main"], target, translate_op_with_tir=translate_op_with_tir ) - relax_ex = relax.vm.build(relax_mod, target) + relax_ex = relax.build(relax_mod, target) relax_vm = relax.VirtualMachine(relax_ex, tvm.cpu()) return relay_vm, relax_vm, relax_mod diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py index b96fb89e6c0a..ceaf8fb16554 100644 --- a/tests/python/relax/test_transform_bind_params.py +++ b/tests/python/relax/test_transform_bind_params.py @@ -60,11 +60,11 @@ def main( assert len(mod["main"].params) == 1 target = tvm.target.Target("llvm") - ex_after = relax.vm.build(mod, target) + ex_after = relax.build(mod, target) vm_after = relax.VirtualMachine(ex_after, tvm.cpu()) res_after = vm_after["main"](x_tvm) - ex_before = relax.vm.build(InputModule, target) + ex_before = relax.build(InputModule, target) vm_before = relax.VirtualMachine(ex_before, tvm.cpu()) res_before = vm_before["main"](x_tvm, w_tvm) diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py index e50ad8f5f427..3e9501147aa0 100644 --- a/tests/python/relax/test_transform_codegen_pass.py +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -56,10 +56,10 @@ def check_executable(exec, dev, inputs, expected): def check_roundtrip(exec0, dev, inputs, expected): exec0.mod.export_library("exec.so") - exec1 = relax.vm.Executable(tvm.runtime.load_module("exec.so")) + exec1 = tvm.runtime.load_module("exec.so") os.remove("exec.so") - assert exec0.stats() == exec1.stats() - assert exec0.as_text() == exec1.as_text() + assert exec0.stats() == exec1["stats"] + assert exec0.as_text() == exec1["as_text"]() check_executable(exec0, dev, inputs, expected) check_executable(exec1, dev, inputs, expected) @@ -81,7 +81,7 @@ def gen_ground_truth(mod, target, dev, inputs): ) new_mod = seq(mod) assert relax.analysis.well_formed(new_mod) - exec = relax.vm.build(new_mod, target, params={}) + exec = relax.build(new_mod, target, params={}) vm = relax.VirtualMachine(exec, dev) return vm["main"](*inputs) @@ -140,7 +140,7 @@ def test_tensorrt_only(): ] )(mod) - ex0 = relax.vm.build(new_mod, target, params={}) + ex0 = relax.build(new_mod, target, params={}) # Sanity check for the correctness and rountrip check_roundtrip(ex0, dev, inputs, expected) @@ -173,7 +173,7 @@ def test_mix_use_tensorrt_and_tvm(): )(mod) assert relax.analysis.well_formed(new_mod) with transform.PassContext(opt_level=0): - ex0 = relax.vm.build(new_mod, target, params={}) + ex0 = relax.build(new_mod, target, params={}) # Sanity check for the correctness and rountrip check_roundtrip(ex0, dev, inputs, expected) diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index e78e926dcb7c..e51e22e3233c 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -46,7 +46,7 @@ def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): mod = TestVMCompileStage0 target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) inp1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) inp2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) vm = relax.VirtualMachine(ex, tvm.cpu()) @@ -64,7 +64,7 @@ def foo(x: R.Tensor(["n", "m"], "int32"), y: R.Object) -> R.Tensor(["m", "n"], d mod = TestMatchCheck target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) x0 = tvm.nd.array(np.zeros((1, 2)).astype("int32")) y0 = tvm.nd.array(np.zeros((2, 1)).astype("float32")) @@ -92,7 +92,7 @@ def foo(x: R.Tensor(dtype="float32")) -> R.Shape: mod = TestVMCompileStage2 target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (32, 16) @@ -127,7 +127,7 @@ def foo(x: R.Tensor((32, 16), "float32")) -> R.Tensor: mod = TestVMCompileStage3 target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (32, 16) @@ -152,7 +152,7 @@ def foo(x: R.Tensor(dtype="float32")) -> R.Tensor: mod = TestVMCompileE2E target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (32, 16) @@ -193,7 +193,7 @@ def func( mod = TestVMCompileE2E2 target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) @@ -220,7 +220,7 @@ def test_vm_emit_te_extern(exec_mode): mod = bb.get() target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) @@ -249,7 +249,7 @@ def te_func(A, B): mod = bb.get() target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) inp = tvm.nd.array( @@ -288,7 +288,7 @@ def te_func(A): new_mod = relax.transform.CallTIRRewrite()(mod) target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) inp = tvm.nd.array( @@ -317,7 +317,7 @@ def te_func(A): mod = bb.get() target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (9,) @@ -346,7 +346,7 @@ def test_vm_emit_te_constant_param_cpu(exec_mode): bb.emit_func_output(gv) mod = bb.get() - exec = relax.vm.build(mod, "llvm", exec_mode=exec_mode) + exec = relax.build(mod, "llvm", exec_mode=exec_mode) dev = tvm.cpu() vm = relax.VirtualMachine(exec, dev) @@ -374,7 +374,7 @@ def test_vm_emit_te_constant_param_gpu(exec_mode): loops = sch.get_loops(sch.get_block(name="T_add", func_name="add")) sch.bind(loops[0], "threadIdx.x") - exec = relax.vm.build(sch.mod, "cuda", exec_mode=exec_mode) + exec = relax.build(sch.mod, "cuda", exec_mode=exec_mode) dev = tvm.cuda() vm = relax.VirtualMachine(exec, dev) @@ -400,7 +400,7 @@ def te_func(A, B): mod = bb.get() target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) shape1 = (5,) @@ -435,14 +435,10 @@ def te_func(A): mod = bb.get() target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) - ex.mod.export_library("exec.so") - exec1 = relax.vm.Executable(tvm.runtime.load_module("exec.so")) - os.remove("exec.so") - assert ex.as_text() == exec1.as_text() - - vm = relax.VirtualMachine(ex, tvm.cpu()) + ex.export_library("exec.so") + vm = relax.VirtualMachine(tvm.runtime.load_module("exec.so"), tvm.cpu()) inp = tvm.nd.array(np.random.rand(2).astype(np.float32)) inp2 = tvm.nd.array(np.random.rand(3).astype(np.float32)) @@ -466,7 +462,7 @@ def test_vm_tuple(exec_mode): mod = bb.get() target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) shape = (5,) @@ -496,7 +492,7 @@ def tuple_get_item( mod = TestVMTupleGetItem target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) y_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) @@ -526,7 +522,7 @@ def copy(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): mod = TestMemoryAllocStorageTensor target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) x = tvm.nd.array(np.random.rand(2, 3).astype("float32")) y = vm["main"](x) @@ -577,7 +573,7 @@ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> return gv1 target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(TestVMSubFunction, target, exec_mode=exec_mode) + ex = relax.build(TestVMSubFunction, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) x_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) y_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) @@ -609,7 +605,7 @@ def recursion(n: R.Tensor((1,), "float32")) -> R.Tensor: return res target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(TestVMRecursion, target, exec_mode=exec_mode) + ex = relax.build(TestVMRecursion, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) inp = np.empty(1).astype("float32") @@ -639,7 +635,7 @@ def main( mod = TestClosure target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(mod, target, exec_mode=exec_mode) + ex = relax.build(mod, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) y_inp = tvm.nd.array(np.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]], dtype="float32")) @@ -658,7 +654,7 @@ def main(x: R.Tensor((1,), "float32"), y: R.Tensor((1,), "float32")): ) target = tvm.target.Target("llvm", host="llvm") - ex = relax.vm.build(TestTimeEvaluator, target, exec_mode=exec_mode) + ex = relax.build(TestTimeEvaluator, target, exec_mode=exec_mode) vm = relax.VirtualMachine(ex, tvm.cpu()) x = tvm.nd.array(np.random.rand(1).astype("float32")) y = tvm.nd.array(np.random.rand(1).astype("float32")) @@ -780,9 +776,9 @@ def set_input_attempt_get(vm: relax.VirtualMachine, device: tvm.runtime.Device) def make_vm(mod, exec_mode) -> Tuple[relax.VirtualMachine, tvm.runtime.Device]: """Returns a local VM for the given mod and the device""" target = tvm.target.Target("llvm", host="llvm") - exec = relax.vm.build(TestVMSetInput, target, exec_mode=exec_mode) - exec.mod.export_library("exec.so") - exec_loaded = relax.vm.Executable(tvm.runtime.load_module("exec.so")) + exec = relax.build(TestVMSetInput, target, exec_mode=exec_mode) + exec.export_library("exec.so") + exec_loaded = tvm.runtime.load_module("exec.so") os.remove("exec.so") device = tvm.cpu() return relax.VirtualMachine(exec_loaded, device), device @@ -798,10 +794,10 @@ def run_on_rpc( The trial function should take a VM and a device """ target = tvm.target.Target("llvm", host="llvm") - exec = relax.vm.build(mod, target, exec_mode=exec_mode) + exec = relax.build(mod, target, exec_mode=exec_mode) temp = utils.tempdir() path = temp.relpath("vm_library.so") - exec.mod.export_library(path) + exec.export_library(path) # Use local rpc server for testing. # Server must use popen so it doesn't inherit the current process state. It @@ -817,7 +813,7 @@ def check_remote(server): device = remote.cpu() # Build a VM out of the executable and context. - vm = relax.vm.VirtualMachine(exec=rexec, device=device) + vm = relax.VirtualMachine(rexec, device=device) trial_func(vm, device) check_remote(rpc.Server("127.0.0.1")) diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index 600d2456174e..679641de13be 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -33,8 +33,8 @@ def codegen(mod, target, exec_mode="bytecode"): builder = relax.ExecBuilder() - tir_mod = relax.vm._vmcodegen(builder, mod, exec_mode=exec_mode) - return relax.vm._vmlink(builder, target, tir_mod) + tir_mod = relax.vm_build._vmcodegen(builder, mod, exec_mode=exec_mode) + return relax.vm_build._vmlink(builder, target, tir_mod) @pytest.mark.parametrize("exec_mode", EXEC_MODE) @@ -95,10 +95,10 @@ def foo(x: R.Tensor((3, 4), "float32")): temp_dir = utils.tempdir() path_exec = temp_dir.relpath("exec.so") - ex.mod.export_library(path_exec) + ex.export_library(path_exec) - loaded_exec = relax.vm.Executable(tvm.runtime.load_module(path_exec)) - assert ex.as_text() == loaded_exec.as_text() + loaded_exec = tvm.runtime.load_module(path_exec) + assert ex.as_text() == loaded_exec["as_text"]() @pytest.mark.parametrize("exec_mode", EXEC_MODE) diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py index 6f3bced38581..d6bac6ae157c 100644 --- a/tests/python/relax/test_vm_codegen_tir.py +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -28,7 +28,7 @@ def get_tir_mod(mod): builder = relax.ExecBuilder() - return relax.vm._vmcodegen(builder, mod, exec_mode="compiled") + return relax.vm_build._vmcodegen(builder, mod, exec_mode="compiled") def test_add(): diff --git a/tests/python/relax/test_vm_profiler.py b/tests/python/relax/test_vm_profiler.py index 90737cc9c980..114596741113 100644 --- a/tests/python/relax/test_vm_profiler.py +++ b/tests/python/relax/test_vm_profiler.py @@ -47,7 +47,7 @@ def get_exec(data_shape): mod = relax.transform.BindParams("main", params)(mod) target = "llvm" - return relax.vm.build(mod, target) + return relax.build(mod, target) def test_conv2d_cpu(): @@ -65,7 +65,7 @@ def test_conv2d_cpu(): def with_rpc(ex, f, data_np): temp = utils.tempdir() path = temp.relpath("vm_library.so") - ex.mod.export_library(path) + ex.export_library(path) server = rpc.Server("127.0.0.1") remote = rpc.connect(server.host, server.port, session_timeout=10) @@ -75,7 +75,7 @@ def with_rpc(ex, f, data_np): device = remote.cpu() - vm = relax.vm.VirtualMachine(exec=rexec, device=device, profile=True) + vm = relax.VirtualMachine(rexec, device=device, profile=True) data = tvm.nd.array(data_np, device) f(vm, data) @@ -115,7 +115,7 @@ def main( return ((x, (x,)), x) target = "llvm" - ex = relax.vm.build(NestedTuple, target) + ex = relax.build(NestedTuple, target) data_np = np.random.randn(16).astype("float32") diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 6443d50f9e98..63ff66eaa291 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -862,7 +862,7 @@ def prepare_vm_model(path, tensor_shape): vm_exec = vm.compile(mod, target=target) # Export to Disk - vm_exec.mod.export_library(path) + vm_exec.export_library(path) def test_vm_rpc(): @@ -1393,7 +1393,7 @@ def test_large_constants(): path_consts = temp.relpath("consts") vm_exec.move_late_bound_consts(path_consts, byte_limit=256) path_dso = temp.relpath("lib.so") - vm_exec.mod.export_library(path_dso) + vm_exec.export_library(path_dso) # Load library files and constants mod = runtime.load_module(path_dso) @@ -1442,7 +1442,7 @@ def test_load_late_bound_consts_with_no_late_bound_consts(): # Ensure const_data is below the byte threshold for a late-bound const. byte_limit = len(const_data.tobytes()) + 1 vm_exec.move_late_bound_consts(path_consts, byte_limit=byte_limit) - vm_exec.mod.export_library(path_dso) + vm_exec.export_library(path_dso) mod = runtime.load_module(path_dso) mod["load_late_bound_consts"](path_consts) @@ -1503,7 +1503,7 @@ def test_load_and_save_constants_via_map(): # Save to constants and library files temp = utils.tempdir() path_dso = temp.relpath("lib.so") - vm_exec.mod.export_library(path_dso) + vm_exec.export_library(path_dso) # Load library files and constants mod = runtime.load_module(path_dso) @@ -1551,7 +1551,7 @@ def test_load_late_bound_consts_via_map_with_no_late_bound_consts(): # Ensure const_data is below the byte threshold for a late-bound const. byte_limit = len(const_data.tobytes()) + 1 consts_map = vm_exec.get_late_bound_consts(byte_limit=byte_limit) - vm_exec.mod.export_library(path_dso) + vm_exec.export_library(path_dso) mod = runtime.load_module(path_dso) mod["load_late_bound_consts_from_map"](consts_map) From d1997fd53e3317124d35aa3814deef09c202ca9e Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 22 Feb 2023 12:50:15 -0800 Subject: [PATCH 62/81] [Unity][Relax] Set Shape Function to Be Host Function (#14090) Set shape function to be host func. --- src/relax/backend/vm/vm_shape_lower.cc | 5 +++++ tests/python/relax/test_backend_transform_shape_lower.py | 1 + 2 files changed, 6 insertions(+) diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 090bcf01b5a5..f4b272979bb6 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -531,6 +531,11 @@ class VMShapeLowerMutator // the shape_func to indicate that this is a host function // This could require us to attach target to the relax function here. tir::PrimFunc shape_func(params, body, ret_type, buffer_map); + if (shape_func->attrs.GetAttr(tvm::attr::kTarget) == nullptr) { + // kTarget and kIsHostFunc are mutually exclusive + shape_func = + WithAttr(std::move(shape_func), tvm::tir::attr::kIsHostFunc, Integer(1)); + } GlobalVar shape_func_var = builder_->AddFunction(shape_func, "shape_func"); builder_->Emit(Call(shape_func_var, {shape_heap_}), "_"); return to_compute.size(); diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index 5cd104dd013f..9c11b352c831 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -178,6 +178,7 @@ class Expected: @T.prim_func def shape_func(H: T.Buffer(T.int64(4), "int64")): # generated compute function + T.func_attr({"tir.is_host_func": 1}) H[T.int64(sindex["k+1"])] = H[T.int64(sindex["k"])] + T.int64(1) @R.function From 4ca7107ac638ab4588287ae8901022962df99e3c Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 22 Feb 2023 15:45:03 -0800 Subject: [PATCH 63/81] [Unity] Fix typo in the comment (#14096) --- include/tvm/relax/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index b42fb5864ef7..3c02871f6cbd 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -176,7 +176,7 @@ TVM_DLL Pass FoldConstant(); */ TVM_DLL Pass LegalizeOps(Optional> cmap); -/* +/*! * \brief Lift transformation of the parameters of a function. * * When some inputs of the function is marked as 'parameters' (the model weights), this pass From fc5981b09e7e4d35e65dd3f06de2b16558139970 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Wed, 22 Feb 2023 15:57:49 -0800 Subject: [PATCH 64/81] [Unity] Lower `shape_of` to a builtin (#14093) This PR lowers shape_of op to a Relax VM builtin, and changes a utility function to take StructInfo as input. Co-authored-by: Steven S. Lyubomirsky --- include/tvm/relax/utils.h | 8 +-- src/relax/backend/vm/vm_builtin_lower.cc | 10 ++++ src/relax/utils.cc | 5 +- tests/python/relax/test_relax_operators.py | 62 ++++++++++++++++++++++ 4 files changed, 79 insertions(+), 6 deletions(-) create mode 100644 tests/python/relax/test_relax_operators.py diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index b3cc76768dd4..dd0200623a9a 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -25,7 +25,6 @@ #define TVM_RELAX_UTILS_H_ #include -#include #include #include @@ -110,9 +109,10 @@ class NameTable { TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); /*! - * \brief Check if the given type is a boolean scalar type (tensor of rank 0 with a boolean dtype). + * \brief Check if the given StructInfo is for a boolean scalar (tensor of rank 0 with a boolean + * dtype). * - * \param ty The input type. + * \param sinfo The input StructInfo. * \param permit_unknown_rank If true, it will permit the input type to have unknown rank * (ndim of -1), which will require a dynamic check. * \param permit_unknown_dtype If true, it will permit the input type to have an unknown dtype @@ -121,7 +121,7 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); * \return True iff the input type is a boolean scalar type (or, depending on options, has unknown * rank or dtype) */ -TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true, +TVM_DLL bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank = true, bool permit_unknown_dtype = true); /*! diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/vm_builtin_lower.cc index 6613b39626da..00d8512dc6af 100644 --- a/src/relax/backend/vm/vm_builtin_lower.cc +++ b/src/relax/backend/vm/vm_builtin_lower.cc @@ -53,6 +53,8 @@ class VMBuiltinLowerMutator : public ExprMutator { return CallTIRDyn(call); } else if (call->op == reshape_op_) { return Reshape(call); + } else if (call->op == shape_of_op_) { + return ShapeOf(call); } else if (call->op == make_closure_op_) { return MakeClosure(call); } else if (call->op == invoke_closure_op_) { @@ -132,6 +134,12 @@ class VMBuiltinLowerMutator : public ExprMutator { return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); } + Expr ShapeOf(const Call& call_node) { + ICHECK(call_node->args.size() == 1); + ICHECK(call_node->struct_info_.defined()); + return Call(builtin_shape_of_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + } + Expr MakeClosure(const Call& call_node) { ICHECK(call_node->args.size() == 2); ICHECK(call_node->args[0]->IsInstance()); @@ -173,6 +181,7 @@ class VMBuiltinLowerMutator : public ExprMutator { // object to pattern match. const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); const Op& reshape_op_ = Op::Get("relax.reshape"); + const Op& shape_of_op_ = Op::Get("relax.shape_of"); const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor"); @@ -187,6 +196,7 @@ class VMBuiltinLowerMutator : public ExprMutator { const ExternFunc builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"}; const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"}; const ExternFunc builtin_reshape_{"vm.builtin.reshape"}; + const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"}; const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; }; diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 110bdb5c8c20..1cf64cbf64a4 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -67,8 +67,9 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { } } -bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) { - const DynTensorTypeNode* tt = ty.as(); +bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank, + bool permit_unknown_dtype) { + const TensorStructInfoNode* tt = sinfo.as(); if (!tt) { return false; } diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py new file mode 100644 index 000000000000..7b0b98fea976 --- /dev/null +++ b/tests/python/relax/test_relax_operators.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +import tempfile + +import numpy as np +import tvm +import tvm.testing +from tvm import relax +from tvm._ffi.base import TVMError +from tvm.script import relax as R + + +def run_cpu(mod, func_name, *input): + target = tvm.target.Target("llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + return vm[func_name](*input) + + +@tvm.script.ir_module +class ShapeOfTest: + @R.function + def get_shape(t: R.Tensor(ndim=-1, dtype="int32")) -> R.Shape(ndim=-1): + return R.shape_of(t) + + @R.function + def get_shape_const() -> R.Shape(ndim=-1): + x: R.Tensor((), "int32") = R.const(1, dtype="int32") + return R.shape_of(x) + + +def test_op_shape_of(): + const_shape = run_cpu(ShapeOfTest, "get_shape_const") + assert const_shape == tvm.runtime.ShapeTuple([]) + + scalar_shape = run_cpu(ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, dtype="int32"))) + assert scalar_shape == tvm.runtime.ShapeTuple([]) + + tensor_shape = run_cpu( + ShapeOfTest, "get_shape", tvm.nd.array(np.zeros((1, 2, 3)).astype("int32")) + ) + assert tensor_shape == tvm.runtime.ShapeTuple([1, 2, 3]) + + +if __name__ == "__main__": + tvm.testing.main() From 3f4835c4c8bbb6338915b04483ae2d6907306e63 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 22 Feb 2023 16:00:06 -0800 Subject: [PATCH 65/81] [Unity] Relax Recursive function (#14092) This PR adds TVMScript local recursive function support. It also update lambda lifting pass. Removed CalledGlobalVars, it was not used anymore. It also updates well-form pass to allow un-defined vars for recursive call --- include/tvm/relax/analysis.h | 9 -- include/tvm/script/ir_builder/relax/ir.h | 7 ++ python/tvm/script/ir_builder/relax/ir.py | 17 +++- python/tvm/script/parser/relax/parser.py | 62 ++++++++++++-- src/relax/analysis/analysis.cc | 20 ----- src/relax/analysis/well_formed.cc | 11 ++- src/relax/transform/lambda_lift.cc | 84 ++++++++++++++----- src/script/ir_builder/relax/ir.cc | 9 ++ .../python/relax/test_analysis_well_formed.py | 27 ++++++ .../relax/test_transform_lambda_lift.py | 34 ++++---- tests/python/relax/test_utils.py | 6 +- 11 files changed, 213 insertions(+), 73 deletions(-) diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index b9866577e9b6..39ecfd9e13a7 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -296,15 +296,6 @@ TVM_DLL tvm::Array FreeVars(const Expr& expr); */ TVM_DLL tvm::Array AllVars(const Expr& expr); -/*! - * \brief Get all global variables used in calls in expression expr. - * - * \param expr the expression. - * - * \return List of all global variables called in expr. - */ -TVM_DLL tvm::Array CalledGlobalVars(const Expr& expr); - /*! * \brief Get all global variables from expression expr. * diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index 72aab6684ebf..42aa591a95b7 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -110,6 +110,13 @@ TVM_DLL tvm::relax::Var Emit( TVM_DLL tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, const tvm::relax::StructInfo& struct_info); +/*! + * \brief Emit a binding to the last binding block frame. + * \param binding The binding to be emitted. + * \return The left side var of the emitted binding. + */ +TVM_DLL tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding); + ///////////////////////////// If Then Else ///////////////////////////// /*! diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 43918ce7ec83..63efea135c15 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -25,7 +25,7 @@ import tvm from tvm import DataType, relax from tvm.ir import PrimExpr -from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, const +from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, VarBinding, const ############################### Operators ############################### from tvm.relax.op import ( @@ -342,6 +342,20 @@ def emit_match_cast(value: Expr, struct_info: StructInfo) -> Var: return _ffi_api.EmitMatchCast(value, struct_info) # type: ignore +def emit_var_binding(value: VarBinding) -> Var: + """Emit a binding to the last binding block frame. + Parameters + ---------- + value: VarBinding + The binding to be emitted. + Returns + ------- + var: Var + The left side var of the emitted binding. + """ + return _ffi_api.EmitVarBinding(value) # type: ignore + + ############################# If Then Else ############################# @@ -497,6 +511,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "divide", "dtype", "emit", + "emit_var_binding", "emit_match_cast", "equal", "ewise_fma", diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index e5e5bb2743e1..e1af1c1df346 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -96,8 +96,7 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy: annotation = annotation() if isinstance(annotation, StructInfoProxy): return annotation - else: - raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.") + raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.") except Exception as err: self.report_error(node, str(err)) raise err @@ -112,6 +111,38 @@ def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> St raise err +def is_called(node: Any, func_name: str) -> bool: + # Check if it calls into a func + if isinstance(node, doc.Call): + # Recursive call was found + if isinstance(node.func, doc.Name) and node.func.id == func_name: + return True + elif isinstance(node, (list, tuple)): + for stmt in node: + if is_called(stmt, func_name): + return True + elif isinstance(node, (doc.AnnAssign, doc.Assign, doc.Return, doc.Expr)): + return is_called(node.value, func_name) + elif isinstance(node, doc.With): + return is_called(node.body, func_name) + elif isinstance(node, doc.If): + smts = [] + if node.body is not None: + smts = smts + list(node.body) + if node.orelse is not None: + smts = smts + list(node.orelse) + return is_called(smts, func_name) + return False + + +def is_recursive(node: doc.FunctionDef) -> bool: + # Check if it is a recursive function + for stmt in node.body: + if is_called(stmt, node.name): + return True + return False + + def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> None: # Collect symbolic vars from parameters symbolic_vars = set() @@ -128,6 +159,24 @@ def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> Non @dispatch.register(token="relax", type_name="FunctionDef") def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + # reserve a var for local function + func_val = self.var_table.get().get(node.name) + if not func_val and is_recursive(node): + collect_symbolic_var_from_params(self, node) + if node.returns is None: + ret_sinfo = relax.TupleStructInfo([]) + else: + ret_sinfo = eval_struct_info(self, node.returns, eval_str=True) + params_sinfo = [] + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) + params_sinfo.append(param_sinfo) + # created a var for the local function, the same var could be used for recursive call + local_func_var = relax.Var(node.name, relax.FuncStructInfo(params_sinfo, ret_sinfo)) + self.var_table.add(node.name, local_func_var) + with self.var_table.with_frame(): with self.with_dispatch_token("relax"): with R.function(): @@ -164,12 +213,10 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: else: ret_sinfo = eval_struct_info(self, node.returns, eval_str=True) params = [] - params_sinfo = [] for arg in node.args.args: if arg.annotation is None: self.report_error(arg, "Type annotation is required for function parameters.") param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) - params_sinfo.append(param_sinfo) params.append(relax.Var(arg.arg, param_sinfo)) func_signature = relax.Function.create_empty(params, ret_sinfo) @@ -188,7 +235,12 @@ def post_token_switch(self: Parser, node: doc.Expr) -> None: ir_builder = IRBuilder.current() result = ir_builder.get() ir_builder.__exit__(None, None, None) - var = R.emit(result) + # reuse var if it is reserved + reserved_var = self.var_table.get().get(node.name) + if reserved_var: + var = R.emit_var_binding(relax.VarBinding(reserved_var, result)) + else: + var = R.emit(result) IRBuilder.name(node.name, var) self.var_table.add(node.name, var, allow_shadowing=False) diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc index 33197308fa1b..4132039a5e34 100644 --- a/src/relax/analysis/analysis.cc +++ b/src/relax/analysis/analysis.cc @@ -87,15 +87,6 @@ class VarVisitor : protected ExprVisitor { return ret; } - Array CalledGlobalVars(const Expr& expr) { - this->VisitExpr(expr); - Array ret; - for (const auto& v : called_global_vars_.data) { - ret.push_back(v); - } - return ret; - } - void MarkBounded(const Var& v) { bound_vars_.Insert(v); vars_.Insert(v); @@ -123,10 +114,6 @@ class VarVisitor : protected ExprVisitor { for (Expr arg : call_node->args) { VisitExpr(arg); } - - if (const GlobalVarNode* global_var_node = call_node->op.as()) { - called_global_vars_.Insert(GetRef(global_var_node)); - } } void VisitBinding_(const VarBindingNode* binding) final { @@ -144,7 +131,6 @@ class VarVisitor : protected ExprVisitor { InsertionSet vars_; InsertionSet bound_vars_; InsertionSet global_vars_; - InsertionSet called_global_vars_; }; tvm::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } @@ -155,10 +141,6 @@ tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } tvm::Array AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); } -tvm::Array CalledGlobalVars(const Expr& expr) { - return VarVisitor().CalledGlobalVars(expr); -} - TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars); TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars); @@ -167,7 +149,5 @@ TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars); TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars); -TVM_REGISTER_GLOBAL("relax.analysis.called_global_vars").set_body_typed(CalledGlobalVars); - } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 05ad0954bbfc..25b9155d7740 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -177,7 +177,7 @@ class WellFormedChecker : public relax::ExprVisitor, void VisitExpr_(const VarNode* op) final { Var var = GetRef(op); - if (var_set_.count(var) == 0) { + if (var_set_.count(var) == 0 && recur_vars_.count(var) == 0) { Malformed(Diagnostic::Error(var) << "Var " << op->name_hint() << " is not defined."); } CheckStructInfo(op); @@ -316,12 +316,20 @@ class WellFormedChecker : public relax::ExprVisitor, } void VisitBinding_(const VarBindingNode* binding) final { + bool is_lambda = false; + if (binding->value->IsInstance()) { + is_lambda = true; + recur_vars_.insert(binding->var); + } if (binding->value->IsInstance()) { Malformed(Diagnostic::Error(binding->value) << "Inline PrimFunc is disallowed in Relax IR."); } else { this->VisitExpr(binding->value); } this->VisitVarDef(binding->var); + if (is_lambda) { + recur_vars_.erase(binding->var); + } } void VisitBinding_(const MatchCastNode* binding) final { @@ -451,6 +459,7 @@ class WellFormedChecker : public relax::ExprVisitor, VisitMode mode_ = VisitMode::kDefault; // set of context variables. std::unordered_set var_set_; + std::unordered_set recur_vars_; std::unordered_set dataflow_var_set_; std::unordered_set symbolic_var_set_; std::unordered_map param_var_func_map_; diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index f08499036b1c..74920823100a 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -46,35 +46,72 @@ class LambdaLifter : public ExprMutator { using ExprMutator::VisitExpr_; + void VisitBinding_(const VarBindingNode* binding) final { + bool is_lambda = false; + if (binding->value->IsInstance()) { + is_lambda = true; + recur_vars_.push_back(binding->var); + } + Expr new_value = this->VisitExpr(binding->value); + if (new_value->struct_info_.defined() && + !new_value->struct_info_.same_as(binding->var->struct_info_)) { + binding->var->struct_info_ = GetStructInfo(new_value); + binding->var->checked_type_ = new_value->checked_type_; + } + if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + builder_->EmitNormalized(VarBinding(binding->var, new_value)); + } + if (is_lambda) { + recur_vars_.pop_back(); + } + } + Expr VisitExpr_(const CallNode* call_node) final { auto call = Downcast(ExprMutator::VisitExpr_(call_node)); - if (auto const* var = call_node->op.as()) { - bool has_closure = HasClosure(GetRef(var)); - auto val = builder_->LookupBinding(GetRef(var)); + if (const auto* var_node = call_node->op.as()) { + auto var = GetRef(var_node); + bool has_closure = HasClosure(var); + auto val = builder_->LookupBinding(var); + if (const auto* fsinfo_node = GetStructInfo(var).as()) { + auto fsinfo = GetRef(fsinfo_node); + if (!GetStructInfo(call).same_as(fsinfo)) { + call->struct_info_ = fsinfo->ret; + call->checked_type_ = GetStaticType(fsinfo->ret); + } + } // Call "relax.invoke_closure" to invoke closure - if (has_closure && val.as()) { - Var clo_arg = GetRef(var); + Var clo_arg = var; + if (has_closure && val->IsInstance()) { if (this->var_remap_.find(var->vid) != this->var_remap_.end()) { clo_arg = this->var_remap_.at(var->vid); } return Call(invoke_closure_op_, {clo_arg, Tuple(call_node->args)}, {}, {GetStructInfo(GetRef(call_node))}); } - } - if (auto global_var_node = call_node->op.as()) { - String rec_name = global_var_node->name_hint; - auto global_var = GetRef(global_var_node); - auto it = lambda_map_.find(global_var); + auto it = lambda_map_.find(var); if (it != lambda_map_.end()) { // flatten nested call, e.g. call(y)(x) -> call(x, y)) Array new_args; + Array params; for (const auto arg : call->args) { new_args.push_back(arg); + params.push_back(StructInfoFromType(arg->checked_type())); } if (const auto* nest_call = it->second.as()) { + // Update the StructInfo accordingly for (const auto arg : nest_call->args) { new_args.push_back(arg); + params.push_back(StructInfoFromType(arg->checked_type())); } + StructInfo new_func_sinfo; + if (const auto* fsinfo = GetStructInfo(nest_call->op).as()) { + auto func_sinfo = GetRef(fsinfo); + new_func_sinfo = FuncStructInfo(params, func_sinfo->ret); + } + nest_call->op->struct_info_ = new_func_sinfo; + nest_call->op->checked_type_ = GetStaticType(new_func_sinfo); return Call(nest_call->op, new_args, call_node->attrs, call_node->sinfo_args); } return Call(it->second, call->args, call_node->attrs, call_node->sinfo_args); @@ -89,11 +126,19 @@ class LambdaLifter : public ExprMutator { // TODO(@yongwww): consider appending inner func name into the lifted func name String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++); auto global = GlobalVar(lift_func_name); - Array captured_vars = FreeVars(func); - recur_vars_ = CalledGlobalVars(func); - auto all_global_vars = AllGlobalVars(func); + Array free_vars = FreeVars(func); + Array captured_vars; Array typed_captured_vars; + bool recursive = false; + for (const auto& var : free_vars) { + if (!recur_vars_.empty() && var == recur_vars_.back()) { + recursive = true; + } else { + captured_vars.push_back(var); + } + } + Map rebinding_map; for (auto free_var : captured_vars) { Var var = Var(free_var->name_hint(), GetStructInfo(free_var), free_var->span); @@ -102,12 +147,14 @@ class LambdaLifter : public ExprMutator { } // recursive call - if (!recur_vars_.empty()) { + if (recursive) { if (!captured_vars.empty()) { Array fvs; for (auto fv : captured_vars) { fvs.push_back(fv); } + // it is required by block_blocker, will be updated later + UpdateStructInfo(global, GetStructInfo(recur_vars_.back())); lambda_map_.emplace(recur_vars_.back(), Call(global, fvs)); } else { if (recur_vars_.size() > 0) { @@ -162,18 +209,17 @@ class LambdaLifter : public ExprMutator { /*attrs=*/new_func->attrs, /*span=*/func->span); - Array param_types; for (Var param : closure_params) { CHECK(param->checked_type_.defined()) << "relax.Function requires params to contain checked_type_"; - param_types.push_back(param->checked_type_); } } ICHECK(lifted_func.defined()); // Add the lifted function to the module. - UpdateStructInfo(global, GetStructInfo(lifted_func)); + global->struct_info_ = GetStructInfo(lifted_func); + global->checked_type_ = lifted_func->checked_type_; builder_->UpdateFunction(global, lifted_func); if (!is_closure) { @@ -242,8 +288,8 @@ class LambdaLifter : public ExprMutator { } private: - std::unordered_map lambda_map_; - Array recur_vars_; + std::unordered_map lambda_map_; + Array recur_vars_; IRModule mod_; size_t lift_func_num_ = 0; /*! \brief Cache ops that would be used later to reduce lookup overhead. */ diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index ece645243c82..ddfb1ddfa35f 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -203,8 +203,17 @@ tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, return var; } +tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding) { + BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + block_builder->EmitNormalized(binding); + block_frame->emitted_vars.push_back(binding->var); + return binding->var; +} + TVM_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitVarBinding").set_body_typed(EmitVarBinding); ///////////////////////////// If Then Else ///////////////////////////// diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 67da77274188..ee5814eb7bfc 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -173,6 +173,33 @@ def test_seq_expr(): assert not rx.analysis.well_formed(mod, check_struct_info=False) +def test_recursive(): + scalar_struct_info = rx.TensorStructInfo(shape=[], dtype="int32") + gv0 = rx.Var("gv0", scalar_struct_info) + f = rx.Var("f", rx.FuncStructInfo([scalar_struct_info], scalar_struct_info)) + ipt = rx.Var("ipt", scalar_struct_info) + x0 = rx.Var("x0", scalar_struct_info) + x1 = rx.Var("x1", scalar_struct_info) + x2 = rx.Var("x2", scalar_struct_info) + y = rx.Var("y", scalar_struct_info) + inner_block = rx.BindingBlock( + [rx.VarBinding(x0, rx.const(2, "int32")), rx.VarBinding(y, rx.Call(f, [x0]))] + ) + inner_func = rx.Function([ipt], rx.SeqExpr([inner_block], y), scalar_struct_info) + outer_block = rx.BindingBlock( + [ + rx.VarBinding(f, inner_func), + rx.VarBinding(x1, rx.const(1, "int32")), + rx.VarBinding(x2, rx.op.add(x1, rx.Call(f, [x1]))), + rx.VarBinding(gv0, x2), + ] + ) + func = rx.Function([], rx.SeqExpr([outer_block], gv0), scalar_struct_info) + mod = tvm.IRModule.from_expr(func) + normalized = rx.transform.Normalize()(mod) + assert rx.analysis.well_formed(normalized) + + def test_if(): # Error: Var defined in true/false branch is invisible in the outer scope # except the return Var, i.e the var in the last stmt diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index c9bbc0fb91e7..5a137f22cb5f 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -114,7 +114,9 @@ def main( x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") ) -> R.Tensor((2, 3), "float32"): @R.function - def outer_func(c1: R.Tensor((2, 3), "float32")): + def outer_func( + c1: R.Tensor((2, 3), "float32") + ) -> R.Callable((R.Tensor((2, 3), "float32"),), R.Tensor((2, 3), "float32")): @R.function def inner_func(x1: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): s: R.Tensor((2, 3), "float32") = R.add(x1, c1) @@ -133,7 +135,6 @@ def inner_func(x1: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): _check_save_roundtrip(after) -@pytest.mark.skip(reason="Need fix after parser switch over") def test_recursive(): # the expected IRModule @tvm.script.ir_module @@ -149,18 +150,19 @@ def lifted_func_0( if cond: new_i: R.Tensor((), "int32") = R.add(i, c) new_s: R.Tensor((2, 3), "float32") = R.add(s, x) - r = lifted_func_0(new_i, new_s, x) + new_r = lifted_func_0(new_i, new_s, x) + r = new_r else: r = s return r @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor: + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), dtype="float32"): while_loop = R.make_closure(lifted_func_0, (x,)) - gv = R.invoke_closure( + gv: R.Tensor((2, 3), dtype="float32") = R.invoke_closure( while_loop, - (relax.const(0), x), - sinfo_args=(R.Tensor(ndim=2, dtype="float32")), + (R.const(0), x), + sinfo_args=(R.Tensor((2, 3), dtype="float32")), ) return gv @@ -185,11 +187,14 @@ def while_loop( r: R.Tensor((2, 3), "float32") = s return r - gv: R.Tensor((2, 3), "float32") = while_loop(relax.const(0), x) + gv: R.Tensor((2, 3), "float32") = while_loop(R.const(0), x) return gv before = Before expected = Expected + # check well-formness of recursive call + assert relax.analysis.well_formed(before) + # Perform Lambda Lifting after = transform.LambdaLift()(before) assert len(after.functions) == 2 @@ -198,7 +203,6 @@ def while_loop( _check_save_roundtrip(after) -@pytest.mark.skip(reason="Need fix after parser switch over") def test_multi_func(): # expected IRModule @tvm.script.ir_module @@ -207,29 +211,29 @@ class Expected: def glob_func_1( x1: R.Tensor((10, 5), "float32"), y1: R.Tensor((10, 5), "float32") ) -> R.Tensor(None, "float32", ndim=2): - inner = lifted_func_1 - gv1 = inner(x1, y1) + inner = lifted_func_0 + gv1: R.Tensor((10, 5), "float32") = inner(x1, y1) return gv1 @R.function def glob_func_2( x11: R.Tensor((10, 5), "float32"), y11: R.Tensor((10, 5), "float32") ) -> R.Tensor(None, "float32", ndim=2): - inner1 = lifted_func_0 - gv11 = inner1(x11, y11) + inner = lifted_func_1 + gv11: R.Tensor((10, 5), "float32") = inner(x11, y11) return gv11 @R.function def lifted_func_0( x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") - ) -> R.Tensor(None, "float32", ndim=2): + ) -> R.Tensor((10, 5), "float32"): s: R.Tensor((10, 5), "float32") = R.add(x2, y2) return s @R.function def lifted_func_1( x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5), "float32") - ) -> R.Tensor(None, "float32", ndim=2): + ) -> R.Tensor((10, 5), "float32"): s1: R.Tensor((10, 5), "float32") = R.add(x21, y21) return s1 diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index fbeb57564fb5..15122dab3771 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -69,7 +69,7 @@ class Actual: @R.function def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): @R.function - def inner(x: R.Tensor((3,), "float32")): + def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): gv = R.add(x, x) return gv @@ -81,7 +81,7 @@ class Expected: @R.function def func(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): @R.function - def inner(x: R.Tensor((3,), "float32")): + def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): gv = R.add(x, x) return gv @@ -91,7 +91,7 @@ def inner(x: R.Tensor((3,), "float32")): @R.function def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): @R.function - def inner(x: R.Tensor((3,), "float32")): + def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): gv = R.add(x, x) return gv From 4d72dafa6b09ef75112fe675ef550da674b267c5 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Thu, 23 Feb 2023 05:21:20 -0500 Subject: [PATCH 66/81] [Unity][Layout] Add layout transformation analysis for PrimFunc (#14066) * [Layout] Add layout transformation analysis for PrimFunc. This change adds a PrimFunc level analysis to suggest layout transformations to block and buffers in the PrimFunc based on the layout transformations to PrimFunc outputs. * Add support for multiple blocks such as split op. * Add negative tests and increase coverage. * fix warning message * fix lint * remove unused header * Address comments. Moved some utility functions to support/array.h improve doc * fix deprecation warn T.var("int64") to T.int64() * address comments --- include/tvm/relax/analysis.h | 13 + python/tvm/relax/analysis/analysis.py | 32 +- src/relax/analysis/layout_transformation.cc | 621 +++++++++++++ src/support/array.h | 27 +- ...test_analysis_suggest_layout_transforms.py | 831 ++++++++++++++++++ 5 files changed, 1522 insertions(+), 2 deletions(-) create mode 100644 src/relax/analysis/layout_transformation.cc create mode 100644 tests/python/relax/test_analysis_suggest_layout_transforms.py diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 39ecfd9e13a7..2b771b9708ab 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -403,6 +403,19 @@ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func); */ TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true); +/*! + * \brief Using the layout transforms on the outputs, suggest layout transformation on the blocks + * and buffers for the PrimFunc. + * + * \param fn The PrimFunc to be analyzed. + * \param write_buffer_transformations Array of IndexMap transformations on PrimFunc outputs. + * \return Suggested transforms per block in `fn`. For each block the returned value is a map + * from the object (block or buffer) to it's index map transformation. + */ + +TVM_DLL Map> SuggestLayoutTransforms( + const Function& fn, Array write_buffer_transformations); + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index ffcdaceb4076..efd1b51f11de 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -21,7 +21,7 @@ configuring the passes and scripting them in Python. """ -from typing import Dict, List +from typing import Dict, List, Union, Callable from enum import IntEnum from tvm import tir @@ -29,6 +29,7 @@ from tvm.relax.ty import Type from tvm.relax.struct_info import StructInfo, FuncStructInfo from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call, Binding +from tvm.tir import IndexMap, PrimFunc, Block, Buffer from . import _ffi_api @@ -289,3 +290,32 @@ def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool: will be well tested and will not be blocked by not having structure info. """ return _ffi_api.well_formed(mod, check_struct_info) # type: ignore + + +def suggest_layout_transforms( + func: PrimFunc, write_buffer_transforms: List[Union[IndexMap, Callable]] +) -> Dict[Block, Dict[Union[Block, Buffer], IndexMap]]: + """Suggest Layout transformations of blocks and buffers in a PrimFunc. + + Parameters + ---------- + func: PrimFunc + PrimFunc on which analysis will be performed and transformations suggested. + + write_buffer_transforms: List[Union[IndexMap, Callable] + List of layout transformations on the output buffers. The number of layout + transformations must match the number of outputs of the PrimFunc. + + Returns + ------- + ret: Dict[Block, Dict[Union[Block, Buffer], IndexMap]] + Suggested transforms per block in `func`. For each block the returned value is a map + from the object (block or buffer) to it's index map transformation. + """ + write_buffer_index_maps = [] + for transform in write_buffer_transforms: + if callable(transform): + transform = IndexMap.from_func(transform) + assert isinstance(transform, IndexMap) + write_buffer_index_maps.append(transform) + return _ffi_api.suggest_layout_transforms(func, write_buffer_index_maps) # type: ignore diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc new file mode 100644 index 000000000000..44538fea98e5 --- /dev/null +++ b/src/relax/analysis/layout_transformation.cc @@ -0,0 +1,621 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/analysis/layout_transormation.cc + * \brief Analyze the PrimFunc and suggest layout transformation on it's blocks and buffers based on + * the user provided layout transformations on it's outputs. + */ +#include +#include +#include +#include + +#include "../../support/array.h" + +namespace tvm { +namespace relax { + +using namespace tir; + +/********** Helper Functions **********/ + +/*! \brief Checks if a transformation is bijective affine over the given ranges */ +static bool IsBijectiveAffine(const IndexMap& m, const Array& ranges) { + Map input_iters; + ICHECK_EQ(m->initial_indices.size(), ranges.size()); + for (size_t i = 0; i < ranges.size(); i++) { + input_iters.Set(m->initial_indices[i], ranges[i]); + } + arith::Analyzer analyzer; + auto iter_map_result = DetectIterMap(m->final_indices, input_iters, /* predicate = */ 1, + /*check_level=*/arith::IterMapLevel::Bijective, &analyzer, + /*simplify_trivial_iterators=*/true); + return !iter_map_result->indices.empty(); +} + +/*! + * \brief Analyzer to collect iterators from IterSumExpr. + * \details Analyzes the indices from DetectIterMap analysis to collect the spatial iterators that + * are used in it. This is important to get which spatial iterators are accessed in each index + * of buffer access. + */ +class IndexAnalyzer : public ExprVisitor { + public: + Array Analyze(const arith::IterSumExpr& expr) { + VisitExpr(expr); + return iterators_; + } + + private: + /*! \brief Override VisitExpr for iter expr type processing */ + void VisitExpr(const PrimExpr& expr) override { + if (const auto* op = expr.as()) { + for (const auto& arg : op->args) VisitExpr(arg); + VisitExpr(op->base); + return; + } + if (const auto* op = expr.as()) { + VisitIterMark(op->source); + VisitExpr(op->lower_factor); + VisitExpr(op->extent); + VisitExpr(op->scale); + return; + } + return ExprVisitor::VisitExpr(expr); + } + + void VisitIterMark(const arith::IterMark& op) { + if (const auto* var = op->source.as()) + iterators_.push_back(GetRef(var)); + else + VisitExpr(op->source); + VisitExpr(op->extent); + } + + private: + Array iterators_; +}; + +/*! + * \brief Analyzes IterMapResult to get the Spatial Layout of buffer access. + * \details We define Spatial Layout of a buffer access as an array of length equal to the + * dimensions of the buffer. i-th element of Spatial Layout contains spatial iter var used from the + * block iteration domain. For indices, where no spatial iter vars are used, the spatial layout + * element is empty. If any of the buffer access indices use multiple spatial iter vars, the spatial + * layout is undefined. + * + * Here are a few examples of inferred spatial layout from buffer access. si denotes i-th spatial + * iter var, and ri denotes i-th reduction iter var. + * + * SpatialLayout(A[s0*constant, s1]) = {s0, s1} + * SpatialLayout(A[s0, constant, r0, s1]) = {s0, null, null, s1} + * SpatialLayout(A[s0 * c + s1]) = undefined + */ +using SpatialLayout = Array>; +static SpatialLayout GetSpatialLayout(const arith::IterMapResult& iter_map_result) { + ICHECK(!iter_map_result->indices.empty()); + SpatialLayout result; + for (const arith::IterSumExpr& index : iter_map_result->indices) { + IndexAnalyzer index_analyzer; + Array iter_vars = index_analyzer.Analyze(index); + if (iter_vars.size() >= 2) { + LOG(WARNING) << "[LayoutInference] Unable to get spatial layout of access: " + << arith::NormalizeIterMapToExpr(index); + return {}; + } + if (iter_vars.empty()) { + result.push_back({}); + continue; + } + result.push_back(iter_vars[0]); + } + return result; +} + +/*! + * \brief Checks if the two spatial layouts are identical. Two empty spatial layouts are treated as + * unequal. + */ +static bool AreIdenticalSpatialAccess(const SpatialLayout& s0, const SpatialLayout& s1) { + if (s0.empty() || s1.empty()) return false; + if (s0.size() != s1.size()) return false; + for (size_t i = 0; i < s0.size(); ++i) { + if ((!s0[i].defined() && s1[i].defined()) || (s0[i].defined() && !s1[i].defined())) + return false; + if (!s0[i].same_as(s1[i])) return false; + } + return true; +} + +/*! + * \brief Checks if the block accesses a buffer sequentially in terms of spatial dimensions + * (ignoring reduction dimensions). It checks that the order of spatial iter vars in spatial layout + * of a buffer access is same as the order of spatial iter vars in block domain. + */ +using VarToBlockIndexMap = std::unordered_map; +static bool IsSequentialAccess(const SpatialLayout& iterators, + const VarToBlockIndexMap& iter_to_block_index) { + int last_value = -1; + for (const auto& i : iterators) { + if (!i.defined()) continue; + auto it = iter_to_block_index.find(i.value()); + ICHECK(it != iter_to_block_index.end()); + int blk_index = it->second; + if (blk_index <= last_value) return false; + last_value = blk_index; + } + return true; +} + +/*! \brief Checks if two IndexMaps represent identical transforms */ +static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) { + if (t0->initial_indices.size() != t1->initial_indices.size()) return false; + if (t0->final_indices.size() != t1->final_indices.size()) return false; + + // Create a new shape expression. + Array t1_initial_indices = + t1->initial_indices.Map([](tir::Var i) -> PrimExpr { return i; }); + auto t0_output = t0->MapIndices(t1_initial_indices); + arith::Analyzer analyzer; + for (size_t i = 0; i < t0_output.size(); ++i) { + if (!analyzer.CanProveEqual(t0_output[i], t1->final_indices[i])) return false; + } + return true; +} + +/*! + * \brief Returns the layout transformation for a target spatial layout from the source spatial + * layout and transformation. + * \details Given the source buffer spatial layout \p src_spatial_layout and its transformation \p + * src_transformation, this function constructs the transformation for the target buffer whose + * spatial layout is given as \p tgt_spatial_layout. + * + * The algorithm is explained below using an example: + * + * Let's say the source transformation is lambda N, C, H, W -> (N, H, W, C // 4, C % + * 4), source spatial layout is 'NCHW' and target spatial layout is 'KCHW'. + * + * Step 1: Copy over the source transformation initial & final indices to target transformation + * initial and final indices. + * target transformation = lambda N, C, H, W -> (N, H, W, C // 4, C %4) + * + * Step 2: Drop any vars from initial indices which do not occur in target buffer using source and + * target spatial layouts. + * target transformation = lambda C, H, W -> (N, H, W, C // 4, C %4) + * + * Step 3: Erase any expression from final indices which is dependent on a var not present in + * initial indices. + * target transformation = lambda C, H, W -> (H, W, C // 4, C %4) + * + * Step 4: Go over the target spatial layout and add any missing dims to both initial and final + * indices. This is done by checking if any iterator in target spatial layout is not present in + * source spatial layout. + * target transformation = lambda dim, C, H, W -> (dim, H, W, C // 4, C %4) + */ +using VarSet = std::unordered_set; +static Optional InferLayoutTransformation(const SpatialLayout& src_spatial_layout, + const IndexMap& src_transformation, + const SpatialLayout& tgt_spatial_layout) { + // Copy over the src transformation intial and final indices + auto initial_indices = support::AsList(src_transformation->initial_indices); + auto final_indices = support::AsList(src_transformation->final_indices); + + // Get the iterator var set used in target spatial layout. + VarSet tgt_var_set; + for (const auto& i : tgt_spatial_layout) { + if (i.defined()) tgt_var_set.insert(i.value()); + } + + // Erase initial indices corresponding to iter vars that do not occur in target spatial layout. + // Also compute the var set of initial indices. + auto initial_indices_it = initial_indices.begin(); + VarSet initial_indices_var_set; + for (const auto& i : src_spatial_layout) { + ICHECK(i.defined()); + if (tgt_var_set.count(i.value())) { + initial_indices_var_set.insert(*initial_indices_it); + initial_indices_it++; + continue; + } + initial_indices_it = initial_indices.erase(initial_indices_it); + } + + // Erase any expressions in final indices that have undefined vars + auto final_indices_it = final_indices.begin(); + while (final_indices_it != final_indices.end()) { + // Collect all the vars used in this final index. + Array used_vars = tir::UndefinedVars(*final_indices_it); + ICHECK(!used_vars.empty()) + << "IndexMap expression must always contain tir::Var nodes but found none in: " + << *final_indices_it; + + bool has_undefined_vars = std::any_of(used_vars.begin(), used_vars.end(), + [&initial_indices_var_set](const tir::Var& v) { + return initial_indices_var_set.count(v) == 0; + }); + + // If all vars are from initial indices, nothing to do for this final index. + if (!has_undefined_vars) { + final_indices_it++; + continue; + } + // We are about to drop this expr from final indices since it has undefined vars. Check if it is + // dependent on any of the initial indices. If it is dependent, this cannot be dropped and we + // bail by returning null. + // This captures the scenario where the source transformation is unpacking a dimension (e.g, + // "H4h" -> "H*4+h" ) and the buffer we are trying to infer the transformation of has 'h' + // dimension, but not 'H'. So, it is dependent on undefined var 'H' and defined var 'h'. + bool depends_on_initial_indices = std::any_of(used_vars.begin(), used_vars.end(), + [&initial_indices_var_set](const tir::Var& v) { + return initial_indices_var_set.count(v) != 0; + }); + if (depends_on_initial_indices) { + LOG(WARNING) + << "[LayoutInference] Buffer access is dependent on both defined and undefined vars"; + return {}; + } + // It is ok to erase this final index expression as it only depends on undefined vars. + final_indices_it = final_indices.erase(final_indices_it); + } + + // Go over the target spatial layout and add any missing dims to both initial and final indices. + // This is done by checking if any iterator in target spatial layout is not present in source + // spatial layout. + VarSet src_var_set; + for (const auto& i : src_spatial_layout) { + ICHECK(i.defined()); + src_var_set.insert(i.value()); + } + + initial_indices_it = initial_indices.begin(); + final_indices_it = final_indices.begin(); + for (const auto& i : tgt_spatial_layout) { + if (i.defined() && src_var_set.count(i.value())) { + initial_indices_it++; + if (final_indices_it != final_indices.end()) final_indices_it++; + continue; + } + + auto new_dim = tir::Var("d"); + initial_indices.insert(initial_indices_it, new_dim); + final_indices.insert(final_indices_it, new_dim); + } + + return IndexMap(support::AsArray(initial_indices), support::AsArray(final_indices)); +} + +/*! + * \brief Analyzes the Block and given output buffer transformations to propose + * transformations of block and read buffers. + * \details It does a best effort analysis to propose transformations which would preserve + * sequential access to buffers (especially output buffers). Since this is best effort, it is + * possible that the Block is too complex for analysis. In such a case, no transformations are + * proposed. Limitations: + * 1. Expects exactly one write buffer in the block whose transformation is given by + * `write_transformation`. + * 2. Expects write buffer access to be affine and only use spatial iterators of the block. + * 3. Proposes transformations to a read buffer if all access to it are affine. + */ +class BlockAnalyzer : public StmtExprVisitor { + public: + explicit BlockAnalyzer(const Block& block, const Map& transformation_cache, + IndexMap write_transformation) + : can_transform_block_(true), + write_transformation_(write_transformation), + block_(block), + buffer_transformation_cache_(transformation_cache) { + ICHECK(block_->writes.size() == 1); + auto write_buffer = block_->writes[0]->buffer; + + ComputeBlockSpatialDomain(); + + // Visit the block body to collect load/store access patterns of different buffers. + VisitStmt(block_->body); + + // While visiting the load/store accesses it is possible we see an unexpected pattern, such as + // nested block or write access to multiple buffers. In such a case, we can return early as we + // would not be making any layout suggesstions. + if (!can_transform_block_) { + LOG(WARNING) << "[LayoutInference] Unable to transform block " << block->name_hint; + return; + } + + // Get iterator ordering and it's spatial layout. + VarToBlockIndexMap iter_var_to_block_index; + SpatialLayout block_spatial_layout; + int index = 0; + for (const auto& iter_var : block->iter_vars) { + auto var = iter_var->var; + iter_var_to_block_index[var] = index++; + block_spatial_layout.push_back(var); + } + + // Helper to get the spatial layout of buffer from buffer access map. + auto get_spatial_layout = [&](Buffer b) -> SpatialLayout { + auto it = buffer_access_info_.find(b); + if (it == buffer_access_info_.end()) { + return {}; + } + auto access_info = it->second; + return access_info.GetValidSpatialLayout(); + }; + + // Check that write has sequential access within the block. + SpatialLayout write_spatial_layout = get_spatial_layout(write_buffer); + if (write_spatial_layout.empty()) { + can_transform_block_ = false; + return; + } + if (!IsSequentialAccess(write_spatial_layout, iter_var_to_block_index)) { + can_transform_block_ = false; + return; + } + + // Infer Block transformation from write buffer transformation. + auto maybe_block_transformation = InferLayoutTransformation( + write_spatial_layout, write_transformation_, block_spatial_layout); + if (!maybe_block_transformation.defined()) { + can_transform_block_ = false; + return; + } + block_transformation_ = maybe_block_transformation.value(); + + Array block_ranges = block_->iter_vars.Map([](const IterVar& i) { return i->dom; }); + if (!IsBijectiveAffine(block_transformation_, block_ranges)) { + can_transform_block_ = false; + LOG(WARNING) << "[LayoutInference] Inferred block transformation is not bijective affine, " + "transformation: (" + << block_transformation_ << ") over range (" << block_ranges << ")"; + return; + } + + // Infer read buffer transformations from write buffer transformation. + for (const auto& r : block->reads) { + SpatialLayout read_spatial_layout = get_spatial_layout(r->buffer); + if (read_spatial_layout.empty()) continue; + if (!IsSequentialAccess(read_spatial_layout, iter_var_to_block_index)) continue; + + auto maybe_read_transformation = InferLayoutTransformation( + write_spatial_layout, write_transformation_, read_spatial_layout); + if (!maybe_read_transformation.defined()) continue; + IndexMap read_transformation = maybe_read_transformation.value(); + if (buffer_transformation_cache_.count(r->buffer) != 0) { + if (!AreIdenticalTransforms(read_transformation, buffer_transformation_cache_[r->buffer])) + LOG(WARNING) << "[LayoutInference] Buffer: " << r->buffer + << " has conflicting transform proposals -- (preferred) " + << buffer_transformation_cache_[r->buffer] << " vs. " << read_transformation; + continue; + } + read_buffer_transformations_.Set(r->buffer, read_transformation); + } + } + + private: + // Helper class to keep track of spatial layout of buffer as we visit multiple accesses to this + // buffer within the block. + class BufferAccessInfo { + public: + BufferAccessInfo() : is_valid_(true) {} + void Update(SpatialLayout s) { + if (!IsValid()) return; + if (spatial_layout_.empty()) spatial_layout_ = s; + if (!AreIdenticalSpatialAccess(s, spatial_layout_)) { + Invalidate(); + return; + } + } + bool IsValid() { return is_valid_; } + void Invalidate() { is_valid_ = false; } + SpatialLayout GetValidSpatialLayout() { + if (!IsValid()) return {}; + return spatial_layout_; + } + + private: + bool is_valid_; + SpatialLayout spatial_layout_; + }; + + // Helper to break down the indices of buffer access. + SpatialLayout DetectBufferAccessIterMap(Array indices) { + auto result = arith::DetectIterMap( + /*indices=*/indices, /*input_iters*/ spatial_dom_, + /*predicate*/ 1, /*check_level*/ arith::IterMapLevel::NoCheck, &arith_analyzer_); + if (result->indices.empty()) { + LOG(WARNING) << "[LayoutInference] Failed to analyze indices " << indices + << ", error: " << result->errors; + return {}; + } + return GetSpatialLayout(result); + } + + // Compute the spatial domain map of block + void ComputeBlockSpatialDomain() { + for (const IterVar& v : block_->iter_vars) { + if (v->iter_type == kDataPar) { + spatial_dom_.Set(v->var, v->dom); + continue; + } + if (v->iter_type == kCommReduce) continue; + LOG(WARNING) << "[LayoutInference] Cannot compute block spatial domain in presence of " + "unknown block iter_type : " + << v->iter_type; + can_transform_block_ = false; + return; + } + } + + void VisitStmt_(const BlockNode* op) final { + // Blocks with nested blocks cannot be handled yet. + LOG(WARNING) << "[LayoutInference] Nested blocks are not supported for layout inference yet"; + can_transform_block_ = false; + } + void VisitStmt_(const BufferStoreNode* op) final { + StmtExprVisitor::VisitStmt_(op); + + BufferAccessInfo& access_info = buffer_access_info_[op->buffer]; + + // Fast path to ignore further analysis if we know that the buffer access is invalid. + if (!access_info.IsValid()) return; + + // Only single write buffer is supported for each block. + if (!op->buffer.same_as(block_->writes[0]->buffer)) { + access_info.Invalidate(); + LOG(WARNING) << "[LayoutInference] Exactly one write buffer is supported for layout " + "inference, found two: " + << op->buffer << " and " << block_->writes[0]->buffer; + can_transform_block_ = false; + return; + } + + // If the write buffer access cannot be analyzed, no transformation to the block will be made. + auto detected_spatial_layout = DetectBufferAccessIterMap(op->indices); + if (detected_spatial_layout.empty()) { + access_info.Invalidate(); + return; + } + + // Check if we have access info for this buffer, if present, the two accesses must be + // identical. + access_info.Update(detected_spatial_layout); + } + + void VisitExpr_(const BufferLoadNode* op) final { + Buffer read_buffer = op->buffer; + BufferAccessInfo& access_info = buffer_access_info_[op->buffer]; + + auto detected_spatial_layout = DetectBufferAccessIterMap(op->indices); + + if (detected_spatial_layout.empty()) { + access_info.Invalidate(); + return; + } + access_info.Update(detected_spatial_layout); + } + + public: + bool CanBeTransformed() { return can_transform_block_; } + IndexMap GetBlockTransformation() { return block_transformation_; } + Map GetReadBufferTransformations() { return read_buffer_transformations_; } + + private: + bool can_transform_block_; + IndexMap write_transformation_; + Map spatial_dom_; + arith::Analyzer arith_analyzer_; + + Block block_; + IndexMap block_transformation_; + + Map read_buffer_transformations_; + const Map& buffer_transformation_cache_; + std::unordered_map buffer_access_info_; +}; + +/*! + * \brief Analyzes the PrimFunc and user provided output buffer transformations to propose + * transformations of block and buffers within the PrimFunc. + * \details It does a best effort analysis to propose transformations which would preserve + * sequential access to buffers (especially output buffers). Since this is best effort, it is + * possible that the PrimFunc is too complex for analysis. In such a case, no transformations are + * proposed. + */ +class PrimFuncAnalyzer : public StmtExprVisitor { + public: + explicit PrimFuncAnalyzer(const PrimFunc& func, Array write_transformations) { + ICHECK_LE(write_transformations.size(), func->params.size()) + << "Incompatible PrimFunc and write_transformations"; + + size_t first_write_index = func->params.size() - write_transformations.size(); + for (size_t i = 0; i < write_transformations.size(); ++i) { + auto param = func->params[first_write_index + i]; + Optional param_buf = func->buffer_map.Get(param); + ICHECK(param_buf.defined()); + ICHECK_EQ(param_buf.value()->shape.size(), write_transformations[i]->initial_indices.size()) + << "Mismatch between output buffer shape and index map"; + buffer_transformation_cache_.Set(param_buf.value(), write_transformations[i]); + } + VisitStmt(func->body); + } + Map> GetSuggestedTransforms() { + Map> result; + for (const auto& [block, index_map] : block_transformations_) { + Map block_transformations; + block_transformations.Set(block, index_map); + for (const auto& buffer : block_to_buffer_[block]) { + block_transformations.Set(buffer, buffer_transformation_cache_[buffer]); + } + result.Set(block, block_transformations); + } + return result; + } + + private: + void VisitStmt_(const BlockNode* op) final { + if (op->name_hint == "root") { + // Skip the root block + StmtVisitor::VisitStmt_(op); + return; + } + + Block block = GetRef(op); + // Get block write buffer transformation. + if (block->writes.size() != 1) return; + auto write_buffer = block->writes[0]->buffer; + block_to_buffer_[block].push_back(write_buffer); + BlockAnalyzer block_analyzer(block, buffer_transformation_cache_, + buffer_transformation_cache_[write_buffer]); + + if (!block_analyzer.CanBeTransformed()) return; + // Collect the suggested transformations + block_transformations_.Set(block, block_analyzer.GetBlockTransformation()); + + for (const auto& [buffer, index_map] : block_analyzer.GetReadBufferTransformations()) { + // BlockAnalyzer makes sure that it does not propose transformation for a buffer for which a + // transformation has already been proposed by other blocks or by write_transformations which + // are input to this analysis. + ICHECK_EQ(buffer_transformation_cache_.count(buffer), 0); + buffer_transformation_cache_.Set(buffer, index_map); + block_to_buffer_[block].push_back(buffer); + } + } + + private: + Map buffer_transformation_cache_; + Map block_transformations_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> block_to_buffer_; +}; + +Map> SuggestLayoutTransforms( + const PrimFunc& prim_func, Array write_buffer_transformations) { + // No changes to the PrimFunc are required if no transformations on output buffers. + if (write_buffer_transformations.empty()) return {}; + + PrimFuncAnalyzer analyzer(prim_func, write_buffer_transformations); + return analyzer.GetSuggestedTransforms(); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.suggest_layout_transforms")) + .set_body_typed([](PrimFunc fn, Array write_buffer_transformations) { + return SuggestLayoutTransforms(fn, write_buffer_transformations); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/support/array.h b/src/support/array.h index 218150f9dba0..0ca57a2410c5 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -21,6 +21,7 @@ #include #include +#include #include namespace tvm { @@ -81,11 +82,35 @@ inline std::vector AsVector(const Array& vec); * \brief Convert a std::vector to tvm::runtime::Array * \tparam TSrc The type of elements in the source vector * \tparam TDst The type of elements in the result Array - * \return The result vector + * \return The result Array */ template inline Array AsArray(const std::vector& vec); +/*! + * \brief Convert a tvm::runtime::Array to std::list + * \tparam T The type of elements in the source array + * \return The result list + */ +template +inline std::list AsList(const Array& array) { + std::list list; + for (const auto& v : array) list.push_back(v); + return list; +} + +/*! + * \brief Convert a std::list to tvm::runtime::Array + * \tparam T The type of elements in the source list + * \return The result list + */ +template +inline Array AsArray(const std::list& list) { + Array array; + for (const auto& v : list) array.push_back(v); + return array; +} + /*! * \brief Get the shape tuple as array * \param shape The shape tuple diff --git a/tests/python/relax/test_analysis_suggest_layout_transforms.py b/tests/python/relax/test_analysis_suggest_layout_transforms.py new file mode 100644 index 000000000000..2850f0ed9f94 --- /dev/null +++ b/tests/python/relax/test_analysis_suggest_layout_transforms.py @@ -0,0 +1,831 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm.testing + +from tvm import relax, tir +from tvm.script import tir as T + + +def apply_transformations(func, suggested_transfoms, print_transformation=False): + sch = tir.Schedule(func) + for block, per_block_transformations in suggested_transfoms.items(): + blockrv = sch.get_block(block.name_hint) + for obj, index_map in per_block_transformations.items(): + if isinstance(obj, tir.Block): + block_name = obj.name_hint + if print_transformation: + print("Block transformation: ", block_name, " :: ", index_map) + sch.transform_block_layout(block_name, index_map) + else: + assert isinstance(obj, tir.Buffer) + buffer = obj + if print_transformation: + print("Buffer transformation: ", buffer, " :: ", index_map) + sch.transform_layout(blockrv, buffer, index_map) + return sch.mod["main"] + + +def test_nested_blocks(): + @T.prim_func + def nested_block( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i, j in T.grid(32, 64): + with T.block("outer"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(arg[v_i, v_j, 0:224, 0:224]) + T.writes(relu[v_i, v_j, 0:224, 0:224]) + for k, l in T.grid(224, 224): + with T.block("inner"): + v_k, v_l = T.axis.remap("SS", [k, l]) + T.reads(arg[v_i, v_j, v_k, v_l]) + T.writes(relu[v_i, v_j, v_k, v_l]) + relu[v_i, v_j, v_k, v_l] = T.max(arg[v_i, v_j, v_k, v_l], T.float32(0)) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=nested_block, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + # no suggestions for nested block. + assert len(suggested_transforms.items()) == 0 + + +def test_mismatch_transformations_and_num_params(): + @T.prim_func + def elemwise( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(relu[v_i0, v_i1, v_i2, v_i3]) + relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + with pytest.raises(tvm.TVMError, match="Incompatible PrimFunc and write_transformations"): + _ = relax.analysis.suggest_layout_transforms( + func=elemwise, + write_buffer_transforms=[ + lambda n, c, h, w: (n, h, w, c), + lambda n, c, h, w: (n, h, w, c), + lambda n, c, h, w: (n, h, w, c), + ], + ) + + +def test_empty_write_transformations(): + @T.prim_func + def elemwise( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(relu[v_i0, v_i1, v_i2, v_i3]) + relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=elemwise, write_buffer_transforms=[] + ) + assert len(suggested_transforms.items()) == 0 + + +def test_non_bijective_block_transform(): + @T.prim_func + def before( + arg: T.Buffer((32, 64), "float32"), + output: T.Buffer((32, 64), "float32"), + ): + for ax0, ax1 in T.grid(32, 64): + with T.block("compute"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(arg[v_ax0, v_ax1]) + T.writes(output[v_ax0, v_ax1]) + output[v_ax0, v_ax1] = arg[v_ax0, v_ax1] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c: (n, c // 5, c % 5)] + ) + assert len(suggested_transforms.items()) == 0 + + +def test_non_affine_access(): + @T.prim_func + def before( + arg: T.Buffer((32, 64), "float32"), + output: T.Buffer((32 * 64, 10), "float32"), + ): + for ax0, ax1, ax2 in T.grid(32, 64, 10): + with T.block("compute"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(arg[v_ax0, v_ax1]) + T.writes(output[v_ax0 * v_ax1, v_ax2]) + output[v_ax0 * v_ax1, v_ax2] = arg[v_ax0, v_ax1] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda a, b: (b, a)] + ) + assert len(suggested_transforms.items()) == 0 + + +def test_unsupported_write_spatial_layout(): + @T.prim_func + def before( + arg: T.Buffer((4, 4), "float32"), + output: T.Buffer((16), "float32"), + ): + for ax0, ax1 in T.grid(4, 4): + with T.block("flatten"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(arg[v_ax0, v_ax1]) + T.writes(output[v_ax0 * 4 + v_ax1]) + output[v_ax0 * 4 + v_ax1] = arg[v_ax0, v_ax1] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda a: (a // 4, a % 4)] + ) + assert len(suggested_transforms.items()) == 0 + + +def test_unpacked_iter_used_in_read_access(): + @T.prim_func + def before( + arg: T.Buffer((8, 4), "float32"), + output: T.Buffer((4, 8), "float32"), + ): + for ax0, ax1, ax2 in T.grid(4, 8, 4): + with T.block("compute"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(arg[v_ax1, v_ax2]) + T.writes(output[v_ax0, v_ax1]) + output[v_ax0, v_ax1] = arg[v_ax1, v_ax2] + + @T.prim_func + def expected( + arg: T.Buffer((8, 4), "float32"), + output: T.Buffer((32), "float32"), + ): + for ax0, ax2 in T.grid(32, 4): + with T.block("compute"): + v_ax0, v_ax2 = T.axis.remap("SS", [ax0, ax2]) + T.reads(arg[v_ax0 % 8, v_ax2]) + T.writes(output[v_ax0]) + output[v_ax0] = arg[v_ax0 % 8, v_ax2] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda a, b: (a * 8 + b)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_invalid_index_map(): + @T.prim_func + def elemwise( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(relu[v_i0, v_i1, v_i2, v_i3]) + relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + with pytest.raises(tvm.TVMError, match="Mismatch between output buffer shape and index map"): + _ = relax.analysis.suggest_layout_transforms( + func=elemwise, write_buffer_transforms=[lambda n, h, w: (n, w, h)] + ) + with pytest.raises(AssertionError): + _ = relax.analysis.suggest_layout_transforms(func=elemwise, write_buffer_transforms=[2]) + + +def test_SRSR_block(): + @T.prim_func + def before( + arg: T.Buffer((32, 224, 64, 224), "float32"), + sum: T.Buffer((32, 64), "float32"), + ): + for ax0, k2, ax1, k3 in T.grid(32, 224, 64, 224): + with T.block("rxplaceholder_red"): + v_ax0, v_k2, v_ax1, v_k3 = T.axis.remap("SRSR", [ax0, k2, ax1, k3]) + T.reads(arg[v_ax0, v_ax1, v_k2, v_k3]) + T.writes(sum[v_ax0, v_ax1]) + with T.init(): + sum[v_ax0, v_ax1] = T.float32(0) + sum[v_ax0, v_ax1] = sum[v_ax0, v_ax1] + arg[v_ax0, v_k2, v_ax1, v_k3] + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 16, 224, 4), "float32"), + sum: T.Buffer((32, 16, 4), "float32"), + ): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 16, 224, 4): + with T.block("rxplaceholder_red"): + v0, v1, v2, v3, v4 = T.axis.remap("SRSRS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1, v2, v3, v4]) + T.writes(sum[v0, v2, v4]) + with T.init(): + sum[v0, v2, v4] = T.float32(0) + sum[v0, v2, v4] = sum[v0, v2, v4] + arg[v0, v1, v2, v3, v4] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c: (n, c // 4, c % 4)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_elemwise_symbolic(): + @T.prim_func + def before(arg: T.handle, relu: T.handle): + N = T.int64() + C = T.int64() + H = T.int64() + W = T.int64() + Arg = T.match_buffer(arg, (N, C, H, W)) + Relu = T.match_buffer(relu, (N, C, H, W)) + for i0, i1, i2, i3 in T.grid(N, C, H, W): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(Arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(Relu[v_i0, v_i1, v_i2, v_i3]) + Relu[v_i0, v_i1, v_i2, v_i3] = T.max(Arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + @T.prim_func + def expected(arg: T.handle, relu: T.handle): + N = T.int64() + C = T.int64() + H = T.int64() + W = T.int64() + Arg = T.match_buffer(arg, (N, H, W, C)) + Relu = T.match_buffer(relu, (N, H, W, C)) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C): + with T.block("compute"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(Arg[v0, v1, v2, v3]) + T.writes(Relu[v0, v1, v2, v3]) + Relu[v0, v1, v2, v3] = T.max(Arg[v0, v1, v2, v3], T.float32(0)) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_elemwise(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(relu[v_i0, v_i1, v_i2, v_i3]) + relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 64), "float32"), + relu: T.Buffer((32, 224, 224, 64), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 64): + with T.block("compute"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v1, v2, v3]) + T.writes(relu[v0, v1, v2, v3]) + relu[v0, v1, v2, v3] = T.max(arg[v0, v1, v2, v3], T.float32(0)) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_pool_nchw_nhwc(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + pool_max: T.Buffer((32, 64, 111, 223), "float32"), + ): + for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(32, 64, 111, 223, 2, 2): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap( + "SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1] + ) + T.reads( + arg[ + v_ax0, + v_ax1, + v_ax2 * 2 + v_rv0 * 2, + v_ax3 + v_rv1, + ] + ) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(-3.4028234663852886e38) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3], + arg[ + v_ax0, + v_ax1, + v_ax2 * 2 + v_rv0 * 2, + v_ax3 + v_rv1, + ], + ) + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 64), "float32"), + pool_max: T.Buffer((32, 111, 223, 64), "float32"), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(32, 111, 223, 64, 2, 2): + with T.block("pool_max"): + v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5]) + T.reads(arg[v0, v1 * 2 + v4 * 2, v2 + v5, v3]) + T.writes(pool_max[v0, v1, v2, v3]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v0, v1, v2, v3] = T.float32(-3.4028234663852886e38) + pool_max[v0, v1, v2, v3] = T.max( + pool_max[v0, v1, v2, v3], + arg[v0, v1 * 2 + v4 * 2, v2 + v5, v3], + ) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, + write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)], + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_pool_nchw16c_nhwc(): + @T.prim_func + def before( + arg: T.Buffer( + (32, 4, 224, 224, 16), + "float32", + ), + pool_max: T.Buffer( + (32, 4, 110, 220, 16), + "float32", + ), + ): + for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(32, 4, 110, 220, 16, 5, 5): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( + "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] + ) + T.reads(arg[v_ax0, v_ax1, v_ax2 * 2 + v_rv0, v_ax3 + v_rv1, v_ax4]) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32(-3.4028234663852886e38) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], + arg[v_ax0, v_ax1, v_ax2 * 2 + v_rv0, v_ax3 + v_rv1, v_ax4], + ) + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 64), "float32"), + pool_max: T.Buffer((32, 110, 220, 64), "float32"), + ): + for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(32, 110, 220, 64, 5, 5): + with T.block("pool_max"): + v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5]) + T.reads(arg[v0, v1 * 2 + v4, v2 + v5, v3]) + T.writes(pool_max[v0, v1, v2, v3]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v0, v1, v2, v3] = T.float32(-3.4028234663852886e38) + pool_max[v0, v1, v2, v3] = T.max( + pool_max[v0, v1, v2, v3], + arg[v0, v1 * 2 + v4, v2 + v5, v3], + ) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, + write_buffer_transforms=[lambda n, C, h, w, c: (n, h, w, C * 16 + c)], + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_reduce(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + sum: T.Buffer((32, 64), "float32"), + ): + for ax0, ax1, k2, k3 in T.grid(32, 64, 224, 224): + with T.block("rxplaceholder_red"): + v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3]) + T.reads(arg[v_ax0, v_ax1, v_k2, v_k3]) + T.writes(sum[v_ax0, v_ax1]) + with T.init(): + sum[v_ax0, v_ax1] = T.float32(0) + sum[v_ax0, v_ax1] = sum[v_ax0, v_ax1] + arg[v_ax0, v_ax1, v_k2, v_k3] + + @T.prim_func + def expected( + arg: T.Buffer((32, 4, 224, 224, 16), "float32"), + sum: T.Buffer((32, 4, 16), "float32"), + ): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 4, 224, 224, 16): + with T.block("rxplaceholder_red"): + v0, v1, v2, v3, v4 = T.axis.remap("SSRRS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1, v2, v3, v4]) + T.writes(sum[v0, v1, v4]) + with T.init(): + sum[v0, v1, v4] = T.float32(0) + sum[v0, v1, v4] = sum[v0, v1, v4] + arg[v0, v1, v2, v3, v4] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c: (n, c // 16, c % 16)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_upsampling(): + # relay materializes the layout if H, W or D dimensions are moved or tiled. + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + resize: T.Buffer((32, 64, 202, 246), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 202, 246): + with T.block("resize"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, 0:224, 0:224]) + T.writes(resize[v_i0, v_i1, v_i2, v_i3]) + resize[v_i0, v_i1, v_i2, v_i3] = arg[ + v_i0, + v_i1, + T.max( + T.min( + T.Cast( + "int64", + T.floor( + T.float32(1.1089109182357788) * T.Cast("float32", v_i2) + + T.float32(1.0000000000000001e-05) + ), + ), + 223, + ), + 0, + ), + T.max( + T.min( + T.Cast( + "int64", + T.floor( + T.float32(0.91056913137435913) * T.Cast("float32", v_i3) + + T.float32(1.0000000000000001e-05) + ), + ), + 223, + ), + 0, + ), + ] + + @T.prim_func + def expected( + arg: T.Buffer((32, 64, 224, 224), "float32"), + resize: T.Buffer((32, 202, 246, 64), "float32"), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(32, 202, 246, 64): + with T.block("resize"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v3, 0:224, 0:224]) + T.writes(resize[v0, v1, v2, v3]) + resize[v0, v1, v2, v3] = arg[ + v0, + v3, + T.max( + T.min( + T.Cast( + "int64", + T.floor( + T.float32(1.1089109182357788) * T.Cast("float32", v1) + + T.float32(1.0000000000000001e-05) + ), + ), + T.int64(223), + ), + T.int64(0), + ), + T.max( + T.min( + T.Cast( + "int64", + T.floor( + T.float32(0.91056913137435913) * T.Cast("float32", v2) + + T.float32(1.0000000000000001e-05) + ), + ), + T.int64(223), + ), + T.int64(0), + ), + ] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_strided_slice(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + T_strided_slice_with_axes: T.Buffer((32, 64, 10, 8), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 64, 10, 8): + with T.block("T_strided_slice_with_axes"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads( + arg[ + v_ax0, + v_ax1, + v_ax2 * 5 + 2, + v_ax3 * 7 + 4, + ] + ) + T.writes(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3]) + T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3] = arg[ + v_ax0, + v_ax1, + v_ax2 * 5 + 2, + v_ax3 * 7 + 4, + ] + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 16, 4), "float32"), + T_strided_slice_with_axes: T.Buffer((32, 10, 8, 16, 4), "float32"), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 10, 8, 16, 4): + with T.block("T_strided_slice_with_axes"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1 * 5 + 2, v2 * 7 + 4, v3, v4]) + T.writes(T_strided_slice_with_axes[v0, v1, v2, v3, v4]) + T_strided_slice_with_axes[v0, v1, v2, v3, v4] = arg[ + v0, v1 * 5 + 2, v2 * 7 + 4, v3, v4 + ] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c // 4, c % 4)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_binary_broadcast(): + @T.prim_func + def before( + arg0: T.Buffer((32, 64, 224, 224), "float32"), + arg1: T.Buffer((64, 224, 224), "float32"), + T_add: T.Buffer((32, 64, 224, 224), "float32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(32, 64, 224, 224): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads( + arg0[v_ax0, v_ax1, v_ax2, v_ax3], + arg1[v_ax1, v_ax2, v_ax3], + ) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = ( + arg0[v_ax0, v_ax1, v_ax2, v_ax3] + arg1[v_ax1, v_ax2, v_ax3] + ) + + @T.prim_func + def expected( + arg0: T.Buffer((32, 224, 224, 16, 4), "float32"), + arg1: T.Buffer((224, 224, 16, 4), "float32"), + T_add: T.Buffer((32, 224, 224, 16, 4), "float32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 16, 4): + with T.block("T_add"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg0[v0, v1, v2, v3, v4], arg1[v1, v2, v3, v4]) + T.writes(T_add[v0, v1, v2, v3, v4]) + T_add[v0, v1, v2, v3, v4] = arg0[v0, v1, v2, v3, v4] + arg1[v1, v2, v3, v4] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c // 4, c % 4)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_transpose(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + T_transpose: T.Buffer((32, 224, 224, 64), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 64): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax3, v_ax1, v_ax2]) + T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax3, v_ax1, v_ax2] + + @T.prim_func + def expected( + arg: T.Buffer((32, 64, 224, 224), "float32"), + T_transpose: T.Buffer((32, 224, 64, 224), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 64, 224): + with T.block("T_transpose"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v2, v3, v1]) + T.writes(T_transpose[v0, v1, v2, v3]) + T_transpose[v0, v1, v2, v3] = arg[v0, v2, v3, v1] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_pad(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + PadInput: T.Buffer((32, 64, 230, 230), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 230, 230): + with T.block("PadInput"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2 - 2, v_i3 - 2]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else( + 2 <= v_i2 and v_i2 < 226 and 2 <= v_i3 and v_i3 < 226, + arg[v_i0, v_i1, v_i2 - 2, v_i3 - 2], + T.float32(2), + ) + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 16, 4), "float32"), + PadInput: T.Buffer((32, 230, 230, 16, 4), "float32"), + ): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 230, 230, 16, 4): + with T.block("PadInput"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1 - 2, v2 - 2, v3, v4]) + T.writes(PadInput[v0, v1, v2, v3, v4]) + PadInput[v0, v1, v2, v3, v4] = T.if_then_else( + 2 <= v1 and v1 < 226 and 2 <= v2 and v2 < 226, + arg[v0, v1 - 2, v2 - 2, v3, v4], + T.float32(2), + ) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c // 4, c % 4)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_split(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + split0: T.Buffer((32, 32, 224, 224), "float32"), + split1: T.Buffer((32, 32, 224, 224), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): + with T.block("T_split_sections"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(split0[v_ax0, v_ax1, v_ax2, v_ax3]) + split0[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): + with T.block("T_split_sections_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3]) + T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3]) + split1[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3] + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 64), "float32"), + split0: T.Buffer((32, 224, 224, 32), "float32"), + split1: T.Buffer((32, 224, 224, 32), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 32): + with T.block("T_split_sections"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v1, v2, v3]) + T.writes(split0[v0, v1, v2, v3]) + split0[v0, v1, v2, v3] = arg[v0, v1, v2, v3] + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 32): + with T.block("T_split_sections_1"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v1, v2, v3 + 32]) + T.writes(split1[v0, v1, v2, v3]) + split1[v0, v1, v2, v3] = arg[v0, v1, v2, v3 + 32] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, + write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c), lambda n, c, h, w: (n, h, w, c)], + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_split_tiling_split_dim(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + split0: T.Buffer((32, 32, 224, 224), "float32"), + split1: T.Buffer((32, 32, 224, 224), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): + with T.block("T_split_sections"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(split0[v_ax0, v_ax1, v_ax2, v_ax3]) + split0[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): + with T.block("T_split_sections_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3]) + T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3]) + split1[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3] + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 16, 4), "float32"), + split0: T.Buffer((32, 224, 224, 8, 4), "float32"), + split1: T.Buffer((32, 224, 224, 8, 4), "float32"), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 8, 4): + with T.block("T_split_sections"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1, v2, v3, v4]) + T.writes(split0[v0, v1, v2, v3, v4]) + split0[v0, v1, v2, v3, v4] = arg[v0, v1, v2, v3, v4] + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 8, 4): + with T.block("T_split_sections_1"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1, v2, v3 + 8, v4]) + T.writes(split1[v0, v1, v2, v3, v4]) + split1[v0, v1, v2, v3, v4] = arg[v0, v1, v2, v3 + 8, v4] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, + write_buffer_transforms=[ + lambda n, c, h, w: (n, h, w, c // 4, c % 4), + lambda n, c, h, w: (n, h, w, c // 4, c % 4), + ], + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +if __name__ == "__main__": + tvm.testing.main() From 3f12d4df596d0ff424edd821529e392854427595 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Thu, 23 Feb 2023 07:32:21 -0800 Subject: [PATCH 67/81] [Unity] Remove attributes of relax.print, assert and unique (#14101) Remove the attributes of operators assert, print and unique. Use PrimValue as substitute. Co-authored-by: Steven S. Lyubomirsky [slyubomirsky@gmail.com](mailto:slyubomirsky@gmail.com) Co-authored-by: Prakalp Srivastava [prakalp@octoml.ai](mailto:prakalp@octoml.ai) --- include/tvm/relax/attrs/set.h | 62 ----------- include/tvm/relax/op_attr_types.h | 21 ---- python/tvm/relax/op/base.py | 28 +++-- python/tvm/relax/op/builtin/builtin.py | 16 ++- python/tvm/relax/op/op_attrs.py | 5 - python/tvm/relax/op/set.py | 33 ++++-- src/relax/backend/vm/codegen_vm.cc | 24 ++-- src/relax/op/op.cc | 72 +++++++++++- src/relax/op/tensor/set.cc | 80 +++++++++---- src/relax/op/tensor/set.h | 7 +- tests/python/relax/test_relax_operators.py | 117 +++++++++++++++++++- tests/python/relax/test_tvmscript_parser.py | 4 +- 12 files changed, 307 insertions(+), 162 deletions(-) delete mode 100644 include/tvm/relax/attrs/set.h diff --git a/include/tvm/relax/attrs/set.h b/include/tvm/relax/attrs/set.h deleted file mode 100644 index 3fae7646ff8e..000000000000 --- a/include/tvm/relax/attrs/set.h +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/relax/attrs/set.h - * \brief Attributes for set operators. - */ -#ifndef TVM_RELAX_ATTRS_SET_H_ -#define TVM_RELAX_ATTRS_SET_H_ - -#include - -namespace tvm { -namespace relax { - -/*! \brief Attributes used in unique operator */ -struct UniqueAttrs : public tvm::AttrsNode { - bool sorted; - bool return_index; - bool return_inverse; - bool return_counts; - Optional axis; - - TVM_DECLARE_ATTRS(UniqueAttrs, "relax.attrs.UniqueAttrs") { - TVM_ATTR_FIELD(sorted).describe( - "Whether to sort the unique elements in ascending order before returning as output."); - TVM_ATTR_FIELD(return_index) - .describe( - "Whether to return an additional tensor with indices for where elements in the unique " - "tensor come from the original input."); - TVM_ATTR_FIELD(return_inverse) - .describe( - "Whether to return an additional tensor with indices for where elements in the " - "original input ended up in the returned unique list."); - TVM_ATTR_FIELD(return_counts) - .describe("Whether to return an additional tensor with counts of each unique elements"); - TVM_ATTR_FIELD(axis).describe( - "The dimension to apply unique. If it is NullOpt, the unique values of the flattened input " - "is are returned."); - } -}; // struct UniqueAttrs - -} // namespace relax -} // namespace tvm - -#endif // TVM_RELAX_ATTRS_SET_H_ diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index a34cf251dc33..413d3e0499d0 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -58,27 +58,6 @@ using FCallPacked = String; */ using FLegalize = runtime::TypedPackedFunc; -struct PrintAttrs : public tvm::AttrsNode { - std::string format; - TVM_DECLARE_ATTRS(PrintAttrs, "relax.attrs.PrintAttrs") { - TVM_ATTR_FIELD(format) - .describe("Python-style format string to use for displaying the input. Ignored if empty.") - .set_default(""); - } -}; - -struct AssertOpAttrs : public tvm::AttrsNode { - std::string format; - TVM_DECLARE_ATTRS(AssertOpAttrs, "relax.attrs.AssertOpAttrs") { - TVM_ATTR_FIELD(format) - .describe( - "Python-style format string to use for displaying " - "an error message if the assert fails. " - "Ignored if empty.") - .set_default(""); - } -}; - } // namespace relax } // namespace tvm #endif // TVM_RELAX_OP_ATTR_TYPES_H_ diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index d76b155beb83..0b298679c1c5 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -22,7 +22,7 @@ from tvm.runtime.object import Object from . import _ffi_api -from ..expr import Expr, ShapeExpr, Call, ExternFunc +from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc from ..expr import Tuple as RxTuple from ..struct_info import StructInfo, TensorStructInfo from ...ir import PrimExpr @@ -199,7 +199,7 @@ def render_object(val: tvm.Object) -> str: ret: str A string representing the value, ideally human-readable """ - if isinstance(val, tvm.runtime.ndarray.NDArray): + if isinstance(val, tvm.nd.NDArray): return str(val) # no pretty-printer by default, so if we don't handle this, # then we can't look inside tuples @@ -211,6 +211,9 @@ def render_object(val: tvm.Object) -> str: if val.tag == 0: return f"({fields})" return f"ADT(tag={val.tag}, fields=[{fields}])" + if isinstance(val, tvm.ir.Array): + fields = ", ".join([render_object(val[i]) for i in range(len(val))]) + return f"({fields})" return str(val) @@ -240,7 +243,7 @@ def relax_print(format_str: str, *format_args: tvm.Object) -> None: py_print(format_str.format(*val_strs)) -def print(*values: List[Expr], format: str = "") -> Expr: +def print(*values: List[Expr], format: Union[str, Expr] = "") -> Expr: """Print op to print the values Parameters @@ -248,14 +251,17 @@ def print(*values: List[Expr], format: str = "") -> Expr: values : List[Expr] The values to print. - format_str: str - The format string. + format: Union[str, Expr] + The format string or StringImm. Returns ------- result : Expr A relax Call, which will print the value during runtime. """ + if isinstance(format, str): + format = StringImm(format) + return _ffi_api.print(values, format) # type: ignore # pylint: disable=no-member @@ -289,7 +295,7 @@ def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Ob ) # should be guaranteed by the type system - if not isinstance(condition, tvm.runtime.ndarray.NDArray): + if not isinstance(condition, tvm.nd.NDArray): raise ValueError(f"The condition must be an NDArray, but given a {type(condition)}.") # may happen if the original program had unknown shape or dtype for the tensor's type @@ -313,7 +319,9 @@ def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Ob def assert_op( - condition: Expr, format_args: Optional[Union[Expr, List[Expr]]] = None, format: str = "" + condition: Expr, + format_args: Optional[Union[Expr, List[Expr]]] = None, + format: Union[str, Expr] = "", ) -> Expr: """ Create a call to Relax's assert_op operation (`assert` is reserved in Python, @@ -327,8 +335,8 @@ def assert_op( format_args: Optional[Union[Expr, List[Expr]]] Format arguments for the error message if the condition fails. - format_str: str - The format string for the error message. + format: Union[str, Expr] + The format string or StringImm for the error message. Returns ------- @@ -339,6 +347,8 @@ def assert_op( format_args = [] if isinstance(format_args, Expr): # type: ignore format_args = [format_args] + if isinstance(format, str): + format = StringImm(format) return _ffi_api.assert_op(condition, format_args, format) # type: ignore diff --git a/python/tvm/relax/op/builtin/builtin.py b/python/tvm/relax/op/builtin/builtin.py index 0afe6a42d09a..43bbd461bca8 100644 --- a/python/tvm/relax/op/builtin/builtin.py +++ b/python/tvm/relax/op/builtin/builtin.py @@ -15,13 +15,16 @@ # specific language governing permissions and limitations """The builtin Relax operators.""" -from ...expr import Call, Expr +from typing import Union +from ...expr import Call, Expr, PrimValue, DataTypeImm from ...utils import args_converter from . import _ffi_api @args_converter.auto -def alloc_tensor(shape: Expr, dtype: str, runtime_device_index: int) -> Call: +def alloc_tensor( + shape: Expr, dtype: Union[str, Expr], runtime_device_index: Union[int, Expr] +) -> Call: """Construct a Call to allocate a tensor with specific shape, dtype, runtime_device_index. Parameters @@ -29,10 +32,10 @@ def alloc_tensor(shape: Expr, dtype: str, runtime_device_index: int) -> Call: shape : Expr The shape of the tensor to be allocated. - dtype : str + dtype : Union[str, Expr] The datatype of the tensor to be allocated. - runtime_device_index : int + runtime_device_index : Union[int, Expr] The device index indicating on which device the tensor is to be allocated at runtime. Index -1 is reserved for the host device. @@ -41,4 +44,9 @@ def alloc_tensor(shape: Expr, dtype: str, runtime_device_index: int) -> Call: result : Call A relax Call, which gets the allocated tensor. """ + if isinstance(dtype, str): + dtype = DataTypeImm(dtype) + if isinstance(runtime_device_index, int): + runtime_device_index = PrimValue(runtime_device_index) + return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index) # type: ignore diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index efad5d98f01a..ff89d7c90327 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -122,8 +122,3 @@ class LayoutTransformAttrs(Attrs): @tvm._ffi.register_object("relax.attrs.Resize2DAttrs") class Resize2DAttrs(Attrs): """Attributes used in image resize2d operator""" - - -@tvm._ffi.register_object("relax.attrs.UniqueAttrs") -class UniqueAttrs(Attrs): - """Attributes used for the unique operator""" diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py index b7ee0f381169..4d106ad6d23c 100644 --- a/python/tvm/relax/op/set.py +++ b/python/tvm/relax/op/set.py @@ -16,22 +16,22 @@ # under the License. # pylint: disable=import-outside-toplevel, redefined-builtin, unused-argument """Set operators.""" -from typing import Optional +from typing import Optional, Union import numpy as np # type: ignore import tvm from . import _ffi_api -from ..expr import Expr +from ..expr import Expr, PrimValue def unique( x: Expr, - sorted: bool = True, - return_index: bool = False, - return_inverse: bool = False, - return_counts: bool = False, - axis: Optional[int] = None, + sorted: Union[bool, Expr] = True, + return_index: Union[bool, Expr] = False, + return_inverse: Union[bool, Expr] = False, + return_counts: Union[bool, Expr] = False, + axis: Optional[Union[int, Expr]] = None, ) -> Expr: """Find the unique elements in a given tensor. In addition, it optionally returns @@ -44,19 +44,19 @@ def unique( x : relax.Expr The input tensor. - sorted : bool + sorted : Union[bool, Expr] Whether to sort the unique elements in ascending order before returning as output. - return_index : bool + return_index : Union[bool, Expr] Whether to return an additional tensor with indices for where elements in the unique tensor come from the original input. - return_inverse : bool + return_inverse : Union[bool, Expr] Whether to return an additional tensor with indices for where elements in the original input ended up in the returned unique list. - return_counts : bool + return_counts : Union[bool, Expr] Whether to return an additional tensor with counts of each unique elements. axis : Optional @@ -69,6 +69,16 @@ def unique( The created relax call with """ + if isinstance(sorted, bool): + sorted = PrimValue(sorted) + if isinstance(return_index, bool): + return_index = PrimValue(return_index) + if isinstance(return_inverse, bool): + return_inverse = PrimValue(return_inverse) + if isinstance(return_counts, bool): + return_counts = PrimValue(return_counts) + if axis and isinstance(axis, int): + axis = PrimValue(axis) return _ffi_api.unique( # type: ignore x, sorted, return_index, return_inverse, return_counts, axis ) @@ -81,7 +91,6 @@ def numpy_unique( return_index: int, return_inverse: int, return_counts: int, - axis: Optional[int], ) -> tvm.nd.array: """Returns the unique elements of the input tensor. diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 1782f1107a5b..da0ca3a0b55e 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -148,7 +148,14 @@ class CodeGenVM : public ExprFunctor { // allocate dst register. RegName dst_reg = HasVoidStructInfo(call) ? Instruction::kVoidRegister : NewRegister(); if (call->op.as()) { - if (call_node->op == call_builtin_with_ctx_op_) { + // special case generate for the intrinsics whose attribute fields + // cannot be represented by args in the CallNode + FCallPacked name = GetPackedFuncName(call); + if (!name.empty()) { + // If the operator has a registered packed function implementation, emit call to that packed + // function. + EmitPackedFuncCall(call, name, dst_reg); + } else if (call_node->op == call_builtin_with_ctx_op_) { // TODO(relax-team) migrate most handling of op to // directly map to call_builtin_with_ctx before codegen and simplify vm codegen. EmitCallBuiltinWithCtx(call, dst_reg); @@ -355,22 +362,9 @@ class CodeGenVM : public ExprFunctor { builder_->EmitCall(func, args, dst_reg); } - // TODO(relax-team) revisit after PrimValue. - // Emit the `call_node` attributes as constants and append these constants to `args` vector. - void AppendAttrsAsConstants(const Call& call_node, std::vector& args) { - auto attrs = call_node->attrs; - if (!attrs.defined()) return; - - LOG(FATAL) << "Support for attributes of Op " << call_node->op - << " has not been implemented yet."; - return; - } - - // Emits call to packed function `name` with arguments copied over from `call_node` args and - // attributes. + // Emits call to packed function `name` with arguments copied over from `call_node` args void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name, RegName dst_reg) { std::vector args = VisitArray(call_node->args); - AppendAttrsAsConstants(call_node, args); builder_->EmitCall(name, args, dst_reg); } diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index f478871e218f..21d692b6a460 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -141,6 +141,70 @@ Expr MakeCallNullValue() { TVM_REGISTER_GLOBAL("relax.op.null_value").set_body_typed(MakeCallNullValue); +// print + +RELAY_REGISTER_OP("relax.print") + .set_num_inputs(-1) + .add_argument("vals", "Array", + "The first value is Python-style format string to use to print. The others " + "are values to print") + .set_attr("FInferStructInfo", ReturnVoidStructInfo) + .set_attr("FCallPacked", "relax.run.print"); + +Expr MakePrint(Array vals, StringImm format) { + Array params; + params.push_back(format); + for (const auto val : vals) { + params.push_back(val); + } + static const Op& op = Op::Get("relax.print"); + return Call(op, params); +} + +TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint); + +// assert_op + +// can't actually name it assert or else Python will consider it a syntax error + +StructInfo InferAssertStructInfo(const Call& call, const BlockBuilder& ctx) { + // Ensure that the condition argument is a boolean scalar. + // Also permitted is a tensor with unknown shape and unknown dtype + // (checked dynamically in that case). Returns void. + if (call->args.size() < 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Assert must have at least one argument (the condition)."); + } + StructInfo arg_struct_info = GetStructInfo(call->args[0]); + if (!IsBoolStructInfo(arg_struct_info)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The argument to assert must be a boolean scalar, but received " + << arg_struct_info); + } + return ReturnVoidStructInfo(call, ctx); +} + +RELAY_REGISTER_OP("relax.assert_op") + .set_num_inputs(-1) + .add_argument("vals", "Array", + "The first value is used as the assertion condition. The second value is " + "Python-style format string to use for displaying an error message, if the " + "assert fails. The others are used as format arguments if there is an error.") + .set_attr("FInferStructInfo", InferAssertStructInfo) + .set_attr("FCallPacked", "relax.run.assert_op"); + +Expr MakeAssertOp(Expr condition, Array vals, StringImm format) { + static const Op& op = Op::Get("relax.assert_op"); + Array args = {condition}; + args.push_back(format); + for (auto val : vals) { + args.push_back(val); + } + return Call(op, args); +} + +TVM_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp); + // make_closure RELAY_REGISTER_OP("relax.make_closure") @@ -213,15 +277,15 @@ StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& c RELAY_REGISTER_OP("relax.builtin.alloc_tensor") .set_num_inputs(3) .add_argument("shape", "Expr", "The shape of the tensor to allocate.") - .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.") - .add_argument("runtime_device_index", "int64_t", + .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.") + .add_argument("runtime_device_index", "PrimValue", "The device index indicating on which device the tensor is to be " "allocated at runtime. Index -1 is reserved for the host device.") .set_attr("FInferStructInfo", InferStructInfoAllocateTensor); -Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t runtime_device_index) { +Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_index) { static const Op& op = Op::Get("relax.builtin.alloc_tensor"); - return Call(op, {shape, DataTypeImm(dtype), PrimValue::Int64(runtime_device_index)}, Attrs(), {}); + return Call(op, {shape, dtype, runtime_device_index}, Attrs(), {}); } TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor); diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index 4d5a274e17fa..8df0813ed2b5 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -31,34 +31,55 @@ namespace tvm { namespace relax { /* relax.unique */ -TVM_REGISTER_NODE_TYPE(UniqueAttrs); - -Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool return_counts, - Optional axis) { - ObjectPtr attrs = make_object(); - attrs->sorted = sorted; - attrs->return_index = return_index; - attrs->return_inverse = return_inverse; - attrs->return_counts = return_counts; - attrs->axis = std::move(axis); +Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse, + PrimValue return_counts, Optional axis) { static const Op& op = Op::Get("relax.unique"); - return Call(op, {std::move(x)}, Attrs(attrs), {}); + Call call; + if (!axis) { + call = Call(op, {std::move(x), sorted, return_index, return_inverse, return_counts}); + } else { + PrimValue pv_axis = axis.value(); + call = Call(op, {std::move(x), sorted, return_index, return_inverse, return_counts, pv_axis}); + } + return call; } TVM_REGISTER_GLOBAL("relax.op.unique").set_body_typed(unique); StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - const auto* attrs = call->attrs.as(); - if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) { + TensorStructInfo data_sinfo = Downcast(call->args[0]->struct_info_); + PrimValue axis, return_index, return_inverse, return_counts; + if (call->args.size() == 6) { + if (auto* prim_value_node = call->args[5].as()) { + axis = GetRef(prim_value_node); + } + } + if (!data_sinfo->IsUnknownNdim() && axis.defined()) { // Normalize the axis for sanity check purpose. - NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value); + if (const auto* axis_int = axis->value.as()) { + NormalizeAxis(call, ctx, data_sinfo->ndim, axis_int->value); + } } - - int n_int_return = static_cast(attrs->return_index) + - static_cast(attrs->return_inverse) + - static_cast(attrs->return_counts); + ICHECK(call->args[2]->IsInstance()); + ICHECK(call->args[3]->IsInstance()); + ICHECK(call->args[4]->IsInstance()); + + return_index = Downcast(call->args[2]); + return_inverse = Downcast(call->args[3]); + return_counts = Downcast(call->args[4]); + + auto f_convert_to_int64 = [](const PrimExpr& value) { + CHECK(value->IsInstance()) + << value << " expects to be IntImm, but gets " << value->GetTypeKey(); + const auto* val_node = value.as(); + auto val_imm = GetRef(val_node); + return val_imm->value; + }; + + int64_t n_int_return = f_convert_to_int64(return_index->value) + + f_convert_to_int64(return_inverse->value) + + f_convert_to_int64(return_counts->value); std::vector output_sinfo; output_sinfo.reserve(1 + n_int_return); @@ -67,7 +88,7 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { if (data_sinfo->ndim == 0) { output_sinfo.push_back( TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), data_sinfo->dtype)); - } else if (attrs->axis.defined()) { + } else if (axis.defined()) { output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim)); } else { output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1)); @@ -93,9 +114,24 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { } TVM_REGISTER_OP("relax.unique") - .set_attrs_type() - .set_num_inputs(1) + .set_num_inputs(6) .add_argument("x", "Tensor", "The input tensor") + .add_argument( + "sorted", "Tensor", + "Whether to sort the unique elements in ascending order before returning as output.") + .add_argument( + "return_index", "Tensor", + "Whether to return an additional tensor with indices for where elements in the unique " + "tensor come from the original input.") + .add_argument("return_inverse", "Tensor", + "Whether to return an additional tensor with indices for where elements in the " + "original input ended up in the returned unique list.") + .add_argument("return_counts", "Tensor", + "Whether to return an additional tensor with counts of each unique elements") + .add_argument( + "axis", "Tensor", + "The dimension to apply unique. If it is NullOpt, the unique values of the flattened input " + "are returned.") .set_attr("FInferStructInfo", InferStructInfoUnique) .set_attr("FCallPacked", "relax.run.unique"); diff --git a/src/relax/op/tensor/set.h b/src/relax/op/tensor/set.h index 83d2619e4d2c..a5c7ee85bfb2 100644 --- a/src/relax/op/tensor/set.h +++ b/src/relax/op/tensor/set.h @@ -24,16 +24,13 @@ #ifndef TVM_RELAX_OP_TENSOR_SET_H_ #define TVM_RELAX_OP_TENSOR_SET_H_ -#include - #include "../op_common.h" namespace tvm { namespace relax { -Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool return_counts, - Optional axis); - +Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse, + PrimValue return_counts, Optional axis); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index 7b0b98fea976..c66a5729fd4d 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -26,13 +26,128 @@ from tvm.script import relax as R +@tvm.script.ir_module +class InputModule: + @R.function + def foo(x: R.Tensor(("m", "n"), "int64")): + y = R.unique(x, sorted=False) + y_sorted = R.unique(x) + return y, y_sorted + + def run_cpu(mod, func_name, *input): target = tvm.target.Target("llvm") - ex = relax.vm.build(mod, target) + ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, tvm.cpu()) return vm[func_name](*input) +def test_unique(): + + # TODO(prakalp): also add test for compiling and running on cuda device. + data_numpy = np.random.randint(0, 16, (16, 16)) + data = tvm.nd.array(data_numpy) + result, result_sorted = run_cpu(InputModule, "foo", data) + + expected_output_sorted, indices = np.unique(data_numpy, return_index=True) + expected_output = [data_numpy.flatten()[index] for index in sorted(indices, reverse=True)] + + np.testing.assert_array_equal(expected_output_sorted, result_sorted.numpy()) + np.testing.assert_array_equal(expected_output, result.numpy()) + + +@tvm.script.ir_module +class PrintTest: + @R.function + def foo(x: R.Tensor((), "int32")): + # results have to be bound, but we don't use them + # TODO: We should allow calls whose results are not bound for side effects; + # it would be easy syntactic sugar to add. + p1 = R.print(x) + p2 = R.print(x, format="Number: {}") + t = (x, x) + p3 = R.print(t, format="Tuple: {}") + p4 = R.print(x, t) + p5 = R.print(x, x, format="Custom print: {} {}") + p6 = R.print(x, t, format="Another print: {} {}") + return x + + +def test_print(): + try: + stdout = sys.stdout + with tempfile.TemporaryFile(mode="w+") as test_out: + sys.stdout = test_out + run_cpu(PrintTest, "foo", tvm.nd.array(np.array(1).astype("int32"))) + test_out.seek(0) + printed_text = str(test_out.read()) + expected = "1\nNumber: 1\nTuple: (1, 1)\n1 (1, 1)\nCustom print: 1 1\nAnother print: 1 (1, 1)\n" + assert printed_text in expected, ("printed_text is ", printed_text) + finally: + sys.stdout = stdout + + +@tvm.script.ir_module +class AssertOpTest: + @R.function + def passes(x: R.Tensor((), "int32")): + p1 = R.assert_op(relax.const(True)) + return x + + @R.function + def pass_with_args(x: R.Tensor((), "int32")): + p1 = R.assert_op(relax.const(True), x, format="You won't see me") + return x + + @R.function + def simple_fail(x: R.Tensor((), "int32")): + p1 = R.assert_op(relax.const(False)) + return x + + @R.function + def fail_with_message(x: R.Tensor((), "int32")): + p1 = R.assert_op(relax.const(False), format="I failed...") + return x + + @R.function + def fail_with_args(x: R.Tensor((), "int32")): + # no format + p1 = R.assert_op(relax.const(False), [x, x]) + return x + + @R.function + def fail_with_formatted_message(x: R.Tensor((), "int32")): + p1 = R.assert_op(relax.const(False), x, format="Number: {}") + return x + + +def test_assert_op(): + def check_assertion_error(func_name, func_arg, expected_message): + passed = False + try: + run_cpu(AssertOpTest, func_name, func_arg) + passed = True + except TVMError as e: + # TVM will print out a TVMError that will contain the + # generated error at the bottom of a stack trace + assert "AssertionError" in e.args[0] + assert expected_message in e.args[0] + assert not passed + + run_cpu(AssertOpTest, "passes", tvm.nd.array(np.array(1).astype("int32"))) + run_cpu(AssertOpTest, "pass_with_args", tvm.nd.array(np.array(2).astype("int32"))) + check_assertion_error( + "simple_fail", tvm.nd.array(np.array(3).astype("int32")), "Assertion Failed" + ) + check_assertion_error( + "fail_with_message", tvm.nd.array(np.array(4).astype("int32")), "I failed..." + ) + check_assertion_error("fail_with_args", tvm.nd.array(np.array(5).astype("int32")), "5, 5") + check_assertion_error( + "fail_with_formatted_message", tvm.nd.array(np.array(6).astype("int32")), "Number: 6" + ) + + @tvm.script.ir_module class ShapeOfTest: @R.function diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index b458b290ec13..7724c8e761bf 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -213,8 +213,8 @@ def foo(x: R.Tensor((4, 4), "float32")): alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)), "float32", 0)) shape = bb.emit(relax.op.shape_of(alloc)) bb.emit_func_output(shape) - # todo(yongwww): comment this check because 0 was changed to R.prim_value(0) in the printed IR - # _check(foo, bb.get()["foo"]) + + _check(foo, bb.get()["foo"]) def test_symbolic_shape(): From d3a0e98b6d2b9a739fdc0d183f8f96873e9f7501 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 24 Feb 2023 04:16:47 -0500 Subject: [PATCH 68/81] [Unity][BYOC]Add relax backend pattern registry (#14106) * Add relax backend pattern registry * Add doc --- CMakeLists.txt | 1 + python/tvm/relax/backend/__init__.py | 20 +++ python/tvm/relax/backend/_ffi_api.py | 21 ++++ python/tvm/relax/backend/contrib/__init__.py | 20 +++ python/tvm/relax/backend/contrib/cutlass.py | 90 +++++++++++++ python/tvm/relax/backend/pattern_registry.py | 125 +++++++++++++++++++ python/tvm/relax/backend/patterns.py | 115 +++++++++++++++++ python/tvm/relax/dpl/pattern.py | 27 +--- src/relax/backend/pattern_registry.cc | 82 ++++++++++++ src/relax/backend/pattern_registry.h | 106 ++++++++++++++++ tests/python/relax/test_codegen_cutlass.py | 67 ++-------- 11 files changed, 598 insertions(+), 76 deletions(-) create mode 100644 python/tvm/relax/backend/__init__.py create mode 100644 python/tvm/relax/backend/_ffi_api.py create mode 100644 python/tvm/relax/backend/contrib/__init__.py create mode 100644 python/tvm/relax/backend/contrib/cutlass.py create mode 100644 python/tvm/relax/backend/pattern_registry.py create mode 100644 python/tvm/relax/backend/patterns.py create mode 100644 src/relax/backend/pattern_registry.cc create mode 100644 src/relax/backend/pattern_registry.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 18be118832ef..22e82e2fb74a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -295,6 +295,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/relax/transform/*.cc src/relax/backend/vm/*.cc src/relax/backend/task_extraction.cc + src/relax/backend/pattern_registry.cc src/relax/utils.cc ) diff --git a/python/tvm/relax/backend/__init__.py b/python/tvm/relax/backend/__init__.py new file mode 100644 index 000000000000..c3786591e310 --- /dev/null +++ b/python/tvm/relax/backend/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax backends""" + +from . import contrib +from .pattern_registry import get_pattern, get_patterns_with_prefix diff --git a/python/tvm/relax/backend/_ffi_api.py b/python/tvm/relax/backend/_ffi_api.py new file mode 100644 index 000000000000..d1378b2eacc2 --- /dev/null +++ b/python/tvm/relax/backend/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI API for Relax backend.""" + +import tvm._ffi + +tvm._ffi._init_api("relax.backend", __name__) diff --git a/python/tvm/relax/backend/contrib/__init__.py b/python/tvm/relax/backend/contrib/__init__.py new file mode 100644 index 000000000000..a094c97d24bf --- /dev/null +++ b/python/tvm/relax/backend/contrib/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""External backend codegen modules for Relax.""" + +from .cutlass import partition_for_cutlass diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py new file mode 100644 index 000000000000..20cf57a40a5c --- /dev/null +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern table for CUTLASS backend""" + +from tvm.relax import transform + +from ..pattern_registry import get_patterns_with_prefix, register_patterns +from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern + +register_patterns( + [ + ( + "cutlass.conv2d", + make_fused_bias_activation_pattern( + "relax.nn.conv2d", + with_bias=False, + activation=None, + ), + ), + ( + "cutlass.conv2d_bias_relu", + make_fused_bias_activation_pattern( + "relax.nn.conv2d", + with_bias=True, + activation="relax.nn.relu", + ), + ), + ( + "cutlass.matmul", + make_matmul_pattern( + with_bias=False, + ), + ), + ( + "cutlass.matmul_bias", + make_matmul_pattern( + with_bias=True, + ), + ), + ( + "cutlass.matmul_bias_relu", + make_matmul_pattern( + with_bias=True, + activation="relax.nn.relu", + ), + ), + ( + "cutlass.matmul_bias_gelu", + make_matmul_pattern( + with_bias=True, + activation="relax.nn.gelu", + ), + ), + ] +) + + +def partition_for_cutlass(mod): + """ + Partition the input module into CUTLASS-supported subgraphs. + + Parameters + ---------- + mod: tvm.IRModule + The IRModule to be partitioned. + + Returns + ------- + mod: tvm.IRModule + The resulting IRModule, containing partitioned subgraphs to be + compiled by the CUTLASS backend. + """ + + cutlass_patterns = get_patterns_with_prefix("cutlass") + return transform.FuseOpsByPattern(cutlass_patterns, annotate_codegen=True)(mod) diff --git a/python/tvm/relax/backend/pattern_registry.py b/python/tvm/relax/backend/pattern_registry.py new file mode 100644 index 000000000000..0016de0a50da --- /dev/null +++ b/python/tvm/relax/backend/pattern_registry.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern registry for BYOC backends""" + +from typing import List, Mapping, Optional, Tuple, Union + +import tvm +from tvm.relax.dpl import DFPattern +from tvm.runtime import Object + +from . import _ffi_api + + +@tvm._ffi.register_object("relax.backend.PatternRegistryEntry") +class PatternRegistryEntry(Object): + """ + An entry in the pattern registry. This represents a single pattern that + can be used to identify expressions that can be handled by external + backends, like CUTLASS and TensorRT. + + Parameters + ---------- + name: str + The name of pattern. Usually it starts with the name of backend, like 'cutlass.matmul'. + + pattern: DFPattern + The dataflow pattern that will be used to match expressions that can be handled + by external backends. + + arg_patterns: Mapping[str, DFPattern] + The mapping from arg name to its pattern. It can be used to extract arg expression + from match result. All DFPattern in this map should be part of the `pattern`. + """ + + name: str + pattern: DFPattern + arg_patterns: Mapping[str, DFPattern] + + def __init__(self, name: str, pattern: DFPattern, arg_patterns: Mapping[str, DFPattern]): + self.__init_handle_by_constructor__( + _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns # type: ignore + ) + + +Pattern = Union[ + PatternRegistryEntry, + Tuple[str, DFPattern], + Tuple[str, Tuple[DFPattern, Mapping[str, DFPattern]]], +] + + +def register_patterns(patterns: List[Pattern]): + """ + Register patterns which will be used to partition the DataflowBlock into + subgraphs that are supported by external backends. + + Parameters + ---------- + patterns: List[Pattern] + Patterns to be registered. Patterns that appear later in the list have + higher priority when partitioning DataflowBlock. + """ + entries = [] + for item in patterns: + if isinstance(item, PatternRegistryEntry): + entries.append(item) + elif isinstance(item, tuple): + name, pattern_or_tuple = item + if isinstance(pattern_or_tuple, tuple): + pattern, arg_patterns = pattern_or_tuple + else: + pattern, arg_patterns = pattern_or_tuple, {} + entries.append(PatternRegistryEntry(name, pattern, arg_patterns)) + else: + raise TypeError(f"Cannot register type {type(pattern)} as pattern") + _ffi_api.RegisterPatterns(entries) + + +def get_patterns_with_prefix(prefix: str) -> List[PatternRegistryEntry]: + """ + Get a list of patterns whose names startwith `prefix`. + + Parameters + ---------- + prefix: str + The prefix of pattern name. + + Returns + ------- + patterns: PatternRegistryEntry + Matched patterns, ordered by priority from high to low. + """ + return _ffi_api.GetPatternsWithPrefix(prefix) + + +def get_pattern(name: str) -> Optional[PatternRegistryEntry]: + """ + Find the pattern with a particular name. + + Parameters + ---------- + name: str + The pattern name. + + Returns + ------- + pattern: Optional[PatternRegistryEntry] + The matched pattern. Returns None if such pattern is not found. + """ + return _ffi_api.GetPattern(name) diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py new file mode 100644 index 000000000000..2f744af66002 --- /dev/null +++ b/python/tvm/relax/backend/patterns.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Common patterns used in BYOC""" + +from typing import Dict, Mapping, Tuple + +from tvm.relax.dpl.pattern import DFPattern, is_op, wildcard + + +def _with_bias_activation_pattern( + out: DFPattern, + args: Dict[str, DFPattern], + with_bias: bool = False, + activation: str = None, +) -> Tuple[DFPattern, Mapping[str, DFPattern]]: + if with_bias: + args["bias"] = bias = wildcard() + out = is_op("relax.add")(out, bias) + + if activation: + out = is_op(activation)(out) + + return out, args + + +def make_fused_bias_activation_pattern( + op_name: str, + with_bias: bool = False, + activation: str = None, +) -> Tuple[DFPattern, Mapping[str, DFPattern]]: + """ + A simple utility to create patterns for an operation fused with bias addition and activation. + + Parameters + ---------- + op_name: str + The name of a Relax op, such as "relax.nn.conv2d" + + with_bias: bool + Whether or not to include bias addition + + activation: str + The name of an activation Relax op, such as "relax.nn.relu" + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a fused operation + + args: Mapping[str, DFPattern] + The mapping from arg name to its pattern. It can be used to extract + arg expression from match result. + """ + lhs = wildcard() + rhs = wildcard() + args = {"lhs": lhs, "rhs": rhs} + out = is_op(op_name)(lhs, rhs) + + return _with_bias_activation_pattern(out, args, with_bias, activation) + + +def make_matmul_pattern( + with_bias: bool = False, + activation: str = None, + transposed_rhs: bool = False, +) -> Tuple[DFPattern, Mapping[str, DFPattern]]: + """ + Create pattern for matrix multiplication. + + Parameters + ---------- + with_bias: bool + Whether or not to include bias addition + + activation: str + The name of an activation Relax op, such as "relax.nn.relu" + + transposed_rhs: bool + Whether the right hand side of multiplication is transposed. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a matrix multiplication. + + args: Mapping[str, DFPattern] + The mapping from arg name to its pattern. It can be used to extract + arg expression from match result. + """ + + lhs = wildcard() + rhs = wildcard() + args = {"lhs": lhs, "rhs": rhs} + + if transposed_rhs: + rhs = is_op("relax.permute_dims")(rhs) + + out = is_op("relax.matmul")(lhs, rhs) + + return _with_bias_activation_pattern(out, args, with_bias, activation) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 44faa0c93a14..9e1963f7edfd 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -1046,17 +1046,6 @@ def _only_used_by( return ffi.only_used_by(lhs, rhs, index) # type: ignore -def _add_bias_activation_pattern(out, with_bias=False, activation=None): - if with_bias: - bias = wildcard() - out = is_op("relax.add")(out, bias) - - if activation: - return is_op(activation)(out) - - return out - - def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None): """ A simple utility to create patterns for an operation fused with bias addition and activation. @@ -1081,15 +1070,11 @@ def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None rhs = wildcard() out = is_op(op_name)(lhs, rhs) - return _add_bias_activation_pattern(out, with_bias, activation) - + if with_bias: + bias = wildcard() + out = is_op("relax.add")(out, bias) -def make_matmul_pattern(with_bias=False, activation=None, transposed_b=False): - lhs = wildcard() - if transposed_b: - rhs = is_op("relax.permute_dims")(wildcard()) - else: - rhs = wildcard() - out = is_op("relax.matmul")(lhs, rhs) + if activation: + return is_op(activation)(out) - return _add_bias_activation_pattern(out, with_bias, activation) + return out diff --git a/src/relax/backend/pattern_registry.cc b/src/relax/backend/pattern_registry.cc new file mode 100644 index 000000000000..3ca797336588 --- /dev/null +++ b/src/relax/backend/pattern_registry.cc @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./pattern_registry.h" + +#include "../../support/utils.h" + +namespace tvm { +namespace relax { +namespace backend { + +PatternRegistryEntry::PatternRegistryEntry(String name, DFPattern pattern, + Map arg_patterns) { + ObjectPtr n = make_object(); + n->name = std::move(name); + n->pattern = std::move(pattern); + n->arg_patterns = std::move(arg_patterns); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PatternRegistryEntryNode); + +static std::vector* GetRegistryTable() { + static std::vector table; + return &table; +} + +void RegisterPatterns(Array entries) { + auto* table = GetRegistryTable(); + for (const auto& entry : entries) { + table->push_back(entry); + } +} + +Array GetPatternsWithPrefix(const String& prefix) { + auto* table = GetRegistryTable(); + Array result; + for (auto it = table->rbegin(); it != table->rend(); ++it) { + if (support::StartsWith((*it)->name, prefix.data())) { + result.push_back(*it); + } + } + return result; +} + +Optional GetPattern(const String& pattern_name) { + auto* table = GetRegistryTable(); + for (auto it = table->rbegin(); it != table->rend(); ++it) { + if ((*it)->name == pattern_name) { + return *it; + } + } + return NullOpt; +} + +TVM_REGISTER_GLOBAL("relax.backend.PatternRegistryEntry") + .set_body_typed([](String name, DFPattern pattern, Map arg_patterns) { + return PatternRegistryEntry(name, pattern, arg_patterns); + }); +TVM_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns); +TVM_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix").set_body_typed(GetPatternsWithPrefix); +TVM_REGISTER_GLOBAL("relax.backend.GetPattern").set_body_typed(GetPattern); + +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/pattern_registry.h b/src/relax/backend/pattern_registry.h new file mode 100644 index 000000000000..2e199a2bb1db --- /dev/null +++ b/src/relax/backend/pattern_registry.h @@ -0,0 +1,106 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/backend/contrib/pattern_registry.h + * \brief Functions related to registering and retrieving patterns for + * functions handled by backends. + */ +#ifndef TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_ +#define TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace backend { + +/*! + * \brief An entry in the pattern registry. This represents a single pattern that + * can be used to identify expressions that can be handled by external + * backends, like CUTLASS and TensorRT. + */ +class PatternRegistryEntryNode : public Object { + public: + /*! + * \brief The name of pattern. Usually it starts with the name of backend, like + * 'cutlass.matmul'. + */ + String name; + /*! + * \brief The dataflow pattern that will be used to match expressions that can + * be handled by external backends. + */ + DFPattern pattern; + /*! + * \brief The mapping from arg name to its pattern. It can be used to extract + * arg expression from match result. All DFPattern in this map should be part of + * the `pattern`. + */ + Map arg_patterns; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("pattern", &pattern); + v->Visit("arg_patterns", &arg_patterns); + } + + static constexpr const char* _type_key = "relax.backend.PatternRegistryEntry"; + TVM_DECLARE_FINAL_OBJECT_INFO(PatternRegistryEntryNode, Object); +}; + +class PatternRegistryEntry : public ObjectRef { + public: + PatternRegistryEntry(String name, DFPattern pattern, Map arg_patterns); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternRegistryEntry, ObjectRef, + PatternRegistryEntryNode); +}; + +/*! + * \brief Register patterns which will be used to partition the DataflowBlock + * into subgraphs that are supported by external backends. + * \param patterns Patterns to be registered. Patterns that appear later in the list have + * higher priority when partitioning DataflowBlock. + */ +void RegisterPatterns(Array entries); + +/*! + * \brief Find patterns whose name starts with a particular prefix. + * \param prefx The pattern name prefix. + * \return Matched patterns, ordered by priority from high to low. + */ +Array GetPatternsWithPrefix(const String& prefix); + +/*! + * \brief Find the pattern with a particular name. + * \param name The pattern name. + * \return The matched pattern. NullOpt if not found. + */ +Optional GetPattern(const String& name); + +} // namespace backend +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_ diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 5556d1e5d9a8..673155342cbf 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -23,7 +23,7 @@ import tvm import tvm.testing from tvm import relax, relay -from tvm.relax.dpl import make_fused_bias_activation_pattern, make_matmul_pattern +from tvm.relax.backend import get_patterns_with_prefix from tvm.script import relax as R @@ -219,7 +219,11 @@ def main( pytestmark = [cutlass_enabled] -def get_result_with_relax_cutlass_offload(mod, patterns: List[Tuple], *args): +def get_result_with_relax_cutlass_offload(mod, *args): + patterns = [(entry.name, entry.pattern) for entry in get_patterns_with_prefix("cutlass")] + + assert len(patterns) != 0, "Cannot find cutlass patterns" + seq = tvm.transform.Sequential( [ relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True), @@ -243,15 +247,7 @@ def test_conv2d_offload(): weight = np.random.randn(32, 3, 3, 16).astype("float16") bias = np.random.randn(1, 1, 1, 32).astype("float16") - patterns = [ - ( - "cutlass.conv2d_bias_relu", - make_fused_bias_activation_pattern( - "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu" - ), - ) - ] - out = get_result_with_relax_cutlass_offload(Conv2dBiasReLU, patterns, data, weight, bias) + out = get_result_with_relax_cutlass_offload(Conv2dBiasReLU, data, weight, bias) ref_relay_expr = get_relay_conv2d_bias_relu(data.shape, weight.shape) ref = get_relay_ref(ref_relay_expr, data, weight, bias) @@ -327,17 +323,8 @@ def matmul_bias(matmul_size, target_dtype): def test_matmul_offload(matmul_x, matmul_y): x, y = matmul_x, matmul_y - patterns = [ - ( - "cutlass.matmul", - make_matmul_pattern( - with_bias=False, - ), - ), - ] - mod = get_relax_matmul_module(x, y) - out = get_result_with_relax_cutlass_offload(mod, patterns, x, y) + out = get_result_with_relax_cutlass_offload(mod, x, y) ref_relay_expr = get_relay_matmul(x.shape, y.shape[::-1]) ref = get_relay_ref(ref_relay_expr, x, y.transpose()) @@ -347,16 +334,8 @@ def test_matmul_offload(matmul_x, matmul_y): def test_matmul_bias_offload(matmul_x, matmul_y, matmul_bias): x, y, bias = matmul_x, matmul_y, matmul_bias - patterns = [ - ( - "cutlass.matmul_bias", - make_matmul_pattern( - with_bias=True, - ), - ), - ] mod = get_relax_matmul_module(x, y, with_bias=True) - out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias) + out = get_result_with_relax_cutlass_offload(mod, x, y, bias) ref_relay_expr = get_relay_matmul_bias(x.shape, y.shape[::-1]) ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias) @@ -367,17 +346,8 @@ def test_matmul_bias_offload(matmul_x, matmul_y, matmul_bias): def test_matmul_bias_relu_offload(matmul_x, matmul_y, matmul_bias): x, y, bias = matmul_x, matmul_y, matmul_bias - patterns = [ - ( - "cutlass.matmul_bias_relu", - make_matmul_pattern( - with_bias=True, - activation="relax.nn.relu", - ), - ), - ] mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.relu) - out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias) + out = get_result_with_relax_cutlass_offload(mod, x, y, bias) ref_relay_expr = get_relay_matmul_bias_relu(x.shape, y.shape[::-1]) ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias) @@ -388,17 +358,8 @@ def test_matmul_bias_relu_offload(matmul_x, matmul_y, matmul_bias): def test_matmul_bias_gelu_offload(matmul_x, matmul_y, matmul_bias): x, y, bias = matmul_x, matmul_y, matmul_bias - patterns = [ - ( - "cutlass.matmul_bias_gelu", - make_matmul_pattern( - with_bias=True, - activation="relax.nn.gelu", - ), - ), - ] mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.gelu) - out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias) + out = get_result_with_relax_cutlass_offload(mod, x, y, bias) ref_relay_expr = get_relay_matmul_bias_gelu(x.shape, y.shape[::-1]) ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias) @@ -411,11 +372,7 @@ def test_kernel_sharing(): weight1_np = np.random.randn(16, 3, 3, 16).astype("float16") weight2_np = np.random.randn(16, 3, 3, 16).astype("float16") - pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None) - - out = get_result_with_relax_cutlass_offload( - Conv2dx2, [("cutlass.conv2d", pat)], data_np, weight1_np, weight2_np - ) + out = get_result_with_relax_cutlass_offload(Conv2dx2, data_np, weight1_np, weight2_np) relay_expr = get_relay_conv2d_relu_x2(data_np.shape, weight1_np.shape) ref = get_relay_ref(relay_expr, data_np, weight1_np, weight2_np) From cc5292c6cf812b8c687550fac51f39213beec5d1 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Fri, 24 Feb 2023 21:23:36 +0800 Subject: [PATCH 69/81] [Unity] Update tests again to adapt to latest TVMScript syntax (#14115) * finished * fix * rollback merge_composite_functions --- python/tvm/relax/block_builder.py | 4 +- python/tvm/relax/transform/transform.py | 12 +++--- src/script/printer/relax/tir.cc | 7 +--- .../test_backend_transform_shape_lower.py | 8 ++-- .../test_transform_canonicalize_bindings.py | 12 +++--- .../test_transform_legalize_ops_manipulate.py | 8 ++-- .../relax/test_transform_legalize_ops_nn.py | 34 ++++++++--------- .../test_transform_remove_unused_funcs.py | 32 ++++++++-------- .../relax/test_tvmscript_printer_relax.py | 37 ++++++++++--------- tests/python/relax/test_vm_codegen_only.py | 4 +- .../python/unittest/test_arith_detect_cse.py | 6 +-- 11 files changed, 80 insertions(+), 84 deletions(-) diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 783700847909..f219641c81df 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -516,8 +516,8 @@ def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_compute: T.handle) -> None: # function attr dict T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32") compute = T.match_buffer(var_compute, [128, 128], dtype="float32") diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 7044314e8581..263195a10027 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -422,9 +422,9 @@ def main( @T.prim_func def add( - A: T.Buffer[(2, 3), "float32"], - B: T.Buffer[(2, 3), "float32"], - T_add: T.Buffer[(2, 3), "float32"], + A: T.Buffer((2, 3), "float32"), + B: T.Buffer((2, 3), "float32"), + T_add: T.Buffer((2, 3), "float32"), ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(2, 3): @@ -436,9 +436,9 @@ def add( @T.prim_func def multiply( - A: T.Buffer[(2, 3), "float32"], - B: T.Buffer[(2, 3), "float32"], - T_multiply: T.Buffer[(2, 3), "float32"], + A: T.Buffer((2, 3), "float32"), + B: T.Buffer((2, 3), "float32"), + T_multiply: T.Buffer((2, 3), "float32"), ): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(2, 3): diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 2c8bb0f1da6c..9aed11895ac1 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -53,12 +53,7 @@ Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) { } IdDoc var = d->Define(n, GetRef(f), n->name_hint.empty() ? "v" : n->name_hint); var->source_paths.push_back(n_p); - f->stmts.push_back(AssignDoc(var, - TIR(d, "Var")->Call({ - LiteralDoc::Str(var->name, n_p->Attr("name_hint")), - LiteralDoc::DataType(n->dtype, n_p->Attr("dtype")), - }), - NullOpt)); + f->stmts.push_back(AssignDoc(var, TIR(d, DType2Str(n->dtype))->Call({}), NullOpt)); } if (Optional doc = d->GetVarDoc(n)) { return doc.value(); diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index 9c11b352c831..bd53bf8aecab 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -164,8 +164,8 @@ class Before: def main( x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None) ) -> R.Shape(ndim=3): - n = T.Var("n", "int64") - k = T.Var("k", "int64") + n = T.int64() + k = T.int64() z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None)) return R.shape([k + 1, m, 2]) @@ -185,8 +185,8 @@ def shape_func(H: T.Buffer(T.int64(4), "int64")): def main( x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None) ) -> R.Shape(ndim=3): - n = T.Var("n", "int64") - k = T.Var("k", "int64") + n = T.int64() + k = T.int64() shape_heap = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", [R.prim_value(4)], diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index 4694e98973f4..086c316ae817 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -142,7 +142,7 @@ class TestMatchCast: @R.function def main(x: R.Tensor): q = x - m, n = T.var("int64"), T.var("int64") + m, n = T.int64(), T.int64() z = R.match_cast(q, R.Tensor((m, n))) w = z return w @@ -153,7 +153,7 @@ class Expected: def main(x: R.Tensor): q = x # can't get rid of z because its shape_ is different from x's - m, n = T.var("int64"), T.var("int64") + m, n = T.int64(), T.int64() z = R.match_cast(x, R.Tensor((m, n))) w = z return z @@ -167,7 +167,7 @@ def test_same_shape(): class TestSameShape: @R.function def main(x: R.Tensor(("m", "n"), "float32")): - m, n = T.var("int64"), T.var("int64") + m, n = T.int64(), T.int64() y = x # trivial check z = R.match_cast(x, R.Tensor((m, n), "float32")) @@ -179,7 +179,7 @@ def main(x: R.Tensor(("m", "n"), "float32")): class Expected: @R.function def main(x: R.Tensor(("m", "n"), "float32")): - m, n = T.var("int64"), T.var("int64") + m, n = T.int64(), T.int64() y = x # canonicalized into a var binding z = x @@ -198,7 +198,7 @@ class TestChangeShape: def main(x: R.Tensor(("m", "n"))): y = x # not trivial: introduces new shape vars - o, p = T.var("int64"), T.var("int64") + o, p = T.int64(), T.int64() z = R.match_cast(x, R.Tensor((o, p))) w = z q = R.add(w, y) @@ -209,7 +209,7 @@ class Expected: @R.function def main(x: R.Tensor(("m", "n"))): y = x - o, p = T.var("int64"), T.var("int64") + o, p = T.int64(), T.int64() z = R.match_cast(x, R.Tensor((o, p))) w = z # the shape_ field on q will need to be updated diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 8743261ee71e..7ae0eb359a6d 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -802,7 +802,7 @@ def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1, 3), "float32")) -> R.Te return gv @T.prim_func - def collapse_sum(rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], rxplaceholder_red: T.Buffer[(T.int64(1), T.int64(3)), "float32"]): + def collapse_sum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), rxplaceholder_red: T.Buffer((T.int64(1), T.int64(3)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1, i2 in T.grid(T.int64(1), T.int64(3), T.int64(2)): with T.block("rxplaceholder_red"): @@ -825,7 +825,7 @@ def test_collapse_sum_like_symbolic(): class CollapseSumLike: @R.function def main(x: R.Tensor(("a", "b", "a"), "float32"), y: R.Tensor(("b", 1), "float32")) -> R.Tensor(("b", 1), "float32"): - b = T.var("int64") + b = T.int64() gv: R.Tensor((b, 1), "float32") = R.collapse_sum_like(x, y) return gv @@ -855,7 +855,7 @@ def main( return gv @T.prim_func - def collapse_sum(rxplaceholder: T.Buffer[(T.int64(3), T.int64(2), T.int64(3)), "float32"], rxplaceholder_red: T.Buffer[(T.int64(2), T.int64(1)), "float32"]): + def collapse_sum(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(1)), "float32")): T.func_attr({"tir.noalias": True}) for ax0, ax1, k0, k2 in T.grid(T.int64(2), T.int64(1), T.int64(3), T.int64(3)): with T.block("rxplaceholder_red"): @@ -878,7 +878,7 @@ def test_collapse_sum_to_symbolic(): class CollapseSumTo: @R.function def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("b", 1), "float32"): - b = T.var("int64") + b = T.int64() gv: R.Tensor((b, 1), "float32") = R.collapse_sum_to(x, (b, 1)) return gv diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 07d414980e30..698ad2727456 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -868,7 +868,7 @@ def main(x: R.Tensor((2, 3, 16, 32), dtype="float32")) -> R.Tensor((2, 3, 16, 32 return gv @T.prim_func - def log_softmax(rxplaceholder: T.Buffer[(T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"], compute: T.Buffer[(T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"],): + def log_softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"),): T.func_attr({"tir.noalias": True}) T_softmax_maxelem = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") compute_1 = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") @@ -907,9 +907,9 @@ def test_log_softmax_symbolic(): class LogSoftmax: @R.function def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), "float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() gv: R.Tensor((a, b, c), "float32") = R.nn.log_softmax(x) return gv @@ -917,9 +917,9 @@ def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), " class Expected: @R.function def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", "c"), dtype="float32"): - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() # block 0 gv = R.call_tir(log_softmax, (x,), R.Tensor((a, b, c), dtype="float32")) return gv @@ -927,9 +927,9 @@ def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", " @T.prim_func def log_softmax(var_rxplaceholder: T.handle, var_compute: T.handle): T.func_attr({"tir.noalias": True}) - a = T.var("int64") - b = T.var("int64") - c = T.var("int64") + a = T.int64() + b = T.int64() + c = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c], dtype="float32") compute = T.match_buffer(var_compute, [a, b, c], dtype="float32") T_softmax_maxelem = T.alloc_buffer([a, b], dtype="float32") @@ -980,7 +980,7 @@ def main(x: R.Tensor((3,), dtype="float32"), y: R.Tensor((3,), dtype="float32")) return gv @T.prim_func - def cross_entropy_with_logits(rxplaceholder: T.Buffer[T.int64(3), "float32"], rxplaceholder_1: T.Buffer[T.int64(3), "float32"], T_multiply: T.Buffer[(), "float32"]): + def cross_entropy_with_logits(rxplaceholder: T.Buffer(T.int64(3), "float32"), rxplaceholder_1: T.Buffer(T.int64(3), "float32"), T_multiply: T.Buffer((), "float32")): T.func_attr({"tir.noalias": True}) T_multiply_1 = T.alloc_buffer([T.int64(3)], dtype="float32") T_multiply_red = T.alloc_buffer([], dtype="float32") @@ -1026,7 +1026,7 @@ def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float3 return gv @T.prim_func - def cross_entropy_with_logits(rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], rxplaceholder_1: T.Buffer[(T.int64(2), T.int64(3)), "float32"], T_divide: T.Buffer[(), "float32"]): + def cross_entropy_with_logits(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((), "float32")): T.func_attr({"tir.noalias": True}) T_multiply = T.alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") T_multiply_red = T.alloc_buffer([], dtype="float32") @@ -1067,8 +1067,8 @@ def test_cross_entropy_with_logits_batch_symbolic(): class CrossEntropyWithLogits: @R.function def main(x: R.Tensor(("n", "m"), "float32"), y: R.Tensor(("n", "m"), "float32")) -> R.Tensor(None, "float32", ndim=2): - n = T.var("int64") - m = T.var("int64") + n = T.int64() + m = T.int64() gv: R.Tensor((), "float32") = R.nn.cross_entropy_with_logits(x, y) return gv @@ -1080,10 +1080,10 @@ def main(x: R.Tensor(("n", "m"), dtype="float32"), y: R.Tensor(("n", "m"), dtype return gv @T.prim_func - def cross_entropy_with_logits(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, T_divide: T.Buffer[(), "float32"]): + def cross_entropy_with_logits(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, T_divide: T.Buffer((), "float32")): T.func_attr({"tir.noalias": True}) - m = T.var("int64") - n = T.var("int64") + m = T.int64() + n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32") T_multiply = T.alloc_buffer([n, m], dtype="float32") diff --git a/tests/python/relax/test_transform_remove_unused_funcs.py b/tests/python/relax/test_transform_remove_unused_funcs.py index 8a57b38508d0..fa07065ef043 100644 --- a/tests/python/relax/test_transform_remove_unused_funcs.py +++ b/tests/python/relax/test_transform_remove_unused_funcs.py @@ -34,9 +34,9 @@ def test_unused_relax_func(): class InputModule: @T.prim_func def tir_add( - x: T.Buffer[(16, 16), "float32"], - y: T.Buffer[(16, 16), "float32"], - z: T.Buffer[(16, 16), "float32"], + x: T.Buffer((16, 16), "float32"), + y: T.Buffer((16, 16), "float32"), + z: T.Buffer((16, 16), "float32"), ) -> None: for i, j in T.grid(16, 16): with T.block("add"): @@ -68,9 +68,9 @@ def test_unused_relax_func_custom_entry_func(): class InputModule: @T.prim_func def tir_add( - x: T.Buffer[(16, 16), "float32"], - y: T.Buffer[(16, 16), "float32"], - z: T.Buffer[(16, 16), "float32"], + x: T.Buffer((16, 16), "float32"), + y: T.Buffer((16, 16), "float32"), + z: T.Buffer((16, 16), "float32"), ) -> None: for i, j in T.grid(16, 16): with T.block("add"): @@ -105,9 +105,9 @@ def test_unused_relax_func_symbolic_shape(): class InputModule: @T.prim_func def tir_add( - x: T.Buffer[(16, 16), "float32"], - y: T.Buffer[(16, 16), "float32"], - z: T.Buffer[(16, 16), "float32"], + x: T.Buffer((16, 16), "float32"), + y: T.Buffer((16, 16), "float32"), + z: T.Buffer((16, 16), "float32"), ) -> None: for i, j in T.grid(16, 16): with T.block("add"): @@ -121,7 +121,7 @@ def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "flo @R.function def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): - m, k = T.var("int64"), T.var("int64") + m, k = T.int64(), T.int64() gv0 = R.call_tir(tir_add, (x, w), R.Tensor((m + 1, k), dtype="float32")) return gv0 @@ -139,9 +139,9 @@ def test_unused_prim_func(): class InputModule: @T.prim_func def unused_func( - x: T.Buffer[(16, 16), "float32"], - y: T.Buffer[(16, 16), "float32"], - z: T.Buffer[(16, 16), "float32"], + x: T.Buffer((16, 16), "float32"), + y: T.Buffer((16, 16), "float32"), + z: T.Buffer((16, 16), "float32"), ) -> None: T.func_attr({"global_symbol": "tir_unused"}) for i, j in T.grid(16, 16): @@ -175,9 +175,9 @@ def test_multiple_unused_funcs(): class InputModule: @T.prim_func def unused_func1( - x: T.Buffer[(16, 16), "float32"], - y: T.Buffer[(16, 16), "float32"], - z: T.Buffer[(16, 16), "float32"], + x: T.Buffer((16, 16), "float32"), + y: T.Buffer((16, 16), "float32"), + z: T.Buffer((16, 16), "float32"), ) -> None: T.func_attr({"global_symbol": "tir_unused"}) for i, j in T.grid(16, 16): diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index db90c66422d0..464591f2592b 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -17,6 +17,7 @@ # pylint: disable=missing-docstring import tvm import pytest +import tvm.testing from tvm import IRModule, relax, tir from tvm.script import relax as R @@ -99,7 +100,7 @@ def test_shape_struct_info_2(): _assert_print( obj, """ -a = T.Var("a", "int64") +a = T.int64() R.Shape([1, a, 3])""", ) @@ -112,7 +113,7 @@ def test_tensor_struct_info(): _assert_print( obj, """ -a = T.Var("a", "int64") +a = T.int64() R.Tensor((1, a, 3), dtype="float32") """, ) @@ -134,7 +135,7 @@ def test_tuple_struct_info(): _assert_print( obj, """ -a = T.Var("a", "int64") +a = T.int64() R.Tuple(R.Prim("float32"), R.Object, R.Shape([1, a, 3])) """, ) @@ -155,7 +156,7 @@ def test_func_struct_info(): _assert_print( obj, """ -a = T.Var("a", "int64") +a = T.int64() R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3])), R.Tensor((1, 2, 3), dtype="float32")) """, ) @@ -226,7 +227,7 @@ def test_var(): _assert_print( obj, """ -x = T.Var("x", "int64") +x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") a""", ) @@ -237,7 +238,7 @@ def test_dataflow_var(): _assert_print( obj, """ -x = T.Var("x", "int64") +x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") a""", ) @@ -254,11 +255,11 @@ def test_tuple(): _assert_print( obj, """ -x = T.Var("x", "int64") +x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") -y = T.Var("y", "int64") +y = T.int64() b: R.Tensor((1, y, 3), dtype="float32") -z = T.Var("z", "int64") +z = T.int64() c: R.Tensor((1, z, 3), dtype="float32") (a, b, c) """, @@ -279,11 +280,11 @@ def test_tuple_get_item(): _assert_print( obj, """ -x = T.Var("x", "int64") +x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") -y = T.Var("y", "int64") +y = T.int64() b: R.Tensor((1, y, 3), dtype="float32") -z = T.Var("z", "int64") +z = T.int64() c: R.Tensor((1, z, 3), dtype="float32") (a, b, c)[0] """, @@ -302,7 +303,7 @@ def test_call(): _assert_print( obj, """ -x = T.Var("x", "int64") +x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x])) """, @@ -330,7 +331,7 @@ def test_seq_expr(): _assert_print( obj, """ -x = T.Var("x", "int64") +x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") with R.dataflow(): b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) @@ -356,7 +357,7 @@ def test_binding_block(): _assert_print( obj, """ -x = T.Var("x", "int64") +x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) c: R.Tensor((1, x, 3), dtype="float32") = R.sin(b) @@ -379,7 +380,7 @@ def test_dataflow_block(): _assert_print( obj, """ -x = T.Var("x", "int64") +x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") with R.dataflow(): b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) @@ -401,7 +402,7 @@ def test_match_cast(): _assert_print( obj, """ -x = T.Var("x", "int64") +x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") b: R.Tensor((1, 5, 3), dtype="float32") = R.match_cast(a, R.Tensor((1, 5, 3), dtype="float32")) """, @@ -417,7 +418,7 @@ def test_var_binding(): _assert_print( obj, """ -x = T.Var("x", "int64") +x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a) """, diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index 679641de13be..b9904429f3b8 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -194,8 +194,8 @@ class TestVMShapeCheck: @R.function def main(x: R.Tensor(["n", "m"], "float32")) -> R.Shape(ndim=3): R.func_attr({"global_symbol": "main"}) - n = T.Var("n", "int64") - k = T.Var("k", "int64") + n = T.int64() + k = T.int64() shape_heap = R.call_builtin_with_ctx( "vm.builtin.alloc_shape_heap", [R.prim_value(3)], diff --git a/tests/python/unittest/test_arith_detect_cse.py b/tests/python/unittest/test_arith_detect_cse.py index eba0920cb2da..dd7362ff1b7c 100644 --- a/tests/python/unittest/test_arith_detect_cse.py +++ b/tests/python/unittest/test_arith_detect_cse.py @@ -20,9 +20,9 @@ def test_detect_cs(): - x = T.Var("x", dtype="int32") - y = T.Var("y", dtype="int32") - z = T.Var("z", dtype="int32") + x = T.int32() + y = T.int32() + z = T.int32() c = T.floor(x + y + 0.5) + x + z * (T.floor(x + y + 0.5)) m = tvm.arith.detect_common_subexpr(c, 2) assert c.a.a in m From cfce06f0736cd89c9f17eb8525e3f1f545f622bb Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Fri, 24 Feb 2023 23:27:25 +0800 Subject: [PATCH 70/81] [Unity][Fix] Fix bug in MergeCompositeFunctions (#14117) Currently `MergeCompositeFunctions` will modify the map while iterating over it, and that makes tests/python/relax/test_transform_merge_composite_functions.py does not pass. This PR fixes this bug. --- src/relax/transform/merge_composite_functions.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index db73392b02e6..609dd173f21f 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -324,13 +324,17 @@ IRModule MergeCompositeFunctions(IRModule mod) { auto new_mod = MakeGroupedFunctions(mod, group_map); CompositeInliner inliner(mod); + std::vector> to_update; for (const auto& [gvar, func] : new_mod->functions) { if (func->GetAttr(attr::kCodegen)) { auto new_func = inliner.Run(Downcast(func)); new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, gvar->name_hint); - new_mod->Update(gvar, new_func); + to_update.emplace_back(gvar, new_func); } } + for (const auto& [gvar, func] : to_update) { + new_mod->Update(gvar, func); + } // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better way to handle this. return RemoveUnusedFunctions(new_mod, {"main"}); } From 82578c394c17ef3aeced530fbcbd18ee66ce6425 Mon Sep 17 00:00:00 2001 From: Chaofan Lin <1713833595@qq.com> Date: Sat, 25 Feb 2023 14:08:12 +0800 Subject: [PATCH 71/81] [Unity][BlockBuilder] Add `name_hint` argument for `emit` and `emit_output` (#14126) This PR adds `name_hint` argument for `emit` and `emit_output` API of Relax blockbuilder. The argument exists in the C++ side but not exposed to Python side (So user who use the Python bb.emit will let `name_hint` be `""` by default). Co-authored-by: Yixin Dong --- python/tvm/relax/block_builder.py | 19 ++++++++++--------- src/relax/ir/block_builder.cc | 11 ++++++----- tests/python/relax/test_blockbuilder.py | 16 ++++++++++++++++ 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index f219641c81df..3421bd4d0982 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -365,7 +365,7 @@ def dataflow(self) -> DataflowScope: """ return DataflowScope(self) - def emit(self, expr: Expr) -> Var: + def emit(self, expr: Expr, name_hint: str = "") -> Var: """Emit an expr. This infers the shape and type of the expr, create a variable, and bind the expr to the variable. @@ -375,12 +375,15 @@ def emit(self, expr: Expr) -> Var: expr : tvm.relax.Expr The Expr to be emitted. + name_hint : str + Name hint for the bound variable. + Returns ------- ret : tvm.relax.Var A newly created variable that gets bound to the input expr. """ - return _ffi_api.BlockBuilderEmit(self, expr) # type: ignore + return _ffi_api.BlockBuilderEmit(self, expr, name_hint) # type: ignore def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr: """Generate a call node according to the te function. @@ -601,7 +604,7 @@ def match_cast(self, value: Expr, struct_info: StructInfo) -> Var: """ return _ffi_api.BlockBuilderEmitMatchCast(self, value, struct_info) # type: ignore - def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None: + def emit_output(self, output: Union[Expr, Tuple, List[Expr]], name_hint: str = "") -> Var: """Emit output for the current dataflow block or function. Parameters @@ -609,6 +612,9 @@ def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None: output : Expr | Tuple | List[Expr] The output of the current block/function. + name_hint : str + Name hint for the bound variable. + Returns ------- ret : tvm.relax.Var @@ -616,7 +622,7 @@ def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None: """ if isinstance(output, (list, tuple)): output = Tuple(output) - return _ffi_api.BlockBuilderEmitOutput(self, output) # type: ignore + return _ffi_api.BlockBuilderEmitOutput(self, output, name_hint) # type: ignore def emit_func_output( self, @@ -633,11 +639,6 @@ def emit_func_output( params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional The parameters of the function to be built. If params is None, it means the params have been initialized in the function with scope. - - Returns - ------- - ret : tvm.relax.Var - The return variable which gets bound to the output. """ if self._is_emit_func_output_called: raise RuntimeError("emit_func_output must be called exactly once in a relax function.") diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 5976cbb3f441..ac92114ef9cb 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -899,9 +899,10 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderEndBlock") TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize") .set_body_method(&BlockBuilderNode::Normalize); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit").set_body_typed([](BlockBuilder builder, Expr expr) { - return builder->Emit(expr); -}); +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit") + .set_body_typed([](BlockBuilder builder, Expr expr, String name_hint) { + return builder->Emit(expr, name_hint); + }); TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast") .set_body_typed([](BlockBuilder builder, Expr value, StructInfo struct_info) { @@ -909,8 +910,8 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast") }); TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput") - .set_body_typed([](BlockBuilder builder, const Expr& output) { - return builder->EmitOutput(output); + .set_body_typed([](BlockBuilder builder, const Expr& output, String name_hint) { + return builder->EmitOutput(output, name_hint); }); TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitNormalized") diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py index e54e2b7bf943..9d9d28d7d615 100644 --- a/tests/python/relax/test_blockbuilder.py +++ b/tests/python/relax/test_blockbuilder.py @@ -57,6 +57,22 @@ def test_block_builder(): assert not isinstance(b2, rx.DataflowBlock) +def test_emit_with_name(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) + y = rx.Var("y", rx.TensorStructInfo([n], "float16")) + bb = rx.BlockBuilder() + + bb._begin_dataflow_block() + lv0 = bb.emit(rx.op.add(x, y), "add") + gv0 = bb.emit_output(rx.op.multiply(lv0, y), "multi") + b0 = bb._end_block() + + assert b0.bindings[0].var.name_hint == "add" + assert b0.bindings[1].var.name_hint == "multi" + + def test_function_single_block(): m = tir.Var("m", "int64") n = tir.Var("n", "int64") From 678d01dd4a4e75ef6186ce356bb1a20e584a7b24 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 25 Feb 2023 13:22:10 -0500 Subject: [PATCH 72/81] [Unity][WEB] Relax vm on web runtime (#14131) This PR brings initial relax vm support on web runtime --- include/tvm/runtime/relax_vm/vm.h | 4 + python/tvm/contrib/tvmjs.py | 119 ++++++++++ python/tvm/exec/rpc_proxy.py | 32 ++- python/tvm/relax/vm_build.py | 14 +- python/tvm/rpc/proxy.py | 21 +- src/runtime/relax_vm/vm.cc | 11 + web/.gitignore | 1 + web/apps/browser/rpc_server.html | 65 +++++- web/emcc/wasm_runtime.cc | 74 ++++++ web/src/rpc_server.ts | 29 ++- web/src/runtime.ts | 315 ++++++++++++++++++++++++-- web/tests/node/test_relax_vm.js | 67 ++++++ web/tests/python/prepare_test_libs.py | 30 ++- web/tests/python/relax_rpc_test.py | 87 +++++++ web/tests/python/webgpu_rpc_test.py | 4 +- web/tests/python/websock_rpc_test.py | 4 +- 16 files changed, 825 insertions(+), 52 deletions(-) create mode 100644 python/tvm/contrib/tvmjs.py create mode 100644 web/tests/node/test_relax_vm.js create mode 100644 web/tests/python/relax_rpc_test.py diff --git a/include/tvm/runtime/relax_vm/vm.h b/include/tvm/runtime/relax_vm/vm.h index d39de74f2dab..bd59106cc1cf 100644 --- a/include/tvm/runtime/relax_vm/vm.h +++ b/include/tvm/runtime/relax_vm/vm.h @@ -23,6 +23,10 @@ #ifndef TVM_RUNTIME_RELAX_VM_VM_H_ #define TVM_RUNTIME_RELAX_VM_VM_H_ +#ifndef TVM_RELAX_VM_ENABLE_PROFILER +#define TVM_RELAX_VM_ENABLE_PROFILER 1 +#endif + #include #include #include diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py new file mode 100644 index 000000000000..18cbf332c8fe --- /dev/null +++ b/python/tvm/contrib/tvmjs.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Namespace to store utilities for building web runtime.""" +# pylint: disable=unused-import +import sys +import os +import json +from typing import Mapping, Union + +import numpy as np + +import tvm +from .emcc import create_tvmjs_wasm + + +def _convert_f32_to_bf16(value): + cap = np.finfo("float32").max + assert -np.finfo("float32").max == np.finfo("float32").min + bf16_limit = ((np.array([cap.view("uint32")]) >> 16) << 16).view("float32")[0] + # When the value is in [-bf16_limit, bf16_limit], round to nearest even. + # We can afford to do it in dumping phase to reduce overall rounding error. + # + # When the value is out of bound(usually mask values in attention), use truncation + # so it is equivalent to clip to the limit values + data = value.view("uint32") + rounding_bias = np.where( + np.logical_and(value < bf16_limit, value > -bf16_limit), + ((data >> 16) & 1) + 0x7FFF, + np.zeros_like(data), + ) + return ((data + rounding_bias) >> 16).astype("uint16") + + +def dump_ndarray_cache( + params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]], + cachedir: str, + encode_format="f32-to-bf16", +): + """Dump parameters to NDArray cache. + + Parameters + ---------- + params: Mapping[str, tvm.runtime.NDArray], + The parameter dictionary + + cachedir: str + The path to the cache + + encode_format: {"f32-to-bf16", "raw"} + Encoding format. + """ + records = [] + total = len(params) + counter = 0 + max_out_length = 0 + + if not os.path.exists(cachedir): + os.makedirs(cachedir) + + f32_to_bf16_triggered = False + + print("Start storing to cache %s" % cachedir) + for k, v in params.items(): + fname = k + ".bin" + out_path = os.path.join(cachedir, fname) + shape = list(v.shape) + + if not isinstance(v, np.ndarray): + v = v.numpy() + + # convert fp32 to bf16 + if encode_format == "f32-to-bf16" and v.dtype == "float32": + _convert_f32_to_bf16(v).tofile(out_path) + dtype = "bfloat16" + f32_to_bf16_triggered = True + else: + v.tofile(out_path) + + dtype = str(v.dtype) + records.append( + {"name": k, "shape": shape, "dtype": dtype, "dataPath": fname, "format": encode_format} + ) + counter += 1 + last_cmd = "[%04d/%04d] saving %s" % (counter, total, out_path) + flush = "\r" + (" " * max_out_length) + "\r" + max_out_length = max(len(last_cmd), max_out_length) + sys.stdout.write(flush + last_cmd) + + nd_cache_json = os.path.join(cachedir, "ndarray-cache.json") + with open(nd_cache_json, "w") as outfile: + json.dump(records, outfile, indent=4) + print("\nAll finished, record saved to %s" % nd_cache_json) + + if f32_to_bf16_triggered: + rec_bf16 = [] + for item in records: + if item["dtype"] == "float32": + item["format"] = "raw" + item["dtype"] = "bfloat16" + rec_bf16.append(item) + b16_nd_cache_json = os.path.join(cachedir, "ndarray-cache-b16.json") + # also dump a file that contains bf16 + with open(b16_nd_cache_json, "w") as outfile: + json.dump(rec_bf16, outfile, indent=4) + print("Also saved a bf16 record to %s" % b16_nd_cache_json) diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py index 7eae4fe1742f..d340750785e4 100644 --- a/python/tvm/exec/rpc_proxy.py +++ b/python/tvm/exec/rpc_proxy.py @@ -19,6 +19,7 @@ import logging import argparse import os +import glob from tvm.rpc.proxy import Proxy @@ -28,16 +29,29 @@ def find_example_resource(): base_path = os.path.abspath(os.path.join(curr_path, "..", "..", "..")) index_page = os.path.join(base_path, "web", "apps", "browser", "rpc_server.html") resource_files = [ - os.path.join(base_path, "web", "dist", "tvmjs.bundle.js"), - os.path.join(base_path, "web", "dist", "wasm", "tvmjs_runtime.wasi.js"), + ("/", os.path.join(base_path, "web", "dist", "tvmjs.bundle.js")), + ("/", os.path.join(base_path, "web", "dist", "wasm", "tvmjs_runtime.wasi.js")), + ("/", index_page), ] - resource_base = os.path.join(base_path, "web", "dist", "www") - if os.path.isdir(resource_base): - for fname in os.listdir(resource_base): - full_name = os.path.join(resource_base, fname) - if os.path.isfile(full_name): - resource_files.append(full_name) - for fname in [index_page] + resource_files: + allow_format = ("json", "bin", "js", "wasm") + + # recursively apend things in www, up to two levels + resource_bases = [ + os.path.join(base_path, "web", "dist", "www"), + os.path.join(base_path, "web", ".ndarray_cache"), + ] + for base in resource_bases: + if not os.path.isdir(base): + continue + for full_name in glob.glob("%s/**" % base, recursive=True): + fname = os.path.relpath(full_name, base) + dirname = os.path.dirname(fname) + fmt = fname.rsplit(".", 1)[-1] + if os.path.isfile(full_name) and fmt in allow_format: + resource_files.append((dirname, full_name)) + + for item in resource_files: + fname = item[-1] if not os.path.exists(fname): raise RuntimeError("Cannot find %s" % fname) return index_page, resource_files diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index 35fc65bdc6c0..0586bf9217a2 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -180,6 +180,18 @@ def _vmcodegen( raise ValueError("Unknown exec_mode %s" % exec_mode) +def _autodetect_system_lib_req(target: tvm.target.Target): + """Automatically detect system lib requirement""" + host = target if target.host is None else target.host + system_lib = False + if "wasm" in host.attrs.get("mtriple", ""): + system_lib = True + if system_lib: + # use packed-func to avoid relay dep. + return tvm.get_global_func("relay.backend.CreateRuntime")("cpp", {"system-lib": system_lib}) + return None + + def _vmlink( builder: "relax.ExecBuilder", target: Union[str, tvm.target.Target], @@ -224,7 +236,7 @@ def _vmlink( ext_libs = [] lib = None if tir_mod is not None: - lib = tvm.build(tir_mod, target=target) + lib = tvm.build(tir_mod, target=target, runtime=_autodetect_system_lib_req(target)) return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index d7027c88a4b5..59af53d4e164 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -203,11 +203,20 @@ def signal_close(self): self.close() +MIME_MAP = { + "js": "application/javascript", + "wasm": "application/wasm", + "json": "application/json", +} + + class RequestHandler(tornado.web.RequestHandler): """Handles html request.""" def __init__(self, *args, **kwargs): file_path = kwargs.pop("file_path") + self.format = file_path.split(".")[-1] + if file_path.endswith("html"): self.page = open(file_path).read() web_port = kwargs.pop("rpc_web_port", None) @@ -217,12 +226,15 @@ def __init__(self, *args, **kwargs): ) else: self.page = open(file_path, "rb").read() + super(RequestHandler, self).__init__(*args, **kwargs) def data_received(self, _): pass def get(self, *args, **kwargs): + if self.format in MIME_MAP: + self.set_header("Content-Type", MIME_MAP[self.format]) self.write(self.page) @@ -254,9 +266,14 @@ def __init__( ) logging.info("Serving RPC index html page at http://localhost:%d", web_port) resource_files = resource_files if resource_files else [] - for fname in resource_files: + for item in resource_files: + prefix, fname = item + if not prefix.endswith("/"): + prefix += "/" + if not prefix.startswith("/"): + prefix = "/" + prefix basename = os.path.basename(fname) - pair = (r"/%s" % basename, RequestHandler, {"file_path": fname}) + pair = (r"%s%s" % (prefix, basename), RequestHandler, {"file_path": fname}) handlers.append(pair) logging.info(pair) self.app = tornado.web.Application(handlers) diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index 3b952c1ff597..8679b2a79330 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -827,6 +827,11 @@ void VirtualMachineImpl::RunLoop() { ObjectPtr VirtualMachine::Create() { return make_object(); } +//---------------------------------------------------------------- +// Profiler can be optionally disabled via a macro to reduce dep. +//---------------------------------------------------------------- +#if TVM_RELAX_VM_ENABLE_PROFILER + /*! * \brief An extension of VirtualMachineImpl to support per-op profiling * It overrides RunInstrCall to add instrumentations around it. @@ -927,6 +932,12 @@ ObjectPtr VirtualMachine::CreateProfiler() { return make_object(); } +#else +ObjectPtr VirtualMachine::CreateProfiler() { + LOG(FATAL) << "Profiler support is disabled"; + return nullptr; +} +#endif // TVM_RELAX_VM_ENABLE_PROFILER } // namespace relax_vm } // namespace runtime } // namespace tvm diff --git a/web/.gitignore b/web/.gitignore index 1f7cc0916a5f..69bf96a8a726 100644 --- a/web/.gitignore +++ b/web/.gitignore @@ -4,3 +4,4 @@ out node_modules build debug +.ndarray_cache diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html index 6d353e29b08d..8fa50272b24d 100644 --- a/web/apps/browser/rpc_server.html +++ b/web/apps/browser/rpc_server.html @@ -15,38 +15,71 @@ + - + TVM RPC Test Page - +

TVM WebSocket RPC Server

To use this page
    @@ -59,20 +92,34 @@

    TVM WebSocket RPC Server

Options

- Proxy URL
- RPC Server Key
+ NDArrayCache - + + CacheDevice - + +
+
+
+ +
diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 00d2a8c579f1..c90b917c5c8b 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -26,6 +26,7 @@ #define TVM_LOG_STACK_TRACE 0 #define TVM_LOG_DEBUG 0 #define TVM_LOG_CUSTOMIZE 1 + #define DMLC_USE_LOGGING_LIBRARY #include @@ -51,6 +52,12 @@ #include "src/runtime/rpc/rpc_session.cc" #include "src/runtime/system_library.cc" #include "src/runtime/workspace_pool.cc" +// relax setup +#include "src/runtime/relax_vm/builtin.cc" +#include "src/runtime/relax_vm/bytecode.cc" +#include "src/runtime/relax_vm/executable.cc" +#include "src/runtime/relax_vm/memory_manager.cc" +#include "src/runtime/relax_vm/vm.cc" // --- Implementations of backend and wasm runtime API. --- @@ -111,5 +118,72 @@ TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRet // and get another value. *ret = (obj.use_count() - 1); }); + +/*! + * A NDArray cache to store pre-loaded arrays in the system. + */ +class NDArrayCache { + public: + static NDArrayCache* Global() { + static NDArrayCache* inst = new NDArrayCache(); + return inst; + } + + static void Update(String name, NDArray arr, bool override) { + NDArrayCache* pool = Global(); + if (!override) { + ICHECK_EQ(pool->pool_.count(name), 0) << "Name " << name << " already exists in the cache"; + } + pool->pool_.Set(name, arr); + } + + static Optional Get(String name) { + NDArrayCache* pool = Global(); + auto it = pool->pool_.find(name); + if (it != pool->pool_.end()) { + return (*it).second; + } else { + return NullOpt; + } + } + + static void Remove(String name) { + NDArrayCache* pool = Global(); + pool->pool_.erase(name); + } + + static void Clear() { Global()->pool_.clear(); } + + private: + Map pool_; +}; + +TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.get").set_body_typed(NDArrayCache::Get); +TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.update").set_body_typed(NDArrayCache::Update); +TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove); +TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear); + +void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format) { + if (format == "f32-to-bf16") { + std::vector buffer(bytes.length() / 2); + std::memcpy(buffer.data(), bytes.data(), buffer.size() * 2); + // decode bf16 to f32 + const uint16_t* bf16 = reinterpret_cast(buffer.data()); + uint32_t* data = static_cast(cpu_arr->data); + ICHECK(cpu_arr.IsContiguous()); + size_t size = 1; + for (int i = 0; i < cpu_arr->ndim; ++i) { + size *= cpu_arr->shape[i]; + } + ICHECK_EQ(size, bytes.length() / 2); + for (size_t i = 0; i < size; ++i) { + data[i] = static_cast(bf16[i]) << 16; + } + } else { + cpu_arr.CopyFromBytes(bytes.data(), bytes.length()); + } +} + +TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage); } // namespace runtime } // namespace tvm diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts index e37d1838d604..4dd7228d3cfe 100644 --- a/web/src/rpc_server.ts +++ b/web/src/rpc_server.ts @@ -82,6 +82,9 @@ export class RPCServer { state: RPCServerState = RPCServerState.InitHeader; logger: (msg: string) => void; getImports: () => Record; + private ndarrayCacheUrl: string; + private ndarrayCacheDevice: string; + private fetchProgressCallback?: runtime.FetchProgressCallback; private pendingSend: Promise = Promise.resolve(); private name: string; private inst?: runtime.Instance = undefined; @@ -98,13 +101,19 @@ export class RPCServer { url: string, key: string, getImports: () => Record, - logger: (msg: string) => void = console.log + logger: (msg: string) => void = console.log, + ndarrayCacheUrl: string = "", + ndarrayCacheDevice: string = "cpu", + fetchProgressCallback: runtime.FetchProgressCallback | undefined = undefined ) { this.url = url; this.key = key; this.name = "WebSocketRPCServer[" + this.key + "]: "; this.getImports = getImports; this.logger = logger; + this.ndarrayCacheUrl = ndarrayCacheUrl; + this.ndarrayCacheDevice = ndarrayCacheDevice; + this.fetchProgressCallback = fetchProgressCallback; this.checkLittleEndian(); this.socket = compact.createWebSocket(url); @@ -132,7 +141,9 @@ export class RPCServer { if (this.state == RPCServerState.ReceivePacketHeader) { this.log("Closing the server in clean state"); this.log("Automatic reconnecting.."); - new RPCServer(this.url, this.key, this.getImports, this.logger); + new RPCServer( + this.url, this.key, this.getImports, this.logger, + this.ndarrayCacheUrl, this.ndarrayCacheDevice, this.fetchProgressCallback); } else { this.log("Closing the server, final state=" + this.state); } @@ -272,6 +283,20 @@ export class RPCServer { // begin scope to allow handling of objects // the object should stay alive during all sessions. this.inst.beginScope(); + if (this.fetchProgressCallback !== undefined) { + this.inst.registerFetchProgressCallback(this.fetchProgressCallback); + } + + if (this.ndarrayCacheUrl.length != 0) { + if (this.ndarrayCacheDevice == "cpu") { + await this.inst.fetchNDArrayCache(this.ndarrayCacheUrl, this.inst.cpu()); + } else { + assert(this.ndarrayCacheDevice == "webgpu"); + await this.inst.fetchNDArrayCache(this.ndarrayCacheUrl, this.inst.webgpu()); + } + } + + assert(this.inst !== undefined); const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer"); const messageHandler = fcreate( diff --git a/web/src/runtime.ts b/web/src/runtime.ts index a24459ca29a0..463532762ec4 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -29,7 +29,6 @@ import { WebGPUContext } from "./webgpu"; import * as compact from "./compact"; import * as ctypes from "./ctypes"; -import { tsImportEqualsDeclaration } from "@babel/types"; /** * Type for PackedFunc inthe TVMRuntime. @@ -144,6 +143,11 @@ class RuntimeContext implements Disposable { arrayGetSize : PackedFunc; arrayMake : PackedFunc; getSysLib: PackedFunc; + arrayCacheGet: PackedFunc; + arrayCacheUpdate: PackedFunc; + arrayCacheRemove: PackedFunc; + arrayCacheClear: PackedFunc; + arrayDecodeStorage: PackedFunc; private autoDisposeScope: Array> = []; @@ -152,12 +156,25 @@ class RuntimeContext implements Disposable { this.arrayGetSize = getGlobalFunc("runtime.ArraySize"); this.arrayMake = getGlobalFunc("runtime.Array"); this.getSysLib = getGlobalFunc("runtime.SystemLib"); + this.arrayCacheGet = getGlobalFunc("tvmjs.ndarray_cache.get"); + this.arrayCacheRemove = getGlobalFunc("tvmjs.ndarray_cache.remove"); + this.arrayCacheUpdate = getGlobalFunc("tvmjs.ndarray_cache.update"); + this.arrayCacheClear = getGlobalFunc("tvmjs.ndarray_cache.clear"); + this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage"); + } dispose(): void { + // call array cache clear to clear all cached items + this.arrayCacheClear(); this.arrayGetItem.dispose(); this.arrayGetSize.dispose(); this.arrayMake.dispose(); + this.arrayCacheGet.dispose(); + this.arrayCacheRemove.dispose(); + this.arrayCacheUpdate.dispose(); + this.arrayCacheClear.dispose(); + this.arrayDecodeStorage.dispose(); } beginScope() : void { @@ -522,6 +539,9 @@ export class NDArray implements Disposable { * @returns this */ copyFromRawBytes(data: Uint8Array): this { + if (this.device.deviceType != DeviceStrToEnum.cpu) { + throw new Error("Can only sync copy CPU array, use cpu_arr.copyfrom(gpu_arr) then sync instead."); + } const size = this.shape.reduce((a, b) => { return a * b; }, 1); @@ -552,7 +572,7 @@ export class NDArray implements Disposable { */ toRawBytes(): Uint8Array { if (this.device.deviceType != DeviceStrToEnum.cpu) { - throw new Error("Can only synchronize copy for GPU array, use copyfrom instead."); + throw new Error("Can only sync copy CPU array, use cpu_arr.copyfrom(gpu_arr) then sync instead."); } const size = this.shape.reduce((a, b) => { return a * b; @@ -806,12 +826,70 @@ export class TVMArray extends TVMObject { } } +export const enum VMAllocatorKind { + NAIVE_ALLOCATOR = 1, + POOLED_ALLOCATOR = 2, +} + +/** + * VirtualMachine Executor. + * + * This is a thin wrapper of the underlying TVM module. + * you can also directly call set_input, run, and get_output + * of underlying module functions + */ +export class VirtualMachine implements Disposable { + private mod: Module; + /** + * Constructor + * @param mod The underlying module, need to be detached. + * @param device The main device ro run VM on. + */ + constructor(mod: Module, device: DLDevice) { + this.mod = mod; + this.mod.getFunction("vm_initialization")( + new Scalar(device.deviceType, "int"), + new Scalar(device.deviceId, "int"), + new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int") + ); + } + + dispose(): void { + this.mod.dispose(); + } + /** + * Get a function in the VM module. + * @param name The name of the function. + * @returns The result function. + */ + getFunction(name: string): PackedFunc { + return this.mod.getFunction(name); + } +} + /** Code used as the first argument of the async callback. */ const enum AyncCallbackCode { kReturn = 4, kException = 5, } +export interface NDArrayCacheEntry { + name: string; + shape: Array; + dtype: string; + format: "f32-to-bf16" | "raw"; + dataPath: string; +} + +export interface FetchProgressReport { + fetchedBytes: number; + totalBytes: number; + timeElapsed: number; + text: string; +} + +export type FetchProgressCallback = (report: FetchProgressReport) => void; + /** * TVM runtime instance. * @@ -836,6 +914,7 @@ export class Instance implements Disposable { private env: Environment; private objFactory: Map; private ctx: RuntimeContext; + private fetchProgressCallback: Array = []; /** * Internal function(registered by the runtime) @@ -898,26 +977,26 @@ export class Instance implements Disposable { * @number The number of times to compute the average. * @repeat The number of times to repeat the run. */ - async benchmark(run: ()=>void, dev: DLDevice, number=10, repeat=4): Promise { - // Skip first run as it can involve GPU warmup and module loading time. - const perf = compact.getPerformance(); - const results = []; + async benchmark(run: ()=>void, dev: DLDevice, number=10, repeat=4): Promise { + // Skip first run as it can involve GPU warmup and module loading time. + const perf = compact.getPerformance(); + const results = []; - // run with new scope - this.withNewScope(run); - await dev.sync(); + // run with new scope + this.withNewScope(run); + await dev.sync(); - for (let k = 0; k < repeat; ++k) { - const tstart = perf.now(); - for (let i = 0; i < number; ++i) { - this.withNewScope(run); - } - await dev.sync(); - const tend = perf.now(); - results.push((tend - tstart) / number); + for (let k = 0; k < repeat; ++k) { + const tstart = perf.now(); + for (let i = 0; i < number; ++i) { + this.withNewScope(run); } - return results; + await dev.sync(); + const tend = perf.now(); + results.push((tend - tstart) / number); } + return results; + } dispose(): void { // order matters @@ -1131,9 +1210,9 @@ export class Instance implements Disposable { * @param func Input function. * @returns The converted function. */ - toPackedFunc(func: Function): PackedFunc { - return this.toPackedFuncInternal(func, true); - } + toPackedFunc(func: Function): PackedFunc { + return this.toPackedFuncInternal(func, true); + } private toPackedFuncInternal(func: Function, autoAttachToScope: boolean): PackedFunc { if (this.isPackedFunc(func)) return func as PackedFunc; @@ -1142,6 +1221,200 @@ export class Instance implements Disposable { return ret; } + /** + * Setup a virtual machine module with given device. + * + * @param dev DLDevice the device. + * @returns The created virtual machime. + */ + createVirtualMachine(dev: DLDevice): VirtualMachine { + const mod = this.ctx.detachFromCurrentScope( + this.systemLib().getFunction("vm_load_executable")() + ); + return this.ctx.attachToCurrentScope( + new VirtualMachine(mod, dev) + ); + } + + //----------------------------------------------- + // Native NDArray Cache Support + //----------------------------------------------- + /** + * Register a call back for fetch progress. + * + * @param cb the fetch progress callback. + */ + registerFetchProgressCallback(cb: FetchProgressCallback) { + this.fetchProgressCallback.push(cb); + } + + /** + * Get NDArray from cache. + * @param name The name of array. + * @returns The result. + */ + ndarrayCacheGet(name: string) : NDArray | undefined { + return this.ctx.arrayCacheGet(name); + } + + /** + * Get NDArray from cache. + * @param name The name of array. + * @returns The result. + */ + ndarrayCacheRemove(name: string) : NDArray | undefined { + return this.ctx.arrayCacheRemove(name); + } + + /** + * Update the ndarray cache. + * @param name The name of the array. + * @param arr The content. + */ + ndarrayCacheUpdate(name: string, arr: NDArray, override: boolean = false) { + this.ctx.arrayCacheUpdate(name, arr, this.scalar(override ? 1 : 0, "int32")); + } + + /** + * Update the ndarray cache. + * @param name The name of the array. + * @param arr The content. + */ + ndarrayCacheClear() { + this.ctx.arrayCacheClear(); + } + + /** + * Fetch NDArray cache from url. + * + * @param ndarrayCacheUrl The cache url. + * @param device The device to be fetched to. + */ + async fetchNDArrayCache(ndarrayCacheUrl: string, device: DLDevice) { + const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href; + var list; + try { + + list = await (await fetch(jsonUrl)).json(); + } catch(err) { + this.env.logger("Cannot fetch " + jsonUrl); + } + await this.fetchNDArrayCacheInternal(ndarrayCacheUrl, list as Array, device); + } + + /** + * Fetch list of NDArray into the NDArrayCache. + * + * @param ndarrayCacheUrl The cache url. + * @param list The list of array data. + * @param device The device to store the data to. + */ + private async fetchNDArrayCacheInternal(ndarrayCacheUrl: string, list: Array, device: DLDevice) { + const computeTotalBytes = (rec: NDArrayCacheEntry) => { + + const dtype = this.toDLDataType(rec.dtype); + const size = rec.shape.reduce((a, b) => { + return a * b; + }, 1); + if (rec.format == "f32-to-bf16" && rec.dtype == "float32") { + return size * 2; + } + return size * dtype.bits * dtype.lanes / 8; + }; + const perf = compact.getPerformance(); + let tstart = perf.now(); + + let totalBytes = 0; + for (let i = 0; i < list.length; ++i) { + totalBytes += computeTotalBytes(list[i]); + }; + let fetchedBytes = 0; + let timeElapsed = 0; + + const reportCallback = (iter: number)=> { + // report + for (let j = 0; j < this.fetchProgressCallback.length; ++j) { + let text = "Fetching NDArray Cache[" + iter + "/" + list.length+ "]:"; + text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB fetched " + text += "from " + Math.ceil(totalBytes / (1024 * 1024)).toString() + "MB, " + text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, " + text += timeElapsed + " secs elapsed"; + if (timeElapsed != 0){ + text += ", speed=" + (fetchedBytes / timeElapsed / (1024 * 1024)).toFixed(1) + " MB/sec"; + } + this.fetchProgressCallback[j]({ + fetchedBytes: fetchedBytes, + totalBytes: totalBytes, + timeElapsed: timeElapsed, + text: text + }); + } + }; + + for (let j = 0; j < this.fetchProgressCallback.length; ++j) { + this.fetchProgressCallback[j]({ + fetchedBytes: 0, + totalBytes: totalBytes, + timeElapsed: 0, + text: "Start to fetch " + ndarrayCacheUrl + }); + } + const cache = await caches.open("tvmjs"); + + for (let i = 0; i < list.length; ++i) { + const rec = list[i]; + reportCallback(i); + fetchedBytes += computeTotalBytes(rec); + const cpu_arr = this.withNewScope(() => { + return this.detachFromCurrentScope( + this.empty(rec.shape, rec.dtype, this.cpu()) + ) + }); + const dataUrl = new URL(rec.dataPath, ndarrayCacheUrl).href; + const request = new Request(dataUrl); + + let buffer; + try { + // use native cache + let result = await cache.match(request); + if (result === undefined) { + await cache.add(request); + result = await cache.match(request); + } + if (result == undefined) { + this.env.logger("Error: Cannot cache " + dataUrl + ", reloading will be slow"); + result = await fetch(request); + } + buffer = await result.arrayBuffer(); + } catch (err) { + this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err); + cpu_arr.dispose(); + throw err; + } + // first sync copy to cpu. + this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(buffer), rec.format); + // then async stream into GPU if needed + if (device.deviceType == DeviceStrToEnum.cpu) { + this.ndarrayCacheUpdate(rec.name, cpu_arr, false); + cpu_arr.dispose(); + } else { + // allocate a gpu arr and async copy to it. + const gpu_arr = this.withNewScope(() => { + return this.detachFromCurrentScope( + this.empty(rec.shape, rec.dtype, device) + ) + }); + gpu_arr.copyFrom(cpu_arr); + await device.sync(); + this.ndarrayCacheUpdate(rec.name, gpu_arr, false); + cpu_arr.dispose(); + gpu_arr.dispose(); + } + timeElapsed = Math.ceil((perf.now() - tstart) / 1000); + } + reportCallback(list.length); + } + /** * Convert dtype to {@link DLDataType} * diff --git a/web/tests/node/test_relax_vm.js b/web/tests/node/test_relax_vm.js new file mode 100644 index 000000000000..ceb47aa014ec --- /dev/null +++ b/web/tests/node/test_relax_vm.js @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* eslint-disable no-undef */ +// Load Emscripten Module, need to change path to root/lib +const path = require("path"); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "test_relax.wasm")); + +const tvm = new tvmjs.Instance( + new WebAssembly.Module(wasmSource), + new EmccWASI() +); + + +function randomArray(length, max) { + return Array.apply(null, Array(length)).map(function () { + return Math.random() * max; + }); +} + +test("add one", () => { + tvm.beginScope(); + // Load system library + const vm = tvm.createVirtualMachine(tvm.cpu()); + // grab pre-loaded function + const fadd = vm.getFunction("main"); + + assert(tvm.isPackedFunc(fadd)); + const n = 124; + const A = tvm.empty(n).copyFrom(randomArray(n, 1)); + const B = tvm.empty(n).copyFrom(randomArray(n, 1)); + + // call the function. + const C = fadd(A, B); + const AA = A.toArray(); // retrieve values in js array + const BB = B.toArray(); // retrieve values in js array + const CC = C.toArray(); // retrieve values in js array + // verify + for (var i = 0; i < BB.length; ++i) { + assert(Math.abs(CC[i] - (AA[i] + BB[i])) < 1e-5); + } + tvm.endScope(); + // assert auto release scope behavior + assert(vm.mod.getHandle(false) == 0); + assert(fadd._tvmPackedCell.getHandle(false) == 0); +}); diff --git a/web/tests/python/prepare_test_libs.py b/web/tests/python/prepare_test_libs.py index 5c1f7c68c421..a63e0655b45d 100644 --- a/web/tests/python/prepare_test_libs.py +++ b/web/tests/python/prepare_test_libs.py @@ -18,12 +18,32 @@ import tvm from tvm import te -from tvm.contrib import emcc +from tvm.contrib import tvmjs from tvm.relay.backend import Runtime +from tvm import relax +from tvm.script import relax as R import os -def prepare_test_libs(base_path): +def prepare_relax_lib(base_path): + pipeline = relax.get_pipeline() + + @tvm.script.ir_module + class Mod: + @R.function + def main(x: R.Tensor(["n"], "float32"), y: R.Tensor(["n"], "float32")): + lv0 = R.add(x, y) + return lv0 + + target = tvm.target.Target("llvm -mtriple=wasm32-unknown-unknown-wasm") + + mod = pipeline(Mod) + ex = relax.build(mod, target) + wasm_path = os.path.join(base_path, "test_relax.wasm") + ex.export_library(wasm_path, tvmjs.create_tvmjs_wasm) + + +def prepare_tir_lib(base_path): runtime = Runtime("cpp", {"system-lib": True}) target = "llvm -mtriple=wasm32-unknown-unknown-wasm" if not tvm.runtime.enabled(target): @@ -35,9 +55,11 @@ def prepare_test_libs(base_path): fadd = tvm.build(s, [A, B], target, runtime=runtime, name="add_one") wasm_path = os.path.join(base_path, "test_addone.wasm") - fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + fadd.export_library(wasm_path, tvmjs.create_tvmjs_wasm) if __name__ == "__main__": curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - prepare_test_libs(os.path.join(curr_path, "../../dist/wasm")) + base_path = os.path.join(curr_path, "../../dist/wasm") + prepare_tir_lib(base_path) + prepare_relax_lib(base_path) diff --git a/web/tests/python/relax_rpc_test.py b/web/tests/python/relax_rpc_test.py new file mode 100644 index 000000000000..a347fe70b345 --- /dev/null +++ b/web/tests/python/relax_rpc_test.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test relax vm through rpc.""" + +import tvm +import numpy as np +from tvm import rpc, relax +from tvm.contrib import utils, tvmjs +from tvm.script import relax as R + +proxy_host = "127.0.0.1" +proxy_port = 9090 + + +def get_model(): + pipeline = relax.get_pipeline() + + @tvm.script.ir_module + class Mod: + @R.function + def main(x: R.Tensor([1024], "float32"), y: R.Tensor([1024], "float32")): + lv0 = R.add(x, y) + return lv0 + + mod = pipeline(Mod) + sch = tvm.tir.Schedule(mod) + # manually transform loop + sch.work_on("add") + (i,) = sch.get_loops(block=sch.get_block("T_add")) + i0, i1 = sch.split(i, [None, 128]) + sch.bind(i0, "blockIdx.x") + sch.bind(i1, "threadIdx.x") + return sch.mod + + +def test_rpc(): + if not tvm.runtime.enabled("rpc"): + return + n = 1024 + dtype = "float32" + temp = utils.tempdir() + wasm_path = temp.relpath("relax.wasm") + target = tvm.target.Target("webgpu", host="llvm -mtriple=wasm32-unknown-unknown-wasm") + + mod = get_model() + ex = relax.build(mod, target) + ex.export_library(wasm_path, tvmjs.create_tvmjs_wasm) + wasm_binary = open(wasm_path, "rb").read() + + remote = rpc.connect( + proxy_host, + proxy_port, + key="wasm", + session_constructor_args=["rpc.WasmSession", wasm_binary], + ) + + def check(remote): + dev = remote.webgpu(0) + # invoke the function + vm = relax.VirtualMachine(remote.system_lib(), device=dev) + adata = np.random.uniform(size=n).astype(dtype) + bdata = np.random.uniform(size=n).astype(dtype) + a = tvm.nd.array(adata, dev) + b = tvm.nd.array(bdata, dev) + vm.set_input("main", a, b) + vm.invoke_stateful("main") + c = vm.get_outputs("main") + np.testing.assert_equal(c.numpy(), a.numpy() + b.numpy()) + + check(remote) + + +test_rpc() diff --git a/web/tests/python/webgpu_rpc_test.py b/web/tests/python/webgpu_rpc_test.py index 6e34a8a2b36c..986393e9d41d 100644 --- a/web/tests/python/webgpu_rpc_test.py +++ b/web/tests/python/webgpu_rpc_test.py @@ -23,7 +23,7 @@ import tvm from tvm import te from tvm import rpc -from tvm.contrib import utils, emcc +from tvm.contrib import utils, tvmjs from tvm.relay.backend import Runtime import numpy as np @@ -52,7 +52,7 @@ def test_rpc(): temp = utils.tempdir() wasm_path = temp.relpath("addone_gpu.wasm") - fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + fadd.export_library(wasm_path, tvmjs.create_tvmjs_wasm) wasm_binary = open(wasm_path, "rb").read() remote = rpc.connect( diff --git a/web/tests/python/websock_rpc_test.py b/web/tests/python/websock_rpc_test.py index 7de5ee956ec8..19d5dc57480c 100644 --- a/web/tests/python/websock_rpc_test.py +++ b/web/tests/python/websock_rpc_test.py @@ -23,7 +23,7 @@ import tvm from tvm import te from tvm import rpc -from tvm.contrib import utils, emcc +from tvm.contrib import utils, tvmjs from tvm.relay.backend import Runtime import numpy as np @@ -48,7 +48,7 @@ def test_rpc(): temp = utils.tempdir() wasm_path = temp.relpath("addone.wasm") - fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + fadd.export_library(wasm_path, tvmjs.create_tvmjs_wasm) wasm_binary = open(wasm_path, "rb").read() From e62169cc8afb5b9062a40072d6d44fc817408d77 Mon Sep 17 00:00:00 2001 From: Hongyi Jin <3231950289@qq.com> Date: Sun, 26 Feb 2023 11:05:47 -0500 Subject: [PATCH 73/81] [Unity] Add Global info (#14132) --- include/tvm/ir/global_info.h | 80 +++++++++++++++++++++ include/tvm/ir/module.h | 16 ++++- include/tvm/script/ir_builder/base.h | 2 + include/tvm/script/ir_builder/ir/frame.h | 7 ++ python/tvm/ir/__init__.py | 1 + python/tvm/ir/global_info.py | 42 +++++++++++ python/tvm/ir/module.py | 30 +++++++- python/tvm/script/ir_builder/base.py | 11 +++ python/tvm/script/ir_builder/ir/__init__.py | 9 ++- python/tvm/script/ir_builder/ir/ir.py | 39 +++++++++- python/tvm/script/parser/ir/__init__.py | 4 +- python/tvm/script/parser/ir/parser.py | 11 ++- src/ir/global_info.cc | 32 +++++++++ src/ir/module.cc | 25 +++++-- src/script/ir_builder/base.cc | 6 ++ src/script/ir_builder/ir/frame.cc | 3 +- src/script/ir_builder/ir/ir.cc | 24 +++++++ src/script/printer/ir/ir.cc | 15 ++++ tests/python/relax/test_tvmscript_parser.py | 42 +++++++++++ 19 files changed, 383 insertions(+), 16 deletions(-) create mode 100644 include/tvm/ir/global_info.h create mode 100644 python/tvm/ir/global_info.py create mode 100644 src/ir/global_info.cc diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h new file mode 100644 index 000000000000..65b5e0a3d28d --- /dev/null +++ b/include/tvm/ir/global_info.h @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/global_info.h + * \brief GlobalInfo are globally static object that are referred by the IR itself. + */ + +#ifndef TVM_IR_GLOBAL_INFO_H_ +#define TVM_IR_GLOBAL_INFO_H_ + +#include "tvm/ir/expr.h" + +namespace tvm { + +/*! + * \brief GlobalInfo are globally static object that are referred by the IR itself. + * Base node for all global info that can appear in the IR + */ +class GlobalInfoNode : public Object { + public: + static constexpr const char* _type_key = "GlobalInfoNode"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(GlobalInfoNode, Object); +}; + +/*! + * \brief Managed reference to GlobalInfoNode. + * \sa GlobalInfoNode + */ +class GlobalInfo : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(GlobalInfo, ObjectRef, GlobalInfoNode); +}; + +/*! + * \brief A dummy global info sub-class for testing purpose. + */ +class DummyGlobalInfoNode : public GlobalInfoNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "DummyGlobalInfo"; + + TVM_DLL bool SEqualReduce(const DummyGlobalInfoNode* other, SEqualReducer equal) const { + return true; + } + + TVM_DLL void SHashReduce(SHashReducer hash_reduce) const {} + TVM_DECLARE_FINAL_OBJECT_INFO(DummyGlobalInfoNode, GlobalInfoNode); +}; + +/*! + * \brief Managed reference to DummyGlobalInfoNode. + * \sa DummyGlobalInfoNode + */ +class DummyGlobalInfo : public GlobalInfo { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DummyGlobalInfo, GlobalInfo, DummyGlobalInfoNode); +}; + +} // namespace tvm + +#endif // TVM_IR_GLOBAL_INFO_H_ diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 538ff64ca3fb..4c2d5cd81264 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -63,6 +64,8 @@ class IRModuleNode : public Object { SourceMap source_map; /* \brief Additional attributes storing meta-data about the module. */ DictAttrs attrs; + /*! \brief Globally static object that are referred by the IR itself */ + Map> global_infos; /*! * \brief A map from string names to global variables that * ensures global uniqueness. @@ -151,6 +154,7 @@ class IRModuleNode : public Object { v->Visit("global_type_var_map_", &global_type_var_map_); v->Visit("source_map", &source_map); v->Visit("attrs", &attrs); + v->Visit("global_infos", &global_infos); } TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const; @@ -210,6 +214,13 @@ class IRModuleNode : public Object { */ TVM_DLL void UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type); + /*! + * \brief Update an array of global infos in the global environment. + * \param name The name of the global info. + * \param info The new array of global infos. + */ + TVM_DLL void UpdateGlobalInfo(const String& name, const Array& info); + /*! * \brief Remove a function from the global environment. * \param var The name of the global function to update. @@ -359,12 +370,13 @@ class IRModule : public ObjectRef { * \param type_definitions Type definitions in the module. * \param import_set Set of imported files in the module. * \param map The module source map. - * \param attrs The module attributes. + * \param attrs The module meta-data attributes. + * \param global_infos Global infos in the module. */ TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, std::unordered_set import_set = {}, SourceMap map = {}, - DictAttrs attrs = {}); + DictAttrs attrs = {}, Map> global_infos = {}); /*! \brief default constructor */ IRModule() : IRModule(Map({})) {} diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 61ca3eb9f7eb..a00ea5768e23 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef { * \sa tvm::support::With */ static IRBuilder Current(); + /*! \brief See if the current thread-local scope has an IRBuilder. */ + static bool IsInScope(); /*! * \brief Give a string name to the `obj` * \tparam TObjectRef The type of the object to name. diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index dacfc361a6c7..6e758372b94b 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -45,11 +46,17 @@ class IRModuleFrameNode : public IRBuilderFrameNode { * \note Only defined functions are in the map, while declared functions are not included. */ Map functions; + /*! \brief IRModule's attributes. */ + Map attrs; + /*! \brief IRModule's global_infos */ + Map> global_infos; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); + v->Visit("attrs", &attrs); + v->Visit("global_infos", &global_infos); } static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame"; diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 4f63cbecd9d1..01fea2abbda7 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -34,6 +34,7 @@ from .container import Array, Map from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr from .function import BaseFunc, CallingConv +from .global_info import GlobalInfo, DummyGlobalInfo from .memory_pools import ( ConstantMemoryPools, ConstantPoolInfo, diff --git a/python/tvm/ir/global_info.py b/python/tvm/ir/global_info.py new file mode 100644 index 000000000000..17011e76a66c --- /dev/null +++ b/python/tvm/ir/global_info.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Global Info.""" +import tvm +from tvm.runtime.object import Object +from . import _ffi_api + + +class GlobalInfo(Object): + """Base node for all global info that can appear in the IR""" + + def __eq__(self, other): + """Compare two struct info for structural equivalence.""" + return tvm.ir.structural_equal(self, other) + + def __ne__(self, other): + return not self.__eq__(other) + + def same_as(self, other): + """Overload with structural equality.""" + return super().__eq__(other) + + +class DummyGlobalInfo(GlobalInfo): + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.DummyGlobalInfo, + ) diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 6a151d5a897c..707d46d0cdf8 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -42,7 +42,7 @@ class IRModule(Node, Scriptable): Map of global var to BaseFunc """ - def __init__(self, functions=None, type_definitions=None): + def __init__(self, functions=None, type_definitions=None, attrs=None, global_infos=None): if functions is None: functions = {} elif isinstance(functions, dict): @@ -65,7 +65,20 @@ def __init__(self, functions=None, type_definitions=None): raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") mapped_type_defs[k] = v type_definitions = mapped_type_defs - self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions) + + attrs = None if not attrs else attrs + if attrs is not None: + attrs = ast.literal_eval(str(attrs)) + attrs = tvm.ir.make_node("DictAttrs", **attrs) + if global_infos is None: + global_infos = {} + self.__init_handle_by_constructor__( + _ffi_api.IRModule, + functions, + type_definitions, + attrs, + global_infos, + ) def __setitem__(self, var, val): """Add a mapping to the module. @@ -140,6 +153,19 @@ def update_func(self, var, func): """ return _ffi_api.Module_UpdateFunction(self, var, func) + def update_global_info(self, name, global_info): + """Update global info in the module + + Parameters + ---------- + name: str + The name for the global info. + + global_info: List[GlobalInfo] + The global info to be updated. + """ + return _ffi_api.Module_UpdateGlobalInfo(self, name, global_info) + def get_global_var(self, name): """Get a global variable in the function by name. diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index b35bbd0a7df5..1d5d050444f7 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -138,6 +138,17 @@ def current() -> "IRBuilder": """ return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] # pylint: disable=no-member + @staticmethod + def is_in_scope() -> bool: + """See if the current thread-local scope has an IRBuilder. + + Returns + ------- + bool + Whether the current thread-local scope has an IRBuilder + """ + return _ffi_api.IRBuilderIsInScope() # type: ignore[attr-defined] # pylint: disable=no-member + def get(self) -> _Object: """Get the constructed IR.""" return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index 946be263a779..68eda2cfeebf 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,11 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import decl_function, def_function, ir_module +from .ir import ( + decl_function, + def_function, + ir_module, + module_attrs, + module_global_infos, + dummy_global_info, +) diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 796d6f3aad04..53c48b4cc540 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,7 +16,11 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" -from tvm.ir import BaseFunc, GlobalVar +from typing import Dict, List + +from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, DummyGlobalInfo +from tvm.runtime import Object as tvm_Object + from . import _ffi_api from .frame import IRModuleFrame @@ -67,3 +71,36 @@ def def_function(func_name: str, func: BaseFunc) -> None: The given function implementation """ return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member + + +def module_attrs(attrs: Dict[str, tvm_Object]) -> None: + """Specify the attrs of the ir_module frame. + Parameters + ---------- + attrs: Dict[str, Object] + The module attrs. + """ + return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member + + +def module_global_infos(global_infos: Dict[str, List[GlobalInfo]]) -> None: + """Specify the global infos of the ir_module frame. + Parameters + ---------- + global_infos: Dict[str, List[GlobalInfo]] + The module global infos. + """ + return _ffi_api.ModuleGlobalInfos(global_infos) # type: ignore[attr-defined] # pylint: disable=no-member + + +############################### GlobalInfo ############################### + + +def dummy_global_info() -> DummyGlobalInfo: + """Create a dummy global info expression. + Returns + ------- + res : DummyGlobalInfo + The result dummy global info. + """ + return DummyGlobalInfo() # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/parser/ir/__init__.py b/python/tvm/script/parser/ir/__init__.py index fedd2f0a14a8..f8c9d4f0afc9 100644 --- a/python/tvm/script/parser/ir/__init__.py +++ b/python/tvm/script/parser/ir/__init__.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """The ir module parser""" - +from ...ir_builder.ir import * # pylint: disable=redefined-builtin from . import parser as _parser from .entry import ir_module -__all__ = ["ir_module"] +__all__ = ["ir_module", "module_attrs", "module_global_infos", "dummy_global_info"] diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index 13b3e298590f..201c99074f20 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -35,11 +35,17 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: with self.var_table.with_frame(): with I.ir_module(): + with self.with_dispatch_token("ir"): + for stmt in node.body: + if not isinstance(stmt, doc.FunctionDef): + self.visit(stmt) for stmt in node.body: if isinstance(stmt, doc.FunctionDef): self.visit_tvm_declare_function(stmt) with self.with_dispatch_token("ir"): - self.visit_body(node.body) + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit(stmt) @dispatch.register(token="ir", type_name="Assign") @@ -57,7 +63,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None: @dispatch.register(token="ir", type_name="Expr") -def _visit_expr(_self: Parser, _node: doc.Expr) -> None: +def _visit_expr(self: Parser, node: doc.Expr) -> None: """The expression visiting method for ir module. Parameters @@ -68,6 +74,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None: node : doc.ClassDef The doc AST expression node. """ + self.eval_expr(node.value) @dispatch.register(token="default", type_name="Assign") diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc new file mode 100644 index 000000000000..48f56d60d68c --- /dev/null +++ b/src/ir/global_info.cc @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/global_info.cc + * \brief Module global info. + */ + +#include +namespace tvm { +TVM_REGISTER_NODE_TYPE(DummyGlobalInfoNode); +TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() { + auto n = DummyGlobalInfo(make_object()); + return n; +}); +} // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index 8f23f19d352e..da1f3942c78f 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -34,7 +34,8 @@ namespace tvm { IRModule::IRModule(tvm::Map functions, tvm::Map type_definitions, - std::unordered_set import_set, SourceMap source_map, DictAttrs attrs) { + std::unordered_set import_set, SourceMap source_map, DictAttrs attrs, + Map> global_infos) { auto n = make_object(); n->functions = std::move(functions); n->type_definitions = std::move(type_definitions); @@ -44,6 +45,7 @@ IRModule::IRModule(tvm::Map functions, n->import_set_ = std::move(import_set); n->source_map = source_map; n->attrs = std::move(attrs); + n->global_infos = std::move(global_infos); for (const auto& kv : n->functions) { // set global var map @@ -64,7 +66,10 @@ IRModule::IRModule(tvm::Map functions, bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { if (!equal(this->attrs, other->attrs)) return false; - + if (this->global_infos.size() != other->global_infos.size()) return false; + for (const auto& kv : this->global_infos) { + if (!equal(kv.second, other->global_infos[kv.first])) return false; + } if (functions.size() != other->functions.size()) return false; // Update GlobalVar remap for (const auto& gv : this->GetGlobalVars()) { @@ -116,6 +121,7 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { } reduce_temp(); hash_reduce(this->attrs); + hash_reduce(this->global_infos); } bool IRModuleNode::ContainGlobalVar(const String& name) const { @@ -239,6 +245,10 @@ void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) this->AddTypeDef(var, type, true); } +void IRModuleNode::UpdateGlobalInfo(const String& name, const Array& info) { + this->global_infos.Set(name, info); +} + void IRModuleNode::Remove(const GlobalVar& var) { auto functions_node = this->functions.CopyOnWrite(); functions_node->erase(var); @@ -359,9 +369,9 @@ IRModule IRModule::FromText(const String& text, const String& source_path) { TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") - .set_body_typed([](tvm::Map funcs, - tvm::Map types) { - return IRModule(funcs, types, {}); + .set_body_typed([](tvm::Map funcs, tvm::Map types, + tvm::DictAttrs attrs, Map> global_infos) { + return IRModule(funcs, types, {}, {}, attrs, global_infos); }); TVM_REGISTER_GLOBAL("ir.Module_Add") @@ -423,6 +433,11 @@ TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction") .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); +TVM_REGISTER_GLOBAL("ir.Module_UpdateGlobalInfo") + .set_body_typed([](IRModule mod, String name, Array global_info) { + mod->UpdateGlobalInfo(name, global_info); + }); + TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) { mod->Import(path); }); diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 8303efff4f20..879db4f3d713 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -77,6 +77,11 @@ IRBuilder IRBuilder::Current() { return stack->back(); } +bool IRBuilder::IsInScope() { + std::vector* stack = ThreadLocalBuilderStack(); + return !stack->empty(); +} + namespace details { Namer::FType& Namer::vtable() { @@ -106,6 +111,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") .set_body_method(&IRBuilderNode::Get); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name); diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index addf12928435..3d917cee887b 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -38,7 +38,8 @@ void IRModuleFrameNode::ExitWithScope() { } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; - builder->result = tvm::IRModule(func_map); + auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); + builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs, global_infos); } TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index da2330b5772b..148e90b28c05 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -69,9 +69,33 @@ void DefFunction(const String& func_name, const BaseFunc& func) { } } +void ModuleAttrs(Map attrs) { + if (IRBuilder::IsInScope()) { + // TODO(hongyi): add comments to explain why we need to check if the module frame is in scope + IRModuleFrame frame = FindModuleFrame("I.ModuleAttr"); + if (!frame->attrs.empty()) { + LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" << frame->attrs; + } + frame->attrs = attrs; + } +} + +void ModuleGlobalInfos(Map> global_infos) { + if (IRBuilder::IsInScope()) { + IRModuleFrame frame = FindModuleFrame("I.ModuleGlobalInfos"); + if (!frame->global_infos.empty()) { + LOG(FATAL) << "ValueError: Duplicate module global_infos, previous one is:\n" + << frame->global_infos; + } + frame->global_infos = global_infos; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGlobalInfos").set_body_typed(ModuleGlobalInfos); } // namespace ir } // namespace ir_builder diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index e6f4a1eaee2c..62919246b073 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -64,6 +64,16 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) std::sort(functions.begin(), functions.end()); With f(d); (*f)->AddDispatchToken(d, "ir"); + if (mod->attrs.defined() && !mod->attrs->dict.empty()) { + (*f)->stmts.push_back( + ExprStmtDoc(IR(d, "module_attrs") // + ->Call({d->AsDoc(mod->attrs, p->Attr("attrs"))}))); + } + if (mod->global_infos.defined() && !mod->global_infos.empty()) { + (*f)->stmts.push_back(ExprStmtDoc( + IR(d, "module_global_infos") // + ->Call({d->AsDoc(mod->global_infos, p->Attr("global_infos"))}))); + } for (const auto& entry : functions) { const GlobalVar& gv = entry.gv; const BaseFunc& func = entry.func; @@ -92,6 +102,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return IR(d, "GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))}); }); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](GlobalInfo ginfo, ObjectPath p, IRDocsifier d) -> Doc { + return IR(d, "dummy_global_info")->Call({}); + }); + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc { return IR(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))}); diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 7724c8e761bf..9636a98b41b8 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -22,6 +22,7 @@ import tvm.script import tvm.testing from tvm import IRModule, relax, tir, topi +from tvm.ir import DummyGlobalInfo from tvm.script.parser import ir as I from tvm.script.parser import relax as R from tvm.script.parser import tir as T @@ -183,6 +184,47 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): _check(TestModule, bb.get()) +def test_module_with_attr_and_global_info(): + @I.ir_module + class TestModule: + I.module_attrs({"attr": 10}) + I.module_global_infos( + { + "dummy": [ + I.dummy_global_info(), # dummy[0] + I.dummy_global_info(), # dummy[1] + ] + } + ) + + @T.prim_func + def tir_func( + x: T.Buffer((T.int64(128), T.int64(128)), "float32"), + y: T.Buffer((T.int64(128), T.int64(128)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i, j in T.grid(T.int64(128), T.int64(128)): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + y[vi, vj] = x[vi, vj] + 1.0 + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): + # TODO(Siyuan): Need to change to `TestModule.tir_func` + gv0 = R.call_tir(tir_func, x, R.Tensor((128, 128), dtype="float32")) + return gv0 + + x = relax.Var("x", R.Tensor((128, 128), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") + bb.emit_func_output(out) + mod = bb.get() + mod.update_global_info("dummy", [DummyGlobalInfo(), DummyGlobalInfo()]) + mod = mod.with_attr("attr", tvm.tir.IntImm("int32", 10)) + _check(TestModule, mod) + + def test_relax_tensor_op(): @R.function def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"): From d7a6285f473dad912dd90183248f05b07a18e7e4 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Mon, 27 Feb 2023 03:33:50 -0500 Subject: [PATCH 74/81] [Unity][BYOC] Add transposed matmul support to Relax CUTLASS BYOC (#14128) Add transposed matmul support for Relax CUTLASS --- python/tvm/contrib/cutlass/build.py | 88 ++++++++++++------ python/tvm/contrib/cutlass/gemm_operation.py | 11 +-- python/tvm/contrib/cutlass/gen_tensor_op.py | 93 +++++++++++++++----- python/tvm/relax/__init__.py | 1 + python/tvm/relax/backend/contrib/cutlass.py | 30 +++++++ python/tvm/relax/dpl/pattern.py | 25 ++++++ src/relax/ir/dataflow_matcher.cc | 2 + tests/python/relax/test_codegen_cutlass.py | 68 ++++++++++++-- 8 files changed, 259 insertions(+), 59 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index c6e5adacec86..954aef60c242 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -19,6 +19,7 @@ import logging import multiprocessing import os +from typing import Optional import tvm from tvm import relax, relay, runtime @@ -522,7 +523,19 @@ def tune_cutlass_function( ) -def _extract_relax_function_info(f): +def _get_call_node(expr: relax.Expr, op_name: str) -> Optional[relax.Call]: + node = None + + def fvisit(e): + nonlocal node + if isinstance(e, relax.Call) and e.op.name == op_name: + node = e + + relax.analysis.post_order_visit(expr, fvisit) + return node + + +def _extract_relax_function_signature(f): signature = {} for i, arg in enumerate(f.params): @@ -534,16 +547,26 @@ def _extract_relax_function_info(f): signature["ret_shape"] = list(ret_sinfo.shape) signature["ret_dtype"] = ret_sinfo.dtype - op_attrs = {} + return signature - def fvisit(e): - nonlocal op_attrs - if isinstance(e, relax.Call) and e.op.name in ["relax.nn.conv2d"]: - op_attrs = e.attrs - relax.analysis.post_order_visit(f.body, fvisit) +def _extract_arg_idx(pattern_name, f): + pattern_entry = relax.backend.get_pattern(pattern_name) + if pattern_entry is None: + raise ValueError(f"Unsupported op_type {pattern_name}") + var2val = relax.analysis.get_var2val(f) + matched_expr = pattern_entry.pattern.extract_matched_expr(f.body.body, var2val) - return signature, op_attrs + func_args = list(f.params) + + arg_idx = {} + for arg_name, arg_pattern in pattern_entry.arg_patterns.items(): + arg_expr = matched_expr[arg_pattern] + if arg_expr not in func_args: + raise ValueError(f"Cannot find arg {arg_name} in the fused function parameters") + arg_idx[arg_name] = func_args.index(arg_expr) + + return arg_idx @relax.expr_functor.mutator @@ -566,7 +589,8 @@ def __init__( def handle_conv2d(self, f, op_type): """Tune and annotate a conv2d op.""" - signature, op_attrs = _extract_relax_function_info(f) + signature = _extract_relax_function_signature(f) + op_attrs = _get_call_node(f.body, "relax.nn.conv2d").attrs d_shape = signature["arg0_shape"] w_shape = signature["arg1_shape"] @@ -622,18 +646,29 @@ def handle_conv2d(self, f, op_type): def handle_matmul(self, f, op_type): """Tune and annotate a dense op.""" - signature, _ = _extract_relax_function_info(f) + signature = _extract_relax_function_signature(f) + arg_idx = _extract_arg_idx(op_type, f) + + lhs_arg = f"arg{arg_idx['lhs']}" + rhs_arg = f"arg{arg_idx['rhs']}" - arg0_shape = signature["arg0_shape"] - arg1_shape = signature["arg1_shape"] + lhs_shape = signature[f"{lhs_arg}_shape"] + rhs_shape = signature[f"{rhs_arg}_shape"] out_shape = signature["ret_shape"] - arg0_dtype = signature["arg0_dtype"] - arg1_dtype = signature["arg1_dtype"] + lhs_dtype = signature[f"{lhs_arg}_dtype"] + rhs_dtype = signature[f"{rhs_arg}_dtype"] out_dtype = signature["ret_dtype"] - MM = arg0_shape[0] - KK = arg0_shape[1] - NN = arg1_shape[1] + MM = lhs_shape[0] + KK = lhs_shape[1] + if "transposed" in op_type: + NN = rhs_shape[0] + ldb = "K" + layout_b = LayoutType.ColumnMajor + else: + NN = rhs_shape[1] + ldb = "N" + layout_b = LayoutType.RowMajor use_3xtf32 = self.options.get("use_3xtf32", False) find_first_valid = self.options.get("find_first_valid", True) @@ -645,26 +680,29 @@ def handle_matmul(self, f, op_type): NN, KK, out_dtype, - arg0_dtype, - arg1_dtype, + lhs_dtype, + rhs_dtype, use_3xtf32, batched=False, find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, - layout_b=LayoutType.RowMajor, + layout_b=layout_b, ) return f.with_attrs( { "op_type": op_type, - "arg0_dtype": arg0_dtype, - "arg1_dtype": arg1_dtype, + "lhs_arg_idx": arg_idx["lhs"], + "rhs_arg_idx": arg_idx["rhs"], + "bias_arg_idx": arg_idx.get("bias"), + "arg0_dtype": signature["arg0_dtype"], + "arg1_dtype": signature["arg1_dtype"], "ret_dtype": out_dtype, - "arg0_shape": arg0_shape, - "arg1_shape": arg1_shape, + "arg0_shape": signature["arg0_shape"], + "arg1_shape": signature["arg1_shape"], "ret_shape": out_shape, "lda": "K", - "ldb": "N", + "ldb": ldb, "ldc": "N", "cutlass_op_name": op_name, "cutlass_op_def": op_def, diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py index 58f5de6a9c9a..3e74cbaec8db 100644 --- a/python/tvm/contrib/cutlass/gemm_operation.py +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -259,7 +259,7 @@ def emit(self, operation, no_beta_scaling=False, batched=False): return substitute_template(gemm_template, values) -def instantiate_gemm_template(attrs, func_args): +def instantiate_gemm_template(attrs): """Return CUTLASS host code for GEMM based on a template and the provided attribute map.""" template = """ @@ -277,8 +277,8 @@ def instantiate_gemm_template(attrs, func_args): cutlass::gemm::GemmCoord problem_size(M, N, K); ElementComputeEpilogue alpha = ElementComputeEpilogue(1); ElementComputeEpilogue beta = ElementComputeEpilogue(${beta}); - void* ptr_a = (void*)(${arg0}->data); - void* ptr_b = (void*)(${arg1}->data); + void* ptr_a = (void*)(${lhs_arg}->data); + void* ptr_b = (void*)(${rhs_arg}->data); ${bias_decl} void* ptr_out = (void*)(out0->data); @@ -310,7 +310,7 @@ def instantiate_gemm_template(attrs, func_args): if has_bias: aux_map.update( { - "bias_decl": "void* ptr_c_bias = (void*)(${arg2}->data);\n", + "bias_decl": "void* ptr_c_bias = (void*)(${bias_arg}->data);\n", "ptr_c": "ptr_c_bias", "c_stride": "0", } @@ -342,7 +342,4 @@ def instantiate_gemm_template(attrs, func_args): template = substitute_template(template, aux_map) - for i, arg in enumerate(func_args): - attrs["arg{}".format(i)] = arg - return substitute_template(template, attrs) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index d3ab020839f3..92bf04e863e3 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -17,27 +17,28 @@ # pylint: disable=invalid-name """Common functions and classes for CUTLASS GEMM and Conv2d geneator.""" import logging +import multiprocessing import os import re -import tempfile import subprocess -import multiprocessing +import tempfile + import tvm._ffi -from tvm.tir import IntImm from tvm.runtime import Object +from tvm.tir import IntImm + from . import _ffi_api as ffi +from .conv2d_operation import instantiate_conv2d_template +from .gemm_operation import instantiate_gemm_template from .library import ( - MathInstruction, DataType, DataTypeTag, - OpcodeClass, + EpilogueFunctor, + MathInstruction, MathOperation, + OpcodeClass, TileDescription, - EpilogueFunctor, ) -from .gemm_operation import instantiate_gemm_template -from .conv2d_operation import instantiate_conv2d_template - logger = logging.getLogger("cutlass") @@ -371,6 +372,10 @@ def get_tile_descriptions(math_inst): "cutlass.matmul_bias": (EpilogueFunctor.LinearCombinationBias, True), "cutlass.matmul_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True), "cutlass.matmul_bias_gelu": (EpilogueFunctor.LinearCombinationGelu, False), + "cutlass.matmul_transposed": (EpilogueFunctor.LinearCombination, False), + "cutlass.matmul_transposed_bias": (EpilogueFunctor.LinearCombinationBias, True), + "cutlass.matmul_transposed_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True), + "cutlass.matmul_transposed_bias_gelu": (EpilogueFunctor.LinearCombinationGelu, False), "cutlass.batch_matmul": (EpilogueFunctor.LinearCombination, False), "cutlass.conv2d_bias_hardswish": (EpilogueFunctor.LinearCombinationHardSwish, False), "cutlass.conv2d_bias_silu": (EpilogueFunctor.LinearCombinationSilu, False), @@ -454,6 +459,13 @@ def __init__(self, code, headers): self.__init_handle_by_constructor__(ffi.CodegenResult, code, headers) +def _get_optional_int_annotation(annotations, key, default=None): + value = annotations.get(key, default) + if value is not None: + return int(value) + return value + + @tvm._ffi.register_func("contrib.cutlass.instantiate_template") def instantiate_template(func_name, annotations, func_args): """Return CUTLASS host code based on a template and the provided annotations. @@ -519,32 +531,69 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ if "dense" in func_name or "matmul" in func_name: batched = "batch_matmul" in func_name batched_offset = 1 if batched else 0 - attrs["K"] = str(int(arg0_shape[batched_offset + 1])) - attrs["M"] = get_dim(arg0_shape[batched_offset], func_args[0], 0, batched_offset) - - if annotations["ldb"] == "N": - attrs["N"] = get_dim(arg1_shape[batched_offset + 1], func_args[1], 1, batched_offset) + transposed = "transposed" in func_name + lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx", 0) + rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx", 1) + bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", 2) + lhs_arg = func_args[lhs_arg_idx] + rhs_arg = func_args[rhs_arg_idx] + lhs_shape = annotations[f"arg{lhs_arg_idx}_shape"] + rhs_shape = annotations[f"arg{rhs_arg_idx}_shape"] + + attrs["lhs_arg"] = lhs_arg + attrs["rhs_arg"] = rhs_arg + if len(func_args) > 2: + attrs["bias_arg"] = func_args[bias_arg_idx] + attrs["ElementInputA"] = DataTypeTag[dtype_map[annotations[f"arg{lhs_arg_idx}_dtype"]]] + attrs["ElementInputB"] = DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]] + attrs["ElementOutput"] = DataTypeTag[dtype_map[annotations["ret_dtype"]]] + + attrs["K"] = str(int(lhs_shape[batched_offset + 1])) + attrs["M"] = get_dim(lhs_shape[batched_offset], lhs_arg, 0, batched_offset) + + if transposed: + attrs["N"] = get_dim(rhs_shape[batched_offset], rhs_arg, 0, batched_offset) else: - attrs["N"] = get_dim(arg1_shape[batched_offset], func_args[1], 0, batched_offset) + attrs["N"] = get_dim(rhs_shape[batched_offset + 1], rhs_arg, 1, batched_offset) if batched: headers.append("cutlass/gemm/device/gemm_batched.h") - attrs["batch"] = get_dim(arg0_shape[0], func_args[0], 0) - attrs["batch_stride_A"] = get_batch_stride(annotations["batch_stride_A"], 0, 0, 1, 2) - attrs["batch_stride_B"] = get_batch_stride(annotations["batch_stride_B"], 1, 1, 1, 2) + attrs["batch"] = get_dim(lhs_shape[0], lhs_arg, 0) + attrs["batch_stride_A"] = get_batch_stride( + annotations["batch_stride_A"], + lhs_arg_idx, + lhs_arg_idx, + 1, + 2, + ) + attrs["batch_stride_B"] = get_batch_stride( + annotations["batch_stride_B"], + rhs_arg_idx, + rhs_arg_idx, + 1, + 2, + ) - if annotations["ldb"] == "N": + if transposed: attrs["batch_stride_C"] = get_batch_stride( - annotations["batch_stride_C"], 0, 1, 1, 2 + annotations["batch_stride_C"], + lhs_arg_idx, + rhs_arg_idx, + 1, + 1, ) else: attrs["batch_stride_C"] = get_batch_stride( - annotations["batch_stride_C"], 0, 1, 1, 1 + annotations["batch_stride_C"], + lhs_arg_idx, + rhs_arg_idx, + 1, + 2, ) else: headers.append("cutlass/gemm/device/gemm.h") - code = instantiate_gemm_template(attrs, func_args) + code = instantiate_gemm_template(attrs) return CodegenResult(code, headers) elif "conv2d" in func_name: diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index d0a1942ebdcb..e86f8c607436 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -87,6 +87,7 @@ from . import block_builder from . import op from . import struct_info +from . import backend # VM from .vm_build import build, Executable diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 20cf57a40a5c..51684abb06ee 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -66,6 +66,36 @@ activation="relax.nn.gelu", ), ), + ( + "cutlass.matmul_transposed", + make_matmul_pattern( + with_bias=False, + transposed_rhs=True, + ), + ), + ( + "cutlass.matmul_transposed_bias", + make_matmul_pattern( + with_bias=True, + transposed_rhs=True, + ), + ), + ( + "cutlass.matmul_transposed_bias_relu", + make_matmul_pattern( + with_bias=True, + activation="relax.nn.relu", + transposed_rhs=True, + ), + ), + ( + "cutlass.matmul_transposed_bias_gelu", + make_matmul_pattern( + with_bias=True, + activation="relax.nn.gelu", + transposed_rhs=True, + ), + ), ] ) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 9e1963f7edfd..300b0af568c0 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -204,6 +204,31 @@ def match(self, expr, var2val: Optional[Dict[Var, Expr]] = None) -> bool: """ return ffi.match_expr(self, expr, var2val) # type: ignore + def extract_matched_expr( + self, expr, var2val: Optional[Dict[Var, Expr]] = None + ) -> Optional[Dict["DFPattern", Expr]]: + """ + Match a relax.Expr and return a map from matching patterns to matched expressions. + + Parameters + ---------- + expr : tvm.relax.Expr + The expression to match + var2val : Optional[Dict[tvm.relax.Var, tvm.relax.Expr]] + A mapping from relax.Var to relax.Expr for autojump. + + Returns + ------- + result: Optional[Dict[DFPattern, Expr]] + Map from matching patterns to matched expressions. + Return None if the pattern does not match expr. + + Note + ---- + Check the note of `match` for the meaning of var2val. + """ + return ffi.extract_matched_expr(self, expr, var2val) + def used_by(self, other: Union["DFPattern", "PatternSeq"], index=-1) -> "PatternSeq": """ The current pattern being used by another pattern (sequence) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 92eb452a0065..da8c6ce2da78 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -515,6 +515,8 @@ Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, return matching; } +TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); + bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); } diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 673155342cbf..af3d40d9c40f 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -255,9 +255,12 @@ def test_conv2d_offload(): tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) -def get_relax_matmul_module(x, y, with_bias=False, activation=None): +def get_relax_matmul_module(x, y, transposed_y=False, with_bias=False, activation=None): m, k = x.shape - n = y.shape[-1] + if transposed_y: + n = y.shape[-2] + else: + n = y.shape[-1] dtype = str(x.dtype) from tvm.script.ir_builder import IRBuilder @@ -266,13 +269,15 @@ def get_relax_matmul_module(x, y, with_bias=False, activation=None): with IRBuilder() as builder: with relax_builder.function(): R.func_name("main") - x = R.arg("x", R.Tensor((m, k), dtype)) - y = R.arg("y", R.Tensor((k, n), dtype)) + x = R.arg("x", R.Tensor(x.shape, dtype)) + y = R.arg("y", R.Tensor(y.shape, dtype)) if with_bias: bias = R.arg("bias", R.Tensor((n,), dtype)) with R.dataflow() as frame: - result = R.emit(R.matmul(x, y)) + if transposed_y: + y = R.emit(R.permute_dims(y)) + result = R.emit(R.matmul(x, y, out_dtype=dtype)) if with_bias: result = R.emit(result + bias) if activation is not None: @@ -380,5 +385,58 @@ def test_kernel_sharing(): tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) +def test_matmul_transposed_offload(matmul_x, matmul_y): + x, y = matmul_x, matmul_y + + mod = get_relax_matmul_module(x, y.transpose(), transposed_y=True) + out = get_result_with_relax_cutlass_offload(mod, x, y.transpose()) + ref_relay_expr = get_relay_matmul(x.shape, y.shape[::-1]) + ref = get_relay_ref(ref_relay_expr, x, y.transpose()) + + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-4) + + +def test_matmul_transposed_bias_offload(matmul_x, matmul_y, matmul_bias): + x, y, bias = matmul_x, matmul_y, matmul_bias + + mod = get_relax_matmul_module( + x, y.transpose(), transposed_y=True, with_bias=True, activation=None + ) + out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias) + + ref_relay_expr = get_relay_matmul_bias(x.shape, y.shape[::-1]) + ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-4) + + +def test_matmul_transposed_bias_relu_offload(matmul_x, matmul_y, matmul_bias): + x, y, bias = matmul_x, matmul_y, matmul_bias + + mod = get_relax_matmul_module( + x, y.transpose(), transposed_y=True, with_bias=True, activation=R.nn.relu + ) + out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias) + + ref_relay_expr = get_relay_matmul_bias_relu(x.shape, y.shape[::-1]) + ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-4) + + +def test_matmul_transposed_bias_gelu_offload(matmul_x, matmul_y, matmul_bias): + x, y, bias = matmul_x, matmul_y, matmul_bias + + mod = get_relax_matmul_module( + x, y.transpose(), transposed_y=True, with_bias=True, activation=R.nn.gelu + ) + out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias) + + ref_relay_expr = get_relay_matmul_bias_gelu(x.shape, y.shape[::-1]) + ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-3) + + if __name__ == "__main__": tvm.testing.main() From ff21d66ab80dabb13a3cc43e26de56b3047cf8c4 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Mon, 27 Feb 2023 06:09:40 -0800 Subject: [PATCH 75/81] [Unity][TVMScript] emit_te sugar (#14123) This PR adds R.emit_te meta-programming mechanism to emit a topi operator from TVMScript --- python/tvm/script/ir_builder/relax/ir.py | 34 +++++++++++++++++- src/script/ir_builder/relax/ir.cc | 2 +- .../python/relax/test_tvmscript_ir_builder.py | 36 +++++++++++++------ 3 files changed, 59 insertions(+), 13 deletions(-) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 63efea135c15..045fe9ddd99a 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -20,12 +20,13 @@ import builtins import functools import inspect -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Callable import tvm from tvm import DataType, relax from tvm.ir import PrimExpr from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, VarBinding, const +from tvm.relax.block_builder import BlockBuilder as rx_bb ############################### Operators ############################### from tvm.relax.op import ( @@ -304,6 +305,7 @@ def wrapped(*args, **kwargs): call_builtin_with_ctx = _sinfo_arg_wrapper(call_builtin_with_ctx) # pylint: disable=invalid-name + ############################### Bindings ############################### @@ -325,6 +327,35 @@ def emit(value: Expr, annotate_struct_info: Optional[StructInfo] = None) -> Var: return _ffi_api.Emit(value, annotate_struct_info) # type: ignore[attr-defined] # pylint: disable=no-member +def emit_te(func: Callable, *args: Any, **kwargs: Any) -> Var: + """Emit a call node according to the te function. + This function converts arguments from relax expression to te tensor, + The callback func should return a te tensor or a list of te tensors. + + Parameters + ---------- + func : Callable + A function that returns a te tensor or a list of te tensors. + + args : Any, optional + arguments passed to the function. + + kwargs : Any, optional + The keyword arguments passed to the function. + Note that the key "primfunc_name_hint" is reserved for passing name hint + to the PrimFunc that gets generated. + + Returns + ------- + var : Var + A newly created variable that gets bound to the call code. + """ + + # Levarage the util function call_te in Relax Block Blocker + emit_expr = rx_bb().call_te(func, *args, **kwargs) + return emit(emit_expr) + + def emit_match_cast(value: Expr, struct_info: StructInfo) -> Var: """Emit a match_cast binding to the last binding block frame. Parameters @@ -511,6 +542,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "divide", "dtype", "emit", + "emit_te", "emit_var_binding", "emit_match_cast", "equal", diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index ddfb1ddfa35f..71a0651de859 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -108,7 +108,7 @@ void FuncRetValue(const tvm::relax::Expr& value) { if (block_frame.defined()) { block_frame.value()->ExitWithScope(); ICHECK(!IRBuilder::Current()->FindFrame()) - << "All block frame are supposed to be popped out already"; + << "ValueError: Relax functions don't support return in true/false branch of If Node."; } // Step 2. Add the output value to the function frame. FunctionFrame frame = FindFunctionFrame("return"); diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py index eb0aaf56040b..014b00af0097 100644 --- a/tests/python/relax/test_tvmscript_ir_builder.py +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -16,7 +16,7 @@ # under the License. import tvm import tvm.testing -from tvm import relax, tir +from tvm import relax, tir, topi from tvm.script.ir_builder import relax as R from tvm.script.ir_builder.base import IRBuilder @@ -57,15 +57,19 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) assert func.body.body.name_hint == "out" -def test_match_cast(): - """ +def test_emits(): + """Tests for R.emit, R.emit_match_cast, R.emit_var_binding, R.emit_te + @R.function - def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): + def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")) -> R.Shape(ndim=2): m = T.int64() n = T.int64() - _ = R.match_cast(x, R.Tensor((m,), "float32")) - y1 = R.match_cast(x, R.Tensor((n,), "float32")) - return (m, n * 2) + gv: R.Tensor((m,), dtype="float32") = R.match_cast(x, R.Tensor((m,), dtype="float32")) + gv1: R.Tensor((n,), dtype="float32") = R.match_cast(y, R.Tensor((n,), dtype="float32")) + v: R.Tensor((n,), dtype="float32") = gv1 + gv2 = R.call_tir(add, (v, v), out_sinfo=R.Tensor((n,), dtype="float32")) + gv3: R.Tensor((n,), dtype="float32") = gv2 + return R.shape([m, n * 2]) """ # create with Script IRBuilder with IRBuilder() as ir_builder: @@ -77,23 +81,33 @@ def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): n = tir.Var("n", dtype="int64") _ = R.emit_match_cast(x, relax.TensorStructInfo((m,), "float32")) y1 = R.emit_match_cast(y, relax.TensorStructInfo((n,), "float32")) - IRBuilder.name("y1", y1) + v = relax.Var("v", relax.TensorStructInfo((n,), "float32")) + vb = relax.VarBinding(v, y1) + v = R.emit_var_binding(vb) + v1 = R.emit_te(topi.add, v, v) + R.emit(v1) + + IRBuilder.name("v", v) R.func_ret_value(relax.ShapeExpr([m, n * 2])) func = ir_builder.get() # create with BlockBuilder - x = relax.Var("x", relax.TensorStructInfo(dtype="float32", ndim=-1)) - y = relax.Var("y", relax.TensorStructInfo(dtype="float32", ndim=-1)) m = tir.Var("m", dtype="int64") n = tir.Var("n", dtype="int64") + x = relax.Var("x", relax.TensorStructInfo(dtype="float32", ndim=-1)) + y = relax.Var("y", relax.TensorStructInfo(dtype="float32", ndim=-1)) + v = relax.Var("v", relax.TensorStructInfo((n,), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x, y)): _ = bb.match_cast(x, relax.TensorStructInfo((m,), "float32")) y1 = bb.match_cast(y, relax.TensorStructInfo((n,), "float32")) + bb.emit_normalized(relax.VarBinding(v, y1)) + v1 = bb.emit_te(topi.add, v, v) + bb.emit(v1) bb.emit_func_output(relax.ShapeExpr([m, n * 2])) mod = bb.get() - tvm.ir.assert_structural_equal(func, mod["foo"]) + tvm.ir.assert_structural_equal(func, mod["foo"], map_free_vars=True) def test_dataflow_block(): From 15ba19fa78148e6d9146fdba4539a0d9ba1dbf47 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 27 Feb 2023 12:20:39 -0800 Subject: [PATCH 76/81] [Unity][BYOC] Assign group to unused bindings and ignroe PrimFunc (#14139) * [Unity][BYOC] Assign group to unused bindings and ignroe PrimFunc * Update fuse_ops.cc --- src/relax/transform/fuse_ops.cc | 46 +++---- src/relax/transform/run_codegen.cc | 3 + .../test_transform_fuse_ops_by_pattern.py | 121 +++++++++++++++++- 3 files changed, 144 insertions(+), 26 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 813c0c8f0366..c5042d019110 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -890,14 +890,6 @@ IRModule MakeGroupedFunctions( return OperatorFusor(mod, partition, lift_constants).Transform(); } -static Map GetBindingInverse(const Map& binding) { - Map value_to_bound_var; - for (const auto& [var, val] : binding) { - value_to_bound_var.Set(val, var); - } - return value_to_bound_var; -} - /*! \brief Create a "partitioning", a map from interior / leaf expr to its representative group, * based on the provided pattern. The result can be passed to OperatorFusor above to fuse operations * in a group and create a grouped function. @@ -909,21 +901,26 @@ class PatternBasedPartitioner : ExprVisitor { using ExprVisitor::VisitExpr_; static GroupMap Run(String pattern_name, DFPattern pattern, Expr expr, support::Arena* arena) { - PatternBasedPartitioner part(pattern_name, pattern, AnalyzeVar2Value(expr)); - // Initialize each expr to have its own group - PostOrderVisit( - expr, [arena, &part](const Expr& e) { part.group_map_[e.get()] = arena->make(); }); + PatternBasedPartitioner part(pattern_name, pattern, arena); part.VisitExpr(expr); return part.group_map_; } - PatternBasedPartitioner(String pattern_name, DFPattern pattern, const Map& bindings) - : pat_name_(pattern_name), - pat_(pattern), - bindings_(bindings), - value_to_bound_var_(GetBindingInverse(bindings)) {} + PatternBasedPartitioner(String pattern_name, DFPattern pattern, support::Arena* arena) + : pat_name_(pattern_name), pat_(pattern), arena_(arena) {} + + void VisitVarDef(const Var& var) final { group_map_[var.get()] = arena_->make(); } + + void VisitBinding_(const VarBindingNode* binding) final { + bindings_.Set(binding->var, binding->value); + value_to_bound_var_.Set(binding->value, binding->var); + ExprVisitor::VisitBinding_(binding); + } + + void VisitExpr_(const ConstantNode* op) final { group_map_[op] = arena_->make(); } - void VisitExpr_(const CallNode* call) override { + void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { + VisitVarDef(binding->var); if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef(call), bindings_)) { // If a match is found, put all matching expressions into the same group. // OperatorFusor also requires that the bound variable be in the same group as the RHS value. @@ -939,15 +936,12 @@ class PatternBasedPartitioner : ExprVisitor { // conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv) // parent_group corresponds to the group of "conv1" above. - auto parent_group = GetGroupForBoundVar(GetRef(call)); + auto parent_group = GetGroupForBoundVar(binding->var); ICHECK(parent_group); parent_group->attrs.Set(attr::kComposite, pat_name_); - for (const auto& [pat, match] : matches_opt.value()) { - ICHECK(group_map_.count(match.get())); // Put all matching call nodes into the parent group. if (pat->IsInstance() && match != GetRef(call)) { - AddToGroup(match, parent_group); // Put the bound variable on the LHS into the same parent group. AddToGroup(value_to_bound_var_[match], parent_group); } @@ -964,15 +958,14 @@ class PatternBasedPartitioner : ExprVisitor { } } - Group* GetGroupForBoundVar(Expr e) { - ICHECK(value_to_bound_var_.count(e)); - auto bound_var = value_to_bound_var_[e]; + Group* GetGroupForBoundVar(const Var& bound_var) { ICHECK(group_map_.count(bound_var.get())); return group_map_[bound_var.get()]->FindRoot(); } String pat_name_; DFPattern pat_; + support::Arena* arena_; Map bindings_; Map value_to_bound_var_; GroupMap group_map_; @@ -1055,6 +1048,9 @@ IRModule FuseOpsByPattern(const tvm::Array& pattern_names, for (size_t i = 0; i < pattern_names.size(); ++i) { OperatorFusor::GroupMap group_map; for (const auto& entry : mod->functions) { + if (entry.second->IsInstance()) { + continue; + } auto map = PatternBasedPartitioner::Run(pattern_names[i], patterns[i], entry.second, &arena); group_map.insert(map.begin(), map.end()); } diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index 114b7d2a345d..7deeb139d1a0 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -138,6 +138,9 @@ class CodeGenRunner : ExprMutator { std::unordered_map> target_functions; for (const auto& entry : mod->functions) { + if (entry.second->IsInstance()) { + continue; + } PostOrderVisit(entry.second, [&target_functions](Expr e) { if (e->IsInstance()) { auto f = Downcast(e); diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index da5b92fb64e0..21f952096be9 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -20,7 +20,7 @@ import tvm from tvm import relax -from tvm.script import relax as R +from tvm.script import relax as R, tir as T, ir as I from tvm.relax.dpl.pattern import make_fused_bias_activation_pattern, is_op, wildcard @@ -460,5 +460,124 @@ def test_multiple_calls_same_extern(): check(Conv2dx2, [("cutlass.conv2d", pat)], Conv2dx2_partitioned, annoatate_codegen=True) +def test_ignore_call_tir(): + @I.ir_module + class Conv2dReLUCallTIR: + @T.prim_func + def relu( + data: T.Buffer((64, 64, 56, 56), "float32"), out: T.Buffer((64, 64, 56, 56), "float32") + ): + for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56): + with T.block("root"): + i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + out[i, j, k, l] = T.max(data[i, j, k, l], 0.0) + + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.conv2d(data, weight1, padding=(1, 1)) + relu1 = R.call_tir(relu, (conv1,), R.Tensor((64, 64, 56, 56), "float32")) + R.output(relu1) + + return relu1 + + @I.ir_module + class Conv2dReLUCallTIR_partitioned: + @T.prim_func + def relu( + data: T.Buffer((64, 64, 56, 56), "float32"), out: T.Buffer((64, 64, 56, 56), "float32") + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56): + with T.block("root"): + i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(data[i, j, k, l]) + T.writes(out[i, j, k, l]) + out[i, j, k, l] = T.max(data[i, j, k, l], T.float32(0)) + + @R.function + def fused_relax_nn_conv2d( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1}) + with R.dataflow(): + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data, + weight1, + padding=(1, 1), + ) + R.output(gv) + return gv + + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((64, 64, 56, 56), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_nn_conv2d( + data, weight1 + ) + relu1 = R.call_tir( + relu, (lv,), out_sinfo=R.Tensor((64, 64, 56, 56), dtype="float32") + ) + R.output(relu1) + return relu1 + + pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None) + check(Conv2dReLUCallTIR, [("cutlass.conv2d", pat)], Conv2dReLUCallTIR_partitioned) + + +def test_unused(): + @I.ir_module + class Conv2dReLU: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = R.nn.conv2d(data, weight1, padding=(1, 1)) + relu = R.nn.relu(data) + R.output(conv1) + + return conv1 + + @I.ir_module + class Conv2dReLU_partitioned: + @R.function + def fused_relax_nn_conv2d( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1}) + with R.dataflow(): + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data, weight1, padding=(1, 1) + ) + R.output(gv) + return gv + + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_nn_conv2d( + data, weight1 + ) + relu: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(data) + R.output(gv) + return gv + + pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None) + check(Conv2dReLU, [("cutlass.conv2d", pat)], Conv2dReLU_partitioned) + + if __name__ == "__main__": pytest.main([__file__]) From 2d0c2e976127d7fa7d6ce4182e1a7cb17e7ec7fd Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Fri, 24 Feb 2023 10:20:58 -0800 Subject: [PATCH 77/81] [Unity][TVMScript] Multiple return support in Relax --- include/tvm/relax/expr.h | 46 +++++++++++++++++++ include/tvm/script/ir_builder/base.h | 1 + include/tvm/script/ir_builder/relax/ir.h | 2 +- python/tvm/script/ir_builder/relax/ir.py | 10 ++-- python/tvm/script/parser/core/parser.py | 1 + python/tvm/script/parser/relax/parser.py | 3 +- src/relax/ir/expr.cc | 11 +++++ src/script/ir_builder/relax/frame.cc | 13 ++++-- src/script/ir_builder/relax/ir.cc | 32 +++++++++++-- src/script/ir_builder/relax/utils.h | 8 +++- tests/python/relax/test_codegen_cutlass.py | 2 +- .../python/relax/test_tvmscript_ir_builder.py | 4 +- tests/python/relax/test_tvmscript_parser.py | 18 +++++++- 13 files changed, 130 insertions(+), 21 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 0788193ee7c4..7b54fa1864e6 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -625,6 +625,52 @@ class PrimValue : public LeafExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimValueNode); }; +/*! + * \brief NullExpr. + * + * Expression representing a Null expression. + */ +class NullExprNode : public LeafExprNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("struct_info_", &struct_info_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const NullExprNode* other, SEqualReducer equal) const { + // struct info can be deterministically derived from data. + return equal(struct_info_, other->struct_info_); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(struct_info_); } + + static constexpr const char* _type_key = "relax.expr.NullExpr"; + TVM_DECLARE_FINAL_OBJECT_INFO(NullExprNode, LeafExprNode); +}; + +/*! + * \brief Managed reference to NullExprNode + * \sa NullExprNode + */ +class NullExpr : public LeafExpr { + public: + /*! + * \brief The constructor + * \param span The source span of the expression. + */ + TVM_DLL explicit NullExpr(Span span); + + /*! + * \brief Create a null expression. + * \param span The source span of the expression. + * \return The created prim value. + */ + + TVM_DEFINE_OBJECT_REF_METHODS(NullExpr, LeafExpr, NullExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(NullExprNode); +}; + /*! * \brief Represent a string literal constant. */ diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index a00ea5768e23..d263ea360d67 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -236,6 +236,7 @@ class IRBuilder : public runtime::ObjectRef { * \sa IRBuilder::ExitWithScope * \sa tvm::support::With */ + static std::vector All(); static IRBuilder Current(); /*! \brief See if the current thread-local scope has an IRBuilder. */ static bool IsInScope(); diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index 42aa591a95b7..d36c839bd736 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -67,7 +67,7 @@ TVM_DLL void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo); * \brief Specify the return value of the last function frame. * \param value The return value. */ -TVM_DLL void FuncRetValue(const tvm::relax::Expr& value); +TVM_DLL void RetValue(const tvm::relax::Expr& value); ///////////////////////////// BindingBlock ////////////////////////////// diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 045fe9ddd99a..fc89c96c3bca 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -190,14 +190,14 @@ def func_ret_struct_info(ret_sinfo: StructInfo) -> None: return _ffi_api.FuncRetStructInfo(ret_sinfo) # type: ignore[attr-defined] # pylint: disable=no-member -def func_ret_value(value: Expr) -> None: - """Specify the return value of the last function frame. +def ret_value(value: Expr) -> None: + """Specify the return value of the last frame. Parameters ---------- value: Expr - The function return value. + The return value. """ - return _ffi_api.FuncRetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.RetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member ############################# BindingBlock ############################## @@ -557,7 +557,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "func_attr", "func_name", "func_ret_struct_info", - "func_ret_value", + "ret_value", "function", "greater", "greater_equal", diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 105164ed5ffc..619b57698c11 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -683,4 +683,5 @@ def visit_Return(self, node: doc.Return) -> Any: # pylint: disable=invalid-name res : Any The visiting result. """ + print("yes visiting return, value: ", node.value) return _dispatch(self, "Return")(self, node) diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index e1af1c1df346..02d935e08e44 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -321,7 +321,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: def visit_return(self: Parser, node: doc.Assign) -> None: value = self.eval_expr(node.value) value = convert_to_expr(value) - R.func_ret_value(value) + R.ret_value(value) @dispatch.register(token="relax", type_name="If") @@ -331,6 +331,7 @@ def visit_if(self: Parser, node: doc.If) -> None: with R.If(self.eval_expr(node.test)) as if_frame: with self.var_table.with_frame(): with R.Then(): + print("fuck here") self.visit_body(node.body) with self.var_table.with_frame(): with R.Else(): diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 5392be7cb69b..70cb4a5f80ff 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -309,6 +309,17 @@ TVM_REGISTER_GLOBAL("relax.PrimValue").set_body_typed([](PrimExpr value, Span sp return PrimValue(value, span); }); +NullExpr::NullExpr(Span span) { + ObjectPtr n = make_object(); + n->checked_type_ = ObjectType(); + n->struct_info_ = ObjectStructInfo(); + n->span = std::move(span); +} + +TVM_REGISTER_NODE_TYPE(NullExprNode); + +TVM_REGISTER_GLOBAL("relax.NullExpr").set_body_typed([](Span span) { return NullExpr(span); }); + StringImm::StringImm(String value, Span span) { ObjectPtr n = make_object(); n->value = std::move(value); diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index c78b9e73c534..bd3209cd6319 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -52,10 +52,17 @@ void FunctionFrameNode::ExitWithScope() { IRBuilder builder = IRBuilder::Current(); SeqExprFrameNode::ExitWithScope(); // Step 1: Create the function. - CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use " - "`return` to return an Expr"; + // CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use " + // "`return` to return an Expr"; this->block_builder->BeginScope(params); - Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); + Expr body; + if (output.defined()) { + body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); + } else { + body = + this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, tvm::relax::NullExpr())); + } + auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); this->block_builder->EndScope(); tvm::relax::Function func(/*params=*/params, diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 71a0651de859..f67bee9c4149 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -96,27 +96,49 @@ void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) { frame->ret_struct_info = ret_sinfo; } -void FuncRetValue(const tvm::relax::Expr& value) { +void RetValue(const tvm::relax::Expr& value) { // Step 0. Normalize the value. const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); tvm::relax::Expr normalized_value = block_builder->Normalize(value); + LOG(INFO) << "normalized_value: " << normalized_value; // Step 1. The current Relax TVMScript syntax only allows function return appearing at the end of // a function body. Therefore if there is any unended block frame when dealing with function // return, we should end the block frame. Optional block_frame = IRBuilder::Current()->GetLastFrame(); + // Exit BlockFrame if (block_frame.defined()) { block_frame.value()->ExitWithScope(); ICHECK(!IRBuilder::Current()->FindFrame()) << "ValueError: Relax functions don't support return in true/false branch of If Node."; } // Step 2. Add the output value to the function frame. + Array all_frames = IRBuilder::Current()->frames; + int i = 0; + for (auto f : all_frames) { + LOG(INFO) << "yongwww frame_" << i++ << " = " << f; + } + + // IfFrame if_frame = IRBuilder::Current()->FindFrame().value(); + // LOG(INFO) << "return if_frame: " << if_frame; + + IRBuilderFrame last_frame = all_frames[all_frames.size() - 1]; + Optional then_frame = IRBuilder::Current()->GetLastFrame(); + if (then_frame.defined()) { + then_frame.value()->output = std::move(normalized_value); + return; + } + Optional else_frame = IRBuilder::Current()->GetLastFrame(); + if (else_frame.defined()) { + else_frame.value()->output = std::move(normalized_value); + return; + } + // Optional func_frame = IRBuilder::Current()->GetLastFrame(); FunctionFrame frame = FindFunctionFrame("return"); - CHECK(!frame->output.defined()) - << "ValueError: Relax functions don't support multiple return statement. Please make sure " - "the return statement appears at the end of function."; + LOG(INFO) << "return FunctionFrame frame: " << frame; frame->output = std::move(normalized_value); + // block_frame.value()->ExitWithScope(); } TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function); @@ -124,7 +146,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg); TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName); TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs); TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetStructInfo").set_body_typed(FuncRetStructInfo); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.RetValue").set_body_typed(RetValue); ///////////////////////////// BindingBlock ////////////////////////////// diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h index ae91d05769bd..ac64c70a645e 100644 --- a/src/script/ir_builder/relax/utils.h +++ b/src/script/ir_builder/relax/utils.h @@ -79,8 +79,12 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String } // Step 1. Check non-empty block and last binding is non-dataflow - CHECK(!frame->binding_blocks.empty()) - << "Empty body is not allowed for '" << method << "' statements."; + // CHECK(!frame->binding_blocks.empty()) + // << "Empty body is not allowed for '" << method << "' statements."; + if (frame->binding_blocks.empty()) { + LOG(INFO) << " frame output : " << frame->output; + return tvm::relax::SeqExpr({}, frame->output.value()); + } const tvm::relax::BindingBlock& last_block = frame->binding_blocks.back(); CHECK(!last_block->bindings.empty()) << "Blocks are expected to be non-empty."; diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index af3d40d9c40f..fd8c2f0bc561 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -284,7 +284,7 @@ def get_relax_matmul_module(x, y, transposed_y=False, with_bias=False, activatio result = R.emit(activation(result)) R.output(result) - R.func_ret_value(frame.output_vars[0]) + R.ret_value(frame.output_vars[0]) func = builder.get() return tvm.IRModule({"main": func}) diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py index 014b00af0097..298cfd9bf47e 100644 --- a/tests/python/relax/test_tvmscript_ir_builder.py +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -39,7 +39,7 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) R.call_tir("extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32")) ) IRBuilder.name("out", out) - R.func_ret_value(out) + R.ret_value(out) func = ir_builder.get() # create with BlockBuilder x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32")) @@ -137,7 +137,7 @@ def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): IRBuilder.name("gv", gv) R.output(gv) (gv,) = df.output_vars - R.func_ret_value(gv) + R.ret_value(gv) func = ir_builder.get() # create with BlockBuilder diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 9636a98b41b8..dd8f3f813cad 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1162,5 +1162,21 @@ def mul_add(x: R.Tensor) -> R.Tensor: _check(InputModule, OutputModule) +def test_control_flow(): + @tvm.script.ir_module + class ControlFlowExample: + @R.function + def foo(x: R.Tensor) -> R.Tensor: + y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool") + if y: + return R.add(x, x) + else: + return R.multiply(x, x) + return x + + ControlFlowExample.show() + + if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + test_control_flow() From 752c25343b2c6c1a38c67c0c5a5eb0f3be2b0717 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Tue, 28 Feb 2023 13:14:21 -0800 Subject: [PATCH 78/81] Use ReturnGlobalInfo --- include/tvm/ir/global_info.h | 26 +++++++++++ include/tvm/ir/module.h | 2 + include/tvm/script/ir_builder/relax/frame.h | 2 + python/tvm/ir/__init__.py | 2 +- python/tvm/ir/global_info.py | 30 ++++++++++++ python/tvm/ir/module.py | 10 ++++ python/tvm/script/ir_builder/ir/__init__.py | 3 ++ python/tvm/script/ir_builder/ir/ir.py | 45 +++++++++++++++++- python/tvm/script/parser/ir/__init__.py | 10 +++- python/tvm/script/parser/relax/parser.py | 51 ++++++++++++++++++++- src/ir/global_info.cc | 10 ++++ src/ir/module.cc | 8 ++++ src/script/ir_builder/ir/ir.cc | 17 +++++++ src/script/ir_builder/ir/utils.h | 9 ++-- src/script/ir_builder/relax/frame.cc | 1 + src/script/ir_builder/relax/ir.cc | 4 +- src/script/printer/ir/ir.cc | 11 +++++ src/script/printer/relax/binding.cc | 1 + tests/python/relax/test_tvmscript_parser.py | 23 ++++++++-- 19 files changed, 250 insertions(+), 15 deletions(-) diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h index 65b5e0a3d28d..3fd96dc37930 100644 --- a/include/tvm/ir/global_info.h +++ b/include/tvm/ir/global_info.h @@ -75,6 +75,32 @@ class DummyGlobalInfo : public GlobalInfo { TVM_DEFINE_OBJECT_REF_METHODS(DummyGlobalInfo, GlobalInfo, DummyGlobalInfoNode); }; +/*! + * \brief A return global info sub-class for expressions to return. + */ +class ReturnGlobalInfoNode : public GlobalInfoNode { + public: + Array return_exprs; + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "ReturnGlobalInfo"; + + TVM_DLL bool SEqualReduce(const ReturnGlobalInfoNode* other, SEqualReducer equal) const { + return equal(return_exprs, other->return_exprs); + } + + TVM_DLL void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(return_exprs); } + TVM_DECLARE_FINAL_OBJECT_INFO(ReturnGlobalInfoNode, GlobalInfoNode); +}; + +/*! + * \brief Managed reference to ReturnGlobalInfoNode. + * \sa ReturnGlobalInfoNode + */ +class ReturnGlobalInfo : public GlobalInfo { + public: + TVM_DEFINE_OBJECT_REF_METHODS(ReturnGlobalInfo, GlobalInfo, ReturnGlobalInfoNode); +}; + } // namespace tvm #endif // TVM_IR_GLOBAL_INFO_H_ diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 4c2d5cd81264..8e0cf201ef4e 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -221,6 +221,8 @@ class IRModuleNode : public Object { */ TVM_DLL void UpdateGlobalInfo(const String& name, const Array& info); + TVM_DLL Array GetGlobalInfo(const String& name); + /*! * \brief Remove a function from the global environment. * \param var The name of the global function to update. diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 0f544d3abcc2..c99b9220520c 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -100,6 +100,8 @@ class FunctionFrameNode : public SeqExprFrameNode { /*! \brief The function attributes. */ Map attrs; + + // todo(yongwww) Add Map> global_infos; /*! \brief The block builder to create Relax function. */ tvm::relax::BlockBuilder block_builder; diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 01fea2abbda7..9ada942a9e2f 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -34,7 +34,7 @@ from .container import Array, Map from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr from .function import BaseFunc, CallingConv -from .global_info import GlobalInfo, DummyGlobalInfo +from .global_info import GlobalInfo, ReturnGlobalInfo, DummyGlobalInfo from .memory_pools import ( ConstantMemoryPools, ConstantPoolInfo, diff --git a/python/tvm/ir/global_info.py b/python/tvm/ir/global_info.py index 17011e76a66c..5dbf4476a501 100644 --- a/python/tvm/ir/global_info.py +++ b/python/tvm/ir/global_info.py @@ -16,6 +16,8 @@ # under the License. """Global Info.""" import tvm +from typing import List +from tvm.ir import RelayExpr as Expr from tvm.runtime.object import Object from . import _ffi_api @@ -35,6 +37,34 @@ def same_as(self, other): return super().__eq__(other) +class ReturnGlobalInfo(GlobalInfo): + """ReturnGlobalInfo in the IR. + + Parameters + ---------- + return_exprs : List[Expr] + The expressions to be returned. + """ + + return_exprs: List[Expr] + + def __init__(self, return_exprs: List[Expr]) -> None: + print("yes entering ReturnGlobalInfo in global_info.py, return_exprs: ", return_exprs) + self.__init_handle_by_constructor__(_ffi_api.ReturnGlobalInfo, return_exprs) + + def add(): + pass + + def update(return_exprs: List[Expr]): + pass + + def get() -> GlobalInfo: + pass + + def get_exprs(self): + return self.return_exprs + + class DummyGlobalInfo(GlobalInfo): def __init__(self) -> None: self.__init_handle_by_constructor__( diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 707d46d0cdf8..d04de0f9ad99 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -166,6 +166,16 @@ def update_global_info(self, name, global_info): """ return _ffi_api.Module_UpdateGlobalInfo(self, name, global_info) + def get_global_info(self, name): + """Get global info in the module + + Parameters + ---------- + name: str + The name for the global info. + """ + return _ffi_api.Module_GetGlobalInfo(self, name) + def get_global_var(self, name): """Get a global variable in the function by name. diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index 68eda2cfeebf..91eb0535ffda 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -22,5 +22,8 @@ ir_module, module_attrs, module_global_infos, + module_get_global_infos, + module_update_global_infos, + return_global_info, dummy_global_info, ) diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 53c48b4cc540..90c1568b69c5 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -18,7 +18,8 @@ from typing import Dict, List -from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, DummyGlobalInfo +from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, ReturnGlobalInfo, DummyGlobalInfo +from tvm.ir import RelayExpr as Expr from tvm.runtime import Object as tvm_Object @@ -93,9 +94,51 @@ def module_global_infos(global_infos: Dict[str, List[GlobalInfo]]) -> None: return _ffi_api.ModuleGlobalInfos(global_infos) # type: ignore[attr-defined] # pylint: disable=no-member +def module_get_global_infos() -> Dict[str, List[GlobalInfo]]: + """Get the global infos of the ir_module frame. + + Returns + ---------- + ret: Dict[str, List[GlobalInfo]] + The module global infos. + """ + ginfos = _ffi_api.ModuleGetGlobalInfos() # type: ignore[attr-defined] # pylint: disable=no-member + # Map -> Python Dict + ret = {} + for (k, v) in ginfos.items(): + ret[k] = v + return ret + + +def module_update_global_infos(global_infos: Dict[str, List[GlobalInfo]]) -> None: + """Update the global infos of the ir_module frame. + Parameters + ---------- + global_infos: Dict[str, List[GlobalInfo]] + The module global infos. + """ + return _ffi_api.ModuleUpdateGlobalInfos(global_infos) # type: ignore[attr-defined] # pylint: disable=no-member + + ############################### GlobalInfo ############################### +def return_global_info(return_exprs: List[Expr]) -> ReturnGlobalInfo: + """Create a return global info expression. + Parameters + ---------- + return_exprs : List[Expr] + The expressions to be returned. + + Returns + ------- + res : ReturnGlobalInfo + The result return global info. + """ + print("yes return_global_info in ir_builder/ir.py") + return ReturnGlobalInfo(return_exprs) # type: ignore[attr-defined] # pylint: disable=no-member + + def dummy_global_info() -> DummyGlobalInfo: """Create a dummy global info expression. Returns diff --git a/python/tvm/script/parser/ir/__init__.py b/python/tvm/script/parser/ir/__init__.py index f8c9d4f0afc9..4dbb615b89e4 100644 --- a/python/tvm/script/parser/ir/__init__.py +++ b/python/tvm/script/parser/ir/__init__.py @@ -19,4 +19,12 @@ from . import parser as _parser from .entry import ir_module -__all__ = ["ir_module", "module_attrs", "module_global_infos", "dummy_global_info"] +__all__ = [ + "ir_module", + "module_attrs", + "module_global_infos", + "module_get_global_infos", + "module_update_global_infos", + "return_global_info", + "dummy_global_info", +] diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 02d935e08e44..45a05b3cb160 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -232,6 +232,7 @@ def pre_token_switch(self: Parser, node: doc.Expr) -> None: # pylint: disable=u @dispatch.register(token="relax", type_name="post_token_switch") def post_token_switch(self: Parser, node: doc.Expr) -> None: + print("it could be the last one") ir_builder = IRBuilder.current() result = ir_builder.get() ir_builder.__exit__(None, None, None) @@ -321,7 +322,52 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: def visit_return(self: Parser, node: doc.Assign) -> None: value = self.eval_expr(node.value) value = convert_to_expr(value) - R.ret_value(value) + """ + TODO (yongwww): + issue 1): Save all values into a global list, and add into global_info in the end of parsing -> Status: wip + => we can just have a single api like add_return_global_info into the ReturnGlobalInfo, + Solution: + o1: create one if no ReturnGlobalInfo found, therefore we can avoid saving values when parsing + o2: Create an IRModuleNode::GetGlobalInfo(String name), plus UpdateGlobalInfo should help do the modification + But how to expose it to parser? doesn't work, hard to expose to ir_builder + [x]o3: add ModuleGetGlobalInfos and ModuleUpdateGlobalInfos in src/script/ir_builder/ir/ir.cc + and python/tvm/script/ir_builder/ir/ir.py + how to reassembly the ReturnGlobalInfo and make sure just only one in the ir_module + + + issue 2): global issue was required explicitly at the beggining of the ir_module, + need to figure out a way to update/create a return global info at any point -> Status: todo + Solution: No matter if the tvmscript has explicitly feed the module_gloabl_info or not, and one for return! + + issue 3): need to hide the return global info, it shouldn't be visible to users, + it might crash the exiting test cases -> Status: todo + Solution: solution in 2) should help fix test cases, since we will have return_global_info anyway, + the only concern is that the ordering of return_exprs, topological ordering for relax func parsing + should fix it too. And it just potentially impact test structural_equal, no functionality impacted! + + Conclusion: + 1) The best way is to add "Bool return_body" in SeqExpr, but we need to keep IR constrained at this moment + 2) Introduce func_info in relax function level, similar to global info, but it will introduce return_func_info + into Function, and the IR is affected, then prefer option 1) + So, I decided to move forward with GlobalInfo, because it is already there. + """ + ginfo = I.module_get_global_infos() + print("the current global info: ", ginfo) + # ReturnGlobalInfo exists, append a new value + from tvm.ir.container import Array, Map + from tvm.ir.global_info import ReturnGlobalInfo + + ret_ginfo = I.return_global_info([value]) + if "return_exprs" in ginfo: + r_ginfos = [] + for rginfo in ginfo["return_exprs"]: + r_ginfos.append(rginfo) + r_ginfos.append(ret_ginfo) + ginfo["return_exprs"] = r_ginfos + else: + ginfo["return_exprs"] = [ret_ginfo] + I.module_update_global_infos(ginfo) + R.ret_value(value) # TODO(yongwww): probably we can remove R.ret_value as well @dispatch.register(token="relax", type_name="If") @@ -331,9 +377,10 @@ def visit_if(self: Parser, node: doc.If) -> None: with R.If(self.eval_expr(node.test)) as if_frame: with self.var_table.with_frame(): with R.Then(): - print("fuck here") + print("Entering R.Then") self.visit_body(node.body) with self.var_table.with_frame(): with R.Else(): + print("Entering R.Else") self.visit_body(node.orelse) self.var_table.add(if_frame.var_name, if_frame.var, allow_shadowing=True) diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc index 48f56d60d68c..f55e18cf8948 100644 --- a/src/ir/global_info.cc +++ b/src/ir/global_info.cc @@ -24,9 +24,19 @@ #include namespace tvm { + +TVM_REGISTER_NODE_TYPE(ReturnGlobalInfoNode); + +TVM_REGISTER_GLOBAL("ir.ReturnGlobalInfo").set_body_typed([](Array return_exprs) { + auto n = make_object(); + n->return_exprs = return_exprs; + return ReturnGlobalInfo(n); +}); + TVM_REGISTER_NODE_TYPE(DummyGlobalInfoNode); TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() { auto n = DummyGlobalInfo(make_object()); return n; }); + } // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index da1f3942c78f..f6d5bcb9956c 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -249,6 +249,10 @@ void IRModuleNode::UpdateGlobalInfo(const String& name, const Array& this->global_infos.Set(name, info); } +Array IRModuleNode::GetGlobalInfo(const String& name) { + return this->global_infos[name]; +} + void IRModuleNode::Remove(const GlobalVar& var) { auto functions_node = this->functions.CopyOnWrite(); functions_node->erase(var); @@ -433,6 +437,10 @@ TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction") .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); +TVM_REGISTER_GLOBAL("ir.Module_GetGlobalInfo").set_body_typed([](IRModule mod, String name) { + return mod->GetGlobalInfo(name); +}); + TVM_REGISTER_GLOBAL("ir.Module_UpdateGlobalInfo") .set_body_typed([](IRModule mod, String name, Array global_info) { mod->UpdateGlobalInfo(name, global_info); diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 148e90b28c05..c415491a684f 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -91,11 +91,28 @@ void ModuleGlobalInfos(Map> global_infos) { } } +Map> ModuleGetGlobalInfos() { + CHECK(IRBuilder::IsInScope()); + IRModuleFrame frame = FindModuleFrame("I.ModuleGlobalInfos"); + return frame->global_infos; +} + +void ModuleUpdateGlobalInfos(Map> global_infos) { + if (IRBuilder::IsInScope()) { + IRModuleFrame frame = FindModuleFrame("I.ModuleGlobalInfos"); + frame->global_infos = global_infos; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGlobalInfos").set_body_typed(ModuleGlobalInfos); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGetGlobalInfos") + .set_body_typed(ModuleGetGlobalInfos); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleUpdateGlobalInfos") + .set_body_typed(ModuleUpdateGlobalInfos); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h index 58d5e53f7032..f81592ba13c4 100644 --- a/src/script/ir_builder/ir/utils.h +++ b/src/script/ir_builder/ir/utils.h @@ -29,10 +29,11 @@ namespace ir { inline IRModuleFrame FindModuleFrame(const String& method) { IRBuilder builder = IRBuilder::Current(); if (Optional frame = builder->FindFrame()) { - const Optional& last_module_frame = builder->GetLastFrame(); - if (last_module_frame.defined() && last_module_frame.value() == frame) { - return frame.value(); - } + return frame.value(); + // const Optional& last_module_frame = builder->GetLastFrame(); + // if (last_module_frame.defined() && last_module_frame.value() == frame) { + // return frame.value(); + //} } else { LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method << "' is called under I.ir_module()"; diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index bd3209cd6319..773d92291366 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -59,6 +59,7 @@ void FunctionFrameNode::ExitWithScope() { if (output.defined()) { body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); } else { + // todo (yongwww): handle null for no return for func's body body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, tvm::relax::NullExpr())); } diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index f67bee9c4149..37fa1b053784 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -109,8 +109,8 @@ void RetValue(const tvm::relax::Expr& value) { // Exit BlockFrame if (block_frame.defined()) { block_frame.value()->ExitWithScope(); - ICHECK(!IRBuilder::Current()->FindFrame()) - << "ValueError: Relax functions don't support return in true/false branch of If Node."; + //ICHECK(!IRBuilder::Current()->FindFrame()) + // << "ValueError: Relax functions don't support return in true/false branch of If Node."; } // Step 2. Add the output value to the function frame. Array all_frames = IRBuilder::Current()->frames; diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 62919246b073..d8478240f6d5 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -70,6 +70,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ->Call({d->AsDoc(mod->attrs, p->Attr("attrs"))}))); } if (mod->global_infos.defined() && !mod->global_infos.empty()) { + // todo(yongwww): global return_exprs for printer (*f)->stmts.push_back(ExprStmtDoc( IR(d, "module_global_infos") // ->Call({d->AsDoc(mod->global_infos, p->Attr("global_infos"))}))); @@ -102,6 +103,16 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return IR(d, "GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))}); }); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](ReturnGlobalInfo rginfo, ObjectPath p, IRDocsifier d) -> Doc { + Array return_exprs; + for (const auto& ret_expr : rginfo->return_exprs) { + return_exprs.push_back(d->AsDoc(ret_expr, p->Attr("return_exprs"))); + } + return IR(d, "return_global_info")->Call({ListDoc(return_exprs)}); + }); + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](GlobalInfo ginfo, ObjectPath p, IRDocsifier d) -> Doc { return IR(d, "dummy_global_info")->Call({}); diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index 8a50fe969850..ceaebba19b75 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -27,6 +27,7 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& using relax::SeqExpr; ExprDoc cond = d->AsDoc(n->cond, n_p->Attr("cond")); std::vector> branches{ + // todo(yongwww): check if it's return PrintSeqExpr(Downcast(n->true_branch), n_p->Attr("true_branch"), d, false), PrintSeqExpr(Downcast(n->false_branch), n_p->Attr("false_branch"), d, false), }; diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index dd8f3f813cad..5cd0feb973af 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -22,7 +22,7 @@ import tvm.script import tvm.testing from tvm import IRModule, relax, tir, topi -from tvm.ir import DummyGlobalInfo +from tvm.ir import ReturnGlobalInfo, DummyGlobalInfo from tvm.script.parser import ir as I from tvm.script.parser import relax as R from tvm.script.parser import tir as T @@ -190,9 +190,17 @@ class TestModule: I.module_attrs({"attr": 10}) I.module_global_infos( { - "dummy": [ - I.dummy_global_info(), # dummy[0] - I.dummy_global_info(), # dummy[1] + "return_exprs": [ + I.return_global_info( + [ + R.prim_value(1), + ] + ), # dummy[0] + I.return_global_info( + [ + R.prim_value(2), + ] + ), # dummy[1] ] } ) @@ -214,6 +222,7 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): gv0 = R.call_tir(tir_func, x, R.Tensor((128, 128), dtype="float32")) return gv0 + TestModule.show() x = relax.Var("x", R.Tensor((128, 128), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x,)): @@ -1165,6 +1174,8 @@ def mul_add(x: R.Tensor) -> R.Tensor: def test_control_flow(): @tvm.script.ir_module class ControlFlowExample: + I.module_global_infos({"return_exprs": []}) + @R.function def foo(x: R.Tensor) -> R.Tensor: y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool") @@ -1175,8 +1186,12 @@ def foo(x: R.Tensor) -> R.Tensor: return x ControlFlowExample.show() + print("yongwww get_global_info:") + print(ControlFlowExample.get_global_info("return_exprs")) + print("yongwww get_global_info: done") if __name__ == "__main__": # tvm.testing.main() + # test_module_with_attr_and_global_info() test_control_flow() From 65b53e801ba6e6aaf657c21439d12789350d2f93 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Tue, 28 Feb 2023 21:04:48 -0800 Subject: [PATCH 79/81] update printer --- include/tvm/script/printer/ir_docsifier.h | 11 +++ python/tvm/ir/global_info.py | 1 + python/tvm/script/parser/core/parser.py | 1 - python/tvm/script/parser/relax/parser.py | 44 ++++++------ src/script/printer/ir/ir.cc | 22 +++--- src/script/printer/ir_docsifier.cc | 8 +++ src/script/printer/relax/binding.cc | 77 +++++++++++++++++++-- tests/python/relax/test_tvmscript_parser.py | 9 ++- 8 files changed, 135 insertions(+), 38 deletions(-) diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 156daebf001f..cbac10d69198 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -19,8 +19,10 @@ #ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ #define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ +#include #include #include +#include #include #include @@ -143,8 +145,12 @@ class IRDocsifierNode : public Object { Array dispatch_tokens; /*! \brief Mapping from a var to its info */ std::unordered_map obj2info; + /*! \brief A binding table that maps var to value. */ + std::unordered_map binding_table_; /*! \brief Metadata printing */ std::unordered_map> metadata; + /*! \brief Return exprs used to help tell whether or not an expr is a return*/ + std::unordered_set return_exprs; /*! \brief The variable names used already */ std::unordered_set defined_names; /*! \brief Common prefixes of variable usages */ @@ -206,6 +212,11 @@ class IRDocsifierNode : public Object { Optional GetVarDoc(const ObjectRef& obj) const; /*! \brief Add a TVM object to the metadata section*/ ExprDoc AddMetadata(const ObjectRef& obj); + + Optional LookupBinding(const relax::Var& var); + + void AddReturnExpr(const RelayExpr& ret_expr); + /*! * \brief Check if a variable exists in the table. * \param obj The variable object. diff --git a/python/tvm/ir/global_info.py b/python/tvm/ir/global_info.py index 5dbf4476a501..b3936864b681 100644 --- a/python/tvm/ir/global_info.py +++ b/python/tvm/ir/global_info.py @@ -50,6 +50,7 @@ class ReturnGlobalInfo(GlobalInfo): def __init__(self, return_exprs: List[Expr]) -> None: print("yes entering ReturnGlobalInfo in global_info.py, return_exprs: ", return_exprs) + self.return_exprs = return_exprs self.__init_handle_by_constructor__(_ffi_api.ReturnGlobalInfo, return_exprs) def add(): diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 619b57698c11..105164ed5ffc 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -683,5 +683,4 @@ def visit_Return(self, node: doc.Return) -> Any: # pylint: disable=invalid-name res : Any The visiting result. """ - print("yes visiting return, value: ", node.value) return _dispatch(self, "Return")(self, node) diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 45a05b3cb160..39f4fcb0c2d9 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -30,9 +30,14 @@ from ...ir_builder import relax as R from ...ir_builder.base import IRBuilder from .._core import Parser, dispatch, doc +from ..core.parser import VarTable as return_var_table from .entry import MatchCastPair, StructInfoProxy, TupleProxy +# An global list to record all exprs to return +return_expr_list = [] + + def bind_assign_value( self: Parser, node: doc.expr, @@ -232,7 +237,7 @@ def pre_token_switch(self: Parser, node: doc.Expr) -> None: # pylint: disable=u @dispatch.register(token="relax", type_name="post_token_switch") def post_token_switch(self: Parser, node: doc.Expr) -> None: - print("it could be the last one") + print("Entering post_token_switch") ir_builder = IRBuilder.current() result = ir_builder.get() ir_builder.__exit__(None, None, None) @@ -327,12 +332,16 @@ def visit_return(self: Parser, node: doc.Assign) -> None: issue 1): Save all values into a global list, and add into global_info in the end of parsing -> Status: wip => we can just have a single api like add_return_global_info into the ReturnGlobalInfo, Solution: - o1: create one if no ReturnGlobalInfo found, therefore we can avoid saving values when parsing + [x]o1: Save all return values in a global list, and assembly it in the end of parsing, + don't allow user to provide it. Ignore if it exists o2: Create an IRModuleNode::GetGlobalInfo(String name), plus UpdateGlobalInfo should help do the modification But how to expose it to parser? doesn't work, hard to expose to ir_builder - [x]o3: add ModuleGetGlobalInfos and ModuleUpdateGlobalInfos in src/script/ir_builder/ir/ir.cc + o3: add ModuleGetGlobalInfos and ModuleUpdateGlobalInfos in src/script/ir_builder/ir/ir.cc and python/tvm/script/ir_builder/ir/ir.py - how to reassembly the ReturnGlobalInfo and make sure just only one in the ir_module + how to reassembly the ReturnGlobalInfo is a problem, before the fetch returnGlobalInfo is a runtime.Object + seems there is no way to update it, so give up o3 + + Solution: expose get elements of ReturnGlobalInfo into IR-builder issue 2): global issue was required explicitly at the beggining of the ir_module, @@ -351,22 +360,17 @@ def visit_return(self: Parser, node: doc.Assign) -> None: into Function, and the IR is affected, then prefer option 1) So, I decided to move forward with GlobalInfo, because it is already there. """ - ginfo = I.module_get_global_infos() - print("the current global info: ", ginfo) - # ReturnGlobalInfo exists, append a new value - from tvm.ir.container import Array, Map - from tvm.ir.global_info import ReturnGlobalInfo - - ret_ginfo = I.return_global_info([value]) - if "return_exprs" in ginfo: - r_ginfos = [] - for rginfo in ginfo["return_exprs"]: - r_ginfos.append(rginfo) - r_ginfos.append(ret_ginfo) - ginfo["return_exprs"] = r_ginfos - else: - ginfo["return_exprs"] = [ret_ginfo] - I.module_update_global_infos(ginfo) + + return_expr_list.append(value) + print("Entering return visit") + # use var_table to record the return exprs + ginfos = I.module_get_global_infos() + print("the current global info: ", ginfos) + ret_ginfo = I.return_global_info(return_expr_list) + # str "relax_return_exprs" was reserved as key for return exprs in global_info + ginfos["return_exprs"] = [ret_ginfo] + I.module_update_global_infos(ginfos) + R.ret_value(value) # TODO(yongwww): probably we can remove R.ret_value as well diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index d8478240f6d5..2d1bcd7c34ad 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -104,14 +104,20 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( - "", [](ReturnGlobalInfo rginfo, ObjectPath p, IRDocsifier d) -> Doc { - Array return_exprs; - for (const auto& ret_expr : rginfo->return_exprs) { - return_exprs.push_back(d->AsDoc(ret_expr, p->Attr("return_exprs"))); - } - return IR(d, "return_global_info")->Call({ListDoc(return_exprs)}); - }); + .set_dispatch("", + [](ReturnGlobalInfo rginfo, ObjectPath p, + IRDocsifier d) -> Doc { + Array return_exprs; + for (const auto& ret_expr : rginfo->return_exprs) { + d->AddReturnExpr(ret_expr); + // return_exprs.push_back(d->AsDoc(ret_expr, + // p->Attr("return_exprs"))); + } + // return IR(d, + // "return_global_info")->Call({ListDoc(return_exprs)}); + + return IR(d, "return_global_info")->Call({}); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](GlobalInfo ginfo, ObjectPath p, IRDocsifier d) -> Doc { diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 936534480ffb..441448353ccf 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -63,6 +63,14 @@ ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) { [{LiteralDoc::Int(index, NullOpt)}]; } +void IRDocsifierNode::AddReturnExpr(const RelayExpr& ret_expr) { return_exprs.insert(ret_expr); } + +Optional IRDocsifierNode::LookupBinding(const relax::Var& var) { + auto it = binding_table_.find(var->vid); + if (it == binding_table_.end()) return NullOpt; + return it->second; +} + bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); } void IRDocsifierNode::RemoveVar(const ObjectRef& obj) { diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index ceaebba19b75..06d2d9053e0c 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -26,15 +26,76 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& const Optional& var, const Optional& ann) { using relax::SeqExpr; ExprDoc cond = d->AsDoc(n->cond, n_p->Attr("cond")); - std::vector> branches{ - // todo(yongwww): check if it's return - PrintSeqExpr(Downcast(n->true_branch), n_p->Attr("true_branch"), d, false), - PrintSeqExpr(Downcast(n->false_branch), n_p->Attr("false_branch"), d, false), - }; + std::vector> branches; + // todo(yongwww): looks the return_exprs are the values, and normalizer adds a new binding + // need to figure out a way to get if the seqexpr.body was bound to one of return_exprs, too + // complicated! + for (auto ret_expr : d->return_exprs) { + LOG(INFO) << "yongwww 33 ret_expr: " << ret_expr; + } + auto true_seq_expr = Downcast(n->true_branch); + auto false_seq_expr = Downcast(n->false_branch); + if (const auto* var_node = true_seq_expr->body.as()) { + auto t_var = GetRef(var_node); + LOG(INFO) << "yongwww true_seq_expr->body: " << t_var << " -- val: " << d->LookupBinding(t_var); + } + + for (auto ele : d->binding_table_) { + LOG(INFO) << "ele k: " << ele.first << " - value: " << ele.second; + } + + if (const auto* var_node = false_seq_expr->body.as()) { + auto t_var = GetRef(var_node); + LOG(INFO) << "yongwww false_seq_expr->body: " << t_var + << " -- val: " << d->LookupBinding(t_var); + } + bool ret_true_branch = false; + bool ret_false_branch = false; + relax::BindingBlock last_block_true = true_seq_expr->blocks[true_seq_expr->blocks.size() - 1]; + relax::Binding last_binding_true = + last_block_true->bindings[last_block_true->bindings.size() - 1]; + if (auto* var_binding = last_binding_true.as()) { + auto last_var_binding_true = GetRef(var_binding); + if (last_var_binding_true->var.same_as(true_seq_expr->body) && + d->return_exprs.find(last_var_binding_true->value) != d->return_exprs.end()) { + ret_true_branch = true; + LOG(INFO) << "yongwww ret_true_branch true"; + } + } + + relax::BindingBlock last_block_false = false_seq_expr->blocks[false_seq_expr->blocks.size() - 1]; + relax::Binding last_binding_false = + last_block_false->bindings[last_block_false->bindings.size() - 1]; + if (auto* var_binding = last_binding_false.as()) { + auto last_var_binding_false = GetRef(var_binding); + if (last_var_binding_false->var.same_as(false_seq_expr->body) && + d->return_exprs.find(last_var_binding_false->value) != d->return_exprs.end()) { + ret_false_branch = true; + LOG(INFO) << "yongwww ret_false_branch true"; + } + } + + if (d->return_exprs.find(true_seq_expr->body) != d->return_exprs.end()) { + branches.push_back(PrintSeqExpr(true_seq_expr, n_p->Attr("true_branch"), d, ret_true_branch)); + } else { + branches.push_back(PrintSeqExpr(true_seq_expr, n_p->Attr("true_branch"), d, ret_true_branch)); + } + + if (d->return_exprs.find(false_seq_expr->body) != d->return_exprs.end()) { + branches.push_back( + PrintSeqExpr(false_seq_expr, n_p->Attr("false_branch"), d, ret_false_branch)); + } else { + branches.push_back( + PrintSeqExpr(false_seq_expr, n_p->Attr("false_branch"), d, ret_false_branch)); + } + if (var.defined()) { for (Array& stmts : branches) { - ExprDoc ret = Downcast(stmts.back())->expr; - stmts.Set(stmts.size() - 1, AssignDoc(var.value(), ret, ann)); + if (!stmts.back()->IsInstance()) { + ExprDoc ret = Downcast(stmts.back())->expr; + stmts.Set(stmts.size() - 1, AssignDoc(var.value(), ret, ann)); + } + LOG(INFO) << "yongwww stmts.back() key: " << stmts.back()->GetTypeKey(); } } return IfDoc(cond, branches[0], branches[1]); @@ -56,6 +117,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::VarBinding n, ObjectPath n_p, IRDocsifier d) -> Doc { + LOG(INFO) << "n->var: " << n->var << " --- value: " << n->value; + d->binding_table_[n->var->vid] = n->value; if (const auto if_ = n->value.as()) { Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 5cd0feb973af..e04f6657a451 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -190,6 +190,10 @@ class TestModule: I.module_attrs({"attr": 10}) I.module_global_infos( { + "dummy": [ + I.dummy_global_info(), # dummy[0] + I.dummy_global_info(), # dummy[1] + ], "return_exprs": [ I.return_global_info( [ @@ -201,7 +205,7 @@ class TestModule: R.prim_value(2), ] ), # dummy[1] - ] + ], } ) @@ -1183,7 +1187,7 @@ def foo(x: R.Tensor) -> R.Tensor: return R.add(x, x) else: return R.multiply(x, x) - return x + return R.subtract(x, x) ControlFlowExample.show() print("yongwww get_global_info:") @@ -1194,4 +1198,5 @@ def foo(x: R.Tensor) -> R.Tensor: if __name__ == "__main__": # tvm.testing.main() # test_module_with_attr_and_global_info() + # test_module_with_attr_and_global_info() test_control_flow() From e32d773ef9bd0549fcc4f09c98b3d156b4964dd0 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 1 Mar 2023 09:33:07 -0800 Subject: [PATCH 80/81] Remove null_expr --- include/tvm/ir/global_info.h | 25 +++---- include/tvm/relax/expr.h | 22 ------ include/tvm/script/printer/ir_docsifier.h | 2 +- python/tvm/ir/__init__.py | 2 +- python/tvm/ir/global_info.py | 27 ++----- python/tvm/script/ir_builder/ir/__init__.py | 2 +- python/tvm/script/ir_builder/ir/ir.py | 13 ++-- python/tvm/script/parser/core/parser.py | 6 ++ python/tvm/script/parser/ir/__init__.py | 2 +- python/tvm/script/parser/relax/parser.py | 30 ++++---- src/ir/global_info.cc | 13 ++-- src/relax/ir/expr.cc | 11 --- src/script/ir_builder/relax/frame.cc | 6 +- src/script/printer/ir/ir.cc | 33 ++++----- src/script/printer/ir_docsifier.cc | 4 +- src/script/printer/relax/binding.cc | 15 ++-- tests/python/relax/test_tvmscript_parser.py | 79 ++++++++++++++------- 17 files changed, 138 insertions(+), 154 deletions(-) diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h index 3fd96dc37930..63682dd89b64 100644 --- a/include/tvm/ir/global_info.h +++ b/include/tvm/ir/global_info.h @@ -76,29 +76,30 @@ class DummyGlobalInfo : public GlobalInfo { }; /*! - * \brief A return global info sub-class for expressions to return. + * \brief A return global info sub-class for return expressions. */ -class ReturnGlobalInfoNode : public GlobalInfoNode { +class RelaxReturnGlobalInfoNode : public GlobalInfoNode { public: - Array return_exprs; + Array relax_return_exprs; void VisitAttrs(tvm::AttrVisitor* v) {} - static constexpr const char* _type_key = "ReturnGlobalInfo"; + static constexpr const char* _type_key = "RelaxReturnGlobalInfo"; - TVM_DLL bool SEqualReduce(const ReturnGlobalInfoNode* other, SEqualReducer equal) const { - return equal(return_exprs, other->return_exprs); + TVM_DLL bool SEqualReduce(const RelaxReturnGlobalInfoNode* other, SEqualReducer equal) const { + // return equal(relax_return_exprs, other->relax_return_exprs) + return true; } - TVM_DLL void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(return_exprs); } - TVM_DECLARE_FINAL_OBJECT_INFO(ReturnGlobalInfoNode, GlobalInfoNode); + TVM_DLL void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(relax_return_exprs); } + TVM_DECLARE_FINAL_OBJECT_INFO(RelaxReturnGlobalInfoNode, GlobalInfoNode); }; /*! - * \brief Managed reference to ReturnGlobalInfoNode. - * \sa ReturnGlobalInfoNode + * \brief Managed reference to RelaxReturnGlobalInfoNode. + * \sa RelaxReturnGlobalInfoNode */ -class ReturnGlobalInfo : public GlobalInfo { +class RelaxReturnGlobalInfo : public GlobalInfo { public: - TVM_DEFINE_OBJECT_REF_METHODS(ReturnGlobalInfo, GlobalInfo, ReturnGlobalInfoNode); + TVM_DEFINE_OBJECT_REF_METHODS(RelaxReturnGlobalInfo, GlobalInfo, RelaxReturnGlobalInfoNode); }; } // namespace tvm diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 7b54fa1864e6..5ff8a6dc205f 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -649,28 +649,6 @@ class NullExprNode : public LeafExprNode { TVM_DECLARE_FINAL_OBJECT_INFO(NullExprNode, LeafExprNode); }; -/*! - * \brief Managed reference to NullExprNode - * \sa NullExprNode - */ -class NullExpr : public LeafExpr { - public: - /*! - * \brief The constructor - * \param span The source span of the expression. - */ - TVM_DLL explicit NullExpr(Span span); - - /*! - * \brief Create a null expression. - * \param span The source span of the expression. - * \return The created prim value. - */ - - TVM_DEFINE_OBJECT_REF_METHODS(NullExpr, LeafExpr, NullExprNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(NullExprNode); -}; - /*! * \brief Represent a string literal constant. */ diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index cbac10d69198..469d94d3993d 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -150,7 +150,7 @@ class IRDocsifierNode : public Object { /*! \brief Metadata printing */ std::unordered_map> metadata; /*! \brief Return exprs used to help tell whether or not an expr is a return*/ - std::unordered_set return_exprs; + std::unordered_set relax_return_exprs; /*! \brief The variable names used already */ std::unordered_set defined_names; /*! \brief Common prefixes of variable usages */ diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 9ada942a9e2f..71c6d9c3a238 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -34,7 +34,7 @@ from .container import Array, Map from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr from .function import BaseFunc, CallingConv -from .global_info import GlobalInfo, ReturnGlobalInfo, DummyGlobalInfo +from .global_info import GlobalInfo, RelaxReturnGlobalInfo, DummyGlobalInfo from .memory_pools import ( ConstantMemoryPools, ConstantPoolInfo, diff --git a/python/tvm/ir/global_info.py b/python/tvm/ir/global_info.py index b3936864b681..17f00c599c7a 100644 --- a/python/tvm/ir/global_info.py +++ b/python/tvm/ir/global_info.py @@ -37,33 +37,20 @@ def same_as(self, other): return super().__eq__(other) -class ReturnGlobalInfo(GlobalInfo): - """ReturnGlobalInfo in the IR. +class RelaxReturnGlobalInfo(GlobalInfo): + """RelaxReturnGlobalInfo in the IR. Parameters ---------- - return_exprs : List[Expr] + relax_return_exprs : List[Expr] The expressions to be returned. """ - return_exprs: List[Expr] + relax_return_exprs: List[Expr] - def __init__(self, return_exprs: List[Expr]) -> None: - print("yes entering ReturnGlobalInfo in global_info.py, return_exprs: ", return_exprs) - self.return_exprs = return_exprs - self.__init_handle_by_constructor__(_ffi_api.ReturnGlobalInfo, return_exprs) - - def add(): - pass - - def update(return_exprs: List[Expr]): - pass - - def get() -> GlobalInfo: - pass - - def get_exprs(self): - return self.return_exprs + def __init__(self, relax_return_exprs: List[Expr]) -> None: + self.relax_return_exprs = relax_return_exprs + self.__init_handle_by_constructor__(_ffi_api.RelaxReturnGlobalInfo, relax_return_exprs) class DummyGlobalInfo(GlobalInfo): diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index 91eb0535ffda..6865c2fad4ed 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -24,6 +24,6 @@ module_global_infos, module_get_global_infos, module_update_global_infos, - return_global_info, + relax_return_global_info, dummy_global_info, ) diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 90c1568b69c5..cf4bf13c0fb0 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -18,7 +18,7 @@ from typing import Dict, List -from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, ReturnGlobalInfo, DummyGlobalInfo +from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, RelaxReturnGlobalInfo, DummyGlobalInfo from tvm.ir import RelayExpr as Expr from tvm.runtime import Object as tvm_Object @@ -123,20 +123,21 @@ def module_update_global_infos(global_infos: Dict[str, List[GlobalInfo]]) -> Non ############################### GlobalInfo ############################### -def return_global_info(return_exprs: List[Expr]) -> ReturnGlobalInfo: +def relax_return_global_info(relax_return_exprs: List[Expr] = None) -> RelaxReturnGlobalInfo: """Create a return global info expression. Parameters ---------- - return_exprs : List[Expr] + relax_return_exprs : List[Expr] The expressions to be returned. Returns ------- - res : ReturnGlobalInfo + res : RelaxReturnGlobalInfo The result return global info. """ - print("yes return_global_info in ir_builder/ir.py") - return ReturnGlobalInfo(return_exprs) # type: ignore[attr-defined] # pylint: disable=no-member + if relax_return_exprs is None: + relax_return_exprs = [] + return RelaxReturnGlobalInfo(relax_return_exprs) # type: ignore[attr-defined] # pylint: disable=no-member def dummy_global_info() -> DummyGlobalInfo: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 105164ed5ffc..56ca0299b250 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -231,16 +231,22 @@ class Parser(doc.NodeVisitor): var_table : VarTable The variable table for parsing. + + aux_dict: Dict[str, List[Any]] + The auxiliary dict for storing global info. like return exprs + of RelaxReturnGloablInfo """ diag: Diagnostics dispatch_tokens: List[str] var_table: VarTable + aux_dict: Dict[str, List[Any]] def __init__(self, source: Source) -> None: self.diag = Diagnostics(source) self.dispatch_tokens = ["default"] self.var_table = VarTable() + self.aux_dict = {} def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: """The main parse method for parser. diff --git a/python/tvm/script/parser/ir/__init__.py b/python/tvm/script/parser/ir/__init__.py index 4dbb615b89e4..3ede928ef423 100644 --- a/python/tvm/script/parser/ir/__init__.py +++ b/python/tvm/script/parser/ir/__init__.py @@ -25,6 +25,6 @@ "module_global_infos", "module_get_global_infos", "module_update_global_infos", - "return_global_info", + "relax_return_global_info", "dummy_global_info", ] diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 39f4fcb0c2d9..f7fad0e664a2 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -34,10 +34,6 @@ from .entry import MatchCastPair, StructInfoProxy, TupleProxy -# An global list to record all exprs to return -return_expr_list = [] - - def bind_assign_value( self: Parser, node: doc.expr, @@ -330,7 +326,7 @@ def visit_return(self: Parser, node: doc.Assign) -> None: """ TODO (yongwww): issue 1): Save all values into a global list, and add into global_info in the end of parsing -> Status: wip - => we can just have a single api like add_return_global_info into the ReturnGlobalInfo, + => we can just have a single api like add_relax_return_global_info into the RelaxReturnGlobalInfo, Solution: [x]o1: Save all return values in a global list, and assembly it in the end of parsing, don't allow user to provide it. Ignore if it exists @@ -338,10 +334,10 @@ def visit_return(self: Parser, node: doc.Assign) -> None: But how to expose it to parser? doesn't work, hard to expose to ir_builder o3: add ModuleGetGlobalInfos and ModuleUpdateGlobalInfos in src/script/ir_builder/ir/ir.cc and python/tvm/script/ir_builder/ir/ir.py - how to reassembly the ReturnGlobalInfo is a problem, before the fetch returnGlobalInfo is a runtime.Object + how to reassembly the RelaxReturnGlobalInfo is a problem, before the fetch returnGlobalInfo is a runtime.Object seems there is no way to update it, so give up o3 - Solution: expose get elements of ReturnGlobalInfo into IR-builder + Solution: expose get elements of RelaxReturnGlobalInfo into IR-builder issue 2): global issue was required explicitly at the beggining of the ir_module, @@ -350,8 +346,8 @@ def visit_return(self: Parser, node: doc.Assign) -> None: issue 3): need to hide the return global info, it shouldn't be visible to users, it might crash the exiting test cases -> Status: todo - Solution: solution in 2) should help fix test cases, since we will have return_global_info anyway, - the only concern is that the ordering of return_exprs, topological ordering for relax func parsing + Solution: solution in 2) should help fix test cases, since we will have relax_return_global_info anyway, + the only concern is that the ordering of relax_return_exprs, topological ordering for relax func parsing should fix it too. And it just potentially impact test structural_equal, no functionality impacted! Conclusion: @@ -361,14 +357,16 @@ def visit_return(self: Parser, node: doc.Assign) -> None: So, I decided to move forward with GlobalInfo, because it is already there. """ - return_expr_list.append(value) - print("Entering return visit") - # use var_table to record the return exprs + # "relax_return_exprs" was used as key for return exprs + return_expr_key = "relax_return_exprs" + if return_expr_key not in self.aux_dict: + self.aux_dict[return_expr_key] = [] + self.aux_dict[return_expr_key].append(value) ginfos = I.module_get_global_infos() - print("the current global info: ", ginfos) - ret_ginfo = I.return_global_info(return_expr_list) - # str "relax_return_exprs" was reserved as key for return exprs in global_info - ginfos["return_exprs"] = [ret_ginfo] + + ret_ginfo = I.relax_return_global_info(self.aux_dict[return_expr_key]) + + ginfos[return_expr_key] = [ret_ginfo] I.module_update_global_infos(ginfos) R.ret_value(value) # TODO(yongwww): probably we can remove R.ret_value as well diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc index f55e18cf8948..96587de74b03 100644 --- a/src/ir/global_info.cc +++ b/src/ir/global_info.cc @@ -25,13 +25,14 @@ #include namespace tvm { -TVM_REGISTER_NODE_TYPE(ReturnGlobalInfoNode); +TVM_REGISTER_NODE_TYPE(RelaxReturnGlobalInfoNode); -TVM_REGISTER_GLOBAL("ir.ReturnGlobalInfo").set_body_typed([](Array return_exprs) { - auto n = make_object(); - n->return_exprs = return_exprs; - return ReturnGlobalInfo(n); -}); +TVM_REGISTER_GLOBAL("ir.RelaxReturnGlobalInfo") + .set_body_typed([](Array relax_return_exprs) { + auto n = make_object(); + n->relax_return_exprs = relax_return_exprs; + return RelaxReturnGlobalInfo(n); + }); TVM_REGISTER_NODE_TYPE(DummyGlobalInfoNode); TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() { diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 70cb4a5f80ff..5392be7cb69b 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -309,17 +309,6 @@ TVM_REGISTER_GLOBAL("relax.PrimValue").set_body_typed([](PrimExpr value, Span sp return PrimValue(value, span); }); -NullExpr::NullExpr(Span span) { - ObjectPtr n = make_object(); - n->checked_type_ = ObjectType(); - n->struct_info_ = ObjectStructInfo(); - n->span = std::move(span); -} - -TVM_REGISTER_NODE_TYPE(NullExprNode); - -TVM_REGISTER_GLOBAL("relax.NullExpr").set_body_typed([](Span span) { return NullExpr(span); }); - StringImm::StringImm(String value, Span span) { ObjectPtr n = make_object(); n->value = std::move(value); diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 773d92291366..7c2282469931 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -60,8 +60,8 @@ void FunctionFrameNode::ExitWithScope() { body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); } else { // todo (yongwww): handle null for no return for func's body - body = - this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, tvm::relax::NullExpr())); + //binding_blocks.pop_back() + LOG(FATAL) << "ValueError: Cannot find the output for the function"; } auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); @@ -264,8 +264,6 @@ void ElseFrameNode::ExitWithScope() { output = GetSeqExprForBranch(GetRef(this), &var_name); IfFrame frame = FindIfFrame("R.Else"); frame->else_expr = output; - CHECK(frame->var_name == var_name) - << "This last binding of both branches must have the same variable."; } TVM_REGISTER_NODE_TYPE(FunctionFrameNode); diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 2d1bcd7c34ad..b07de75f8f6c 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -70,10 +70,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ->Call({d->AsDoc(mod->attrs, p->Attr("attrs"))}))); } if (mod->global_infos.defined() && !mod->global_infos.empty()) { - // todo(yongwww): global return_exprs for printer - (*f)->stmts.push_back(ExprStmtDoc( - IR(d, "module_global_infos") // - ->Call({d->AsDoc(mod->global_infos, p->Attr("global_infos"))}))); + // RelaxReturnGlobalInfo was not printed + ExprStmtDoc mod_ginfos = ExprStmtDoc( + IR(d, "module_global_infos") + ->Call({d->AsDoc(mod->global_infos, p->Attr("global_infos"))})); + if (mod->global_infos.size() > 1 || mod->global_infos.count("relax_return_exprs") == 0) { + (*f)->stmts.push_back(mod_ginfos); + } } for (const auto& entry : functions) { const GlobalVar& gv = entry.gv; @@ -104,20 +107,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", - [](ReturnGlobalInfo rginfo, ObjectPath p, - IRDocsifier d) -> Doc { - Array return_exprs; - for (const auto& ret_expr : rginfo->return_exprs) { - d->AddReturnExpr(ret_expr); - // return_exprs.push_back(d->AsDoc(ret_expr, - // p->Attr("return_exprs"))); - } - // return IR(d, - // "return_global_info")->Call({ListDoc(return_exprs)}); - - return IR(d, "return_global_info")->Call({}); - }); + .set_dispatch("", + [](RelaxReturnGlobalInfo rginfo, ObjectPath p, + IRDocsifier d) -> Doc { + for (const auto& ret_expr : rginfo->relax_return_exprs) { + d->AddReturnExpr(ret_expr); + } + return IR(d, "relax_return_global_info")->Call({}); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](GlobalInfo ginfo, ObjectPath p, IRDocsifier d) -> Doc { diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 441448353ccf..33ba2b78995c 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -63,7 +63,9 @@ ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) { [{LiteralDoc::Int(index, NullOpt)}]; } -void IRDocsifierNode::AddReturnExpr(const RelayExpr& ret_expr) { return_exprs.insert(ret_expr); } +void IRDocsifierNode::AddReturnExpr(const RelayExpr& ret_expr) { + relax_return_exprs.insert(ret_expr); +} Optional IRDocsifierNode::LookupBinding(const relax::Var& var) { auto it = binding_table_.find(var->vid); diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index 06d2d9053e0c..81e6d41c73c3 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -27,10 +27,10 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& using relax::SeqExpr; ExprDoc cond = d->AsDoc(n->cond, n_p->Attr("cond")); std::vector> branches; - // todo(yongwww): looks the return_exprs are the values, and normalizer adds a new binding - // need to figure out a way to get if the seqexpr.body was bound to one of return_exprs, too + // todo(yongwww): looks the relax_return_exprs are the values, and normalizer adds a new binding + // need to figure out a way to get if the seqexpr.body was bound to one of relax_return_exprs, too // complicated! - for (auto ret_expr : d->return_exprs) { + for (auto ret_expr : d->relax_return_exprs) { LOG(INFO) << "yongwww 33 ret_expr: " << ret_expr; } auto true_seq_expr = Downcast(n->true_branch); @@ -57,7 +57,7 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& if (auto* var_binding = last_binding_true.as()) { auto last_var_binding_true = GetRef(var_binding); if (last_var_binding_true->var.same_as(true_seq_expr->body) && - d->return_exprs.find(last_var_binding_true->value) != d->return_exprs.end()) { + d->relax_return_exprs.find(last_var_binding_true->value) != d->relax_return_exprs.end()) { ret_true_branch = true; LOG(INFO) << "yongwww ret_true_branch true"; } @@ -69,19 +69,19 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& if (auto* var_binding = last_binding_false.as()) { auto last_var_binding_false = GetRef(var_binding); if (last_var_binding_false->var.same_as(false_seq_expr->body) && - d->return_exprs.find(last_var_binding_false->value) != d->return_exprs.end()) { + d->relax_return_exprs.find(last_var_binding_false->value) != d->relax_return_exprs.end()) { ret_false_branch = true; LOG(INFO) << "yongwww ret_false_branch true"; } } - if (d->return_exprs.find(true_seq_expr->body) != d->return_exprs.end()) { + if (d->relax_return_exprs.find(true_seq_expr->body) != d->relax_return_exprs.end()) { branches.push_back(PrintSeqExpr(true_seq_expr, n_p->Attr("true_branch"), d, ret_true_branch)); } else { branches.push_back(PrintSeqExpr(true_seq_expr, n_p->Attr("true_branch"), d, ret_true_branch)); } - if (d->return_exprs.find(false_seq_expr->body) != d->return_exprs.end()) { + if (d->relax_return_exprs.find(false_seq_expr->body) != d->relax_return_exprs.end()) { branches.push_back( PrintSeqExpr(false_seq_expr, n_p->Attr("false_branch"), d, ret_false_branch)); } else { @@ -117,7 +117,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::VarBinding n, ObjectPath n_p, IRDocsifier d) -> Doc { - LOG(INFO) << "n->var: " << n->var << " --- value: " << n->value; d->binding_table_[n->var->vid] = n->value; if (const auto if_ = n->value.as()) { Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index e04f6657a451..8a6419e5fe3e 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -22,7 +22,7 @@ import tvm.script import tvm.testing from tvm import IRModule, relax, tir, topi -from tvm.ir import ReturnGlobalInfo, DummyGlobalInfo +from tvm.ir import RelaxReturnGlobalInfo, DummyGlobalInfo from tvm.script.parser import ir as I from tvm.script.parser import relax as R from tvm.script.parser import tir as T @@ -193,19 +193,7 @@ class TestModule: "dummy": [ I.dummy_global_info(), # dummy[0] I.dummy_global_info(), # dummy[1] - ], - "return_exprs": [ - I.return_global_info( - [ - R.prim_value(1), - ] - ), # dummy[0] - I.return_global_info( - [ - R.prim_value(2), - ] - ), # dummy[1] - ], + ] } ) @@ -226,7 +214,6 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): gv0 = R.call_tir(tir_func, x, R.Tensor((128, 128), dtype="float32")) return gv0 - TestModule.show() x = relax.Var("x", R.Tensor((128, 128), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x,)): @@ -234,8 +221,11 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): bb.emit_func_output(out) mod = bb.get() mod.update_global_info("dummy", [DummyGlobalInfo(), DummyGlobalInfo()]) + mod.update_global_info("relax_return_exprs", [RelaxReturnGlobalInfo([mod["foo"].body.body])]) mod = mod.with_attr("attr", tvm.tir.IntImm("int32", 10)) - _check(TestModule, mod) + roundtrip_mod = tvm.script.from_source(TestModule.script(show_meta=True)) + tvm.ir.assert_structural_equal(TestModule, roundtrip_mod, True) + tvm.ir.assert_structural_equal(TestModule, mod, True) def test_relax_tensor_op(): @@ -1175,11 +1165,11 @@ def mul_add(x: R.Tensor) -> R.Tensor: _check(InputModule, OutputModule) -def test_control_flow(): +def test_multi_return(): @tvm.script.ir_module - class ControlFlowExample: - I.module_global_infos({"return_exprs": []}) - + class MultiReturn: + """ + # foo was supported due to the body of SeqExpr of function is required @R.function def foo(x: R.Tensor) -> R.Tensor: y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool") @@ -1187,16 +1177,53 @@ def foo(x: R.Tensor) -> R.Tensor: return R.add(x, x) else: return R.multiply(x, x) + + # noelse don't work. Main reason is SeqExpr is required + # in false branch relax.If + @R.function + def noelse(x: R.Tensor) -> R.Tensor: + y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool") + if y: + return R.add(x, x) + return R.subtract(x, x) + """ + + @R.function + def foo0(x: R.Tensor) -> R.Tensor: + y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool") + if y: + v = R.add(x, x) + return v + else: + v = R.multiply(x, x) + return v + + @R.function + def foo1(x: R.Tensor) -> R.Tensor: + y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool") + if y: + return R.add(x, x) + else: + return R.multiply(x, x) return R.subtract(x, x) - ControlFlowExample.show() - print("yongwww get_global_info:") - print(ControlFlowExample.get_global_info("return_exprs")) - print("yongwww get_global_info: done") + @R.function + def foo2(x: R.Tensor) -> R.Tensor: + y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool") + if y: + v = R.add(x, x) + else: + return R.multiply(x, x) + return v + + MultiReturn.show() + print("yongwww get_global_info:", MultiReturn.get_global_info("relax_return_exprs")) + _check(MultiReturn) + roundtrip_mod = tvm.script.from_source(MultiReturn.script(show_meta=True)) + tvm.ir.assert_structural_equal(MultiReturn, roundtrip_mod, True) if __name__ == "__main__": # tvm.testing.main() # test_module_with_attr_and_global_info() - # test_module_with_attr_and_global_info() - test_control_flow() + test_multi_return() From bfe9082080f209b08fcdef8b4042e9c3135da754 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 1 Mar 2023 16:47:14 -0800 Subject: [PATCH 81/81] update --- include/tvm/relax/expr.h | 24 ------- include/tvm/script/ir_builder/base.h | 1 - include/tvm/script/ir_builder/relax/frame.h | 1 - include/tvm/script/printer/ir_docsifier.h | 10 ++- python/tvm/script/parser/relax/parser.py | 47 ++---------- src/ir/global_info.cc | 1 - src/script/ir_builder/relax/frame.cc | 1 - src/script/ir_builder/relax/ir.cc | 8 +-- src/script/printer/ir_docsifier.cc | 6 -- src/script/printer/relax/binding.cc | 79 +++++---------------- tests/python/relax/test_tvmscript_parser.py | 29 +++++--- 11 files changed, 48 insertions(+), 159 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 5ff8a6dc205f..0788193ee7c4 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -625,30 +625,6 @@ class PrimValue : public LeafExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimValueNode); }; -/*! - * \brief NullExpr. - * - * Expression representing a Null expression. - */ -class NullExprNode : public LeafExprNode { - public: - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("struct_info_", &struct_info_); - v->Visit("_checked_type_", &checked_type_); - v->Visit("span", &span); - } - - bool SEqualReduce(const NullExprNode* other, SEqualReducer equal) const { - // struct info can be deterministically derived from data. - return equal(struct_info_, other->struct_info_); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(struct_info_); } - - static constexpr const char* _type_key = "relax.expr.NullExpr"; - TVM_DECLARE_FINAL_OBJECT_INFO(NullExprNode, LeafExprNode); -}; - /*! * \brief Represent a string literal constant. */ diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index d263ea360d67..a00ea5768e23 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -236,7 +236,6 @@ class IRBuilder : public runtime::ObjectRef { * \sa IRBuilder::ExitWithScope * \sa tvm::support::With */ - static std::vector All(); static IRBuilder Current(); /*! \brief See if the current thread-local scope has an IRBuilder. */ static bool IsInScope(); diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index c99b9220520c..6d0810716753 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -101,7 +101,6 @@ class FunctionFrameNode : public SeqExprFrameNode { /*! \brief The function attributes. */ Map attrs; - // todo(yongwww) Add Map> global_infos; /*! \brief The block builder to create Relax function. */ tvm::relax::BlockBuilder block_builder; diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 469d94d3993d..e65fa234e727 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -145,8 +145,6 @@ class IRDocsifierNode : public Object { Array dispatch_tokens; /*! \brief Mapping from a var to its info */ std::unordered_map obj2info; - /*! \brief A binding table that maps var to value. */ - std::unordered_map binding_table_; /*! \brief Metadata printing */ std::unordered_map> metadata; /*! \brief Return exprs used to help tell whether or not an expr is a return*/ @@ -212,11 +210,11 @@ class IRDocsifierNode : public Object { Optional GetVarDoc(const ObjectRef& obj) const; /*! \brief Add a TVM object to the metadata section*/ ExprDoc AddMetadata(const ObjectRef& obj); - - Optional LookupBinding(const relax::Var& var); - + /*! + * \brief Add an expression into return expression set. + * \param ret_expr The return expression. + */ void AddReturnExpr(const RelayExpr& ret_expr); - /*! * \brief Check if a variable exists in the table. * \param obj The variable object. diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index f7fad0e664a2..6058b0ccdfeb 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -323,53 +323,17 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: def visit_return(self: Parser, node: doc.Assign) -> None: value = self.eval_expr(node.value) value = convert_to_expr(value) - """ - TODO (yongwww): - issue 1): Save all values into a global list, and add into global_info in the end of parsing -> Status: wip - => we can just have a single api like add_relax_return_global_info into the RelaxReturnGlobalInfo, - Solution: - [x]o1: Save all return values in a global list, and assembly it in the end of parsing, - don't allow user to provide it. Ignore if it exists - o2: Create an IRModuleNode::GetGlobalInfo(String name), plus UpdateGlobalInfo should help do the modification - But how to expose it to parser? doesn't work, hard to expose to ir_builder - o3: add ModuleGetGlobalInfos and ModuleUpdateGlobalInfos in src/script/ir_builder/ir/ir.cc - and python/tvm/script/ir_builder/ir/ir.py - how to reassembly the RelaxReturnGlobalInfo is a problem, before the fetch returnGlobalInfo is a runtime.Object - seems there is no way to update it, so give up o3 - - Solution: expose get elements of RelaxReturnGlobalInfo into IR-builder - - - issue 2): global issue was required explicitly at the beggining of the ir_module, - need to figure out a way to update/create a return global info at any point -> Status: todo - Solution: No matter if the tvmscript has explicitly feed the module_gloabl_info or not, and one for return! - - issue 3): need to hide the return global info, it shouldn't be visible to users, - it might crash the exiting test cases -> Status: todo - Solution: solution in 2) should help fix test cases, since we will have relax_return_global_info anyway, - the only concern is that the ordering of relax_return_exprs, topological ordering for relax func parsing - should fix it too. And it just potentially impact test structural_equal, no functionality impacted! - - Conclusion: - 1) The best way is to add "Bool return_body" in SeqExpr, but we need to keep IR constrained at this moment - 2) Introduce func_info in relax function level, similar to global info, but it will introduce return_func_info - into Function, and the IR is affected, then prefer option 1) - So, I decided to move forward with GlobalInfo, because it is already there. - """ - # "relax_return_exprs" was used as key for return exprs return_expr_key = "relax_return_exprs" if return_expr_key not in self.aux_dict: self.aux_dict[return_expr_key] = [] self.aux_dict[return_expr_key].append(value) + # update the return global info ginfos = I.module_get_global_infos() - ret_ginfo = I.relax_return_global_info(self.aux_dict[return_expr_key]) - ginfos[return_expr_key] = [ret_ginfo] I.module_update_global_infos(ginfos) - - R.ret_value(value) # TODO(yongwww): probably we can remove R.ret_value as well + R.ret_value(value) @dispatch.register(token="relax", type_name="If") @@ -379,10 +343,11 @@ def visit_if(self: Parser, node: doc.If) -> None: with R.If(self.eval_expr(node.test)) as if_frame: with self.var_table.with_frame(): with R.Then(): - print("Entering R.Then") self.visit_body(node.body) with self.var_table.with_frame(): with R.Else(): - print("Entering R.Else") self.visit_body(node.orelse) - self.var_table.add(if_frame.var_name, if_frame.var, allow_shadowing=True) + if not if_frame.var_name: + self.var_table.add(str(if_frame.var), if_frame.var, allow_shadowing=True) + else: + self.var_table.add(if_frame.var_name, if_frame.var, allow_shadowing=True) diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc index 96587de74b03..22623b168b90 100644 --- a/src/ir/global_info.cc +++ b/src/ir/global_info.cc @@ -39,5 +39,4 @@ TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() { auto n = DummyGlobalInfo(make_object()); return n; }); - } // namespace tvm diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 7c2282469931..c4c57d0c2c2d 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -60,7 +60,6 @@ void FunctionFrameNode::ExitWithScope() { body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); } else { // todo (yongwww): handle null for no return for func's body - //binding_blocks.pop_back() LOG(FATAL) << "ValueError: Cannot find the output for the function"; } diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 37fa1b053784..a0ae45677d89 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -109,18 +109,13 @@ void RetValue(const tvm::relax::Expr& value) { // Exit BlockFrame if (block_frame.defined()) { block_frame.value()->ExitWithScope(); - //ICHECK(!IRBuilder::Current()->FindFrame()) + // ICHECK(!IRBuilder::Current()->FindFrame()) // << "ValueError: Relax functions don't support return in true/false branch of If Node."; } // Step 2. Add the output value to the function frame. Array all_frames = IRBuilder::Current()->frames; - int i = 0; - for (auto f : all_frames) { - LOG(INFO) << "yongwww frame_" << i++ << " = " << f; - } // IfFrame if_frame = IRBuilder::Current()->FindFrame().value(); - // LOG(INFO) << "return if_frame: " << if_frame; IRBuilderFrame last_frame = all_frames[all_frames.size() - 1]; Optional then_frame = IRBuilder::Current()->GetLastFrame(); @@ -138,7 +133,6 @@ void RetValue(const tvm::relax::Expr& value) { LOG(INFO) << "return FunctionFrame frame: " << frame; frame->output = std::move(normalized_value); - // block_frame.value()->ExitWithScope(); } TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function); diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 33ba2b78995c..4e4cf51592f2 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -67,12 +67,6 @@ void IRDocsifierNode::AddReturnExpr(const RelayExpr& ret_expr) { relax_return_exprs.insert(ret_expr); } -Optional IRDocsifierNode::LookupBinding(const relax::Var& var) { - auto it = binding_table_.find(var->vid); - if (it == binding_table_.end()) return NullOpt; - return it->second; -} - bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); } void IRDocsifierNode::RemoveVar(const ObjectRef& obj) { diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index 81e6d41c73c3..8d48adc17147 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -27,67 +27,26 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& using relax::SeqExpr; ExprDoc cond = d->AsDoc(n->cond, n_p->Attr("cond")); std::vector> branches; - // todo(yongwww): looks the relax_return_exprs are the values, and normalizer adds a new binding - // need to figure out a way to get if the seqexpr.body was bound to one of relax_return_exprs, too - // complicated! - for (auto ret_expr : d->relax_return_exprs) { - LOG(INFO) << "yongwww 33 ret_expr: " << ret_expr; - } - auto true_seq_expr = Downcast(n->true_branch); - auto false_seq_expr = Downcast(n->false_branch); - if (const auto* var_node = true_seq_expr->body.as()) { - auto t_var = GetRef(var_node); - LOG(INFO) << "yongwww true_seq_expr->body: " << t_var << " -- val: " << d->LookupBinding(t_var); - } - - for (auto ele : d->binding_table_) { - LOG(INFO) << "ele k: " << ele.first << " - value: " << ele.second; - } - - if (const auto* var_node = false_seq_expr->body.as()) { - auto t_var = GetRef(var_node); - LOG(INFO) << "yongwww false_seq_expr->body: " << t_var - << " -- val: " << d->LookupBinding(t_var); - } - bool ret_true_branch = false; - bool ret_false_branch = false; - relax::BindingBlock last_block_true = true_seq_expr->blocks[true_seq_expr->blocks.size() - 1]; - relax::Binding last_binding_true = - last_block_true->bindings[last_block_true->bindings.size() - 1]; - if (auto* var_binding = last_binding_true.as()) { - auto last_var_binding_true = GetRef(var_binding); - if (last_var_binding_true->var.same_as(true_seq_expr->body) && - d->relax_return_exprs.find(last_var_binding_true->value) != d->relax_return_exprs.end()) { - ret_true_branch = true; - LOG(INFO) << "yongwww ret_true_branch true"; - } - } - - relax::BindingBlock last_block_false = false_seq_expr->blocks[false_seq_expr->blocks.size() - 1]; - relax::Binding last_binding_false = - last_block_false->bindings[last_block_false->bindings.size() - 1]; - if (auto* var_binding = last_binding_false.as()) { - auto last_var_binding_false = GetRef(var_binding); - if (last_var_binding_false->var.same_as(false_seq_expr->body) && - d->relax_return_exprs.find(last_var_binding_false->value) != d->relax_return_exprs.end()) { - ret_false_branch = true; - LOG(INFO) << "yongwww ret_false_branch true"; + // normalizer adds a new binding, need to figure out if the seqexpr.body was bound + auto is_return = [](const SeqExpr& seq_expr, const IRDocsifier& dd) { + relax::BindingBlock last_block = seq_expr->blocks[seq_expr->blocks.size() - 1]; + relax::Binding last_binding = last_block->bindings[last_block->bindings.size() - 1]; + if (auto* var_binding = last_binding.as()) { + auto last_var_binding = GetRef(var_binding); + if (last_var_binding->var.same_as(seq_expr->body) && + dd->relax_return_exprs.find(last_var_binding->value) != dd->relax_return_exprs.end()) { + return true; + } } - } + return false; + }; - if (d->relax_return_exprs.find(true_seq_expr->body) != d->relax_return_exprs.end()) { - branches.push_back(PrintSeqExpr(true_seq_expr, n_p->Attr("true_branch"), d, ret_true_branch)); - } else { - branches.push_back(PrintSeqExpr(true_seq_expr, n_p->Attr("true_branch"), d, ret_true_branch)); - } - - if (d->relax_return_exprs.find(false_seq_expr->body) != d->relax_return_exprs.end()) { - branches.push_back( - PrintSeqExpr(false_seq_expr, n_p->Attr("false_branch"), d, ret_false_branch)); - } else { - branches.push_back( - PrintSeqExpr(false_seq_expr, n_p->Attr("false_branch"), d, ret_false_branch)); - } + auto true_seq_expr = Downcast(n->true_branch); + auto false_seq_expr = Downcast(n->false_branch); + bool ret_true_branch = is_return(true_seq_expr, d); + bool ret_false_branch = is_return(false_seq_expr, d); + branches.push_back(PrintSeqExpr(true_seq_expr, n_p->Attr("true_branch"), d, ret_true_branch)); + branches.push_back(PrintSeqExpr(false_seq_expr, n_p->Attr("false_branch"), d, ret_false_branch)); if (var.defined()) { for (Array& stmts : branches) { @@ -95,7 +54,6 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& ExprDoc ret = Downcast(stmts.back())->expr; stmts.Set(stmts.size() - 1, AssignDoc(var.value(), ret, ann)); } - LOG(INFO) << "yongwww stmts.back() key: " << stmts.back()->GetTypeKey(); } } return IfDoc(cond, branches[0], branches[1]); @@ -117,7 +75,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::VarBinding n, ObjectPath n_p, IRDocsifier d) -> Doc { - d->binding_table_[n->var->vid] = n->value; if (const auto if_ = n->value.as()) { Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 8a6419e5fe3e..298d58218b6f 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -33,6 +33,7 @@ def _check( expect: Optional[Union[relax.Function, IRModule]] = None, ): test = parsed.script(show_meta=True) + print(test) roundtrip_mod = tvm.script.from_source(test) tvm.ir.assert_structural_equal(parsed, roundtrip_mod) if expect: @@ -1192,11 +1193,11 @@ def noelse(x: R.Tensor) -> R.Tensor: def foo0(x: R.Tensor) -> R.Tensor: y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool") if y: - v = R.add(x, x) - return v + r = R.add(x, x) + return r else: - v = R.multiply(x, x) - return v + r = R.multiply(x, x) + return r @R.function def foo1(x: R.Tensor) -> R.Tensor: @@ -1211,19 +1212,27 @@ def foo1(x: R.Tensor) -> R.Tensor: def foo2(x: R.Tensor) -> R.Tensor: y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool") if y: - v = R.add(x, x) + r = R.add(x, x) else: return R.multiply(x, x) - return v + return r MultiReturn.show() - print("yongwww get_global_info:", MultiReturn.get_global_info("relax_return_exprs")) + # print("yongwww get_global_info:", MultiReturn.get_global_info("relax_return_exprs")) _check(MultiReturn) roundtrip_mod = tvm.script.from_source(MultiReturn.script(show_meta=True)) tvm.ir.assert_structural_equal(MultiReturn, roundtrip_mod, True) +def test_meta_data(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): + a = R.const([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], "float32") + g = R.add(x, a) + return g + + _check(foo) + + if __name__ == "__main__": - # tvm.testing.main() - # test_module_with_attr_and_global_info() - test_multi_return() + tvm.testing.main()