diff --git a/3rdparty/cutlass b/3rdparty/cutlass index a3bcc6981d..8b42e751c6 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit a3bcc6981d5dad3afb212689e2c7853d1b1ee45d +Subproject commit 8b42e751c63ba219755c8ed91af5f6ec1ecc1ee6 diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index afd5ef5302..4b4ef355b6 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -16,8 +16,8 @@ # under the License. if(USE_CUDA AND USE_CUTLASS) - tvm_file_glob(GLOB CUTLASS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc) - list(APPEND COMPILER_SRCS ${CUTLASS_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc src/relax/backend/contrib/cutlass/*.cc) + list(APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC}) message(STATUS "Build with CUTLASS") endif() diff --git a/cmake/modules/contrib/DNNL.cmake b/cmake/modules/contrib/DNNL.cmake index 7547af81eb..857f7bdfd5 100644 --- a/cmake/modules/contrib/DNNL.cmake +++ b/cmake/modules/contrib/DNNL.cmake @@ -21,8 +21,8 @@ if(IS_DIRECTORY ${USE_DNNL}) message(WARNING "Cannot find DNNL library at ${USE_DNNL}.") else() add_definitions(-DUSE_JSON_RUNTIME=1) - tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc) - list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc src/relax/backend/contrib/dnnl/*.cc) + list(APPEND COMPILER_SRCS ${DNNL_CONTRIB_SRC}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL}) tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -34,8 +34,8 @@ if(IS_DIRECTORY ${USE_DNNL}) endif() elseif((USE_DNNL STREQUAL "ON") OR (USE_DNNL STREQUAL "JSON")) add_definitions(-DUSE_JSON_RUNTIME=1) - tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc) - list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) + tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc src/relax/backend/contrib/dnnl/*.cc) + list(APPEND COMPILER_SRCS ${DNNL_CONTRIB_SRC}) find_library(EXTERN_LIBRARY_DNNL dnnl) list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL}) diff --git a/gallery/how_to/work_with_relay/using_pipeline_executor.py b/gallery/how_to/work_with_relay/using_pipeline_executor.py index 8f61368656..4a28a59251 100755 --- a/gallery/how_to/work_with_relay/using_pipeline_executor.py +++ b/gallery/how_to/work_with_relay/using_pipeline_executor.py @@ -29,12 +29,8 @@ from tvm import relay from tvm.relay import testing import tvm.testing -from tvm.contrib.cutlass import ( - has_cutlass, - num_cutlass_partitions, - finalize_modules, - finalize_modules_vm, -) +from tvm.contrib.cutlass import finalize_modules + img_size = 8 ####################################################################### diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index e394e9ff53..ae0c7e548c 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -45,6 +45,9 @@ namespace relax { */ bool MatchExpr(DFPattern pattern, Expr expr, Optional> var2val = NullOpt); +Optional> ExtractMatchedExpr( + DFPattern pattern, Expr expr, Optional> bindings_opt = NullOpt); + /** * \brief Match a sub-graph in a DataflowBlock with a graph of patterns and return the mapping. * \note This algorithm returns the first matched sub-graph. Use `start_hint` to specify the diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 68d8fe7cef..af95622d76 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -20,7 +20,7 @@ import os import multiprocessing import tvm -from tvm import runtime, relay +from tvm import runtime, relay, relax from tvm.contrib.nvcc import get_cuda_version from tvm._ffi.registry import register_func from .gen_gemm import CutlassGemmProfiler @@ -516,6 +516,22 @@ def tune_cutlass_function( ) +@register_func("contrib.cutlass.compile") +def compile_cutlass_module(c_source_module): + # TODO: Pass them as param + tmp_dir = "tmp" + compile_config = {"sm": 80, "threads": -1, "use_fast_math": False} + + function_names = c_source_module.get_function("get_func_names")() + compile_options = _get_cutlass_compile_options(**compile_config) + lib_path = os.path.join(tmp_dir, "cutlass.o") + logger.info("Compiling generated CUTLASS code") + c_source_module.export_library(lib_path, workspace_dir=tmp_dir, **compile_options) + + # Recover static library + return tvm.runtime.load_static_library(lib_path, function_names) + + @register_func("relay.ext.cutlass.compile_for_cutlass") def compile_for_cutlass(mod, cutlass_target): """Given an IRModule with at least one Compiler='cutlass' Relay function, return a @@ -558,6 +574,8 @@ def compile_for_cutlass(mod, cutlass_target): logger.info("Creating CSource module for CUTLASS") create_c_source_module = tvm._ffi.get_global_func("relay.ext.cutlass.create_c_source_module") c_module = create_c_source_module(mod) + + # TODO: use compile_cutlass_module above function_names = c_module.get_function("get_func_names")() compile_options = _get_cutlass_compile_options(**compile_config) lib_path = os.path.join(tmp_dir, "cutlass.o") @@ -633,3 +651,16 @@ def finalize_modules_vm(vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro", fo.write(code) lib = tvm.runtime.load_module(lib_path) return tvm.runtime.vm.Executable.load_exec(code, lib) + + +def finalize_modules_relax( + vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro", tmp_dir="./tmp" +): + lib_path = os.path.join(tmp_dir, lib_path) + vmcode_path = os.path.join(tmp_dir, vmcode_path) + + lib = vm_exec.mod + lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc") + lib = tvm.runtime.load_module(lib_path) + + return relax.vm.Executable(lib) diff --git a/python/tvm/relax/dpl/context.py b/python/tvm/relax/dpl/context.py index a621d31460..69a5e70ed0 100644 --- a/python/tvm/relax/dpl/context.py +++ b/python/tvm/relax/dpl/context.py @@ -20,7 +20,7 @@ from typing import Optional, Dict import tvm -from tvm.relax import DataflowBlock, Var +from ..expr import DataflowBlock, Var from .pattern import DFPattern from . import _ffi as ffi diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 31dbffda4a..7c360e57ab 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -25,10 +25,10 @@ import tvm import tvm._ffi as tvm_ffi from tvm.ir.expr import PrimExpr -from tvm.relax import Expr, Var from tvm.relay.op import get from tvm.ir.container import Array +from ..expr import Expr, Var from ...ir import make_node from ...runtime import Object from ...ir.base import Node @@ -198,7 +198,7 @@ def match(self, expr, var2val: Optional[Dict[Var, Expr]] = None) -> bool: Unlike Relay whose function is an expression, functions in Relax consists of blocks of bindings that they are not syntactically connected. We use a mapping (i.e., var2val) to migrate the gap. For example, to when matching - "relax.add(lv0, lv1)", given var2val, we match lv0's binded expression + "relax.add(lv0, lv1)", given var2val, we match lv0's bound expression when the recursive pattern matching goes to check lv0. The var2val mapping can be computed through the tvm.relax.analysis.get_var2val function. """ diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 323b4f3d1a..d657dd9e8b 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -25,6 +25,7 @@ import tvm.ir from tvm.runtime import NDArray from . import _ffi_api +from ..dpl import DFPattern @tvm._ffi.register_object("relax.FunctionPass") @@ -286,6 +287,11 @@ def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass: return _ffi_api.FuseOps(fuse_opt_level) # type: ignore +def FuseOpsByPattern(pattern: DFPattern) -> tvm.ir.transform.Pass: + """TODO""" + return _ffi_api.FuseOpsByPattern(pattern) # type: ignore + + def FuseTIR() -> tvm.ir.transform.Pass: """Fuse primitive relax function into a larger TIR function if possible diff --git a/src/relax/analysis/var2value.cc b/src/relax/analysis/var2value.cc index d034afeb21..0e30427397 100644 --- a/src/relax/analysis/var2value.cc +++ b/src/relax/analysis/var2value.cc @@ -28,6 +28,7 @@ class Var2ValAnalysis : public relax::ExprVisitor { tvm::runtime::Map var2value_; void VisitBinding_(const VarBindingNode* binding) override { var2value_.Set(binding->var, binding->value); + VisitExpr(binding->value); } }; diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index 809156bfeb..7daa63f7b6 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -358,7 +358,7 @@ class JSONSerializer // TODO(@sunggg): Revisit when we have op naming convention. // Currently, simply remove "relax." prefix to make it work. - name = std::string("tensorrt.") + name.substr(6); + name = std::string("dnnl.") + name.substr(6); std::vector inputs; for (const auto& arg : cn->args) { diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc new file mode 100644 index 0000000000..f4380071ed --- /dev/null +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -0,0 +1,1074 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/cutlass/codegen.cc + * \brief Implementation of the CUTLASS JSON serializer. + */ +#include +#include +#include +#include + +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using Str2StrMap = std::unordered_map; + +static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, + {"float32", "float"}, + {"int8", "int8_t"}, + {"uint8", "uint8_t"}, + {"int32", "int32_t"}}; + +constexpr const char* kAnyDim = "Any"; + +std::string GetDimAsStr(ObjectRef dim) { + if (auto d = dim.as()) { + return std::to_string(d->value); + } + return kAnyDim; +} + +inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int indent = 2) { + for (int i = 0; i < indent; ++i) { + os << " "; + } + os << stmt; +} + +Str2StrMap ArgsCommon(const Map& attrs) { + Str2StrMap args; + auto arg0_dtype = std::string(attrs["arg0_dtype"].as()->data); + auto arg1_dtype = std::string(attrs["arg1_dtype"].as()->data); + auto ret_dtype = std::string(attrs["ret_dtype"].as()->data); + args["ElementInputA"] = dtype_map.at(arg0_dtype); + args["ElementInputB"] = dtype_map.at(arg1_dtype); + args["ElementOutput"] = dtype_map.at(ret_dtype); + args["op_def"] = std::string(attrs["cutlass_op_def"].as()->data); + args["op_name"] = std::string(attrs["cutlass_op_name"].as()->data); + args["op_type"] = std::string(attrs["op_type"].as()->data); + return args; +} + +Str2StrMap GemmArgsCommon(const Map& attrs) { + Str2StrMap args = ArgsCommon(attrs); + args["lda"] = std::string(attrs["lda"].as()->data); + args["ldb"] = std::string(attrs["ldb"].as()->data); + args["ldc"] = std::string(attrs["ldc"].as()->data); + return args; +} + +Str2StrMap DenseArgs(const Map& attrs) { + Str2StrMap args = GemmArgsCommon(attrs); + auto arg0_shape = attrs["arg0_shape"].as(); + auto arg1_shape = attrs["arg1_shape"].as(); + args["M"] = GetDimAsStr(arg0_shape->at(0)); + args["K"] = GetDimAsStr(arg0_shape->at(1)); + args["N"] = GetDimAsStr(arg1_shape->at(0)); + return args; +} + +Str2StrMap BatchMatmulArgs(const Map& attrs) { + Str2StrMap args = GemmArgsCommon(attrs); + args["batch"] = GetDimAsStr(attrs["batch"]); + args["batch_stride_A"] = GetDimAsStr(attrs["batch_stride_A"]); + args["batch_stride_B"] = GetDimAsStr(attrs["batch_stride_B"]); + args["batch_stride_C"] = GetDimAsStr(attrs["batch_stride_C"]); + auto arg0_shape = attrs["arg0_shape"].as(); + auto arg1_shape = attrs["arg1_shape"].as(); + args["M"] = GetDimAsStr(arg0_shape->at(1)); + args["K"] = GetDimAsStr(arg0_shape->at(2)); + args["N"] = GetDimAsStr(arg1_shape->at(1)); + return args; +} + +void AppendPrologue(std::ostringstream& gemm_decl, const Str2StrMap& attrs, + const std::vector& func_args, const std::string& kernel, + bool has_bias, bool is_gelu, int m_axis_idx, int n_axis_idx, int k_axis_idx) { + CutlassPrint(gemm_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n"); + CutlassPrint(gemm_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n"); + CutlassPrint(gemm_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n"); + CutlassPrint(gemm_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n"); + CutlassPrint(gemm_decl, attrs.at("op_def")); + CutlassPrint(gemm_decl, "using " + kernel + " = Operation_" + attrs.at("op_name") + ";\n"); + + auto get_dim = [&attrs, &func_args](const std::string& axis, int arg_idx, int axis_idx) { + if (attrs.at(axis) == kAnyDim) { + return func_args[arg_idx] + "->shape[" + std::to_string(axis_idx) + "]"; + } else { + return attrs.at(axis); + } + }; + CutlassPrint(gemm_decl, "int M = " + get_dim("M", 0, m_axis_idx) + ";\n"); + CutlassPrint(gemm_decl, "int N = " + get_dim("N", 1, n_axis_idx) + ";\n"); + CutlassPrint(gemm_decl, "int K = " + get_dim("K", 0, k_axis_idx) + ";\n"); + CutlassPrint(gemm_decl, "cutlass::gemm::GemmCoord problem_size(M, N, K);\n"); + CutlassPrint(gemm_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n"); + if (is_gelu) { + // GeLU epilogue does not compile with NoBetaScaling, so we explicitly specify the scale. + CutlassPrint(gemm_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n"); + } else { + CutlassPrint(gemm_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); + } + + ICHECK(func_args.size() >= 2); + CutlassPrint(gemm_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n"); + CutlassPrint(gemm_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n"); + if (has_bias) { + ICHECK(func_args.size() >= 3); + CutlassPrint(gemm_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n"); + } + + CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0->data);\n"); + + CutlassPrint(gemm_decl, "typename " + kernel + "::Arguments arguments{\n"); + CutlassPrint(gemm_decl, " problem_size,\n"); +} + +void AppendGemmExecute(std::ostringstream& gemm_decl, const std::string& kernel) { + // Using the arguments, query for extra workspace required for matrix multiplication computation + CutlassPrint(gemm_decl, + "size_t workspace_size = " + kernel + "::get_workspace_size(arguments);\n"); + // Allocate workspace memory + CutlassPrint(gemm_decl, + "cutlass::device_memory::allocation workspace(workspace_size);\n"); + // Instantiate CUTLASS kernel depending on template + CutlassPrint(gemm_decl, kernel + " gemm_op;\n"); + + // Check the problem size is supported or not + CutlassPrint(gemm_decl, "cutlass::Status status = gemm_op.can_implement(arguments);\n"); + CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + // Initialize CUTLASS kernel with arguments and workspace pointer + CutlassPrint(gemm_decl, "status = gemm_op.initialize(arguments, workspace.get());\n"); + CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + // Launch initialized CUTLASS kernel + CutlassPrint(gemm_decl, "status = gemm_op();\n"); + CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); +} + +std::string DenseOp(std::string id, const Str2StrMap& attrs, + const std::vector& func_args) { + bool has_bias = attrs.at("op_type").find("bias") != std::string::npos; + bool is_gelu = + attrs.at("op_type").find("cutlass.dense_bias_gelu") != std::string::npos; // fp32 or fp16 + std::ostringstream gemm_decl; + AppendPrologue(gemm_decl, attrs, func_args, "Gemm", has_bias, is_gelu, 0, 0, 1); + + CutlassPrint(gemm_decl, " {static_cast(ptr_a), " + attrs.at("lda") + "},\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_b), " + attrs.at("ldb") + "},\n"); + if (has_bias) { + CutlassPrint(gemm_decl, " {static_cast(ptr_c_bias), 0},\n"); + } else { + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + } + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + if (has_bias && !is_gelu) { + CutlassPrint(gemm_decl, " {alpha},\n"); + } else { + // For GeLU, we explicitly specify the scale. + CutlassPrint(gemm_decl, " {alpha, beta},\n"); + } + CutlassPrint(gemm_decl, " 1};\n"); // split_k_slices + + AppendGemmExecute(gemm_decl, "Gemm"); + return gemm_decl.str(); +} + +std::string BatchMatmulOp(std::string id, const Str2StrMap& attrs, + const std::vector& func_args) { + std::ostringstream gemm_decl; + AppendPrologue(gemm_decl, attrs, func_args, "BatchedGemm", false, false, 1, 1, 2); + + auto get_batch_stride = [&attrs, &func_args](const std::string& name, int arg0_idx, int arg1_idx, + int arg0_axis_idx, int arg1_axis_idx) { + if (attrs.at(name) == kAnyDim) { + return func_args[arg0_idx] + "->shape[" + std::to_string(arg0_axis_idx) + "] * " + + func_args[arg1_idx] + "->shape[" + std::to_string(arg1_axis_idx) + "]"; + } else { + return attrs.at(name); + } + }; + + CutlassPrint(gemm_decl, " {static_cast(ptr_a), " + attrs.at("lda") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_A", 0, 0, 1, 2) + ",\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_b), " + attrs.at("ldb") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_B", 1, 1, 1, 2) + ",\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) + ",\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) + ",\n"); + CutlassPrint(gemm_decl, " {alpha, beta},\n"); + + if (attrs.at("batch") == kAnyDim) { + CutlassPrint(gemm_decl, func_args[0] + "->shape[0]" + "};\n"); + } else { + CutlassPrint(gemm_decl, attrs.at("batch") + "};\n"); + } + + AppendGemmExecute(gemm_decl, "BatchedGemm"); + return gemm_decl.str(); +} + +Str2StrMap Conv2dArgs(const Map& attrs, bool is_dgrad = false, + bool is_wgrad = false) { + Str2StrMap args = ArgsCommon(attrs); + auto arg0_shape = attrs["arg0_shape"].as(); + auto arg1_shape = attrs["arg1_shape"].as(); + auto ret_shape = attrs["ret_shape"].as(); + auto activation_shape = arg0_shape; + auto weight_shape = arg1_shape; + auto output_shape = ret_shape; + + if (is_dgrad) { + activation_shape = ret_shape; + output_shape = arg0_shape; + } else if (is_wgrad) { + activation_shape = arg1_shape; + weight_shape = ret_shape; + output_shape = arg0_shape; + } + + args["N"] = GetDimAsStr(activation_shape->at(0)); + args["H"] = GetDimAsStr(activation_shape->at(1)); + args["W"] = GetDimAsStr(activation_shape->at(2)); + args["C"] = GetDimAsStr(activation_shape->at(3)); + args["P"] = GetDimAsStr(output_shape->at(1)); + args["Q"] = GetDimAsStr(output_shape->at(2)); + args["K"] = GetDimAsStr(output_shape->at(3)); + args["R"] = GetDimAsStr(weight_shape->at(1)); + args["S"] = GetDimAsStr(weight_shape->at(2)); + args["pad_h"] = GetDimAsStr(attrs["padding"].as()->at(0)); + args["pad_w"] = GetDimAsStr(attrs["padding"].as()->at(1)); + args["stride_h"] = GetDimAsStr(attrs["strides"].as()->at(0)); + args["stride_w"] = GetDimAsStr(attrs["strides"].as()->at(1)); + args["dilation_h"] = GetDimAsStr(attrs["dilation"].as()->at(0)); + args["dilation_w"] = GetDimAsStr(attrs["dilation"].as()->at(1)); + + return args; +} + +std::string Conv2dOp(std::string id, const Str2StrMap& attrs, + const std::vector& func_args, bool has_residual_block = false) { + auto op_type = attrs.at("op_type"); + bool has_bias = op_type.find("bias") != std::string::npos; + bool no_bias_scaling = op_type != "cutlass.conv2d_bias_sigmoid" && + op_type != "cutlass.conv2d_bias_silu" && + op_type != "cutlass.conv2d_bias_hardswish"; + + const std::string op_name = attrs.at("op_name"); + std::ostringstream conv2d_decl; + CutlassPrint(conv2d_decl, attrs.at("op_def")); + CutlassPrint(conv2d_decl, "using Operation_" + op_name + + " = cutlass::conv::device::ImplicitGemmConvolution<" + op_name + + ">;\n"); + CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + op_name + ";\n"); + CutlassPrint(conv2d_decl, "using ElementInputA = Conv2d::ElementA;\n"); + CutlassPrint(conv2d_decl, "using ElementInputB = Conv2d::ElementB;\n"); + CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = Conv2d::ElementAccumulator;\n"); + + auto get_dim = [&attrs](const std::string& axis, const std::string& var_name, int axis_idx) { + if (attrs.at(axis) == kAnyDim) { + return var_name + "->shape[" + std::to_string(axis_idx) + "]"; + } else { + return attrs.at(axis); + } + }; + + CutlassPrint(conv2d_decl, "int N = " + get_dim("N", func_args[0], 0) + ";\n"); + CutlassPrint(conv2d_decl, "int H = " + get_dim("H", func_args[0], 1) + ";\n"); + CutlassPrint(conv2d_decl, "int W = " + get_dim("W", func_args[0], 2) + ";\n"); + CutlassPrint(conv2d_decl, "int C = " + attrs.at("C") + ";\n"); + CutlassPrint(conv2d_decl, "int K = " + attrs.at("K") + ";\n"); + CutlassPrint(conv2d_decl, "int R = " + attrs.at("R") + ";\n"); + CutlassPrint(conv2d_decl, "int S = " + attrs.at("S") + ";\n"); + CutlassPrint(conv2d_decl, "int P = " + get_dim("P", "out0", 1) + ";\n"); + CutlassPrint(conv2d_decl, "int Q = " + get_dim("Q", "out0", 2) + ";\n"); + CutlassPrint(conv2d_decl, "int pad_h = " + attrs.at("pad_h") + ";\n"); + CutlassPrint(conv2d_decl, "int pad_w = " + attrs.at("pad_w") + ";\n"); + CutlassPrint(conv2d_decl, "int stride_h = " + attrs.at("stride_h") + ";\n"); + CutlassPrint(conv2d_decl, "int stride_w = " + attrs.at("stride_w") + ";\n"); + CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + ";\n"); + CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + ";\n"); + + const bool use_split_k = op_name.find("splitk") != std::string::npos; + + if (use_split_k) { + std::string split_k_slices = op_name.substr(op_name.find_last_not_of("0123456789") + 1); + CutlassPrint(conv2d_decl, "int split_k_slices = " + split_k_slices + ";\n"); + } else { + CutlassPrint(conv2d_decl, "int split_k_slices = 1;\n"); + } + + CutlassPrint( + conv2d_decl, + "cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, " + "stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, " + "split_k_slices);\n"); + + const std::string split_k_mode = use_split_k ? "kParallel" : "kSerial"; + CutlassPrint(conv2d_decl, + "const cutlass::conv::SplitKMode split_k_mode = cutlass::conv::SplitKMode::" + + split_k_mode + ";\n"); + + bool is_wgrad = op_type.find("backward_weight") != std::string::npos; + bool is_dgrad = op_type.find("conv2d_transpose") != std::string::npos; + + ICHECK(func_args.size() >= 2); + CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n"); + CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n"); + + if (has_residual_block) { + ICHECK(func_args.size() >= 4); + CutlassPrint(conv2d_decl, "void* ptr_bias = (void*)(" + func_args[2] + "->data);\n"); + CutlassPrint(conv2d_decl, "void* ptr_residual = (void*)(" + func_args[3] + "->data);\n"); + } else if (has_bias) { + ICHECK(func_args.size() >= 3); + CutlassPrint(conv2d_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n"); + } + + CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n"); + CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n"); + if ((!has_bias || no_bias_scaling) && !has_residual_block) { + CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); + } else { + CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n"); + } + CutlassPrint(conv2d_decl, "using cutlass::layout::TensorNHWC;\n"); + CutlassPrint(conv2d_decl, + "auto activation_shape = TensorNHWC::packed(cutlass::make_Coord(N, H, W, C));\n"); + CutlassPrint(conv2d_decl, + "auto weight_shape = TensorNHWC::packed(cutlass::make_Coord(K, R, S, C));\n"); + CutlassPrint(conv2d_decl, + "auto output_oshape = TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K));\n"); + + if (is_wgrad) { + CutlassPrint(conv2d_decl, "TensorNHWC layout_A(output_oshape);\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_B(activation_shape);\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_C(weight_shape);\n\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_D(weight_shape);\n\n"); + } else if (is_dgrad) { + CutlassPrint(conv2d_decl, "TensorNHWC layout_A(output_oshape);\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_B(weight_shape);\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_C(activation_shape);\n\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_D(activation_shape);\n\n"); + } else { + CutlassPrint(conv2d_decl, "TensorNHWC layout_A(activation_shape);\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_B(weight_shape);\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_C(output_oshape);\n\n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_D(output_oshape);\n\n"); + } + + if (use_split_k) { + CutlassPrint(conv2d_decl, "using ElementOutput = EpilogueOutputOp::ElementOutput;\n"); + } else { + CutlassPrint(conv2d_decl, "using ElementOutput = Conv2d::ElementC;\n"); + } + + std::string tensor_c_init = "{static_cast(ptr_out), layout_C}"; + if (has_residual_block) { + tensor_c_init = "{static_cast(ptr_residual), layout_C}"; + } else if (has_bias) { + tensor_c_init = + "{static_cast(ptr_c_bias), cutlass::layout::TensorNHWC::Stride(0)}"; + } + + CutlassPrint(conv2d_decl, + "cutlass::TensorRef tensor_c" + tensor_c_init + ";\n"); + CutlassPrint(conv2d_decl, + "cutlass::TensorRef " + "tensor_d{static_cast(ptr_out),layout_D};\n"); + + CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n"); + CutlassPrint(conv2d_decl, " problem_size,\n"); + CutlassPrint(conv2d_decl, " {static_cast(ptr_a), layout_A},\n"); + CutlassPrint(conv2d_decl, " {static_cast(ptr_b), layout_B},\n"); + + if (use_split_k) { + CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n"); + CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n"); + } else { + CutlassPrint(conv2d_decl, " tensor_c,\n"); + CutlassPrint(conv2d_decl, " tensor_d,\n"); + } + + if (has_residual_block) { + ICHECK(use_split_k == false) << "Split-k not supported for residual block fusion"; + CutlassPrint(conv2d_decl, "{alpha, beta},\n"); + CutlassPrint(conv2d_decl, "cutlass::conv::SplitKMode::kSerial,\n"); // split_k_slices + CutlassPrint(conv2d_decl, "static_cast(ptr_bias),\n"); + CutlassPrint(conv2d_decl, "nullptr, 0, K};\n"); + } else if (has_bias && no_bias_scaling) { + CutlassPrint(conv2d_decl, " {alpha},\n"); + CutlassPrint(conv2d_decl, "split_k_mode\n};\n"); + } else { + CutlassPrint(conv2d_decl, "{alpha, beta},\n"); + CutlassPrint(conv2d_decl, "split_k_mode\n};\n"); + } + + CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n"); + + CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n"); + // Allocate workspace memory + CutlassPrint(conv2d_decl, + "cutlass::device_memory::allocation workspace(workspace_size);\n"); + // Check the problem size is supported or not + CutlassPrint(conv2d_decl, "cutlass::Status status = conv2d_op.can_implement(arguments);\n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n"); + + if (use_split_k) { + CutlassPrint(conv2d_decl, + "arguments.ref_D.reset(reinterpret_cast(workspace.get())," + " layout_D);\n\n"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + CutlassPrint(conv2d_decl, "status = conv2d_op.initialize(arguments, workspace.get());\n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n"); + + if (use_split_k) { + CutlassPrint( + conv2d_decl, + "arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}; \n"); + CutlassPrint(conv2d_decl, "status = conv2d_op.update(arguments, workspace.get()); \n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n"); + } + + // Launch initialized CUTLASS kernel + CutlassPrint(conv2d_decl, "status = conv2d_op();\n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n"); + + if (use_split_k) { + CutlassPrint(conv2d_decl, "ReductionDevice reduction_op;\n"); + CutlassPrint(conv2d_decl, + "const static cutlass::conv::Operator kConvolutionalOperator = " + "Conv2d::kConvolutionalOperator;\n"); + CutlassPrint(conv2d_decl, "typename ReductionDevice::Arguments reduction_args(\n"); + CutlassPrint(conv2d_decl, + "cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, " + "problem_size).mn(),\n"); + CutlassPrint(conv2d_decl, "problem_size.split_k_slices,\n"); + CutlassPrint(conv2d_decl, + "cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, " + "problem_size),\n"); + CutlassPrint(conv2d_decl, "{\n"); + CutlassPrint(conv2d_decl, + " reinterpret_cast (workspace.get()),\n"); + CutlassPrint(conv2d_decl, + "ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::" + "kTensorCStrideIdx])\n"); + CutlassPrint(conv2d_decl, "},\n"); + CutlassPrint(conv2d_decl, "{\n"); + CutlassPrint(conv2d_decl, "tensor_d.data(),\n"); + CutlassPrint(conv2d_decl, + "ReductionStrideIndex(tensor_d.stride()[Conv2d::ImplicitGemmKernel::" + "kTensorCStrideIdx])\n"); + CutlassPrint(conv2d_decl, "},\n"); + CutlassPrint(conv2d_decl, "{\n"); + CutlassPrint(conv2d_decl, "tensor_c.data(),\n"); + CutlassPrint(conv2d_decl, + "ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::" + "kTensorCStrideIdx])\n"); + CutlassPrint(conv2d_decl, "},\n"); + CutlassPrint(conv2d_decl, " {alpha, beta}\n"); + CutlassPrint(conv2d_decl, ");\n\n"); + CutlassPrint(conv2d_decl, "status = reduction_op.initialize(reduction_args, nullptr);\n"); + CutlassPrint(conv2d_decl, "status = reduction_op();\n"); + } + + return conv2d_decl.str(); +} + +struct Output { + std::string name; + std::string dtype; + int size; + bool need_copy; +}; + +struct GenerateBodyOutput { + std::string decl; + std::vector buffers; + std::vector outputs; +}; + +inline bool IsOp(const CallNode* call, const std::string& op_name) { + const auto* op_node = call->op.as(); + if (!op_node) return false; + Op op = GetRef(op_node); + return op == Op::Get(op_name); +} + +class CodegenCutlass : public tvm::relax::backend::MemoizedExprTranslator> { + public: + CodegenCutlass(const std::string& id, const Map& attrs, const Expr& expr) { + // todo: clean up + this->ext_func_id_ = id; + this->attrs_ = attrs; + bindings_ = AnalyzeVar2Value(expr); + } + + std::vector VisitExpr_(const VarNode* node) final { + ext_func_args_.push_back(GetRef(node)); + Output output; + output.name = node->name_hint(); + return {output}; + } + + std::vector VisitExpr_(const CallNode* call) final { + const auto* fn_var = call->op.as(); + ICHECK(fn_var); + const auto func = Downcast(bindings_[GetRef(fn_var)]); + ICHECK(func.defined()) << "Only composite function is supported for CUTLASS."; + GenerateBodyOutput ret = GenerateCompositeFunctionCall(func, call); + ext_func_body_.push_back(ret.decl); + return ret.outputs; + } + + std::string JIT(const std::vector& out) { + CHECK(out.size() > 0); + code_stream_ << "void " << ext_func_id_ << "_("; + + for (const auto& arg : ext_func_args_) { + code_stream_ << "DLTensor* " << arg->name_hint() << ", "; + } + for (size_t i = 0; i < out.size() - 1; ++i) { + code_stream_ << "DLTensor* out" << i << ", "; + } + code_stream_ << "DLTensor* out" << out.size() - 1 << ") {\n"; + this->EnterScope(); + + // Function body + for (auto decl : buf_decl_) { + this->PrintIndents(); + code_stream_ << decl << "\n"; + } + code_stream_ << "\n"; + for (auto stmt : ext_func_body_) { + this->PrintIndents(); + code_stream_ << stmt << "\n"; + } + + this->ExitScope(); + code_stream_ << "}\n"; + + this->GenerateBackendCFunc(ext_func_id_, ext_func_args_, /*const_arr_name=*/"", out, true); + return code_stream_.str(); + } + + /*! \brief The external function source code stream. */ + std::ostringstream code_stream_; + + protected: + std::vector VisitExpr_(const FunctionNode* fn) { + ICHECK(fn->GetAttr(attr::kComposite).defined()) + << "JSON runtime only supports composite functions"; + // FunctionNode should be handled by the caller. + return {}; + } + + std::vector VisitBinding_(const VarBindingNode* binding) { + ICHECK_EQ(memo_.count(binding->var), 0); + memo_[binding->var] = VisitExpr(binding->value); + return VisitExpr(binding->value); + } + + std::vector VisitBinding(const Binding& binding) { + std::vector nodes; + if (const auto* node = binding.as()) { + auto from_b = VisitBinding_(node); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } else { + LOG(FATAL) << "Unimplemented type: " << binding->GetTypeKey(); + } + return nodes; + } + + std::vector VisitBindingBlock(const BindingBlock& block) { + std::vector nodes; + if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + nodes.insert(nodes.end(), from_bb.begin(), from_bb.end()); + } else if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + nodes.insert(nodes.end(), from_bb.begin(), from_bb.end()); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + return nodes; + } + + std::vector VisitBindingBlock_(const BindingBlockNode* block) { + std::vector nodes; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } + return nodes; + } + + std::vector VisitBindingBlock_(const DataflowBlockNode* block) { + std::vector nodes; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } + return nodes; + } + + std::vector VisitExpr_(const SeqExprNode* op) { + std::vector nodes; + + for (BindingBlock block : op->blocks) { + auto from_bb = VisitBindingBlock(block); + } + + auto from_body = VisitExpr(op->body); + nodes.insert(nodes.end(), from_body.begin(), from_body.end()); + + return nodes; + } + + private: + std::vector GetArgumentNames(const CallNode* call) { + std::vector arg_names; + for (size_t i = 0; i < call->args.size(); ++i) { + auto res = VisitExpr(call->args[i]); + for (const auto& out : res) { + arg_names.push_back(out.name); + } + } + return arg_names; + } + + GenerateBodyOutput GenerateCompositeFunctionCall(Function callee, const CallNode* caller) { + const auto pattern_name = callee->GetAttr(attr::kComposite); + ICHECK(pattern_name.defined()) << "Only functions with composite attribute are supported."; + + if (pattern_name == "conv2d_bias_relu") { + const CallNode* conv2d_call = caller; + for (auto [var, val] : bindings_) { + if (val->IsInstance() && IsOp(val.as(), "relax.nn.conv2d")) { + conv2d_call = val.as(); + break; + } + } + return GenerateBody(conv2d_call, "cutlass_conv2d_bias_relu", GetArgumentNames(caller), + Conv2dArgs(std::ref(attrs_))); + } + + LOG(FATAL) << "Unknown composite function: " << pattern_name; + return {}; + } + + GenerateBodyOutput GenerateBody(const CallNode* root_call, const std::string& func_name, + const std::vector& func_args, + const Str2StrMap& attribute_args) { + // Make function call with input buffers when visiting arguements + ICHECK_GT(func_args.size(), 0); + std::ostringstream decl_stream; + decl_stream << "(" << func_args[0]; + for (size_t i = 1; i < func_args.size(); ++i) { + decl_stream << ", " << func_args[i]; + } + // Analyze the output buffers + auto struct_info = GetStructInfo(GetRef(root_call)); + + std::vector out_types; + if (const auto* tensor_sinfo = struct_info.as()) { + out_types.emplace_back(backend::DType2String(tensor_sinfo->dtype)); + } else { + LOG(FATAL) << "Unimplemented"; + } + + GenerateBodyOutput ret; + for (const auto& out_type : out_types) { + const std::string out = "out" + std::to_string(buf_idx_++); + decl_stream << ", " << out; + Output output; + output.name = out; + output.dtype = out_type; + output.need_copy = false; + ret.outputs.push_back(output); + } + decl_stream << ");"; + if (func_name.find("dense") != std::string::npos) { + ret.decl = DenseOp(ext_func_id_, attribute_args, func_args); + } else if (func_name == "cutlass_batch_matmul") { + ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args); + } else if (func_name.find("conv2d") != std::string::npos) { + ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args); + } + return ret; + } + + /*! \brief Print indents using spaces. */ + void PrintIndents() { + for (int i = 0; i < indent_; i++) { + code_stream_ << ' '; + } + } + + /*! + * \brief Enter a new scope. + */ + void EnterScope() { indent_ += 2; } + + /*! + * \brief Exit a scope. + */ + void ExitScope() { + ICHECK_GE(indent_, 2U) << "Wrong ident found."; + indent_ -= 2; + } + + /*! + * \brief Creates a runtime function header + */ + void PrintRuntimeFunctionHeader(std::string func_name) { + code_stream_ << "#ifdef __cplusplus\n"; + code_stream_ << "extern \"C\" {\n"; + code_stream_ << "#endif\n"; + code_stream_ << "TVM_DLL int32_t "; + code_stream_ << func_name << "("; + code_stream_ << "TVMValue* args, "; + code_stream_ << "int* type_code, "; + code_stream_ << "int num_args, "; + code_stream_ << "TVMValue* out_value, "; + code_stream_ << "int* out_type_code) {\n"; + } + + /*! + * \brief Adds a line to convert TVMValue args to DLTensors + */ + void PrintArgToData(int idx) { + PrintIndents(); + code_stream_ << "DLTensor* arg" << idx << " = "; + code_stream_ << "(DLTensor*)(((TVMValue*)args)[" << idx << "].v_handle);\n"; + } + + /*! + * \brief Adds a line to convert TVMValue rets to DLTensors + */ + void PrintRetToData(int idx) { + PrintIndents(); + code_stream_ << "DLTensor* ret" << idx << " = "; + code_stream_ << "(DLTensor*)(((TVMValue*)args)[" << idx << "].v_handle);\n"; + } + + /*! + * \brief Gerenate C code for the external function. + * + * \param func_name The name of the external function. + * \param args arguments to the external function. + * + * \code + * + * Array foo_consts; + * + * // An example code for the generated C function. + * int foo_wrapper_(DLTensor* arg0, + * DLTensor* arg1, + * DLTensor* out) { + * foo_((float*)(arg0->data), + * (float*)(arg1->data), + * (float*)(out->data)); + * return 0; + * } + * + * TVM_DLL_EXPORT_TYPED_FUNC(foo, foo_wrapper_); + * + * int foo_init_wrapper_(Array arr) { + * foo_consts = arr; + * return 0; + * } + * + * TVM_DLL_EXPORT_TYPED_FUNC(__init_foo, foo_init_wrapper_); + * + * \endcode + */ + void GenerateBackendCFunc(const std::string& func_name, const Array& args, + const std::string& const_arr_name, const std::vector& outs, + bool pass_dl_tensor = false) { + // Print signature + code_stream_ << "\n"; + + code_stream_ << "int " << func_name << "_wrapper_("; + for (size_t i = 0; i < args.size(); i++) { + code_stream_ << "DLTensor* arg" << i << ",\n"; + code_stream_ << "\t"; + } + for (size_t i = 0; i < outs.size() - 1; i++) { + code_stream_ << "DLTensor* out" << i << ",\n"; + code_stream_ << "\t"; + } + code_stream_ << "DLTensor* out" << outs.size() - 1 << ") {\n"; + + EnterScope(); + + // Generate the internal call. + PrintIndents(); + code_stream_ << func_name << "_("; + for (size_t i = 0; i < args.size(); i++) { + if (pass_dl_tensor) { + code_stream_ << "arg" << i << ",\n"; + } else { + const auto& dtype_str = GetDtypeString(args[i]); + code_stream_ << "(" << dtype_str << "*)(arg" << i << "->data),\n"; + } + PrintIndents(); + } + for (size_t i = 0; i < outs.size() - 1; i++) { + if (pass_dl_tensor) { + code_stream_ << "out" << i << ",\n"; + } else { + code_stream_ << "(" << outs[i].dtype << "*)(out" << i << "->data),\n"; + } + PrintIndents(); + } + if (pass_dl_tensor) { + code_stream_ << "out" << outs.size() - 1 << ");\n"; + } else { + code_stream_ << "(" << outs.back().dtype << "*)(out" << outs.size() - 1 << "->data));\n"; + } + PrintIndents(); + code_stream_ << "return 0;\n"; + ExitScope(); + code_stream_ << "}\n\n"; + + // Create the external function + PrintRuntimeFunctionHeader(func_name); + EnterScope(); + for (size_t i = 0; i < args.size(); i++) { + PrintArgToData(i); + } + for (size_t i = 0; i < outs.size(); i++) { + PrintRetToData(args.size() + i); + } + PrintIndents(); + code_stream_ << func_name << "_wrapper_("; + for (size_t i = 0; i < args.size(); i++) { + code_stream_ << "arg" << i << ","; + } + for (size_t i = 0; i < outs.size() - 1; i++) { + code_stream_ << "ret" << args.size() + i << ","; + } + code_stream_ << "ret" << args.size() + outs.size() - 1 << ");\n"; + PrintIndents(); + code_stream_ << "return 0;\n"; + ExitScope(); + code_stream_ << "}\n"; + code_stream_ << "#ifdef __cplusplus\n"; + code_stream_ << "}\n"; + code_stream_ << "#endif\n"; + + if (!const_arr_name.empty()) { + // If there are constants, insert the __init_ and the wrapper + // This segment would be generated in C++ because of the usage + // of tvm::runtime::Array. This is not ideal, but this to demonstrate + // constant copying process used packed imports in other external + // codegen. Moreover, in microTVM we dont expect this part to be generated. + code_stream_ << "#ifdef __cplusplus\n"; + code_stream_ << "int " << func_name + << "_init_wrapper_(tvm::runtime::Array arr) {\n"; + EnterScope(); + PrintIndents(); + code_stream_ << func_name << "_consts = arr;\n"; + code_stream_ << "return 0;\n"; + ExitScope(); + code_stream_ << "}\n\n"; + code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(__init_" << func_name << ", " << func_name + << "_init_wrapper_);\n\n"; + code_stream_ << "#endif\n"; + } + } + + std::string GetDtypeString(const Var& var) { + auto ttype = var->checked_type().as(); + ICHECK(ttype) << "Expect TensorTypeNode"; + return GetDtypeString(ttype); + } + + /*! + * \brief Returns dtype string + * + * \param ttype TensorTypeNode* to get the dtype of + * + * \return The dtype string. + */ + std::string GetDtypeString(const TensorTypeNode* ttype) { + std::string dtype; + if (runtime::TypeMatch(ttype->dtype, kDLFloat, 32)) { + dtype = "float"; + } else if (runtime::TypeMatch(ttype->dtype, kDLFloat, 16)) { + dtype = "half"; + } else if (runtime::TypeMatch(ttype->dtype, kDLBfloat, 16)) { + dtype = "bfloat"; + } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 32)) { + dtype = "int"; + } else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) { + dtype = "int64_t"; + } else { + LOG(FATAL) << "Unsupported dtype " << ttype->dtype; + } + + return dtype; + } + + /*! \brief Indent of the source code. */ + int indent_{0}; + /*! \brief The id of the external cutlass ext_func. */ + std::string ext_func_id_; + /*! \brief The attrs of the external cutlass ext_func. */ + Map attrs_; + /*! + * \brief The index to track the output buffer. Each kernel will redirect the + * output to a buffer that may be consumed by other kernels. + */ + int buf_idx_{0}; + /*! \brief The arguments used by a wrapped function that calls CUTLASS kernels. */ + Array ext_func_args_; + /*! \brief Statement of the function that will be compiled using CUTLASS kernels. */ + std::vector ext_func_body_; + /*! \brief The declaration of intermediate buffers. */ + std::vector buf_decl_; + + Map bindings_; +}; + +class CutlassModuleCodegen { + public: + runtime::Module CreateCSourceModule(Function f) { + EmitPreamble(); + GenCutlassFunc(f); + return Finalize(); + } + + private: + void EmitPreamble() { + // create header + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + // cutlass header + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + } + + void GenCutlassFunc(const Function& function) { + ICHECK(function.defined()) << "Input error: expect a Relay function."; + + // Record the external symbol for runtime lookup. + Optional opt_global_symbol = function->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(opt_global_symbol.defined()) + << "CUTLASS functions must have a " << tvm::attr::kGlobalSymbol << " attribute"; + std::string sid = opt_global_symbol.value(); + if (std::find(func_names_.begin(), func_names_.end(), sid) != func_names_.end()) { + // Already emitted. + return; + } + func_names_.push_back(sid); + + const auto* attrs = function->attrs.as(); + ICHECK(attrs != nullptr); + const auto dict = attrs->dict; + CodegenCutlass builder(sid, dict, function); + VLOG(1) << "Creating cutlass C code for '" << sid << "' from:\n" << PrettyPrint(function); + auto out = builder.VisitExpr(function->body); + code_stream_ << builder.JIT(out); + } + + runtime::Module Finalize() { + ICHECK(!func_names_.empty()) + << "Should only create CUTLASS CSourceModule if have at least one CUTLASS partition"; + const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); + ICHECK(pf != nullptr) << "Cannot find CSource module to create the external runtime module"; + VLOG(1) << "Generated CUTLASS code:" << std::endl << code_stream_.str(); + return (*pf)(code_stream_.str(), "cu", func_names_, /*const_vars=*/Array()); + } + + /*! + * \brief Returns \p expr as function if it is a \p Function with "Compiler" attribute + * value "cutlass". + */ + static const FunctionNode* GetCutlassFunctionNode(const Expr& expr) { + if (const auto* function_node = expr.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCodegen); + if (opt_compiler.defined() && opt_compiler.value() == "cutlass") { + return function_node; + } + } + return nullptr; + } + + /*! \brief The accumulated code stream that will be compiled by NVCC */ + std::ostringstream code_stream_; + /*! \brief The accumulated function names. */ + Array func_names_; +}; // CutlassModuleCodegen + +/*! + * \brief Create a runtime module for CUTLASS. + * \param ref The ext_func Relay expression/module to be executed using extern ops. + * \return A runtime module. + */ +runtime::Module CUTLASSCompiler(const ObjectRef& ref) { + ICHECK(ref->IsInstance()) << "The input ref is expected to be a Relax function."; + Function func = Downcast(ref); + std::string func_name = backend::GetExtSymbol(func); + auto source_mod = CutlassModuleCodegen().CreateCSourceModule(func); + const auto* pf = runtime::Registry::Get("contrib.cutlass.compile"); + ICHECK(pf != nullptr); + return (*pf)(source_mod); +} + +TVM_REGISTER_GLOBAL("relax.ext.cutlass").set_body_typed(CUTLASSCompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/dnnl/codegen.cc b/src/relax/backend/contrib/dnnl/codegen.cc new file mode 100644 index 0000000000..cf2379a0fa --- /dev/null +++ b/src/relax/backend/contrib/dnnl/codegen.cc @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/dnnl/codegen.cc + * \brief Implementation of the DNNL JSON serializer. + */ +#include +#include +#include +#include + +#include +#include +#include + +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; +using JSONSerializer = backend::contrib::JSONSerializer; + +inline bool IsOp(const CallNode* call, const std::string& op_name) { + const auto* op_node = call->op.as(); + if (!op_node) return false; + Op op = GetRef(op_node); + return op == Op::Get(op_name); +} + +/*! + * \brief Generates an DNNLModule from a relax expression by serializing the expression to a + * json representation. DNNL is not required here because use of DNNL APIs is deferred until + * runtime. + */ +class DNNLJSONSerializer : public JSONSerializer { + public: + DNNLJSONSerializer(const std::string& symbol, const Expr& expr) + : JSONSerializer(symbol, expr), bindings_(AnalyzeVar2Value(expr)) {} + + using JSONSerializer::VisitExpr_; + + std::vector VisitExpr_(const CallNode* call_node) final { + // The call must be to an inline "Composite" function + const auto* fn_var = call_node->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + ICHECK(fn.defined()); + + auto opt_composite = fn->GetAttr(attr::kComposite); + ICHECK(opt_composite.defined()); + + std::string name = opt_composite.value(); + + const CallNode* root_call = call_node; + if (name.find("conv2d") != std::string::npos) { + for (auto [var, val] : bindings_) { + if (val->IsInstance() && IsOp(val.as(), "relax.nn.conv2d")) { + root_call = val.as(); + break; + } + } + ICHECK(root_call->op.as()) << "Not op node"; + } else { + LOG(FATAL) << "Unimplemented"; + } + + std::vector inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + SetCallNodeAttribute(node, root_call); + return AddNode(node, GetRef(call_node)); + } + + private: + Map bindings_; +}; + +/*! + * \brief Create a runtime module for DNNL. + * \param ref The ext_func Relay expression/module to be executed using extern ops. + * \return A runtime module. + */ +runtime::Module DNNLCompiler(const ObjectRef& ref) { + ICHECK(ref->IsInstance()) << "The input ref is expected to be a Relax function."; + Function func = Downcast(ref); + std::string func_name = backend::GetExtSymbol(func); + + DNNLJSONSerializer serializer(func_name, func); + serializer.serialize(); + std::string graph_json = serializer.GetJSON(); + auto param_names = serializer.GetParams(); + const auto* pf = runtime::Registry::Get("runtime.DNNLJSONRuntimeCreate"); + ICHECK(pf != nullptr) << "Cannot find DNNL runtime module create function."; + runtime::Module lib = (*pf)(func_name, graph_json, param_names); + return lib; +} + +TVM_REGISTER_GLOBAL("relax.ext.dnnl").set_body_typed(DNNLCompiler); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 76bfdb12d2..d061f675e3 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -499,7 +499,7 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr } bool MatchExpr(DFPattern pattern, Expr expr, Optional> var2val) { - if (var2val.defined()) // autojump is enabled with var2val. + if (var2val) // autojump is enabled with var2val. return DFPatternMatcher(std::move(var2val.value())).Match(pattern, expr); else return DFPatternMatcher().Match(pattern, expr); @@ -507,6 +507,23 @@ bool MatchExpr(DFPattern pattern, Expr expr, Optional> v TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); +Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, + Optional> bindings_opt) { + auto bindings = bindings_opt ? bindings_opt.value() : Map{}; + DFPatternMatcher matcher(bindings); + + if (!matcher.Match(pattern, expr)) { + return NullOpt; + } + + Map matching; + for (const auto& [pat, matches] : matcher.GetMemo()) { + ICHECK(matches.size() == 1) << "More than one match for the pattern " << pat; + matching.Set(pat, matches[0]); + } + return matching; +} + struct PNode { const DFPatternNode* ptr; const VarNode* matched = nullptr; diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 0983db3989..85768c7174 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -27,6 +27,9 @@ * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. */ +#include +#include +#include #include #include #include @@ -372,15 +375,22 @@ class FunctionCreator : public ExprMutator { if (const auto* var_binding = binding.as()) { if (const auto* call = var_binding->value.as()) { - ICHECK(call->op == Op::Get("relax.call_tir")); - // Update the name of the function. - name_hint_ = name_hint_ + "_" + Downcast(call->args[0])->name_hint; + if (call->op == Op::Get("relax.call_tir")) { + // Update the name of the function. + name_hint_ = name_hint_ + "_" + Downcast(call->args[0])->name_hint; - const Tuple& args = Downcast(call->args[1]); - for (const Expr& arg : args->fields) { - CheckDefAndUpdateParam(arg); + const Tuple& args = Downcast(call->args[1]); + for (const Expr& arg : args->fields) { + CheckDefAndUpdateParam(arg); + } + // TODO(tvm-team): handle shape expr + } else { + ICHECK(call->op->IsInstance()); + name_hint_ = name_hint_ + "_" + Downcast(call->op)->name; + for (const Expr& arg : call->args) { + CheckDefAndUpdateParam(arg); + } } - // TODO(tvm-team): handle shape expr } else { const auto* tuple_item = var_binding->value.as(); ICHECK(tuple_item != nullptr); @@ -407,7 +417,7 @@ class FunctionCreator : public ExprMutator { /*! * \brief Create the grouped function according according to the collected bindings and parameters - * \note The created function won't be returned immediately. Tt's stored in the `function_` field. + * \note The created function won't be returned immediately. It's stored in the `function_` field. */ void CreateFunction() { // Step 1. Start constructing a new dataflow block. @@ -543,14 +553,16 @@ class OperatorFusor : public ExprMutator { } } + OperatorFusor(IRModule mod, + const std::unordered_map& obj2group) + : ExprMutator(mod), mod_(std::move(mod)), obj2group_(obj2group) {} + /*! * \brief The main transformation on the IRModule * \return The new IRModule after transformation */ IRModule Transform() { - for (const auto& kv : mod_->functions) { - const GlobalVar& gv = kv.first; - const BaseFunc& func = kv.second; + for (const auto& [gv, func] : mod_->functions) { // Only visit Relax function without attr kPrimitive. if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { auto updated_func = Downcast(VisitExpr(func)); @@ -579,8 +591,7 @@ class OperatorFusor : public ExprMutator { CollectFuncBoundary(block->bindings); // Step 3. Create the grouped function for each group. - for (auto& kv : group2func_) { - FunctionCreator& creator = kv.second; + for (auto& [_, creator] : group2func_) { creator.CreateFunction(); } @@ -757,6 +768,90 @@ IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) { return mod; } +static Map GetBindingInverse(const Map& binding) { + Map value_to_bound_var; + for (const auto& [var, val] : binding) { + value_to_bound_var.Set(val, var); + } + return value_to_bound_var; +} + +class PatternBasedPartitioner : ExprVisitor { + public: + using Group = GraphPartitioner::Group; + using ExprVisitor::VisitExpr_; + + static std::unordered_map Run(DFPattern pattern, Expr expr, + support::Arena* arena) { + PatternBasedPartitioner part(pattern, AnalyzeVar2Value(expr)); + PostOrderVisit( + expr, [arena, &part](const Expr& e) { part.group_map_[e.get()] = arena->make(); }); + part.VisitExpr(expr); + return part.group_map_; + } + + PatternBasedPartitioner(DFPattern pattern, const tvm::runtime::Map& bindings) + : pat_(pattern), bindings_(bindings), value_to_bound_var_(GetBindingInverse(bindings)) {} + + void VisitBindingBlock_(const DataflowBlockNode* block) final { + for (const auto& binding : block->bindings) { + auto it = group_map_.find(binding->var.get()); + ICHECK(it != group_map_.end()); + if (const auto* var_binding = binding.as()) { + VisitExpr(var_binding->value); + } + } + } + + void VisitExpr_(const CallNode* call) override { + if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef(call), bindings_)) { + auto parent_group = GetGroupForBoundVar(GetRef(call)); + ICHECK(parent_group); + + for (const auto& [_, match] : matches_opt.value()) { + ICHECK(group_map_.count(match.get())); + if (!match->IsInstance()) { + AddToGroup(match, parent_group); + if (value_to_bound_var_.count(match) && GetGroupForBoundVar(match)->num_nodes == 1) { + AddToGroup(value_to_bound_var_[match], parent_group); + } + } + } + } + } + + private: + void AddToGroup(Expr e, Group* to) { + if (group_map_[e.get()] != to) { + --group_map_[e.get()]->num_nodes; + group_map_[e.get()] = to; + ++to->num_nodes; + } + } + + Group* GetGroupForBoundVar(Expr e) { + ICHECK(value_to_bound_var_.count(e)); + auto bound_var = value_to_bound_var_[e]; + ICHECK(group_map_.count(bound_var.get())); + return group_map_[bound_var.get()]; + } + + DFPattern pat_; + Map bindings_; + Map value_to_bound_var_; + std::unordered_map group_map_; +}; + +IRModule FuseOpsByPattern(DFPattern pattern, IRModule mod) { + std::unordered_map group_map; + support::Arena arena; + for (const auto& [gv, func] : mod->functions) { + auto map = PatternBasedPartitioner::Run(pattern, func, &arena); + group_map.insert(map.begin(), map.end()); + } + return OperatorFusor(mod, group_map).Transform(); +} + namespace transform { Pass FuseOps(int fuse_opt_level) { @@ -774,6 +869,17 @@ Pass FuseOps(int fuse_opt_level) { TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); +Pass FuseOpsByPattern(DFPattern pattern) { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { return relax::FuseOpsByPattern(pattern, m); }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseOpsByPattern", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern); + } // namespace transform } // namespace relax diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index 173dcf5e5f..0e36768c70 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -801,8 +801,8 @@ class CutlassModuleCodegen { runtime::Module CreateCSourceModule() { EmitPreamble(); - for (const auto& kv : mod_->functions) { - if (const auto* function_node = GetCutlassFunctionNode(kv.second)) { + for (const auto& [_, f] : mod_->functions) { + if (const auto* function_node = GetCutlassFunctionNode(f)) { GenCutlassFunc(GetRef(function_node)); } } diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index ba06d082c4..a9e621117d 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -73,6 +73,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { /* Thread safe implementation of Run. Keep runtime instance immutable */ void Run(const TVMArgs& args) const { + LOG(INFO) << "Running DNNL"; auto arg_data_provider = makeIODataProvider(args); auto mem_solver = tensor_registry_.MakeSolver(arg_data_provider); // Execute primitives one by one @@ -316,7 +317,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto padding = GetNodeAttr>(node, "padding"); std::vector padding_l(padding.begin(), padding.begin() + padding.size() / 2); std::vector padding_r(padding.begin() + padding.size() / 2, padding.end()); - auto groups = GetNodeAttr(node, "groups"); + // todo: groups attribute missing in Relax conv2d + auto groups = 1; // GetNodeAttr(node, "groups"); auto src_layout = GetNodeAttr(node, "data_layout"); auto dst_layout = GetNodeAttr(node, "out_layout"); auto wgh_layout = GetNodeAttr(node, "kernel_layout"); diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py new file mode 100644 index 0000000000..5225294f93 --- /dev/null +++ b/tests/python/relax/test_codegen_cutlass.py @@ -0,0 +1,263 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import tvm +import tvm.testing + +from tvm import relax, relay +from tvm.script import relax as R +from tvm.relax.dpl import * +from tvm.contrib.cutlass.build import finalize_modules_relax + + +op_name = "cutlass_tensorop_h1688fprop_optimized_256x128_32x2_nhwc_align8" + +op_def = """ + using cutlass_tensorop_h1688fprop_optimized_256x128_32x2_nhwc_align8 = + typename cutlass::conv::kernel::DefaultConv2dFprop< + cutlass::half_t, + cutlass::layout::TensorNHWC, + cutlass::half_t, + cutlass::layout::TensorNHWC, + cutlass::half_t, + cutlass::layout::TensorNHWC, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32 >, + cutlass::gemm::GemmShape<16, 8, 8>, + + cutlass::epilogue::thread::LinearCombinationRelu< + cutlass::half_t, + 8, + cutlass::half_t, + cutlass::half_t, + cutlass::epilogue::thread::ScaleType::NoBetaScaling + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided, + 8, + 8 + >::Kernel; +""" + + +def make_conv_pattern(conv_name, with_bias=False, activation=None): + data = wildcard() + weight = wildcard() + conv = is_op(conv_name)(data, weight) + + if with_bias: + bias = wildcard() + conv_out = is_op("relax.add")(conv, bias) + else: + conv_out = conv + + if activation: + return is_op(activation)(conv_out) + + return conv_out + + +@tvm.script.ir_module +class Conv2dBiasReLU: + @R.function + def conv2d( + data: R.Tensor((16, 32, 32, 16), "float16"), + weight: R.Tensor((32, 3, 3, 16), "float16"), + bias: R.Tensor((1, 1, 1, 32), "float16"), + ): + with R.dataflow(): + conv1 = relax.op.nn.relu( + relax.op.add( + relax.op.nn.conv2d( + data, weight, padding=(1, 1), data_layout="NHWC", kernel_layout="OHWI" + ), + bias, + ) + ) + R.output(conv1) + + return conv1 + + +@tvm.script.ir_module +class Conv2dBiasReLUPartitioned: + @R.function + def main( + data: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight: R.Tensor((32, 3, 3, 16), dtype="float16"), + bias: R.Tensor((1, 1, 1, 32), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 32), dtype="float16"): + # block 0 + with R.dataflow(): + gv: R.Tensor( + (16, 32, 32, 32), dtype="float16" + ) = fused_relax_nn_conv2d_relax_add_relax_nn_relu(data, weight, bias) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_add_relax_nn_relu( + data1: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((32, 3, 3, 16), dtype="float16"), + bias1: R.Tensor((1, 1, 1, 32), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 32), dtype="float16"): + R.func_attr( + {"Codegen": "cutlass", "global_symbol": "fused_relax_nn_conv2d_relax_add_relax_nn_relu"} + ) + + @R.function + def fused_relax_nn_conv2d_relax_add_relax_nn_relu_inner( + data1: R.Tensor((16, 32, 32, 16), dtype="float16"), + weight1: R.Tensor((32, 3, 3, 16), dtype="float16"), + bias1: R.Tensor((1, 1, 1, 32), dtype="float16"), + ) -> R.Tensor((16, 32, 32, 32), dtype="float16"): + # function attr dict + R.func_attr({"Primitive": 1, "Composite": "conv2d_bias_relu"}) + # block 0 + with R.dataflow(): + lv: R.Tensor((16, 32, 32, 32), dtype="float16") = R.nn.conv2d( + data1, + weight1, + strides=[1, 1], + padding=[1, 1], + dilation=[1, 1], + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="", + ) + lv1: R.Tensor((16, 32, 32, 32), dtype="float16") = R.add(lv, bias1) + gv1: R.Tensor((16, 32, 32, 32), dtype="float16") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + return fused_relax_nn_conv2d_relax_add_relax_nn_relu_inner(data1, weight1, bias1) + + +def annotate_attributes(mod): + # TODO: automate + f_name = "fused_relax_nn_conv2d_relax_add_relax_nn_relu" + f = mod[f_name] + + for k, v in { + "arg0_dtype": "float16", + "arg1_dtype": "float16", + "ret_dtype": "float32", + "arg0_shape": "float16", + "arg1_dtype": "float16", + "ret_dtype": "float32", + "op_type": "conv2d_bias_relu", + "arg0_shape": [16, 32, 32, 16], + "arg1_shape": [32, 3, 3, 16], + "ret_shape": [16, 32, 32, 32], + "strides": [1, 1], + "padding": [1, 1], + "dilation": [1, 1], + "cutlass_op_name": op_name, + "cutlass_op_def": op_def, + }.items(): + f = f.with_attr(k, v) + + mod[f_name] = f + + return mod + + +def test_conv2d_partition(): + mod = Conv2dBiasReLU + pat = make_conv_pattern("relax.nn.conv2d", True, "relax.nn.relu") + mod = relax.transform.FuseOpsByPattern(pat)(mod) + + print(mod.script()) + + +def get_relay_conv2d_bias_relu(d_shape, w_shape): + data = relay.var("data", shape=d_shape) + weight = relay.var("weight", shape=w_shape) + bias = relay.var("bias", shape=(1, 1, 1, w_shape[0])) + return relay.nn.relu( + relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="OHWI", + ) + + bias + ) + + +def get_ref(data_np, weight_np, bias_np): + relay_mod = tvm.IRModule.from_expr(get_relay_conv2d_bias_relu(data_np.shape, weight_np.shape)) + + with tvm.transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential( + [relay.transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"]})] + ) + relay_mod = seq(relay_mod) + + ref = ( + relay.create_executor("graph", mod=relay_mod, device=tvm.gpu(0), target="cuda") + .evaluate()(*[data_np, weight_np, bias_np]) + .numpy() + ) + + return ref + + +def test_conv2d_offload(): + data_np = np.random.randn(16, 32, 32, 16).astype("float16") + weight_np = np.random.randn(32, 3, 3, 16).astype("float16") + bias_np = np.random.randn(1, 1, 1, 32).astype("float16") + + seq = tvm.transform.Sequential( + [ + relax.transform.RunCodegen(), + relax.transform.RemoveUnusedFunctions(), + ] + ) + + mod = annotate_attributes(Conv2dBiasReLUPartitioned) + mod = seq(mod) + + target = tvm.target.Target("cuda") + ex = relax.vm.build(mod, target) + ex = finalize_modules_relax(ex) + + dev = tvm.gpu(0) + vm = relax.VirtualMachine(ex, dev) + + data = tvm.nd.array(data_np, dev) + weight = tvm.nd.array(weight_np, dev) + bias = tvm.nd.array(bias_np, dev) + out = vm["main"](data, weight, bias).numpy() + + ref = get_ref(data_np, weight_np, bias_np) + + print(np.max(np.abs(out - ref)), np.mean(np.abs(out - ref))) + + +if __name__ == "__main__": + test_conv2d_offload() diff --git a/tests/python/relax/test_codegen_dnnl.py b/tests/python/relax/test_codegen_dnnl.py new file mode 100644 index 0000000000..d7af00c0b4 --- /dev/null +++ b/tests/python/relax/test_codegen_dnnl.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import tvm +import tvm.testing + +from tvm import relax, relay +from tvm.script import relax as R +from tvm.relax.dpl import * + + +def make_conv_pattern(conv_name, with_bias=False, activation=None): + data = wildcard() + weight = wildcard() + conv = is_op(conv_name)(data, weight) + + if with_bias: + bias = wildcard() + conv_out = is_op("add")(conv, bias) + else: + conv_out = conv + + return is_op(activation)(conv_out) + + +@tvm.script.ir_module +class Conv2dReLUx2: + @R.function + def conv2d( + data: R.Tensor((1, 64, 56, 56), "float32"), + weight1: R.Tensor((64, 64, 3, 3), "float32"), + weight2: R.Tensor((64, 64, 3, 3), "float32"), + ): + with R.dataflow(): + conv1 = relax.op.nn.relu(relax.op.nn.conv2d(data, weight1, padding=(1, 1))) + conv2d = relax.op.nn.relu(relax.op.nn.conv2d(conv1, weight2, padding=(0, 0))) + R.output(conv2d) + + return conv2d + + +def get_relay_conv2d_relu_x2(d_shape, w_shape): + data = relay.var("data", shape=d_shape) + weight1 = relay.var("weight1", shape=w_shape) + weight2 = relay.var("weight2", shape=w_shape) + conv1 = relay.nn.relu( + relay.nn.conv2d( + data=data, + weight=weight1, + kernel_size=w_shape[2:], + padding=(1, 1), + ) + ) + return relay.nn.relu( + relay.nn.conv2d( + data=conv1, + weight=weight2, + kernel_size=w_shape[2:], + padding=(0, 0), + ) + ) + + +def test_conv2d_partition(): + mod = Conv2dReLUx2 + pat = make_conv_pattern("relax.nn.conv2d", False, "relax.nn.relu") + mod = relax.transform.FuseOpsByPattern(pat)(mod) + print(mod.script()) + + +@tvm.script.ir_module +class Conv2dReLUx2Partitioned: + @R.function + def main( + data: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), + weight2: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_nn_conv2d_relax_nn_relu( + data, weight1 + ) + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = fused_relax_nn_conv2d_relax_nn_relu1( + lv, weight2 + ) + R.output(gv) + return gv + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + R.func_attr({"Codegen": "dnnl", "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu"}) + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu_inner( + data1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): + # function attr dict + R.func_attr({"Primitive": 1, "Composite": "conv2d_relu"}) + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d( + data1, + weight11, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="", + ) + gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv1) + R.output(gv1) + return gv1 + + return fused_relax_nn_conv2d_relax_nn_relu_inner(data1, weight11) + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu1( + conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + R.func_attr({"Codegen": "dnnl", "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu1"}) + + @R.function + def fused_relax_nn_conv2d_relax_nn_relu1_inner( + conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), + weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), + ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): + # function attr dict + R.func_attr({"Primitive": 1, "Composite": "conv2d_relu"}) + # block 0 + with R.dataflow(): + lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d( + conv1, + weight21, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="", + ) + gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv2) + R.output(gv2) + return gv2 + + return fused_relax_nn_conv2d_relax_nn_relu1_inner(conv1, weight21) + + +def test_dnnl_offload(): + seq = tvm.transform.Sequential( + [ + relax.transform.RunCodegen(), + relax.transform.RemoveUnusedFunctions(), + ] + ) + + mod = seq(Conv2dReLUx2Partitioned) + print(mod.script()) + + target = tvm.target.Target("llvm") + ex = relax.vm.build(mod, target) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + f = vm["main"] + + data_np = np.random.randn(1, 64, 56, 56).astype("float32") + weight1_np = np.random.randn(64, 64, 3, 3).astype("float32") + weight2_np = np.random.randn(64, 64, 3, 3).astype("float32") + out = f(tvm.nd.array(data_np), tvm.nd.array(weight1_np), tvm.nd.array(weight2_np)).numpy() + + relay_mod = tvm.IRModule.from_expr(get_relay_conv2d_relu_x2(data_np.shape, weight1_np.shape)) + + ref = ( + relay.create_executor("graph", mod=relay_mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight1_np, weight2_np]) + .numpy() + ) + + print(np.max(np.abs(out - ref)), np.mean(np.abs(out - ref))) + + +if __name__ == "__main__": + test_conv2d_partition() + # test_dnnl_offload()